Skip to content

FD Operator CUDAExt#2466

Merged
imreddyTeja merged 8 commits intomainfrom
tr/matmul
Mar 31, 2026
Merged

FD Operator CUDAExt#2466
imreddyTeja merged 8 commits intomainfrom
tr/matmul

Conversation

@imreddyTeja
Copy link
Copy Markdown
Member

@imreddyTeja imreddyTeja commented Mar 10, 2026

TODO before merge:

  • Generic Nv support

  • Squash commits

  • Delete old copyto_stencil_64

  • rename files and kernel

  • add shmem support to auto_launch

  • Code follows the style guidelines OR N/A.

  • Unit tests are included OR N/A.

  • Code is exercised in an integration test OR N/A.

  • Documentation has been added/updated OR N/A.

Comment thread test/Operators/finitedifference/convergence_column.jl Outdated
Comment thread src/MatrixFields/MatrixFields.jl Outdated
Comment on lines +127 to +133
arg1_isa_matrix =
eltype(arg1) <: BandMatrixRow || arg1 isa LazyOperatorBroadcasted
eltype(arg1) <: BandMatrixRow || (arg1 isa LazyOperatorBroadcasted)
if arg1 isa LazyOperatorBroadcasted && length(arg1.args) > 0
arg1_isa_matrix =
eltype(arg1.args[1]) <: BandMatrixRow ||
arg1.args[1] isa LazyOperatorBroadcasted
end
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've seen this sort of conditional variable updating hurt inference inside recursive calls. Better to have something like

arg1_isa_matrix =
    arg1 isa LazyOperatorBroadcasted && length(arg1.args) > 0 ?
    eltype(arg1.args[1]) <: BandMatrixRow || arg1.args[1] isa LazyOperatorBroadcasted :
    eltype(arg1) <: BandMatrixRow || arg1 isa LazyOperatorBroadcasted

But also, this looks like it should be defined recursively? Only going down one level into arg1.args feels a bit arbitrary.

Comment thread src/MatrixFields/MatrixFields.jl Outdated
end
end

# TODO: move into CUDAExt
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could also live in Geometry/rmul_with_projection.jl for now

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that would make a dependency loop because MatrixFields depends on Geometry already

else
return Operators.return_eltype(matrix1.op.op, matrix1.args[1], arg)
end
end
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the call to rmul_return_type below not generate the same result? I don't see why this new branch is needed.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not in the case with divgrad of a vec

Comment thread .buildkite/pipeline.yml Outdated
Comment thread ext/cuda/operators_fd_eager.jl Outdated
@inline @inbounds project_row2_for_mul(mat1_row, mat2_row, mat2_space)
# It should be possible to use static shared memory here, but it allocates new shared memory
# for each layer of recursion
CUDA.sync_threads()
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the purpose of this first synchronization? The second one is to ensure that every level sees the same values in mat2, but there are no matrix values being synchronized here.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To ensure that any potential shmem use for the recursion is complete

Comment thread ext/cuda/newmm.jl Outdated
Comment thread ext/cuda/newmm.jl Outdated
Comment thread ext/cuda/newmm.jl Outdated
project_onto =
ClimaCore.Geometry.recursively_find_dual_axes_for_projection(typeof(mat1_row))
if space.staggering isa Spaces.CellCenter && v == Int32(64)
lg = rzero(Spaces.local_geometry_type(typeof(space)))
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
lg = rzero(Spaces.local_geometry_type(typeof(space)))
lg = new_struct(Spaces.local_geometry_type(typeof(space)))

You can avoid all these calls to rzero (which come which a decently large latency penalty) by using something like

@generated new_struct(::Type{T}) where {T} = Expr(:new, :T)

