Skip to content

Commit a5acbf5

Browse files
committed
Move shared storage gpu -> gpu transfer heuristic to MtlPtr code
1 parent 2abede9 commit a5acbf5

2 files changed

Lines changed: 22 additions & 31 deletions

File tree

src/array.jl

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -422,23 +422,6 @@ function Base.unsafe_copyto!(dev::MTLDevice, dest::MtlArray{T}, doffs, src::MtlA
422422
end
423423
return dest
424424
end
425-
function Base.unsafe_copyto!(dev::MTLDevice, dest::MtlArray{T, <:Any, Metal.SharedStorage}, doffs, src::MtlArray{T, <:Any, Metal.SharedStorage}, soffs, n) where {T}
426-
synchronize()
427-
bytes = n * sizeof(T)
428-
# Use GPU blit for large copies (>32MiB) where it's faster than CPU memcpy.
429-
# For small copies, CPU memcpy avoids GPU command buffer overhead.
430-
if bytes >= 32 * 2^20 # If changed, also change in tests
431-
GC.@preserve src dest unsafe_copyto!(dev, pointer(dest, doffs), pointer(src, soffs), n)
432-
if Base.isbitsunion(T)
433-
error("Not implemented")
434-
end
435-
else
436-
# use the raw CPU pointers directly so this also works with non-aligned offsets
437-
# (which can arise from e.g. reinterpret of a view); unsafe_wrap would refuse them
438-
GC.@preserve src dest unsafe_copyto!(pointer(dest, doffs; storage=SharedStorage), pointer(src, soffs; storage=SharedStorage), n)
439-
end
440-
return dest
441-
end
442425

443426

444427
## regular gpu array adaptor

src/memory.jl

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -78,26 +78,34 @@ end
7878
# Split up copies > 2GiB to avoid silent failures when copying buffers > 4Gib
7979
# to fix JuliaGPU/Metal.jl#710. Solution inspired by
8080
# https://github.com/pytorch/pytorch/pull/126104
81-
@autoreleasepool function Base.unsafe_copyto!(dev::MTLDevice, dst::MtlPtr{T},
81+
function Base.unsafe_copyto!(dev::MTLDevice, dst::MtlPtr{T},
8282
src::MtlPtr{T}, N::Integer;
8383
queue::MTLCommandQueue=global_queue(dev),
8484
async::Bool=false) where T
8585
if N > 0
86-
chunk_size = 2^31
87-
cmdbuf = MTLCommandBuffer(queue)
88-
MTLBlitCommandEncoder(cmdbuf) do enc
89-
nbytes = N * sizeof(T)
90-
offset = 0
91-
92-
while nbytes > 0
93-
transfer_bytes = min(nbytes, chunk_size)
94-
append_copy!(enc, dst.buffer, dst.offset + offset, src.buffer, src.offset + offset, transfer_bytes)
95-
offset += transfer_bytes
96-
nbytes -= transfer_bytes
86+
nbytes = N * sizeof(T)
87+
# For small copies of Shared memory arrays, CPU memcpy avoids GPU command buffer overhead.
88+
# Otherwise, use GPU blit for large copies (>32MiB) where it's faster than CPU memcpy.
89+
if dst.buffer.storageMode == src.buffer.storageMode == MTL.MTLStorageModeShared && nbytes < 2^25
90+
unsafe_copyto!(convert(Ptr{T}, dst), convert(Ptr{T}, src), N)
91+
else
92+
@autoreleasepool begin
93+
chunk_size = 2^31
94+
cmdbuf = MTLCommandBuffer(queue)
95+
MTLBlitCommandEncoder(cmdbuf) do enc
96+
offset = 0
97+
98+
while nbytes > 0
99+
transfer_bytes = min(nbytes, chunk_size)
100+
append_copy!(enc, dst.buffer, dst.offset + offset, src.buffer, src.offset + offset, transfer_bytes)
101+
offset += transfer_bytes
102+
nbytes -= transfer_bytes
103+
end
104+
end
105+
commit!(cmdbuf)
106+
async || wait_completed(cmdbuf)
97107
end
98108
end
99-
commit!(cmdbuf)
100-
async || wait_completed(cmdbuf)
101109
end
102110
return dst
103111
end

0 commit comments

Comments
 (0)