Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
d138a03
Add support for CUMSUM and TRI for CUDA.
pwilkin Nov 28, 2025
67207d2
Minor optimizations.
pwilkin Nov 28, 2025
fab0029
Correct warp_prefix_inclusive_sum in float2 variant to return float2
pwilkin Nov 28, 2025
51c40a5
Optimize TRI
pwilkin Dec 1, 2025
c30f565
Whitespace
pwilkin Dec 1, 2025
31b55fa
Fix strides.
pwilkin Dec 1, 2025
d1ca1c2
Implement double loop
pwilkin Dec 1, 2025
5289b53
Whitespace
pwilkin Dec 1, 2025
f422ba8
Fix HIP compilation bugs
pwilkin Dec 1, 2025
df917cc
Optimizations + big case performance tests
pwilkin Dec 2, 2025
76382d7
Implement using CUB with fallback to custom kernel
pwilkin Dec 2, 2025
01d4033
Remove error message.
pwilkin Dec 2, 2025
10a2ea9
Fixes from code review
pwilkin Dec 3, 2025
7a83b05
Comment out CPU-unsupported F16/BF16 cases to fix CI
pwilkin Dec 3, 2025
bbe3743
Fine, you win :P
pwilkin Dec 4, 2025
069413a
Fix last cast, use NO_DEVICE_CODE and GGML_UNUSED_VARS
pwilkin Dec 4, 2025
5aa7438
Vary warp-size based on physical warp size
pwilkin Dec 4, 2025
579eba6
Add GGML_UNUSED_VARS in tri as well
pwilkin Dec 4, 2025
08b3f2d
Use constexpr and call prefix_inclusive with warp_size template param
pwilkin Dec 4, 2025
9cd0eff
Update ggml/src/ggml-cuda/cumsum.cu
pwilkin Dec 4, 2025
9574264
Apply suggestions from code review
pwilkin Dec 4, 2025
efd619a
Change to tid % warp_size
pwilkin Dec 4, 2025
86a0853
Fix strides; hardcode mask; add ggml_lane_mask_t
pwilkin Dec 4, 2025
de45c63
Missing renames, remove unused get_warp_mask(), explicit calls to ggm…
pwilkin Dec 4, 2025
8a7375c
Too hasty...
pwilkin Dec 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions ggml/src/ggml-cuda/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,64 @@ static __device__ __forceinline__ float warp_reduce_max(float x) {
return x;
}

static __device__ __forceinline__ unsigned int get_warp_mask() {
#ifdef __HIP_PLATFORM_AMD__
return __ballot(1); // HIP equivalent
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know basically nothing about HIP, but according to this doc, it seems like __activemask(); should be supported? The main difference referenced there is the warp size of 64 vs 32 which I could absolutely imagine being accidentally hard coded somewhere.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Specifically, I see #define WARP_SIZE 32 at the top of this file.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc/ @IMbackK

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the WARP_SIZE is deprecated and the remaining uses should only be used in places affecting performance, but not correctness, the non-deprecated equivalent is ggml_cuda_get_physical_warp_size

__activemask is indeed supported and works, but i will need to check how long - will do that later.

We will need to change the return type of this and the kernel below, @pwilkin you can do so or skip the kernel on hip and i will fix it in a follow up.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@IMbackK okay, I'll comment it out then and add a TODO, prefer to leave it so someone who knows what they're doing then leave an untested vibe-coded patch :)

#else
return __activemask(); // CUDA
#endif
}

template<typename T, int width = WARP_SIZE>
static __device__ __forceinline__ T warp_prefix_inclusive_sum(T x) {
const int lane_id = threadIdx.x % width;
const auto mask = get_warp_mask();
#pragma unroll
for (int offset = 1; offset < width; offset <<= 1) {
const T t = __shfl_up_sync(mask, x, offset, width);
if (lane_id >= offset) {
x += t;
}
}
return x;
}

