Skip to content

Commit c0b42c2

Browse files
committed
Preserve type of mappings in composition
1 parent 2469dbc commit c0b42c2

File tree

2 files changed

+136
-8
lines changed

2 files changed

+136
-8
lines changed

src/Bijections.jl

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -269,15 +269,52 @@ function Serialization.deserialize(
269269
return B(f)
270270
end
271271

272+
# WARN this uses internals so it's dangerous!
272273
"""
273-
c = (∘)(a::Bijection{A,B}, b::Bijection{B,C})::Bijection{A,C} where {A,B,C}
274+
C = composed_dict_type(A::Type{<:AbstractDict}, B::Type{<:AbstractDict})
275+
276+
Returns the type of the forward dictionary of `(a ∘ b)` where `A` and `B` are
277+
the types of the forward-dictionaries of `a` and `b`, respectively.
278+
279+
For any combination of a `IdDict` and a `Dict`, the result will be an `IdDict`.
280+
Otherwise, return `A` with the types of keys and values adjusted so that the
281+
resulting dict maps keys of `b` to values of `a`.
282+
"""
283+
function composed_dict_type(
284+
A::Type{<:AbstractDict{AK,AV}}, ::Type{<:AbstractDict{BK,BV}}
285+
) where {AK,AV,BK,BV}
286+
return A.name.wrapper{BK,AV}
287+
end
288+
function composed_dict_type(::Type{Dict{AK,AV}}, ::Type{Dict{BK,BV}}) where {AK,AV,BK,BV}
289+
Dict{BK,AV}
290+
end
291+
function composed_dict_type(::Type{Dict{AK,AV}}, ::Type{IdDict{BK,BV}}) where {AK,AV,BK,BV}
292+
IdDict{BK,AV}
293+
end
294+
function composed_dict_type(::Type{IdDict{AK,AV}}, ::Type{Dict{BK,BV}}) where {AK,AV,BK,BV}
295+
IdDict{BK,AV}
296+
end
297+
function composed_dict_type(
298+
::Type{IdDict{AK,AV}}, ::Type{IdDict{BK,BV}}
299+
) where {AK,AV,BK,BV}
300+
IdDict{BK,AV}
301+
end
302+
303+
"""
304+
c = (∘)(a::Bijection, b::Bijection)
274305
c = compose(a, b)
275306
276307
The result of `a ∘ b` or `compose(a, b)` is a new `Bijection` `c` such that
277-
`c[x]` is `a[b[x]]` for `x` in the domain of `b`.
308+
`c[x]` is `a[b[x]]` for `x` in the domain of `b`. The internal type of the
309+
forward mapping is determined by [`composed_dict_type`](@ref), and the type
310+
of the backward mapping is determined by [`inverse_dict_type`](@ref).
278311
"""
279-
function compose(a::Bijection{B,A}, b::Bijection{C,B}) where {A,B,C}
280-
c = Bijection{C,A}()
312+
function compose(
313+
a::Bijection{AK,AV,AF,AFinv}, b::Bijection{BK,BV,BF,BFinv}
314+
) where {AK,AV,AF,AFinv,BK,BV,BF,BFinv}
315+
CF = composed_dict_type(AF, BF)
316+
CFinv = inverse_dict_type(CF)
317+
c = Bijection{BK,AV,CF,CFinv}()
281318
for x in keys(b)
282319
c[x] = a[b[x]]
283320
end

test/runtests.jl

Lines changed: 95 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,18 +40,109 @@ end
4040

4141
# check composition
4242
@testset "Composition" begin
43-
a = Bijection{Int,Int}()
44-
a[1] = 10
45-
a[2] = 20
43+
a = Bijection{Int,Float64}()
44+
a[1] = 10.0
45+
a[2] = 20.0
4646

4747
b = Bijection{String,Int}()
4848
b["hi"] = 1
4949
b["bye"] = 2
5050

5151
c = a b
52-
@test c["hi"] == 10
52+
@test c["hi"] == 10.0
53+
@test c(10.0) == "hi"
5354

5455
@test compose(a, b) == c
56+
57+
# Mutable objects (1)
58+
# - mutable keys/values in `a`, non-mutable keys/values in `b`
59+
60+
a = Bijection{
61+
Int64,Vector{Int64},IdDict{Int64,Vector{Int64}},IdDict{Vector{Int64},Int64}
62+
}()
63+
A₁ = [1, 2, 3]
64+
A₂ = [3, 4, 5]
65+
a[1] = A₁
66+
a[2] = A₂
67+
68+
c = a b
69+
@test c isa Bijection{
70+
String,Vector{Int64},IdDict{String,Vector{Int64}},IdDict{Vector{Int64},String}
71+
}
72+
@test c["hi"] === A₁
73+
@test c(A₁) == "hi"
74+
A₁[1] = 10
75+
@test c["hi"] === A₁
76+
@test c(A₁) == "hi"
77+
@test_throws KeyError c([1, 2, 3])
78+
79+
# Mutable objects (2)
80+
# - Mutable keys/values both in `a` and `b`
81+
82+
a = Bijection{
83+
Int64,Vector{Int64},IdDict{Int64,Vector{Int64}},IdDict{Vector{Int64},Int64}
84+
}()
85+
A₁ = [1, 2, 3]
86+
A₂ = [3, 4, 5]
87+
a[1] = A₁
88+
a[2] = A₂
89+
90+
b = Bijection{
91+
Vector{Int64},Int64,IdDict{Vector{Int64},Int64},IdDict{Int64,Vector{Int64}}
92+
}()
93+
b₁ = [1, 2, 3]
94+
b₂ = [3, 4, 5]
95+
@test b₁ A₁
96+
@test b₂ A₂
97+
b[b₁] = 1
98+
b[b₂] = 2
99+
100+
c = a b
101+
@test c[b₁] === A₁
102+
@test c(A₁) === b₁
103+
104+
b₁[1] = 10
105+
@test c[b₁] === A₁
106+
107+
# Mutable objects (3)
108+
# - Non-mutable keys/values in `a`, mutable keys/values in `b`
109+
110+
a = Bijection{Int,Float64}()
111+
a[1] = 10.0
112+
a[2] = 20.0
113+
114+
b = Bijection{
115+
Vector{Int64},Int64,IdDict{Vector{Int64},Int64},Dict{Int64,Vector{Int64}}
116+
}()
117+
b₁ = [1, 2, 3]
118+
b₂ = [3, 4, 5]
119+
b[b₁] = 1
120+
b[b₂] = 2
121+
122+
c = a b
123+
@test c isa Bijection{
124+
Vector{Int64},Float64,IdDict{Vector{Int64},Float64},IdDict{Float64,Vector{Int64}}
125+
}
126+
@test c[b₁] == 10.0
127+
@test c(10.0) === b₁
128+
129+
b₁[1] = 10
130+
@test c[b₁] == 10.0
131+
132+
# ImmutableDict (testing the fallback of composed_dict_type)
133+
134+
a = Bijection{Int,Float64}()
135+
a[1] = 10.0
136+
a[2] = 20.0
137+
138+
b = Bijection(
139+
Base.ImmutableDict("hi" => 1, "bye" => 2), Base.ImmutableDict(1 => "hi", 2 => "bye")
140+
)
141+
142+
c = a b
143+
@test c isa Bijection{String,Float64,Dict{String,Float64},Dict{Float64,String}}
144+
@test c["hi"] == 10.0
145+
@test c(10.0) == "hi"
55146
end
56147

57148
# Test empty constructor

0 commit comments

Comments
 (0)