Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ADTypes"
uuid = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
authors = ["Vaibhav Dixit <[email protected]>, Guillaume Dalle and contributors"]
version = "1.10.0"
version = "1.11.0"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
1 change: 1 addition & 0 deletions src/ADTypes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ export AutoChainRules,
AutoReverseDiff,
AutoSymbolics,
AutoTapir,
AutoTaylorDiff,
AutoTracker,
AutoZygote
@public AbstractMode
Expand Down
29 changes: 29 additions & 0 deletions src/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,35 @@ function Base.show(io::IO, backend::AutoForwardDiff{chunksize}) where {chunksize
print(io, ")")
end

"""
AutoTaylorDiff{order}

Struct used to select the [TaylorDiff.jl](https://github.com/JuliaDiff/TaylorDiff.jl) backend for automatic differentiation.

Defined by [ADTypes.jl](https://github.com/SciML/ADTypes.jl).

# Constructors

AutoTaylorDiff(; order = 1)

# Type parameters

- `order`: the order of the Taylor-mode automatic differentiation
"""
struct AutoTaylorDiff{order} <: AbstractADType end

function AutoTaylorDiff(; order = 1)
return AutoTaylorDiff{order}()
end

mode(::AutoTaylorDiff) = ForwardMode()

function Base.show(io::IO, ::AutoTaylorDiff{order}) where {order}
print(io, AutoTaylorDiff, "(")
print(io, "order=", repr(order; context = io))
print(io, ")")
end

"""
AutoGTPSA{D}

Expand Down
12 changes: 12 additions & 0 deletions test/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,18 @@ end
@test !ad.safe_mode
end

@testset "AutoTaylorDiff" begin
ad = AutoTaylorDiff{2}()
@test ad isa AbstractADType
@test ad isa AutoTaylorDiff{2}
@test mode(ad) isa ForwardMode

ad = AutoTaylorDiff()
@test ad isa AbstractADType
@test ad isa AutoTaylorDiff{1}
@test mode(ad) isa ForwardMode
end

@testset "AutoTracker" begin
ad = AutoTracker()
@test ad isa AbstractADType
Expand Down