diff options
author | Hicham Janati <hicham.janati@inria.fr> | 2019-06-12 15:50:25 +0200 |
---|---|---|
committer | Hicham Janati <hicham.janati@inria.fr> | 2019-06-12 15:50:25 +0200 |
commit | 28b549ef3ef93c01462cd811d6e55c36ae5a76a2 (patch) | |
tree | 75605c0ca1e501d1e5950ad09c6efd73f5e1c85a | |
parent | 3c53834d46f093f5770ec76748beb5667bebb6fa (diff) |
add test and example of UOT
-rw-r--r-- | examples/plot_UOT_1D.py | 76 | ||||
-rw-r--r-- | test/test_unbalanced.py | 36 |
2 files changed, 112 insertions, 0 deletions
diff --git a/examples/plot_UOT_1D.py b/examples/plot_UOT_1D.py new file mode 100644 index 0000000..1b1dd9c --- /dev/null +++ b/examples/plot_UOT_1D.py @@ -0,0 +1,76 @@ +# -*- coding: utf-8 -*- +""" +==================== +1D Unbalanced optimal transport +==================== + +This example illustrates the computation of Unbalanced Optimal transport +using a Kullback-Leibler relaxation. +""" + +# Author: Hicham Janati <hicham.janati@inria.fr> +# +# License: MIT License + +import numpy as np +import matplotlib.pylab as pl +import ot +import ot.plot +from ot.datasets import make_1D_gauss as gauss + +############################################################################## +# Generate data +# ------------- + + +#%% parameters + +n = 100 # nb bins + +# bin positions +x = np.arange(n, dtype=np.float64) + +# Gaussian distributions +a = gauss(n, m=20, s=5) # m= mean, s= std +b = gauss(n, m=60, s=10) + +# make distributions unbalanced +b *= 5. + +# loss matrix +M = ot.dist(x.reshape((n, 1)), x.reshape((n, 1))) +M /= M.max() + + +############################################################################## +# Plot distributions and loss matrix +# ---------------------------------- + +#%% plot the distributions + +pl.figure(1, figsize=(6.4, 3)) +pl.plot(x, a, 'b', label='Source distribution') +pl.plot(x, b, 'r', label='Target distribution') +pl.legend() + +#%% plot distributions and loss matrix + +pl.figure(2, figsize=(5, 5)) +ot.plot.plot1D_mat(a, b, M, 'Cost matrix M') + + +############################################################################## +# Solve Unbalanced Sinkhorn +# -------------- + + +#%% Sinkhorn + +lambd = 0.1 +alpha = 1. +Gs = ot.unbalanced.sinkhorn_unbalanced(a, b, M, lambd, alpha, verbose=True) + +pl.figure(4, figsize=(5, 5)) +ot.plot.plot1D_mat(a, b, Gs, 'UOT matrix Sinkhorn') + +pl.show() diff --git a/test/test_unbalanced.py b/test/test_unbalanced.py new file mode 100644 index 0000000..863b6f3 --- /dev/null +++ b/test/test_unbalanced.py @@ -0,0 +1,36 @@ +"""Tests for module Unbalanced OT with entropy regularization""" + +# Author: Hicham Janati <hicham.janati@inria.fr> +# +# License: MIT License + +import numpy as np +import ot + + +def test_unbalanced(): + # test generalized sinkhorn for unbalanced OT + n = 100 + rng = np.random.RandomState(42) + + x = rng.randn(n, 2) + a = ot.utils.unif(n) + b = ot.utils.unif(n) * 1.5 + + M = ot.dist(x, x) + epsilon = 1. + alpha = 1. + K = np.exp(- M / epsilon) + + G, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon, alpha=alpha, + stopThr=1e-10, log=True) + + # check fixed point equations + fi = alpha / (alpha + epsilon) + v_final = (b / K.T.dot(log["u"])) ** fi + u_final = (a / K.dot(log["v"])) ** fi + + np.testing.assert_allclose( + u_final, log["u"], atol=1e-05) + np.testing.assert_allclose( + v_final, log["v"], atol=1e-05) |