Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 127 additions & 0 deletions cpp/include/rmm/mr/detail/failure_callback_resource_adaptor_impl.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2020-2026, NVIDIA CORPORATION.
* SPDX-License-Identifier: Apache-2.0
*/
#pragma once

#include <rmm/detail/export.hpp>
#include <rmm/resource_ref.hpp>

#include <cuda/memory_resource>

#include <cstddef>
#include <functional>
#include <utility>

namespace RMM_NAMESPACE {
namespace mr {

/**
* @brief Callback function type used by failure_callback_resource_adaptor
*
* The resource adaptor calls this function when a memory allocation throws a specified exception
* type. The function decides whether the resource adaptor should try to allocate the memory again
* or re-throw the exception.
*
* The callback function signature is:
* `bool failure_callback_t(std::size_t bytes, void* callback_arg)`
*
* The callback function is passed two parameters: `bytes` is the size of the failed memory
* allocation and `arg` is the extra argument passed to the constructor of the
* `failure_callback_resource_adaptor`. The callback function returns a Boolean where true means to
* retry the memory allocation and false means to re-throw the exception.
*/
using failure_callback_t = std::function<bool(std::size_t, void*)>;

namespace detail {

/**
* @brief Implementation class for failure_callback_resource_adaptor.
*
* @tparam ExceptionType The type of exception that this adaptor should respond to.
*/
template <typename ExceptionType>
class failure_callback_resource_adaptor_impl {
public:
failure_callback_resource_adaptor_impl(device_async_resource_ref upstream,
failure_callback_t callback,
void* callback_arg)
: upstream_mr_{upstream}, callback_{std::move(callback)}, callback_arg_{callback_arg}
{
}

~failure_callback_resource_adaptor_impl() = default;

failure_callback_resource_adaptor_impl(failure_callback_resource_adaptor_impl const&) = delete;
failure_callback_resource_adaptor_impl(failure_callback_resource_adaptor_impl&&) = delete;
failure_callback_resource_adaptor_impl& operator=(failure_callback_resource_adaptor_impl const&) =
delete;
failure_callback_resource_adaptor_impl& operator=(failure_callback_resource_adaptor_impl&&) =
delete;

bool operator==(failure_callback_resource_adaptor_impl const& other) const noexcept
{
return this == std::addressof(other);
}

bool operator!=(failure_callback_resource_adaptor_impl const& other) const noexcept
{
return !(*this == other);
}

[[nodiscard]] device_async_resource_ref get_upstream_resource() const noexcept
{
return device_async_resource_ref{
const_cast<cuda::mr::any_resource<cuda::mr::device_accessible>&>(upstream_mr_)};
}

void* allocate(cuda::stream_ref stream,
std::size_t bytes,
std::size_t /*alignment*/ = alignof(std::max_align_t))
{
void* ret{};
while (true) {
try {
ret = upstream_mr_.allocate(stream, bytes);
break;
} catch (ExceptionType const&) {
if (!callback_(bytes, callback_arg_)) { throw; }
}
}
return ret;
}

void deallocate(cuda::stream_ref stream,
void* ptr,
std::size_t bytes,
std::size_t /*alignment*/ = alignof(std::max_align_t)) noexcept
{
upstream_mr_.deallocate(stream, ptr, bytes);
}

void* allocate_sync(std::size_t bytes, std::size_t alignment = alignof(std::max_align_t))
{
return allocate(cuda_stream_view{}, bytes, alignment);
}

void deallocate_sync(void* ptr,
std::size_t bytes,
std::size_t alignment = alignof(std::max_align_t)) noexcept
{
deallocate(cuda_stream_view{}, ptr, bytes, alignment);
}

RMM_CONSTEXPR_FRIEND void get_property(failure_callback_resource_adaptor_impl const&,
cuda::mr::device_accessible) noexcept
{
}

private:
cuda::mr::any_resource<cuda::mr::device_accessible> upstream_mr_;
failure_callback_t callback_;
void* callback_arg_;
};

} // namespace detail
} // namespace mr
} // namespace RMM_NAMESPACE
181 changes: 63 additions & 118 deletions cpp/include/rmm/mr/failure_callback_resource_adaptor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,17 @@
*/
#pragma once

#include <rmm/aligned.hpp>
#include <rmm/cuda_stream_view.hpp>
#include <rmm/detail/error.hpp>
#include <rmm/detail/export.hpp>
#include <rmm/mr/detail/failure_callback_resource_adaptor_impl.hpp>
#include <rmm/mr/device_memory_resource.hpp>
#include <rmm/mr/per_device_resource.hpp>
#include <rmm/resource_ref.hpp>

#include <cuda/memory_resource>

#include <cstddef>
#include <functional>
#include <memory>
#include <utility>

