From 28b549ef3ef93c01462cd811d6e55c36ae5a76a2 Mon Sep 17 00:00:00 2001 From: Hicham Janati Date: Wed, 12 Jun 2019 15:50:25 +0200 Subject: add test and example of UOT --- examples/plot_UOT_1D.py | 76 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 76 insertions(+) create mode 100644 examples/plot_UOT_1D.py (limited to 'examples/plot_UOT_1D.py') diff --git a/examples/plot_UOT_1D.py b/examples/plot_UOT_1D.py new file mode 100644 index 0000000..1b1dd9c --- /dev/null +++ b/examples/plot_UOT_1D.py @@ -0,0 +1,76 @@ +# -*- coding: utf-8 -*- +""" +==================== +1D Unbalanced optimal transport +==================== + +This example illustrates the computation of Unbalanced Optimal transport +using a Kullback-Leibler relaxation. +""" + +# Author: Hicham Janati +# +# License: MIT License + +import numpy as np +import matplotlib.pylab as pl +import ot +import ot.plot +from ot.datasets import make_1D_gauss as gauss + +############################################################################## +# Generate data +# ------------- + + +#%% parameters + +n = 100 # nb bins + +# bin positions +x = np.arange(n, dtype=np.float64) + +# Gaussian distributions +a = gauss(n, m=20, s=5) # m= mean, s= std +b = gauss(n, m=60, s=10) + +# make distributions unbalanced +b *= 5. + +# loss matrix +M = ot.dist(x.reshape((n, 1)), x.reshape((n, 1))) +M /= M.max() + + +############################################################################## +# Plot distributions and loss matrix +# ---------------------------------- + +#%% plot the distributions + +pl.figure(1, figsize=(6.4, 3)) +pl.plot(x, a, 'b', label='Source distribution') +pl.plot(x, b, 'r', label='Target distribution') +pl.legend() + +#%% plot distributions and loss matrix + +pl.figure(2, figsize=(5, 5)) +ot.plot.plot1D_mat(a, b, M, 'Cost matrix M') + + +############################################################################## +# Solve Unbalanced Sinkhorn +# -------------- + + +#%% Sinkhorn + +lambd = 0.1 +alpha = 1. +Gs = ot.unbalanced.sinkhorn_unbalanced(a, b, M, lambd, alpha, verbose=True) + +pl.figure(4, figsize=(5, 5)) +ot.plot.plot1D_mat(a, b, Gs, 'UOT matrix Sinkhorn') + +pl.show() -- cgit v1.2.3 From 11381a7ecc79ef719ee9107167c3adc22b5a3f59 Mon Sep 17 00:00:00 2001 From: Hicham Janati Date: Wed, 12 Jun 2019 17:06:32 +0200 Subject: integrate comments of jmassich --- examples/plot_UOT_1D.py | 6 +++--- ot/unbalanced.py | 54 +++++++++++++++---------------------------------- test/test_unbalanced.py | 9 +++++++-- 3 files changed, 26 insertions(+), 43 deletions(-) (limited to 'examples/plot_UOT_1D.py') diff --git a/examples/plot_UOT_1D.py b/examples/plot_UOT_1D.py index 1b1dd9c..59b7e77 100644 --- a/examples/plot_UOT_1D.py +++ b/examples/plot_UOT_1D.py @@ -66,9 +66,9 @@ ot.plot.plot1D_mat(a, b, M, 'Cost matrix M') #%% Sinkhorn -lambd = 0.1 -alpha = 1. -Gs = ot.unbalanced.sinkhorn_unbalanced(a, b, M, lambd, alpha, verbose=True) +epsilon = 0.1 # entropy parameter +alpha = 1. # Unbalanced KL relaxation parameter +Gs = ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, alpha, verbose=True) pl.figure(4, figsize=(5, 5)) ot.plot.plot1D_mat(a, b, Gs, 'UOT matrix Sinkhorn') diff --git a/ot/unbalanced.py b/ot/unbalanced.py index 8bd02eb..f4208b5 100644 --- a/ot/unbalanced.py +++ b/ot/unbalanced.py @@ -6,6 +6,7 @@ Regularized Unbalanced OT # Author: Hicham Janati # License: MIT License +import warnings import numpy as np # from .utils import unif, dist @@ -29,7 +30,7 @@ def sinkhorn_unbalanced(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000, - a and b are source and target weights - KL is the Kullback-Leibler divergence - The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in [10]_ + The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in [10, 23]_ Parameters @@ -85,15 +86,14 @@ def sinkhorn_unbalanced(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000, .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. + .. [23] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. : Learning with a Wasserstein Loss, Advances in Neural Information Processing Systems (NIPS) 2015 See Also -------- - ot.lp.emd : Unregularized OT - ot.optim.cg : General regularized OT - ot.bregman.sinkhorn_knopp : Classic Sinkhorn [2] - ot.bregman.sinkhorn_stabilized: Stabilized sinkhorn [9][10] - ot.bregman.sinkhorn_epsilon_scaling: Sinkhorn with epslilon scaling [9][10] + ot.unbalanced.sinkhorn_knopp : Unbalanced Classic Sinkhorn [10] + ot.unbalanced.sinkhorn_stabilized: Unbalanced Stabilized sinkhorn [9][10] + ot.unbalanced.sinkhorn_epsilon_scaling: Unbalanced Sinkhorn with epslilon scaling [9][10] """ @@ -101,17 +101,8 @@ def sinkhorn_unbalanced(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000, def sink(): return sinkhorn_knopp(a, b, M, reg, alpha, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) - # elif method.lower() == 'sinkhorn_stabilized': - # def sink(): - # return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax, - # stopThr=stopThr, verbose=verbose, log=log, **kwargs) - # elif method.lower() == 'sinkhorn_epsilon_scaling': - # def sink(): - # return sinkhorn_epsilon_scaling( - # a, b, M, reg, numItermax=numItermax, - # stopThr=stopThr, verbose=verbose, log=log, **kwargs) else: - print('Warning : unknown method. Falling back to classic Sinkhorn Knopp') + warnings.warn('Unknown method. Falling back to classic Sinkhorn Knopp') def sink(): return sinkhorn_knopp(a, b, M, reg, alpha, numItermax=numItermax, @@ -139,7 +130,7 @@ def sinkhorn2(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000, - a and b are source and target weights - KL is the Kullback-Leibler divergence - The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in [10]_ + The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in [10, 23]_ Parameters @@ -196,18 +187,13 @@ def sinkhorn2(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000, .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. - [21] Altschuler J., Weed J., Rigollet P. : Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration, Advances in Neural Information Processing Systems (NIPS) 31, 2017 - - + .. [23] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. : Learning with a Wasserstein Loss, Advances in Neural Information Processing Systems (NIPS) 2015 See Also -------- - ot.lp.emd : Unregularized OT - ot.optim.cg : General regularized OT - ot.bregman.sinkhorn_knopp : Classic Sinkhorn [2] - ot.bregman.greenkhorn : Greenkhorn [21] - ot.bregman.sinkhorn_stabilized: Stabilized sinkhorn [9][10] - ot.bregman.sinkhorn_epsilon_scaling: Sinkhorn with epslilon scaling [9][10] + ot.unbalanced.sinkhorn_knopp : Unbalanced Classic Sinkhorn [10] + ot.unbalanced.sinkhorn_stabilized: Unbalanced Stabilized sinkhorn [9][10] + ot.unbalanced.sinkhorn_epsilon_scaling: Unbalanced Sinkhorn with epslilon scaling [9][10] """ @@ -215,17 +201,8 @@ def sinkhorn2(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000, def sink(): return sinkhorn_knopp(a, b, M, reg, alpha, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) - # elif method.lower() == 'sinkhorn_stabilized': - # def sink(): - # return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax, - # stopThr=stopThr, verbose=verbose, log=log, **kwargs) - # elif method.lower() == 'sinkhorn_epsilon_scaling': - # def sink(): - # return sinkhorn_epsilon_scaling( - # a, b, M, reg, numItermax=numItermax, - # stopThr=stopThr, verbose=verbose, log=log, **kwargs) else: - print('Warning : unknown method using classic Sinkhorn Knopp') + warnings.warn('Unknown method using classic Sinkhorn Knopp') def sink(): return sinkhorn_knopp(a, b, M, reg, alpha, **kwargs) @@ -256,7 +233,7 @@ def sinkhorn_knopp(a, b, M, reg, alpha, numItermax=1000, - a and b are source and target weights - KL is the Kullback-Leibler divergence - The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in [10]_ + The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in [10, 23]_ Parameters @@ -306,6 +283,7 @@ def sinkhorn_knopp(a, b, M, reg, alpha, numItermax=1000, .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. + .. [23] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. : Learning with a Wasserstein Loss, Advances in Neural Information Processing Systems (NIPS) 2015 See Also -------- @@ -368,7 +346,7 @@ def sinkhorn_knopp(a, b, M, reg, alpha, numItermax=1000, or np.any(np.isinf(u)) or np.any(np.isinf(v))): # we have reached the machine precision # come back to previous solution and quit loop - print('Warning: numerical errors at iteration', cpt) + warnings.warn('Numerical errors at iteration', cpt) u = uprev v = vprev break diff --git a/test/test_unbalanced.py b/test/test_unbalanced.py index 863b6f3..e37498f 100644 --- a/test/test_unbalanced.py +++ b/test/test_unbalanced.py @@ -6,15 +6,19 @@ import numpy as np import ot +import pytest -def test_unbalanced(): +@pytest.mark.parametrize("metric", ["sinkhorn"]) +def test_unbalanced_convergence(method): # test generalized sinkhorn for unbalanced OT n = 100 rng = np.random.RandomState(42) x = rng.randn(n, 2) a = ot.utils.unif(n) + + # make dists unbalanced b = ot.utils.unif(n) * 1.5 M = ot.dist(x, x) @@ -23,7 +27,8 @@ def test_unbalanced(): K = np.exp(- M / epsilon) G, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon, alpha=alpha, - stopThr=1e-10, log=True) + stopThr=1e-10, method=method, + log=True) # check fixed point equations fi = alpha / (alpha + epsilon) -- cgit v1.2.3 From adf9d046445bf142a29d914352f397b36f7905c0 Mon Sep 17 00:00:00 2001 From: Hicham Janati Date: Tue, 18 Jun 2019 22:26:48 +0200 Subject: update Readme + minor rendering in examples --- README.md | 4 ++++ examples/plot_UOT_1D.py | 8 ++++---- examples/plot_UOT_barycenter_1D.py | 10 +++++----- ot/unbalanced.py | 6 +++--- 4 files changed, 16 insertions(+), 12 deletions(-) (limited to 'examples/plot_UOT_1D.py') diff --git a/README.md b/README.md index b6b215c..d24d8b9 100644 --- a/README.md +++ b/README.md @@ -27,6 +27,7 @@ It provides the following solvers: * Gromov-Wasserstein distances and barycenters ([13] and regularized [12]) * Stochastic Optimization for Large-scale Optimal Transport (semi-dual problem [18] and dual problem [19]) * Non regularized free support Wasserstein barycenters [20]. +* Unbalanced OT with KL relaxation distance and barycenter [10, 25]. Some demonstrations (both in Python and Jupyter Notebook format) are available in the examples folder. @@ -165,6 +166,7 @@ The contributors to this library are: * [Kilian Fatras](https://kilianfatras.github.io/) * [Alain Rakotomamonjy](https://sites.google.com/site/alainrakotomamonjy/home) * [Vayer Titouan](https://tvayer.github.io/) +* [Hicham Janati](https://hichamjanati.github.io/) (Unbalanced OT) This toolbox benefit a lot from open source research and we would like to thank the following persons for providing some code (in various languages): @@ -236,3 +238,5 @@ You can also post bug reports and feature requests in Github issues. Make sure t [23] Aude, G., Peyré, G., Cuturi, M., [Learning Generative Models with Sinkhorn Divergences](https://arxiv.org/abs/1706.00292), Proceedings of the Twenty-First International Conference on Artficial Intelligence and Statistics, (AISTATS) 21, 2018 [24] Vayer, T., Chapel, L., Flamary, R., Tavenard, R. and Courty, N. (2019). [Optimal Transport for structured data with application on graphs](http://proceedings.mlr.press/v97/titouan19a.html) Proceedings of the 36th International Conference on Machine Learning (ICML). + +[25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. (2019). [Learning with a Wasserstein Loss](http://cbcl.mit.edu/wasserstein/) Advances in Neural Information Processing Systems (NIPS). diff --git a/examples/plot_UOT_1D.py b/examples/plot_UOT_1D.py index 59b7e77..2ea8b05 100644 --- a/examples/plot_UOT_1D.py +++ b/examples/plot_UOT_1D.py @@ -1,8 +1,8 @@ # -*- coding: utf-8 -*- """ -==================== +=============================== 1D Unbalanced optimal transport -==================== +=============================== This example illustrates the computation of Unbalanced Optimal transport using a Kullback-Leibler relaxation. @@ -53,7 +53,7 @@ pl.plot(x, a, 'b', label='Source distribution') pl.plot(x, b, 'r', label='Target distribution') pl.legend() -#%% plot distributions and loss matrix +# plot distributions and loss matrix pl.figure(2, figsize=(5, 5)) ot.plot.plot1D_mat(a, b, M, 'Cost matrix M') @@ -64,7 +64,7 @@ ot.plot.plot1D_mat(a, b, M, 'Cost matrix M') # -------------- -#%% Sinkhorn +# Sinkhorn epsilon = 0.1 # entropy parameter alpha = 1. # Unbalanced KL relaxation parameter diff --git a/examples/plot_UOT_barycenter_1D.py b/examples/plot_UOT_barycenter_1D.py index 8dfb84f..c8d9d3b 100644 --- a/examples/plot_UOT_barycenter_1D.py +++ b/examples/plot_UOT_barycenter_1D.py @@ -27,7 +27,7 @@ from matplotlib.collections import PolyCollection # Generate data # ------------- -#%% parameters +# parameters n = 100 # nb bins @@ -53,7 +53,7 @@ M /= M.max() # Plot data # --------- -#%% plot the distributions +# plot the distributions pl.figure(1, figsize=(6.4, 3)) for i in range(n_distributions): @@ -65,7 +65,7 @@ pl.tight_layout() # Barycenter computation # ---------------------- -#%% non weighted barycenter computation +# non weighted barycenter computation weight = 0.5 # 0<=weight<=1 weights = np.array([1 - weight, weight]) @@ -97,7 +97,7 @@ pl.tight_layout() # Barycentric interpolation # ------------------------- -#%% barycenter interpolation +# barycenter interpolation n_weight = 11 weight_list = np.linspace(0, 1, n_weight) @@ -114,7 +114,7 @@ for i in range(0, n_weight): B_wass[:, i] = ot.unbalanced.barycenter_unbalanced(A, M, reg, alpha, weights) -#%% plot interpolation +# plot interpolation pl.figure(3) diff --git a/ot/unbalanced.py b/ot/unbalanced.py index 97e2576..918dda4 100644 --- a/ot/unbalanced.py +++ b/ot/unbalanced.py @@ -87,7 +87,7 @@ def sinkhorn_unbalanced(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000, .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. - .. [23] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. : Learning with a Wasserstein Loss, Advances in Neural Information Processing Systems (NIPS) 2015 + .. [25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. : Learning with a Wasserstein Loss, Advances in Neural Information Processing Systems (NIPS) 2015 See Also @@ -196,7 +196,7 @@ def sinkhorn_unbalanced2(a, b, M, reg, alpha, method='sinkhorn', .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. - .. [23] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. : Learning with a Wasserstein Loss, Advances in Neural Information Processing Systems (NIPS) 2015 + .. [25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. : Learning with a Wasserstein Loss, Advances in Neural Information Processing Systems (NIPS) 2015 See Also -------- @@ -299,7 +299,7 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000, .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. - .. [23] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. : Learning with a Wasserstein Loss, Advances in Neural Information Processing Systems (NIPS) 2015 + .. [25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. : Learning with a Wasserstein Loss, Advances in Neural Information Processing Systems (NIPS) 2015 See Also -------- -- cgit v1.2.3