diff --git a/cpp/src/gandiva/gdv_function_stubs_test.cc b/cpp/src/gandiva/gdv_function_stubs_test.cc index 0b8d14914e82..6213f7f4d862 100644 --- a/cpp/src/gandiva/gdv_function_stubs_test.cc +++ b/cpp/src/gandiva/gdv_function_stubs_test.cc @@ -587,6 +587,10 @@ TEST(TestGdvFnStubs, TestSubstringIndex) { std::numeric_limits::min(), &out_len); EXPECT_EQ(std::string(out_str, out_len), "a.b.c"); EXPECT_FALSE(ctx.has_error()); + + out_str = gdv_fn_substring_index(ctx_ptr, "a", -2, ".", -1, -50, &out_len); + EXPECT_STREQ(out_str, ""); + EXPECT_EQ(out_len, 0); } TEST(TestGdvFnStubs, TestUpper) { @@ -640,6 +644,26 @@ TEST(TestGdvFnStubs, TestUpper) { EXPECT_THAT(ctx.get_error(), ::testing::HasSubstr( "unexpected byte \\c3 encountered while decoding utf8 string")); + + ctx.Reset(); + + // Max Len Test + out_len = -1; + int32_t bad_len = std::numeric_limits::max() / 2 + 1; + const char* out = gdv_fn_upper_utf8(ctx_ptr, "dummy", bad_len, &out_len); + // Expect failure + EXPECT_EQ(out_len, 0); + EXPECT_STREQ(out, ""); + EXPECT_THAT(ctx.get_error(), + ::testing::HasSubstr("Would overflow maximum output size")); + ctx.Reset(); + + // Negative length test + out_len = -1; + out = gdv_fn_upper_utf8(ctx_ptr, "abc", -105, &out_len); + EXPECT_EQ(out_len, 0); + EXPECT_STREQ(out, ""); + EXPECT_THAT(ctx.get_error(), ::testing::HasSubstr("Invalid (negative) data length")); ctx.Reset(); std::string e( @@ -697,6 +721,26 @@ TEST(TestGdvFnStubs, TestLower) { out_str = gdv_fn_lower_utf8(ctx_ptr, "", 0, &out_len); EXPECT_EQ(std::string(out_str, out_len), ""); EXPECT_FALSE(ctx.has_error()); + ctx.Reset(); + + // Max Len Test + out_len = -1; + int32_t bad_len = std::numeric_limits::max() / 2 + 1; + const char* out = gdv_fn_lower_utf8(ctx_ptr, "dummy", bad_len, &out_len); + // Expect failure + EXPECT_EQ(out_len, 0); + EXPECT_STREQ(out, ""); + EXPECT_THAT(ctx.get_error(), + ::testing::HasSubstr("Would overflow maximum output size")); + ctx.Reset(); + + // Negative length test + out_len = -1; + out = gdv_fn_lower_utf8(ctx_ptr, "abc", -105, &out_len); + EXPECT_EQ(out_len, 0); + EXPECT_STREQ(out, ""); + EXPECT_THAT(ctx.get_error(), ::testing::HasSubstr("Invalid (negative) data length")); + ctx.Reset(); std::string d("AbOJjÜoß\xc3"); out_str = gdv_fn_lower_utf8(ctx_ptr, d.data(), static_cast(d.length()), &out_len); @@ -796,6 +840,25 @@ TEST(TestGdvFnStubs, TestInitCap) { "unexpected byte \\c3 encountered while decoding utf8 string")); ctx.Reset(); + // Max Len Test + out_len = -1; + int32_t bad_len = std::numeric_limits::max() / 2 + 1; + const char* out = gdv_fn_initcap_utf8(ctx_ptr, "dummy", bad_len, &out_len); + // Expect failure + EXPECT_EQ(out_len, 0); + EXPECT_STREQ(out, ""); + EXPECT_THAT(ctx.get_error(), + ::testing::HasSubstr("Would overflow maximum output size")); + ctx.Reset(); + + // Negative length test + out_len = -1; + out = gdv_fn_initcap_utf8(ctx_ptr, "abc", -105, &out_len); + EXPECT_EQ(out_len, 0); + EXPECT_STREQ(out, ""); + EXPECT_THAT(ctx.get_error(), ::testing::HasSubstr("Invalid (negative) data length")); + ctx.Reset(); + std::string e( "åbÑg\xe0\xa0" "åBUå"); @@ -1127,6 +1190,15 @@ TEST(TestGdvFnStubs, TestTranslate) { result = translate_utf8_utf8_utf8(ctx_ptr, "987654321", 9, "123456789", 9, "0123456789", 10, &out_len); EXPECT_EQ(expected, std::string(result, out_len)); + + int32_t bad_in_len = std::numeric_limits::max() / 4 + 1; + out_len = -1; + result = + translate_utf8_utf8_utf8(ctx_ptr, "ABCDE", bad_in_len, "B", 1, "C", 1, &out_len); + EXPECT_EQ(out_len, 0); + EXPECT_STREQ(result, ""); + EXPECT_THAT(ctx.get_error(), + ::testing::HasSubstr("Would overflow maximum output size")); } TEST(TestGdvFnStubs, TestToUtcTimezone) { diff --git a/cpp/src/gandiva/gdv_string_function_stubs.cc b/cpp/src/gandiva/gdv_string_function_stubs.cc index d271834fb478..51ca6ab79b86 100644 --- a/cpp/src/gandiva/gdv_string_function_stubs.cc +++ b/cpp/src/gandiva/gdv_string_function_stubs.cc @@ -213,6 +213,25 @@ int32_t gdv_fn_utf8_char_length(char c) { return 0; } +static inline bool is_datalen_valid(int64_t context, int32_t data_len, int32_t* alloc_len, + int32_t* out_len) { + // Reject negative lengths + if (ARROW_PREDICT_FALSE(data_len < 0)) { + gdv_fn_context_set_error_msg(context, "Invalid (negative) data length"); + *out_len = 0; + return false; + } + + // Check overflow: 2 * data_len + if (ARROW_PREDICT_FALSE( + arrow::internal::MultiplyWithOverflow(2, data_len, alloc_len))) { + gdv_fn_context_set_error_msg(context, "Would overflow maximum output size"); + *out_len = 0; + return false; + } + return true; +} + // Convert an utf8 string to its corresponding lowercase string GANDIVA_EXPORT const char* gdv_fn_lower_utf8(int64_t context, const char* data, int32_t data_len, @@ -222,10 +241,16 @@ const char* gdv_fn_lower_utf8(int64_t context, const char* data, int32_t data_le return ""; } + int32_t alloc_length = 0; + if (ARROW_PREDICT_FALSE( + not is_datalen_valid(context, data_len, &alloc_length, out_len))) { + return ""; + } + // If it is a single-byte character (ASCII), corresponding lowercase is always 1-byte // long; if it is >= 2 bytes long, lowercase can be at most 4 bytes long, so length of // the output can be at most twice the length of the input - char* out = reinterpret_cast(gdv_fn_context_arena_malloc(context, 2 * data_len)); + char* out = reinterpret_cast(gdv_fn_context_arena_malloc(context, alloc_length)); if (out == nullptr) { gdv_fn_context_set_error_msg(context, "Could not allocate memory for output string"); *out_len = 0; @@ -294,10 +319,16 @@ const char* gdv_fn_upper_utf8(int64_t context, const char* data, int32_t data_le return ""; } + int32_t alloc_length = 0; + if (ARROW_PREDICT_FALSE( + not is_datalen_valid(context, data_len, &alloc_length, out_len))) { + return ""; + } + // If it is a single-byte character (ASCII), corresponding uppercase is always 1-byte // long; if it is >= 2 bytes long, uppercase can be at most 4 bytes long, so length of // the output can be at most twice the length of the input - char* out = reinterpret_cast(gdv_fn_context_arena_malloc(context, 2 * data_len)); + char* out = reinterpret_cast(gdv_fn_context_arena_malloc(context, alloc_length)); if (out == nullptr) { gdv_fn_context_set_error_msg(context, "Could not allocate memory for output string"); *out_len = 0; @@ -367,6 +398,15 @@ const char* gdv_fn_substring_index(int64_t context, const char* txt, int32_t txt return ""; } + if (ARROW_PREDICT_FALSE(txt_len < 0)) { + *out_len = 0; + return ""; + } + if (ARROW_PREDICT_FALSE(pat_len < 0)) { + *out_len = 0; + return ""; + } + char* out = reinterpret_cast(gdv_fn_context_arena_malloc(context, txt_len)); if (out == nullptr) { gdv_fn_context_set_error_msg(context, "Could not allocate memory for output string"); @@ -445,8 +485,12 @@ const char* gdv_fn_substring_index(int64_t context, const char* txt, int32_t txt return out; } else { + if (txt_len < 0) { + *out_len = 0; + return ""; + } + memcpy(out, txt, static_cast(txt_len)); *out_len = txt_len; - memcpy(out, txt, txt_len); return out; } } @@ -480,10 +524,16 @@ const char* gdv_fn_initcap_utf8(int64_t context, const char* data, int32_t data_ return ""; } + int32_t alloc_length = 0; + if (ARROW_PREDICT_FALSE( + not is_datalen_valid(context, data_len, &alloc_length, out_len))) { + return ""; + } + // If it is a single-byte character (ASCII), corresponding uppercase is always 1-byte // long; if it is >= 2 bytes long, uppercase can be at most 4 bytes long, so length of // the output can be at most twice the length of the input - char* out = reinterpret_cast(gdv_fn_context_arena_malloc(context, 2 * data_len)); + char* out = reinterpret_cast(gdv_fn_context_arena_malloc(context, alloc_length)); if (out == nullptr) { gdv_fn_context_set_error_msg(context, "Could not allocate memory for output string"); *out_len = 0; @@ -579,15 +629,24 @@ const char* translate_utf8_utf8_utf8(int64_t context, const char* in, int32_t in return in; } + int32_t alloc_length = 0; + // Check overflow: 4 * in_len + if (ARROW_PREDICT_FALSE( + arrow::internal::MultiplyWithOverflow(4, in_len, &alloc_length))) { + gdv_fn_context_set_error_msg(context, "Would overflow maximum output size"); + *out_len = 0; + return ""; + } + // This variable is to control if there are multi-byte utf8 entries bool has_multi_byte = false; // This variable is to store the final result char* result; - int result_len; + int32_t result_len; // Searching multi-bytes in In - for (int i = 0; i < in_len; i++) { + for (int32_t i = 0; i < in_len; i++) { unsigned char char_single_byte = in[i]; if (char_single_byte > 127) { // found a multi-byte utf-8 char @@ -598,7 +657,7 @@ const char* translate_utf8_utf8_utf8(int64_t context, const char* in, int32_t in // Searching multi-bytes in From if (!has_multi_byte) { - for (int i = 0; i < from_len; i++) { + for (int32_t i = 0; i < from_len; i++) { unsigned char char_single_byte = from[i]; if (char_single_byte > 127) { // found a multi-byte utf-8 char @@ -610,7 +669,7 @@ const char* translate_utf8_utf8_utf8(int64_t context, const char* in, int32_t in // Searching multi-bytes in To if (!has_multi_byte) { - for (int i = 0; i < to_len; i++) { + for (int32_t i = 0; i < to_len; i++) { unsigned char char_single_byte = to[i]; if (char_single_byte > 127) { // found a multi-byte utf-8 char @@ -621,7 +680,7 @@ const char* translate_utf8_utf8_utf8(int64_t context, const char* in, int32_t in } // If there are no multibytes in the input, work only with char - if (!has_multi_byte) { + if (not has_multi_byte) { // This variable is for receive the substitutions result = reinterpret_cast(gdv_fn_context_arena_malloc(context, in_len)); @@ -638,7 +697,7 @@ const char* translate_utf8_utf8_utf8(int64_t context, const char* in, int32_t in // This variable is for controlling the position in entry TO, for never repeat the // changes - int start_compare; + int32_t start_compare; if (to_len > 0) { start_compare = 0; @@ -650,7 +709,7 @@ const char* translate_utf8_utf8_utf8(int64_t context, const char* in, int32_t in // list, to mark deletion positions const char empty = '\0'; - for (int in_for = 0; in_for < in_len; in_for++) { + for (int32_t in_for = 0; in_for < in_len; in_for++) { if (subs_list.find(in[in_for]) != subs_list.end()) { if (subs_list[in[in_for]] != empty) { // If exist in map, only add the correspondent value in result @@ -658,7 +717,7 @@ const char* translate_utf8_utf8_utf8(int64_t context, const char* in, int32_t in result_len++; } } else { - for (int from_for = 0; from_for <= from_len; from_for++) { + for (int32_t from_for = 0; from_for <= from_len; from_for++) { if (from_for == from_len) { // If it's not in the FROM list, just add it to the map and the result. subs_list.insert(std::pair(in[in_for], in[in_for])); @@ -686,10 +745,11 @@ const char* translate_utf8_utf8_utf8(int64_t context, const char* in, int32_t in } } } - } else { // If there are no multibytes in the input, work with std::strings + } else { + // If there are multibytes in the input, work with std::strings // This variable is for receive the substitutions, malloc is in_len * 4 to receive // possible inputs with 4 bytes - result = reinterpret_cast(gdv_fn_context_arena_malloc(context, in_len * 4)); + result = reinterpret_cast(gdv_fn_context_arena_malloc(context, alloc_length)); if (result == nullptr) { gdv_fn_context_set_error_msg(context, @@ -704,7 +764,7 @@ const char* translate_utf8_utf8_utf8(int64_t context, const char* in, int32_t in // This variable is for controlling the position in entry TO, for never repeat the // changes - int start_compare; + int32_t start_compare; if (to_len > 0) { start_compare = 0; @@ -717,11 +777,11 @@ const char* translate_utf8_utf8_utf8(int64_t context, const char* in, int32_t in const std::string empty = ""; // This variables is to control len of multi-bytes entries - int len_char_in = 0; - int len_char_from = 0; - int len_char_to = 0; + int32_t len_char_in = 0; + int32_t len_char_from = 0; + int32_t len_char_to = 0; - for (int in_for = 0; in_for < in_len; in_for += len_char_in) { + for (int32_t in_for = 0; in_for < in_len; in_for += len_char_in) { // Updating len to char in this position len_char_in = gdv_fn_utf8_char_length(in[in_for]); // Making copy to std::string with length for this char position @@ -734,11 +794,7 @@ const char* translate_utf8_utf8_utf8(int64_t context, const char* in, int32_t in result_len += static_cast(subs_list[insert_copy_key].length()); } } else { - for (int from_for = 0; from_for <= from_len; from_for += len_char_from) { - // Updating len to char in this position - len_char_from = gdv_fn_utf8_char_length(from[from_for]); - // Making copy to std::string with length for this char position - std::string copy_from_compare(from + from_for, len_char_from); + for (int32_t from_for = 0; from_for <= from_len; from_for += len_char_from) { if (from_for == from_len) { // If it's not in the FROM list, just add it to the map and the result. std::string insert_copy_value(in + in_for, len_char_in); @@ -751,6 +807,11 @@ const char* translate_utf8_utf8_utf8(int64_t context, const char* in, int32_t in break; } + // Updating len to char in this position + len_char_from = gdv_fn_utf8_char_length(from[from_for]); + // Making copy to std::string with length for this char position + std::string copy_from_compare(from + from_for, len_char_from); + if (insert_copy_key != copy_from_compare) { // If this character does not exist in FROM list, don't need treatment continue; diff --git a/cpp/src/gandiva/precompiled/string_ops.cc b/cpp/src/gandiva/precompiled/string_ops.cc index 035d3c8c62e1..ae8d014ff143 100644 --- a/cpp/src/gandiva/precompiled/string_ops.cc +++ b/cpp/src/gandiva/precompiled/string_ops.cc @@ -1924,9 +1924,17 @@ const char* quote_utf8(gdv_int64 context, const char* in, gdv_int32 in_len, *out_len = 0; return ""; } + + int32_t alloc_length = 0; + if (ARROW_PREDICT_FALSE( + arrow::internal::AddWithOverflow(2, (2 * in_len), &alloc_length))) { + gdv_fn_context_set_error_msg(context, "Memory allocation size too large"); + *out_len = 0; + return ""; + } + // try to allocate double size output string (worst case) - auto out = - reinterpret_cast(gdv_fn_context_arena_malloc(context, (in_len * 2) + 2)); + auto out = reinterpret_cast(gdv_fn_context_arena_malloc(context, alloc_length)); if (out == nullptr) { gdv_fn_context_set_error_msg(context, "Could not allocate memory for output string"); *out_len = 0; @@ -2444,6 +2452,66 @@ void concat_word(char* out_buf, int* out_idx, const char* in_buf, int in_len, *out_idx += in_len; } +// Helper structure to maintain state during safe length accumulation +struct SafeLengthState { + int32_t total_len = 0; + int32_t num_valid = 0; + bool overflow = false; +}; + +// Helper to safely add a word length +static inline bool safe_accumulate_word(SafeLengthState* state, int32_t word_len, + bool word_validity) { + if (!word_validity) return true; + + int32_t temp = 0; + if (ARROW_PREDICT_FALSE( + arrow::internal::AddWithOverflow(state->total_len, word_len, &temp))) { + state->overflow = true; + return false; + } + state->total_len = temp; + state->num_valid++; + return true; +} + +// Helper to safely add separators based on number of valid words +static inline bool safe_add_separators(SafeLengthState* state, int32_t separator_len) { + if (state->num_valid <= 1) return true; + + int32_t sep_total = 0; + int32_t temp = 0; + + if (ARROW_PREDICT_FALSE(arrow::internal::MultiplyWithOverflow( + separator_len, state->num_valid - 1, &sep_total))) { + state->overflow = true; + return false; + } + + if (ARROW_PREDICT_FALSE( + arrow::internal::AddWithOverflow(state->total_len, sep_total, &temp))) { + state->overflow = true; + return false; + } + + state->total_len = temp; + return true; +} + +// Helper to handle overflow failure (sets output parameters and returns nullptr) +static inline const char* handle_overflow_failure(bool* out_valid, int32_t* out_len) { + *out_len = 0; + *out_valid = false; + return ""; +} + +// Helper to handle empty result (all words invalid) +static inline const char* handle_empty_result(bool* out_valid, int32_t* out_len) { + *out_len = 0; + *out_valid = true; + return ""; +} + FORCE_INLINE const char* concat_ws_utf8_utf8(int64_t context, const char* separator, int32_t separator_len, bool separator_validity, @@ -2451,7 +2519,6 @@ const char* concat_ws_utf8_utf8(int64_t context, const char* separator, const char* word2, int32_t word2_len, bool word2_validity, bool* out_valid, int32_t* out_len) { *out_len = 0; - int numValidInput = 0; // If separator is null, always return null if (!separator_validity) { *out_len = 0; @@ -2459,22 +2526,34 @@ const char* concat_ws_utf8_utf8(int64_t context, const char* separator, return ""; } - if (word1_validity) { - *out_len += word1_len; - numValidInput++; + // If separator is null, always return null + if (!separator_validity) { + return handle_overflow_failure(out_valid, out_len); } - if (word2_validity) { - *out_len += word2_len; - numValidInput++; + + SafeLengthState state; + + // Accumulate word lengths safely + safe_accumulate_word(&state, word1_len, word1_validity); + safe_accumulate_word(&state, word2_len, word2_validity); + + if (state.overflow) { + return handle_overflow_failure(out_valid, out_len); } - *out_len += separator_len * (numValidInput > 1 ? numValidInput - 1 : 0); - if (*out_len == 0) { - *out_valid = true; - return ""; + // Add separator lengths + if (!safe_add_separators(&state, separator_len)) { + return handle_overflow_failure(out_valid, out_len); + } + + // Handle case with no valid words + if (state.total_len == 0) { + return handle_empty_result(out_valid, out_len); } - char* out = reinterpret_cast(gdv_fn_context_arena_malloc(context, *out_len)); + // Allocate and concatenate + char* out = + reinterpret_cast(gdv_fn_context_arena_malloc(context, state.total_len)); if (out == nullptr) { gdv_fn_context_set_error_msg(context, "Could not allocate memory for output string"); *out_len = 0; @@ -2503,36 +2582,30 @@ const char* concat_ws_utf8_utf8_utf8( const char* word2, int32_t word2_len, bool word2_validity, const char* word3, int32_t word3_len, bool word3_validity, bool* out_valid, int32_t* out_len) { *out_len = 0; - int numValidInput = 0; - // If separator is null, always return null if (!separator_validity) { - *out_len = 0; - *out_valid = false; - return ""; + return handle_overflow_failure(out_valid, out_len); } - if (word1_validity) { - *out_len += word1_len; - numValidInput++; - } - if (word2_validity) { - *out_len += word2_len; - numValidInput++; - } - if (word3_validity) { - *out_len += word3_len; - numValidInput++; + SafeLengthState state; + + safe_accumulate_word(&state, word1_len, word1_validity); + safe_accumulate_word(&state, word2_len, word2_validity); + safe_accumulate_word(&state, word3_len, word3_validity); + + if (state.overflow) { + return handle_overflow_failure(out_valid, out_len); } - *out_len += separator_len * (numValidInput > 1 ? numValidInput - 1 : 0); + if (!safe_add_separators(&state, separator_len)) { + return handle_overflow_failure(out_valid, out_len); + } - if (*out_len == 0) { - *out_len = 0; - *out_valid = true; - return ""; + if (state.total_len == 0) { + return handle_empty_result(out_valid, out_len); } - char* out = reinterpret_cast(gdv_fn_context_arena_malloc(context, *out_len)); + char* out = + reinterpret_cast(gdv_fn_context_arena_malloc(context, state.total_len)); if (out == nullptr) { gdv_fn_context_set_error_msg(context, "Could not allocate memory for output string"); *out_len = 0; @@ -2564,39 +2637,44 @@ const char* concat_ws_utf8_utf8_utf8_utf8( int32_t word3_len, bool word3_validity, const char* word4, int32_t word4_len, bool word4_validity, bool* out_valid, int32_t* out_len) { *out_len = 0; - int numValidInput = 0; // If separator is null, always return null if (!separator_validity) { *out_len = 0; *out_valid = false; return ""; } - if (word1_validity) { - *out_len += word1_len; - numValidInput++; - } - if (word2_validity) { - *out_len += word2_len; - numValidInput++; - } - if (word3_validity) { - *out_len += word3_len; - numValidInput++; - } - if (word4_validity) { - *out_len += word4_len; - numValidInput++; + + SafeLengthState state; + + // Accumulate all word lengths with overflow checking + safe_accumulate_word(&state, word1_len, word1_validity); + safe_accumulate_word(&state, word2_len, word2_validity); + safe_accumulate_word(&state, word3_len, word3_validity); + safe_accumulate_word(&state, word4_len, word4_validity); + + if (state.overflow) { + *out_len = 0; + *out_valid = false; + return ""; } - *out_len += separator_len * (numValidInput > 1 ? numValidInput - 1 : 0); + // Add separator lengths with overflow checking + if (!safe_add_separators(&state, separator_len)) { + *out_len = 0; + *out_valid = false; + return ""; + } - if (*out_len == 0) { + // Handle case with no valid words + if (state.total_len == 0) { *out_len = 0; *out_valid = true; return ""; } - char* out = reinterpret_cast(gdv_fn_context_arena_malloc(context, *out_len)); + // Allocate memory + char* out = + reinterpret_cast(gdv_fn_context_arena_malloc(context, state.total_len)); if (out == nullptr) { gdv_fn_context_set_error_msg(context, "Could not allocate memory for output string"); *out_valid = false; @@ -2631,43 +2709,45 @@ const char* concat_ws_utf8_utf8_utf8_utf8_utf8( bool word4_validity, const char* word5, int32_t word5_len, bool word5_validity, bool* out_valid, int32_t* out_len) { *out_len = 0; - int numValidInput = 0; // If separator is null, always return null if (!separator_validity) { *out_len = 0; *out_valid = false; return ""; } - if (word1_validity) { - *out_len += word1_len; - numValidInput++; - } - if (word2_validity) { - *out_len += word2_len; - numValidInput++; - } - if (word3_validity) { - *out_len += word3_len; - numValidInput++; - } - if (word4_validity) { - *out_len += word4_len; - numValidInput++; - } - if (word5_validity) { - *out_len += word5_len; - numValidInput++; + + SafeLengthState state; + + // Accumulate all word lengths with overflow checking + safe_accumulate_word(&state, word1_len, word1_validity); + safe_accumulate_word(&state, word2_len, word2_validity); + safe_accumulate_word(&state, word3_len, word3_validity); + safe_accumulate_word(&state, word4_len, word4_validity); + safe_accumulate_word(&state, word5_len, word5_validity); + + if (state.overflow) { + *out_len = 0; + *out_valid = false; + return ""; } - *out_len += separator_len * (numValidInput > 1 ? numValidInput - 1 : 0); + // Add separator lengths with overflow checking + if (!safe_add_separators(&state, separator_len)) { + *out_len = 0; + *out_valid = false; + return ""; + } - if (*out_len == 0) { + // Handle case with no valid words + if (state.total_len == 0) { *out_len = 0; *out_valid = true; return ""; } - char* out = reinterpret_cast(gdv_fn_context_arena_malloc(context, *out_len)); + // Allocate memory + char* out = + reinterpret_cast(gdv_fn_context_arena_malloc(context, state.total_len)); if (out == nullptr) { gdv_fn_context_set_error_msg(context, "Could not allocate memory for output string"); *out_len = 0; @@ -2824,13 +2904,22 @@ const char* elt_int32_utf8_utf8_utf8_utf8_utf8( FORCE_INLINE const char* to_hex_binary(int64_t context, const char* text, int32_t text_len, int32_t* out_len) { - if (text_len == 0) { + if (ARROW_PREDICT_FALSE(text_len <= 0)) { *out_len = 0; return ""; } - auto ret = - reinterpret_cast(gdv_fn_context_arena_malloc(context, text_len * 2 + 1)); + int32_t alloc_length = 0; + + // Check overflow for text_len + if (ARROW_PREDICT_FALSE( + arrow::internal::AddWithOverflow(1, (2 * text_len), &alloc_length))) { + gdv_fn_context_set_error_msg(context, "Memory allocation size too large"); + *out_len = 0; + return ""; + } + + auto ret = reinterpret_cast(gdv_fn_context_arena_malloc(context, alloc_length)); if (ret == nullptr) { gdv_fn_context_set_error_msg(context, "Could not allocate memory for output string"); diff --git a/cpp/src/gandiva/precompiled/string_ops_test.cc b/cpp/src/gandiva/precompiled/string_ops_test.cc index d57eb437530c..7da6dc0b10a5 100644 --- a/cpp/src/gandiva/precompiled/string_ops_test.cc +++ b/cpp/src/gandiva/precompiled/string_ops_test.cc @@ -1165,6 +1165,16 @@ TEST(TestStringOps, TestQuote) { out_str = quote_utf8(ctx_ptr, "'''''''''", 9, &out_len); EXPECT_EQ(std::string(out_str, out_len), "'\\'\\'\\'\\'\\'\\'\\'\\'\\''"); EXPECT_FALSE(ctx.has_error()); + + int32_t bad_in_len = std::numeric_limits::max() / 2 + 1; + out_str = quote_utf8(ctx_ptr, "YYZ", bad_in_len, &out_len); + EXPECT_EQ(out_len, 0); + EXPECT_STREQ(out_str, ""); + + bad_in_len = std::numeric_limits::max() / 2 + 20; + out_str = quote_utf8(ctx_ptr, "ABCDE", bad_in_len, &out_len); + EXPECT_EQ(out_len, 0); + EXPECT_STREQ(out_str, ""); } TEST(TestStringOps, TestLtrim) { @@ -2298,11 +2308,42 @@ TEST(TestStringOps, TestConcatWs) { EXPECT_EQ(std::string(out, out_len), "hey"); EXPECT_EQ(out_result, true); + // Max word1_len + out = concat_ws_utf8_utf8(ctx_ptr, separator, sep_len, true, word1, + std::numeric_limits::max(), true, word2, word2_len, + true, &out_result, &out_len); + EXPECT_STREQ(out, ""); + EXPECT_EQ(out_len, 0); + EXPECT_EQ(out_result, false); + + // Max word2 len + out = concat_ws_utf8_utf8(ctx_ptr, separator, sep_len, true, word1, word1_len, true, + word2, std::numeric_limits::max(), true, &out_result, + &out_len); + EXPECT_STREQ(out, ""); + EXPECT_EQ(out_len, 0); + EXPECT_EQ(out_result, false); + + // Max separator len + out = concat_ws_utf8_utf8(ctx_ptr, separator, std::numeric_limits::max(), true, + word1, word1_len, true, word2, word2_len, true, &out_result, + &out_len); + EXPECT_STREQ(out, ""); + EXPECT_EQ(out_len, 0); + EXPECT_EQ(out_result, false); + separator = "#"; sep_len = static_cast(strlen(separator)); const char* word3 = "wow"; int32_t word3_len = static_cast(strlen(word3)); + out = concat_ws_utf8_utf8_utf8(ctx_ptr, separator, std::numeric_limits::max(), + true, word1, word1_len, true, word2, word2_len, true, + word3, word3_len, true, &out_result, &out_len); + EXPECT_STREQ(out, ""); + EXPECT_EQ(out_len, 0); + EXPECT_EQ(out_result, false); + out = concat_ws_utf8_utf8_utf8(ctx_ptr, separator, sep_len, true, word1, word1_len, true, word2, word2_len, true, word3, word3_len, true, &out_result, &out_len); @@ -2344,6 +2385,14 @@ TEST(TestStringOps, TestConcatWs) { const char* word4 = "awesome"; int32_t word4_len = static_cast(strlen(word4)); + out = concat_ws_utf8_utf8_utf8_utf8(ctx_ptr, separator, sep_len, true, word1, + std::numeric_limits::max(), true, word2, + word2_len, true, word3, word3_len, true, word4, + word4_len, true, &out_result, &out_len); + EXPECT_STREQ(out, ""); + EXPECT_EQ(out_len, 0); + EXPECT_EQ(out_result, false); + out = concat_ws_utf8_utf8_utf8_utf8( ctx_ptr, separator, sep_len, true, word1, word1_len, true, word2, word2_len, true, word3, word3_len, true, word4, word4_len, true, &out_result, &out_len); @@ -2355,6 +2404,13 @@ TEST(TestStringOps, TestConcatWs) { const char* word5 = "super"; int32_t word5_len = static_cast(strlen(word5)); + out = concat_ws_utf8_utf8_utf8_utf8_utf8( + ctx_ptr, separator, sep_len, true, word1, word1_len, true, word2, word2_len, true, + word3, word3_len, true, word4, std::numeric_limits::max(), true, word5, + std::numeric_limits::max(), true, &out_result, &out_len); + EXPECT_STREQ(out, ""); + EXPECT_EQ(out_result, false); + out = concat_ws_utf8_utf8_utf8_utf8_utf8(ctx_ptr, separator, sep_len, true, word1, word1_len, true, word2, word2_len, true, word3, word3_len, true, word4, word4_len, true, word5, @@ -2498,6 +2554,25 @@ TEST(TestStringOps, TestToHex) { output = std::string(out_str, out_len); EXPECT_EQ(out_len, 2 * in_len); EXPECT_EQ(output, "090A090A090A090A0A0A092061206C657474405D6572"); + ctx.Reset(); + + int32_t bad_text_len = std::numeric_limits::max() / 2 + 20; + out_str = to_hex_binary(ctx_ptr, binary_string, bad_text_len, &out_len); + EXPECT_EQ(out_len, 0); + EXPECT_STREQ(out_str, ""); + ctx.Reset(); + + bad_text_len = (std::numeric_limits::max() / 2) + 1; + out_str = to_hex_binary(ctx_ptr, binary_string, bad_text_len, &out_len); + EXPECT_EQ(out_len, 0); + EXPECT_STREQ(out_str, ""); + ctx.Reset(); + + int32_t neg_in_len = -20; + out_str = to_hex_binary(ctx_ptr, binary_string, neg_in_len, &out_len); + EXPECT_EQ(out_len, 0); + EXPECT_STREQ(out_str, ""); + ctx.Reset(); } TEST(TestStringOps, TestToHexInt64) {