template<int width = WARP_SIZE>
static __device__ __forceinline__ float2 warp_prefix_inclusive_sum(float2 a) {
const int lane_id = threadIdx.x % width;
const auto mask = get_warp_mask();
#pragma unroll
for (int offset = 1; offset < width; offset <<= 1) {
const float t_x = __shfl_up_sync(mask, a.x, offset, width);
const float t_y = __shfl_up_sync(mask, a.y, offset, width);
if (lane_id >= offset) {
a.x += t_x;
a.y += t_y;
}
}
return a;
}

template<int width = WARP_SIZE>
static __device__ __forceinline__ half2 warp_prefix_inclusive_sum(half2 a) {
#ifdef FP16_AVAILABLE
const int lane_id = threadIdx.x % width;
const auto mask = get_warp_mask();
#pragma unroll
for (int offset = 1; offset < width; offset <<= 1) {
const half2 t = __shfl_up_sync(mask, a, offset, width);
if (lane_id >= offset) {
a = __hadd2(a, t);
}
}
return a;

#else
NO_DEVICE_CODE;
return a;
#endif // FP16_AVAILABLE
}

static __device__ __forceinline__ half ggml_cuda_hmax(const half a, const half b) {
#ifdef FP16_AVAILABLE

Expand Down
136 changes: 136 additions & 0 deletions ggml/src/ggml-cuda/cumsum.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
#include <algorithm>

#include "cumsum.cuh"

// Kernel to compute cumulative sum along the innermost dimension (ne[0])
// Each block processes one row (ne[0] elements)
// Algorithm matches Metal implementation:
// 1. Each warp computes prefix sum within itself
// 2. Last thread of each warp stores result in shared memory
// 3. All warps sync
// 4. Each element adds the sum of all preceding warps

template<typename T>
static __global__ void cumsum_kernel(
const T * src, T * dst,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03,
const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3) {

// Shared memory to store warp sums (always use float for accumulation)
extern __shared__ float shmem[];

const int64_t i3 = blockIdx.z;
const int64_t i2 = blockIdx.y;
const int64_t i1 = blockIdx.x;

if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
return;
}

const T * src_row = src + i1 * nb01 + i2*nb02 + i3*nb03;
T * dst_row = dst + i1 * nb1 + i2*nb2 + i3*nb3;

const int tid = threadIdx.x;
const int lane_id = tid % WARP_SIZE;

if (tid >= ne00) {
return;
}

// Phase 1: Each thread processes elements at stride blockDim.x
// Compute warp-level prefix sums
for (int64_t i0 = tid; i0 < ne00; i0 += blockDim.x) {
// Load value and compute prefix sum within warp
float val = static_cast<float>(src_row[i0]);
val = warp_prefix_inclusive_sum(val);
dst_row[i0] = static_cast<T>(val);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be much preferable to store the temporary results in registers or shared memory rather than global memory.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't val here already stored in a register though? I'm afraid I'll need some more guidance here.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dst_row is in global memory. With this code you are writing data to VRAM on this line, only to later read this data again, add a value to it, and write it back. So you have 3x as much I/O to the comparatively slow VRAM vs. the comparatively faster SRAM or registers where you could be storing it instead until you write the data once at the end of the kernel.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now I get it, thanks!


// Last thread of warp stores its sum to shared memory at position based on data index
if (lane_id == WARP_SIZE - 1 || i0 == ne00 - 1) {
const int shmem_idx = i0 / WARP_SIZE;
shmem[shmem_idx] = val;
}
}

// Sync once after all warp prefix sums are computed
__syncthreads();

// Phase 2: Add the sum of all preceding warp groups to each element
for (int64_t i0 = tid; i0 < ne00; i0 += blockDim.x) {
const int shmem_idx = i0 / WARP_SIZE;
float sum = 0.0f;
for (int j = 0; j < shmem_idx; ++j) {
sum += shmem[j];
}
dst_row[i0] = static_cast<T>(static_cast<float>(dst_row[i0]) + sum);
}
}

