Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 45 additions & 15 deletions cpp/include/rmm/detail/cccl_adaptors.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include <rmm/mr/detail/device_memory_resource_view.hpp>
#include <rmm/mr/device_memory_resource.hpp>

#include <cuda/memory_resource>
#include <cuda/std/optional>

#include <cstddef>
Expand Down Expand Up @@ -351,6 +352,20 @@ class cccl_async_resource_ref {
{
}

/**
* @brief Constructs a resource reference from a CCCL any_resource.
*
* This constructor enables constructing a resource_ref from an any_resource,
* which is useful when retrieving resources from containers that store any_resource.
*
* @param res A CCCL any_resource to reference
*/
template <typename... Properties>
cccl_async_resource_ref(cuda::mr::any_resource<Properties...>& res)
: view_{cuda::std::nullopt}, ref_{res}
{
}

/**
* @brief Copy constructor that properly reconstructs the ref to point to the new view.
*
Expand Down Expand Up @@ -398,24 +413,27 @@ class cccl_async_resource_ref {
* @brief Construct a ref from a resource.
*
* This constructor accepts CCCL resource types but NOT CCCL resource_ref types,
* our own wrapper types, or device_memory_resource derived types. The exclusions
* are checked FIRST to prevent recursive constraint satisfaction.
* our own wrapper types, any_resource types, or device_memory_resource derived types.
* The exclusions are checked FIRST to prevent recursive constraint satisfaction.
*
* @tparam OtherResourceType A CCCL resource type (not a resource_ref, wrapper, or DMR)
* @tparam OtherResourceType A CCCL resource type (not a resource_ref, wrapper, any_resource, or
* DMR)
* @param other The resource to construct a ref from
*/
template <typename OtherResourceType,
std::enable_if_t<not is_specialization_of_v<std::remove_cv_t<OtherResourceType>,
cuda::mr::synchronous_resource_ref> and
not is_specialization_of_v<std::remove_cv_t<OtherResourceType>,
cuda::mr::resource_ref> and
not is_specialization_of_v<std::remove_cv_t<OtherResourceType>,
::rmm::detail::cccl_resource_ref> and
not is_specialization_of_v<std::remove_cv_t<OtherResourceType>,
::rmm::detail::cccl_async_resource_ref> and
not std::is_base_of_v<rmm::mr::device_memory_resource,
std::remove_cv_t<OtherResourceType>> and
cuda::mr::resource<OtherResourceType>>* = nullptr>
template <
typename OtherResourceType,
std::enable_if_t<
not is_specialization_of_v<std::remove_cv_t<OtherResourceType>,
cuda::mr::synchronous_resource_ref> and
not is_specialization_of_v<std::remove_cv_t<OtherResourceType>, cuda::mr::resource_ref> and
not is_specialization_of_v<std::remove_cv_t<OtherResourceType>, cuda::mr::any_resource> and
not is_specialization_of_v<std::remove_cv_t<OtherResourceType>,
::rmm::detail::cccl_resource_ref> and
not is_specialization_of_v<std::remove_cv_t<OtherResourceType>,
::rmm::detail::cccl_async_resource_ref> and
not std::is_base_of_v<rmm::mr::device_memory_resource,
std::remove_cv_t<OtherResourceType>> and
cuda::mr::resource<OtherResourceType>>* = nullptr>
cccl_async_resource_ref(OtherResourceType& other) : view_{}, ref_{ResourceType{other}}
{
}
Expand Down Expand Up @@ -538,6 +556,18 @@ class cccl_async_resource_ref {
return try_get_property(ref.ref_, prop);
}

/**
* @brief Implicit conversion to cuda::mr::any_resource<>.
*
* This enables reification of the resource_ref to an owning any_resource type.
* The conversion copies the underlying resource into the any_resource.
*/
template <typename... Properties>
operator cuda::mr::any_resource<Properties...>() const
{
return cuda::mr::any_resource<Properties...>{ref_};
}

protected:
cuda::std::optional<rmm::mr::detail::device_memory_resource_view> view_;
ResourceType ref_;
Expand Down
34 changes: 20 additions & 14 deletions cpp/include/rmm/mr/per_device_resource.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2020-2025, NVIDIA CORPORATION.
* SPDX-FileCopyrightText: Copyright (c) 2020-2026, NVIDIA CORPORATION.
* SPDX-License-Identifier: Apache-2.0
*/

Expand All @@ -11,6 +11,8 @@
#include <rmm/mr/device_memory_resource.hpp>
#include <rmm/resource_ref.hpp>

#include <cuda/memory_resource>

#include <map>
#include <mutex>

Expand Down Expand Up @@ -140,12 +142,13 @@ RMM_EXPORT inline std::mutex& ref_map_lock()

// This symbol must have default visibility, see: https://github.com/rapidsai/rmm/issues/826
/**
* @briefreturn{Reference to the map from device id -> resource_ref}
* @briefreturn{Reference to the map from device id -> any_resource}
*/
RMM_EXPORT inline auto& get_ref_map()
{
static std::map<cuda_device_id::value_type, device_async_resource_ref> device_id_to_resource_ref;
return device_id_to_resource_ref;
static std::map<cuda_device_id::value_type, cuda::mr::any_resource<cuda::mr::device_accessible>>
device_id_to_resource;
return device_id_to_resource;
}

} // namespace detail
Expand Down Expand Up @@ -192,17 +195,17 @@ namespace detail {
inline device_async_resource_ref set_per_device_resource_ref_unsafe(
cuda_device_id device_id, device_async_resource_ref new_resource_ref)
{
auto& map = detail::get_ref_map();
auto const old_itr = map.find(device_id.value());
// If a resource didn't previously exist for `device_id`, return pointer to initial_resource
// Note: because resource_ref is not default-constructible, we can't use std::map::operator[]
using any_device_resource = cuda::mr::any_resource<cuda::mr::device_accessible>;
auto& map = detail::get_ref_map();
auto const old_itr = map.find(device_id.value());
// If a resource didn't previously exist for `device_id`, return ref to initial_resource
if (old_itr == map.end()) {
map.insert({device_id.value(), new_resource_ref});
map.emplace(device_id.value(), static_cast<any_device_resource>(new_resource_ref));
return device_async_resource_ref{*detail::initial_resource()};
}

auto old_resource_ref = old_itr->second;
old_itr->second = new_resource_ref; // update map directly via iterator
device_async_resource_ref old_resource_ref{old_itr->second};
old_itr->second = static_cast<any_device_resource>(new_resource_ref); // reify and store
return old_resource_ref;
}
} // namespace detail
Expand Down Expand Up @@ -333,15 +336,18 @@ inline device_memory_resource* set_current_device_resource(device_memory_resourc
*/
inline device_async_resource_ref get_per_device_resource_ref(cuda_device_id device_id)
{
using any_device_resource = cuda::mr::any_resource<cuda::mr::device_accessible>;
std::lock_guard<std::mutex> lock{detail::ref_map_lock()};
auto& map = detail::get_ref_map();
// If a resource was never set for `id`, set to the initial resource
auto const found = map.find(device_id.value());
if (found == map.end()) {
auto item = map.insert({device_id.value(), *detail::initial_resource()});
return item.first->second;
// Create a resource_ref from the initial resource, then reify it to any_resource
device_async_resource_ref initial_ref{*detail::initial_resource()};
auto item = map.emplace(device_id.value(), static_cast<any_device_resource>(initial_ref));
return device_async_resource_ref{item.first->second};
}
return found->second;
return device_async_resource_ref{found->second};
}

/**
Expand Down
23 changes: 16 additions & 7 deletions cpp/tests/mr/mr_ref_test_basic.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2023-2025, NVIDIA CORPORATION.
* SPDX-FileCopyrightText: Copyright (c) 2023-2026, NVIDIA CORPORATION.
* SPDX-License-Identifier: Apache-2.0
*/

Expand All @@ -19,18 +19,27 @@ TEST_P(mr_ref_test, SetCurrentDeviceResourceRef)
rmm::mr::set_current_device_resource_ref(cuda_ref);
auto old = rmm::mr::set_current_device_resource_ref(this->ref);

// old mr should equal a cuda mr
EXPECT_EQ(old, cuda_ref);
// Old ref should be functional (verify by successful allocation)
constexpr std::size_t size{100};
void* ptr = old.allocate(rmm::cuda_stream_default, size);
EXPECT_NE(ptr, nullptr);
old.deallocate(rmm::cuda_stream_default, ptr, size);

// current dev resource should equal this resource
EXPECT_EQ(this->ref, rmm::mr::get_current_device_resource_ref());
// Current device resource should be usable for allocation
auto current = rmm::mr::get_current_device_resource_ref();
ptr = current.allocate(rmm::cuda_stream_default, size);
EXPECT_NE(ptr, nullptr);
current.deallocate(rmm::cuda_stream_default, ptr, size);

test_get_current_device_resource_ref();

// Resetting should reset to initial cuda resource
rmm::mr::reset_current_device_resource_ref();
EXPECT_EQ(rmm::device_async_resource_ref{rmm::mr::detail::initial_resource()},
rmm::mr::get_current_device_resource_ref());
// Verify reset worked by checking allocation succeeds with initial resource
current = rmm::mr::get_current_device_resource_ref();
ptr = current.allocate(rmm::cuda_stream_default, size);
EXPECT_NE(ptr, nullptr);
current.deallocate(rmm::cuda_stream_default, ptr, size);
}

