Skip to content

Problems with "Compiling Lux Models using Reactant.jl" tutorial from docs #1668

@mrarat

Description

@mrarat

I have followed Compiling Lux Models using Reactant.jl tutorial and noticed several issues. I have made below a simplified reproducible MWE of this tutorial, where different setups can be specified by argument passed to the script i.e. julia tutorial.jl <arg>. Here are the issues:

  1. Compile time for train_model function for arg=1 and arg=3 is 400s which is almost 7min (for arg=0 it is 80s), which seems way to long given how simple is the example and that Reactant is the recommended backend.
  2. Trying to run train_model for the 2nd time crashed the script when Reactant is used (use arg=1 and arg=3 to reproduce), works fine when cpu_device() is used (arg=0).
  3. The error between non-Reactant and Reactant inference results differs from the tutorial. In the tutorial errors have order of 1e-8, but when using the GPU in MWE results have worryingly high error of 1e-4, which is 4 orders of magnitude higher and not acceptable (use arg=3 to reproduce).
  4. Even though this example uses small neural network it still immediately allocates 75% of available VRAM, which in my case is 12GB. Is it expected? (use arg=3 to reproduce).

I would especially appreciate suggestions on how to address the first issue, i.e. how to use Lux with maximum efficiency on both CPU and GPU while avoiding extremely long compilation times.

Here is the MWE code:

# Code adopted from Lux.jl tutorial
# https://lux.csail.mit.edu/stable/manual/compiling_lux_models

using Lux, Reactant, Enzyme, Random
using Optimisers, Printf, Statistics, Logging

mode = isempty(ARGS) ? 0 : parse(Int, ARGS[1])  # 0..3
@assert 0 <= mode <= 3 "Pass one Int in 0..3 (bit 1 =SET_GPU_BACKEND, bit 0 =USE_REACTANT_DEVICE)"
const SET_GPU_BACKEND::Bool = (mode & 0x2) != 0
const USE_REACTANT_DEVICE::Bool = (mode & 0x1) != 0

@info "SET_GPU_BACKEND = $SET_GPU_BACKEND"
@info "USE_REACTANT_DEVICE = $USE_REACTANT_DEVICE"

if SET_GPU_BACKEND
    Reactant.set_default_backend("gpu")
else
    Reactant.set_default_backend("cpu")
end

const xdev = USE_REACTANT_DEVICE ? reactant_device() : cpu_device()
@info "xdev = $xdev"

model = Chain(
    Dense(2 => 32, gelu),
    Dense(32 => 32, gelu),
    Dense(32 => 2)
)
ps, st = Lux.setup(MersenneTwister(42), model)

x = randn(Float32, 2, 32)
y = x .^ 2

x_ra = x |> xdev
y_ra = y |> xdev
ps_ra = ps |> xdev
st_ra = st |> xdev

pred_lux, _ = model(x, ps, Lux.testmode(st))

@info "Compiling model with Reactant..."
@time model_compiled = @compile model(x_ra, ps_ra, Lux.testmode(st_ra))

@info "Inference with compiled model... (1st time)"
@time pred_compiled, _ = model_compiled(x_ra, ps_ra, Lux.testmode(st_ra))
@info "Inference with compiled model... (2nd time)"
@time pred_compiled, _ = model_compiled(x_ra, ps_ra, Lux.testmode(st_ra))

diff = pred_lux .- Array(pred_compiled)
@info "Mean difference: $(mean(abs.(diff)))"

# Second part of tutorial
model = Chain(
    Dense(2 => 4, gelu),
    Dense(4 => 4, gelu),
    Dense(4 => 2)
)
ps, st = Lux.setup(MersenneTwister(42), model)

x_ra = [randn(Float32, 2, 32) for _ in 1:32]
y_ra = [xᵢ .^ 2 for xᵢ in x_ra]
ps_ra = ps |> xdev
st_ra = st |> xdev

dataloader = DeviceIterator(xdev, zip(x_ra, y_ra))

function train_model(model, ps, st, dataloader)
    train_state = Training.TrainState(model, ps, st, Adam(0.001f0))

    for iteration in 1:1000
        for (i, (xᵢ, yᵢ)) in enumerate(dataloader)
            _, loss, _, train_state = Training.single_train_step!(
                AutoEnzyme(), MSELoss(), (xᵢ, yᵢ), train_state
            )
            if (iteration % 100 == 0 || iteration == 1) && i == 1
                @printf("Iter: [%4d/%4d]\tLoss: %.8f\n", iteration, 1000, loss)
            end
        end
    end

    return train_state
end

@info "Training model with Reactant... (1st time)"
@time train_model(model, ps_ra, st_ra, dataloader)
@info "Training model with Reactant... (2nd time)"
@time train_model(model, ps_ra, st_ra, dataloader)

Output for arg=0

