Skip to content

Commit 24ecf5d

Browse files
authored
Simplify tests (#98)
1 parent a9ad75d commit 24ecf5d

File tree

7 files changed

+29
-42
lines changed

7 files changed

+29
-42
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,13 @@ julia = "1"
2828

2929
[extras]
3030
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
31+
HCubature = "19dc6840-f33b-545b-b366-655c7e3ffd49"
3132
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
3233
PythonOT = "3c485715-4278-42b2-9b5f-8f00e43c12ef"
3334
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
3435
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
3536
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
3637
Tulip = "6dd1b50a-3aae-11e9-10b5-ef983d2400fa"
37-
HCubature = "19dc6840-f33b-545b-b366-655c7e3ffd49"
3838

3939
[targets]
4040
test = ["ForwardDiff", "Pkg", "PythonOT", "Random", "SafeTestsets", "Test", "Tulip", "HCubature"]

examples/basic/script.jl

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ sinkhorn2(μ, ν, C, ε)
9292
# ```
9393
# One property of the quadratically regularised optimal transport problem is that the
9494
# resulting transport plan $\gamma$ is *sparse*. We take advantage of this and represent it as
95-
# a sparse matrix.
95+
# a sparse matrix.
9696

9797
quadreg(μ, ν, C, ε; maxiter=500);
9898

@@ -120,7 +120,7 @@ norm(γ - γ_pot, Inf)
120120
γpot = POT.sinkhorn(μ, ν, C, ε; method="sinkhorn_epsilon_scaling", numItermax=5000)
121121
norm- γpot, Inf)
122122

123-
# ## Unbalanced optimal transport
123+
# ## Unbalanced optimal transport
124124
#
125125
# [Unbalanced optimal transport](https://doi.org/10.1090/mcom/3303) deals with general
126126
# positive measures which do not necessarily have the same total mass. For unbalanced
@@ -166,10 +166,8 @@ norm(γ - γpot, Inf)
166166

167167
μsupport = νsupport = range(-2, 2; length=100)
168168
C = pairwise(SqEuclidean(), μsupport', νsupport'; dims=2)
169-
μ = exp.(-μsupport .^ 2 ./ 0.5^2)
170-
μ ./= sum(μ)
171-
ν = νsupport .^ 2 .* exp.(-νsupport .^ 2 ./ 0.5^2)
172-
ν ./= sum(ν)
169+
μ = normalize!(exp.(-μsupport .^ 2 ./ 0.5^2), 1)
170+
ν = normalize!(νsupport .^ 2 .* exp.(-νsupport .^ 2 ./ 0.5^2), 1)
173171

