Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions cpp/src/gandiva/gdv_function_stubs_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -587,6 +587,10 @@ TEST(TestGdvFnStubs, TestSubstringIndex) {
std::numeric_limits<int32_t>::min(), &out_len);
EXPECT_EQ(std::string(out_str, out_len), "a.b.c");
EXPECT_FALSE(ctx.has_error());

out_str = gdv_fn_substring_index(ctx_ptr, "a", -2, ".", -1, -50, &out_len);
EXPECT_STREQ(out_str, "");
EXPECT_EQ(out_len, 0);
}

TEST(TestGdvFnStubs, TestUpper) {
Expand Down Expand Up @@ -640,6 +644,26 @@ TEST(TestGdvFnStubs, TestUpper) {
EXPECT_THAT(ctx.get_error(),
::testing::HasSubstr(
"unexpected byte \\c3 encountered while decoding utf8 string"));

ctx.Reset();

// Max Len Test
out_len = -1;
int32_t bad_len = std::numeric_limits<int32_t>::max() / 2 + 1;
const char* out = gdv_fn_upper_utf8(ctx_ptr, "dummy", bad_len, &out_len);
// Expect failure
EXPECT_EQ(out_len, 0);
EXPECT_STREQ(out, "");
EXPECT_THAT(ctx.get_error(),
::testing::HasSubstr("Would overflow maximum output size"));
ctx.Reset();

// Negative length test
out_len = -1;
out = gdv_fn_upper_utf8(ctx_ptr, "abc", -105, &out_len);
EXPECT_EQ(out_len, 0);
EXPECT_STREQ(out, "");
EXPECT_THAT(ctx.get_error(), ::testing::HasSubstr("Invalid (negative) data length"));
ctx.Reset();

std::string e(
Expand Down Expand Up @@ -697,6 +721,26 @@ TEST(TestGdvFnStubs, TestLower) {
out_str = gdv_fn_lower_utf8(ctx_ptr, "", 0, &out_len);
EXPECT_EQ(std::string(out_str, out_len), "");
EXPECT_FALSE(ctx.has_error());
ctx.Reset();

// Max Len Test
out_len = -1;
int32_t bad_len = std::numeric_limits<int32_t>::max() / 2 + 1;
const char* out = gdv_fn_lower_utf8(ctx_ptr, "dummy", bad_len, &out_len);
// Expect failure
EXPECT_EQ(out_len, 0);
EXPECT_STREQ(out, "");
EXPECT_THAT(ctx.get_error(),
::testing::HasSubstr("Would overflow maximum output size"));
ctx.Reset();

// Negative length test
out_len = -1;
out = gdv_fn_lower_utf8(ctx_ptr, "abc", -105, &out_len);
EXPECT_EQ(out_len, 0);
EXPECT_STREQ(out, "");
EXPECT_THAT(ctx.get_error(), ::testing::HasSubstr("Invalid (negative) data length"));
ctx.Reset();

std::string d("AbOJjÜoß\xc3");
out_str = gdv_fn_lower_utf8(ctx_ptr, d.data(), static_cast<int>(d.length()), &out_len);
Expand Down Expand Up @@ -796,6 +840,25 @@ TEST(TestGdvFnStubs, TestInitCap) {
"unexpected byte \\c3 encountered while decoding utf8 string"));
ctx.Reset();

// Max Len Test
out_len = -1;
int32_t bad_len = std::numeric_limits<int32_t>::max() / 2 + 1;
const char* out = gdv_fn_initcap_utf8(ctx_ptr, "dummy", bad_len, &out_len);
// Expect failure
EXPECT_EQ(out_len, 0);
EXPECT_STREQ(out, "");
EXPECT_THAT(ctx.get_error(),
::testing::HasSubstr("Would overflow maximum output size"));
ctx.Reset();

// Negative length test
out_len = -1;
out = gdv_fn_initcap_utf8(ctx_ptr, "abc", -105, &out_len);
EXPECT_EQ(out_len, 0);
EXPECT_STREQ(out, "");
EXPECT_THAT(ctx.get_error(), ::testing::HasSubstr("Invalid (negative) data length"));
ctx.Reset();

std::string e(
"åbÑg\xe0\xa0"
"åBUå");
Expand Down Expand Up @@ -1127,6 +1190,15 @@ TEST(TestGdvFnStubs, TestTranslate) {
result = translate_utf8_utf8_utf8(ctx_ptr, "987654321", 9, "123456789", 9, "0123456789",
10, &out_len);
EXPECT_EQ(expected, std::string(result, out_len));

int32_t bad_in_len = std::numeric_limits<int32_t>::max() / 4 + 1;
out_len = -1;
result =
translate_utf8_utf8_utf8(ctx_ptr, "ABCDE", bad_in_len, "B", 1, "C", 1, &out_len);
EXPECT_EQ(out_len, 0);
EXPECT_STREQ(result, "");
EXPECT_THAT(ctx.get_error(),
::testing::HasSubstr("Would overflow maximum output size"));
}

TEST(TestGdvFnStubs, TestToUtcTimezone) {
Expand Down
109 changes: 85 additions & 24 deletions cpp/src/gandiva/gdv_string_function_stubs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,25 @@ int32_t gdv_fn_utf8_char_length(char c) {
return 0;
}

static inline bool is_datalen_valid(int64_t context, int32_t data_len, int32_t* alloc_len,
int32_t* out_len) {
// Reject negative lengths
if (ARROW_PREDICT_FALSE(data_len < 0)) {
gdv_fn_context_set_error_msg(context, "Invalid (negative) data length");
*out_len = 0;
return false;
}

// Check overflow: 2 * data_len
if (ARROW_PREDICT_FALSE(
arrow::internal::MultiplyWithOverflow(2, data_len, alloc_len))) {
gdv_fn_context_set_error_msg(context, "Would overflow maximum output size");
*out_len = 0;
return false;
}
return true;
}

// Convert an utf8 string to its corresponding lowercase string
GANDIVA_EXPORT
const char* gdv_fn_lower_utf8(int64_t context, const char* data, int32_t data_len,
Expand All @@ -222,10 +241,16 @@ const char* gdv_fn_lower_utf8(int64_t context, const char* data, int32_t data_le
return "";
}

int32_t alloc_length = 0;
if (ARROW_PREDICT_FALSE(
not is_datalen_valid(context, data_len, &alloc_length, out_len))) {
return "";
}

// If it is a single-byte character (ASCII), corresponding lowercase is always 1-byte
// long; if it is >= 2 bytes long, lowercase can be at most 4 bytes long, so length of
// the output can be at most twice the length of the input
char* out = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, 2 * data_len));
char* out = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, alloc_length));
if (out == nullptr) {
gdv_fn_context_set_error_msg(context, "Could not allocate memory for output string");
*out_len = 0;
Expand Down Expand Up @@ -294,10 +319,16 @@ const char* gdv_fn_upper_utf8(int64_t context, const char* data, int32_t data_le
return "";
}

int32_t alloc_length = 0;
if (ARROW_PREDICT_FALSE(
not is_datalen_valid(context, data_len, &alloc_length, out_len))) {
return "";
}

// If it is a single-byte character (ASCII), corresponding uppercase is always 1-byte
// long; if it is >= 2 bytes long, uppercase can be at most 4 bytes long, so length of
// the output can be at most twice the length of the input
char* out = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, 2 * data_len));
char* out = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, alloc_length));
if (out == nullptr) {
gdv_fn_context_set_error_msg(context, "Could not allocate memory for output string");
*out_len = 0;
Expand Down Expand Up @@ -367,6 +398,15 @@ const char* gdv_fn_substring_index(int64_t context, const char* txt, int32_t txt
return "";
}

if (ARROW_PREDICT_FALSE(txt_len < 0)) {
*out_len = 0;
return "";
}
if (ARROW_PREDICT_FALSE(pat_len < 0)) {
*out_len = 0;
return "";
}

char* out = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, txt_len));
if (out == nullptr) {
gdv_fn_context_set_error_msg(context, "Could not allocate memory for output string");
Expand Down Expand Up @@ -445,8 +485,12 @@ const char* gdv_fn_substring_index(int64_t context, const char* txt, int32_t txt
return out;

} else {
if (txt_len < 0) {
*out_len = 0;
return "";
}
memcpy(out, txt, static_cast<size_t>(txt_len));
*out_len = txt_len;
memcpy(out, txt, txt_len);
return out;
}
}
Expand Down Expand Up @@ -480,10 +524,16 @@ const char* gdv_fn_initcap_utf8(int64_t context, const char* data, int32_t data_
return "";
}

