Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRulesCore"
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
version = "1.0.1"
version = "1.0.2"

[deps]
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Expand Down
28 changes: 23 additions & 5 deletions src/projection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -134,14 +134,32 @@ ProjectTo(::Bool) = ProjectTo{NoTangent}() # same projector as ProjectTo(::Abst
ProjectTo(::Real) = ProjectTo{Real}()
ProjectTo(::Complex) = ProjectTo{Complex}()
ProjectTo(::Number) = ProjectTo{Number}()

ProjectTo(x::Integer) = ProjectTo(float(x))
ProjectTo(x::Complex{<:Integer}) = ProjectTo(float(x))

# Preserve low-precision floats as accidental promotion is a common performance bug
for T in (Float16, Float32, Float64, ComplexF16, ComplexF32, ComplexF64)
# Preserve low-precision floats as accidental promotion is a common perforance bug
@eval ProjectTo(::$T) = ProjectTo{$T}()
end
ProjectTo(x::Integer) = ProjectTo(float(x))
ProjectTo(x::Complex{<:Integer}) = ProjectTo(float(x))
(::ProjectTo{T})(dx::Number) where {T<:Number} = convert(T, dx)
(::ProjectTo{T})(dx::Number) where {T<:Real} = convert(T, real(dx))

# In these cases we can just `convert` as we know we are dealing with plain and simple types
(::ProjectTo{T})(dx::AbstractFloat) where T<:AbstractFloat = convert(T, dx)
(::ProjectTo{T})(dx::Integer) where T<:AbstractFloat = convert(T, dx) #needed to avoid ambiguity
# simple Complex{<:AbstractFloat}} cases
(::ProjectTo{T})(dx::Complex{<:AbstractFloat}) where {T<:Complex{<:AbstractFloat}} = convert(T, dx)
(::ProjectTo{T})(dx::AbstractFloat) where {T<:Complex{<:AbstractFloat}} = convert(T, dx)
(::ProjectTo{T})(dx::Complex{<:Integer}) where {T<:Complex{<:AbstractFloat}} = convert(T, dx)
(::ProjectTo{T})(dx::Integer) where {T<:Complex{<:AbstractFloat}} = convert(T, dx)

# Other numbers, including e.g. ForwardDiff.Dual and Symbolics.Sym, should pass through.
# We assume (lacking evidence to the contrary) that it is the right subspace of numebers
# The (::ProjectTo{T})(::T) method doesn't work because we are allowing a different
# Number type that might not be a subtype of the `project_type`.
Comment on lines +157 to +158
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this true? project_type should be Real, Complex, or Number. I get super-confused by dispatch with type parameters on LHS like {T}(.... But not important really.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for e.g. Float32 or ComplexF64 etc it is not the case that it is Real, Complex, or Number.
I could add those in as alternative special cases instead of this.

Copy link
Member Author

@oxinabox oxinabox Jul 27, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would not be shocked if this come backs to bite us, but we can remove it if and when that happens

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh right, I wasn't thinking clearly, that's exactly the use case.

(::ProjectTo{<:Number})(dx::Number) = dx

(project::ProjectTo{<:Real})(dx::Complex) = project(real(dx))
(project::ProjectTo{<:Complex})(dx::Real) = project(complex(dx))

# Arrays
# If we don't have a more specialized `ProjectTo` rule, we just assume that there is
Expand Down
25 changes: 24 additions & 1 deletion test/projection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,15 @@ using ChainRulesCore, Test
using LinearAlgebra, SparseArrays
using OffsetArrays, BenchmarkTools

# Like ForwardDiff.jl's Dual
struct Dual{T<:Real} <: Real
value::T
partial::T
end
Base.real(x::Dual) = x
Base.float(x::Dual) = Dual(float(x.value), float(x.partial))
Base.zero(x::Dual) = Dual(zero(x.value), zero(x.partial))

@testset "projection" begin

#####
Expand All @@ -12,14 +21,28 @@ using OffsetArrays, BenchmarkTools
# real / complex
@test ProjectTo(1.0)(2.0 + 3im) === 2.0
@test ProjectTo(1.0 + 2.0im)(3.0) === 3.0 + 0.0im
@test ProjectTo(2.0+3.0im)(1+1im) === 1.0+1.0im
@test ProjectTo(2.0)(1+1im) === 1.0


# storage
@test ProjectTo(1)(pi) === Float64(pi)
@test ProjectTo(1)(pi) === pi
@test ProjectTo(1 + im)(pi) === ComplexF64(pi)
@test ProjectTo(1//2)(3//4) === 3//4
@test ProjectTo(1.0f0)(1 / 2) === 0.5f0
@test ProjectTo(1.0f0 + 2im)(3) === 3.0f0 + 0im
@test ProjectTo(big(1.0))(2) === 2
@test ProjectTo(1.0)(2) === 2.0
end

@testset "Dual" begin # some weird Real subtype that we should basically leave alone
@test ProjectTo(1.0)(Dual(1.0, 2.0)) isa Dual
@test ProjectTo(1.0)(Dual(1, 2)) isa Dual
@test ProjectTo(1.0 + 1im)(Dual(1.0, 2.0)) isa Complex{<:Dual}
@test ProjectTo(1.0 + 1im)(
Complex(Dual(1.0, 2.0), Dual(1.0, 2.0))
) isa Complex{<:Dual}
@test ProjectTo(1.0)(Complex(Dual(1.0, 2.0), Dual(1.0, 2.0))) isa Dual
end

@testset "Base: arrays of numbers" begin
Expand Down