Skip to content

Commit 0ccb5ee

Browse files
committed
GH-49753: [C++][Gandiva] Fix overflow in string functions.
Fixes potential integer-overflow/invalid-length issues in Gandiva string functions by adding overflow-checked allocation sizing and expanding unit tests to cover extreme and negative lengths. Fixed memcpy call in gdv_substring_index function since, the length argument is of type "size_t". Incorporated review comments.
1 parent d7a02c1 commit 0ccb5ee

4 files changed

Lines changed: 261 additions & 32 deletions

File tree

cpp/src/gandiva/gdv_function_stubs_test.cc

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -587,6 +587,10 @@ TEST(TestGdvFnStubs, TestSubstringIndex) {
587587
std::numeric_limits<int32_t>::min(), &out_len);
588588
EXPECT_EQ(std::string(out_str, out_len), "a.b.c");
589589
EXPECT_FALSE(ctx.has_error());
590+
591+
out_str = gdv_fn_substring_index(ctx_ptr, "a", -2, ".", -1, -50, &out_len);
592+
EXPECT_STREQ(out_str, "");
593+
EXPECT_EQ(out_len, 0);
590594
}
591595

592596
TEST(TestGdvFnStubs, TestUpper) {
@@ -640,6 +644,26 @@ TEST(TestGdvFnStubs, TestUpper) {
640644
EXPECT_THAT(ctx.get_error(),
641645
::testing::HasSubstr(
642646
"unexpected byte \\c3 encountered while decoding utf8 string"));
647+
648+
ctx.Reset();
649+
650+
// Max Len Test
651+
out_len = -1;
652+
int32_t bad_len = std::numeric_limits<int32_t>::max() / 2 + 1;
653+
const char* out = gdv_fn_upper_utf8(ctx_ptr, "dummy", bad_len, &out_len);
654+
// Expect failure
655+
EXPECT_EQ(out_len, 0);
656+
EXPECT_STREQ(out, "");
657+
EXPECT_THAT(ctx.get_error(),
658+
::testing::HasSubstr("Would overflow maximum output size"));
659+
ctx.Reset();
660+
661+
// Negative length test
662+
out_len = -1;
663+
out = gdv_fn_upper_utf8(ctx_ptr, "abc", -105, &out_len);
664+
EXPECT_EQ(out_len, 0);
665+
EXPECT_STREQ(out, "");
666+
EXPECT_THAT(ctx.get_error(), ::testing::HasSubstr("Invalid (negative) data length"));
643667
ctx.Reset();
644668

645669
std::string e(
@@ -697,6 +721,26 @@ TEST(TestGdvFnStubs, TestLower) {
697721
out_str = gdv_fn_lower_utf8(ctx_ptr, "", 0, &out_len);
698722
EXPECT_EQ(std::string(out_str, out_len), "");
699723
EXPECT_FALSE(ctx.has_error());
724+
ctx.Reset();
725+
726+
// Max Len Test
727+
out_len = -1;
728+
int32_t bad_len = std::numeric_limits<int32_t>::max() / 2 + 1;
729+
const char* out = gdv_fn_lower_utf8(ctx_ptr, "dummy", bad_len, &out_len);
730+
// Expect failure
731+
EXPECT_EQ(out_len, 0);
732+
EXPECT_STREQ(out, "");
733+
EXPECT_THAT(ctx.get_error(),
734+
::testing::HasSubstr("Would overflow maximum output size"));
735+
ctx.Reset();
736+
737+
// Negative length test
738+
out_len = -1;
739+
out = gdv_fn_lower_utf8(ctx_ptr, "abc", -105, &out_len);
740+
EXPECT_EQ(out_len, 0);
741+
EXPECT_STREQ(out, "");
742+
EXPECT_THAT(ctx.get_error(), ::testing::HasSubstr("Invalid (negative) data length"));
743+
ctx.Reset();
700744

701745
std::string d("AbOJjÜoß\xc3");
702746
out_str = gdv_fn_lower_utf8(ctx_ptr, d.data(), static_cast<int>(d.length()), &out_len);
@@ -796,6 +840,25 @@ TEST(TestGdvFnStubs, TestInitCap) {
796840
"unexpected byte \\c3 encountered while decoding utf8 string"));
797841
ctx.Reset();
798842

843+
// Max Len Test
844+
out_len = -1;
845+
int32_t bad_len = std::numeric_limits<int32_t>::max() / 2 + 1;
846+
const char* out = gdv_fn_initcap_utf8(ctx_ptr, "dummy", bad_len, &out_len);
847+
// Expect failure
848+
EXPECT_EQ(out_len, 0);
849+
EXPECT_STREQ(out, "");
850+
EXPECT_THAT(ctx.get_error(),
851+
::testing::HasSubstr("Would overflow maximum output size"));
852+
ctx.Reset();
853+
854+
// Negative length test
855+
out_len = -1;
856+
out = gdv_fn_initcap_utf8(ctx_ptr, "abc", -105, &out_len);
857+
EXPECT_EQ(out_len, 0);
858+
EXPECT_STREQ(out, "");
859+
EXPECT_THAT(ctx.get_error(), ::testing::HasSubstr("Invalid (negative) data length"));
860+
ctx.Reset();
861+
799862
std::string e(
800863
"åbÑg\xe0\xa0"
801864
"åBUå");
@@ -1127,6 +1190,15 @@ TEST(TestGdvFnStubs, TestTranslate) {
11271190
result = translate_utf8_utf8_utf8(ctx_ptr, "987654321", 9, "123456789", 9, "0123456789",
11281191
10, &out_len);
11291192
EXPECT_EQ(expected, std::string(result, out_len));
1193+
1194+
int32_t bad_in_len = std::numeric_limits<int32_t>::max() / 4 + 1;
1195+
out_len = -1;
1196+
result =
1197+
translate_utf8_utf8_utf8(ctx_ptr, "ABCDE", bad_in_len, "B", 1, "C", 1, &out_len);
1198+
EXPECT_EQ(out_len, 0);
1199+
EXPECT_STREQ(result, "");
1200+
EXPECT_THAT(ctx.get_error(),
1201+
::testing::HasSubstr("Would overflow maximum output size"));
11301202
}
11311203

11321204
TEST(TestGdvFnStubs, TestToUtcTimezone) {

cpp/src/gandiva/gdv_string_function_stubs.cc

Lines changed: 85 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,25 @@ int32_t gdv_fn_utf8_char_length(char c) {
213213
return 0;
214214
}
215215

216+
static inline bool is_datalen_valid(int64_t context, int32_t data_len, int32_t* alloc_len,
217+
int32_t* out_len) {
218+
// Reject negative lengths
219+
if (ARROW_PREDICT_FALSE(data_len < 0)) {
220+
gdv_fn_context_set_error_msg(context, "Invalid (negative) data length");
221+
*out_len = 0;
222+
return false;
223+
}
224+
225+
// Check overflow: 2 * data_len
226+
if (ARROW_PREDICT_FALSE(
227+
arrow::internal::MultiplyWithOverflow(2, data_len, alloc_len))) {
228+
gdv_fn_context_set_error_msg(context, "Would overflow maximum output size");
229+
*out_len = 0;
230+
return false;
231+
}
232+
return true;
233+
}
234+
216235
// Convert an utf8 string to its corresponding lowercase string
217236
GANDIVA_EXPORT
218237
const char* gdv_fn_lower_utf8(int64_t context, const char* data, int32_t data_len,
@@ -222,10 +241,16 @@ const char* gdv_fn_lower_utf8(int64_t context, const char* data, int32_t data_le
222241
return "";
223242
}
224243

244+
int32_t alloc_length = 0;
245+
if (ARROW_PREDICT_FALSE(
246+
not is_datalen_valid(context, data_len, &alloc_length, out_len))) {
247+
return "";
248+
}
249+
225250
// If it is a single-byte character (ASCII), corresponding lowercase is always 1-byte
226251
// long; if it is >= 2 bytes long, lowercase can be at most 4 bytes long, so length of
227252
// the output can be at most twice the length of the input
228-
char* out = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, 2 * data_len));
253+
char* out = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, alloc_length));
229254
if (out == nullptr) {
230255
gdv_fn_context_set_error_msg(context, "Could not allocate memory for output string");
231256
*out_len = 0;
@@ -294,10 +319,16 @@ const char* gdv_fn_upper_utf8(int64_t context, const char* data, int32_t data_le
294319
return "";
295320
}
296321

322+
int32_t alloc_length = 0;
323+
if (ARROW_PREDICT_FALSE(
324+
not is_datalen_valid(context, data_len, &alloc_length, out_len))) {
325+
return "";
326+
}
327+
297328
// If it is a single-byte character (ASCII), corresponding uppercase is always 1-byte
298329
// long; if it is >= 2 bytes long, uppercase can be at most 4 bytes long, so length of
299330
// the output can be at most twice the length of the input
300-
char* out = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, 2 * data_len));
331+
char* out = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, alloc_length));
301332
if (out == nullptr) {
302333
gdv_fn_context_set_error_msg(context, "Could not allocate memory for output string");
303334
*out_len = 0;
@@ -367,6 +398,15 @@ const char* gdv_fn_substring_index(int64_t context, const char* txt, int32_t txt
367398
return "";
368399
}
369400

401+
if (ARROW_PREDICT_FALSE(txt_len < 0)) {
402+
*out_len = 0;
403+
return "";
404+
}
405+
if (ARROW_PREDICT_FALSE(pat_len < 0)) {
406+
*out_len = 0;
407+
return "";
408+
}
409+
370410
char* out = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, txt_len));
371411
if (out == nullptr) {
372412
gdv_fn_context_set_error_msg(context, "Could not allocate memory for output string");
@@ -445,8 +485,12 @@ const char* gdv_fn_substring_index(int64_t context, const char* txt, int32_t txt
445485
return out;
446486

447487
} else {
488+
if (txt_len < 0) {
489+
*out_len = 0;
490+
return "";
491+
}
492+
memcpy(out, txt, static_cast<size_t>(txt_len));
448493
*out_len = txt_len;
449-
memcpy(out, txt, txt_len);
450494
return out;
451495
}
452496
}
@@ -480,10 +524,16 @@ const char* gdv_fn_initcap_utf8(int64_t context, const char* data, int32_t data_
480524
return "";
481525
}
482526

