Skip to content

Training API with models that are not AbstractLuxLayers #1690

@msainsburydale

Description

@msainsburydale

Thanks for your contributions to deep learning in Julia.

I have a package (NeuralEstimators.jl) that currently uses Flux. I'd like to make the package backend agnostic, in the sense that users can choose either Flux or Lux.

I'm nearly there, but I've run into problem that I'm finding difficult to solve. The public training API (e.g., TrainState, single_train_step!) requires an AbstractLuxLayer. However, declaring a struct to be an AbstractLuxLayer messes up the Flux training API (Flux.withgradient).

To simplify, I essentially have:

struct MyContainer
  field1
  field2
  ...
end

where the fields of MyContainer contain Flux models or Lux models. If the fields are Lux models, then I want to pass my_container::MyContainer into TrainState and single_train_step!.

Is this possible? For instance, could the type annotations on TrainState and single_train_step! be loosened, and we instead just define the necessary methods for our custom structs (e.g., initialparameters, initialstates)?

I think this is something that could come up with other packages in the ecosystem looking to adopt Lux while maintaining compatibility with Flux.

Thanks again for your work, and thanks in advance for any comments you might have.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions