Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,22 @@

### `NDArray`

* A port of Python's `autograd` for `NDArray`. (#274, #385)
Following APIs exported:

* `attach_grad!()`
* `backward!()`
* `getgrad()`
* `is_recording()`
* `is_training()`
* `mark_variables()`
* `pause()`
* `predict_mode()`
* `record()`
* `symbol()`
* `train_mode()`
* `@custom`

* A handy constructor: `NDArray(Type, AbstractArray)` is added. (#TBD)

E.g.
Expand Down
90 changes: 90 additions & 0 deletions examples/autograd/customfunc.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
using MXNet

###############################################################################
# swish: option 1, with inner constructor
###############################################################################

"""
swish(x) = x * σ(x)

See [Swish: a Self-Gated Activation Function]
(https://arxiv.org/pdf/1710.05941.pdf).
"""
mutable struct swish
x
y
σ
@custom function swish(x)
σ = @. 1 / (1 + e^(-x)) # assume there is no mx.sigmoid; we need to get hand dirty
y = x .* σ
new(x, σ, y) # must return a object instance for @custom
end
end

# the actual return value
mx.forward(f::swish, x) = f.y

mx.backward!(f::swish, Δy #= coefficient of gradient =#) =
@. (f.y + f.σ * (1 - f.y)) * Δy

###############################################################################
# swish2: option 2, with outer constructor
###############################################################################

mutable struct swish2
x
y
σ
end

@custom function swish2(x)
σ = @. 1 / (1 + e^(-x))
y = x .* σ
swish2(x, σ, y) # must return a object instance for @custom
end

mx.forward(f::swish2, x) = f.y

mx.backward!(f::swish2, Δy #= coefficient of gradient =#) =
@. (f.y + f.σ * (1 - f.y)) * Δy

###############################################################################
# example usage
###############################################################################

x = NDArray(Float32[1 2; 3 4])
∇ = attach_grad!(x)
y = record() do
swish(x)
end
backward!(y)


# For the record, here shows the overhead of custom function
#
# julia> @benchmark g() # custom func swish
# BenchmarkTools.Trial:
# memory estimate: 29.83 KiB
# allocs estimate: 608
# --------------
# minimum time: 372.205 μs (0.00% GC)
# median time: 475.992 μs (0.00% GC)
# mean time: 565.441 μs (3.65% GC)
# maximum time: 33.960 ms (47.71% GC)
# --------------
# samples: 8723
# evals/sample: 1
#
# julia> @benchmark h() # with native NDArray operator
# BenchmarkTools.Trial:
# memory estimate: 9.39 KiB
# allocs estimate: 184
# --------------
# minimum time: 179.940 μs (0.00% GC)
# median time: 234.188 μs (0.00% GC)
# mean time: 264.236 μs (1.47% GC)
# maximum time: 35.323 ms (28.11% GC)
# --------------
# samples: 10000
# evals/sample: 1
14 changes: 14 additions & 0 deletions src/MXNet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,20 @@ export NDArray,
softmax,
log_softmax

# autograd.jl
export attach_grad!,
backward!,
getgrad,
is_recording,
is_training,
mark_variables,
pause,
predict_mode,
record,
symbol,
train_mode,
@custom

# executor.jl
export Executor,
bind,
Expand Down
175 changes: 174 additions & 1 deletion src/autograd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
# this is a port of Python's autograd module
# https://github.com/apache/incubator-mxnet/blob/master/python/mxnet/autograd.py

using Base.Meta: isexpr

###############################################################################
# Private util functions
###############################################################################
Expand Down Expand Up @@ -383,5 +385,176 @@ function symbol(x::NDArray)
end

###############################################################################
# TODO: User-defined differentiable function
# User-defined differentiable function
###############################################################################

# gc-free holder
const _cbs_r = [Ref{Ptr{Void}}(C_NULL), Ref{Ptr{Void}}(C_NULL)]
const _cbs = [Ptr{Void}(C_NULL), Ptr{Void}(C_NULL)]
const _cbsref = Ref{Ptr{Ptr{Void}}}(C_NULL)
const _frefs = Dict() # hold custom function instance and its args

function _init_customfunc() # will be invoked in __init__
global _cbs_r
global _cbs
global _cbsref
_cbs_r[1][] = _cbs[1] = cfunction(_back_wrapper, Cint,
(Cint, Cint, Ptr{Ptr{Void}}, Ptr{Cint},
Bool, Ptr{Void}))
_cbs_r[2][] = _cbs[2] = cfunction(_del_wrapper, Cint, (Ptr{Void},))
_cbsref[] = Base.unsafe_convert(Ptr{Ptr{Void}}, _cbs)
end

function _back_wrapper(num_ograds, num_igrads, ptrs, reqs, is_train, fptr::Ptr{Void})
hdls = unsafe_wrap(Array, ptrs, num_ograds + num_igrads)
ograds = map(x -> NDArray(MX_NDArrayHandle(x), false), hdls[1:num_ograds])
igrads = map(NDArray ∘ MX_NDArrayHandle, hdls[num_ograds+1:num_ograds+num_igrads])
reqs = unsafe_wrap(Array, reqs, num_igrads)

# passing closure via raw pointer
f = unsafe_pointer_to_objref(fptr)

Δs = backward!(f, ograds...)
Δs = Δs isa NDArray ? [Δs] : Δs

# update gradient
for (i, Δ, req) ∈ zip(igrads, Δs, reqs)
req = GRAD_REQ(req)
if req == GRAD_NOP
continue
elseif req ∈ (GRAD_WRITE, GRAD_INPLACE)
i[:] = Δ
elseif req == GRAD_ADD
i[:] += Δ
end
end

# release ref for gc
delete!(_frefs, f)

Cint(true)
end

function _del_wrapper(ref)
cblist_ref = unsafe_pointer_to_objref(ref)
delete!(_cblists, cblist_ref)
Cint(true)
end

struct MXCallbackList
n::Cint # int num_callbacks;
cbs::Ptr{Ptr{Void}} # int (**callbacks)(void);
ctxs::Ptr{Ptr{Void}} # void **contexts;

# we must provide two callback functions
# the first is backward function `_back_wrapper`
# the second is delete callback `_del_wrapper`
# https://github.com/apache/incubator-mxnet/blob/2f8c1e83f94e84a25a48d2cd43136030fb3f2d1e/include/mxnet/c_api.h#L174-L182

# `ctxs` is a array which is same size as `cbs`
# its elements will be passed as `state` for callback functions,
# usually the last argument.
# In our case, we will push the pointer of custom func instance as
# first element of `ctxs`; the pointer of MXCallbackList instance as
# the second element.
# The purpose of first pointer is to pass closure into `cfunction`.
# The second pointer is to free the reference of MXCallbackList,
# and let the function instance be GC-ed properly.

function MXCallbackList(f) # where all args are Refs
ctxs = [
Base.unsafe_convert(Ptr{Void}, Ref(f)),
Ptr{Void}(C_NULL),
]
ctxsptr = Base.unsafe_convert(Ptr{Ptr{Void}}, ctxs)
cblist = new(2, _cbsref[], ctxsptr)
# get the reference, and make a self-reference in ctxs[2]
cblist_ref = Ref{MXCallbackList}(cblist)
ctxs[2] = Base.unsafe_convert(Ptr{Void}, cblist_ref)
# insert ref into a holder to prevent from being GC-ed.
# hold `xs` and `ys` which is passed into `MXCustomFunctionRecord`.
_cblists[cblist_ref] = Ref(ctxs)
cblist_ref
end
end

# hold MXCallbackList to prevent from gc
const _cblists = Dict{Ref{MXCallbackList},Ref}()

"""
@custom

Create callable custom function.
All the position-arguments should be `NDArray`.
The return value should be a instance of your custom type.

Please checkout `examples/autograd/customfunc.jl` for example.
"""
macro custom(ex::Expr)
fdef = splitdef(ex) # by MacroTools
sig = ex.args[1]
body = esc(Expr(:let, ex.args[2])) # create a new scope via `let`

# only extract symbols, get rid of all annotations and default values
args = map(x -> esc(splitarg(x)[1]), fdef[:args])
# forward(f, xs...)
forward_expr = Expr(:call, :forward, :f, args...)
# insert keyword args
if !isempty(fdef[:kwargs])
# only extract symbols, get rid of all annotations and default values
kwargs = map(fdef[:kwargs]) do x
sym = splitarg(x)[1]
Expr(:kw, sym, esc(sym))
end
append!(forward_expr.args, kwargs)
end

# xs, FIXME: a list of NDArray from positional argument
xs_len = length(args)
xs_expr = Expr(:vect, args...)

body′ = quote
f, ys = _record(false, nothing) do
f = $body # f is the object instance
ys = $forward_expr
f, ys
end

!is_recording() && return ys

xs = $xs_expr
ys′ = ys isa NDArray ? [ys] : ys

# struct MXCallbackList
cblist_ref = MXCallbackList(f)

# gc free
_frefs[f] = (Ref(xs), Ref(ys′))

@mxcall(
:MXCustomFunctionRecord,
(Cint, # num_inputs
Ptr{MX_handle}, # inputs

Cint, # num_outputs
Ptr{MX_handle}, # outputs

Ref{MXCallbackList}), # callbacks
$xs_len,
xs,

length(ys′),
ys′,

cblist_ref)

ys
end

Expr(:function, esc(sig), body′)
end

# custom function should overload these functions.
# the # of forward return values is the inputs of backward!.
function forward end
function backward! end
2 changes: 2 additions & 0 deletions src/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ function __init__()

global const LIB_VERSION = _get_lib_version()

_init_customfunc()

atexit() do
# notify libmxnet we are shutting down
ccall( ("MXNotifyShutdown", MXNET_LIB), Cint, () )
Expand Down
Loading