Skip to content

Commit a2bd7f2

Browse files
Merge pull request #1162 from ChrisRackauckas-Claude/fix-splitfunction-jvp-1109
Fix SplitFunction to use user-provided jvp function
2 parents 92d8460 + 4c8ae35 commit a2bd7f2

File tree

1 file changed

+31
-12
lines changed

1 file changed

+31
-12
lines changed

src/scimlfunctions.jl

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4972,13 +4972,16 @@ function __has_initializeprob(f)
49724972
has_initialization_data(f) && hasfield(typeof(f.initialization_data), :initializeprob)
49734973
end
49744974
function __has_update_initializeprob!(f)
4975-
has_initialization_data(f) && hasfield(typeof(f.initialization_data), :update_initializeprob!)
4975+
has_initialization_data(f) &&
4976+
hasfield(typeof(f.initialization_data), :update_initializeprob!)
49764977
end
49774978
function __has_initializeprobmap(f)
4978-
has_initialization_data(f) && hasfield(typeof(f.initialization_data), :initializeprobmap)
4979+
has_initialization_data(f) &&
4980+
hasfield(typeof(f.initialization_data), :initializeprobmap)
49794981
end
49804982
function __has_initializeprobpmap(f)
4981-
has_initialization_data(f) && hasfield(typeof(f.initialization_data), :initializeprobpmap)
4983+
has_initialization_data(f) &&
4984+
hasfield(typeof(f.initialization_data), :initializeprobpmap)
49824985
end
49834986
__has_initialization_data(f) = hasfield(typeof(f), :initialization_data)
49844987
__has_polynomialize(f) = hasfield(typeof(f), :polynomialize)
@@ -5046,15 +5049,31 @@ function has_observed(f::AbstractSciMLFunction)
50465049
end
50475050
has_colorvec(f::AbstractSciMLFunction) = __has_colorvec(f) && f.colorvec !== nothing
50485051

5049-
# TODO: find an appropriate way to check `has_*`
5050-
has_jac(f::Union{SplitFunction, SplitSDEFunction}) = has_jac(f.f1)
5051-
has_jvp(f::Union{SplitFunction, SplitSDEFunction}) = has_jvp(f.f1)
5052-
has_vjp(f::Union{SplitFunction, SplitSDEFunction}) = has_vjp(f.f1)
5053-
has_tgrad(f::Union{SplitFunction, SplitSDEFunction}) = has_tgrad(f.f1)
5054-
has_Wfact(f::Union{SplitFunction, SplitSDEFunction}) = has_Wfact(f.f1)
5055-
has_Wfact_t(f::Union{SplitFunction, SplitSDEFunction}) = has_Wfact_t(f.f1)
5056-
has_paramjac(f::Union{SplitFunction, SplitSDEFunction}) = has_paramjac(f.f1)
5057-
has_colorvec(f::Union{SplitFunction, SplitSDEFunction}) = has_colorvec(f.f1)
5052+
# Check the SplitFunction's own fields first before delegating to f.f1
5053+
function has_jac(f::Union{SplitFunction, SplitSDEFunction})
5054+
(__has_jac(f) && f.jac !== nothing) || has_jac(f.f1)
5055+
end
5056+
function has_jvp(f::Union{SplitFunction, SplitSDEFunction})
5057+
(__has_jvp(f) && f.jvp !== nothing) || has_jvp(f.f1)
5058+
end
5059+
function has_vjp(f::Union{SplitFunction, SplitSDEFunction})
5060+
(__has_vjp(f) && f.vjp !== nothing) || has_vjp(f.f1)
5061+
end
5062+
function has_tgrad(f::Union{SplitFunction, SplitSDEFunction})
5063+
(__has_tgrad(f) && f.tgrad !== nothing) || has_tgrad(f.f1)
5064+
end
5065+
function has_Wfact(f::Union{SplitFunction, SplitSDEFunction})
5066+
(__has_Wfact(f) && f.Wfact !== nothing) || has_Wfact(f.f1)
5067+
end
5068+
function has_Wfact_t(f::Union{SplitFunction, SplitSDEFunction})
5069+
(__has_Wfact_t(f) && f.Wfact_t !== nothing) || has_Wfact_t(f.f1)
5070+
end
5071+
function has_paramjac(f::Union{SplitFunction, SplitSDEFunction})
5072+
(__has_paramjac(f) && f.paramjac !== nothing) || has_paramjac(f.f1)
5073+
end
5074+
function has_colorvec(f::Union{SplitFunction, SplitSDEFunction})
5075+
(__has_colorvec(f) && f.colorvec !== nothing) || has_colorvec(f.f1)
5076+
end
50585077

50595078
has_jac(f::Union{DynamicalODEFunction, DynamicalDDEFunction}) = has_jac(f.f1)
50605079
has_jvp(f::Union{DynamicalODEFunction, DynamicalDDEFunction}) = has_jvp(f.f1)

0 commit comments

Comments
 (0)