LuxBenchmark on  master [!+?] via ஃ v1.12.5 
❯ julia --project=. lux_tutorial.jl 0
[ Info: SET_GPU_BACKEND = false
[ Info: USE_REACTANT_DEVICE = false
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1772217755.166700   96056 service.cc:153] XLA service 0x13993ab0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1772217755.166723   96056 service.cc:161]   StreamExecutor [0]: NVIDIA GeForce RTX 5070 Ti, Compute Capability 12.0a (Driver: 13.0.0; Runtime: 13.0.0; Toolkit: 13.0.0; DNN: 9.14.0)
I0000 00:00:1772217755.167400   96056 se_gpu_pjrt_client.cc:1467] Using BFC allocator.
I0000 00:00:1772217755.167429   96056 gpu_helpers.cc:141] XLA backend allocating 12439388160 bytes on device 0 for BFCAllocator.
I0000 00:00:1772217755.167455   96056 gpu_helpers.cc:183] XLA backend will use up to 4146462720 bytes on device 0 for CollectiveBFCAllocator.
I0000 00:00:1772217755.173230   96056 cuda_dnn.cc:461] Loaded cuDNN version 91400
[ Info: xdev = CPUDevice{Missing}()
[ Info: Compiling model with Reactant...
 21.120927 seconds (51.89 M allocations: 2.559 GiB, 5.21% gc time, 99.63% compilation time: 1% of which was recompilation)
[ Info: Inference with compiled model... (1st time)
  0.034475 seconds (105.11 k allocations: 4.617 MiB, 99.86% compilation time)
[ Info: Inference with compiled model... (2nd time)
  0.000017 seconds (15 allocations: 864 bytes)
[ Info: Mean difference: 0.0
[ Info: Training model with Reactant... (1st time)
Iter: [   1/1000]       Loss: 2.73682380
Iter: [ 100/1000]       Loss: 0.85331267
Iter: [ 200/1000]       Loss: 0.24761316
Iter: [ 300/1000]       Loss: 0.04767180
Iter: [ 400/1000]       Loss: 0.01951925
Iter: [ 500/1000]       Loss: 0.01071168
Iter: [ 600/1000]       Loss: 0.00661943
Iter: [ 700/1000]       Loss: 0.00433787
Iter: [ 800/1000]       Loss: 0.00312395
Iter: [ 900/1000]       Loss: 0.00235731
Iter: [1000/1000]       Loss: 0.00182673
 83.286212 seconds (192.08 M allocations: 9.023 GiB, 2.38% gc time, 99.37% compilation time)
[ Info: Training model with Reactant... (2nd time)
Iter: [   1/1000]       Loss: 0.00182255
Iter: [ 100/1000]       Loss: 0.00148260
Iter: [ 200/1000]       Loss: 0.00124297
Iter: [ 300/1000]       Loss: 0.00105778
Iter: [ 400/1000]       Loss: 0.00090997
Iter: [ 500/1000]       Loss: 0.00079514
Iter: [ 600/1000]       Loss: 0.00070684
Iter: [ 700/1000]       Loss: 0.00063873
Iter: [ 800/1000]       Loss: 0.00058609
Iter: [ 900/1000]       Loss: 0.00054562
Iter: [1000/1000]       Loss: 0.00051490
  0.510780 seconds (9.20 M allocations: 606.092 MiB, 8.45% gc time)

LuxBenchmark on  master [+?] via ஃ v1.12.5 took 1m52s 
❯ 

Output for arg=1

LuxBenchmark on  master [+?] via ஃ v1.12.5 
❯ julia --project=. lux_tutorial.jl 1
[ Info: SET_GPU_BACKEND = false
[ Info: USE_REACTANT_DEVICE = true
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1772217899.422195   96881 service.cc:153] XLA service 0x3da6a250 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1772217899.422221   96881 service.cc:161]   StreamExecutor [0]: NVIDIA GeForce RTX 5070 Ti, Compute Capability 12.0a (Driver: 13.0.0; Runtime: 13.0.0; Toolkit: 13.0.0; DNN: 9.14.0)
I0000 00:00:1772217899.422916   96881 se_gpu_pjrt_client.cc:1467] Using BFC allocator.
I0000 00:00:1772217899.422952   96881 gpu_helpers.cc:141] XLA backend allocating 12439388160 bytes on device 0 for BFCAllocator.
I0000 00:00:1772217899.422987   96881 gpu_helpers.cc:183] XLA backend will use up to 4146462720 bytes on device 0 for CollectiveBFCAllocator.
I0000 00:00:1772217899.429490   96881 cuda_dnn.cc:461] Loaded cuDNN version 91400
[ Info: xdev = ReactantDevice{Missing, Missing, Missing, Missing, Union{}}(missing, missing, missing)
[ Info: Compiling model with Reactant...
 28.926677 seconds (72.38 M allocations: 3.536 GiB, 4.14% gc time, 99.47% compilation time: <1% of which was recompilation)
[ Info: Inference with compiled model... (1st time)
  0.311331 seconds (816.71 k allocations: 40.729 MiB, 99.94% compilation time)
[ Info: Inference with compiled model... (2nd time)
  0.000077 seconds (27 allocations: 944 bytes)
[ Info: Mean difference: 5.0582457e-8
[ Info: Training model with Reactant... (1st time)
Iter: [   1/1000]       Loss: 0.81762600
Iter: [ 100/1000]       Loss: 0.13008928
Iter: [ 200/1000]       Loss: 0.03643125
Iter: [ 300/1000]       Loss: 0.02133664
Iter: [ 400/1000]       Loss: 0.01420628
Iter: [ 500/1000]       Loss: 0.00967961
Iter: [ 600/1000]       Loss: 0.00724365
Iter: [ 700/1000]       Loss: 0.00563383
Iter: [ 800/1000]       Loss: 0.00444563
Iter: [ 900/1000]       Loss: 0.00351061
Iter: [1000/1000]       Loss: 0.00274719
403.099325 seconds (956.33 M allocations: 46.126 GiB, 2.17% gc time, 99.29% compilation time: <1% of which was recompilation)
[ Info: Training model with Reactant... (2nd time)
ERROR: LoadError: INVALID_ARGUMENT: Invalid buffer passed to Execute() as argument 2 to replica 0: INVALID_ARGUMENT: Buffer has been deleted or donated.

Stacktrace:
  [1] reactant_err(msg::Cstring)
    @ Reactant.XLA ~/.julia/packages/Reactant/mMNEy/src/xla/Utils.jl:12
  [2] macro expansion
    @ ~/.julia/packages/Reactant/mMNEy/src/xla/PJRT/LoadedExecutable.jl:217 [inlined]
  [3] execute_sharded
    @ ~/.julia/packages/Reactant/mMNEy/src/xla/PJRT/LoadedExecutable.jl:186 [inlined]
  [4] macro expansion
    @ ~/.julia/packages/Reactant/mMNEy/src/Compiler.jl:3788 [inlined]
  [5] (::Reactant.Compiler.Thunk{typeof(ReactantExt.compute_gradients_internal_and_step!), Symbol("##compute_gradients_internal_and_step!_reactant#1148612"), false, Tuple{GenericLossFunction{typeof(Lux.LossFunctionImpl.l2_distance_loss), typeof(mean)}, Chain{@NamedTuple{layer_1::Dense{typeof(NNlib.gelu_tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(NNlib.gelu_tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, Tuple{ConcretePJRTArray{Float32, 2, 1}, ConcretePJRTArray{Float32, 2, 1}}, @NamedTuple{layer_1::@NamedTuple{weight::ConcretePJRTArray{Float32, 2, 1}, bias::ConcretePJRTArray{Float32, 1, 1}}, layer_2::@NamedTuple{weight::ConcretePJRTArray{Float32, 2, 1}, bias::ConcretePJRTArray{Float32, 1, 1}}, layer_3::@NamedTuple{weight::ConcretePJRTArray{Float32, 2, 1}, bias::ConcretePJRTArray{Float32, 1, 1}}}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, @NamedTuple{layer_1::@NamedTuple{weight::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1}, Tuple{ConcretePJRTNumber{Float64, 1}, ConcretePJRTNumber{Float64, 1}}, ConcretePJRTNumber{Float64, 1}}}, Tuple{ConcretePJRTArray{Float32, 2, 1}, ConcretePJRTArray{Float32, 2, 1}, Tuple{ConcretePJRTNumber{Float32, 1}, ConcretePJRTNumber{Float32, 1}}}}, bias::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1}, Tuple{ConcretePJRTNumber{Float64, 1}, ConcretePJRTNumber{Float64, 1}}, ConcretePJRTNumber{Float64, 1}}}, Tuple{ConcretePJRTArray{Float32, 1, 1}, ConcretePJRTArray{Float32, 1, 1}, Tuple{ConcretePJRTNumber{Float32, 1}, ConcretePJRTNumber{Float32, 1}}}}}, layer_2::@NamedTuple{weight::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1}, Tuple{ConcretePJRTNumber{Float64, 1}, ConcretePJRTNumber{Float64, 1}}, ConcretePJRTNumber{Float64, 1}}}, Tuple{ConcretePJRTArray{Float32, 2, 1}, ConcretePJRTArray{Float32, 2, 1}, Tuple{ConcretePJRTNumber{Float32, 1}, ConcretePJRTNumber{Float32, 1}}}}, bias::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1}, Tuple{ConcretePJRTNumber{Float64, 1}, ConcretePJRTNumber{Float64, 1}}, ConcretePJRTNumber{Float64, 1}}}, Tuple{ConcretePJRTArray{Float32, 1, 1}, ConcretePJRTArray{Float32, 1, 1}, Tuple{ConcretePJRTNumber{Float32, 1}, ConcretePJRTNumber{Float32, 1}}}}}, layer_3::@NamedTuple{weight::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1}, Tuple{ConcretePJRTNumber{Float64, 1}, ConcretePJRTNumber{Float64, 1}}, ConcretePJRTNumber{Float64, 1}}}, Tuple{ConcretePJRTArray{Float32, 2, 1}, ConcretePJRTArray{Float32, 2, 1}, Tuple{ConcretePJRTNumber{Float32, 1}, ConcretePJRTNumber{Float32, 1}}}}, bias::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1}, Tuple{ConcretePJRTNumber{Float64, 1}, ConcretePJRTNumber{Float64, 1}}, ConcretePJRTNumber{Float64, 1}}}, Tuple{ConcretePJRTArray{Float32, 1, 1}, ConcretePJRTArray{Float32, 1, 1}, Tuple{ConcretePJRTNumber{Float32, 1}, ConcretePJRTNumber{Float32, 1}}}}}}, @NamedTuple{layer_1::@NamedTuple{weight::ConcretePJRTArray{Float32, 2, 1}, bias::ConcretePJRTArray{Float32, 1, 1}}, layer_2::@NamedTuple{weight::ConcretePJRTArray{Float32, 2, 1}, bias::ConcretePJRTArray{Float32, 1, 1}}, layer_3::@NamedTuple{weight::ConcretePJRTArray{Float32, 2, 1}, bias::ConcretePJRTArray{Float32, 1, 1}}}, Bool}, Reactant.XLA.PJRT.LoadedExecutable, Reactant.XLA.PJRT.Device, Reactant.XLA.PJRT.Client, Tuple{}, Vector{Bool}})(::GenericLossFunction{typeof(Lux.LossFunctionImpl.l2_distance_loss), typeof(mean)}, ::Chain{@NamedTuple{layer_1::Dense{typeof(NNlib.gelu_tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(NNlib.gelu_tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, ::Tuple{ConcretePJRTArray{Float32, 2, 1}, ConcretePJRTArray{Float32, 2, 1}}, ::@NamedTuple{layer_1::@NamedTuple{weight::ConcretePJRTArray{Float32, 2, 1}, bias::ConcretePJRTArray{Float32, 1, 1}}, layer_2::@NamedTuple{weight::ConcretePJRTArray{Float32, 2, 1}, bias::ConcretePJRTArray{Float32, 1, 1}}, layer_3::@NamedTuple{weight::ConcretePJRTArray{Float32, 2, 1}, bias::ConcretePJRTArray{Float32, 1, 1}}}, ::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, ::@NamedTuple{layer_1::@NamedTuple{weight::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1}, Tuple{ConcretePJRTNumber{Float64, 1}, ConcretePJRTNumber{Float64, 1}}, ConcretePJRTNumber{Float64, 1}}}, Tuple{ConcretePJRTArray{Float32, 2, 1}, ConcretePJRTArray{Float32, 2, 1}, Tuple{ConcretePJRTNumber{Float32, 1}, ConcretePJRTNumber{Float32, 1}}}}, bias::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1}, Tuple{ConcretePJRTNumber{Float64, 1}, ConcretePJRTNumber{Float64, 1}}, ConcretePJRTNumber{Float64, 1}}}, Tuple{ConcretePJRTArray{Float32, 1, 1}, ConcretePJRTArray{Float32, 1, 1}, Tuple{ConcretePJRTNumber{Float32, 1}, ConcretePJRTNumber{Float32, 1}}}}}, layer_2::@NamedTuple{weight::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1}, Tuple{ConcretePJRTNumber{Float64, 1}, ConcretePJRTNumber{Float64, 1}}, ConcretePJRTNumber{Float64, 1}}}, Tuple{ConcretePJRTArray{Float32, 2, 1}, ConcretePJRTArray{Float32, 2, 1}, Tuple{ConcretePJRTNumber{Float32, 1}, ConcretePJRTNumber{Float32, 1}}}}, bias::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1}, Tuple{ConcretePJRTNumber{Float64, 1}, ConcretePJRTNumber{Float64, 1}}, ConcretePJRTNumber{Float64, 1}}}, Tuple{ConcretePJRTArray{Float32, 1, 1}, ConcretePJRTArray{Float32, 1, 1}, Tuple{ConcretePJRTNumber{Float32, 1}, ConcretePJRTNumber{Float32, 1}}}}}, layer_3::@NamedTuple{weight::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1}, Tuple{ConcretePJRTNumber{Float64, 1}, ConcretePJRTNumber{Float64, 1}}, ConcretePJRTNumber{Float64, 1}}}, Tuple{ConcretePJRTArray{Float32, 2, 1}, ConcretePJRTArray{Float32, 2, 1}, Tuple{ConcretePJRTNumber{Float32, 1}, ConcretePJRTNumber{Float32, 1}}}}, bias::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1}, Tuple{ConcretePJRTNumber{Float64, 1}, ConcretePJRTNumber{Float64, 1}}, ConcretePJRTNumber{Float64, 1}}}, Tuple{ConcretePJRTArray{Float32, 1, 1}, ConcretePJRTArray{Float32, 1, 1}, Tuple{ConcretePJRTNumber{Float32, 1}, ConcretePJRTNumber{Float32, 1}}}}}}, ::@NamedTuple{layer_1::@NamedTuple{weight::ConcretePJRTArray{Float32, 2, 1}, bias::ConcretePJRTArray{Float32, 1, 1}}, layer_2::@NamedTuple{weight::ConcretePJRTArray{Float32, 2, 1}, bias::ConcretePJRTArray{Float32, 1, 1}}, layer_3::@NamedTuple{weight::ConcretePJRTArray{Float32, 2, 1}, bias::ConcretePJRTArray{Float32, 1, 1}}}, ::Bool)
    @ Reactant.Compiler ~/.julia/packages/Reactant/mMNEy/src/Compiler.jl:4288
  [6] (::ReactantExt.var"#26#27"{GenericLossFunction{typeof(Lux.LossFunctionImpl.l2_distance_loss), typeof(mean)}, Tuple{ConcretePJRTArray{Float32, 2, 1}, ConcretePJRTArray{Float32, 2, 1}}})()
    @ ReactantExt ~/.julia/packages/Lux/GCC0y/ext/ReactantExt/training.jl:238
  [7] single_train_step_impl!(backend::Lux.Training.ReactantBackend{Static.True, Missing, Nothing, AutoEnzyme{Nothing, Nothing}}, objective_function::GenericLossFunction{typeof(Lux.LossFunctionImpl.l2_distance_loss), typeof(mean)}, data::Tuple{ConcretePJRTArray{Float32, 2, 1}, ConcretePJRTArray{Float32, 2, 1}}, ts::Lux.Training.TrainState{Nothing, Nothing, Nothing, Chain{@NamedTuple{layer_1::Dense{typeof(NNlib.gelu_tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(NNlib.gelu_tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, @NamedTuple{layer_1::@NamedTuple{weight::ConcretePJRTArray{Float32, 2, 1}, bias::ConcretePJRTArray{Float32, 1, 1}}, layer_2::@NamedTuple{weight::ConcretePJRTArray{Float32, 2, 1}, bias::ConcretePJRTArray{Float32, 1, 1}}, layer_3::@NamedTuple{weight::ConcretePJRTArray{Float32, 2, 1}, bias::ConcretePJRTArray{Float32, 1, 1}}}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, Lux.ReactantCompatibleOptimisers.ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1}, Tuple{ConcretePJRTNumber{Float64, 1}, ConcretePJRTNumber{Float64, 1}}, ConcretePJRTNumber{Float64, 1}}}, @NamedTuple{layer_1::@NamedTuple{weight::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1}, Tuple{ConcretePJRTNumber{Float64, 1}, ConcretePJRTNumber{Float64, 1}}, ConcretePJRTNumber{Float64, 1}}}, Tuple{ConcretePJRTArray{Float32, 2, 1}, ConcretePJRTArray{Float32, 2, 1}, Tuple{ConcretePJRTNumber{Float32, 1}, ConcretePJRTNumber{Float32, 1}}}}, bias::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1}, Tuple{ConcretePJRTNumber{Float64, 1}, ConcretePJRTNumber{Float64, 1}}, ConcretePJRTNumber{Float64, 1}}}, Tuple{ConcretePJRTArray{Float32, 1, 1}, ConcretePJRTArray{Float32, 1, 1}, Tuple{ConcretePJRTNumber{Float32, 1}, ConcretePJRTNumber{Float32, 1}}}}}, layer_2::@NamedTuple{weight::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1}, Tuple{ConcretePJRTNumber{Float64, 1}, ConcretePJRTNumber{Float64, 1}}, ConcretePJRTNumber{Float64, 1}}}, Tuple{ConcretePJRTArray{Float32, 2, 1}, ConcretePJRTArray{Float32, 2, 1}, Tuple{ConcretePJRTNumber{Float32, 1}, ConcretePJRTNumber{Float32, 1}}}}, bias::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1}, Tuple{ConcretePJRTNumber{Float64, 1}, ConcretePJRTNumber{Float64, 1}}, ConcretePJRTNumber{Float64, 1}}}, Tuple{ConcretePJRTArray{Float32, 1, 1}, ConcretePJRTArray{Float32, 1, 1}, Tuple{ConcretePJRTNumber{Float32, 1}, ConcretePJRTNumber{Float32, 1}}}}}, layer_3::@NamedTuple{weight::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1}, Tuple{ConcretePJRTNumber{Float64, 1}, ConcretePJRTNumber{Float64, 1}}, ConcretePJRTNumber{Float64, 1}}}, Tuple{ConcretePJRTArray{Float32, 2, 1}, ConcretePJRTArray{Float32, 2, 1}, Tuple{ConcretePJRTNumber{Float32, 1}, ConcretePJRTNumber{Float32, 1}}}}, bias::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1}, Tuple{ConcretePJRTNumber{Float64, 1}, ConcretePJRTNumber{Float64, 1}}, ConcretePJRTNumber{Float64, 1}}}, Tuple{ConcretePJRTArray{Float32, 1, 1}, ConcretePJRTArray{Float32, 1, 1}, Tuple{ConcretePJRTNumber{Float32, 1}, ConcretePJRTNumber{Float32, 1}}}}}}})
    @ ReactantExt ~/.julia/packages/Lux/GCC0y/ext/ReactantExt/training.jl:182
  [8] single_train_step_impl_with_allocator_cache!
    @ ~/.julia/packages/Lux/GCC0y/src/helpers/training.jl:509 [inlined]
  [9] #single_train_step!#9
    @ ~/.julia/packages/Lux/GCC0y/src/helpers/training.jl:469 [inlined]
 [10] single_train_step!(backend::AutoEnzyme{Nothing, Nothing}, obj_fn::GenericLossFunction{typeof(Lux.LossFunctionImpl.l2_distance_loss), typeof(mean)}, data::Tuple{ConcretePJRTArray{Float32, 2, 1}, ConcretePJRTArray{Float32, 2, 1}}, ts::Lux.Training.TrainState{Nothing, Nothing, Nothing, Chain{@NamedTuple{layer_1::Dense{typeof(NNlib.gelu_tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(NNlib.gelu_tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, @NamedTuple{layer_1::@NamedTuple{weight::ConcretePJRTArray{Float32, 2, 1}, bias::ConcretePJRTArray{Float32, 1, 1}}, layer_2::@NamedTuple{weight::ConcretePJRTArray{Float32, 2, 1}, bias::ConcretePJRTArray{Float32, 1, 1}}, layer_3::@NamedTuple{weight::ConcretePJRTArray{Float32, 2, 1}, bias::ConcretePJRTArray{Float32, 1, 1}}}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, Lux.ReactantCompatibleOptimisers.ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1}, Tuple{ConcretePJRTNumber{Float64, 1}, ConcretePJRTNumber{Float64, 1}}, ConcretePJRTNumber{Float64, 1}}}, @NamedTuple{layer_1::@NamedTuple{weight::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1}, Tuple{ConcretePJRTNumber{Float64, 1}, ConcretePJRTNumber{Float64, 1}}, ConcretePJRTNumber{Float64, 1}}}, Tuple{ConcretePJRTArray{Float32, 2, 1}, ConcretePJRTArray{Float32, 2, 1}, Tuple{ConcretePJRTNumber{Float32, 1}, ConcretePJRTNumber{Float32, 1}}}}, bias::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1}, Tuple{ConcretePJRTNumber{Float64, 1}, ConcretePJRTNumber{Float64, 1}}, ConcretePJRTNumber{Float64, 1}}}, Tuple{ConcretePJRTArray{Float32, 1, 1}, ConcretePJRTArray{Float32, 1, 1}, Tuple{ConcretePJRTNumber{Float32, 1}, ConcretePJRTNumber{Float32, 1}}}}}, layer_2::@NamedTuple{weight::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1}, Tuple{ConcretePJRTNumber{Float64, 1}, ConcretePJRTNumber{Float64, 1}}, ConcretePJRTNumber{Float64, 1}}}, Tuple{ConcretePJRTArray{Float32, 2, 1}, ConcretePJRTArray{Float32, 2, 1}, Tuple{ConcretePJRTNumber{Float32, 1}, ConcretePJRTNumber{Float32, 1}}}}, bias::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1}, Tuple{ConcretePJRTNumber{Float64, 1}, ConcretePJRTNumber{Float64, 1}}, ConcretePJRTNumber{Float64, 1}}}, Tuple{ConcretePJRTArray{Float32, 1, 1}, ConcretePJRTArray{Float32, 1, 1}, Tuple{ConcretePJRTNumber{Float32, 1}, ConcretePJRTNumber{Float32, 1}}}}}, layer_3::@NamedTuple{weight::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1}, Tuple{ConcretePJRTNumber{Float64, 1}, ConcretePJRTNumber{Float64, 1}}, ConcretePJRTNumber{Float64, 1}}}, Tuple{ConcretePJRTArray{Float32, 2, 1}, ConcretePJRTArray{Float32, 2, 1}, Tuple{ConcretePJRTNumber{Float32, 1}, ConcretePJRTNumber{Float32, 1}}}}, bias::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1}, Tuple{ConcretePJRTNumber{Float64, 1}, ConcretePJRTNumber{Float64, 1}}, ConcretePJRTNumber{Float64, 1}}}, Tuple{ConcretePJRTArray{Float32, 1, 1}, ConcretePJRTArray{Float32, 1, 1}, Tuple{ConcretePJRTNumber{Float32, 1}, ConcretePJRTNumber{Float32, 1}}}}}}})
    @ Lux.Training ~/.julia/packages/Lux/GCC0y/src/helpers/training.jl:465
 [11] train_model(model::Chain{@NamedTuple{layer_1::Dense{typeof(NNlib.gelu_tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(NNlib.gelu_tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, ps::@NamedTuple{layer_1::@NamedTuple{weight::ConcretePJRTArray{Float32, 2, 1}, bias::ConcretePJRTArray{Float32, 1, 1}}, layer_2::@NamedTuple{weight::ConcretePJRTArray{Float32, 2, 1}, bias::ConcretePJRTArray{Float32, 1, 1}}, layer_3::@NamedTuple{weight::ConcretePJRTArray{Float32, 2, 1}, bias::ConcretePJRTArray{Float32, 1, 1}}}, st::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, dataloader::DeviceIterator{ReactantDevice{Missing, Missing, Missing, Missing, Union{}}, Base.Iterators.Zip{Tuple{Vector{Matrix{Float32}}, Vector{Matrix{Float32}}}}})
    @ Main ~/Code/Julia/MWE/LuxBenchmark/lux_tutorial.jl:72
 [12] macro expansion
    @ ./timing.jl:697 [inlined]
 [13] top-level scope
    @ ~/Code/Julia/MWE/LuxBenchmark/lux_tutorial.jl:353
 [14] include(mod::Module, _path::String)
    @ Base ./Base.jl:306
 [15] exec_options(opts::Base.JLOptions)
    @ Base ./client.jl:317
 [16] _start()
    @ Base ./client.jl:550
in expression starting at /home/mr/Code/Julia/MWE/LuxBenchmark/lux_tutorial.jl:87

LuxBenchmark on  master [+?] via ஃ v1.12.5 took 7m21s 
❯ 

Output for arg=3

LuxBenchmark on  master [+?] via ஃ v1.12.5 
❯ julia --project=. lux_tutorial.jl 3
[ Info: SET_GPU_BACKEND = true
[ Info: USE_REACTANT_DEVICE = true
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1772218358.347884   98230 service.cc:153] XLA service 0x3a4a0450 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1772218358.347908   98230 service.cc:161]   StreamExecutor [0]: NVIDIA GeForce RTX 5070 Ti, Compute Capability 12.0a (Driver: 13.0.0; Runtime: 13.0.0; Toolkit: 13.0.0; DNN: 9.14.0)
I0000 00:00:1772218358.348604   98230 se_gpu_pjrt_client.cc:1467] Using BFC allocator.
I0000 00:00:1772218358.348640   98230 gpu_helpers.cc:141] XLA backend allocating 12439388160 bytes on device 0 for BFCAllocator.
I0000 00:00:1772218358.348674   98230 gpu_helpers.cc:183] XLA backend will use up to 4146462720 bytes on device 0 for CollectiveBFCAllocator.
I0000 00:00:1772218358.355407   98230 cuda_dnn.cc:461] Loaded cuDNN version 91400
[ Info: xdev = ReactantDevice{Missing, Missing, Missing, Missing, Union{}}(missing, missing, missing)
[ Info: Compiling model with Reactant...
 29.816421 seconds (72.39 M allocations: 3.537 GiB, 4.79% gc time, 99.12% compilation time: <1% of which was recompilation)
[ Info: Inference with compiled model... (1st time)
  0.368544 seconds (816.72 k allocations: 40.730 MiB, 87.53% compilation time)
[ Info: Inference with compiled model... (2nd time)
  0.000175 seconds (27 allocations: 944 bytes)
[ Info: Mean difference: 0.00011303925
[ Info: Training model with Reactant... (1st time)
I0000 00:00:1772218794.696071   98230 dot_merger.cc:481] Merging Dots in computation: main.4
Iter: [   1/1000]       Loss: 1.85156107
Iter: [ 100/1000]       Loss: 0.31370908
Iter: [ 200/1000]       Loss: 0.02984785
Iter: [ 300/1000]       Loss: 0.01746508
Iter: [ 400/1000]       Loss: 0.00948397
Iter: [ 500/1000]       Loss: 0.00613383
Iter: [ 600/1000]       Loss: 0.00411469
Iter: [ 700/1000]       Loss: 0.00279575
Iter: [ 800/1000]       Loss: 0.00192928
Iter: [ 900/1000]       Loss: 0.00136065
Iter: [1000/1000]       Loss: 0.00101000
412.242002 seconds (956.06 M allocations: 46.109 GiB, 2.23% gc time, 98.27% compilation time: <1% of which was recompilation)
[ Info: Training model with Reactant... (2nd time)
I0000 00:00:1772218804.849489   98230 dot_merger.cc:481] Merging Dots in computation: main.4
ERROR: LoadError: INVALID_ARGUMENT: Invalid buffer passed to Execute() as argument 2 to replica 0: INVALID_ARGUMENT: Donation requested for invalid buffer

Stacktrace:
  [1] reactant_err(msg::Cstring)
    @ Reactant.XLA ~/.julia/packages/Reactant/mMNEy/src/xla/Utils.jl:12
  [2] macro expansion
    @ ~/.julia/packages/Reactant/mMNEy/src/xla/PJRT/LoadedExecutable.jl:217 [inlined]
  [3] execute_sharded
    @ ~/.julia/packages/Reactant/mMNEy/src/xla/PJRT/LoadedExecutable.jl:186 [inlined]
  [4] macro expansion
    @ ~/.julia/packages/Reactant/mMNEy/src/Compiler.jl:3788 [inlined]
  [5] (::Reactant.Compiler.Thunk{typeof(ReactantExt.compute_gradients_internal_and_step!), Symbol("##compute_gradients_internal_and_step!_reactant#1148612"), false, Tuple{GenericLossFunction{typeof(Lux.LossFunctionImpl.l2_distance_loss), typeof(mean)}, Chain{@NamedTuple{layer_1::Dense{typeof(NNlib.gelu_tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(NNlib.gelu_tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, Tuple{ConcretePJRTArray{Float32, 2, 1}, ConcretePJRTArray{Float32, 2, 1}}, @NamedTuple{layer_1::@NamedTuple{weight::ConcretePJRTArray{Float32, 2, 1}, bias::ConcretePJRTArray{Float32, 1, 1}}, layer_2::@NamedTuple{weight::ConcretePJRTArray{Float32, 2, 1}, bias::ConcretePJRTArray{Float32, 1, 1}}, layer_3::@NamedTuple{weight::ConcretePJRTArray{Float32, 2, 1}, bias::ConcretePJRTArray{Float32, 1, 1}}}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, @NamedTuple{layer_1::@NamedTuple{weight::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1}, Tuple{ConcretePJRTNumber{Float64, 1}, ConcretePJRTNumber{Float64, 1}}, ConcretePJRTNumber{Float64, 1}}}, Tuple{ConcretePJRTArray{Float32, 2, 1}, ConcretePJRTArray{Float32, 2, 1}, Tuple{ConcretePJRTNumber{Float32, 1}, ConcretePJRTNumber{Float32, 1}}}}, bias::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1}, Tuple{ConcretePJRTNumber{Float64, 1}, ConcretePJRTNumber{Float64, 1}}, ConcretePJRTNumber{Float64, 1}}}, Tuple{ConcretePJRTArray{Float32, 1, 1}, ConcretePJRTArray{Float32, 1, 1}, Tuple{ConcretePJRTNumber{Float32, 1}, ConcretePJRTNumber{Float32, 1}}}}}, layer_2::@NamedTuple{weight::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1}, Tuple{ConcretePJRTNumber{Float64, 1}, ConcretePJRTNumber{Float64, 1}}, ConcretePJRTNumber{Float64, 1}}}, Tuple{ConcretePJRTArray{Float32, 2, 1}, ConcretePJRTArray{Float32, 2, 1}, Tuple{ConcretePJRTNumber{Float32, 1}, ConcretePJRTNumber{Float32, 1}}}}, bias::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1}, Tuple{ConcretePJRTNumber{Float64, 1}, ConcretePJRTNumber{Float64, 1}}, ConcretePJRTNumber{Float64, 1}}}, Tuple{ConcretePJRTArray{Float32, 1, 1}, ConcretePJRTArray{Float32, 1, 1}, Tuple{ConcretePJRTNumber{Float32, 1}, ConcretePJRTNumber{Float32, 1}}}}}, layer_3::@NamedTuple{weight::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1}, Tuple{ConcretePJRTNumber{Float64, 1}, ConcretePJRTNumber{Float64, 1}}, ConcretePJRTNumber{Float64, 1}}}, Tuple{ConcretePJRTArray{Float32, 2, 1}, ConcretePJRTArray{Float32, 2, 1}, Tuple{ConcretePJRTNumber{Float32, 1}, ConcretePJRTNumber{Float32, 1}}}}, bias::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1}, Tuple{ConcretePJRTNumber{Float64, 1}, ConcretePJRTNumber{Float64, 1}}, ConcretePJRTNumber{Float64, 1}}}, Tuple{ConcretePJRTArray{Float32, 1, 1}, ConcretePJRTArray{Float32, 1, 1}, Tuple{ConcretePJRTNumber{Float32, 1}, ConcretePJRTNumber{Float32, 1}}}}}}, @NamedTuple{layer_1::@NamedTuple{weight::ConcretePJRTArray{Float32, 2, 1}, bias::ConcretePJRTArray{Float32, 1, 1}}, layer_2::@NamedTuple{weight::ConcretePJRTArray{Float32, 2, 1}, bias::ConcretePJRTArray{Float32, 1, 1}}, layer_3::@NamedTuple{weight::ConcretePJRTArray{Float32, 2, 1}, bias::ConcretePJRTArray{Float32, 1, 1}}}, Bool}, Reactant.XLA.PJRT.LoadedExecutable, Reactant.XLA.PJRT.Device, Reactant.XLA.PJRT.Client, Tuple{}, Vector{Bool}})(::GenericLossFunction{typeof(Lux.LossFunctionImpl.l2_distance_loss), typeof(mean)}, ::Chain{@NamedTuple{layer_1::Dense{typeof(NNlib.gelu_tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(NNlib.gelu_tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, ::Tuple{ConcretePJRTArray{Float32, 2, 1}, ConcretePJRTArray{Float32, 2, 1}}, ::@NamedTuple{layer_1::@NamedTuple{weight::ConcretePJRTArray{Float32, 2, 1}, bias::ConcretePJRTArray{Float32, 1, 1}}, layer_2::@NamedTuple{weight::ConcretePJRTArray{Float32, 2, 1}, bias::ConcretePJRTArray{Float32, 1, 1}}, layer_3::@NamedTuple{weight::ConcretePJRTArray{Float32, 2, 1}, bias::ConcretePJRTArray{Float32, 1, 1}}}, ::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, ::@NamedTuple{layer_1::@NamedTuple{weight::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1}, Tuple{ConcretePJRTNumber{Float64, 1}, ConcretePJRTNumber{Float64, 1}}, ConcretePJRTNumber{Float64, 1}}}, Tuple{ConcretePJRTArray{Float32, 2, 1}, ConcretePJRTArray{Float32, 2, 1}, Tuple{ConcretePJRTNumber{Float32, 1}, ConcretePJRTNumber{Float32, 1}}}}, bias::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1}, Tuple{ConcretePJRTNumber{Float64, 1}, ConcretePJRTNumber{Float64, 1}}, ConcretePJRTNumber{Float64, 1}}}, Tuple{ConcretePJRTArray{Float32, 1, 1}, ConcretePJRTArray{Float32, 1, 1}, Tuple{ConcretePJRTNumber{Float32, 1}, ConcretePJRTNumber{Float32, 1}}}}}, layer_2::@NamedTuple{weight::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1}, Tuple{ConcretePJRTNumber{Float64, 1}, ConcretePJRTNumber{Float64, 1}}, ConcretePJRTNumber{Float64, 1}}}, Tuple{ConcretePJRTArray{Float32, 2, 1}, ConcretePJRTArray{Float32, 2, 1}, Tuple{ConcretePJRTNumber{Float32, 1}, ConcretePJRTNumber{Float32, 1}}}}, bias::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1}, Tuple{ConcretePJRTNumber{Float64, 1}, ConcretePJRTNumber{Float64, 1}}, ConcretePJRTNumber{Float64, 1}}}, Tuple{ConcretePJRTArray{Float32, 1, 1}, ConcretePJRTArray{Float32, 1, 1}, Tuple{ConcretePJRTNumber{Float32, 1}, ConcretePJRTNumber{Float32, 1}}}}}, layer_3::@NamedTuple{weight::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1}, Tuple{ConcretePJRTNumber{Float64, 1}, ConcretePJRTNumber{Float64, 1}}, ConcretePJRTNumber{Float64, 1}}}, Tuple{ConcretePJRTArray{Float32, 2, 1}, ConcretePJRTArray{Float32, 2, 1}, Tuple{ConcretePJRTNumber{Float32, 1}, ConcretePJRTNumber{Float32, 1}}}}, bias::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1}, Tuple{ConcretePJRTNumber{Float64, 1}, ConcretePJRTNumber{Float64, 1}}, ConcretePJRTNumber{Float64, 1}}}, Tuple{ConcretePJRTArray{Float32, 1, 1}, ConcretePJRTArray{Float32, 1, 1}, Tuple{ConcretePJRTNumber{Float32, 1}, ConcretePJRTNumber{Float32, 1}}}}}}, ::@NamedTuple{layer_1::@NamedTuple{weight::ConcretePJRTArray{Float32, 2, 1}, bias::ConcretePJRTArray{Float32, 1, 1}}, layer_2::@NamedTuple{weight::ConcretePJRTArray{Float32, 2, 1}, bias::ConcretePJRTArray{Float32, 1, 1}}, layer_3::@NamedTuple{weight::ConcretePJRTArray{Float32, 2, 1}, bias::ConcretePJRTArray{Float32, 1, 1}}}, ::Bool)
    @ Reactant.Compiler ~/.julia/packages/Reactant/mMNEy/src/Compiler.jl:4288
  [6] (::ReactantExt.var"#26#27"{GenericLossFunction{typeof(Lux.LossFunctionImpl.l2_distance_loss), typeof(mean)}, Tuple{ConcretePJRTArray{Float32, 2, 1}, ConcretePJRTArray{Float32, 2, 1}}})()
    @ ReactantExt ~/.julia/packages/Lux/GCC0y/ext/ReactantExt/training.jl:238
  [7] single_train_step_impl!(backend::Lux.Training.ReactantBackend{Static.True, Missing, Nothing, AutoEnzyme{Nothing, Nothing}}, objective_function::GenericLossFunction{typeof(Lux.LossFunctionImpl.l2_distance_loss), typeof(mean)}, data::Tuple{ConcretePJRTArray{Float32, 2, 1}, ConcretePJRTArray{Float32, 2, 1}}, ts::Lux.Training.TrainState{Nothing, Nothing, Nothing, Chain{@NamedTuple{layer_1::Dense{typeof(NNlib.gelu_tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(NNlib.gelu_tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, @NamedTuple{layer_1::@NamedTuple{weight::ConcretePJRTArray{Float32, 2, 1}, bias::ConcretePJRTArray{Float32, 1, 1}}, layer_2::@NamedTuple{weight::ConcretePJRTArray{Float32, 2, 1}, bias::ConcretePJRTArray{Float32, 1, 1}}, layer_3::@NamedTuple{weight::ConcretePJRTArray{Float32, 2, 1}, bias::ConcretePJRTArray{Float32, 1, 1}}}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, Lux.ReactantCompatibleOptimisers.ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1}, Tuple{ConcretePJRTNumber{Float64, 1}, ConcretePJRTNumber{Float64, 1}}, ConcretePJRTNumber{Float64, 1}}}, @NamedTuple{layer_1::@NamedTuple{weight::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1}, Tuple{ConcretePJRTNumber{Float64, 1}, ConcretePJRTNumber{Float64, 1}}, ConcretePJRTNumber{Float64, 1}}}, Tuple{ConcretePJRTArray{Float32, 2, 1}, ConcretePJRTArray{Float32, 2, 1}, Tuple{ConcretePJRTNumber{Float32, 1}, ConcretePJRTNumber{Float32, 1}}}}, bias::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1}, Tuple{ConcretePJRTNumber{Float64, 1}, ConcretePJRTNumber{Float64, 1}}, ConcretePJRTNumber{Float64, 1}}}, Tuple{ConcretePJRTArray{Float32, 1, 1}, ConcretePJRTArray{Float32, 1, 1}, Tuple{ConcretePJRTNumber{Float32, 1}, ConcretePJRTNumber{Float32, 1}}}}}, layer_2::@NamedTuple{weight::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1}, Tuple{ConcretePJRTNumber{Float64, 1}, ConcretePJRTNumber{Float64, 1}}, ConcretePJRTNumber{Float64, 1}}}, Tuple{ConcretePJRTArray{Float32, 2, 1}, ConcretePJRTArray{Float32, 2, 1}, Tuple{ConcretePJRTNumber{Float32, 1}, ConcretePJRTNumber{Float32, 1}}}}, bias::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1}, Tuple{ConcretePJRTNumber{Float64, 1}, ConcretePJRTNumber{Float64, 1}}, ConcretePJRTNumber{Float64, 1}}}, Tuple{ConcretePJRTArray{Float32, 1, 1}, ConcretePJRTArray{Float32, 1, 1}, Tuple{ConcretePJRTNumber{Float32, 1}, ConcretePJRTNumber{Float32, 1}}}}}, layer_3::@NamedTuple{weight::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1}, Tuple{ConcretePJRTNumber{Float64, 1}, ConcretePJRTNumber{Float64, 1}}, ConcretePJRTNumber{Float64, 1}}}, Tuple{ConcretePJRTArray{Float32, 2, 1}, ConcretePJRTArray{Float32, 2, 1}, Tuple{ConcretePJRTNumber{Float32, 1}, ConcretePJRTNumber{Float32, 1}}}}, bias::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1}, Tuple{ConcretePJRTNumber{Float64, 1}, ConcretePJRTNumber{Float64, 1}}, ConcretePJRTNumber{Float64, 1}}}, Tuple{ConcretePJRTArray{Float32, 1, 1}, ConcretePJRTArray{Float32, 1, 1}, Tuple{ConcretePJRTNumber{Float32, 1}, ConcretePJRTNumber{Float32, 1}}}}}}})
    @ ReactantExt ~/.julia/packages/Lux/GCC0y/ext/ReactantExt/training.jl:182
  [8] single_train_step_impl_with_allocator_cache!
    @ ~/.julia/packages/Lux/GCC0y/src/helpers/training.jl:509 [inlined]
  [9] #single_train_step!#9
    @ ~/.julia/packages/Lux/GCC0y/src/helpers/training.jl:469 [inlined]
 [10] single_train_step!(backend::AutoEnzyme{Nothing, Nothing}, obj_fn::GenericLossFunction{typeof(Lux.LossFunctionImpl.l2_distance_loss), typeof(mean)}, data::Tuple{ConcretePJRTArray{Float32, 2, 1}, ConcretePJRTArray{Float32, 2, 1}}, ts::Lux.Training.TrainState{Nothing, Nothing, Nothing, Chain{@NamedTuple{layer_1::Dense{typeof(NNlib.gelu_tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(NNlib.gelu_tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, @NamedTuple{layer_1::@NamedTuple{weight::ConcretePJRTArray{Float32, 2, 1}, bias::ConcretePJRTArray{Float32, 1, 1}}, layer_2::@NamedTuple{weight::ConcretePJRTArray{Float32, 2, 1}, bias::ConcretePJRTArray{Float32, 1, 1}}, layer_3::@NamedTuple{weight::ConcretePJRTArray{Float32, 2, 1}, bias::ConcretePJRTArray{Float32, 1, 1}}}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, Lux.ReactantCompatibleOptimisers.ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1}, Tuple{ConcretePJRTNumber{Float64, 1}, ConcretePJRTNumber{Float64, 1}}, ConcretePJRTNumber{Float64, 1}}}, @NamedTuple{layer_1::@NamedTuple{weight::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1}, Tuple{ConcretePJRTNumber{Float64, 1}, ConcretePJRTNumber{Float64, 1}}, ConcretePJRTNumber{Float64, 1}}}, Tuple{ConcretePJRTArray{Float32, 2, 1}, ConcretePJRTArray{Float32, 2, 1}, Tuple{ConcretePJRTNumber{Float32, 1}, ConcretePJRTNumber{Float32, 1}}}}, bias::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1}, Tuple{ConcretePJRTNumber{Float64, 1}, ConcretePJRTNumber{Float64, 1}}, ConcretePJRTNumber{Float64, 1}}}, Tuple{ConcretePJRTArray{Float32, 1, 1}, ConcretePJRTArray{Float32, 1, 1}, Tuple{ConcretePJRTNumber{Float32, 1}, ConcretePJRTNumber{Float32, 1}}}}}, layer_2::@NamedTuple{weight::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1}, Tuple{ConcretePJRTNumber{Float64, 1}, ConcretePJRTNumber{Float64, 1}}, ConcretePJRTNumber{Float64, 1}}}, Tuple{ConcretePJRTArray{Float32, 2, 1}, ConcretePJRTArray{Float32, 2, 1}, Tuple{ConcretePJRTNumber{Float32, 1}, ConcretePJRTNumber{Float32, 1}}}}, bias::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1}, Tuple{ConcretePJRTNumber{Float64, 1}, ConcretePJRTNumber{Float64, 1}}, ConcretePJRTNumber{Float64, 1}}}, Tuple{ConcretePJRTArray{Float32, 1, 1}, ConcretePJRTArray{Float32, 1, 1}, Tuple{ConcretePJRTNumber{Float32, 1}, ConcretePJRTNumber{Float32, 1}}}}}, layer_3::@NamedTuple{weight::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1}, Tuple{ConcretePJRTNumber{Float64, 1}, ConcretePJRTNumber{Float64, 1}}, ConcretePJRTNumber{Float64, 1}}}, Tuple{ConcretePJRTArray{Float32, 2, 1}, ConcretePJRTArray{Float32, 2, 1}, Tuple{ConcretePJRTNumber{Float32, 1}, ConcretePJRTNumber{Float32, 1}}}}, bias::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1}, Tuple{ConcretePJRTNumber{Float64, 1}, ConcretePJRTNumber{Float64, 1}}, ConcretePJRTNumber{Float64, 1}}}, Tuple{ConcretePJRTArray{Float32, 1, 1}, ConcretePJRTArray{Float32, 1, 1}, Tuple{ConcretePJRTNumber{Float32, 1}, ConcretePJRTNumber{Float32, 1}}}}}}})
    @ Lux.Training ~/.julia/packages/Lux/GCC0y/src/helpers/training.jl:465
 [11] train_model(model::Chain{@NamedTuple{layer_1::Dense{typeof(NNlib.gelu_tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(NNlib.gelu_tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, ps::@NamedTuple{layer_1::@NamedTuple{weight::ConcretePJRTArray{Float32, 2, 1}, bias::ConcretePJRTArray{Float32, 1, 1}}, layer_2::@NamedTuple{weight::ConcretePJRTArray{Float32, 2, 1}, bias::ConcretePJRTArray{Float32, 1, 1}}, layer_3::@NamedTuple{weight::ConcretePJRTArray{Float32, 2, 1}, bias::ConcretePJRTArray{Float32, 1, 1}}}, st::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, dataloader::DeviceIterator{ReactantDevice{Missing, Missing, Missing, Missing, Union{}}, Base.Iterators.Zip{Tuple{Vector{Matrix{Float32}}, Vector{Matrix{Float32}}}}})
    @ Main ~/Code/Julia/MWE/LuxBenchmark/lux_tutorial.jl:72
 [12] macro expansion
    @ ./timing.jl:697 [inlined]
 [13] top-level scope
    @ ~/Code/Julia/MWE/LuxBenchmark/lux_tutorial.jl:353
 [14] include(mod::Module, _path::String)
    @ Base ./Base.jl:306
 [15] exec_options(opts::Base.JLOptions)
    @ Base ./client.jl:317
 [16] _start()
    @ Base ./client.jl:550
in expression starting at /home/mr/Code/Julia/MWE/LuxBenchmark/lux_tutorial.jl:87

LuxBenchmark on  master [+?] via ஃ v1.12.5 took 7m31s 
❯ 

System and Julia info:

LuxBenchmark on  master [+?] via ஃ v1.12.5 
❯ julia --project=. -q
julia> versioninfo()
Julia Version 1.12.5
Commit 5fe89b8ddc1 (2026-02-09 16:05 UTC)
Build Info:
  Official https://julialang.org release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 12 × AMD Ryzen 5 5600X 6-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-18.1.7 (ORCJIT, znver3)
  GC: Built with stock GC
Threads: 1 default, 1 interactive, 1 GC (on 12 virtual cores)

(LuxBenchmark) pkg> st
Status `~/Code/Julia/MWE/LuxBenchmark/Project.toml`
  [7da242da] Enzyme v0.13.129
  [b2108857] Lux v1.31.3
  [d0bbae9a] LuxCUDA v0.3.4
  [da2b9cff] Mooncake v0.5.8
  [3bd65402] Optimisers v0.4.7
  [92933f4c] ProgressMeter v1.11.0
  [3c362404] Reactant v0.2.228
  [10745b16] Statistics v1.11.1
  [e88e6eb3] Zygote v0.7.10
  [56ddb016] Logging v1.11.0
  [de0858da] Printf v1.11.0
  [9a3f8284] Random v1.11.0

(LuxBenchmark) pkg> 

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions