|
345 | 345 | function _ot_cost(c, μ::DiscreteNonParametric, ν::DiscreteNonParametric, plan) |
346 | 346 | return dot(plan, StatsBase.pairwise(c, support(μ), support(ν))) |
347 | 347 | end |
| 348 | + |
| 349 | +################ |
| 350 | +# OT Gaussians |
| 351 | +################ |
| 352 | + |
| 353 | +""" |
| 354 | + ot_cost(::SqEuclidean, μ::MvNormal, ν::MvNormal) |
| 355 | +
|
| 356 | +Compute the squared 2-Wasserstein distance between normal distributions `μ` and `ν` as |
| 357 | +source and target marginals. |
| 358 | +
|
| 359 | +In this setting, the optimal transport cost can be computed as |
| 360 | +```math |
| 361 | +W_2^2(\\mu, \\nu) = \\|m_\\mu - m_\\nu \\|^2 + \\mathcal{B}(\\Sigma_\\mu, \\Sigma_\\nu)^2, |
| 362 | +``` |
| 363 | +where ``\\mu = \\mathcal{N}(m_\\mu, \\Sigma_\\mu)``, |
| 364 | +``\\nu = \\mathcal{N}(m_\\nu, \\Sigma_\\nu)``, and ``\\mathcal{B}`` is the Bures metric. |
| 365 | +
|
| 366 | +See also: [`ot_plan`](@ref), [`emd2`](@ref) |
| 367 | +""" |
| 368 | +function ot_cost(::SqEuclidean, μ::MvNormal, ν::MvNormal) |
| 369 | + return sqeuclidean(μ.μ, ν.μ) + sqbures(μ.Σ, ν.Σ) |
| 370 | +end |
| 371 | + |
| 372 | +""" |
| 373 | + ot_cost(::SqEuclidean, μ::Normal, ν::Normal) |
| 374 | +
|
| 375 | +Compute the squared 2-Wasserstein distance between univariate normal distributions `μ` and |
| 376 | +`ν` as source and target marginals. |
| 377 | +
|
| 378 | +See also: [`ot_plan`](@ref), [`emd2`](@ref) |
| 379 | +""" |
| 380 | +function ot_cost(::SqEuclidean, μ::Normal, ν::Normal) |
| 381 | + return (μ.μ - ν.μ)^2 + (μ.σ - ν.σ)^2 |
| 382 | +end |
| 383 | + |
| 384 | +""" |
| 385 | + ot_plan(::SqEuclidean, μ::MvNormal, ν::MvNormal) |
| 386 | +
|
| 387 | +Compute the optimal transport plan for the Monge-Kantorovich problem with multivariate |
| 388 | +normal distributions `μ` and `ν` as source and target marginals and cost function |
| 389 | +``c(x, y) = \\|x - y\\|_2^2``. |
| 390 | +
|
| 391 | +In this setting, for ``\\mu = \\mathcal{N}(m_\\mu, \\Sigma_\\mu)`` and |
| 392 | +``\\nu = \\mathcal{N}(m_\\nu, \\Sigma_\\nu)``, the optimal transport plan is the Monge |
| 393 | +map |
| 394 | +```math |
| 395 | +T \\colon x \\mapsto m_\\nu |
| 396 | ++ \\Sigma_\\mu^{-1/2} |
| 397 | +{\\big(\\Sigma_\\mu^{1/2} \\Sigma_\\nu \\Sigma_\\mu^{1/2}\\big)}^{1/2}\\Sigma_\\mu^{-1/2} |
| 398 | +(x - m_\\mu). |
| 399 | +
|
| 400 | +See also: [`ot_cost`](@ref), [`emd`](@ref) |
| 401 | +""" |
| 402 | +function ot_plan(::SqEuclidean, μ::MvNormal, ν::MvNormal) |
| 403 | + Σμsqrt = μ.Σ^(-1 / 2) |
| 404 | + A = Σμsqrt * sqrt(_gaussian_ot_A(μ.Σ, ν.Σ)) * Σμsqrt |
| 405 | + mμ = μ.μ |
| 406 | + mν = ν.μ |
| 407 | + T(x) = mν + A * (x - mμ) |
| 408 | + return T |
| 409 | +end |
| 410 | + |
| 411 | +""" |
| 412 | + ot_plan(::SqEuclidean, μ::Normal, ν::Normal) |
| 413 | +
|
| 414 | +Compute the optimal transport plan for the Monge-Kantorovich problem with |
| 415 | +normal distributions `μ` and `ν` as source and target marginals and cost function |
| 416 | +``c(x, y) = \\|x - y\\|_2^2``. |
| 417 | +
|
| 418 | +See also: [`ot_cost`](@ref), [`emd`](@ref) |
| 419 | +""" |
| 420 | +function ot_plan(::SqEuclidean, μ::Normal, ν::Normal) |
| 421 | + mμ = μ.μ |
| 422 | + mν = ν.μ |
| 423 | + a = ν.σ / μ.σ |
| 424 | + T(x) = mν + a * (x - mμ) |
| 425 | + return T |
| 426 | +end |
0 commit comments