namespace RMM_NAMESPACE {
Expand All @@ -23,23 +25,6 @@ namespace mr {
* @file
*/

/**
* @brief Callback function type used by failure_callback_resource_adaptor
*
* The resource adaptor calls this function when a memory allocation throws a specified exception
* type. The function decides whether the resource adaptor should try to allocate the memory again
* or re-throw the exception.
*
* The callback function signature is:
* `bool failure_callback_t(std::size_t bytes, void* callback_arg)`
*
* The callback function is passed two parameters: `bytes` is the size of the failed memory
* allocation and `arg` is the extra argument passed to the constructor of the
* `failure_callback_resource_adaptor`. The callback function returns a Boolean where true means to
* retry the memory allocation and false means to re-throw the exception.
*/
using failure_callback_t = std::function<bool(std::size_t, void*)>;

/**
* @brief A device memory resource that calls a callback function when allocations
* throw a specified exception type.
Expand All @@ -51,152 +36,112 @@ using failure_callback_t = std::function<bool(std::size_t, void*)>;
* a bool representing whether to retry the allocation (true) or re-throw the caught exception
* (false).
*
* When implementing a callback function for allocation retry, care must be taken to avoid an
* infinite loop. The following example makes sure to only retry the allocation once:
*
* @code{.cpp}
* using failure_callback_adaptor =
* rmm::mr::failure_callback_resource_adaptor<rmm::mr::device_memory_resource>;
*
* bool failure_handler(std::size_t bytes, void* arg)
* {
* bool& retried = *reinterpret_cast<bool*>(arg);
* if (!retried) {
* retried = true;
* return true; // First time we request an allocation retry
* }
* return false; // Second time we let the adaptor throw std::bad_alloc
* }
* This class is copyable and shares ownership of its internal state via
* `cuda::mr::shared_resource`.
*
* int main()
* {
* bool retried{false};
* failure_callback_adaptor mr{
* rmm::mr::get_current_device_resource_ref(), failure_handler, &retried
* };
* rmm::mr::set_current_device_resource_ref(mr);
* }
* @endcode
*
* @tparam Upstream The type of the upstream resource used for allocation/deallocation.
* @tparam ExceptionType The type of exception that this adaptor should respond to
*/
template <typename Upstream, typename ExceptionType = rmm::out_of_memory>
class failure_callback_resource_adaptor final : public device_memory_resource {
template <typename ExceptionType = rmm::out_of_memory>
class failure_callback_resource_adaptor
: public device_memory_resource,
private cuda::mr::shared_resource<
detail::failure_callback_resource_adaptor_impl<ExceptionType>> {
using shared_base =
cuda::mr::shared_resource<detail::failure_callback_resource_adaptor_impl<ExceptionType>>;

public:
using exception_type = ExceptionType; ///< The type of exception this object catches/throws

// Begin legacy device_memory_resource compatibility layer
using device_memory_resource::allocate;
using device_memory_resource::allocate_sync;
using device_memory_resource::deallocate;
using device_memory_resource::deallocate_sync;

/**
* @brief Construct a new `failure_callback_resource_adaptor` using `upstream` to satisfy
* allocation requests.
* @brief Compare two adaptors for equality (shared-impl identity).
*
* @param upstream The resource used for allocating/deallocating device memory
* @param callback Callback function @see failure_callback_t
* @param callback_arg Extra argument passed to `callback`
* @param other The other failure_callback_resource_adaptor to compare against.
* @return true if both adaptors share the same underlying state.
*/
failure_callback_resource_adaptor(device_async_resource_ref upstream,
failure_callback_t callback,
void* callback_arg)
: upstream_{upstream}, callback_{std::move(callback)}, callback_arg_{callback_arg}
[[nodiscard]] bool operator==(failure_callback_resource_adaptor const& other) const noexcept
{
return static_cast<shared_base const&>(*this) == static_cast<shared_base const&>(other);
}

/**
* @brief Compare two adaptors for inequality.
*
* @param other The other failure_callback_resource_adaptor to compare against.
* @return true if the adaptors do not share the same underlying state.
*/
[[nodiscard]] bool operator!=(failure_callback_resource_adaptor const& other) const noexcept
{
return !(*this == other);
}
// End legacy device_memory_resource compatibility layer

/**
* @brief Enables the `cuda::mr::device_accessible` property
*/
RMM_CONSTEXPR_FRIEND void get_property(failure_callback_resource_adaptor const&,
cuda::mr::device_accessible) noexcept
{
}

/**
* @brief Construct a new `failure_callback_resource_adaptor` using `upstream` to satisfy
* allocation requests.
*
* @throws rmm::logic_error if `upstream == nullptr`
*
* @param upstream The resource used for allocating/deallocating device memory
* @param callback Callback function @see failure_callback_t
* @param callback_arg Extra argument passed to `callback`
*/
failure_callback_resource_adaptor(Upstream* upstream,
failure_callback_resource_adaptor(device_async_resource_ref upstream,
failure_callback_t callback,
void* callback_arg)
: upstream_{to_device_async_resource_ref_checked(upstream)},
callback_{std::move(callback)},
callback_arg_{callback_arg}
: shared_base(cuda::mr::make_shared_resource<
detail::failure_callback_resource_adaptor_impl<ExceptionType>>(
upstream, std::move(callback), callback_arg))
{
}

failure_callback_resource_adaptor() = delete;
~failure_callback_resource_adaptor() override = default;
failure_callback_resource_adaptor(failure_callback_resource_adaptor const&) = delete;
failure_callback_resource_adaptor& operator=(failure_callback_resource_adaptor const&) = delete;
failure_callback_resource_adaptor(failure_callback_resource_adaptor&&) noexcept =
default; ///< @default_move_constructor
failure_callback_resource_adaptor& operator=(failure_callback_resource_adaptor&&) noexcept =
default; ///< @default_move_assignment{failure_callback_resource_adaptor}
~failure_callback_resource_adaptor() = default;

/**
* @briefreturn{rmm::device_async_resource_ref to the upstream resource}
*/
[[nodiscard]] rmm::device_async_resource_ref get_upstream_resource() const noexcept
[[nodiscard]] device_async_resource_ref get_upstream_resource() const noexcept
{
return upstream_;
return this->get().get_upstream_resource();
}

// Begin legacy device_memory_resource compatibility layer
private:
/**
* @brief Allocates memory of size at least `bytes` using the upstream
* resource.
*
* @throws `exception_type` if the requested allocation could not be fulfilled
* by the upstream resource.
*
* @param bytes The size, in bytes, of the allocation
* @param stream Stream on which to perform the allocation
* @return void* Pointer to the newly allocated memory
*/
void* do_allocate(std::size_t bytes, cuda_stream_view stream) override
{
void* ret{};

while (true) {
try {
ret = get_upstream_resource().allocate(stream, bytes);
break;
} catch (exception_type const& e) {
if (!callback_(bytes, callback_arg_)) { throw; }
}
}
return ret;
return shared_base::allocate(stream, bytes, rmm::CUDA_ALLOCATION_ALIGNMENT);
}

/**
* @brief Free allocation of size `bytes` pointed to by `ptr`
*
* @param ptr Pointer to be deallocated
* @param bytes Size of the allocation
* @param stream Stream on which to perform the deallocation
*/
void do_deallocate(void* ptr, std::size_t bytes, cuda_stream_view stream) noexcept override
{
get_upstream_resource().deallocate(stream, ptr, bytes);
shared_base::deallocate(stream, ptr, bytes, rmm::CUDA_ALLOCATION_ALIGNMENT);
}

/**
* @brief Compare the upstream resource to another.
*
* @param other The other resource to compare to
* @return true If the two resources are equivalent
* @return false If the two resources are not equal
*/
[[nodiscard]] bool do_is_equal(device_memory_resource const& other) const noexcept override
{
if (this == std::addressof(other)) { return true; }
auto cast = dynamic_cast<failure_callback_resource_adaptor<Upstream> const*>(&other);
auto const* cast = dynamic_cast<failure_callback_resource_adaptor const*>(&other);
if (cast == nullptr) { return false; }
return get_upstream_resource() == cast->get_upstream_resource();
return static_cast<shared_base const&>(*this) == static_cast<shared_base const&>(*cast);
}

// the upstream resource used for satisfying allocation requests
device_async_resource_ref upstream_;
failure_callback_t callback_;
void* callback_arg_;
// End legacy device_memory_resource compatibility layer
};

static_assert(
cuda::mr::resource_with<failure_callback_resource_adaptor<>, cuda::mr::device_accessible>,
"failure_callback_resource_adaptor does not satisfy the cuda::mr::resource concept");

/** @} */ // end of group
} // namespace mr
} // namespace RMM_NAMESPACE
1 change: 1 addition & 0 deletions cpp/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ ConfigureTest(TRACKING_MR_REF_TEST mr/mr_ref_tracking_tests.cpp GPUS 1 PERCENT 5

# out-of-memory callback adaptor tests
ConfigureTest(FAILURE_CALLBACK_TEST mr/failure_callback_mr_tests.cpp)
ConfigureTest(FAILURE_CALLBACK_MR_REF_TEST mr/mr_ref_failure_callback_tests.cpp GPUS 1 PERCENT 5)

# prefetch adaptor tests
ConfigureTest(PREFETCH_ADAPTOR_TEST mr/prefetch_resource_adaptor_tests.cpp)
Expand Down
Loading
Loading