template<typename T>
static void cumsum_cuda(
const T * src, T * dst,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03,
const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3,
cudaStream_t stream) {

dim3 grid_dims(ne01, ne02, ne03);

// Shared memory size: one float per warp
const int num_warps = (ne00 + WARP_SIZE - 1) / WARP_SIZE;
const size_t shmem_size = num_warps * sizeof(float);
const size_t type_size = sizeof(T);

int block_size = num_warps * WARP_SIZE;
block_size = std::min(block_size, CUDA_CUMSUM_BLOCK_SIZE);
dim3 block_dims(block_size, 1, 1);

cumsum_kernel<<<grid_dims, block_dims, shmem_size, stream>>>(
src, dst,
ne00, ne01, ne02, ne03,
nb00 / type_size, nb01 / type_size, nb02 / type_size, nb03 / type_size,
nb0 / type_size, nb1 / type_size, nb2 / type_size, nb3 / type_size
);
}

void ggml_cuda_op_cumsum(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
cudaStream_t stream = ctx.stream();

GGML_ASSERT(src0->type == dst->type);
switch(src0->type) {
case GGML_TYPE_F32:
{
cumsum_cuda(
(const float *)src0->data, (float *)dst->data,
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],
stream
);
} break;
case GGML_TYPE_F16:
{
cumsum_cuda(
(const half *)src0->data, (half *)dst->data,
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],
stream
);
} break;
case GGML_TYPE_BF16:
{
cumsum_cuda(
(const nv_bfloat16 *)src0->data, (nv_bfloat16 *)dst->data,
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],
stream
);
} break;
default:
GGML_ABORT("fatal error");
}
}
5 changes: 5 additions & 0 deletions ggml/src/ggml-cuda/cumsum.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#include "common.cuh"

#define CUDA_CUMSUM_BLOCK_SIZE 256

void ggml_cuda_op_cumsum(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
10 changes: 10 additions & 0 deletions ggml/src/ggml-cuda/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@
#include "ggml-cuda/set-rows.cuh"
#include "ggml-cuda/pad_reflect_1d.cuh"
#include "ggml-cuda/solve_tri.cuh"
#include "ggml-cuda/tri.cuh"
#include "ggml-cuda/cumsum.cuh"
#include "ggml.h"

#include <algorithm>
Expand Down Expand Up @@ -2700,6 +2702,12 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_CROSS_ENTROPY_LOSS:
ggml_cuda_cross_entropy_loss(ctx, dst);
break;
case GGML_OP_CUMSUM:
ggml_cuda_op_cumsum(ctx, dst);
break;
case GGML_OP_TRI:
ggml_cuda_op_tri(ctx, dst);
break;
case GGML_OP_RWKV_WKV6:
ggml_cuda_op_rwkv_wkv6(ctx, dst);
break;
Expand Down Expand Up @@ -4262,6 +4270,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
case GGML_OP_OPT_STEP_ADAMW:
case GGML_OP_OPT_STEP_SGD:
case GGML_OP_CUMSUM:
case GGML_OP_TRI:
return true;
case GGML_OP_SOLVE_TRI:
return op->src[0]->ne[0] <= 64 && op->src[1]->ne[0] <= 32;
Expand Down
133 changes: 133 additions & 0 deletions ggml/src/ggml-cuda/tri.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
#include "ggml-cuda/common.cuh"
#include "tri.cuh"
#include "ggml.h"

template<typename T, bool prefix_keep, int add_to_split>
static __global__ void tri_kernel(
const T * src, T * dst,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03,
const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3) {
const int64_t i3 = blockIdx.z;
const int64_t i2 = blockIdx.y;
const int64_t i1 = blockIdx.x;
const int64_t split_point = i1 + add_to_split;

if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
return;
}

const T * src_row = src + i1*nb01 + i2*nb02 + i3*nb03;
T * dst_row = dst + i1*nb1 + i2*nb2 + i3*nb3;

if constexpr (prefix_keep) {
for (int64_t i0 = threadIdx.x; i0 < split_point; i0 += blockDim.x) {
dst_row[i0] = src_row[i0];
}
for (int64_t i0 = threadIdx.x + split_point; i0 < ne00; i0 += blockDim.x) {
dst_row[i0] = T(0);
}
} else {
for (int64_t i0 = threadIdx.x; i0 < split_point; i0 += blockDim.x) {
dst_row[i0] = T(0);
}
for (int64_t i0 = threadIdx.x + split_point; i0 < ne00; i0 += blockDim.x) {
dst_row[i0] = src_row[i0];
}
}
}

