diff --git a/cpp/include/rmm/device_buffer.hpp b/cpp/include/rmm/device_buffer.hpp index b210a1174..54e041dfe 100644 --- a/cpp/include/rmm/device_buffer.hpp +++ b/cpp/include/rmm/device_buffer.hpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2019-2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2019-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ #pragma once @@ -11,6 +11,7 @@ #include #include +#include #include #include @@ -320,7 +321,7 @@ class device_buffer { /** * @briefreturn{The resource used to allocate and deallocate} */ - [[nodiscard]] rmm::device_async_resource_ref memory_resource() const noexcept { return _mr; } + [[nodiscard]] rmm::device_async_resource_ref memory_resource() noexcept { return _mr; } private: void* _data{nullptr}; ///< Pointer to device memory allocation @@ -328,9 +329,8 @@ class device_buffer { std::size_t _capacity{}; ///< The actual size of the device memory allocation cuda_stream_view _stream{}; ///< Stream to use for device memory deallocation - rmm::device_async_resource_ref _mr{ - rmm::mr::get_current_device_resource_ref()}; ///< The memory resource used to - ///< allocate/deallocate device memory + cuda::mr::any_resource _mr; ///< The memory resource used to + ///< allocate/deallocate device memory cuda_device_id _device{get_current_cuda_device()}; /** diff --git a/cpp/include/rmm/device_uvector.hpp b/cpp/include/rmm/device_uvector.hpp index b24df64ad..7e674f570 100644 --- a/cpp/include/rmm/device_uvector.hpp +++ b/cpp/include/rmm/device_uvector.hpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2020-2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2020-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -602,7 +602,7 @@ class device_uvector { * @briefreturn{The resource used to allocate and deallocate the device * storage} */ - [[nodiscard]] rmm::device_async_resource_ref memory_resource() const noexcept + [[nodiscard]] rmm::device_async_resource_ref memory_resource() noexcept { return _storage.memory_resource(); } diff --git a/cpp/src/device_buffer.cpp b/cpp/src/device_buffer.cpp index 5a2ce2aef..9958c75db 100644 --- a/cpp/src/device_buffer.cpp +++ b/cpp/src/device_buffer.cpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2019-2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2019-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -44,7 +44,7 @@ device_buffer::device_buffer(device_buffer&& other) noexcept _size{other._size}, _capacity{other._capacity}, _stream{other.stream()}, - _mr{other._mr}, + _mr{std::move(other._mr)}, _device{other._device} { other._data = nullptr; @@ -64,7 +64,7 @@ device_buffer& device_buffer::operator=(device_buffer&& other) noexcept _size = other._size; _capacity = other._capacity; set_stream(other.stream()); - _mr = other._mr; + _mr = std::move(other._mr); _device = other._device; other._data = nullptr; diff --git a/cpp/tests/mr/mr_ref_default_tests.cpp b/cpp/tests/mr/mr_ref_default_tests.cpp index b9326d631..52863af75 100644 --- a/cpp/tests/mr/mr_ref_default_tests.cpp +++ b/cpp/tests/mr/mr_ref_default_tests.cpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2023-2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2023-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -63,6 +63,34 @@ TEST(DefaultTest, GetCurrentDeviceResourceRef) EXPECT_EQ(mr, rmm::device_async_resource_ref{rmm::mr::detail::initial_resource()}); } +TEST(DefaultTest, SetCurrentDeviceResourceRefFromPointer) +{ + // Construct a cuda_memory_resource + rmm::mr::cuda_memory_resource cuda_mr{}; + + // Get a pointer to it (device_memory_resource*) + rmm::mr::device_memory_resource* mr_ptr = &cuda_mr; + + // Set with set_current_device_resource_ref using the pointer + rmm::mr::set_current_device_resource_ref(mr_ptr); + + // Get the ref with get_current_device_resource_ref + auto ref = rmm::mr::get_current_device_resource_ref(); + + // Use that ref to allocate + constexpr std::size_t size{1024}; + void* ptr = ref.allocate_sync(size); + EXPECT_NE(ptr, nullptr); + EXPECT_TRUE(is_properly_aligned(ptr)); + EXPECT_TRUE(is_device_accessible_memory(ptr)); + + // Deallocate + ref.deallocate_sync(ptr, size); + + // Reset to initial resource + rmm::mr::reset_current_device_resource_ref(); +} + // Multi-threaded default resource tests TEST(DefaultTest, UseCurrentDeviceResource_mt) { spawn(test_get_current_device_resource); }