Skip to content

Commit be80053

Browse files
committed
GH-49753: [C++][Gandiva] Fix string based functions in string_ops.
Fixes potential integer-overflow/invalid-length issues in string_ops.cc. Expanded unit-tests for concat_ws*. Incorporated review comments.
1 parent a3f57cb commit be80053

2 files changed

Lines changed: 240 additions & 81 deletions

File tree

cpp/src/gandiva/precompiled/string_ops.cc

Lines changed: 170 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -1924,9 +1924,17 @@ const char* quote_utf8(gdv_int64 context, const char* in, gdv_int32 in_len,
19241924
*out_len = 0;
19251925
return "";
19261926
}
1927+
1928+
int32_t alloc_length = 0;
1929+
if (ARROW_PREDICT_FALSE(
1930+
arrow::internal::AddWithOverflow(2, (2 * in_len), &alloc_length))) {
1931+
gdv_fn_context_set_error_msg(context, "Memory allocation size too large");
1932+
*out_len = 0;
1933+
return "";
1934+
}
1935+
19271936
// try to allocate double size output string (worst case)
1928-
auto out =
1929-
reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, (in_len * 2) + 2));
1937+
auto out = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, alloc_length));
19301938
if (out == nullptr) {
19311939
gdv_fn_context_set_error_msg(context, "Could not allocate memory for output string");
19321940
*out_len = 0;
@@ -2444,37 +2452,108 @@ void concat_word(char* out_buf, int* out_idx, const char* in_buf, int in_len,
24442452
*out_idx += in_len;
24452453
}
24462454

2455+
// Helper structure to maintain state during safe length accumulation
2456+
struct SafeLengthState {
2457+
int32_t total_len = 0;
2458+
int32_t num_valid = 0;
2459+
bool overflow = false;
2460+
};
2461+
2462+
// Helper to safely add a word length
2463+
static inline bool safe_accumulate_word(SafeLengthState* state, int32_t word_len,
2464+
bool word_validity) {
2465+
if (!word_validity) return true;
2466+
2467+
int32_t temp = 0;
2468+
if (ARROW_PREDICT_FALSE(
2469+
arrow::internal::AddWithOverflow(state->total_len, word_len, &temp))) {
2470+
state->overflow = true;
2471+
return false;
2472+
}
2473+
state->total_len = temp;
2474+
state->num_valid++;
2475+
return true;
2476+
}
2477+
2478+
// Helper to safely add separators based on number of valid words
2479+
static inline bool safe_add_separators(SafeLengthState* state, int32_t separator_len) {
2480+
if (state->num_valid <= 1) return true;
2481+
2482+
int32_t sep_total = 0;
2483+
int32_t temp = 0;
2484+
2485+
if (ARROW_PREDICT_FALSE(arrow::internal::MultiplyWithOverflow(
2486+
separator_len, state->num_valid - 1, &sep_total))) {
2487+
state->overflow = true;
2488+
return false;
2489+
}
2490+
2491+
if (ARROW_PREDICT_FALSE(
2492+
arrow::internal::AddWithOverflow(state->total_len, sep_total, &temp))) {
2493+
state->overflow = true;
2494+
return false;
2495+
}
2496+
2497+
state->total_len = temp;
2498+
return true;
2499+
}
2500+
2501+
// Helper to handle overflow failure (sets output parameters and returns nullptr)
2502+
static inline const char* handle_overflow_failure(bool* out_valid, int32_t* out_len) {
2503+
*out_len = 0;
2504+
*out_valid = false;
2505+
return "";
2506+
}
2507+
2508+
// Helper to handle empty result (all words invalid)
2509+
static inline const char* handle_empty_result(bool* out_valid, int32_t* out_len) {
2510+
*out_len = 0;
2511+
*out_valid = true;
2512+
return "";
2513+
}
2514+
24472515
FORCE_INLINE
24482516
const char* concat_ws_utf8_utf8(int64_t context, const char* separator,
24492517
int32_t separator_len, bool separator_validity,
24502518
const char* word1, int32_t word1_len, bool word1_validity,
24512519
const char* word2, int32_t word2_len, bool word2_validity,
24522520
bool* out_valid, int32_t* out_len) {
24532521
*out_len = 0;
2454-
int numValidInput = 0;
24552522
// If separator is null, always return null
24562523
if (!separator_validity) {
24572524
*out_len = 0;
24582525
*out_valid = false;
24592526
return "";
24602527
}
24612528

2462-
if (word1_validity) {
2463-
*out_len += word1_len;
2464-
numValidInput++;
2529+
// If separator is null, always return null
2530+
if (!separator_validity) {
2531+
return handle_overflow_failure(out_valid, out_len);
24652532
}
2466-
if (word2_validity) {
2467-
*out_len += word2_len;
2468-
numValidInput++;
2533+
2534+
SafeLengthState state;
2535+
2536+
// Accumulate word lengths safely
2537+
safe_accumulate_word(&state, word1_len, word1_validity);
2538+
safe_accumulate_word(&state, word2_len, word2_validity);
2539+
2540+
if (state.overflow) {
2541+
return handle_overflow_failure(out_valid, out_len);
24692542
}
24702543

2471-
*out_len += separator_len * (numValidInput > 1 ? numValidInput - 1 : 0);
2472-
if (*out_len == 0) {
2473-
*out_valid = true;
2474-
return "";
2544+
// Add separator lengths
2545+
if (!safe_add_separators(&state, separator_len)) {
2546+
return handle_overflow_failure(out_valid, out_len);
2547+
}
2548+
2549+
// Handle case with no valid words
2550+
if (state.total_len == 0) {
2551+
return handle_empty_result(out_valid, out_len);
24752552
}
24762553

2477-
char* out = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, *out_len));
2554+
// Allocate and concatenate
2555+
char* out =
2556+
reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, state.total_len));
24782557
if (out == nullptr) {
24792558
gdv_fn_context_set_error_msg(context, "Could not allocate memory for output string");
24802559
*out_len = 0;
@@ -2503,36 +2582,30 @@ const char* concat_ws_utf8_utf8_utf8(
25032582
const char* word2, int32_t word2_len, bool word2_validity, const char* word3,
25042583
int32_t word3_len, bool word3_validity, bool* out_valid, int32_t* out_len) {
25052584
*out_len = 0;
2506-
int numValidInput = 0;
2507-
// If separator is null, always return null
25082585
if (!separator_validity) {
2509-
*out_len = 0;
2510-
*out_valid = false;
2511-
return "";
2586+
return handle_overflow_failure(out_valid, out_len);
25122587
}
25132588

2514-
if (word1_validity) {
2515-
*out_len += word1_len;
2516-
numValidInput++;
2517-
}
2518-
if (word2_validity) {
2519-
*out_len += word2_len;
2520-
numValidInput++;
2521-
}
2522-
if (word3_validity) {
2523-
*out_len += word3_len;
2524-
numValidInput++;
2589+
SafeLengthState state;
2590+
2591+
safe_accumulate_word(&state, word1_len, word1_validity);
2592+
safe_accumulate_word(&state, word2_len, word2_validity);
2593+
safe_accumulate_word(&state, word3_len, word3_validity);
2594+
2595+
if (state.overflow) {
2596+
return handle_overflow_failure(out_valid, out_len);
25252597
}
25262598

2527-
*out_len += separator_len * (numValidInput > 1 ? numValidInput - 1 : 0);
2599+
if (!safe_add_separators(&state, separator_len)) {
2600+
return handle_overflow_failure(out_valid, out_len);
2601+
}
25282602

2529-
if (*out_len == 0) {
2530-
*out_len = 0;
2531-
*out_valid = true;
2532-
return "";
2603+
if (state.total_len == 0) {
2604+
return handle_empty_result(out_valid, out_len);
25332605
}
25342606

2535-
char* out = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, *out_len));
2607+
char* out =
2608+
reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, state.total_len));
25362609
if (out == nullptr) {
25372610
gdv_fn_context_set_error_msg(context, "Could not allocate memory for output string");
25382611
*out_len = 0;
@@ -2564,39 +2637,44 @@ const char* concat_ws_utf8_utf8_utf8_utf8(
25642637
int32_t word3_len, bool word3_validity, const char* word4, int32_t word4_len,
25652638
bool word4_validity, bool* out_valid, int32_t* out_len) {
25662639
*out_len = 0;
2567-
int numValidInput = 0;
25682640
// If separator is null, always return null
25692641
if (!separator_validity) {
25702642
*out_len = 0;
25712643
*out_valid = false;
25722644
return "";
25732645
}
2574-
if (word1_validity) {
2575-
*out_len += word1_len;
2576-
numValidInput++;
2577-
}
2578-
if (word2_validity) {
2579-
*out_len += word2_len;
2580-
numValidInput++;
2581-
}
2582-
if (word3_validity) {
2583-
*out_len += word3_len;
2584-
numValidInput++;
2585-
}
2586-
if (word4_validity) {
2587-
*out_len += word4_len;
2588-
numValidInput++;
2646+
2647+
SafeLengthState state;
2648+
2649+
// Accumulate all word lengths with overflow checking
2650+
safe_accumulate_word(&state, word1_len, word1_validity);
2651+
safe_accumulate_word(&state, word2_len, word2_validity);
2652+
safe_accumulate_word(&state, word3_len, word3_validity);
2653+
safe_accumulate_word(&state, word4_len, word4_validity);
2654+
2655+
if (state.overflow) {
2656+
*out_len = 0;
2657+
*out_valid = false;
2658+
return "";
25892659
}
25902660

2591-
*out_len += separator_len * (numValidInput > 1 ? numValidInput - 1 : 0);
2661+
// Add separator lengths with overflow checking
2662+
if (!safe_add_separators(&state, separator_len)) {
2663+
*out_len = 0;
2664+
*out_valid = false;
2665+
return "";
2666+
}
25922667

2593-
if (*out_len == 0) {
2668+
// Handle case with no valid words
2669+
if (state.total_len == 0) {
25942670
*out_len = 0;
25952671
*out_valid = true;
25962672
return "";
25972673
}
25982674

2599-
char* out = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, *out_len));
2675+
// Allocate memory
2676+
char* out =
2677+
reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, state.total_len));
26002678
if (out == nullptr) {
26012679
gdv_fn_context_set_error_msg(context, "Could not allocate memory for output string");
26022680
*out_valid = false;
@@ -2631,43 +2709,45 @@ const char* concat_ws_utf8_utf8_utf8_utf8_utf8(
26312709
bool word4_validity, const char* word5, int32_t word5_len, bool word5_validity,
26322710
bool* out_valid, int32_t* out_len) {
26332711
*out_len = 0;
2634-
int numValidInput = 0;
26352712
// If separator is null, always return null
26362713
if (!separator_validity) {
26372714
*out_len = 0;
26382715
*out_valid = false;
26392716
return "";
26402717
}
2641-
if (word1_validity) {
2642-
*out_len += word1_len;
2643-
numValidInput++;
2644-
}
2645-
if (word2_validity) {
2646-
*out_len += word2_len;
2647-
numValidInput++;
2648-
}
2649-
if (word3_validity) {
2650-
*out_len += word3_len;
2651-
numValidInput++;
2652-
}
2653-
if (word4_validity) {
2654-
*out_len += word4_len;
2655-
numValidInput++;
2656-
}
2657-
if (word5_validity) {
2658-
*out_len += word5_len;
2659-
numValidInput++;
2718+
2719+
SafeLengthState state;
2720+
2721+
// Accumulate all word lengths with overflow checking
2722+
safe_accumulate_word(&state, word1_len, word1_validity);
2723+
safe_accumulate_word(&state, word2_len, word2_validity);
2724+
safe_accumulate_word(&state, word3_len, word3_validity);
2725+
safe_accumulate_word(&state, word4_len, word4_validity);
2726+
safe_accumulate_word(&state, word5_len, word5_validity);
2727+
2728+
if (state.overflow) {
2729+
*out_len = 0;
2730+
*out_valid = false;
2731+
return "";
26602732
}
26612733

2662-
*out_len += separator_len * (numValidInput > 1 ? numValidInput - 1 : 0);
2734+
// Add separator lengths with overflow checking
2735+
if (!safe_add_separators(&state, separator_len)) {
2736+
*out_len = 0;
2737+
*out_valid = false;
2738+
return "";
2739+
}
26632740

2664-
if (*out_len == 0) {
2741+
// Handle case with no valid words
2742+
if (state.total_len == 0) {
26652743
*out_len = 0;
26662744
*out_valid = true;
26672745
return "";
26682746
}
26692747

2670-
char* out = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, *out_len));
2748+
// Allocate memory
2749+
char* out =
2750+
reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, state.total_len));
26712751
if (out == nullptr) {
26722752
gdv_fn_context_set_error_msg(context, "Could not allocate memory for output string");
26732753
*out_len = 0;
@@ -2824,13 +2904,22 @@ const char* elt_int32_utf8_utf8_utf8_utf8_utf8(
28242904
FORCE_INLINE
28252905
const char* to_hex_binary(int64_t context, const char* text, int32_t text_len,
28262906
int32_t* out_len) {
2827-
if (text_len == 0) {
2907+
if (ARROW_PREDICT_FALSE(text_len <= 0)) {
28282908
*out_len = 0;
28292909
return "";
28302910
}
28312911

2832-
auto ret =
2833-
reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, text_len * 2 + 1));
2912+
int32_t alloc_length = 0;
2913+
2914+
// Check overflow for text_len
2915+
if (ARROW_PREDICT_FALSE(
2916+
arrow::internal::AddWithOverflow(1, (2 * text_len), &alloc_length))) {
2917+
gdv_fn_context_set_error_msg(context, "Memory allocation size too large");
2918+
*out_len = 0;
2919+
return "";
2920+
}
2921+
2922+
auto ret = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, alloc_length));
28342923

28352924
if (ret == nullptr) {
28362925
gdv_fn_context_set_error_msg(context, "Could not allocate memory for output string");

0 commit comments

Comments
 (0)