TEST_P(mr_ref_test, SelfEquality) { EXPECT_TRUE(this->ref == this->ref); }
Expand Down
41 changes: 20 additions & 21 deletions cpp/tests/mr/mr_ref_test_mt.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2023-2025, NVIDIA CORPORATION.
* SPDX-FileCopyrightText: Copyright (c) 2023-2026, NVIDIA CORPORATION.
* SPDX-License-Identifier: Apache-2.0
*/

Expand Down Expand Up @@ -130,22 +130,23 @@ inline void test_async_allocate_free_different_threads(rmm::device_async_resourc

TEST_P(mr_ref_test_mt, SetCurrentDeviceResourceRef_mt)
{
// single thread changes default resource, then multiple threads use it
auto old = rmm::mr::set_current_device_resource_ref(this->ref);
// Single thread changes default resource, then multiple threads use it
rmm::mr::set_current_device_resource_ref(this->ref);
test_get_current_device_resource_ref();

int device;
RMM_CUDA_TRY(cudaGetDevice(&device));

spawn([device, mr = this->ref]() {
spawn([device]() {
RMM_CUDA_TRY(cudaSetDevice(device));
EXPECT_EQ(mr, rmm::mr::get_current_device_resource_ref());
test_get_current_device_resource_ref(); // test allocating with the new default resource
// Verify the current resource is functional
test_get_current_device_resource_ref();
});

// resetting default resource should reset to initial
// Resetting default resource should reset to initial
rmm::mr::reset_current_device_resource_ref();
EXPECT_EQ(old, rmm::mr::get_current_device_resource_ref());
// Verify reset worked by testing allocation with initial resource
test_get_current_device_resource_ref();
}

TEST_P(mr_ref_test_mt, SetCurrentDeviceResourceRefPerThread_mt)
Expand All @@ -162,19 +163,17 @@ TEST_P(mr_ref_test_mt, SetCurrentDeviceResourceRefPerThread_mt)
threads.emplace_back(
[mr](auto dev_id) {
RMM_CUDA_TRY(cudaSetDevice(dev_id));
auto cuda_ref = rmm::mr::get_current_device_resource_ref();
auto old = rmm::mr::set_current_device_resource_ref(mr);

// initial resource for this device should be CUDA mr
EXPECT_EQ(old, cuda_ref);
// get_current_device_resource_ref should equal the resource we
// just set
EXPECT_EQ(mr, rmm::mr::get_current_device_resource_ref());
// Resetting current dev resource ref should make it
// cuda MR and return the MR we previously set
old = rmm::mr::reset_current_device_resource_ref();
EXPECT_EQ(old, mr);
EXPECT_EQ(cuda_ref, rmm::mr::get_current_device_resource_ref());
// Verify initial resource is functional
test_get_current_device_resource_ref();

rmm::mr::set_current_device_resource_ref(mr);
// Verify newly set resource is functional
test_get_current_device_resource_ref();

// Resetting current dev resource ref should restore initial resource
rmm::mr::reset_current_device_resource_ref();
// Verify reset resource is functional
test_get_current_device_resource_ref();
},
i);
}
Expand Down
15 changes: 6 additions & 9 deletions cpp/tests/mr/statistics_mr_tests.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2020-2025, NVIDIA CORPORATION.
* SPDX-FileCopyrightText: Copyright (c) 2020-2026, NVIDIA CORPORATION.
* SPDX-License-Identifier: Apache-2.0
*/

Expand Down Expand Up @@ -121,24 +121,24 @@ TEST(StatisticsTest, PeakAllocations)

TEST(StatisticsTest, MultiTracking)
{
// Test stacking multiple statistics adaptors, using explicit resource refs
// to avoid lifetime issues with the global device resource map
auto orig_device_resource = rmm::mr::get_current_device_resource_ref();
statistics_adaptor mr{orig_device_resource};
rmm::mr::set_current_device_resource_ref(mr);

std::vector<std::shared_ptr<rmm::device_buffer>> allocations;
for (std::size_t i = 0; i < num_allocations; ++i) {
allocations.emplace_back(
std::make_shared<rmm::device_buffer>(ten_MiB, rmm::cuda_stream_default));
std::make_shared<rmm::device_buffer>(ten_MiB, rmm::cuda_stream_default, &mr));
}

EXPECT_EQ(mr.get_allocations_counter().value, 10);

statistics_adaptor inner_mr{rmm::mr::get_current_device_resource_ref()};
rmm::mr::set_current_device_resource_ref(inner_mr);
statistics_adaptor inner_mr{&mr};

for (std::size_t i = 0; i < num_more_allocations; ++i) {
allocations.emplace_back(
std::make_shared<rmm::device_buffer>(ten_MiB, rmm::cuda_stream_default));
std::make_shared<rmm::device_buffer>(ten_MiB, rmm::cuda_stream_default, &inner_mr));
}

// Check the allocated bytes for both MRs
Expand All @@ -164,9 +164,6 @@ TEST(StatisticsTest, MultiTracking)

EXPECT_EQ(mr.get_allocations_counter().peak, 15);
EXPECT_EQ(inner_mr.get_allocations_counter().peak, 5);

// Reset the current device resource
rmm::mr::set_current_device_resource_ref(orig_device_resource);
}

