Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion ext/AtomixCUDAExt.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# TODO: respect ordering
module AtomixCUDAExt

using Atomix: Atomix, IndexableRef
using Atomix: Atomix, IndexableRef, right
using CUDA: CUDA, CuDeviceArray

const CuIndexableRef{Indexable<:CuDeviceArray} = IndexableRef{Indexable}
Expand Down Expand Up @@ -48,6 +48,8 @@ end
CUDA.atomic_min!(ptr, x)
elseif op === max
CUDA.atomic_max!(ptr, x)
elseif op === right
CUDA.atomic_xchg!(ptr, x)
else
error("not implemented")
end
Expand Down
4 changes: 3 additions & 1 deletion ext/AtomixMetalExt.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# TODO: respect ordering
module AtomixMetalExt

using Atomix: Atomix, IndexableRef
using Atomix: Atomix, IndexableRef, right
using Metal: Metal, MtlDeviceArray

const MtlIndexableRef{Indexable<:MtlDeviceArray} = IndexableRef{Indexable}
Expand Down Expand Up @@ -57,6 +57,8 @@ end
Metal.atomic_fetch_min_explicit(ptr, x)
elseif op === max
Metal.atomic_fetch_max_explicit(ptr, x)
elseif op === right
Metal.atomic_exchange_explicit(ptr, x)
else
error("not implemented")
end
Expand Down
4 changes: 3 additions & 1 deletion ext/AtomixOpenCLExt.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# TODO: respect ordering
module AtomixOpenCLExt

using Atomix: Atomix, IndexableRef
using Atomix: Atomix, IndexableRef, right
using OpenCL: SPIRVIntrinsics, CLDeviceArray

const CLIndexableRef{Indexable<:CLDeviceArray} = IndexableRef{Indexable}
Expand Down Expand Up @@ -48,6 +48,8 @@ end
SPIRVIntrinsics.atomic_min!(ptr, x)
elseif op === max
SPIRVIntrinsics.atomic_max!(ptr, x)
elseif op === right
SPIRVIntrinsics.atomic_xchg!(ptr, x)
else
error("not implemented")
end
Expand Down
4 changes: 3 additions & 1 deletion ext/AtomixoneAPIExt.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# TODO: respect ordering
module AtomixoneAPIExt

using Atomix: Atomix, IndexableRef
using Atomix: Atomix, IndexableRef, right
using oneAPI: oneAPI, oneDeviceArray

const oneIndexableRef{Indexable<:oneDeviceArray} = IndexableRef{Indexable}
Expand Down Expand Up @@ -48,6 +48,8 @@ end
oneAPI.atomic_min!(ptr, x)
elseif op === max
oneAPI.atomic_max!(ptr, x)
elseif op === right
oneAPI.atomic_xchg!(ptr, x)
else
error("not implemented")
end
Expand Down
12 changes: 12 additions & 0 deletions test/test_atomix_cuda.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,15 @@ end
end
@test collect(A) == [2, 1, 1]
end

@testset "AtomixCUDAExt:test_swap_sugar" begin
A = CUDA.ones(Int, 3)
B = CUDA.zeros(Int, 3)
cuda() do
GC.@preserve A B begin
B[begin] = @atomicswap A[begin] = 4
end
end
@test collect(A) == [4, 1, 1]
@test collect(B) == [1, 0, 0]
end
12 changes: 12 additions & 0 deletions test/test_atomix_metal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,15 @@ end
end
@test collect(A) == [2, 1, 1]
end

@testset "AtomixMetalExt:test_swap_sugar" begin
A = Metal.ones(Int32, 3)
B = Metal.zeros(Int32, 3)
metal() do
GC.@preserve A B begin
B[begin] = @atomicswap A[begin] = 4
end
end
@test collect(A) == [4, 1, 1]
@test collect(B) == [1, 0, 0]
end
12 changes: 12 additions & 0 deletions test/test_atomix_oneapi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,15 @@ end
end
@test collect(A) == [2, 1, 1]
end

@testset "AtomixoneAPIExt:test_swap_sugar" begin
A = oneAPI.ones(Int32, 3)
B = oneAPI.zeros(Int32, 3)
oneapi() do
GC.@preserve A B begin
B[begin] = @atomicswap A[begin] = 4
end
end
@test collect(A) == [4, 1, 1]
@test collect(B) == [1, 0, 0]
end
12 changes: 12 additions & 0 deletions test/test_atomix_opencl.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,15 @@ end
end
@test collect(A) == [2, 1, 1]
end

@testset "AtomixOpenCLExt:test_swap_sugar" begin
A = OpenCL.ones(Int32, 3)
B = OpenCL.zeros(Int32, 3)
opencl() do
GC.@preserve A B begin
B[begin] = @atomicswap A[begin] = 4
end
end
@test collect(A) == [4, 1, 1]
@test collect(B) == [1, 0, 0]
end