Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 13 additions & 7 deletions cpp/src/arrow/acero/hash_join_node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -370,9 +370,19 @@ Result<Expression> HashJoinSchema::BindFilter(Expression filter,
const Schema& left_schema,
const Schema& right_schema,
ExecContext* exec_context) {
if (filter.IsBound() || filter == literal(true)) {
auto ValidateFilterTypeAndReturn = [](Expression filter) -> Result<Expression> {
if (filter.type()->id() != Type::BOOL) {
return Status::TypeError("Filter expression must evaluate to bool, but ",
filter.ToString(), " evaluates to ",
filter.type()->ToString());
}
return filter;
};

if (filter.IsBound()) {
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IsBound() is implied by literal(true) so one check should suffice.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose there's already a unit test for that?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not any particular test, but most existing tests exercise this path, both if true and if false. So I think we are good.

return ValidateFilterTypeAndReturn(std::move(filter));
}

// Step 1: Construct filter schema
FieldVector fields;
auto left_f_to_i =
Expand Down Expand Up @@ -401,12 +411,8 @@ Result<Expression> HashJoinSchema::BindFilter(Expression filter,

// Step 3: Bind
ARROW_ASSIGN_OR_RAISE(filter, filter.Bind(filter_schema, exec_context));
if (filter.type()->id() != Type::BOOL) {
return Status::TypeError("Filter expression must evaluate to bool, but ",
filter.ToString(), " evaluates to ",
filter.type()->ToString());
}
return filter;

return ValidateFilterTypeAndReturn(std::move(filter));
}

Expression HashJoinSchema::RewriteFilterToUseFilterSchema(
Expand Down
39 changes: 37 additions & 2 deletions cpp/src/arrow/acero/hash_join_node_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1902,6 +1902,41 @@ TEST(HashJoin, CheckHashJoinNodeOptionsValidation) {
}
}

TEST(HashJoin, CheckResidualFilterType) {
BatchesWithSchema input_left;
input_left.schema = schema({field("lkey", int32()), field("lpayload", int32())});

BatchesWithSchema input_right;
input_right.schema = schema({field("rkey", int32()), field("rpayload", int32())});

Declaration left{"source",
SourceNodeOptions{input_left.schema, input_left.gen(/*parallel=*/false,
/*slow=*/false)}};
Declaration right{
"source", SourceNodeOptions{input_right.schema, input_right.gen(/*parallel=*/false,
/*slow=*/false)}};

for (const auto& filter :
{literal(MakeNullScalar(boolean())), literal(true), literal(false),
equal(field_ref("lpayload"), field_ref("rpayload"))}) {
HashJoinNodeOptions options{
JoinType::INNER, {FieldRef("lkey")}, {FieldRef("rkey")}, filter};
Declaration join{"hashjoin", {left, right}, options};
ASSERT_OK(DeclarationToStatus(std::move(join)));
}

for (const auto& filter :
{literal(NullScalar()), literal(42),
call("add", {field_ref("lpayload"), field_ref("rpayload")})}) {
HashJoinNodeOptions options{
JoinType::INNER, {FieldRef("lkey")}, {FieldRef("rkey")}, filter};
Declaration join{"hashjoin", {left, right}, options};
EXPECT_RAISES_WITH_MESSAGE_THAT(TypeError,
::testing::HasSubstr("must evaluate to bool"),
DeclarationToStatus(std::move(join)));
}
}

class ResidualFilterCaseRunner {
public:
ResidualFilterCaseRunner(BatchesWithSchema left_input, BatchesWithSchema right_input)
Expand Down Expand Up @@ -2369,8 +2404,8 @@ TEST(HashJoin, FineGrainedResidualFilter) {
{
// Literal false, null, and scalar false, null.
for (Expression filter :
{literal(false), literal(NullScalar()), equal(literal(0), literal(1)),
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The null type null (literal(NullScalar())) is invalid. Replacing it with a boolean type null.

equal(literal(1), literal(NullScalar()))}) {
{literal(false), literal(MakeNullScalar(boolean())),
equal(literal(0), literal(1)), equal(literal(1), literal(NullScalar()))}) {
std::vector<FieldRef> left_keys{"l_key", "l_filter"},
right_keys{"r_key", "r_filter"};
{
Expand Down
26 changes: 15 additions & 11 deletions cpp/src/arrow/acero/swiss_join.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1845,6 +1845,11 @@ void JoinResidualFilter::Init(Expression filter, QueryContext* ctx, MemoryPool*
const HashJoinProjectionMaps* build_schemas,
SwissTableForJoin* hash_table) {
filter_ = std::move(filter);
if (auto lit = filter_.literal(); lit) {
const auto& scalar = lit->scalar_as<BooleanScalar>();
is_trivial_ = true;
is_literal_true_ = scalar.is_valid && scalar.value;
}
ctx_ = ctx;
pool_ = pool;
hardware_flags_ = hardware_flags;
Expand Down Expand Up @@ -1918,14 +1923,14 @@ Status JoinResidualFilter::FilterLeftSemi(const ExecBatch& keypayload_batch,
arrow::util::TempVectorStack* temp_stack,
int* num_passing_ids,
uint16_t* passing_batch_row_ids) const {
if (filter_ == literal(true)) {
if (is_literal_true_) {
CollectPassingBatchIds(1, hardware_flags_, batch_start_row, num_batch_rows,
match_bitvector, num_passing_ids, passing_batch_row_ids);
return Status::OK();
}

*num_passing_ids = 0;
if (filter_.IsNullLiteral() || filter_ == literal(false)) {
if (is_trivial_ && !is_literal_true_) {
return Status::OK();
}

Expand Down Expand Up @@ -1993,7 +1998,7 @@ Status JoinResidualFilter::FilterLeftAnti(const ExecBatch& keypayload_batch,
arrow::util::TempVectorStack* temp_stack,
int* num_passing_ids,
uint16_t* passing_batch_row_ids) const {
if (filter_ == literal(true)) {
if (is_literal_true_) {
CollectPassingBatchIds(0, hardware_flags_, batch_start_row, num_batch_rows,
match_bitvector, num_passing_ids, passing_batch_row_ids);
return Status::OK();
Expand Down Expand Up @@ -2032,12 +2037,12 @@ Status JoinResidualFilter::FilterRightSemiAnti(
int64_t thread_id, const ExecBatch& keypayload_batch, int batch_start_row,
int num_batch_rows, const uint8_t* match_bitvector, const uint32_t* key_ids,
bool no_duplicate_keys, arrow::util::TempVectorStack* temp_stack) const {
if (filter_.IsNullLiteral() || filter_ == literal(false)) {
if (is_trivial_ && !is_literal_true_) {
return Status::OK();
}

int num_matching_ids = 0;
if (filter_ == literal(true)) {
if (is_literal_true_) {
auto match_relative_batch_ids_buf =
arrow::util::TempVectorHolder<uint16_t>(temp_stack, num_batch_rows);
auto match_key_ids_buf =
Expand Down Expand Up @@ -2091,13 +2096,13 @@ Status JoinResidualFilter::FilterInner(
const ExecBatch& keypayload_batch, int num_batch_rows, uint16_t* batch_row_ids,
uint32_t* key_ids, uint32_t* payload_ids_maybe_null, bool output_payload_ids,
arrow::util::TempVectorStack* temp_stack, int* num_passing_rows) const {
if (filter_ == literal(true)) {
if (is_literal_true_) {
*num_passing_rows = num_batch_rows;
return Status::OK();
}

*num_passing_rows = 0;
if (filter_.IsNullLiteral() || filter_ == literal(false)) {
if (is_trivial_ && !is_literal_true_) {
return Status::OK();
}

Expand All @@ -2114,8 +2119,7 @@ Status JoinResidualFilter::FilterOneBatch(const ExecBatch& keypayload_batch,
arrow::util::TempVectorStack* temp_stack,
int* num_passing_rows) const {
// Caller must do shortcuts for trivial filter.
ARROW_DCHECK(!filter_.IsNullLiteral() && filter_ != literal(true) &&
filter_ != literal(false));
ARROW_DCHECK(!is_trivial_);
ARROW_DCHECK(!output_key_ids || key_ids_maybe_null);
ARROW_DCHECK(!output_payload_ids || payload_ids_maybe_null);

Expand All @@ -2128,6 +2132,7 @@ Status JoinResidualFilter::FilterOneBatch(const ExecBatch& keypayload_batch,
ARROW_ASSIGN_OR_RAISE(Datum mask,
EvalFilter(keypayload_batch, num_batch_rows, batch_row_ids,
key_ids_maybe_null, payload_ids_maybe_null));
DCHECK_EQ(mask.type()->id(), Type::BOOL);
if (mask.is_scalar()) {
const auto& mask_scalar = mask.scalar_as<BooleanScalar>();
if (mask_scalar.is_valid && mask_scalar.value) {
Expand Down Expand Up @@ -2162,8 +2167,7 @@ Status JoinResidualFilter::FilterOneBatch(const ExecBatch& keypayload_batch,
Result<Datum> JoinResidualFilter::EvalFilter(
const ExecBatch& keypayload_batch, int num_batch_rows, const uint16_t* batch_row_ids,
const uint32_t* key_ids_maybe_null, const uint32_t* payload_ids_maybe_null) const {
ARROW_DCHECK(!filter_.IsNullLiteral() && filter_ != literal(true) &&
filter_ != literal(false));
ARROW_DCHECK(!is_trivial_);

ARROW_ASSIGN_OR_RAISE(
ExecBatch input,
Expand Down
2 changes: 2 additions & 0 deletions cpp/src/arrow/acero/swiss_join_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -980,6 +980,8 @@ class JoinResidualFilter {

private:
Expression filter_;
bool is_trivial_ = false;
bool is_literal_true_ = false;

QueryContext* ctx_;
MemoryPool* pool_;
Expand Down
Loading