Skip to content

Commit d4c1613

Browse files
authored
Fix key equality check order (#479)
Closes #474 This PR makes the reference value always the right-hand side for key equality checks. The updates for heterogeneous lookup tests indicate that it will be probably a breaking change for libcudf byte pair encoding.
1 parent a41c046 commit d4c1613

9 files changed

Lines changed: 36 additions & 36 deletions

File tree

include/cuco/detail/equal_wrapper.cuh

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,9 @@ struct equal_wrapper {
7878
/**
7979
* @brief Order-sensitive equality operator.
8080
*
81-
* @note This function always compares the left-hand side element against sentinel values first
81+
* @note This function always compares the right-hand side element against sentinel values first
8282
* then performs a equality check with the given `equal_` callable, i.e., `equal_(lhs, rhs)`.
83-
* @note Container (like set or map) keys MUST be always on the left-hand side.
83+
* @note Container (like set or map) buckets MUST be always on the right-hand side.
8484
*
8585
* @tparam IsInsert Flag indicating whether it's an insert equality check or not. Insert probing
8686
* stops when it's an empty or erased slot while query probing stops only when it's empty.
@@ -96,12 +96,12 @@ struct equal_wrapper {
9696
__device__ constexpr equal_result operator()(LHS const& lhs, RHS const& rhs) const noexcept
9797
{
9898
if constexpr (IsInsert == is_insert::YES) {
99-
return (cuco::detail::bitwise_compare(lhs, empty_sentinel_) or
100-
cuco::detail::bitwise_compare(lhs, erased_sentinel_))
99+
return (cuco::detail::bitwise_compare(rhs, empty_sentinel_) or
100+
cuco::detail::bitwise_compare(rhs, erased_sentinel_))
101101
? equal_result::AVAILABLE
102102
: this->equal_to(lhs, rhs);
103103
} else {
104-
return cuco::detail::bitwise_compare(lhs, empty_sentinel_) ? equal_result::EMPTY
104+
return cuco::detail::bitwise_compare(rhs, empty_sentinel_) ? equal_result::EMPTY
105105
: this->equal_to(lhs, rhs);
106106
}
107107
}

include/cuco/detail/open_addressing/functors.cuh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,8 +102,8 @@ struct slot_is_filled {
102102
return slot;
103103
}
104104
}();
105-
return not(cuco::detail::bitwise_compare(empty_sentinel_, key) or
106-
cuco::detail::bitwise_compare(erased_sentinel_, key));
105+
return not(cuco::detail::bitwise_compare(key, empty_sentinel_) or
106+
cuco::detail::bitwise_compare(key, erased_sentinel_));
107107
}
108108
};
109109

include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -371,7 +371,7 @@ class open_addressing_ref_impl {
371371

372372
for (auto& slot_content : window_slots) {
373373
auto const eq_res =
374-
this->predicate_.operator()<is_insert::YES>(this->extract_key(slot_content), key);
374+
this->predicate_.operator()<is_insert::YES>(key, this->extract_key(slot_content));
375375

376376
if constexpr (not allows_duplicates) {
377377
// If the key is already in the container, return false
@@ -422,7 +422,7 @@ class open_addressing_ref_impl {
422422
auto const [state, intra_window_index] = [&]() {
423423
for (auto i = 0; i < window_size; ++i) {
424424
switch (
425-
this->predicate_.operator()<is_insert::YES>(this->extract_key(window_slots[i]), key)) {
425+
this->predicate_.operator()<is_insert::YES>(key, this->extract_key(window_slots[i]))) {
426426
case detail::equal_result::AVAILABLE:
427427
return window_probing_results{detail::equal_result::AVAILABLE, i};
428428
case detail::equal_result::EQUAL: {
@@ -506,7 +506,7 @@ class open_addressing_ref_impl {
506506

507507
for (auto i = 0; i < window_size; ++i) {
508508
auto const eq_res =
509-
this->predicate_.operator()<is_insert::YES>(this->extract_key(window_slots[i]), key);
509+
this->predicate_.operator()<is_insert::YES>(key, this->extract_key(window_slots[i]));
510510
auto* window_ptr = (storage_ref_.data() + *probing_iter)->data();
511511

512512
// If the key is already in the container, return false
@@ -579,7 +579,7 @@ class open_addressing_ref_impl {
579579
auto res = detail::equal_result::UNEQUAL;
580580
for (auto i = 0; i < window_size; ++i) {
581581
res =
582-
this->predicate_.operator()<is_insert::YES>(this->extract_key(window_slots[i]), key);
582+
this->predicate_.operator()<is_insert::YES>(key, this->extract_key(window_slots[i]));
583583
if (res != detail::equal_result::UNEQUAL) { return window_probing_results{res, i}; }
584584
}
585585
// returns dummy index `-1` for UNEQUAL
@@ -662,7 +662,7 @@ class open_addressing_ref_impl {
662662

663663
for (auto& slot_content : window_slots) {
664664
auto const eq_res =
665-
this->predicate_.operator()<is_insert::NO>(this->extract_key(slot_content), key);
665+
this->predicate_.operator()<is_insert::NO>(key, this->extract_key(slot_content));
666666

667667
// Key doesn't exist, return false
668668
if (eq_res == detail::equal_result::EMPTY) { return false; }
@@ -704,7 +704,7 @@ class open_addressing_ref_impl {
704704
auto const [state, intra_window_index] = [&]() {
705705
auto res = detail::equal_result::UNEQUAL;
706706
for (auto i = 0; i < window_size; ++i) {
707-
res = this->predicate_.operator()<is_insert::NO>(this->extract_key(window_slots[i]), key);
707+
res = this->predicate_.operator()<is_insert::NO>(key, this->extract_key(window_slots[i]));
708708
if (res != detail::equal_result::UNEQUAL) { return window_probing_results{res, i}; }
709709
}
710710
// returns dummy index `-1` for UNEQUAL
@@ -758,7 +758,7 @@ class open_addressing_ref_impl {
758758
auto const window_slots = storage_ref_[*probing_iter];
759759

760760
for (auto& slot_content : window_slots) {
761-
switch (this->predicate_.operator()<is_insert::NO>(this->extract_key(slot_content), key)) {
761+
switch (this->predicate_.operator()<is_insert::NO>(key, this->extract_key(slot_content))) {
762762
case detail::equal_result::UNEQUAL: continue;
763763
case detail::equal_result::EMPTY: return false;
764764
case detail::equal_result::EQUAL: return true;
@@ -793,7 +793,7 @@ class open_addressing_ref_impl {
793793
auto const state = [&]() {
794794
auto res = detail::equal_result::UNEQUAL;
795795
for (auto& slot : window_slots) {
796-
res = this->predicate_.operator()<is_insert::NO>(this->extract_key(slot), key);
796+
res = this->predicate_.operator()<is_insert::NO>(key, this->extract_key(slot));
797797
if (res != detail::equal_result::UNEQUAL) { return res; }
798798
}
799799
return res;
@@ -830,7 +830,7 @@ class open_addressing_ref_impl {
830830

831831
for (auto i = 0; i < window_size; ++i) {
832832
switch (
833-
this->predicate_.operator()<is_insert::NO>(this->extract_key(window_slots[i]), key)) {
833+
this->predicate_.operator()<is_insert::NO>(key, this->extract_key(window_slots[i]))) {
834834
case detail::equal_result::EMPTY: {
835835
return this->end();
836836
}
@@ -869,7 +869,7 @@ class open_addressing_ref_impl {
869869
auto const [state, intra_window_index] = [&]() {
870870
auto res = detail::equal_result::UNEQUAL;
871871
for (auto i = 0; i < window_size; ++i) {
872-
res = this->predicate_.operator()<is_insert::NO>(this->extract_key(window_slots[i]), key);
872+
res = this->predicate_.operator()<is_insert::NO>(key, this->extract_key(window_slots[i]));
873873
if (res != detail::equal_result::UNEQUAL) { return window_probing_results{res, i}; }
874874
}
875875
// returns dummy index `-1` for UNEQUAL
@@ -1097,7 +1097,7 @@ class open_addressing_ref_impl {
10971097
if (cuco::detail::bitwise_compare(this->extract_key(*old_ptr), this->extract_key(expected))) {
10981098
return insert_result::SUCCESS;
10991099
} else {
1100-
return this->predicate_.equal_to(this->extract_key(*old_ptr), this->extract_key(desired)) ==
1100+
return this->predicate_.equal_to(this->extract_key(desired), this->extract_key(*old_ptr)) ==
11011101
detail::equal_result::EQUAL
11021102
? insert_result::DUPLICATE
11031103
: insert_result::CONTINUE;
@@ -1144,7 +1144,7 @@ class open_addressing_ref_impl {
11441144

11451145
// Our key was already present in the slot, so our key is a duplicate
11461146
// Shouldn't use `predicate` operator directly since it includes a redundant bitwise compare
1147-
if (this->predicate_.equal_to(*old_key_ptr, desired.first) == detail::equal_result::EQUAL) {
1147+
if (this->predicate_.equal_to(desired.first, *old_key_ptr) == detail::equal_result::EQUAL) {
11481148
return insert_result::DUPLICATE;
11491149
}
11501150

@@ -1183,7 +1183,7 @@ class open_addressing_ref_impl {
11831183

11841184
// Our key was already present in the slot, so our key is a duplicate
11851185
// Shouldn't use `predicate` operator directly since it includes a redundant bitwise compare
1186-
if (this->predicate_.equal_to(*old_key_ptr, desired.first) == detail::equal_result::EQUAL) {
1186+
if (this->predicate_.equal_to(desired.first, *old_key_ptr) == detail::equal_result::EQUAL) {
11871187
return insert_result::DUPLICATE;
11881188
}
11891189

include/cuco/detail/static_map/static_map_ref.inl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -398,7 +398,7 @@ class operator_impl<
398398

399399
for (auto& slot_content : window_slots) {
400400
auto const eq_res =
401-
ref_.impl_.predicate_.operator()<is_insert::YES>(slot_content.first, key);
401+
ref_.impl_.predicate_.operator()<is_insert::YES>(key, slot_content.first);
402402

403403
// If the key is already in the container, update the payload and return
404404
if (eq_res == detail::equal_result::EQUAL) {
@@ -449,7 +449,7 @@ class operator_impl<
449449
auto const [state, intra_window_index] = [&]() {
450450
auto res = detail::equal_result::UNEQUAL;
451451
for (auto i = 0; i < window_size; ++i) {
452-
res = ref_.impl_.predicate_.operator()<is_insert::YES>(window_slots[i].first, key);
452+
res = ref_.impl_.predicate_.operator()<is_insert::YES>(key, window_slots[i].first);
453453
if (res != detail::equal_result::UNEQUAL) {
454454
return detail::window_probing_results{res, i};
455455
}
@@ -514,7 +514,7 @@ class operator_impl<
514514

515515
// if key success or key was already present in the map
516516
if (cuco::detail::bitwise_compare(*old_key_ptr, expected_key) or
517-
(ref_.impl_.predicate().equal_to(*old_key_ptr, value.first) ==
517+
(ref_.impl_.predicate().equal_to(value.first, *old_key_ptr) ==
518518
detail::equal_result::EQUAL)) {
519519
// Update payload
520520
ref_.impl_.atomic_store(&slot->second, value.second);

include/cuco/static_map.cuh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,8 @@ namespace cuco {
6161
* construction.
6262
*
6363
* @note Allows constant time concurrent modify or lookup operations from threads in device code.
64-
* @note cuCollections data structures always place the slot keys on the left-hand side when
65-
* invoking the key comparison predicate, i.e., `pred(slot_key, query_key)`. Order-sensitive
64+
* @note cuCollections data structures always place the slot keys on the right-hand side when
65+
* invoking the key comparison predicate, i.e., `pred(query_key, slot_key)`. Order-sensitive
6666
* `KeyEqual` should be used with caution.
6767
* @note `ProbingScheme::cg_size` indicates how many threads are used to handle one independent
6868
* device operation. `cg_size == 1` uses the scalar (or non-CG) code paths.

include/cuco/static_multiset.cuh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,8 @@ namespace cuco {
5656
* construction.
5757
*
5858
* @note Allows constant time concurrent modify or lookup operations from threads in device code.
59-
* @note cuCollections data structures always place the slot keys on the left-hand side when
60-
* invoking the key comparison predicate, i.e., `pred(slot_key, query_key)`. Order-sensitive
59+
* @note cuCollections data structures always place the slot keys on the right-hand side when
60+
* invoking the key comparison predicate, i.e., `pred(query_key, slot_key)`. Order-sensitive
6161
* `KeyEqual` should be used with caution.
6262
* @note `ProbingScheme::cg_size` indicates how many threads are used to handle one independent
6363
* device operation. `cg_size == 1` uses the scalar (or non-CG) code paths.

include/cuco/static_set.cuh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,8 @@ namespace cuco {
6161
* construction.
6262
*
6363
* @note Allows constant time concurrent modify or lookup operations from threads in device code.
64-
* @note cuCollections data structures always place the slot keys on the left-hand side when
65-
* invoking the key comparison predicate, i.e., `pred(slot_key, query_key)`. Order-sensitive
64+
* @note cuCollections data structures always place the slot keys on the right-hand side when
65+
* invoking the key comparison predicate, i.e., `pred(query_key, slot_key)`. Order-sensitive
6666
* `KeyEqual` should be used with caution.
6767
* @note `ProbingScheme::cg_size` indicates how many threads are used to handle one independent
6868
* device operation. `cg_size == 1` uses the scalar (or non-CG) code paths.

tests/static_map/heterogeneous_lookup_test.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -74,12 +74,12 @@ struct custom_hasher {
7474
};
7575
};
7676

77-
// User-defined device key equality
77+
// User-defined device key equality, Slot key always on the right-hand side
7878
struct custom_key_equal {
79-
template <typename SlotKey, typename InputKey>
80-
__device__ bool operator()(SlotKey const& lhs, InputKey const& rhs) const
79+
template <typename InputKey, typename SlotKey>
80+
__device__ bool operator()(InputKey const& lhs, SlotKey const& rhs) const
8181
{
82-
return lhs == rhs.a;
82+
return lhs.a == rhs;
8383
}
8484
};
8585

tests/static_set/heterogeneous_lookup_test.cu

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,10 +76,10 @@ struct custom_hasher {
7676

7777
// User-defined device key equality
7878
struct custom_key_equal {
79-
template <typename SlotKey, typename InputKey>
80-
__device__ bool operator()(SlotKey const& lhs, InputKey const& rhs) const
79+
template <typename InsertKey, typename SlotKey>
80+
__device__ bool operator()(InsertKey const& lhs, SlotKey const& rhs) const
8181
{
82-
return lhs == rhs.a;
82+
return lhs.a == rhs;
8383
}
8484
};
8585

0 commit comments

Comments
 (0)