diff --git a/cpp/src/gandiva/gdv_function_stubs_test.cc b/cpp/src/gandiva/gdv_function_stubs_test.cc index 0b8d14914e82..7e0a30798de4 100644 --- a/cpp/src/gandiva/gdv_function_stubs_test.cc +++ b/cpp/src/gandiva/gdv_function_stubs_test.cc @@ -640,6 +640,24 @@ TEST(TestGdvFnStubs, TestUpper) { EXPECT_THAT(ctx.get_error(), ::testing::HasSubstr( "unexpected byte \\c3 encountered while decoding utf8 string")); + + // Max Len Test + out_len = -1; + int32_t bad_len = std::numeric_limits::max() / 2 + 1; + const char* out = gdv_fn_upper_utf8(ctx_ptr, "dummy", bad_len, &out_len); + // Expect failure + EXPECT_EQ(out_len, 0); + EXPECT_STREQ(out, ""); + EXPECT_THAT(ctx.get_error(), + ::testing::HasSubstr("Would overflow maximum output size")); + + // Negative length test + out_len = -1; + const char* out = gdv_fn_upper_utf8(ctx_ptr, "abc", -105, &out_len); + EXPECT_EQ(out_len, 0); + EXPECT_STREQ(out, ""); + EXPECT_THAT(ctx.get_error(), ::testing::HasSubstr("Invalid (negative) data length")); + ctx.Reset(); std::string e( @@ -698,6 +716,23 @@ TEST(TestGdvFnStubs, TestLower) { EXPECT_EQ(std::string(out_str, out_len), ""); EXPECT_FALSE(ctx.has_error()); + // Max Len Test + out_len = -1; + int32_t bad_len = std::numeric_limits::max() / 2 + 1; + const char* out = gdv_fn_lower_utf8(ctx_ptr, "dummy", bad_len, &out_len); + // Expect failure + EXPECT_EQ(out_len, 0); + EXPECT_STREQ(out, ""); + EXPECT_THAT(ctx.get_error(), + ::testing::HasSubstr("Would overflow maximum output size")); + + // Negative length test + out_len = -1; + const char* out = gdv_fn_lower_utf8(ctx_ptr, "abc", -105, &out_len); + EXPECT_EQ(out_len, 0); + EXPECT_STREQ(out, ""); + EXPECT_THAT(ctx.get_error(), ::testing::HasSubstr("Invalid (negative) data length")); + std::string d("AbOJjÜoß\xc3"); out_str = gdv_fn_lower_utf8(ctx_ptr, d.data(), static_cast(d.length()), &out_len); EXPECT_EQ(std::string(out_str, out_len), ""); @@ -794,6 +829,24 @@ TEST(TestGdvFnStubs, TestInitCap) { EXPECT_THAT(ctx.get_error(), ::testing::HasSubstr( "unexpected byte \\c3 encountered while decoding utf8 string")); + + // Max Len Test + out_len = -1; + int32_t bad_len = std::numeric_limits::max() / 2 + 1; + const char* out = gdv_fn_initcap_utf8(ctx_ptr, "dummy", bad_len, &out_len); + // Expect failure + EXPECT_EQ(out_len, 0); + EXPECT_STREQ(out, ""); + EXPECT_THAT(ctx.get_error(), + ::testing::HasSubstr("Would overflow maximum output size")); + + // Negative length test + out_len = -1; + const char* out = gdv_fn_initcap_utf8(ctx_ptr, "abc", -105, &out_len); + EXPECT_EQ(out_len, 0); + EXPECT_STREQ(out, ""); + EXPECT_THAT(ctx.get_error(), ::testing::HasSubstr("Invalid (negative) data length")); + ctx.Reset(); std::string e( @@ -1127,6 +1180,15 @@ TEST(TestGdvFnStubs, TestTranslate) { result = translate_utf8_utf8_utf8(ctx_ptr, "987654321", 9, "123456789", 9, "0123456789", 10, &out_len); EXPECT_EQ(expected, std::string(result, out_len)); + + int32_t bad_in_len = std::numeric_limits::max() / 4 + 1; + out_len = -1; + const char* result = + translate_utf8_utf8_utf8(ctx_ptr, "ABCDE", bad_in_len, "B", 1, "C", 1, &out_len); + EXPECT_EQ(out_len, 0); + EXPECT_STREQ(result, ""); + EXPECT_THAT(ctx.get_error(), + ::testing::HasSubstr("Would overflow maximum output size")); } TEST(TestGdvFnStubs, TestToUtcTimezone) { diff --git a/cpp/src/gandiva/gdv_string_function_stubs.cc b/cpp/src/gandiva/gdv_string_function_stubs.cc index d271834fb478..cc80a92a24a9 100644 --- a/cpp/src/gandiva/gdv_string_function_stubs.cc +++ b/cpp/src/gandiva/gdv_string_function_stubs.cc @@ -213,6 +213,25 @@ int32_t gdv_fn_utf8_char_length(char c) { return 0; } +static inline bool compute_alloc_len(int64_t context, int32_t data_len, + int32_t* alloc_len, int32_t* out_len) { + // Reject negative lengths + if (ARROW_PREDICT_FALSE(data_len < 0)) { + gdv_fn_context_set_error_msg(context, "Invalid (negative) data length"); + *out_len = 0; + return false; + } + + // Check overflow: 2 * data_len + if (ARROW_PREDICT_FALSE( + arrow::internal::MultiplyWithOverflow(2, data_len, alloc_len))) { + gdv_fn_context_set_error_msg(context, "Would overflow maximum output size"); + *out_len = 0; + return false; + } + return true; +} + // Convert an utf8 string to its corresponding lowercase string GANDIVA_EXPORT const char* gdv_fn_lower_utf8(int64_t context, const char* data, int32_t data_len, @@ -222,10 +241,16 @@ const char* gdv_fn_lower_utf8(int64_t context, const char* data, int32_t data_le return ""; } + int32_t alloc_length = 0; + if (ARROW_PREDICT_FALSE( + !compute_alloc_len(context, data_len, &alloc_length, out_len))) { + return ""; + } + // If it is a single-byte character (ASCII), corresponding lowercase is always 1-byte // long; if it is >= 2 bytes long, lowercase can be at most 4 bytes long, so length of // the output can be at most twice the length of the input - char* out = reinterpret_cast(gdv_fn_context_arena_malloc(context, 2 * data_len)); + char* out = reinterpret_cast(gdv_fn_context_arena_malloc(context, alloc_length)); if (out == nullptr) { gdv_fn_context_set_error_msg(context, "Could not allocate memory for output string"); *out_len = 0; @@ -294,10 +319,16 @@ const char* gdv_fn_upper_utf8(int64_t context, const char* data, int32_t data_le return ""; } + int32_t alloc_length = 0; + if (ARROW_PREDICT_FALSE( + !compute_alloc_len(context, data_len, &alloc_length, out_len))) { + return ""; + } + // If it is a single-byte character (ASCII), corresponding uppercase is always 1-byte // long; if it is >= 2 bytes long, uppercase can be at most 4 bytes long, so length of // the output can be at most twice the length of the input - char* out = reinterpret_cast(gdv_fn_context_arena_malloc(context, 2 * data_len)); + char* out = reinterpret_cast(gdv_fn_context_arena_malloc(context, alloc_length)); if (out == nullptr) { gdv_fn_context_set_error_msg(context, "Could not allocate memory for output string"); *out_len = 0; @@ -480,10 +511,16 @@ const char* gdv_fn_initcap_utf8(int64_t context, const char* data, int32_t data_ return ""; } + int32_t alloc_length = 0; + if (ARROW_PREDICT_FALSE( + !compute_alloc_len(context, data_len, &alloc_length, out_len))) { + return ""; + } + // If it is a single-byte character (ASCII), corresponding uppercase is always 1-byte // long; if it is >= 2 bytes long, uppercase can be at most 4 bytes long, so length of // the output can be at most twice the length of the input - char* out = reinterpret_cast(gdv_fn_context_arena_malloc(context, 2 * data_len)); + char* out = reinterpret_cast(gdv_fn_context_arena_malloc(context, alloc_length)); if (out == nullptr) { gdv_fn_context_set_error_msg(context, "Could not allocate memory for output string"); *out_len = 0; @@ -579,15 +616,24 @@ const char* translate_utf8_utf8_utf8(int64_t context, const char* in, int32_t in return in; } + int32_t alloc_length = 0; + // Check overflow: 4 * in_len + if (ARROW_PREDICT_FALSE( + arrow::internal::MultiplyWithOverflow(4, in_len, &alloc_length))) { + gdv_fn_context_set_error_msg(context, "Would overflow maximum output size"); + *out_len = 0; + return ""; + } + // This variable is to control if there are multi-byte utf8 entries bool has_multi_byte = false; // This variable is to store the final result char* result; - int result_len; + int32_t result_len; // Searching multi-bytes in In - for (int i = 0; i < in_len; i++) { + for (int32_t i = 0; i < in_len; i++) { unsigned char char_single_byte = in[i]; if (char_single_byte > 127) { // found a multi-byte utf-8 char @@ -598,7 +644,7 @@ const char* translate_utf8_utf8_utf8(int64_t context, const char* in, int32_t in // Searching multi-bytes in From if (!has_multi_byte) { - for (int i = 0; i < from_len; i++) { + for (int32_t i = 0; i < from_len; i++) { unsigned char char_single_byte = from[i]; if (char_single_byte > 127) { // found a multi-byte utf-8 char @@ -610,7 +656,7 @@ const char* translate_utf8_utf8_utf8(int64_t context, const char* in, int32_t in // Searching multi-bytes in To if (!has_multi_byte) { - for (int i = 0; i < to_len; i++) { + for (int32_t i = 0; i < to_len; i++) { unsigned char char_single_byte = to[i]; if (char_single_byte > 127) { // found a multi-byte utf-8 char @@ -638,7 +684,7 @@ const char* translate_utf8_utf8_utf8(int64_t context, const char* in, int32_t in // This variable is for controlling the position in entry TO, for never repeat the // changes - int start_compare; + int32_t start_compare; if (to_len > 0) { start_compare = 0; @@ -650,7 +696,7 @@ const char* translate_utf8_utf8_utf8(int64_t context, const char* in, int32_t in // list, to mark deletion positions const char empty = '\0'; - for (int in_for = 0; in_for < in_len; in_for++) { + for (int32_t in_for = 0; in_for < in_len; in_for++) { if (subs_list.find(in[in_for]) != subs_list.end()) { if (subs_list[in[in_for]] != empty) { // If exist in map, only add the correspondent value in result @@ -658,7 +704,7 @@ const char* translate_utf8_utf8_utf8(int64_t context, const char* in, int32_t in result_len++; } } else { - for (int from_for = 0; from_for <= from_len; from_for++) { + for (int32_t from_for = 0; from_for <= from_len; from_for++) { if (from_for == from_len) { // If it's not in the FROM list, just add it to the map and the result. subs_list.insert(std::pair(in[in_for], in[in_for])); @@ -686,10 +732,11 @@ const char* translate_utf8_utf8_utf8(int64_t context, const char* in, int32_t in } } } - } else { // If there are no multibytes in the input, work with std::strings + } else { + // If there are no multibytes in the input, work with std::strings // This variable is for receive the substitutions, malloc is in_len * 4 to receive // possible inputs with 4 bytes - result = reinterpret_cast(gdv_fn_context_arena_malloc(context, in_len * 4)); + result = reinterpret_cast(gdv_fn_context_arena_malloc(context, alloc_length)); if (result == nullptr) { gdv_fn_context_set_error_msg(context, @@ -704,7 +751,7 @@ const char* translate_utf8_utf8_utf8(int64_t context, const char* in, int32_t in // This variable is for controlling the position in entry TO, for never repeat the // changes - int start_compare; + int32_t start_compare; if (to_len > 0) { start_compare = 0; @@ -717,11 +764,11 @@ const char* translate_utf8_utf8_utf8(int64_t context, const char* in, int32_t in const std::string empty = ""; // This variables is to control len of multi-bytes entries - int len_char_in = 0; - int len_char_from = 0; - int len_char_to = 0; + int32_t len_char_in = 0; + int32_t len_char_from = 0; + int32_t len_char_to = 0; - for (int in_for = 0; in_for < in_len; in_for += len_char_in) { + for (int32_t in_for = 0; in_for < in_len; in_for += len_char_in) { // Updating len to char in this position len_char_in = gdv_fn_utf8_char_length(in[in_for]); // Making copy to std::string with length for this char position @@ -734,7 +781,7 @@ const char* translate_utf8_utf8_utf8(int64_t context, const char* in, int32_t in result_len += static_cast(subs_list[insert_copy_key].length()); } } else { - for (int from_for = 0; from_for <= from_len; from_for += len_char_from) { + for (int32_t from_for = 0; from_for <= from_len; from_for += len_char_from) { // Updating len to char in this position len_char_from = gdv_fn_utf8_char_length(from[from_for]); // Making copy to std::string with length for this char position diff --git a/cpp/src/gandiva/precompiled/string_ops.cc b/cpp/src/gandiva/precompiled/string_ops.cc index 035d3c8c62e1..1a2cefe39e11 100644 --- a/cpp/src/gandiva/precompiled/string_ops.cc +++ b/cpp/src/gandiva/precompiled/string_ops.cc @@ -1924,9 +1924,19 @@ const char* quote_utf8(gdv_int64 context, const char* in, gdv_int32 in_len, *out_len = 0; return ""; } + + int32_t alloc_length = 0; + // Check overflow: 2 * in_len + if (ARROW_PREDICT_FALSE( + arrow::internal::MultiplyWithOverflow(2, in_len, &alloc_length))) { + gdv_fn_context_set_error_msg(context, "Would overflow maximum output size"); + *out_len = 0; + return ""; + } + // try to allocate double size output string (worst case) auto out = - reinterpret_cast(gdv_fn_context_arena_malloc(context, (in_len * 2) + 2)); + reinterpret_cast(gdv_fn_context_arena_malloc(context, alloc_length + 2)); if (out == nullptr) { gdv_fn_context_set_error_msg(context, "Could not allocate memory for output string"); *out_len = 0; @@ -2829,8 +2839,17 @@ const char* to_hex_binary(int64_t context, const char* text, int32_t text_len, return ""; } + int32_t alloc_length = 0; + // Check overflow: 2 * in_len + if (ARROW_PREDICT_FALSE( + arrow::internal::MultiplyWithOverflow(2, in_len, &alloc_length))) { + gdv_fn_context_set_error_msg(context, "Would overflow maximum output size"); + *out_len = 0; + return ""; + } + auto ret = - reinterpret_cast(gdv_fn_context_arena_malloc(context, text_len * 2 + 1)); + reinterpret_cast(gdv_fn_context_arena_malloc(context, alloc_length + 1)); if (ret == nullptr) { gdv_fn_context_set_error_msg(context, "Could not allocate memory for output string"); diff --git a/cpp/src/gandiva/precompiled/string_ops_test.cc b/cpp/src/gandiva/precompiled/string_ops_test.cc index d57eb437530c..1b670adc4de0 100644 --- a/cpp/src/gandiva/precompiled/string_ops_test.cc +++ b/cpp/src/gandiva/precompiled/string_ops_test.cc @@ -1165,6 +1165,11 @@ TEST(TestStringOps, TestQuote) { out_str = quote_utf8(ctx_ptr, "'''''''''", 9, &out_len); EXPECT_EQ(std::string(out_str, out_len), "'\\'\\'\\'\\'\\'\\'\\'\\'\\''"); EXPECT_FALSE(ctx.has_error()); + + int32_t bad_in_len = std::numeric_limits::max() / 2 + 20; + out_str = quote_utf8(ctx_ptr, "ABCDE", bad_in_len, &out_len); + EXPECT_EQ(out_len, 0); + EXPECT_EQ(out_str, ""); } TEST(TestStringOps, TestLtrim) { @@ -2498,6 +2503,11 @@ TEST(TestStringOps, TestToHex) { output = std::string(out_str, out_len); EXPECT_EQ(out_len, 2 * in_len); EXPECT_EQ(output, "090A090A090A090A0A0A092061206C657474405D6572"); + + int32_t bad_in_len = std::numeric_limits::max() / 2 + 20; + out_str = to_hex_binary(ctx_ptr, binary_string, bad_in_len, &out_len); + EXPECT_EQ(out_len, 0); + EXPECT_EQ(out_str, ""); } TEST(TestStringOps, TestToHexInt64) {