template<typename T>
static void tri_cuda(
const T * src, T * dst,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03,
const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3,
const ggml_tri_type ttype,
cudaStream_t stream) {

dim3 block_dims(CUDA_TRI_BLOCK_SIZE, 1, 1);
dim3 grid_dims(ne01, ne02, ne03);
const size_t type_size = sizeof(T);

const int add_to_split = (ttype == GGML_TRI_TYPE_LOWER_DIAG || ttype == GGML_TRI_TYPE_UPPER) ? 1 : 0;
const bool prefix_keep = (ttype == GGML_TRI_TYPE_LOWER || ttype == GGML_TRI_TYPE_LOWER_DIAG);

if (prefix_keep) {
if (add_to_split == 0) {
tri_kernel<T, true, 0><<<grid_dims, block_dims, 0, stream>>>(
src, dst,
ne00, ne01, ne02, ne03,
nb00 / type_size, nb01 / type_size, nb02 / type_size, nb03 / type_size,
nb0 / type_size, nb1 / type_size, nb2 / type_size, nb3 / type_size
);
} else { // only 0 and 1 supported
tri_kernel<T, true, 1><<<grid_dims, block_dims, 0, stream>>>(
src, dst,
ne00, ne01, ne02, ne03,
nb00 / type_size, nb01 / type_size, nb02 / type_size, nb03 / type_size,
nb0 / type_size, nb1 / type_size, nb2 / type_size, nb3 / type_size
);
}
} else {
if (add_to_split == 0) {
tri_kernel<T, false, 0><<<grid_dims, block_dims, 0, stream>>>(
src, dst,
ne00, ne01, ne02, ne03,
nb00 / type_size, nb01 / type_size, nb02 / type_size, nb03 / type_size,
nb0 / type_size, nb1 / type_size, nb2 / type_size, nb3 / type_size
);
} else {
tri_kernel<T, false, 1><<<grid_dims, block_dims, 0, stream>>>(
src, dst,
ne00, ne01, ne02, ne03,
nb00 / type_size, nb01 / type_size, nb02 / type_size, nb03 / type_size,
nb0 / type_size, nb1 / type_size, nb2 / type_size, nb3 / type_size
);
}
}
}

void ggml_cuda_op_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
cudaStream_t stream = ctx.stream();

const ggml_tri_type ttype = static_cast<ggml_tri_type>(ggml_get_op_params_i32(dst, 0));

GGML_ASSERT(src0->type == dst->type);

switch(src0->type) {
case GGML_TYPE_F32:
{
tri_cuda(
(const float *)src0->data, (float *)dst->data,
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],
ttype, stream
);
} break;
case GGML_TYPE_F16:
{
tri_cuda(
(const half *)src0->data, (half *)dst->data,
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],
ttype, stream
);
} break;
case GGML_TYPE_BF16:
{
tri_cuda(
(const nv_bfloat16 *)src0->data, (nv_bfloat16 *)dst->data,
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],
ttype, stream
);
} break;
default:
GGML_ABORT("fatal error");
}
}
5 changes: 5 additions & 0 deletions ggml/src/ggml-cuda/tri.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#include "common.cuh"

#define CUDA_TRI_BLOCK_SIZE 256

void ggml_cuda_op_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
6 changes: 6 additions & 0 deletions tests/test-backend-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7938,6 +7938,12 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 4, 2 }, { 6, 64, 4, 2 }));
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 128, 128, 4, 1 }, { 8, 128, 4, 1 }));

test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_LOWER, GGML_TYPE_F32, { 256, 256, 4, 4 }));
test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_UPPER_DIAG, GGML_TYPE_F32, { 1024, 1024, 8, 4 }));

test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 128, 128, 4, 1 }));
test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 128, 128, 4, 1 }));

for (int bs : {1, 2, 3, 4, 5, 8, 512}) {
for (ggml_type type_a : all_types) {
for (ggml_type type_b : {GGML_TYPE_F32}) {
Expand Down
Loading