-
-
Notifications
You must be signed in to change notification settings - Fork 160
Julia ODE-lstm file added #973
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
src/Odelstm.jl
Outdated
Original paper: https://arxiv.org/abs/2006.04418 | ||
""" | ||
|
||
module ODELSTM |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shouldn't be done in a submodule
src/Odelstm.jl
Outdated
|
||
module ODELSTM | ||
|
||
using DifferentialEquations |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just take the ODE solver from the user, no new deps needed here
src/Odelstm.jl
Outdated
module ODELSTM | ||
|
||
using DifferentialEquations | ||
using Flux |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should use Lux
src/Odelstm.jl
Outdated
|
||
using DifferentialEquations | ||
using Flux | ||
using DiffEqFlux |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cannot import the module from within the same module
src/Odelstm.jl
Outdated
function get_solver(solver_type::Symbol) | ||
solver_map = Dict( | ||
:dopri5 => Tsit5(), | ||
:tsit5 => Tsit5(), | ||
:euler => Euler(), | ||
:heun => Heun(), | ||
:rk4 => RK4() | ||
) | ||
return get(solver_map, solver_type, Tsit5()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just take the solver instead of doing this
src/Odelstm.jl
Outdated
end | ||
|
||
function ODELSTMCell(input_size::Int, hidden_size::Int, solver_type::Symbol=:dopri5) | ||
lstm_cell = Flux.LSTMCell(input_size => hidden_size) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is using the recurrent model, which is not the algorithm and would break adaptivity
src/Odelstm.jl
Outdated
function solve_fixed_step(cell::ODELSTMCell, h, ts) | ||
dt = ts / 3.0 | ||
h_evolved = h | ||
for i in 1:3 | ||
if cell.solver_type == :euler | ||
h_evolved = euler_step(cell.f_node, h_evolved, dt) | ||
elseif cell.solver_type == :heun | ||
h_evolved = heun_step(cell.f_node, h_evolved, dt) | ||
elseif cell.solver_type == :rk4 | ||
h_evolved = rk4_step(cell.f_node, h_evolved, dt) | ||
end | ||
end | ||
return h_evolved | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unnecessary: just use adaptive=false
Updated as per feedback, please review the changes |
end | ||
return results, st | ||
else | ||
t_span = (0.0f0, Float32(ts)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use the types set by the user's input.
else | ||
t_span = (0.0f0, Float32(ts)) | ||
prob = ODEProblem((u,p,t)->cell.f_node(u,p,st)[1], h, t_span) | ||
sol = solve(prob, cell.solver, saveat=[t_span[2]], dense=false) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sol = solve(prob, cell.solver, saveat=[t_span[2]], dense=false) | |
sol = solve(prob, cell.solver, saveat=[t_span[2]]) |
redundant, since if saveat is used then it's false.
|
||
function solve_fixed_step(cell::ODELSTMCell, h, ts, p, st) | ||
prob = ODEProblem((u,p,t)->cell.f_node(u,p,st)[1], h, (0.0f0, Float32(ts))) | ||
sol = solve(prob, cell.solver; adaptive=false) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need to make completely separate, just allow a keyword argument in ODELSTMCell
and just use the bool in the other function, delete the extra code.
for i in 1:batch_size | ||
h_i = h[:, i] | ||
ts_i = ts isa AbstractVector ? ts[i] : ts | ||
t_span = (0.0f0, Float32(ts_i)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This doesn't make sense, the tspan can only be 2 values.
t_span = (0.0f0, Float32(ts_i)) | ||
|
||
prob = ODEProblem((u,p,t)->cell.f_node(u,p,st)[1], h_i, t_span) | ||
sol = solve(prob, cell.solver, saveat=[t_span[2]], dense=false) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did you mean to put saveat = ts
?
export train_model!, evaluate_model, load_dataset | ||
|
||
mutable struct ODELSTMCell{F,S} | ||
lstm_cell::Lux.LSTMCell |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure this is actually the algorithm? Derive it?
Checklist
contributor guidelines, in particular the SciML Style Guide and
COLPRAC.
Additional context
I have implemented the ODE-LSTM code of python to Julia in single file with all functions working right.