174172
plot(μsupport, μ; label=raw"$\mu$", size=(600, 400))
175173
plot!(νsupport, ν; label=raw"$\nu$")
@@ -216,10 +214,8 @@ heatmap(
216214
# $\lambda_1 \in \{0.25, 0.5, 0.75\}$.
217215

218216
support = range(-1, 1; length=250)
219-
mu1 = exp.(-(support .+ 0.5) .^ 2 ./ 0.1^2)
220-
mu1 ./= sum(mu1)
221-
mu2 = exp.(-(support .- 0.5) .^ 2 ./ 0.1^2)
222-
mu2 ./= sum(mu2)
217+
mu1 = normalize!(exp.(-(support .+ 0.5) .^ 2 ./ 0.1^2), 1)
218+
mu2 = normalize!(exp.(-(support .- 0.5) .^ 2 ./ 0.1^2), 1)
223219

224220
plt = plot(; size=(800, 400), legend=:outertopright)
225221
plot!(plt, support, mu1; label=raw"$\mu_1$")

test/entropic.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ using ForwardDiff
55
using LogExpFunctions
66
using PythonOT: PythonOT
77

8+
using LinearAlgebra
89
using Random
910
using Test
1011

@@ -219,10 +220,8 @@ Random.seed!(100)
219220
@testset "example" begin
220221
# set up support
221222
support = range(-1; stop=1, length=250)
222-
μ1 = exp.(-(support .+ 0.5) .^ 2 ./ 0.1^2)
223-
μ1 ./= sum(μ1)
224-
μ2 = exp.(-(support .- 0.5) .^ 2 ./ 0.1^2)
225-
μ2 ./= sum(μ2)
223+
μ1 = normalize!(exp.(-(support .+ 0.5) .^ 2 ./ 0.1^2), 1)
224+
μ2 = normalize!(exp.(-(support .- 0.5) .^ 2 ./ 0.1^2), 1)
226225
μ_all = hcat(μ1, μ2)
227226
# create cost matrix
228227
C = pairwise(SqEuclidean(), support'; dims=2)

test/exact.jl

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,8 @@ Random.seed!(100)
2020
@testset "Earth-Movers Distance" begin
2121
M = 200
2222
N = 250
23-
μ = rand(M)
24-
ν = rand(N)
25-
μ ./= sum(μ)
26-
ν ./= sum(ν)
23+
μ = normalize!(rand(M), 1)
24+
ν = normalize!(rand(N), 1)
2725

2826
@testset "example" begin
2927
# create random cost matrix
@@ -87,8 +85,7 @@ Random.seed!(100)
8785

8886
@testset "semidiscrete case" begin
8987
μ = Normal(randn(), rand())
90-
νprobs = rand(30)
91-
νprobs ./= sum(νprobs)
88+
νprobs = normalize!(rand(30), 1)
9289
ν = Categorical(νprobs)
9390

9491
# compute OT plan
@@ -113,14 +110,12 @@ Random.seed!(100)
113110
@testset "discrete case" begin
114111
# random source and target marginal
115112
m = 30
116-
μprobs = rand(m)
117-
μprobs ./= sum(μprobs)
113+
μprobs = normalize!(rand(m), 1)
118114
μsupport = randn(m)
119115
μ = DiscreteNonParametric(μsupport, μprobs)
120116

121117
n = 50
122-
νprobs = rand(n)
123-
νprobs ./= sum(νprobs)
118+
νprobs = normalize!(rand(n), 1)
124119
νsupport = randn(n)
125120
ν = DiscreteNonParametric(νsupport, νprobs)
126121

@@ -200,11 +195,9 @@ Random.seed!(100)
200195
end
201196

202197
# Create discretized distribution
203-
μprobs = pdf(μ, μsupp')
204-
μprobs = μprobs ./ sum(μprobs)
205-
νprobs = pdf(ν, νsupp')
206-
νprobs = νprobs ./ sum(νprobs)
207-
C = pairwise(SqEuclidean(), μsupp', νsupp')
198+
μprobs = normalize!(pdf(μ, μsupp'), 1)
199+
νprobs = normalize!(pdf(ν, νsupp'), 1)
200+
C = pairwise(SqEuclidean(), μsupp', νsupp'; dims=2)
208201
@test emd2(μprobs, νprobs, C, Tulip.Optimizer()) ot_cost(SqEuclidean(), μ, ν) rtol =
209202
1e-3
210203

test/gpu/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
[deps]
22
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
33
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
4+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
45
OptimalTransport = "7e02d93a-ae51-4f58-b602-d97af76e3b33"
56
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
67

test/gpu/simple_gpu.jl

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ using OptimalTransport
33
using CUDA
44
using Distances
55

6+
using LinearAlgebra
67
using Random
78
using Test
89

@@ -18,14 +19,12 @@ Random.seed!(100)
1819
@testset "sinkhorn" begin
1920
# source histogram
2021
m = 200
21-
μ = rand(Float32, m)
22-
μ ./= sum(μ)
22+
μ = normalize!(rand(Float32, m), 1)
2323
cu_μ = cu(μ)
2424

2525
# target histogram
2626
n = 250
27-
ν = rand(Float32, n)
28-
ν ./= sum(ν)
27+
ν = normalize!(rand(Float32, n), 1)
2928
cu_ν = cu(ν)
3029

3130
# random cost matrix
@@ -71,13 +70,12 @@ Random.seed!(100)
7170
@testset "sinkhorn_unbalanced" begin
7271
# source histogram
7372
m = 200
74-
μ = rand(Float32, m)
75-
μ ./= 1.5f0 * sum(μ)
73+
μ = normalize!(rand(Float32, m), 1)
74+
μ .*= 1.5f0
7675

7776
# target histogram
7877
n = 250
79-
ν = rand(Float32, n)
80-
ν ./= sum(ν)
78+
ν = normalize!(rand(Float32, n), 1)
8179

8280
# random cost matrix
8381
C = pairwise(SqEuclidean(), randn(Float32, 1, m), randn(Float32, 1, n); dims=2)

test/utils.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -66,10 +66,10 @@ Random.seed!(100)
6666
@testset "checkbalanced" begin
6767
mass = rand()
6868

69-
x1 = rand(20)
70-
x1 .*= mass / sum(x1)
71-
y1 = rand(30)
72-
y1 .*= mass / sum(y1)
69+
x1 = normalize!(rand(20), 1)
70+
x1 .*= mass
71+
y1 = normalize!(rand(30), 1)
72+
y1 .*= mass
7373
@test OptimalTransport.checkbalanced(x1, y1) === nothing
7474
@test OptimalTransport.checkbalanced(y1, x1) === nothing
7575
@test_throws ArgumentError OptimalTransport.checkbalanced(rand() .* x1, y1)

0 commit comments

Comments
 (0)