TEST(StatisticsTest, NegativeInnerTracking)
Expand Down
15 changes: 6 additions & 9 deletions cpp/tests/mr/tracking_mr_tests.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2020-2025, NVIDIA CORPORATION.
* SPDX-FileCopyrightText: Copyright (c) 2020-2026, NVIDIA CORPORATION.
* SPDX-License-Identifier: Apache-2.0
*/

Expand Down Expand Up @@ -94,24 +94,24 @@ TEST(TrackingTest, AllocationsLeftWithoutStacks)

TEST(TrackingTest, MultiTracking)
{
// Test stacking multiple tracking adaptors, using explicit resource refs
// to avoid lifetime issues with the global device resource map
auto orig_device_resource = rmm::mr::get_current_device_resource_ref();
tracking_adaptor mr{orig_device_resource, true};
rmm::mr::set_current_device_resource_ref(mr);

std::vector<std::shared_ptr<rmm::device_buffer>> allocations;
for (std::size_t i = 0; i < num_allocations; ++i) {
allocations.emplace_back(
std::make_shared<rmm::device_buffer>(ten_MiB, rmm::cuda_stream_default));
std::make_shared<rmm::device_buffer>(ten_MiB, rmm::cuda_stream_default, &mr));
}

EXPECT_EQ(mr.get_outstanding_allocations().size(), num_allocations);

tracking_adaptor inner_mr{rmm::mr::get_current_device_resource_ref()};
rmm::mr::set_current_device_resource_ref(inner_mr);
tracking_adaptor inner_mr{&mr};

for (std::size_t i = 0; i < num_more_allocations; ++i) {
allocations.emplace_back(
std::make_shared<rmm::device_buffer>(ten_MiB, rmm::cuda_stream_default));
std::make_shared<rmm::device_buffer>(ten_MiB, rmm::cuda_stream_default, &inner_mr));
}

// Check the allocated bytes for both MRs
Expand All @@ -132,9 +132,6 @@ TEST(TrackingTest, MultiTracking)

EXPECT_EQ(mr.get_allocated_bytes(), 0);
EXPECT_EQ(inner_mr.get_allocated_bytes(), 0);

// Reset the current device resource
rmm::mr::set_current_device_resource_ref(orig_device_resource);
}

TEST(TrackingTest, NegativeInnerTracking)
Expand Down