diff --git a/docs/src/operations.md b/docs/src/operations.md index 64e9779..6f7a8ec 100644 --- a/docs/src/operations.md +++ b/docs/src/operations.md @@ -89,3 +89,19 @@ using Bijections # hide b = Bijection(1 => "alpha", 2 => "beta", 3 => "gamma"); for (x, y) in b; println("$x --> $y"); end ``` + +## Composition + +Given two `Bijection`s `a` and `b`, their composition `c = a ∘ b` or `c = compose(a, b)` is a new `Bijection` with the property that `c[x] = a[b[x]]` for all `x` in the +domain of `b`. + +```jldoctest +julia> a = Bijection{Int,Int}(1 => 10, 2 => 20); + +julia> b = Bijection{String,Int}("hi" => 1, "bye" => 2); + +julia> c = a ∘ b; + +julia> c["hi"] +10 +``` diff --git a/src/Bijections.jl b/src/Bijections.jl index 020bae5..5710fcf 100644 --- a/src/Bijections.jl +++ b/src/Bijections.jl @@ -2,7 +2,7 @@ module Bijections using Serialization: Serialization -export Bijection, active_inv, inverse, hasvalue +export Bijection, active_inv, inverse, hasvalue, compose struct Bijection{K,V,F,Finv} <: AbstractDict{K,V} f::F # map from domain to range @@ -269,4 +269,58 @@ function Serialization.deserialize( return B(f) end +# WARN this uses internals so it's dangerous! +""" + C = composed_dict_type(A::Type{<:AbstractDict}, B::Type{<:AbstractDict}) + +Returns the type of the forward dictionary of `(a ∘ b)` where `A` and `B` are +the types of the forward-dictionaries of `a` and `b`, respectively. + +For any combination of a `IdDict` and a `Dict`, the result will be an `IdDict`. +Otherwise, return `A` with the types of keys and values adjusted so that the +resulting dict maps keys of `b` to values of `a`. +""" +function composed_dict_type( + A::Type{<:AbstractDict{AK,AV}}, ::Type{<:AbstractDict{BK,BV}} +) where {AK,AV,BK,BV} + return A.name.wrapper{BK,AV} +end +function composed_dict_type(::Type{Dict{AK,AV}}, ::Type{Dict{BK,BV}}) where {AK,AV,BK,BV} + Dict{BK,AV} +end +function composed_dict_type(::Type{Dict{AK,AV}}, ::Type{IdDict{BK,BV}}) where {AK,AV,BK,BV} + IdDict{BK,AV} +end +function composed_dict_type(::Type{IdDict{AK,AV}}, ::Type{Dict{BK,BV}}) where {AK,AV,BK,BV} + IdDict{BK,AV} +end +function composed_dict_type( + ::Type{IdDict{AK,AV}}, ::Type{IdDict{BK,BV}} +) where {AK,AV,BK,BV} + IdDict{BK,AV} +end + +""" + c = (∘)(a::Bijection, b::Bijection) + c = compose(a, b) + +The result of `a ∘ b` or `compose(a, b)` is a new `Bijection` `c` such that +`c[x]` is `a[b[x]]` for `x` in the domain of `b`. The internal type of the + forward mapping is determined by [`composed_dict_type`](@ref), and the type + of the backward mapping is determined by [`inverse_dict_type`](@ref). +""" +function compose( + a::Bijection{AK,AV,AF,AFinv}, b::Bijection{BK,BV,BF,BFinv} +) where {AK,AV,AF,AFinv,BK,BV,BF,BFinv} + CF = composed_dict_type(AF, BF) + CFinv = inverse_dict_type(CF) + c = Bijection{BK,AV,CF,CFinv}() + for x in keys(b) + c[x] = a[b[x]] + end + return c +end + +Base.:(∘)(a::Bijection, b::Bijection) = compose(a, b) + end # end of module Bijections diff --git a/test/runtests.jl b/test/runtests.jl index 495a02a..3d033f2 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -38,6 +38,113 @@ using Serialization @test Bijection(collect(b)) == b end +# check composition +@testset "Composition" begin + a = Bijection{Int,Float64}() + a[1] = 10.0 + a[2] = 20.0 + + b = Bijection{String,Int}() + b["hi"] = 1 + b["bye"] = 2 + + c = a ∘ b + @test c["hi"] == 10.0 + @test c(10.0) == "hi" + + @test compose(a, b) == c + + # Mutable objects (1) + # - mutable keys/values in `a`, non-mutable keys/values in `b` + + a = Bijection{ + Int64,Vector{Int64},IdDict{Int64,Vector{Int64}},IdDict{Vector{Int64},Int64} + }() + A₁ = [1, 2, 3] + A₂ = [3, 4, 5] + a[1] = A₁ + a[2] = A₂ + + c = a ∘ b + @test c isa Bijection{ + String,Vector{Int64},IdDict{String,Vector{Int64}},IdDict{Vector{Int64},String} + } + @test c["hi"] === A₁ + @test c(A₁) == "hi" + A₁[1] = 10 + @test c["hi"] === A₁ + @test c(A₁) == "hi" + @test_throws KeyError c([1, 2, 3]) + + # Mutable objects (2) + # - Mutable keys/values both in `a` and `b` + + a = Bijection{ + Int64,Vector{Int64},IdDict{Int64,Vector{Int64}},IdDict{Vector{Int64},Int64} + }() + A₁ = [1, 2, 3] + A₂ = [3, 4, 5] + a[1] = A₁ + a[2] = A₂ + + b = Bijection{ + Vector{Int64},Int64,IdDict{Vector{Int64},Int64},IdDict{Int64,Vector{Int64}} + }() + b₁ = [1, 2, 3] + b₂ = [3, 4, 5] + @test b₁ ≢ A₁ + @test b₂ ≢ A₂ + b[b₁] = 1 + b[b₂] = 2 + + c = a ∘ b + @test c[b₁] === A₁ + @test c(A₁) === b₁ + + b₁[1] = 10 + @test c[b₁] === A₁ + + # Mutable objects (3) + # - Non-mutable keys/values in `a`, mutable keys/values in `b` + + a = Bijection{Int,Float64}() + a[1] = 10.0 + a[2] = 20.0 + + b = Bijection{ + Vector{Int64},Int64,IdDict{Vector{Int64},Int64},Dict{Int64,Vector{Int64}} + }() + b₁ = [1, 2, 3] + b₂ = [3, 4, 5] + b[b₁] = 1 + b[b₂] = 2 + + c = a ∘ b + @test c isa Bijection{ + Vector{Int64},Float64,IdDict{Vector{Int64},Float64},IdDict{Float64,Vector{Int64}} + } + @test c[b₁] == 10.0 + @test c(10.0) === b₁ + + b₁[1] = 10 + @test c[b₁] == 10.0 + + # ImmutableDict (testing the fallback of composed_dict_type) + + a = Bijection{Int,Float64}() + a[1] = 10.0 + a[2] = 20.0 + + b = Bijection( + Base.ImmutableDict("hi" => 1, "bye" => 2), Base.ImmutableDict(1 => "hi", 2 => "bye") + ) + + c = a ∘ b + @test c isa Bijection{String,Float64,Dict{String,Float64},Dict{Float64,String}} + @test c["hi"] == 10.0 + @test c(10.0) == "hi" +end + # Test empty constructor @testset "empty_constructor" begin b = Bijection{Int,String}()