Skip to content

Commit b29f3f1

Browse files
committed
Add macro mt_async
1 parent acd21f3 commit b29f3f1

File tree

2 files changed

+60
-0
lines changed

2 files changed

+60
-0
lines changed

src/onthreads.jl

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,3 +193,33 @@ function ThreadLocal{T}(f::Base.Callable) where {T}
193193
@onthreads allthreads() result.value[threadid()] = f()
194194
result
195195
end
196+
197+
198+
"""
199+
@mt_async expr
200+
201+
Spawn a Julia task running `expr` asynchronously.
202+
203+
Compatible with `@sync`. Uses a multi-threaded task scheduler if available (on
204+
Julia >= v1.3).
205+
206+
Equivalent to `Base.@async` on Julia <= v1.2, equivalent to
207+
`Base.Threads.@spawn` on Julia >= v1.3.
208+
"""
209+
macro mt_async(expr)
210+
# Code taken from Base.@async and Base.Threads.@spawn:
211+
thunk = esc(:(()->($expr)))
212+
var = esc(Base.sync_varname)
213+
quote
214+
local task = Task($thunk)
215+
@static if VERSION >= v"1.3.0-alpha.0"
216+
task.sticky = false
217+
end
218+
if $(Expr(:isdefined, var))
219+
push!($var, task)
220+
end
221+
schedule(task)
222+
task
223+
end
224+
end
225+
export @mt_async

test/test_onthreads.jl

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,22 @@ using Base.Threads
1111
@warn "JULIA multithreading not enabled"
1212
end
1313

14+
15+
function do_work(n)
16+
if n < 0
17+
throw(ArgumentError("n must be >= 0"))
18+
end
19+
s::Float64 = 0
20+
for i in 1:n
21+
if n % 1000 == 0
22+
yield()
23+
end
24+
s += log(abs(asin(sin(Complex(log(i), log(i))))) + 1)
25+
end
26+
s
27+
end
28+
29+
1430
@testset "macro onthreads" begin
1531
@test (begin
1632
tl = ThreadLocal(0)
@@ -19,6 +35,20 @@ using Base.Threads
1935
end) == 1:nthreads()
2036
end
2137

38+
@testset "macro mt_async" begin
39+
@test begin
40+
n = 128
41+
A = zeros(n)
42+
@sync for i in eachindex(A)
43+
@mt_async begin
44+
do_work(10^3)
45+
A[i] = log(i)
46+
end
47+
end
48+
A == log.(1:n)
49+
end
50+
end
51+
2252
@testset "Examples" begin
2353
@testset "Example 1" begin
2454
tlsum = ThreadLocal(0.0)

0 commit comments

Comments
 (0)