I have followed Compiling Lux Models using Reactant.jl tutorial and noticed several issues. I have made below a simplified reproducible MWE of this tutorial, where different setups can be specified by argument passed to the script i.e. julia tutorial.jl <arg>. Here are the issues:
- Compile time for
train_model function for arg=1 and arg=3 is 400s which is almost 7min (for arg=0 it is 80s), which seems way to long given how simple is the example and that Reactant is the recommended backend.
- Trying to run
train_model for the 2nd time crashed the script when Reactant is used (use arg=1 and arg=3 to reproduce), works fine when cpu_device() is used (arg=0).
- The error between non-Reactant and Reactant inference results differs from the tutorial. In the tutorial errors have order of
1e-8, but when using the GPU in MWE results have worryingly high error of 1e-4, which is 4 orders of magnitude higher and not acceptable (use arg=3 to reproduce).
- Even though this example uses small neural network it still immediately allocates 75% of available VRAM, which in my case is 12GB. Is it expected? (use arg=3 to reproduce).
I would especially appreciate suggestions on how to address the first issue, i.e. how to use Lux with maximum efficiency on both CPU and GPU while avoiding extremely long compilation times.
Here is the MWE code:
# Code adopted from Lux.jl tutorial
# https://lux.csail.mit.edu/stable/manual/compiling_lux_models
using Lux, Reactant, Enzyme, Random
using Optimisers, Printf, Statistics, Logging
mode = isempty(ARGS) ? 0 : parse(Int, ARGS[1]) # 0..3
@assert 0 <= mode <= 3 "Pass one Int in 0..3 (bit 1 =SET_GPU_BACKEND, bit 0 =USE_REACTANT_DEVICE)"
const SET_GPU_BACKEND::Bool = (mode & 0x2) != 0
const USE_REACTANT_DEVICE::Bool = (mode & 0x1) != 0
@info "SET_GPU_BACKEND = $SET_GPU_BACKEND"
@info "USE_REACTANT_DEVICE = $USE_REACTANT_DEVICE"
if SET_GPU_BACKEND
Reactant.set_default_backend("gpu")
else
Reactant.set_default_backend("cpu")
end
const xdev = USE_REACTANT_DEVICE ? reactant_device() : cpu_device()
@info "xdev = $xdev"
model = Chain(
Dense(2 => 32, gelu),
Dense(32 => 32, gelu),
Dense(32 => 2)
)
ps, st = Lux.setup(MersenneTwister(42), model)
x = randn(Float32, 2, 32)
y = x .^ 2
x_ra = x |> xdev
y_ra = y |> xdev
ps_ra = ps |> xdev
st_ra = st |> xdev
pred_lux, _ = model(x, ps, Lux.testmode(st))
@info "Compiling model with Reactant..."
@time model_compiled = @compile model(x_ra, ps_ra, Lux.testmode(st_ra))
@info "Inference with compiled model... (1st time)"
@time pred_compiled, _ = model_compiled(x_ra, ps_ra, Lux.testmode(st_ra))
@info "Inference with compiled model... (2nd time)"
@time pred_compiled, _ = model_compiled(x_ra, ps_ra, Lux.testmode(st_ra))
diff = pred_lux .- Array(pred_compiled)
@info "Mean difference: $(mean(abs.(diff)))"
# Second part of tutorial
model = Chain(
Dense(2 => 4, gelu),
Dense(4 => 4, gelu),
Dense(4 => 2)
)
ps, st = Lux.setup(MersenneTwister(42), model)
x_ra = [randn(Float32, 2, 32) for _ in 1:32]
y_ra = [xᵢ .^ 2 for xᵢ in x_ra]
ps_ra = ps |> xdev
st_ra = st |> xdev
dataloader = DeviceIterator(xdev, zip(x_ra, y_ra))
function train_model(model, ps, st, dataloader)
train_state = Training.TrainState(model, ps, st, Adam(0.001f0))
for iteration in 1:1000
for (i, (xᵢ, yᵢ)) in enumerate(dataloader)
_, loss, _, train_state = Training.single_train_step!(
AutoEnzyme(), MSELoss(), (xᵢ, yᵢ), train_state
)
if (iteration % 100 == 0 || iteration == 1) && i == 1
@printf("Iter: [%4d/%4d]\tLoss: %.8f\n", iteration, 1000, loss)
end
end
end
return train_state
end
@info "Training model with Reactant... (1st time)"
@time train_model(model, ps_ra, st_ra, dataloader)
@info "Training model with Reactant... (2nd time)"
@time train_model(model, ps_ra, st_ra, dataloader)
Output for arg=0
LuxBenchmark on master [!+?] via ஃ v1.12.5
❯ julia --project=. lux_tutorial.jl 0
[ Info: SET_GPU_BACKEND = false
[ Info: USE_REACTANT_DEVICE = false
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1772217755.166700 96056 service.cc:153] XLA service 0x13993ab0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1772217755.166723 96056 service.cc:161] StreamExecutor [0]: NVIDIA GeForce RTX 5070 Ti, Compute Capability 12.0a (Driver: 13.0.0; Runtime: 13.0.0; Toolkit: 13.0.0; DNN: 9.14.0)
I0000 00:00:1772217755.167400 96056 se_gpu_pjrt_client.cc:1467] Using BFC allocator.
I0000 00:00:1772217755.167429 96056 gpu_helpers.cc:141] XLA backend allocating 12439388160 bytes on device 0 for BFCAllocator.
I0000 00:00:1772217755.167455 96056 gpu_helpers.cc:183] XLA backend will use up to 4146462720 bytes on device 0 for CollectiveBFCAllocator.
I0000 00:00:1772217755.173230 96056 cuda_dnn.cc:461] Loaded cuDNN version 91400
[ Info: xdev = CPUDevice{Missing}()
[ Info: Compiling model with Reactant...
21.120927 seconds (51.89 M allocations: 2.559 GiB, 5.21% gc time, 99.63% compilation time: 1% of which was recompilation)
[ Info: Inference with compiled model... (1st time)
0.034475 seconds (105.11 k allocations: 4.617 MiB, 99.86% compilation time)
[ Info: Inference with compiled model... (2nd time)
0.000017 seconds (15 allocations: 864 bytes)
[ Info: Mean difference: 0.0
[ Info: Training model with Reactant... (1st time)
Iter: [ 1/1000] Loss: 2.73682380
Iter: [ 100/1000] Loss: 0.85331267
Iter: [ 200/1000] Loss: 0.24761316
Iter: [ 300/1000] Loss: 0.04767180
Iter: [ 400/1000] Loss: 0.01951925
Iter: [ 500/1000] Loss: 0.01071168
Iter: [ 600/1000] Loss: 0.00661943
Iter: [ 700/1000] Loss: 0.00433787
Iter: [ 800/1000] Loss: 0.00312395
Iter: [ 900/1000] Loss: 0.00235731
Iter: [1000/1000] Loss: 0.00182673
83.286212 seconds (192.08 M allocations: 9.023 GiB, 2.38% gc time, 99.37% compilation time)
[ Info: Training model with Reactant... (2nd time)
Iter: [ 1/1000] Loss: 0.00182255
Iter: [ 100/1000] Loss: 0.00148260
Iter: [ 200/1000] Loss: 0.00124297
Iter: [ 300/1000] Loss: 0.00105778
Iter: [ 400/1000] Loss: 0.00090997
Iter: [ 500/1000] Loss: 0.00079514
Iter: [ 600/1000] Loss: 0.00070684
Iter: [ 700/1000] Loss: 0.00063873
Iter: [ 800/1000] Loss: 0.00058609
Iter: [ 900/1000] Loss: 0.00054562
Iter: [1000/1000] Loss: 0.00051490
0.510780 seconds (9.20 M allocations: 606.092 MiB, 8.45% gc time)
LuxBenchmark on master [+?] via ஃ v1.12.5 took 1m52s
❯
Output for arg=1
LuxBenchmark on master [+?] via ஃ v1.12.5
❯ julia --project=. lux_tutorial.jl 1
[ Info: SET_GPU_BACKEND = false
[ Info: USE_REACTANT_DEVICE = true
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1772217899.422195 96881 service.cc:153] XLA service 0x3da6a250 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1772217899.422221 96881 service.cc:161] StreamExecutor [0]: NVIDIA GeForce RTX 5070 Ti, Compute Capability 12.0a (Driver: 13.0.0; Runtime: 13.0.0; Toolkit: 13.0.0; DNN: 9.14.0)
I0000 00:00:1772217899.422916 96881 se_gpu_pjrt_client.cc:1467] Using BFC allocator.
I0000 00:00:1772217899.422952 96881 gpu_helpers.cc:141] XLA backend allocating 12439388160 bytes on device 0 for BFCAllocator.
I0000 00:00:1772217899.422987 96881 gpu_helpers.cc:183] XLA backend will use up to 4146462720 bytes on device 0 for CollectiveBFCAllocator.
I0000 00:00:1772217899.429490 96881 cuda_dnn.cc:461] Loaded cuDNN version 91400
[ Info: xdev = ReactantDevice{Missing, Missing, Missing, Missing, Union{}}(missing, missing, missing)
[ Info: Compiling model with Reactant...
28.926677 seconds (72.38 M allocations: 3.536 GiB, 4.14% gc time, 99.47% compilation time: <1% of which was recompilation)
[ Info: Inference with compiled model... (1st time)
0.311331 seconds (816.71 k allocations: 40.729 MiB, 99.94% compilation time)
[ Info: Inference with compiled model... (2nd time)
0.000077 seconds (27 allocations: 944 bytes)
[ Info: Mean difference: 5.0582457e-8
[ Info: Training model with Reactant... (1st time)
Iter: [ 1/1000] Loss: 0.81762600
Iter: [ 100/1000] Loss: 0.13008928
Iter: [ 200/1000] Loss: 0.03643125
Iter: [ 300/1000] Loss: 0.02133664
Iter: [ 400/1000] Loss: 0.01420628
Iter: [ 500/1000] Loss: 0.00967961
Iter: [ 600/1000] Loss: 0.00724365
Iter: [ 700/1000] Loss: 0.00563383
Iter: [ 800/1000] Loss: 0.00444563
Iter: [ 900/1000] Loss: 0.00351061
Iter: [1000/1000] Loss: 0.00274719
403.099325 seconds (956.33 M allocations: 46.126 GiB, 2.17% gc time, 99.29% compilation time: <1% of which was recompilation)
[ Info: Training model with Reactant... (2nd time)
ERROR: LoadError: INVALID_ARGUMENT: Invalid buffer passed to Execute() as argument 2 to replica 0: INVALID_ARGUMENT: Buffer has been deleted or donated.
Stacktrace:
[1] reactant_err(msg::Cstring)
@ Reactant.XLA ~/.julia/packages/Reactant/mMNEy/src/xla/Utils.jl:12
[2] macro expansion
@ ~/.julia/packages/Reactant/mMNEy/src/xla/PJRT/LoadedExecutable.jl:217 [inlined]
[3] execute_sharded
@ ~/.julia/packages/Reactant/mMNEy/src/xla/PJRT/LoadedExecutable.jl:186 [inlined]
[4] macro expansion
@ ~/.julia/packages/Reactant/mMNEy/src/Compiler.jl:3788 [inlined]
[5] (::Reactant.Compiler.Thunk{typeof(ReactantExt.compute_gradients_internal_and_step!), Symbol("##compute_gradients_internal_and_step!_reactant#1148612"), false, Tuple{GenericLossFunction{typeof(Lux.LossFunctionImpl.l2_distance_loss), typeof(mean)}, Chain{@NamedTuple{layer_1::Dense{typeof(NNlib.gelu_tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(NNlib.gelu_tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, Tuple{ConcretePJRTArray{Float32, 2, 1}, ConcretePJRTArray{Float32, 2, 1}}, @NamedTuple{layer_1::@NamedTuple{weight::ConcretePJRTArray{Float32, 2, 1}, bias::ConcretePJRTArray{Float32, 1, 1}}, layer_2::@NamedTuple{weight::ConcretePJRTArray{Float32, 2, 1}, bias::ConcretePJRTArray{Float32, 1, 1}}, layer_3::@NamedTuple{weight::ConcretePJRTArray{Float32, 2, 1}, bias::ConcretePJRTArray{Float32, 1, 1}}}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, @NamedTuple{layer_1::@NamedTuple{weight::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1}, Tuple{ConcretePJRTNumber{Float64, 1}, ConcretePJRTNumber{Float64, 1}}, ConcretePJRTNumber{Float64, 1}}}, Tuple{ConcretePJRTArray{Float32, 2, 1}, ConcretePJRTArray{Float32, 2, 1}, Tuple{ConcretePJRTNumber{Float32, 1}, ConcretePJRTNumber{Float32, 1}}}}, bias::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1}, Tuple{ConcretePJRTNumber{Float64, 1}, ConcretePJRTNumber{Float64, 1}}, ConcretePJRTNumber{Float64, 1}}}, Tuple{ConcretePJRTArray{Float32, 1, 1}, ConcretePJRTArray{Float32, 1, 1}, Tuple{ConcretePJRTNumber{Float32, 1}, ConcretePJRTNumber{Float32, 1}}}}}, layer_2::@NamedTuple{weight::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1}, Tuple{ConcretePJRTNumber{Float64, 1}, ConcretePJRTNumber{Float64, 1}}, ConcretePJRTNumber{Float64, 1}}}, Tuple{ConcretePJRTArray{Float32, 2, 1}, ConcretePJRTArray{Float32, 2, 1}, Tuple{ConcretePJRTNumber{Float32, 1}, ConcretePJRTNumber{Float32, 1}}}}, bias::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1}, Tuple{ConcretePJRTNumber{Float64, 1}, ConcretePJRTNumber{Float64, 1}}, ConcretePJRTNumber{Float64, 1}}}, Tuple{ConcretePJRTArray{Float32, 1, 1}, ConcretePJRTArray{Float32, 1, 1}, Tuple{ConcretePJRTNumber{Float32, 1}, ConcretePJRTNumber{Float32, 1}}}}}, layer_3::@NamedTuple{weight::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1}, Tuple{ConcretePJRTNumber{Float64, 1}, ConcretePJRTNumber{Float64, 1}}, ConcretePJRTNumber{Float64, 1}}}, Tuple{ConcretePJRTArray{Float32, 2, 1}, ConcretePJRTArray{Float32, 2, 1}, Tuple{ConcretePJRTNumber{Float32, 1}, ConcretePJRTNumber{Float32, 1}}}}, bias::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1}, Tuple{ConcretePJRTNumber{Float64, 1}, ConcretePJRTNumber{Float64, 1}}, ConcretePJRTNumber{Float64, 1}}}, Tuple{ConcretePJRTArray{Float32, 1, 1}, ConcretePJRTArray{Float32, 1, 1}, Tuple{ConcretePJRTNumber{Float32, 1}, ConcretePJRTNumber{Float32, 1}}}}}}, @NamedTuple{layer_1::@NamedTuple{weight::ConcretePJRTArray{Float32, 2, 1}, bias::ConcretePJRTArray{Float32, 1, 1}}, layer_2::@NamedTuple{weight::ConcretePJRTArray{Float32, 2, 1}, bias::ConcretePJRTArray{Float32, 1, 1}}, layer_3::@NamedTuple{weight::ConcretePJRTArray{Float32, 2, 1}, bias::ConcretePJRTArray{Float32, 1, 1}}}, Bool}, Reactant.XLA.PJRT.LoadedExecutable, Reactant.XLA.PJRT.Device, Reactant.XLA.PJRT.Client, Tuple{}, Vector{Bool}})(::GenericLossFunction{typeof(Lux.LossFunctionImpl.l2_distance_loss), typeof(mean)}, ::Chain{@NamedTuple{layer_1::Dense{typeof(NNlib.gelu_tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(NNlib.gelu_tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, ::Tuple{ConcretePJRTArray{Float32, 2, 1}, ConcretePJRTArray{Float32, 2, 1}}, ::@NamedTuple{layer_1::@NamedTuple{weight::ConcretePJRTArray{Float32, 2, 1}, bias::ConcretePJRTArray{Float32, 1, 1}}, layer_2::@NamedTuple{weight::ConcretePJRTArray{Float32, 2, 1}, bias::ConcretePJRTArray{Float32, 1, 1}}, layer_3::@NamedTuple{weight::ConcretePJRTArray{Float32, 2, 1}, bias::ConcretePJRTArray{Float32, 1, 1}}}, ::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, ::@NamedTuple{layer_1::@NamedTuple{weight::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1}, Tuple{ConcretePJRTNumber{Float64, 1}, ConcretePJRTNumber{Float64, 1}}, ConcretePJRTNumber{Float64, 1}}}, Tuple{ConcretePJRTArray{Float32, 2, 1}, ConcretePJRTArray{Float32, 2, 1}, Tuple{ConcretePJRTNumber{Float32, 1}, ConcretePJRTNumber{Float32, 1}}}}, bias::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1}, Tuple{ConcretePJRTNumber{Float64, 1}, ConcretePJRTNumber{Float64, 1}}, ConcretePJRTNumber{Float64, 1}}}, Tuple{ConcretePJRTArray{Float32, 1, 1}, ConcretePJRTArray{Float32, 1, 1}, Tuple{ConcretePJRTNumber{Float32, 1}, ConcretePJRTNumber{Float32, 1}}}}}, layer_2::@NamedTuple{weight::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1}, Tuple{ConcretePJRTNumber{Float64, 1}, ConcretePJRTNumber{Float64, 1}}, ConcretePJRTNumber{Float64, 1}}}, Tuple{ConcretePJRTArray{Float32, 2, 1}, ConcretePJRTArray{Float32, 2, 1}, Tuple{ConcretePJRTNumber{Float32, 1}, ConcretePJRTNumber{Float32, 1}}}}, bias::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1}, Tuple{ConcretePJRTNumber{Float64, 1}, ConcretePJRTNumber{Float64, 1}}, ConcretePJRTNumber{Float64, 1}}}, Tuple{ConcretePJRTArray{Float32, 1, 1}, ConcretePJRTArray{Float32, 1, 1}, Tuple{ConcretePJRTNumber{Float32, 1}, ConcretePJRTNumber{Float32, 1}}}}}, layer_3::@NamedTuple{weight::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1}, Tuple{ConcretePJRTNumber{Float64, 1}, ConcretePJRTNumber{Float64, 1}}, ConcretePJRTNumber{Float64, 1}}}, Tuple{ConcretePJRTArray{Float32, 2, 1}, ConcretePJRTArray{Float32, 2, 1}, Tuple{ConcretePJRTNumber{Float32, 1}, ConcretePJRTNumber{Float32, 1}}}}, bias::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1}, Tuple{ConcretePJRTNumber{Float64, 1}, ConcretePJRTNumber{Float64, 1}}, ConcretePJRTNumber{Float64, 1}}}, Tuple{ConcretePJRTArray{Float32, 1, 1}, ConcretePJRTArray{Float32, 1, 1}, Tuple{ConcretePJRTNumber{Float32, 1}, ConcretePJRTNumber{Float32, 1}}}}}}, ::@NamedTuple{layer_1::@NamedTuple{weight::ConcretePJRTArray{Float32, 2, 1}, bias::ConcretePJRTArray{Float32, 1, 1}}, layer_2::@NamedTuple{weight::ConcretePJRTArray{Float32, 2, 1}, bias::ConcretePJRTArray{Float32, 1, 1}}, layer_3::@NamedTuple{weight::ConcretePJRTArray{Float32, 2, 1}, bias::ConcretePJRTArray{Float32, 1, 1}}}, ::Bool)
@ Reactant.Compiler ~/.julia/packages/Reactant/mMNEy/src/Compiler.jl:4288
[6] (::ReactantExt.var"#26#27"{GenericLossFunction{typeof(Lux.LossFunctionImpl.l2_distance_loss), typeof(mean)}, Tuple{ConcretePJRTArray{Float32, 2, 1}, ConcretePJRTArray{Float32, 2, 1}}})()
@ ReactantExt ~/.julia/packages/Lux/GCC0y/ext/ReactantExt/training.jl:238
[7] single_train_step_impl!(backend::Lux.Training.ReactantBackend{Static.True, Missing, Nothing, AutoEnzyme{Nothing, Nothing}}, objective_function::GenericLossFunction{typeof(Lux.LossFunctionImpl.l2_distance_loss), typeof(mean)}, data::Tuple{ConcretePJRTArray{Float32, 2, 1}, ConcretePJRTArray{Float32, 2, 1}}, ts::Lux.Training.TrainState{Nothing, Nothing, Nothing, Chain{@NamedTuple{layer_1::Dense{typeof(NNlib.gelu_tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(NNlib.gelu_tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, @NamedTuple{layer_1::@NamedTuple{weight::ConcretePJRTArray{Float32, 2, 1}, bias::ConcretePJRTArray{Float32, 1, 1}}, layer_2::@NamedTuple{weight::ConcretePJRTArray{Float32, 2, 1}, bias::ConcretePJRTArray{Float32, 1, 1}}, layer_3::@NamedTuple{weight::ConcretePJRTArray{Float32, 2, 1}, bias::ConcretePJRTArray{Float32, 1, 1}}}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, Lux.ReactantCompatibleOptimisers.ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1}, Tuple{ConcretePJRTNumber{Float64, 1}, ConcretePJRTNumber{Float64, 1}}, ConcretePJRTNumber{Float64, 1}}}, @NamedTuple{layer_1::@NamedTuple{weight::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1}, Tuple{ConcretePJRTNumber{Float64, 1}, ConcretePJRTNumber{Float64, 1}}, ConcretePJRTNumber{Float64, 1}}}, Tuple{ConcretePJRTArray{Float32, 2, 1}, ConcretePJRTArray{Float32, 2, 1}, Tuple{ConcretePJRTNumber{Float32, 1}, ConcretePJRTNumber{Float32, 1}}}}, bias::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1}, Tuple{ConcretePJRTNumber{Float64, 1}, ConcretePJRTNumber{Float64, 1}}, ConcretePJRTNumber{Float64, 1}}}, Tuple{ConcretePJRTArray{Float32, 1, 1}, ConcretePJRTArray{Float32, 1, 1}, Tuple{ConcretePJRTNumber{Float32, 1}, ConcretePJRTNumber{Float32, 1}}}}}, layer_2::@NamedTuple{weight::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1}, Tuple{ConcretePJRTNumber{Float64, 1}, ConcretePJRTNumber{Float64, 1}}, ConcretePJRTNumber{Float64, 1}}}, Tuple{ConcretePJRTArray{Float32, 2, 1}, ConcretePJRTArray{Float32, 2, 1}, Tuple{ConcretePJRTNumber{Float32, 1}, ConcretePJRTNumber{Float32, 1}}}}, bias::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1}, Tuple{ConcretePJRTNumber{Float64, 1}, ConcretePJRTNumber{Float64, 1}}, ConcretePJRTNumber{Float64, 1}}}, Tuple{ConcretePJRTArray{Float32, 1, 1}, ConcretePJRTArray{Float32, 1, 1}, Tuple{ConcretePJRTNumber{Float32, 1}, ConcretePJRTNumber{Float32, 1}}}}}, layer_3::@NamedTuple{weight::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1}, Tuple{ConcretePJRTNumber{Float64, 1}, ConcretePJRTNumber{Float64, 1}}, ConcretePJRTNumber{Float64, 1}}}, Tuple{ConcretePJRTArray{Float32, 2, 1}, ConcretePJRTArray{Float32, 2, 1}, Tuple{ConcretePJRTNumber{Float32, 1}, ConcretePJRTNumber{Float32, 1}}}}, bias::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1}, Tuple{ConcretePJRTNumber{Float64, 1}, ConcretePJRTNumber{Float64, 1}}, ConcretePJRTNumber{Float64, 1}}}, Tuple{ConcretePJRTArray{Float32, 1, 1}, ConcretePJRTArray{Float32, 1, 1}, Tuple{ConcretePJRTNumber{Float32, 1}, ConcretePJRTNumber{Float32, 1}}}}}}})
@ ReactantExt ~/.julia/packages/Lux/GCC0y/ext/ReactantExt/training.jl:182
[8] single_train_step_impl_with_allocator_cache!
@ ~/.julia/packages/Lux/GCC0y/src/helpers/training.jl:509 [inlined]
[9] #single_train_step!#9
@ ~/.julia/packages/Lux/GCC0y/src/helpers/training.jl:469 [inlined]
[10] single_train_step!(backend::AutoEnzyme{Nothing, Nothing}, obj_fn::GenericLossFunction{typeof(Lux.LossFunctionImpl.l2_distance_loss), typeof(mean)}, data::Tuple{ConcretePJRTArray{Float32, 2, 1}, ConcretePJRTArray{Float32, 2, 1}}, ts::Lux.Training.TrainState{Nothing, Nothing, Nothing, Chain{@NamedTuple{layer_1::Dense{typeof(NNlib.gelu_tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(NNlib.gelu_tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, @NamedTuple{layer_1::@NamedTuple{weight::ConcretePJRTArray{Float32, 2, 1}, bias::ConcretePJRTArray{Float32, 1, 1}}, layer_2::@NamedTuple{weight::ConcretePJRTArray{Float32, 2, 1}, bias::ConcretePJRTArray{Float32, 1, 1}}, layer_3::@NamedTuple{weight::ConcretePJRTArray{Float32, 2, 1}, bias::ConcretePJRTArray{Float32, 1, 1}}}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, Lux.ReactantCompatibleOptimisers.ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1}, Tuple{ConcretePJRTNumber{Float64, 1}, ConcretePJRTNumber{Float64, 1}}, ConcretePJRTNumber{Float64, 1}}}, @NamedTuple{layer_1::@NamedTuple{weight::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1}, Tuple{ConcretePJRTNumber{Float64, 1}, ConcretePJRTNumber{Float64, 1}}, ConcretePJRTNumber{Float64, 1}}}, Tuple{ConcretePJRTArray{Float32, 2, 1}, ConcretePJRTArray{Float32, 2, 1}, Tuple{ConcretePJRTNumber{Float32, 1}, ConcretePJRTNumber{Float32, 1}}}}, bias::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1}, Tuple{ConcretePJRTNumber{Float64, 1}, ConcretePJRTNumber{Float64, 1}}, ConcretePJRTNumber{Float64, 1}}}, Tuple{ConcretePJRTArray{Float32, 1, 1}, ConcretePJRTArray{Float32, 1, 1}, Tuple{ConcretePJRTNumber{Float32, 1}, ConcretePJRTNumber{Float32, 1}}}}}, layer_2::@NamedTuple{weight::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1}, Tuple{ConcretePJRTNumber{Float64, 1}, ConcretePJRTNumber{Float64, 1}}, ConcretePJRTNumber{Float64, 1}}}, Tuple{ConcretePJRTArray{Float32, 2, 1}, ConcretePJRTArray{Float32, 2, 1}, Tuple{ConcretePJRTNumber{Float32, 1}, ConcretePJRTNumber{Float32, 1}}}}, bias::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1}, Tuple{ConcretePJRTNumber{Float64, 1}, ConcretePJRTNumber{Float64, 1}}, ConcretePJRTNumber{Float64, 1}}}, Tuple{ConcretePJRTArray{Float32, 1, 1}, ConcretePJRTArray{Float32, 1, 1}, Tuple{ConcretePJRTNumber{Float32, 1}, ConcretePJRTNumber{Float32, 1}}}}}, layer_3::@NamedTuple{weight::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1}, Tuple{ConcretePJRTNumber{Float64, 1}, ConcretePJRTNumber{Float64, 1}}, ConcretePJRTNumber{Float64, 1}}}, Tuple{ConcretePJRTArray{Float32, 2, 1}, ConcretePJRTArray{Float32, 2, 1}, Tuple{ConcretePJRTNumber{Float32, 1}, ConcretePJRTNumber{Float32, 1}}}}, bias::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1}, Tuple{ConcretePJRTNumber{Float64, 1}, ConcretePJRTNumber{Float64, 1}}, ConcretePJRTNumber{Float64, 1}}}, Tuple{ConcretePJRTArray{Float32, 1, 1}, ConcretePJRTArray{Float32, 1, 1}, Tuple{ConcretePJRTNumber{Float32, 1}, ConcretePJRTNumber{Float32, 1}}}}}}})
@ Lux.Training ~/.julia/packages/Lux/GCC0y/src/helpers/training.jl:465
[11] train_model(model::Chain{@NamedTuple{layer_1::Dense{typeof(NNlib.gelu_tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(NNlib.gelu_tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, ps::@NamedTuple{layer_1::@NamedTuple{weight::ConcretePJRTArray{Float32, 2, 1}, bias::ConcretePJRTArray{Float32, 1, 1}}, layer_2::@NamedTuple{weight::ConcretePJRTArray{Float32, 2, 1}, bias::ConcretePJRTArray{Float32, 1, 1}}, layer_3::@NamedTuple{weight::ConcretePJRTArray{Float32, 2, 1}, bias::ConcretePJRTArray{Float32, 1, 1}}}, st::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, dataloader::DeviceIterator{ReactantDevice{Missing, Missing, Missing, Missing, Union{}}, Base.Iterators.Zip{Tuple{Vector{Matrix{Float32}}, Vector{Matrix{Float32}}}}})
@ Main ~/Code/Julia/MWE/LuxBenchmark/lux_tutorial.jl:72
[12] macro expansion
@ ./timing.jl:697 [inlined]
[13] top-level scope
@ ~/Code/Julia/MWE/LuxBenchmark/lux_tutorial.jl:353
[14] include(mod::Module, _path::String)
@ Base ./Base.jl:306
[15] exec_options(opts::Base.JLOptions)
@ Base ./client.jl:317
[16] _start()
@ Base ./client.jl:550
in expression starting at /home/mr/Code/Julia/MWE/LuxBenchmark/lux_tutorial.jl:87
LuxBenchmark on master [+?] via ஃ v1.12.5 took 7m21s
❯
Output for arg=3
LuxBenchmark on master [+?] via ஃ v1.12.5
❯ julia --project=. lux_tutorial.jl 3
[ Info: SET_GPU_BACKEND = true
[ Info: USE_REACTANT_DEVICE = true
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1772218358.347884 98230 service.cc:153] XLA service 0x3a4a0450 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1772218358.347908 98230 service.cc:161] StreamExecutor [0]: NVIDIA GeForce RTX 5070 Ti, Compute Capability 12.0a (Driver: 13.0.0; Runtime: 13.0.0; Toolkit: 13.0.0; DNN: 9.14.0)
I0000 00:00:1772218358.348604 98230 se_gpu_pjrt_client.cc:1467] Using BFC allocator.
I0000 00:00:1772218358.348640 98230 gpu_helpers.cc:141] XLA backend allocating 12439388160 bytes on device 0 for BFCAllocator.
I0000 00:00:1772218358.348674 98230 gpu_helpers.cc:183] XLA backend will use up to 4146462720 bytes on device 0 for CollectiveBFCAllocator.
I0000 00:00:1772218358.355407 98230 cuda_dnn.cc:461] Loaded cuDNN version 91400
[ Info: xdev = ReactantDevice{Missing, Missing, Missing, Missing, Union{}}(missing, missing, missing)
[ Info: Compiling model with Reactant...
29.816421 seconds (72.39 M allocations: 3.537 GiB, 4.79% gc time, 99.12% compilation time: <1% of which was recompilation)
[ Info: Inference with compiled model... (1st time)
0.368544 seconds (816.72 k allocations: 40.730 MiB, 87.53% compilation time)
[ Info: Inference with compiled model... (2nd time)
0.000175 seconds (27 allocations: 944 bytes)
[ Info: Mean difference: 0.00011303925
[ Info: Training model with Reactant... (1st time)
I0000 00:00:1772218794.696071 98230 dot_merger.cc:481] Merging Dots in computation: main.4
Iter: [ 1/1000] Loss: 1.85156107
Iter: [ 100/1000] Loss: 0.31370908
Iter: [ 200/1000] Loss: 0.02984785
Iter: [ 300/1000] Loss: 0.01746508
Iter: [ 400/1000] Loss: 0.00948397
Iter: [ 500/1000] Loss: 0.00613383
Iter: [ 600/1000] Loss: 0.00411469
Iter: [ 700/1000] Loss: 0.00279575
Iter: [ 800/1000] Loss: 0.00192928
Iter: [ 900/1000] Loss: 0.00136065
Iter: [1000/1000] Loss: 0.00101000
412.242002 seconds (956.06 M allocations: 46.109 GiB, 2.23% gc time, 98.27% compilation time: <1% of which was recompilation)
[ Info: Training model with Reactant... (2nd time)
I0000 00:00:1772218804.849489 98230 dot_merger.cc:481] Merging Dots in computation: main.4
ERROR: LoadError: INVALID_ARGUMENT: Invalid buffer passed to Execute() as argument 2 to replica 0: INVALID_ARGUMENT: Donation requested for invalid buffer
Stacktrace:
[1] reactant_err(msg::Cstring)
@ Reactant.XLA ~/.julia/packages/Reactant/mMNEy/src/xla/Utils.jl:12
[2] macro expansion
@ ~/.julia/packages/Reactant/mMNEy/src/xla/PJRT/LoadedExecutable.jl:217 [inlined]
[3] execute_sharded
@ ~/.julia/packages/Reactant/mMNEy/src/xla/PJRT/LoadedExecutable.jl:186 [inlined]
[4] macro expansion
@ ~/.julia/packages/Reactant/mMNEy/src/Compiler.jl:3788 [inlined]
[5] (::Reactant.Compiler.Thunk{typeof(ReactantExt.compute_gradients_internal_and_step!), Symbol("##compute_gradients_internal_and_step!_reactant#1148612"), false, Tuple{GenericLossFunction{typeof(Lux.LossFunctionImpl.l2_distance_loss), typeof(mean)}, Chain{@NamedTuple{layer_1::Dense{typeof(NNlib.gelu_tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(NNlib.gelu_tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, Tuple{ConcretePJRTArray{Float32, 2, 1}, ConcretePJRTArray{Float32, 2, 1}}, @NamedTuple{layer_1::@NamedTuple{weight::ConcretePJRTArray{Float32, 2, 1}, bias::ConcretePJRTArray{Float32, 1, 1}}, layer_2::@NamedTuple{weight::ConcretePJRTArray{Float32, 2, 1}, bias::ConcretePJRTArray{Float32, 1, 1}}, layer_3::@NamedTuple{weight::ConcretePJRTArray{Float32, 2, 1}, bias::ConcretePJRTArray{Float32, 1, 1}}}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, @NamedTuple{layer_1::@NamedTuple{weight::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1}, Tuple{ConcretePJRTNumber{Float64, 1}, ConcretePJRTNumber{Float64, 1}}, ConcretePJRTNumber{Float64, 1}}}, Tuple{ConcretePJRTArray{Float32, 2, 1}, ConcretePJRTArray{Float32, 2, 1}, Tuple{ConcretePJRTNumber{Float32, 1}, ConcretePJRTNumber{Float32, 1}}}}, bias::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1}, Tuple{ConcretePJRTNumber{Float64, 1}, ConcretePJRTNumber{Float64, 1}}, ConcretePJRTNumber{Float64, 1}}}, Tuple{ConcretePJRTArray{Float32, 1, 1}, ConcretePJRTArray{Float32, 1, 1}, Tuple{ConcretePJRTNumber{Float32, 1}, ConcretePJRTNumber{Float32, 1}}}}}, layer_2::@NamedTuple{weight::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1}, Tuple{ConcretePJRTNumber{Float64, 1}, ConcretePJRTNumber{Float64, 1}}, ConcretePJRTNumber{Float64, 1}}}, Tuple{ConcretePJRTArray{Float32, 2, 1}, ConcretePJRTArray{Float32, 2, 1}, Tuple{ConcretePJRTNumber{Float32, 1}, ConcretePJRTNumber{Float32, 1}}}}, bias::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1}, Tuple{ConcretePJRTNumber{Float64, 1}, ConcretePJRTNumber{Float64, 1}}, ConcretePJRTNumber{Float64, 1}}}, Tuple{ConcretePJRTArray{Float32, 1, 1}, ConcretePJRTArray{Float32, 1, 1}, Tuple{ConcretePJRTNumber{Float32, 1}, ConcretePJRTNumber{Float32, 1}}}}}, layer_3::@NamedTuple{weight::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1}, Tuple{ConcretePJRTNumber{Float64, 1}, ConcretePJRTNumber{Float64, 1}}, ConcretePJRTNumber{Float64, 1}}}, Tuple{ConcretePJRTArray{Float32, 2, 1}, ConcretePJRTArray{Float32, 2, 1}, Tuple{ConcretePJRTNumber{Float32, 1}, ConcretePJRTNumber{Float32, 1}}}}, bias::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1}, Tuple{ConcretePJRTNumber{Float64, 1}, ConcretePJRTNumber{Float64, 1}}, ConcretePJRTNumber{Float64, 1}}}, Tuple{ConcretePJRTArray{Float32, 1, 1}, ConcretePJRTArray{Float32, 1, 1}, Tuple{ConcretePJRTNumber{Float32, 1}, ConcretePJRTNumber{Float32, 1}}}}}}, @NamedTuple{layer_1::@NamedTuple{weight::ConcretePJRTArray{Float32, 2, 1}, bias::ConcretePJRTArray{Float32, 1, 1}}, layer_2::@NamedTuple{weight::ConcretePJRTArray{Float32, 2, 1}, bias::ConcretePJRTArray{Float32, 1, 1}}, layer_3::@NamedTuple{weight::ConcretePJRTArray{Float32, 2, 1}, bias::ConcretePJRTArray{Float32, 1, 1}}}, Bool}, Reactant.XLA.PJRT.LoadedExecutable, Reactant.XLA.PJRT.Device, Reactant.XLA.PJRT.Client, Tuple{}, Vector{Bool}})(::GenericLossFunction{typeof(Lux.LossFunctionImpl.l2_distance_loss), typeof(mean)}, ::Chain{@NamedTuple{layer_1::Dense{typeof(NNlib.gelu_tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(NNlib.gelu_tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, ::Tuple{ConcretePJRTArray{Float32, 2, 1}, ConcretePJRTArray{Float32, 2, 1}}, ::@NamedTuple{layer_1::@NamedTuple{weight::ConcretePJRTArray{Float32, 2, 1}, bias::ConcretePJRTArray{Float32, 1, 1}}, layer_2::@NamedTuple{weight::ConcretePJRTArray{Float32, 2, 1}, bias::ConcretePJRTArray{Float32, 1, 1}}, layer_3::@NamedTuple{weight::ConcretePJRTArray{Float32, 2, 1}, bias::ConcretePJRTArray{Float32, 1, 1}}}, ::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, ::@NamedTuple{layer_1::@NamedTuple{weight::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1}, Tuple{ConcretePJRTNumber{Float64, 1}, ConcretePJRTNumber{Float64, 1}}, ConcretePJRTNumber{Float64, 1}}}, Tuple{ConcretePJRTArray{Float32, 2, 1}, ConcretePJRTArray{Float32, 2, 1}, Tuple{ConcretePJRTNumber{Float32, 1}, ConcretePJRTNumber{Float32, 1}}}}, bias::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1}, Tuple{ConcretePJRTNumber{Float64, 1}, ConcretePJRTNumber{Float64, 1}}, ConcretePJRTNumber{Float64, 1}}}, Tuple{ConcretePJRTArray{Float32, 1, 1}, ConcretePJRTArray{Float32, 1, 1}, Tuple{ConcretePJRTNumber{Float32, 1}, ConcretePJRTNumber{Float32, 1}}}}}, layer_2::@NamedTuple{weight::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1}, Tuple{ConcretePJRTNumber{Float64, 1}, ConcretePJRTNumber{Float64, 1}}, ConcretePJRTNumber{Float64, 1}}}, Tuple{ConcretePJRTArray{Float32, 2, 1}, ConcretePJRTArray{Float32, 2, 1}, Tuple{ConcretePJRTNumber{Float32, 1}, ConcretePJRTNumber{Float32, 1}}}}, bias::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1}, Tuple{ConcretePJRTNumber{Float64, 1}, ConcretePJRTNumber{Float64, 1}}, ConcretePJRTNumber{Float64, 1}}}, Tuple{ConcretePJRTArray{Float32, 1, 1}, ConcretePJRTArray{Float32, 1, 1}, Tuple{ConcretePJRTNumber{Float32, 1}, ConcretePJRTNumber{Float32, 1}}}}}, layer_3::@NamedTuple{weight::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1}, Tuple{ConcretePJRTNumber{Float64, 1}, ConcretePJRTNumber{Float64, 1}}, ConcretePJRTNumber{Float64, 1}}}, Tuple{ConcretePJRTArray{Float32, 2, 1}, ConcretePJRTArray{Float32, 2, 1}, Tuple{ConcretePJRTNumber{Float32, 1}, ConcretePJRTNumber{Float32, 1}}}}, bias::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1}, Tuple{ConcretePJRTNumber{Float64, 1}, ConcretePJRTNumber{Float64, 1}}, ConcretePJRTNumber{Float64, 1}}}, Tuple{ConcretePJRTArray{Float32, 1, 1}, ConcretePJRTArray{Float32, 1, 1}, Tuple{ConcretePJRTNumber{Float32, 1}, ConcretePJRTNumber{Float32, 1}}}}}}, ::@NamedTuple{layer_1::@NamedTuple{weight::ConcretePJRTArray{Float32, 2, 1}, bias::ConcretePJRTArray{Float32, 1, 1}}, layer_2::@NamedTuple{weight::ConcretePJRTArray{Float32, 2, 1}, bias::ConcretePJRTArray{Float32, 1, 1}}, layer_3::@NamedTuple{weight::ConcretePJRTArray{Float32, 2, 1}, bias::ConcretePJRTArray{Float32, 1, 1}}}, ::Bool)
@ Reactant.Compiler ~/.julia/packages/Reactant/mMNEy/src/Compiler.jl:4288
[6] (::ReactantExt.var"#26#27"{GenericLossFunction{typeof(Lux.LossFunctionImpl.l2_distance_loss), typeof(mean)}, Tuple{ConcretePJRTArray{Float32, 2, 1}, ConcretePJRTArray{Float32, 2, 1}}})()
@ ReactantExt ~/.julia/packages/Lux/GCC0y/ext/ReactantExt/training.jl:238
[7] single_train_step_impl!(backend::Lux.Training.ReactantBackend{Static.True, Missing, Nothing, AutoEnzyme{Nothing, Nothing}}, objective_function::GenericLossFunction{typeof(Lux.LossFunctionImpl.l2_distance_loss), typeof(mean)}, data::Tuple{ConcretePJRTArray{Float32, 2, 1}, ConcretePJRTArray{Float32, 2, 1}}, ts::Lux.Training.TrainState{Nothing, Nothing, Nothing, Chain{@NamedTuple{layer_1::Dense{typeof(NNlib.gelu_tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(NNlib.gelu_tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, @NamedTuple{layer_1::@NamedTuple{weight::ConcretePJRTArray{Float32, 2, 1}, bias::ConcretePJRTArray{Float32, 1, 1}}, layer_2::@NamedTuple{weight::ConcretePJRTArray{Float32, 2, 1}, bias::ConcretePJRTArray{Float32, 1, 1}}, layer_3::@NamedTuple{weight::ConcretePJRTArray{Float32, 2, 1}, bias::ConcretePJRTArray{Float32, 1, 1}}}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, Lux.ReactantCompatibleOptimisers.ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1}, Tuple{ConcretePJRTNumber{Float64, 1}, ConcretePJRTNumber{Float64, 1}}, ConcretePJRTNumber{Float64, 1}}}, @NamedTuple{layer_1::@NamedTuple{weight::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1}, Tuple{ConcretePJRTNumber{Float64, 1}, ConcretePJRTNumber{Float64, 1}}, ConcretePJRTNumber{Float64, 1}}}, Tuple{ConcretePJRTArray{Float32, 2, 1}, ConcretePJRTArray{Float32, 2, 1}, Tuple{ConcretePJRTNumber{Float32, 1}, ConcretePJRTNumber{Float32, 1}}}}, bias::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1}, Tuple{ConcretePJRTNumber{Float64, 1}, ConcretePJRTNumber{Float64, 1}}, ConcretePJRTNumber{Float64, 1}}}, Tuple{ConcretePJRTArray{Float32, 1, 1}, ConcretePJRTArray{Float32, 1, 1}, Tuple{ConcretePJRTNumber{Float32, 1}, ConcretePJRTNumber{Float32, 1}}}}}, layer_2::@NamedTuple{weight::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1}, Tuple{ConcretePJRTNumber{Float64, 1}, ConcretePJRTNumber{Float64, 1}}, ConcretePJRTNumber{Float64, 1}}}, Tuple{ConcretePJRTArray{Float32, 2, 1}, ConcretePJRTArray{Float32, 2, 1}, Tuple{ConcretePJRTNumber{Float32, 1}, ConcretePJRTNumber{Float32, 1}}}}, bias::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1}, Tuple{ConcretePJRTNumber{Float64, 1}, ConcretePJRTNumber{Float64, 1}}, ConcretePJRTNumber{Float64, 1}}}, Tuple{ConcretePJRTArray{Float32, 1, 1}, ConcretePJRTArray{Float32, 1, 1}, Tuple{ConcretePJRTNumber{Float32, 1}, ConcretePJRTNumber{Float32, 1}}}}}, layer_3::@NamedTuple{weight::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1}, Tuple{ConcretePJRTNumber{Float64, 1}, ConcretePJRTNumber{Float64, 1}}, ConcretePJRTNumber{Float64, 1}}}, Tuple{ConcretePJRTArray{Float32, 2, 1}, ConcretePJRTArray{Float32, 2, 1}, Tuple{ConcretePJRTNumber{Float32, 1}, ConcretePJRTNumber{Float32, 1}}}}, bias::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1}, Tuple{ConcretePJRTNumber{Float64, 1}, ConcretePJRTNumber{Float64, 1}}, ConcretePJRTNumber{Float64, 1}}}, Tuple{ConcretePJRTArray{Float32, 1, 1}, ConcretePJRTArray{Float32, 1, 1}, Tuple{ConcretePJRTNumber{Float32, 1}, ConcretePJRTNumber{Float32, 1}}}}}}})
@ ReactantExt ~/.julia/packages/Lux/GCC0y/ext/ReactantExt/training.jl:182
[8] single_train_step_impl_with_allocator_cache!
@ ~/.julia/packages/Lux/GCC0y/src/helpers/training.jl:509 [inlined]
[9] #single_train_step!#9
@ ~/.julia/packages/Lux/GCC0y/src/helpers/training.jl:469 [inlined]
[10] single_train_step!(backend::AutoEnzyme{Nothing, Nothing}, obj_fn::GenericLossFunction{typeof(Lux.LossFunctionImpl.l2_distance_loss), typeof(mean)}, data::Tuple{ConcretePJRTArray{Float32, 2, 1}, ConcretePJRTArray{Float32, 2, 1}}, ts::Lux.Training.TrainState{Nothing, Nothing, Nothing, Chain{@NamedTuple{layer_1::Dense{typeof(NNlib.gelu_tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(NNlib.gelu_tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, @NamedTuple{layer_1::@NamedTuple{weight::ConcretePJRTArray{Float32, 2, 1}, bias::ConcretePJRTArray{Float32, 1, 1}}, layer_2::@NamedTuple{weight::ConcretePJRTArray{Float32, 2, 1}, bias::ConcretePJRTArray{Float32, 1, 1}}, layer_3::@NamedTuple{weight::ConcretePJRTArray{Float32, 2, 1}, bias::ConcretePJRTArray{Float32, 1, 1}}}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, Lux.ReactantCompatibleOptimisers.ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1}, Tuple{ConcretePJRTNumber{Float64, 1}, ConcretePJRTNumber{Float64, 1}}, ConcretePJRTNumber{Float64, 1}}}, @NamedTuple{layer_1::@NamedTuple{weight::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1}, Tuple{ConcretePJRTNumber{Float64, 1}, ConcretePJRTNumber{Float64, 1}}, ConcretePJRTNumber{Float64, 1}}}, Tuple{ConcretePJRTArray{Float32, 2, 1}, ConcretePJRTArray{Float32, 2, 1}, Tuple{ConcretePJRTNumber{Float32, 1}, ConcretePJRTNumber{Float32, 1}}}}, bias::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1}, Tuple{ConcretePJRTNumber{Float64, 1}, ConcretePJRTNumber{Float64, 1}}, ConcretePJRTNumber{Float64, 1}}}, Tuple{ConcretePJRTArray{Float32, 1, 1}, ConcretePJRTArray{Float32, 1, 1}, Tuple{ConcretePJRTNumber{Float32, 1}, ConcretePJRTNumber{Float32, 1}}}}}, layer_2::@NamedTuple{weight::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1}, Tuple{ConcretePJRTNumber{Float64, 1}, ConcretePJRTNumber{Float64, 1}}, ConcretePJRTNumber{Float64, 1}}}, Tuple{ConcretePJRTArray{Float32, 2, 1}, ConcretePJRTArray{Float32, 2, 1}, Tuple{ConcretePJRTNumber{Float32, 1}, ConcretePJRTNumber{Float32, 1}}}}, bias::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1}, Tuple{ConcretePJRTNumber{Float64, 1}, ConcretePJRTNumber{Float64, 1}}, ConcretePJRTNumber{Float64, 1}}}, Tuple{ConcretePJRTArray{Float32, 1, 1}, ConcretePJRTArray{Float32, 1, 1}, Tuple{ConcretePJRTNumber{Float32, 1}, ConcretePJRTNumber{Float32, 1}}}}}, layer_3::@NamedTuple{weight::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1}, Tuple{ConcretePJRTNumber{Float64, 1}, ConcretePJRTNumber{Float64, 1}}, ConcretePJRTNumber{Float64, 1}}}, Tuple{ConcretePJRTArray{Float32, 2, 1}, ConcretePJRTArray{Float32, 2, 1}, Tuple{ConcretePJRTNumber{Float32, 1}, ConcretePJRTNumber{Float32, 1}}}}, bias::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1}, Tuple{ConcretePJRTNumber{Float64, 1}, ConcretePJRTNumber{Float64, 1}}, ConcretePJRTNumber{Float64, 1}}}, Tuple{ConcretePJRTArray{Float32, 1, 1}, ConcretePJRTArray{Float32, 1, 1}, Tuple{ConcretePJRTNumber{Float32, 1}, ConcretePJRTNumber{Float32, 1}}}}}}})
@ Lux.Training ~/.julia/packages/Lux/GCC0y/src/helpers/training.jl:465
[11] train_model(model::Chain{@NamedTuple{layer_1::Dense{typeof(NNlib.gelu_tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(NNlib.gelu_tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, ps::@NamedTuple{layer_1::@NamedTuple{weight::ConcretePJRTArray{Float32, 2, 1}, bias::ConcretePJRTArray{Float32, 1, 1}}, layer_2::@NamedTuple{weight::ConcretePJRTArray{Float32, 2, 1}, bias::ConcretePJRTArray{Float32, 1, 1}}, layer_3::@NamedTuple{weight::ConcretePJRTArray{Float32, 2, 1}, bias::ConcretePJRTArray{Float32, 1, 1}}}, st::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, dataloader::DeviceIterator{ReactantDevice{Missing, Missing, Missing, Missing, Union{}}, Base.Iterators.Zip{Tuple{Vector{Matrix{Float32}}, Vector{Matrix{Float32}}}}})
@ Main ~/Code/Julia/MWE/LuxBenchmark/lux_tutorial.jl:72
[12] macro expansion
@ ./timing.jl:697 [inlined]
[13] top-level scope
@ ~/Code/Julia/MWE/LuxBenchmark/lux_tutorial.jl:353
[14] include(mod::Module, _path::String)
@ Base ./Base.jl:306
[15] exec_options(opts::Base.JLOptions)
@ Base ./client.jl:317
[16] _start()
@ Base ./client.jl:550
in expression starting at /home/mr/Code/Julia/MWE/LuxBenchmark/lux_tutorial.jl:87
LuxBenchmark on master [+?] via ஃ v1.12.5 took 7m31s
❯
System and Julia info:
LuxBenchmark on master [+?] via ஃ v1.12.5
❯ julia --project=. -q
julia> versioninfo()
Julia Version 1.12.5
Commit 5fe89b8ddc1 (2026-02-09 16:05 UTC)
Build Info:
Official https://julialang.org release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 12 × AMD Ryzen 5 5600X 6-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-18.1.7 (ORCJIT, znver3)
GC: Built with stock GC
Threads: 1 default, 1 interactive, 1 GC (on 12 virtual cores)
(LuxBenchmark) pkg> st
Status `~/Code/Julia/MWE/LuxBenchmark/Project.toml`
[7da242da] Enzyme v0.13.129
[b2108857] Lux v1.31.3
[d0bbae9a] LuxCUDA v0.3.4
[da2b9cff] Mooncake v0.5.8
[3bd65402] Optimisers v0.4.7
[92933f4c] ProgressMeter v1.11.0
[3c362404] Reactant v0.2.228
[10745b16] Statistics v1.11.1
[e88e6eb3] Zygote v0.7.10
[56ddb016] Logging v1.11.0
[de0858da] Printf v1.11.0
[9a3f8284] Random v1.11.0
(LuxBenchmark) pkg>
I have followed Compiling Lux Models using Reactant.jl tutorial and noticed several issues. I have made below a simplified reproducible MWE of this tutorial, where different setups can be specified by argument passed to the script i.e.
julia tutorial.jl <arg>. Here are the issues:train_modelfunction for arg=1 and arg=3 is 400s which is almost 7min (for arg=0 it is 80s), which seems way to long given how simple is the example and that Reactant is the recommended backend.train_modelfor the 2nd time crashed the script when Reactant is used (use arg=1 and arg=3 to reproduce), works fine whencpu_device()is used (arg=0).1e-8, but when using the GPU in MWE results have worryingly high error of1e-4, which is 4 orders of magnitude higher and not acceptable (use arg=3 to reproduce).I would especially appreciate suggestions on how to address the first issue, i.e. how to use Lux with maximum efficiency on both CPU and GPU while avoiding extremely long compilation times.
Here is the MWE code:
Output for arg=0
Output for arg=1
Output for arg=3
System and Julia info: