Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ add_library(
src/mr/detail/aligned_resource_adaptor_impl.cpp
src/mr/detail/arena_memory_resource_impl.cpp
src/mr/detail/binning_memory_resource_impl.cpp
src/mr/detail/callback_memory_resource_impl.cpp
src/mr/callback_memory_resource.cpp
src/mr/detail/fixed_size_memory_resource_impl.cpp
src/mr/detail/logging_resource_adaptor_impl.cpp
src/mr/detail/logging_resource_adaptor_impl.cpp
Expand Down
125 changes: 65 additions & 60 deletions cpp/include/rmm/mr/callback_memory_resource.hpp
Original file line number Diff line number Diff line change
@@ -1,22 +1,25 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION.
* SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION.
* SPDX-License-Identifier: Apache-2.0
*/
#pragma once

#include <rmm/cuda_stream_view.hpp>
#include <rmm/detail/export.hpp>
#include <rmm/mr/detail/callback_memory_resource_impl.hpp>
#include <rmm/mr/device_memory_resource.hpp>
#include <rmm/resource_ref.hpp>

#include <cuda/memory_resource>

#include <cstddef>
#include <functional>
#include <utility>

namespace RMM_NAMESPACE {
namespace mr {
/**
* @addtogroup memory_resources
* @{
* @file
*/

/**
Expand Down Expand Up @@ -54,12 +57,58 @@ using allocate_callback_t = std::function<void*(std::size_t, cuda_stream_view, v
*/
using deallocate_callback_t = std::function<void(void*, std::size_t, cuda_stream_view, void*)>;

namespace detail {
class callback_memory_resource_impl;
}

/**
* @brief A device memory resource that uses the provided callbacks for memory allocation
* and deallocation.
*
* This class is copyable and shares ownership of its internal state via
* `cuda::mr::shared_resource`.
*/
class callback_memory_resource final : public device_memory_resource {
class RMM_EXPORT callback_memory_resource
: public device_memory_resource,
private cuda::mr::shared_resource<detail::callback_memory_resource_impl> {
using shared_base = cuda::mr::shared_resource<detail::callback_memory_resource_impl>;

public:
using device_memory_resource::allocate;
using device_memory_resource::allocate_sync;
using device_memory_resource::deallocate;
using device_memory_resource::deallocate_sync;

/**
* @brief Compare two resources for equality (shared-impl identity).
*
* @param other The other callback_memory_resource to compare against.
* @return true if both resources share the same underlying state.
*/
[[nodiscard]] bool operator==(callback_memory_resource const& other) const noexcept
{
return static_cast<shared_base const&>(*this) == static_cast<shared_base const&>(other);
}

/**
* @brief Compare two resources for inequality.
*
* @param other The other callback_memory_resource to compare against.
* @return true if the resources do not share the same underlying state.
*/
[[nodiscard]] bool operator!=(callback_memory_resource const& other) const noexcept
{
return !(*this == other);
}

/**
* @brief Enables the `cuda::mr::device_accessible` property
*/
RMM_CONSTEXPR_FRIEND void get_property(callback_memory_resource const&,
cuda::mr::device_accessible) noexcept
{
}

/**
* @brief Construct a new callback memory resource.
*
Expand All @@ -75,67 +124,23 @@ class callback_memory_resource final : public device_memory_resource {
* It is the caller's responsibility to maintain the lifetime of the pointed-to data
* for the duration of the lifetime of the `callback_memory_resource`.
*/
callback_memory_resource(
allocate_callback_t allocate_callback,
deallocate_callback_t deallocate_callback,
void* allocate_callback_arg = nullptr, // NOLINT(bugprone-easily-swappable-parameters)
void* deallocate_callback_arg = nullptr) noexcept
: allocate_callback_(std::move(allocate_callback)),
deallocate_callback_(std::move(deallocate_callback)),
allocate_callback_arg_(allocate_callback_arg),
deallocate_callback_arg_(deallocate_callback_arg)
{
}
callback_memory_resource(allocate_callback_t allocate_callback,
deallocate_callback_t deallocate_callback,
void* allocate_callback_arg = nullptr,
void* deallocate_callback_arg = nullptr);

callback_memory_resource() = delete;
~callback_memory_resource() override = default;
callback_memory_resource(callback_memory_resource const&) = delete;
callback_memory_resource& operator=(callback_memory_resource const&) = delete;
callback_memory_resource(callback_memory_resource&&) noexcept =
default; ///< @default_move_constructor
callback_memory_resource& operator=(callback_memory_resource&&) noexcept =
default; ///< @default_move_assignment{callback_memory_resource}
callback_memory_resource() = delete;
~callback_memory_resource() = default;

private:
/**
* @brief Allocates memory of size at least \p bytes.
*
* The returned pointer will have at minimum 256 byte alignment.
*
* If supported by the callback, this operation may optionally be executed on
* a stream. Otherwise, the stream is ignored and the null stream is used.
*
* @param bytes The size of the allocation
* @param stream Stream on which to perform allocation
* @return void* Pointer to the newly allocated memory
*/
void* do_allocate(std::size_t bytes, cuda_stream_view stream) override
{
return allocate_callback_(bytes, stream, allocate_callback_arg_);
}

/**
* @brief Deallocate memory pointed to by \p ptr.
*
* If supported by the callback, this operation may optionally be executed on
* a stream. Otherwise, the stream is ignored and the null stream is used.
*
* @param ptr Pointer to be deallocated
* @param bytes The size in bytes of the allocation. This must be equal to the
* value of `bytes` that was passed to the `allocate` call that returned `ptr`.
* @param stream Stream on which to perform deallocation
*/
void do_deallocate(void* ptr, std::size_t bytes, cuda_stream_view stream) noexcept override
{
deallocate_callback_(ptr, bytes, stream, deallocate_callback_arg_);
}

allocate_callback_t allocate_callback_;
deallocate_callback_t deallocate_callback_;
void* allocate_callback_arg_;
void* deallocate_callback_arg_;
void* do_allocate(std::size_t bytes, cuda_stream_view stream) override;
void do_deallocate(void* ptr, std::size_t bytes, cuda_stream_view stream) noexcept override;
[[nodiscard]] bool do_is_equal(device_memory_resource const& other) const noexcept override;
};

static_assert(cuda::mr::resource_with<callback_memory_resource, cuda::mr::device_accessible>,
"callback_memory_resource does not satisfy the cuda::mr::resource concept");

/** @} */ // end of group
} // namespace mr
} // namespace RMM_NAMESPACE
82 changes: 82 additions & 0 deletions cpp/include/rmm/mr/detail/callback_memory_resource_impl.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION.
* SPDX-License-Identifier: Apache-2.0
*/
#pragma once

#include <rmm/cuda_stream_view.hpp>
#include <rmm/detail/export.hpp>
#include <rmm/resource_ref.hpp>

#include <cuda/memory_resource>

#include <cstddef>
#include <functional>

namespace RMM_NAMESPACE {
namespace mr {
namespace detail {

/**
* @brief Implementation class for callback_memory_resource.
*
* Holds the allocate/deallocate callbacks and their arguments. This class
* satisfies the CCCL `cuda::mr::resource` concept and is held by
* `callback_memory_resource` via `cuda::mr::shared_resource` for
* reference-counted ownership.
*/
class callback_memory_resource_impl {
public:
callback_memory_resource_impl(
std::function<void*(std::size_t, cuda_stream_view, void*)> allocate_callback,
std::function<void(void*, std::size_t, cuda_stream_view, void*)> deallocate_callback,
void* allocate_callback_arg,
void* deallocate_callback_arg) noexcept;

~callback_memory_resource_impl() = default;

callback_memory_resource_impl(callback_memory_resource_impl const&) = delete;
callback_memory_resource_impl(callback_memory_resource_impl&&) = delete;
callback_memory_resource_impl& operator=(callback_memory_resource_impl const&) = delete;
callback_memory_resource_impl& operator=(callback_memory_resource_impl&&) = delete;

bool operator==(callback_memory_resource_impl const& other) const noexcept
{
return this == std::addressof(other);
}

bool operator!=(callback_memory_resource_impl const& other) const noexcept
{
return !(*this == other);
}

void* allocate(cuda::stream_ref stream,
std::size_t bytes,
std::size_t alignment = alignof(std::max_align_t));

void deallocate(cuda::stream_ref stream,
void* ptr,
std::size_t bytes,
std::size_t alignment = alignof(std::max_align_t)) noexcept;

void* allocate_sync(std::size_t bytes, std::size_t alignment = alignof(std::max_align_t));

void deallocate_sync(void* ptr,
std::size_t bytes,
std::size_t alignment = alignof(std::max_align_t)) noexcept;

RMM_CONSTEXPR_FRIEND void get_property(callback_memory_resource_impl const&,
cuda::mr::device_accessible) noexcept
{
}

private:
std::function<void*(std::size_t, cuda_stream_view, void*)> allocate_callback_;
std::function<void(void*, std::size_t, cuda_stream_view, void*)> deallocate_callback_;
void* allocate_callback_arg_;
void* deallocate_callback_arg_;
};

} // namespace detail
} // namespace mr
} // namespace RMM_NAMESPACE
50 changes: 50 additions & 0 deletions cpp/src/mr/callback_memory_resource.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION.
* SPDX-License-Identifier: Apache-2.0
*/

#include <rmm/aligned.hpp>
#include <rmm/mr/callback_memory_resource.hpp>
#include <rmm/mr/detail/callback_memory_resource_impl.hpp>

#include <utility>

namespace RMM_NAMESPACE {
namespace mr {

callback_memory_resource::callback_memory_resource(allocate_callback_t allocate_callback,
deallocate_callback_t deallocate_callback,
void* allocate_callback_arg,
void* deallocate_callback_arg)
: shared_base(cuda::mr::make_shared_resource<detail::callback_memory_resource_impl>(
std::move(allocate_callback),
std::move(deallocate_callback),
allocate_callback_arg,
deallocate_callback_arg))
{
}

// Begin legacy device_memory_resource compatibility layer
void* callback_memory_resource::do_allocate(std::size_t bytes, cuda_stream_view stream)
{
return shared_base::allocate(stream, bytes, rmm::CUDA_ALLOCATION_ALIGNMENT);
}

void callback_memory_resource::do_deallocate(void* ptr,
std::size_t bytes,
cuda_stream_view stream) noexcept
{
shared_base::deallocate(stream, ptr, bytes, rmm::CUDA_ALLOCATION_ALIGNMENT);
}

bool callback_memory_resource::do_is_equal(device_memory_resource const& other) const noexcept
{
if (this == std::addressof(other)) { return true; }
auto const* cast = dynamic_cast<callback_memory_resource const*>(&other);
if (cast == nullptr) { return false; }
return static_cast<shared_base const&>(*this) == static_cast<shared_base const&>(*cast);
}
// End legacy device_memory_resource compatibility layer

} // namespace mr
} // namespace RMM_NAMESPACE
55 changes: 55 additions & 0 deletions cpp/src/mr/detail/callback_memory_resource_impl.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION.
* SPDX-License-Identifier: Apache-2.0
*/

#include <rmm/mr/detail/callback_memory_resource_impl.hpp>

#include <utility>

namespace RMM_NAMESPACE {
namespace mr {
namespace detail {

callback_memory_resource_impl::callback_memory_resource_impl(
std::function<void*(std::size_t, cuda_stream_view, void*)> allocate_callback,
std::function<void(void*, std::size_t, cuda_stream_view, void*)> deallocate_callback,
void* allocate_callback_arg,
void* deallocate_callback_arg) noexcept
: allocate_callback_(std::move(allocate_callback)),
deallocate_callback_(std::move(deallocate_callback)),
allocate_callback_arg_(allocate_callback_arg),
deallocate_callback_arg_(deallocate_callback_arg)
{
}

void* callback_memory_resource_impl::allocate(cuda::stream_ref stream,
std::size_t bytes,
std::size_t /*alignment*/)
{
return allocate_callback_(bytes, cuda_stream_view{stream.get()}, allocate_callback_arg_);
}

void callback_memory_resource_impl::deallocate(cuda::stream_ref stream,
void* ptr,
std::size_t bytes,
std::size_t /*alignment*/) noexcept
{
deallocate_callback_(ptr, bytes, cuda_stream_view{stream.get()}, deallocate_callback_arg_);
}

void* callback_memory_resource_impl::allocate_sync(std::size_t bytes, std::size_t alignment)
{
return allocate(cuda_stream_view{}, bytes, alignment);
}

void callback_memory_resource_impl::deallocate_sync(void* ptr,
std::size_t bytes,
std::size_t alignment) noexcept
{
deallocate(cuda_stream_view{}, ptr, bytes, alignment);
}

} // namespace detail
} // namespace mr
} // namespace RMM_NAMESPACE
1 change: 1 addition & 0 deletions cpp/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ ConfigureTest(BINNING_MR_TEST mr/binning_mr_tests.cpp)

# callback memory resource tests
ConfigureTest(CALLBACK_MR_TEST mr/callback_mr_tests.cpp)
ConfigureTest(CALLBACK_MR_REF_TEST mr/mr_ref_callback_tests.cpp GPUS 1 PERCENT 5)

# system memory resource tests
ConfigureTest(SYSTEM_MR_TEST mr/system_mr_tests.cu GPUS 1 PERCENT 100)
Expand Down
Loading
Loading