Skip to content

Commit 71c3378

Browse files
committed
added ot_reg_cost
1 parent 3e8a3c8 commit 71c3378

File tree

2 files changed

+23
-1
lines changed

2 files changed

+23
-1
lines changed

src/OptimalTransport.jl

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ export sinkhorn, sinkhorn2
1313
export emd, emd2
1414
export sinkhorn_stabilized, sinkhorn_stabilized_epsscaling, sinkhorn_barycenter
1515
export sinkhorn_unbalanced, sinkhorn_unbalanced2
16-
export ot_reg_plan
16+
export ot_reg_plan, ot_reg_cost
1717

1818
const MOI = MathOptInterface
1919

@@ -534,6 +534,25 @@ function ot_reg_plan(mu, nu, C, eps; reg_func="L2", method="lorenz", kwargs...)
534534
end
535535
end
536536

537+
"""
538+
ot_reg_cost(mu, nu, C, eps; reg_func = "L2", method = "lorenz", kwargs...)
539+
540+
Compute the optimal transport cost between `mu` and `nu` for optimal transport with a
541+
general choice of regulariser `math Ω(γ)`.
542+
543+
See also: [`ot_reg_plan`](@ref)
544+
545+
"""
546+
function ot_reg_cost(mu, nu, C, eps; reg_func="L2", method="lorenz", kwargs...)
547+
γ = if (reg_func == "L2") && (method == "lorenz")
548+
quadreg(mu, nu, C, eps; kwargs...)
549+
else
550+
@warn "Unimplemented"
551+
nothing
552+
end
553+
return dot(γ, C)
554+
end
555+
537556

538557
"""
539558
quadreg(mu, nu, C, ϵ; θ = 0.1, tol = 1e-5,maxiter = 50,κ = 0.5,δ = 1e-5)

test/runtests.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,9 @@ end
174174
>>>>>>> d6c9ee3 (updated tests and docstrings)
175175
# need to use a larger tolerance here because of a quirk with the POT solver
176176
@test norm- γ_pot, Inf) < 1e-4
177+
c = ot_reg_cost(μ, ν, C, eps; reg_func="L2", method="lorenz")
178+
c_pot = dot(γ_pot, C)
179+
@test c c_pot atol = 1e-4
177180
end
178181
end
179182

0 commit comments

Comments
 (0)