527+
int32_t alloc_length = 0;
528+
if (ARROW_PREDICT_FALSE(
529+
not is_datalen_valid(context, data_len, &alloc_length, out_len))) {
530+
return "";
531+
}
532+
483533
// If it is a single-byte character (ASCII), corresponding uppercase is always 1-byte
484534
// long; if it is >= 2 bytes long, uppercase can be at most 4 bytes long, so length of
485535
// the output can be at most twice the length of the input
486-
char* out = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, 2 * data_len));
536+
char* out = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, alloc_length));
487537
if (out == nullptr) {
488538
gdv_fn_context_set_error_msg(context, "Could not allocate memory for output string");
489539
*out_len = 0;
@@ -579,15 +629,24 @@ const char* translate_utf8_utf8_utf8(int64_t context, const char* in, int32_t in
579629
return in;
580630
}
581631

632+
int32_t alloc_length = 0;
633+
// Check overflow: 4 * in_len
634+
if (ARROW_PREDICT_FALSE(
635+
arrow::internal::MultiplyWithOverflow(4, in_len, &alloc_length))) {
636+
gdv_fn_context_set_error_msg(context, "Would overflow maximum output size");
637+
*out_len = 0;
638+
return "";
639+
}
640+
582641
// This variable is to control if there are multi-byte utf8 entries
583642
bool has_multi_byte = false;
584643

585644
// This variable is to store the final result
586645
char* result;
587-
int result_len;
646+
int32_t result_len;
588647

589648
// Searching multi-bytes in In
590-
for (int i = 0; i < in_len; i++) {
649+
for (int32_t i = 0; i < in_len; i++) {
591650
unsigned char char_single_byte = in[i];
592651
if (char_single_byte > 127) {
593652
// found a multi-byte utf-8 char
@@ -598,7 +657,7 @@ const char* translate_utf8_utf8_utf8(int64_t context, const char* in, int32_t in
598657

599658
// Searching multi-bytes in From
600659
if (!has_multi_byte) {
601-
for (int i = 0; i < from_len; i++) {
660+
for (int32_t i = 0; i < from_len; i++) {
602661
unsigned char char_single_byte = from[i];
603662
if (char_single_byte > 127) {
604663
// found a multi-byte utf-8 char
@@ -610,7 +669,7 @@ const char* translate_utf8_utf8_utf8(int64_t context, const char* in, int32_t in
610669

611670
// Searching multi-bytes in To
612671
if (!has_multi_byte) {
613-
for (int i = 0; i < to_len; i++) {
672+
for (int32_t i = 0; i < to_len; i++) {
614673
unsigned char char_single_byte = to[i];
615674
if (char_single_byte > 127) {
616675
// found a multi-byte utf-8 char
@@ -621,7 +680,7 @@ const char* translate_utf8_utf8_utf8(int64_t context, const char* in, int32_t in
621680
}
622681

623682
// If there are no multibytes in the input, work only with char
624-
if (!has_multi_byte) {
683+
if (not has_multi_byte) {
625684
// This variable is for receive the substitutions
626685
result = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, in_len));
627686

@@ -638,7 +697,7 @@ const char* translate_utf8_utf8_utf8(int64_t context, const char* in, int32_t in
638697

639698
// This variable is for controlling the position in entry TO, for never repeat the
640699
// changes
641-
int start_compare;
700+
int32_t start_compare;
642701

643702
if (to_len > 0) {
644703
start_compare = 0;
@@ -650,15 +709,15 @@ const char* translate_utf8_utf8_utf8(int64_t context, const char* in, int32_t in
650709
// list, to mark deletion positions
651710
const char empty = '\0';
652711

653-
for (int in_for = 0; in_for < in_len; in_for++) {
712+
for (int32_t in_for = 0; in_for < in_len; in_for++) {
654713
if (subs_list.find(in[in_for]) != subs_list.end()) {
655714
if (subs_list[in[in_for]] != empty) {
656715
// If exist in map, only add the correspondent value in result
657716
result[result_len] = subs_list[in[in_for]];
658717
result_len++;
659718
}
660719
} else {
661-
for (int from_for = 0; from_for <= from_len; from_for++) {
720+
for (int32_t from_for = 0; from_for <= from_len; from_for++) {
662721
if (from_for == from_len) {
663722
// If it's not in the FROM list, just add it to the map and the result.
664723
subs_list.insert(std::pair<char, char>(in[in_for], in[in_for]));
@@ -686,10 +745,11 @@ const char* translate_utf8_utf8_utf8(int64_t context, const char* in, int32_t in
686745
}
687746
}
688747
}
689-
} else { // If there are no multibytes in the input, work with std::strings
748+
} else {
749+
// If there are multibytes in the input, work with std::strings
690750
// This variable is for receive the substitutions, malloc is in_len * 4 to receive
691751
// possible inputs with 4 bytes
692-
result = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, in_len * 4));
752+
result = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, alloc_length));
693753

694754
if (result == nullptr) {
695755
gdv_fn_context_set_error_msg(context,
@@ -704,7 +764,7 @@ const char* translate_utf8_utf8_utf8(int64_t context, const char* in, int32_t in
704764

705765
// This variable is for controlling the position in entry TO, for never repeat the
706766
// changes
707-
int start_compare;
767+
int32_t start_compare;
708768

709769
if (to_len > 0) {
710770
start_compare = 0;
@@ -717,11 +777,11 @@ const char* translate_utf8_utf8_utf8(int64_t context, const char* in, int32_t in
717777
const std::string empty = "";
718778

719779
// This variables is to control len of multi-bytes entries
720-
int len_char_in = 0;
721-
int len_char_from = 0;
722-
int len_char_to = 0;
780+
int32_t len_char_in = 0;
781+
int32_t len_char_from = 0;
782+
int32_t len_char_to = 0;
723783

724-
for (int in_for = 0; in_for < in_len; in_for += len_char_in) {
784+
for (int32_t in_for = 0; in_for < in_len; in_for += len_char_in) {
725785
// Updating len to char in this position
726786
len_char_in = gdv_fn_utf8_char_length(in[in_for]);
727787
// Making copy to std::string with length for this char position
@@ -734,11 +794,7 @@ const char* translate_utf8_utf8_utf8(int64_t context, const char* in, int32_t in
734794
result_len += static_cast<int>(subs_list[insert_copy_key].length());
735795
}
736796
} else {
737-
for (int from_for = 0; from_for <= from_len; from_for += len_char_from) {
738-
// Updating len to char in this position
739-
len_char_from = gdv_fn_utf8_char_length(from[from_for]);
740-
// Making copy to std::string with length for this char position
741-
std::string copy_from_compare(from + from_for, len_char_from);
797+
for (int32_t from_for = 0; from_for <= from_len; from_for += len_char_from) {
742798
if (from_for == from_len) {
743799
// If it's not in the FROM list, just add it to the map and the result.
744800
std::string insert_copy_value(in + in_for, len_char_in);
@@ -751,6 +807,11 @@ const char* translate_utf8_utf8_utf8(int64_t context, const char* in, int32_t in
751807
break;
752808
}
753809

810+
// Updating len to char in this position
811+
len_char_from = gdv_fn_utf8_char_length(from[from_for]);
812+
// Making copy to std::string with length for this char position
813+
std::string copy_from_compare(from + from_for, len_char_from);
814+
754815
if (insert_copy_key != copy_from_compare) {
755816
// If this character does not exist in FROM list, don't need treatment
756817
continue;

0 commit comments

Comments
 (0)