Skip to content

Commit 8ac1f7d

Browse files
authored
Fix and test promote_rule definitions (#207)
* Fix and test `promote_rule` definitions * Update Project.toml
1 parent ac44511 commit 8ac1f7d

File tree

4 files changed

+29
-4
lines changed

4 files changed

+29
-4
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ReverseDiff"
22
uuid = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
3-
version = "1.14.2"
3+
version = "1.14.3"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/ReverseDiff.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ using ChainRulesCore
2424
# Not all operations will be valid over all of these types, but that's okay; such cases
2525
# will simply error when they hit the original operation in the overloaded definition.
2626
const ARRAY_TYPES = (:AbstractArray, :AbstractVector, :AbstractMatrix, :Array, :Vector, :Matrix)
27-
const REAL_TYPES = (:Bool, :Integer, :(Irrational{:e}), :(Irrational{}), :Rational, :BigFloat, :BigInt, :AbstractFloat, :Real, :Dual)
27+
const REAL_TYPES = (:Bool, :Integer, :(Irrational{:}), :(Irrational{}), :Rational, :BigFloat, :BigInt, :AbstractFloat, :Real, :Dual)
2828

2929
const SKIPPED_UNARY_SCALAR_FUNCS = Symbol[:isinf, :isnan, :isfinite, :iseven, :isodd, :isreal, :isinteger]
3030
const SKIPPED_BINARY_SCALAR_FUNCS = Symbol[:isequal, :isless, :<, :>, :(==), :(!=), :(<=), :(>=)]

src/tracked.jl

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -270,10 +270,22 @@ Base.convert(::Type{T}, t::T) where {T<:TrackedReal} = t
270270
Base.convert(::Type{T}, t::T) where {T<:TrackedArray} = t
271271

272272
for R in REAL_TYPES
273-
@eval Base.promote_rule(::Type{$R}, ::Type{TrackedReal{V,D,O}}) where {V,D,O} = TrackedReal{promote_type($R,V),D,O}
273+
R === :Dual && continue # ForwardDiff.Dual is handled below
274+
@eval begin
275+
if isconcretetype($R) # issue ForwardDiff#322
276+
Base.promote_rule(::Type{TrackedReal{V,D,O}}, ::Type{$R}) where {V,D,O} = TrackedReal{promote_type(V,$R),D,O}
277+
Base.promote_rule(::Type{$R}, ::Type{TrackedReal{V,D,O}}) where {V,D,O} = TrackedReal{promote_type($R,V),D,O}
278+
else
279+
Base.promote_rule(::Type{TrackedReal{V,D,O}}, ::Type{R}) where {V,D,O,R<:$R} = TrackedReal{promote_type(V,R),D,O}
280+
Base.promote_rule(::Type{R}, ::Type{TrackedReal{V,D,O}}) where {R<:$R,V,D,O,} = TrackedReal{promote_type(R,V),D,O}
281+
end
282+
end
274283
end
275284

276-
Base.promote_rule(::Type{R}, ::Type{TrackedReal{V,D,O}}) where {R<:Real,V,D,O} = TrackedReal{promote_type(R,V),D,O}
285+
# Avoid method ambiguities for ForwardDiff.Dual
286+
Base.promote_rule(::Type{TrackedReal{V1,D,O}}, ::Type{Dual{T,V2,N}}) where {V1,D,O,T,V2,N} = TrackedReal{promote_type(V1,Dual{T,V2,N}),D,O}
287+
Base.promote_rule(::Type{Dual{T,V1,N}}, ::Type{TrackedReal{V2,D,O}}) where {T,V1,N,V2,D,O} = TrackedReal{promote_type(Dual{T,V1,N},V2),D,O}
288+
277289
Base.promote_rule(::Type{TrackedReal{V1,D1,O1}}, ::Type{TrackedReal{V2,D2,O2}}) where {V1,V2,D1,D2,O1,O2} = TrackedReal{promote_type(V1,V2),promote_type(D1,D2),Nothing}
278290

279291
###########################

test/TrackedTests.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ module TrackedTests
33
using ReverseDiff, Test
44
using ReverseDiff: TrackedReal, TrackedArray
55

6+
import ForwardDiff
7+
68
include(joinpath(dirname(@__FILE__), "utils.jl"))
79

810
samefields(a, b) = a === b
@@ -601,8 +603,19 @@ empty!(tp)
601603
@test convert(typeof(ta), ta) === ta
602604
@test convert(typeof(ta1), ta1) === ta1
603605

606+
@test promote_type(T, Bool) === T
607+
@test promote_type(T, Int32) === T
608+
@test promote_type(T, Int64) === T
609+
@test promote_type(T, Integer) === TrackedReal{BigInt,Float64,A}
610+
@test promote_type(T, typeof(ℯ)) === TrackedReal{BigFloat,Float64,A}
611+
@test promote_type(T, typeof(π)) === TrackedReal{BigFloat,Float64,A}
612+
@test promote_type(T, Rational{Int}) === TrackedReal{Rational{BigInt},Float64,A}
604613
@test promote_type(T, BigFloat) === TrackedReal{BigFloat,Float64,A}
614+
@test promote_type(T, BigInt) === T
605615
@test promote_type(T, Float64) === TrackedReal{BigFloat,Float64,A}
616+
@test promote_type(T, AbstractFloat) === TrackedReal{BigFloat,Float64,A}
617+
@test promote_type(T, Real) === TrackedReal{Real,Float64,A}
618+
@test promote_type(T, ForwardDiff.Dual{:tag,Float64,1}) === TrackedReal{ForwardDiff.Dual{:tag,BigFloat,1},Float64,A}
606619
@test promote_type(T, TrackedReal{BigFloat,BigFloat,Nothing}) === TrackedReal{BigFloat,BigFloat,Nothing}
607620
@test promote_type(T, T) === T
608621

0 commit comments

Comments
 (0)