int32_t alloc_length = 0;
if (ARROW_PREDICT_FALSE(
not is_datalen_valid(context, data_len, &alloc_length, out_len))) {
return "";
}

// If it is a single-byte character (ASCII), corresponding uppercase is always 1-byte
// long; if it is >= 2 bytes long, uppercase can be at most 4 bytes long, so length of
// the output can be at most twice the length of the input
char* out = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, 2 * data_len));
char* out = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, alloc_length));
if (out == nullptr) {
gdv_fn_context_set_error_msg(context, "Could not allocate memory for output string");
*out_len = 0;
Expand Down Expand Up @@ -579,15 +629,24 @@ const char* translate_utf8_utf8_utf8(int64_t context, const char* in, int32_t in
return in;
}

int32_t alloc_length = 0;
// Check overflow: 4 * in_len
if (ARROW_PREDICT_FALSE(
arrow::internal::MultiplyWithOverflow(4, in_len, &alloc_length))) {
gdv_fn_context_set_error_msg(context, "Would overflow maximum output size");
*out_len = 0;
return "";
}

// This variable is to control if there are multi-byte utf8 entries
bool has_multi_byte = false;

// This variable is to store the final result
char* result;
int result_len;
int32_t result_len;

// Searching multi-bytes in In
for (int i = 0; i < in_len; i++) {
for (int32_t i = 0; i < in_len; i++) {
unsigned char char_single_byte = in[i];
if (char_single_byte > 127) {
// found a multi-byte utf-8 char
Expand All @@ -598,7 +657,7 @@ const char* translate_utf8_utf8_utf8(int64_t context, const char* in, int32_t in

// Searching multi-bytes in From
if (!has_multi_byte) {
for (int i = 0; i < from_len; i++) {
for (int32_t i = 0; i < from_len; i++) {
unsigned char char_single_byte = from[i];
if (char_single_byte > 127) {
// found a multi-byte utf-8 char
Expand All @@ -610,7 +669,7 @@ const char* translate_utf8_utf8_utf8(int64_t context, const char* in, int32_t in

// Searching multi-bytes in To
if (!has_multi_byte) {
for (int i = 0; i < to_len; i++) {
for (int32_t i = 0; i < to_len; i++) {
unsigned char char_single_byte = to[i];
if (char_single_byte > 127) {
// found a multi-byte utf-8 char
Expand All @@ -621,7 +680,7 @@ const char* translate_utf8_utf8_utf8(int64_t context, const char* in, int32_t in
}

// If there are no multibytes in the input, work only with char
if (!has_multi_byte) {
if (not has_multi_byte) {
// This variable is for receive the substitutions
result = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, in_len));

Expand All @@ -638,7 +697,7 @@ const char* translate_utf8_utf8_utf8(int64_t context, const char* in, int32_t in

// This variable is for controlling the position in entry TO, for never repeat the
// changes
int start_compare;
int32_t start_compare;

if (to_len > 0) {
start_compare = 0;
Expand All @@ -650,15 +709,15 @@ const char* translate_utf8_utf8_utf8(int64_t context, const char* in, int32_t in
// list, to mark deletion positions
const char empty = '\0';

for (int in_for = 0; in_for < in_len; in_for++) {
for (int32_t in_for = 0; in_for < in_len; in_for++) {
if (subs_list.find(in[in_for]) != subs_list.end()) {
if (subs_list[in[in_for]] != empty) {
// If exist in map, only add the correspondent value in result
result[result_len] = subs_list[in[in_for]];
result_len++;
}
} else {
for (int from_for = 0; from_for <= from_len; from_for++) {
for (int32_t from_for = 0; from_for <= from_len; from_for++) {
if (from_for == from_len) {
// If it's not in the FROM list, just add it to the map and the result.
subs_list.insert(std::pair<char, char>(in[in_for], in[in_for]));
Expand Down Expand Up @@ -686,10 +745,11 @@ const char* translate_utf8_utf8_utf8(int64_t context, const char* in, int32_t in
}
}
}
} else { // If there are no multibytes in the input, work with std::strings
} else {
// If there are multibytes in the input, work with std::strings
// This variable is for receive the substitutions, malloc is in_len * 4 to receive
// possible inputs with 4 bytes
result = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, in_len * 4));
result = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, alloc_length));

if (result == nullptr) {
gdv_fn_context_set_error_msg(context,
Expand All @@ -704,7 +764,7 @@ const char* translate_utf8_utf8_utf8(int64_t context, const char* in, int32_t in

// This variable is for controlling the position in entry TO, for never repeat the
// changes
int start_compare;
int32_t start_compare;

if (to_len > 0) {
start_compare = 0;
Expand All @@ -717,11 +777,11 @@ const char* translate_utf8_utf8_utf8(int64_t context, const char* in, int32_t in
const std::string empty = "";

// This variables is to control len of multi-bytes entries
int len_char_in = 0;
int len_char_from = 0;
int len_char_to = 0;
int32_t len_char_in = 0;
int32_t len_char_from = 0;
int32_t len_char_to = 0;

for (int in_for = 0; in_for < in_len; in_for += len_char_in) {
for (int32_t in_for = 0; in_for < in_len; in_for += len_char_in) {
// Updating len to char in this position
len_char_in = gdv_fn_utf8_char_length(in[in_for]);
// Making copy to std::string with length for this char position
Expand All @@ -734,11 +794,7 @@ const char* translate_utf8_utf8_utf8(int64_t context, const char* in, int32_t in
result_len += static_cast<int>(subs_list[insert_copy_key].length());
}
} else {
for (int from_for = 0; from_for <= from_len; from_for += len_char_from) {
// Updating len to char in this position
len_char_from = gdv_fn_utf8_char_length(from[from_for]);
// Making copy to std::string with length for this char position
std::string copy_from_compare(from + from_for, len_char_from);
for (int32_t from_for = 0; from_for <= from_len; from_for += len_char_from) {
if (from_for == from_len) {
// If it's not in the FROM list, just add it to the map and the result.
std::string insert_copy_value(in + in_for, len_char_in);
Expand All @@ -751,6 +807,11 @@ const char* translate_utf8_utf8_utf8(int64_t context, const char* in, int32_t in
break;
}

// Updating len to char in this position
len_char_from = gdv_fn_utf8_char_length(from[from_for]);
// Making copy to std::string with length for this char position
std::string copy_from_compare(from + from_for, len_char_from);

if (insert_copy_key != copy_from_compare) {
// If this character does not exist in FROM list, don't need treatment
continue;
Expand Down
Loading