diff --git a/base/task.jl b/base/task.jl index 8b40e73bbc845..9841bd947fdee 100644 --- a/base/task.jl +++ b/base/task.jl @@ -383,8 +383,11 @@ completed tasks, and the other consists of uncompleted tasks. each runs serially, since this needs to scan the list of `tasks` each time and synchronize with each one every time this is called. Or consider using [`waitall(tasks; failfast=true)`](@ref waitall) instead. + +!!! compat "Julia 1.12" + This function requires at least Julia 1.12. """ -waitany(tasks; throw=true) = _wait_multiple(tasks, throw) +waitany(tasks; throw=true) = _wait_multiple(collect_tasks(tasks), throw) """ waitall(tasks; failfast=true, throw=true) -> (done_tasks, remaining_tasks) @@ -400,17 +403,22 @@ given tasks is finished by exception. If `throw` is `true`, throw The return value consists of two task vectors. The first one consists of completed tasks, and the other consists of uncompleted tasks. + +!!! compat "Julia 1.12" + This function requires at least Julia 1.12. """ -waitall(tasks; failfast=true, throw=true) = _wait_multiple(tasks, throw, true, failfast) +waitall(tasks; failfast=true, throw=true) = _wait_multiple(collect_tasks(tasks), throw, true, failfast) -function _wait_multiple(waiting_tasks, throwexc=false, all=false, failfast=false) +function collect_tasks(waiting_tasks) tasks = Task[] - for t in waiting_tasks t isa Task || error("Expected an iterator of `Task` object") push!(tasks, t) end + return tasks +end +function _wait_multiple(tasks::Vector{Task}, throwexc::Bool=false, all::Bool=false, failfast::Bool=false) if (all && !failfast) || length(tasks) <= 1 exception = false # Force everything to finish synchronously for the case of waitall @@ -474,22 +482,36 @@ function _wait_multiple(waiting_tasks, throwexc=false, all=false, failfast=false end while nremaining > 0 + exception && failfast && break i = take!(chan) t = tasks[i] waiter_tasks[i] = sentinel done_mask[i] = true exception |= istaskfailed(t) nremaining -= 1 - - # stop early if requested, unless there is something immediately - # ready to consume from the channel (using a race-y check) - if (!all || (failfast && exception)) && !isready(chan) - break - end + # stop early if requested + all || break end close(chan) + # now just read which tasks finished directly: the channel is not needed anymore for that + # repeat until we get (acquire) the list of all dependent-exited tasks + changed = true + while changed + changed = false + for (i, done) in enumerate(done_mask) + done && continue + t = tasks[i] + if istaskdone(t) + done_mask[i] = true + exception |= istaskfailed(t) + nremaining -= 1 + changed = true + end + end + end + if nremaining == 0 if throwexc && exception exceptions = [TaskFailedException(t) for t in tasks if istaskfailed(t)] @@ -500,6 +522,7 @@ function _wait_multiple(waiting_tasks, throwexc=false, all=false, failfast=false remaining_mask = .~done_mask for i in findall(remaining_mask) waiter = waiter_tasks[i] + waiter === sentinel && continue donenotify = tasks[i].donenotify::ThreadSynchronizer @lock donenotify list_deletefirst!(donenotify.waitq, waiter) end diff --git a/test/threads_exec.jl b/test/threads_exec.jl index 54ab20538e70f..2780888546964 100644 --- a/test/threads_exec.jl +++ b/test/threads_exec.jl @@ -1372,9 +1372,7 @@ end tasks = [Threads.@spawn(div(1, i)) for i = 0:1] wait(tasks[1]; throw=false) wait(tasks[2]; throw=false) - @test_throws CompositeException begin - waitall(Threads.@spawn(div(1, i)) for i = 0:1) - end + @test_throws CompositeException waitall(tasks) end end end