From 45d31fce3aa6f758499c4c8f33080ae3fad6d3ae Mon Sep 17 00:00:00 2001 From: Iblis Lin Date: Tue, 12 Dec 2017 12:55:12 +0800 Subject: [PATCH 1/3] autograd: user custom function --- examples/autograd/customfunc.jl | 90 ++++++++++++++++ src/autograd.jl | 180 ++++++++++++++++++++++++++++++++ src/base.jl | 2 + test/unittest/autograd.jl | 66 ++++++++++++ 4 files changed, 338 insertions(+) create mode 100644 examples/autograd/customfunc.jl diff --git a/examples/autograd/customfunc.jl b/examples/autograd/customfunc.jl new file mode 100644 index 000000000..aff1284fe --- /dev/null +++ b/examples/autograd/customfunc.jl @@ -0,0 +1,90 @@ +using MXNet + +############################################################################### +# swish: option 1, with inner constructor +############################################################################### + +""" + swish(x) = x * σ(x) + +See [Swish: a Self-Gated Activation Function] +(https://arxiv.org/pdf/1710.05941.pdf). +""" +mutable struct swish + x + y + σ + @mx.custom function swish(x) + σ = @. 1 / (1 + e^(-x)) # assume there is no mx.sigmoid; we need to get hand dirty + y = x .* σ + new(x, σ, y) # must return a object instance for @custom + end +end + +# the actual return value +mx.forward(f::swish, x) = f.y + +mx.backward!(f::swish, Δy #= coefficient of gradient =#) = + @. (f.y + f.σ * (1 - f.y)) * Δy + +############################################################################### +# swish2: option 2, with outer constructor +############################################################################### + +mutable struct swish2 + x + y + σ +end + +@mx.custom function swish2(x) + σ = @. 1 / (1 + e^(-x)) + y = x .* σ + swish2(x, σ, y) # must return a object instance for @custom +end + +mx.forward(f::swish2, x) = f.y + +mx.backward!(f::swish2, Δy #= coefficient of gradient =#) = + @. (f.y + f.σ * (1 - f.y)) * Δy + +############################################################################### +# example usage +############################################################################### + +x = mx.NDArray(Float32[1 2; 3 4]) +∇ = mx.attach_grad!(x) +y = mx.record() do + swish(x) +end +mx.backward!(y) +∇ + + +# For the record, here shows the overhead of custom function +# +# julia> @benchmark g() # custom func swish +# BenchmarkTools.Trial: +# memory estimate: 29.83 KiB +# allocs estimate: 608 +# -------------- +# minimum time: 372.205 μs (0.00% GC) +# median time: 475.992 μs (0.00% GC) +# mean time: 565.441 μs (3.65% GC) +# maximum time: 33.960 ms (47.71% GC) +# -------------- +# samples: 8723 +# evals/sample: 1 +# +# julia> @benchmark h() # with native NDArray operator +# BenchmarkTools.Trial: +# memory estimate: 9.39 KiB +# allocs estimate: 184 +# -------------- +# minimum time: 179.940 μs (0.00% GC) +# median time: 234.188 μs (0.00% GC) +# mean time: 264.236 μs (1.47% GC) +# maximum time: 35.323 ms (28.11% GC) +# -------------- +# samples: 10000 +# evals/sample: 1 diff --git a/src/autograd.jl b/src/autograd.jl index 4584decb0..a492c29d6 100644 --- a/src/autograd.jl +++ b/src/autograd.jl @@ -2,6 +2,8 @@ # this is a port of Python's autograd module # https://github.com/apache/incubator-mxnet/blob/master/python/mxnet/autograd.py +using Base.Meta: isexpr + ############################################################################### # Private util functions ############################################################################### @@ -385,3 +387,181 @@ end ############################################################################### # TODO: User-defined differentiable function ############################################################################### + +# gc-free holder +const _cbs_r = [Ref{Ptr{Void}}(C_NULL), Ref{Ptr{Void}}(C_NULL)] +const _cbs = [Ptr{Void}(C_NULL), Ptr{Void}(C_NULL)] +const _cbsref = Ref{Ptr{Ptr{Void}}}(C_NULL) +const _frefs = Dict() # hold custom function instance and its args + +function _init_customfunc() # will be invoked in __init__ + global _cbs_r + global _cbs + global _cbsref + _cbs_r[1][] = _cbs[1] = cfunction(_back_wrapper, Cint, + (Cint, Cint, Ptr{Ptr{Void}}, Ptr{Cint}, + Bool, Ptr{Void})) + _cbs_r[2][] = _cbs[2] = cfunction(_del_wrapper, Cint, (Ptr{Void},)) + _cbsref[] = Base.unsafe_convert(Ptr{Ptr{Void}}, _cbs) +end + +function _back_wrapper(num_ograds, num_igrads, ptrs, reqs, is_train, fptr::Ptr{Void}) + hdls = unsafe_wrap(Array, ptrs, num_ograds + num_igrads) + ograds = map(x -> NDArray(MX_NDArrayHandle(x), false), hdls[1:num_ograds]) + igrads = map(NDArray ∘ MX_NDArrayHandle, hdls[num_ograds+1:num_ograds+num_igrads]) + reqs = unsafe_wrap(Array, reqs, num_igrads) + + # passing closure via raw pointer + f = unsafe_pointer_to_objref(fptr) + + Δs = backward!(f, ograds...) + Δs = Δs isa NDArray ? [Δs] : Δs + + # update gradient + for (i, Δ, req) ∈ zip(igrads, Δs, reqs) + req = GRAD_REQ(req) + if req == GRAD_NOP + continue + elseif req ∈ (GRAD_WRITE, GRAD_INPLACE) + i[:] = Δ + elseif req == GRAD_ADD + i[:] += Δ + end + end + + # release ref for gc + delete!(_frefs, f) + + Cint(true) +end + +function _del_wrapper(ref) + cblist_ref = unsafe_pointer_to_objref(ref) + delete!(_cblists, cblist_ref) + Cint(true) +end + +struct MXCallbackList + n::Cint # int num_callbacks; + cbs::Ptr{Ptr{Void}} # int (**callbacks)(void); + ctxs::Ptr{Ptr{Void}} # void **contexts; + + # we must provide two callback functions + # the first is backward function `_back_wrapper` + # the second is delete callback `_del_wrapper` + # https://github.com/apache/incubator-mxnet/blob/2f8c1e83f94e84a25a48d2cd43136030fb3f2d1e/include/mxnet/c_api.h#L174-L182 + + # `ctxs` is a array which is same size as `cbs` + # its elements will be passed as `state` for callback functions, + # usually the last argument. + # In our case, we will push the pointer of custom func instance as + # first element of `ctxs`; the pointer of MXCallbackList instance as + # the second element. + # The purpose of first pointer is to pass closure into `cfunction`. + # The second pointer is to free the reference of MXCallbackList, + # and let the function instance be GC-ed properly. + + function MXCallbackList(f) # where all args are Refs + ctxs = [ + Base.unsafe_convert(Ptr{Void}, Ref(f)), + Ptr{Void}(C_NULL), + ] + ctxsptr = Base.unsafe_convert(Ptr{Ptr{Void}}, ctxs) + cblist = new(2, _cbsref[], ctxsptr) + # get the reference, and make a self-reference in ctxs[2] + cblist_ref = Ref{MXCallbackList}(cblist) + ctxs[2] = Base.unsafe_convert(Ptr{Void}, cblist_ref) + # insert ref into a holder to prevent from being GC-ed. + # hold `xs` and `ys` which is passed into `MXCustomFunctionRecord`. + _cblists[cblist_ref] = Ref(ctxs) + cblist_ref + end +end + +# hold MXCallbackList to prevent from gc +const _cblists = Dict{Ref{MXCallbackList},Ref}() + +_isparams(ex) = + isexpr(ex, :call) && length(ex.args) >= 2 && isexpr(ex.args[2], :parameters) + +_isfuncdef(ex) = + isexpr(ex, :function) || (isexpr(ex, :(=)) && isexpr(ex.args[1], :call)) + +""" + @custom + +Create callable custom function. +All the position-arguments should be `NDArray`. +The return value should be a instance of your custom type. + +Please checkout `examples/autograd/customfunc.jl` for example. +""" +macro custom(ex::Expr) + @assert(_isfuncdef(ex), "unspport syntax") + + sig = ex.args[1] + body = esc(Expr(:let, ex.args[2])) # create a new scope via `let` + + # forward(f, xs...) + forward_expr = copy(sig) + args = forward_expr.args + args[1] = :forward + i = !_isparams(sig) ? 2 : 3 + insert!(args, i, :f) + # properly escape + if _isparams(sig) + args[2] = esc(args[2]) + end + for j ∈ i+1:endof(args) + args[j] = esc(args[j]) + end + + # xs, without keyword arguments + xs_len = length(args[i+1:end]) + xs_expr = Expr(:vect, args[i+1:end]...) + + body′ = quote + f, ys = _record(false, nothing) do + f = $body # f is the object instance + ys = $forward_expr + f, ys + end + + !is_recording() && return ys + + xs = $xs_expr + ys′ = ys isa NDArray ? [ys] : ys + + # struct MXCallbackList + cblist_ref = MXCallbackList(f) + + # gc free + _frefs[f] = (Ref(xs), Ref(ys′)) + + @mxcall( + :MXCustomFunctionRecord, + (Cint, # num_inputs + Ptr{MX_handle}, # inputs + + Cint, # num_outputs + Ptr{MX_handle}, # outputs + + Ref{MXCallbackList}), # callbacks + $xs_len, + xs, + + length(ys′), + ys′, + + cblist_ref) + + ys + end + + Expr(:function, esc(sig), body′) +end + +# custom function should overload these functions. +# the # of forward return values is the inputs of backward!. +function forward end +function backward! end diff --git a/src/base.jl b/src/base.jl index b8f73eb4e..ab23a60f7 100644 --- a/src/base.jl +++ b/src/base.jl @@ -50,6 +50,8 @@ function __init__() global const LIB_VERSION = _get_lib_version() + _init_customfunc() + atexit() do # notify libmxnet we are shutting down ccall( ("MXNotifyShutdown", MXNET_LIB), Cint, () ) diff --git a/test/unittest/autograd.jl b/test/unittest/autograd.jl index 12c1022bd..05efd3d3c 100644 --- a/test/unittest/autograd.jl +++ b/test/unittest/autograd.jl @@ -361,7 +361,72 @@ function test_power() x.^.5 end end +end # function test_power + + +include(joinpath(@__DIR__, "..", "..", "examples", "autograd", "customfunc.jl")) + +struct foo end +@mx.custom foo(x) = foo() # test the compat form of func def +mx.forward(f::foo, x) = x + +struct bar + @mx.custom bar(x) = new() # test the compat form of func def end +mx.forward(f::bar, x) = x + +function test_custom_func() + info("AutoGrad::custom function") + @test isbits(mx.MXCallbackList) + + """ + swish with custom function + """ + function g() + x = mx.NDArray(Float32[1 2; 3 4]) + ∇ = mx.attach_grad!(x) + y = mx.record() do + swish(x) # from examples/autograd/customfunc.jl + end + mx.backward!(y, mx.NDArray(Float32[.5 .5; .5 .5])) + ∇ + end + + """ + swish2 with custom function + """ + function g2() + x = mx.NDArray(Float32[1 2; 3 4]) + ∇ = mx.attach_grad!(x) + y = mx.record() do + swish2(x) # from examples/autograd/customfunc.jl + end + mx.backward!(y, mx.NDArray(Float32[.5 .5; .5 .5])) + ∇ + end + + """ + swish without custom function + """ + function h() + x = mx.NDArray(Float32[1 2; 3 4]) + ∇ = mx.attach_grad!(x) + y = mx.record() do + x .* mx.sigmoid(x) + end + mx.backward!(y, mx.NDArray(Float32[.5 .5; .5 .5])) + ∇ + end + + @test copy(g()) ≈ copy(h()) + @test copy(g2()) ≈ copy(h()) + + let x = mx.NDArray(Float32[1, 2, 3, 4]) + @test copy(foo(x)) == copy(x) + + @test copy(bar(x)) == copy(x) + end +end # function test_custom_func @testset "AutoGrad Test" begin @@ -380,6 +445,7 @@ end test_mul() test_div() test_power() + test_custom_func() end From 947b69cf0f4709dce05c02944560d5624c034a93 Mon Sep 17 00:00:00 2001 From: Iblis Lin Date: Thu, 21 Dec 2017 13:38:34 +0800 Subject: [PATCH 2/3] exports --- NEWS.md | 16 +++++++++++++++- examples/autograd/customfunc.jl | 12 ++++++------ src/MXNet.jl | 14 ++++++++++++++ test/unittest/autograd.jl | 4 ++-- 4 files changed, 37 insertions(+), 9 deletions(-) diff --git a/NEWS.md b/NEWS.md index 4540cba50..e01740d24 100644 --- a/NEWS.md +++ b/NEWS.md @@ -75,7 +75,21 @@ ### `NDArray` -* A port of Python's `autograd` for `NDArray` (#274) +* A port of Python's `autograd` for `NDArray`. (#274, #385) + Following APIs exported: + + * `attach_grad!()` + * `backward!()` + * `getgrad()` + * `is_recording()` + * `is_training()` + * `mark_variables()` + * `pause()` + * `predict_mode()` + * `record()` + * `symbol()` + * `train_mode()` + * `@custom` * `size(x, dims...)` is supported now. (#TBD) diff --git a/examples/autograd/customfunc.jl b/examples/autograd/customfunc.jl index aff1284fe..b16598d7a 100644 --- a/examples/autograd/customfunc.jl +++ b/examples/autograd/customfunc.jl @@ -14,7 +14,7 @@ mutable struct swish x y σ - @mx.custom function swish(x) + @custom function swish(x) σ = @. 1 / (1 + e^(-x)) # assume there is no mx.sigmoid; we need to get hand dirty y = x .* σ new(x, σ, y) # must return a object instance for @custom @@ -37,7 +37,7 @@ mutable struct swish2 σ end -@mx.custom function swish2(x) +@custom function swish2(x) σ = @. 1 / (1 + e^(-x)) y = x .* σ swish2(x, σ, y) # must return a object instance for @custom @@ -52,12 +52,12 @@ mx.backward!(f::swish2, Δy #= coefficient of gradient =#) = # example usage ############################################################################### -x = mx.NDArray(Float32[1 2; 3 4]) -∇ = mx.attach_grad!(x) -y = mx.record() do +x = NDArray(Float32[1 2; 3 4]) +∇ = attach_grad!(x) +y = record() do swish(x) end -mx.backward!(y) +backward!(y) ∇ diff --git a/src/MXNet.jl b/src/MXNet.jl index 352d20aad..492ef2d05 100644 --- a/src/MXNet.jl +++ b/src/MXNet.jl @@ -34,6 +34,20 @@ export NDArray, context, empty +# autograd.jl +export attach_grad!, + backward!, + getgrad, + is_recording, + is_training, + mark_variables, + pause, + predict_mode, + record, + symbol, + train_mode, + @custom + # executor.jl export Executor, bind, diff --git a/test/unittest/autograd.jl b/test/unittest/autograd.jl index 05efd3d3c..b5f0b0b1d 100644 --- a/test/unittest/autograd.jl +++ b/test/unittest/autograd.jl @@ -367,11 +367,11 @@ end # function test_power include(joinpath(@__DIR__, "..", "..", "examples", "autograd", "customfunc.jl")) struct foo end -@mx.custom foo(x) = foo() # test the compat form of func def +@custom foo(x) = foo() # test the compat form of func def mx.forward(f::foo, x) = x struct bar - @mx.custom bar(x) = new() # test the compat form of func def + @custom bar(x) = new() # test the compat form of func def end mx.forward(f::bar, x) = x From 61821c5bcfff6556c04bbc67cc0500e99e5e3b8f Mon Sep 17 00:00:00 2001 From: Iblis Lin Date: Sun, 24 Dec 2017 00:49:43 +0800 Subject: [PATCH 3/3] where syntax support --- src/autograd.jl | 39 +++++++++++++------------------ test/unittest/autograd.jl | 48 +++++++++++++++++++++++++++++++-------- 2 files changed, 55 insertions(+), 32 deletions(-) diff --git a/src/autograd.jl b/src/autograd.jl index a492c29d6..37cac0493 100644 --- a/src/autograd.jl +++ b/src/autograd.jl @@ -385,7 +385,7 @@ function symbol(x::NDArray) end ############################################################################### -# TODO: User-defined differentiable function +# User-defined differentiable function ############################################################################### # gc-free holder @@ -481,12 +481,6 @@ end # hold MXCallbackList to prevent from gc const _cblists = Dict{Ref{MXCallbackList},Ref}() -_isparams(ex) = - isexpr(ex, :call) && length(ex.args) >= 2 && isexpr(ex.args[2], :parameters) - -_isfuncdef(ex) = - isexpr(ex, :function) || (isexpr(ex, :(=)) && isexpr(ex.args[1], :call)) - """ @custom @@ -497,28 +491,27 @@ The return value should be a instance of your custom type. Please checkout `examples/autograd/customfunc.jl` for example. """ macro custom(ex::Expr) - @assert(_isfuncdef(ex), "unspport syntax") - + fdef = splitdef(ex) # by MacroTools sig = ex.args[1] body = esc(Expr(:let, ex.args[2])) # create a new scope via `let` + # only extract symbols, get rid of all annotations and default values + args = map(x -> esc(splitarg(x)[1]), fdef[:args]) # forward(f, xs...) - forward_expr = copy(sig) - args = forward_expr.args - args[1] = :forward - i = !_isparams(sig) ? 2 : 3 - insert!(args, i, :f) - # properly escape - if _isparams(sig) - args[2] = esc(args[2]) - end - for j ∈ i+1:endof(args) - args[j] = esc(args[j]) + forward_expr = Expr(:call, :forward, :f, args...) + # insert keyword args + if !isempty(fdef[:kwargs]) + # only extract symbols, get rid of all annotations and default values + kwargs = map(fdef[:kwargs]) do x + sym = splitarg(x)[1] + Expr(:kw, sym, esc(sym)) + end + append!(forward_expr.args, kwargs) end - # xs, without keyword arguments - xs_len = length(args[i+1:end]) - xs_expr = Expr(:vect, args[i+1:end]...) + # xs, FIXME: a list of NDArray from positional argument + xs_len = length(args) + xs_expr = Expr(:vect, args...) body′ = quote f, ys = _record(false, nothing) do diff --git a/test/unittest/autograd.jl b/test/unittest/autograd.jl index b5f0b0b1d..68938df3a 100644 --- a/test/unittest/autograd.jl +++ b/test/unittest/autograd.jl @@ -375,6 +375,30 @@ struct bar end mx.forward(f::bar, x) = x +struct baz{T} # test parametric type + x::T + @custom baz{T}(x) where T = new(x) # test `where` syntax +end +mx.forward(f::baz, x) = x + +struct qaz{T} + x::T + @custom function qaz{T}(x) where T # test `where` syntax + new(x) + end +end +mx.forward(f::qaz, x) = x + +# test keyword args +struct test_kw + x + @custom test_kw(x; magic = false) = new(x) +end +function mx.forward(f::test_kw, x; magic = false) + @assert magic + x +end + function test_custom_func() info("AutoGrad::custom function") @test isbits(mx.MXCallbackList) @@ -383,9 +407,9 @@ function test_custom_func() swish with custom function """ function g() - x = mx.NDArray(Float32[1 2; 3 4]) - ∇ = mx.attach_grad!(x) - y = mx.record() do + x = NDArray(Float32[1 2; 3 4]) + ∇ = attach_grad!(x) + y = record() do swish(x) # from examples/autograd/customfunc.jl end mx.backward!(y, mx.NDArray(Float32[.5 .5; .5 .5])) @@ -396,9 +420,9 @@ function test_custom_func() swish2 with custom function """ function g2() - x = mx.NDArray(Float32[1 2; 3 4]) - ∇ = mx.attach_grad!(x) - y = mx.record() do + x = NDArray(Float32[1 2; 3 4]) + ∇ = attach_grad!(x) + y = record() do swish2(x) # from examples/autograd/customfunc.jl end mx.backward!(y, mx.NDArray(Float32[.5 .5; .5 .5])) @@ -409,9 +433,9 @@ function test_custom_func() swish without custom function """ function h() - x = mx.NDArray(Float32[1 2; 3 4]) - ∇ = mx.attach_grad!(x) - y = mx.record() do + x = NDArray(Float32[1 2; 3 4]) + ∇ = attach_grad!(x) + y = record() do x .* mx.sigmoid(x) end mx.backward!(y, mx.NDArray(Float32[.5 .5; .5 .5])) @@ -425,6 +449,12 @@ function test_custom_func() @test copy(foo(x)) == copy(x) @test copy(bar(x)) == copy(x) + + @test copy(baz{Any}(x)) == copy(x) + + @test copy(qaz{Any}(x)) == copy(x) + + @test copy(test_kw(x, magic = true)) == copy(x) end end # function test_custom_func