|
78 | 78 | # Split up copies > 2GiB to avoid silent failures when copying buffers > 4Gib |
79 | 79 | # to fix JuliaGPU/Metal.jl#710. Solution inspired by |
80 | 80 | # https://github.com/pytorch/pytorch/pull/126104 |
81 | | -@autoreleasepool function Base.unsafe_copyto!(dev::MTLDevice, dst::MtlPtr{T}, |
| 81 | +function Base.unsafe_copyto!(dev::MTLDevice, dst::MtlPtr{T}, |
82 | 82 | src::MtlPtr{T}, N::Integer; |
83 | 83 | queue::MTLCommandQueue=global_queue(dev), |
84 | 84 | async::Bool=false) where T |
85 | 85 | if N > 0 |
86 | | - chunk_size = 2^31 |
87 | | - cmdbuf = MTLCommandBuffer(queue) |
88 | | - MTLBlitCommandEncoder(cmdbuf) do enc |
89 | | - nbytes = N * sizeof(T) |
90 | | - offset = 0 |
91 | | - |
92 | | - while nbytes > 0 |
93 | | - transfer_bytes = min(nbytes, chunk_size) |
94 | | - append_copy!(enc, dst.buffer, dst.offset + offset, src.buffer, src.offset + offset, transfer_bytes) |
95 | | - offset += transfer_bytes |
96 | | - nbytes -= transfer_bytes |
| 86 | + nbytes = N * sizeof(T) |
| 87 | + # For small copies of Shared memory arrays, CPU memcpy avoids GPU command buffer overhead. |
| 88 | + # Otherwise, use GPU blit for large copies (>32MiB) where it's faster than CPU memcpy. |
| 89 | + if dst.buffer.storageMode == src.buffer.storageMode == MTL.MTLStorageModeShared && nbytes < 2^25 |
| 90 | + unsafe_copyto!(convert(Ptr{T}, dst), convert(Ptr{T}, src), N) |
| 91 | + else |
| 92 | + @autoreleasepool begin |
| 93 | + chunk_size = 2^31 |
| 94 | + cmdbuf = MTLCommandBuffer(queue) |
| 95 | + MTLBlitCommandEncoder(cmdbuf) do enc |
| 96 | + offset = 0 |
| 97 | + |
| 98 | + while nbytes > 0 |
| 99 | + transfer_bytes = min(nbytes, chunk_size) |
| 100 | + append_copy!(enc, dst.buffer, dst.offset + offset, src.buffer, src.offset + offset, transfer_bytes) |
| 101 | + offset += transfer_bytes |
| 102 | + nbytes -= transfer_bytes |
| 103 | + end |
| 104 | + end |
| 105 | + commit!(cmdbuf) |
| 106 | + async || wait_completed(cmdbuf) |
97 | 107 | end |
98 | 108 | end |
99 | | - commit!(cmdbuf) |
100 | | - async || wait_completed(cmdbuf) |
101 | 109 | end |
102 | 110 | return dst |
103 | 111 | end |
|
0 commit comments