Skip to content

Commit 05f455c

Browse files
committed
Error on invalid grid size launches
1 parent 28d2eb3 commit 05f455c

2 files changed

Lines changed: 18 additions & 7 deletions

File tree

src/compiler/execution.jl

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -265,15 +265,22 @@ end
265265

266266
@autoreleasepool function (kernel::HostKernel)(args...; groups=1, threads=1,
267267
queue=global_queue(device()))
268-
threadgroupsPerGrid = MTLSize(groups)
269-
threadsPerThreadgroup = MTLSize(threads)
270-
(threadgroupsPerGrid.width>0 && threadgroupsPerGrid.height>0 && threadgroupsPerGrid.depth>0) ||
268+
gs = MTLSize(groups)
269+
ts = MTLSize(threads)
270+
(gs.width>0 && gs.height>0 && gs.depth>0) ||
271271
throw(ArgumentError("All group dimensions should be non-zero"))
272-
(threadsPerThreadgroup.width>0 && threadsPerThreadgroup.height>0 && threadsPerThreadgroup.depth>0) ||
272+
(ts.width>0 && ts.height>0 && ts.depth>0) ||
273273
throw(ArgumentError("All thread dimensions should be non-zero"))
274274

275-
(threadsPerThreadgroup.width * threadsPerThreadgroup.height * threadsPerThreadgroup.depth) > kernel.pipeline.maxTotalThreadsPerThreadgroup &&
276-
throw(ArgumentError("Number of threads in group ($(threadsPerThreadgroup.width * threadsPerThreadgroup.height * threadsPerThreadgroup.depth)) should not exceed $(kernel.pipeline.maxTotalThreadsPerThreadgroup)"))
275+
(ts.width * ts.height * ts.depth) > kernel.pipeline.maxTotalThreadsPerThreadgroup &&
276+
throw(ArgumentError("Number of threads in group ($(ts.width * ts.height * ts.depth)) should not exceed $(kernel.pipeline.maxTotalThreadsPerThreadgroup)"))
277+
278+
(gs.width * ts.width) > typemax(UInt32) &&
279+
throw(ArgumentError("Total threads per grid in a dimension (threads.width($(gs.width)) * groups.width($(ts.width)) = $(gs.width * ts.width)) must not exceed $(typemax(UInt32))"))
280+
(gs.height * ts.height) > typemax(UInt32) &&
281+
throw(ArgumentError("Total threads per grid in a dimension (threads.height($(gs.height)) * groups.height($(ts.height)) = $(gs.height * ts.height)) must not exceed $(typemax(UInt32))"))
282+
(gs.depth * ts.depth) > typemax(UInt32) &&
283+
throw(ArgumentError("Total threads per grid in a dimension (threads.depth($(gs.depth)) * groups.depth($(ts.depth)) = $(gs.depth * ts.depth)) must not exceed $(typemax(UInt32))"))
277284

278285
kernel_state = KernelState(Random.rand(UInt32))
279286

@@ -283,7 +290,7 @@ end
283290
argument_buffers = try
284291
MTL.set_function!(cce, kernel.pipeline)
285292
bufs = encode_arguments!(cce, kernel, kernel_state, kernel.f, args...)
286-
MTL.append_current_function!(cce, threadgroupsPerGrid, threadsPerThreadgroup)
293+
MTL.append_current_function!(cce, gs, ts)
287294
bufs
288295
finally
289296
close(cce)

test/execution.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,10 @@ end
164164
@test_throws InexactError @metal groups=(-2) tester(bufferA)
165165
@test_throws ArgumentError @metal threads=(1025) tester(bufferA)
166166
@test_throws ArgumentError @metal threads=(1000,2) tester(bufferA)
167+
168+
@test_throws ArgumentError @metal threads=(1024,1,1) groups=(4194304,1,1) tester(bufferA)
169+
@test_throws ArgumentError @metal threads=(1,1024,1) groups=(1,4194304,1) tester(bufferA)
170+
@test_throws ArgumentError @metal threads=(1,1,1024) groups=(1,1,4194304) tester(bufferA)
167171
end
168172

169173
############################################################################################

0 commit comments

Comments
 (0)