Comment thread ext/cuda/matmul.jl
Comment on lines +307 to +339
# row_mul_vec! handles banded matrix * vector. There are four methods, but they all have the
# same structure, so we they could be written as a single method.
# The others can be obtained by copy-pasting and changing the indices appropriately.
# Note that these are all specialized for 64 faces , so the indices are hardcoded.
Base.@propagate_inbounds function row_mul_vec!(
::Type{P},
mat1_row,
matrix2,
::FaceToCenter,
) where {P}
@inbounds begin
prod_eltype = P
v = threadIdx().x
i = threadIdx().y
mat1_eltype = typeof(mat1_row)
mat2_eltype = eltype(matrix2)
ld1, ud1 = MatrixFields.outer_diagonals(mat1_eltype)
li = Int32(1)
ri = Int32(63)
zero_entry = rzero(prod_eltype)
return UnrolledUtilities.unrolled_mapreduce(
⊞,
ld1:ud1;
init = zero_entry,
) do mat1_row_d
if (Int32(0) < v + mat1_row_d + half <= Int32(64))
@inbounds outer_or_mul(mat1_row[mat1_row_d], matrix2[v + mat1_row_d + half+ (i - Int32(1)) * Int32(64)])
else
zero_entry
end
end
end
end
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# row_mul_vec! handles banded matrix * vector. There are four methods, but they all have the
# same structure, so we they could be written as a single method.
# The others can be obtained by copy-pasting and changing the indices appropriately.
# Note that these are all specialized for 64 faces , so the indices are hardcoded.
Base.@propagate_inbounds function row_mul_vec!(
::Type{P},
mat1_row,
matrix2,
::FaceToCenter,
) where {P}
@inbounds begin
prod_eltype = P
v = threadIdx().x
i = threadIdx().y
mat1_eltype = typeof(mat1_row)
mat2_eltype = eltype(matrix2)
ld1, ud1 = MatrixFields.outer_diagonals(mat1_eltype)
li = Int32(1)
ri = Int32(63)
zero_entry = rzero(prod_eltype)
return UnrolledUtilities.unrolled_mapreduce(
,
ld1:ud1;
init = zero_entry,
) do mat1_row_d
if (Int32(0) < v + mat1_row_d + half <= Int32(64))
@inbounds outer_or_mul(mat1_row[mat1_row_d], matrix2[v + mat1_row_d + half+ (i - Int32(1)) * Int32(64)])
else
zero_entry
end
end
end
end
@inline function row_mul_vec!(::Type{P}, mat_row, vec, shape) where {P}
v_mat = threadIdx().x
i = threadIdx().y
zero_entry = rzero(P)
ld, ud = MatrixFields.outer_diagonals(typeof(mat_row))
d_offset = shape == FaceToCenter() ? half : shape == CenterToFace() ? -half : 0
return UnrolledUtilities.unrolled_mapreduce(, ld:ud; init = zero_entry) do d
v_vec = v_mat + d + d_offset
Int32(1) <= v_vec <= Spaces.nlevels(axes(vec)) || return zero_entry
@inbounds outer_or_mul(mat_row[d], vec[v_vec + (i - Int32(1)) * Int32(64)])
end
end

I think this method covers all 4 cases of matrix-vector multiplication. And it should be straightforward to extend to matrix-matrix multiplication, letting you get rid of all the code duplication in this file.

@imreddyTeja imreddyTeja force-pushed the tr/matmul branch 2 times, most recently from 17c2e6c to 743c371 Compare March 24, 2026 22:18
Add gpu support to column_convergence.jl
and unit_column.jl
@imreddyTeja imreddyTeja force-pushed the tr/matmul branch 2 times, most recently from 6ce8a50 to be7c8c4 Compare March 25, 2026 16:58
@imreddyTeja imreddyTeja marked this pull request as ready for review March 25, 2026 18:15
@imreddyTeja imreddyTeja force-pushed the tr/matmul branch 4 times, most recently from 0f340b4 to 4f7f19b Compare March 25, 2026 23:20
@imreddyTeja imreddyTeja enabled auto-merge (rebase) March 26, 2026 00:00
rename new_entry

cleanup

test cleanup

enable test

frmt

renaming

Add back inbounds
@imreddyTeja imreddyTeja force-pushed the tr/matmul branch 7 times, most recently from 5c0316c to 1fa3401 Compare March 27, 2026 16:30
@imreddyTeja imreddyTeja force-pushed the tr/matmul branch 2 times, most recently from e3dbde6 to ee17c08 Compare March 30, 2026 18:45
@imreddyTeja imreddyTeja disabled auto-merge March 30, 2026 20:20
@imreddyTeja imreddyTeja merged commit 5143dee into main Mar 31, 2026
34 of 36 checks passed
@imreddyTeja imreddyTeja deleted the tr/matmul branch March 31, 2026 22:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants