Thanks for your contributions to deep learning in Julia.
I have a package (NeuralEstimators.jl) that currently uses Flux. I'd like to make the package backend agnostic, in the sense that users can choose either Flux or Lux.
I'm nearly there, but I've run into problem that I'm finding difficult to solve. The public training API (e.g., TrainState, single_train_step!) requires an AbstractLuxLayer. However, declaring a struct to be an AbstractLuxLayer messes up the Flux training API (Flux.withgradient).
To simplify, I essentially have:
struct MyContainer
field1
field2
...
end
where the fields of MyContainer contain Flux models or Lux models. If the fields are Lux models, then I want to pass my_container::MyContainer into TrainState and single_train_step!.
Is this possible? For instance, could the type annotations on TrainState and single_train_step! be loosened, and we instead just define the necessary methods for our custom structs (e.g., initialparameters, initialstates)?
I think this is something that could come up with other packages in the ecosystem looking to adopt Lux while maintaining compatibility with Flux.
Thanks again for your work, and thanks in advance for any comments you might have.
Thanks for your contributions to deep learning in Julia.
I have a package (NeuralEstimators.jl) that currently uses Flux. I'd like to make the package backend agnostic, in the sense that users can choose either Flux or Lux.
I'm nearly there, but I've run into problem that I'm finding difficult to solve. The public training API (e.g.,
TrainState,single_train_step!) requires anAbstractLuxLayer. However, declaring a struct to be anAbstractLuxLayermesses up the Flux training API (Flux.withgradient).To simplify, I essentially have:
where the fields of
MyContainercontain Flux models or Lux models. If the fields are Lux models, then I want to passmy_container::MyContainerintoTrainStateandsingle_train_step!.Is this possible? For instance, could the type annotations on
TrainStateandsingle_train_step!be loosened, and we instead just define the necessary methods for our custom structs (e.g.,initialparameters,initialstates)?I think this is something that could come up with other packages in the ecosystem looking to adopt Lux while maintaining compatibility with Flux.
Thanks again for your work, and thanks in advance for any comments you might have.