From 28b549ef3ef93c01462cd811d6e55c36ae5a76a2 Mon Sep 17 00:00:00 2001 From: Hicham Janati Date: Wed, 12 Jun 2019 15:50:25 +0200 Subject: add test and example of UOT --- test/test_unbalanced.py | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) create mode 100644 test/test_unbalanced.py (limited to 'test') diff --git a/test/test_unbalanced.py b/test/test_unbalanced.py new file mode 100644 index 0000000..863b6f3 --- /dev/null +++ b/test/test_unbalanced.py @@ -0,0 +1,36 @@ +"""Tests for module Unbalanced OT with entropy regularization""" + +# Author: Hicham Janati +# +# License: MIT License + +import numpy as np +import ot + + +def test_unbalanced(): + # test generalized sinkhorn for unbalanced OT + n = 100 + rng = np.random.RandomState(42) + + x = rng.randn(n, 2) + a = ot.utils.unif(n) + b = ot.utils.unif(n) * 1.5 + + M = ot.dist(x, x) + epsilon = 1. + alpha = 1. + K = np.exp(- M / epsilon) + + G, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon, alpha=alpha, + stopThr=1e-10, log=True) + + # check fixed point equations + fi = alpha / (alpha + epsilon) + v_final = (b / K.T.dot(log["u"])) ** fi + u_final = (a / K.dot(log["v"])) ** fi + + np.testing.assert_allclose( + u_final, log["u"], atol=1e-05) + np.testing.assert_allclose( + v_final, log["v"], atol=1e-05) -- cgit v1.2.3 From 11381a7ecc79ef719ee9107167c3adc22b5a3f59 Mon Sep 17 00:00:00 2001 From: Hicham Janati Date: Wed, 12 Jun 2019 17:06:32 +0200 Subject: integrate comments of jmassich --- examples/plot_UOT_1D.py | 6 +++--- ot/unbalanced.py | 54 +++++++++++++++---------------------------------- test/test_unbalanced.py | 9 +++++++-- 3 files changed, 26 insertions(+), 43 deletions(-) (limited to 'test') diff --git a/examples/plot_UOT_1D.py b/examples/plot_UOT_1D.py index 1b1dd9c..59b7e77 100644 --- a/examples/plot_UOT_1D.py +++ b/examples/plot_UOT_1D.py @@ -66,9 +66,9 @@ ot.plot.plot1D_mat(a, b, M, 'Cost matrix M') #%% Sinkhorn -lambd = 0.1 -alpha = 1. -Gs = ot.unbalanced.sinkhorn_unbalanced(a, b, M, lambd, alpha, verbose=True) +epsilon = 0.1 # entropy parameter +alpha = 1. # Unbalanced KL relaxation parameter +Gs = ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, alpha, verbose=True) pl.figure(4, figsize=(5, 5)) ot.plot.plot1D_mat(a, b, Gs, 'UOT matrix Sinkhorn') diff --git a/ot/unbalanced.py b/ot/unbalanced.py index 8bd02eb..f4208b5 100644 --- a/ot/unbalanced.py +++ b/ot/unbalanced.py @@ -6,6 +6,7 @@ Regularized Unbalanced OT # Author: Hicham Janati # License: MIT License +import warnings import numpy as np # from .utils import unif, dist @@ -29,7 +30,7 @@ def sinkhorn_unbalanced(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000, - a and b are source and target weights - KL is the Kullback-Leibler divergence - The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in [10]_ + The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in [10, 23]_ Parameters @@ -85,15 +86,14 @@ def sinkhorn_unbalanced(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000, .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. + .. [23] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. : Learning with a Wasserstein Loss, Advances in Neural Information Processing Systems (NIPS) 2015 See Also -------- - ot.lp.emd : Unregularized OT - ot.optim.cg : General regularized OT - ot.bregman.sinkhorn_knopp : Classic Sinkhorn [2] - ot.bregman.sinkhorn_stabilized: Stabilized sinkhorn [9][10] - ot.bregman.sinkhorn_epsilon_scaling: Sinkhorn with epslilon scaling [9][10] + ot.unbalanced.sinkhorn_knopp : Unbalanced Classic Sinkhorn [10] + ot.unbalanced.sinkhorn_stabilized: Unbalanced Stabilized sinkhorn [9][10] + ot.unbalanced.sinkhorn_epsilon_scaling: Unbalanced Sinkhorn with epslilon scaling [9][10] """ @@ -101,17 +101,8 @@ def sinkhorn_unbalanced(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000, def sink(): return sinkhorn_knopp(a, b, M, reg, alpha, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) - # elif method.lower() == 'sinkhorn_stabilized': - # def sink(): - # return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax, - # stopThr=stopThr, verbose=verbose, log=log, **kwargs) - # elif method.lower() == 'sinkhorn_epsilon_scaling': - # def sink(): - # return sinkhorn_epsilon_scaling( - # a, b, M, reg, numItermax=numItermax, - # stopThr=stopThr, verbose=verbose, log=log, **kwargs) else: - print('Warning : unknown method. Falling back to classic Sinkhorn Knopp') + warnings.warn('Unknown method. Falling back to classic Sinkhorn Knopp') def sink(): return sinkhorn_knopp(a, b, M, reg, alpha, numItermax=numItermax, @@ -139,7 +130,7 @@ def sinkhorn2(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000, - a and b are source and target weights - KL is the Kullback-Leibler divergence - The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in [10]_ + The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in [10, 23]_ Parameters @@ -196,18 +187,13 @@ def sinkhorn2(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000, .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. - [21] Altschuler J., Weed J., Rigollet P. : Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration, Advances in Neural Information Processing Systems (NIPS) 31, 2017 - - + .. [23] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. : Learning with a Wasserstein Loss, Advances in Neural Information Processing Systems (NIPS) 2015 See Also -------- - ot.lp.emd : Unregularized OT - ot.optim.cg : General regularized OT - ot.bregman.sinkhorn_knopp : Classic Sinkhorn [2] - ot.bregman.greenkhorn : Greenkhorn [21] - ot.bregman.sinkhorn_stabilized: Stabilized sinkhorn [9][10] - ot.bregman.sinkhorn_epsilon_scaling: Sinkhorn with epslilon scaling [9][10] + ot.unbalanced.sinkhorn_knopp : Unbalanced Classic Sinkhorn [10] + ot.unbalanced.sinkhorn_stabilized: Unbalanced Stabilized sinkhorn [9][10] + ot.unbalanced.sinkhorn_epsilon_scaling: Unbalanced Sinkhorn with epslilon scaling [9][10] """ @@ -215,17 +201,8 @@ def sinkhorn2(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000, def sink(): return sinkhorn_knopp(a, b, M, reg, alpha, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) - # elif method.lower() == 'sinkhorn_stabilized': - # def sink(): - # return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax, - # stopThr=stopThr, verbose=verbose, log=log, **kwargs) - # elif method.lower() == 'sinkhorn_epsilon_scaling': - # def sink(): - # return sinkhorn_epsilon_scaling( - # a, b, M, reg, numItermax=numItermax, - # stopThr=stopThr, verbose=verbose, log=log, **kwargs) else: - print('Warning : unknown method using classic Sinkhorn Knopp') + warnings.warn('Unknown method using classic Sinkhorn Knopp') def sink(): return sinkhorn_knopp(a, b, M, reg, alpha, **kwargs) @@ -256,7 +233,7 @@ def sinkhorn_knopp(a, b, M, reg, alpha, numItermax=1000, - a and b are source and target weights - KL is the Kullback-Leibler divergence - The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in [10]_ + The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in [10, 23]_ Parameters @@ -306,6 +283,7 @@ def sinkhorn_knopp(a, b, M, reg, alpha, numItermax=1000, .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. + .. [23] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. : Learning with a Wasserstein Loss, Advances in Neural Information Processing Systems (NIPS) 2015 See Also -------- @@ -368,7 +346,7 @@ def sinkhorn_knopp(a, b, M, reg, alpha, numItermax=1000, or np.any(np.isinf(u)) or np.any(np.isinf(v))): # we have reached the machine precision # come back to previous solution and quit loop - print('Warning: numerical errors at iteration', cpt) + warnings.warn('Numerical errors at iteration', cpt) u = uprev v = vprev break diff --git a/test/test_unbalanced.py b/test/test_unbalanced.py index 863b6f3..e37498f 100644 --- a/test/test_unbalanced.py +++ b/test/test_unbalanced.py @@ -6,15 +6,19 @@ import numpy as np import ot +import pytest -def test_unbalanced(): +@pytest.mark.parametrize("metric", ["sinkhorn"]) +def test_unbalanced_convergence(method): # test generalized sinkhorn for unbalanced OT n = 100 rng = np.random.RandomState(42) x = rng.randn(n, 2) a = ot.utils.unif(n) + + # make dists unbalanced b = ot.utils.unif(n) * 1.5 M = ot.dist(x, x) @@ -23,7 +27,8 @@ def test_unbalanced(): K = np.exp(- M / epsilon) G, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon, alpha=alpha, - stopThr=1e-10, log=True) + stopThr=1e-10, method=method, + log=True) # check fixed point equations fi = alpha / (alpha + epsilon) -- cgit v1.2.3 From 12ed1581225f70c7c8777b6ce31710453fda7f51 Mon Sep 17 00:00:00 2001 From: Hicham Janati Date: Wed, 12 Jun 2019 17:15:40 +0200 Subject: fix typo in test argument --- test/test_unbalanced.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'test') diff --git a/test/test_unbalanced.py b/test/test_unbalanced.py index e37498f..b4fa355 100644 --- a/test/test_unbalanced.py +++ b/test/test_unbalanced.py @@ -9,7 +9,7 @@ import ot import pytest -@pytest.mark.parametrize("metric", ["sinkhorn"]) +@pytest.mark.parametrize("method", ["sinkhorn"]) def test_unbalanced_convergence(method): # test generalized sinkhorn for unbalanced OT n = 100 -- cgit v1.2.3 From 50bc90058940645a13e2f3e41129bdc97161dc63 Mon Sep 17 00:00:00 2001 From: Hicham Janati Date: Wed, 12 Jun 2019 17:52:02 +0200 Subject: add unbalanced barycenters --- examples/plot_UOT_barycenter_1D.py | 164 +++++++++++++++++++++++++++++++++++++ ot/unbalanced.py | 118 ++++++++++++++++++++++++++ test/test_unbalanced.py | 30 +++++++ 3 files changed, 312 insertions(+) create mode 100644 examples/plot_UOT_barycenter_1D.py (limited to 'test') diff --git a/examples/plot_UOT_barycenter_1D.py b/examples/plot_UOT_barycenter_1D.py new file mode 100644 index 0000000..8dfb84f --- /dev/null +++ b/examples/plot_UOT_barycenter_1D.py @@ -0,0 +1,164 @@ +# -*- coding: utf-8 -*- +""" +=========================================================== +1D Wasserstein barycenter demo for Unbalanced distributions +=========================================================== + +This example illustrates the computation of regularized Wassersyein Barycenter +as proposed in [10] for Unbalanced inputs. + + +[10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. + +""" + +# Author: Hicham Janati +# +# License: MIT License + +import numpy as np +import matplotlib.pylab as pl +import ot +# necessary for 3d plot even if not used +from mpl_toolkits.mplot3d import Axes3D # noqa +from matplotlib.collections import PolyCollection + +############################################################################## +# Generate data +# ------------- + +#%% parameters + +n = 100 # nb bins + +# bin positions +x = np.arange(n, dtype=np.float64) + +# Gaussian distributions +a1 = ot.datasets.make_1D_gauss(n, m=20, s=5) # m= mean, s= std +a2 = ot.datasets.make_1D_gauss(n, m=60, s=8) + +# make unbalanced dists +a2 *= 3. + +# creating matrix A containing all distributions +A = np.vstack((a1, a2)).T +n_distributions = A.shape[1] + +# loss matrix + normalization +M = ot.utils.dist0(n) +M /= M.max() + +############################################################################## +# Plot data +# --------- + +#%% plot the distributions + +pl.figure(1, figsize=(6.4, 3)) +for i in range(n_distributions): + pl.plot(x, A[:, i]) +pl.title('Distributions') +pl.tight_layout() + +############################################################################## +# Barycenter computation +# ---------------------- + +#%% non weighted barycenter computation + +weight = 0.5 # 0<=weight<=1 +weights = np.array([1 - weight, weight]) + +# l2bary +bary_l2 = A.dot(weights) + +# wasserstein +reg = 1e-3 +alpha = 1. + +bary_wass = ot.unbalanced.barycenter_unbalanced(A, M, reg, alpha, weights) + +pl.figure(2) +pl.clf() +pl.subplot(2, 1, 1) +for i in range(n_distributions): + pl.plot(x, A[:, i]) +pl.title('Distributions') + +pl.subplot(2, 1, 2) +pl.plot(x, bary_l2, 'r', label='l2') +pl.plot(x, bary_wass, 'g', label='Wasserstein') +pl.legend() +pl.title('Barycenters') +pl.tight_layout() + +############################################################################## +# Barycentric interpolation +# ------------------------- + +#%% barycenter interpolation + +n_weight = 11 +weight_list = np.linspace(0, 1, n_weight) + + +B_l2 = np.zeros((n, n_weight)) + +B_wass = np.copy(B_l2) + +for i in range(0, n_weight): + weight = weight_list[i] + weights = np.array([1 - weight, weight]) + B_l2[:, i] = A.dot(weights) + B_wass[:, i] = ot.unbalanced.barycenter_unbalanced(A, M, reg, alpha, weights) + + +#%% plot interpolation + +pl.figure(3) + +cmap = pl.cm.get_cmap('viridis') +verts = [] +zs = weight_list +for i, z in enumerate(zs): + ys = B_l2[:, i] + verts.append(list(zip(x, ys))) + +ax = pl.gcf().gca(projection='3d') + +poly = PolyCollection(verts, facecolors=[cmap(a) for a in weight_list]) +poly.set_alpha(0.7) +ax.add_collection3d(poly, zs=zs, zdir='y') +ax.set_xlabel('x') +ax.set_xlim3d(0, n) +ax.set_ylabel(r'$\alpha$') +ax.set_ylim3d(0, 1) +ax.set_zlabel('') +ax.set_zlim3d(0, B_l2.max() * 1.01) +pl.title('Barycenter interpolation with l2') +pl.tight_layout() + +pl.figure(4) +cmap = pl.cm.get_cmap('viridis') +verts = [] +zs = weight_list +for i, z in enumerate(zs): + ys = B_wass[:, i] + verts.append(list(zip(x, ys))) + +ax = pl.gcf().gca(projection='3d') + +poly = PolyCollection(verts, facecolors=[cmap(a) for a in weight_list]) +poly.set_alpha(0.7) +ax.add_collection3d(poly, zs=zs, zdir='y') +ax.set_xlabel('x') +ax.set_xlim3d(0, n) +ax.set_ylabel(r'$\alpha$') +ax.set_ylim3d(0, 1) +ax.set_zlabel('') +ax.set_zlim3d(0, B_l2.max() * 1.01) +pl.title('Barycenter interpolation with Wasserstein') +pl.tight_layout() + +pl.show() diff --git a/ot/unbalanced.py b/ot/unbalanced.py index f4208b5..a30fc18 100644 --- a/ot/unbalanced.py +++ b/ot/unbalanced.py @@ -380,3 +380,121 @@ def sinkhorn_knopp(a, b, M, reg, alpha, numItermax=1000, return u[:, None] * K * v[None, :], log else: return u[:, None] * K * v[None, :] + + +def barycenter_unbalanced(A, M, reg, alpha, weights=None, numItermax=1000, + stopThr=1e-4, verbose=False, log=False): + """Compute the entropic regularized unbalanced wasserstein barycenter of distributions A + + The function solves the following optimization problem: + + .. math:: + \mathbf{a} = arg\min_\mathbf{a} \sum_i Wu_{reg}(\mathbf{a},\mathbf{a}_i) + + where : + + - :math:`W_{reg}(\cdot,\cdot)` is the unbalanced entropic regularized Wasserstein distance (see ot.unbalanced.sinkhorn_unbalanced) + - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}` + - reg and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT + - alpha is the marginal relaxation hyperparameter + The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in [10]_ + + Parameters + ---------- + A : np.ndarray (d,n) + n training distributions a_i of size d + M : np.ndarray (d,d) + loss matrix for OT + reg : float + Regularization term > 0 + alpha : float + Regularization term > 0 + weights : np.ndarray (n,) + Weights of each histogram a_i on the simplex (barycentric coodinates) + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshol on error (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + + Returns + ------- + a : (d,) ndarray + Unbalanced Wasserstein barycenter + log : dict + log dictionary return only if log==True in parameters + + + References + ---------- + + .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). Iterative Bregman projections for regularized transportation problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138. + .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. + + + """ + p, n_hists = A.shape + if weights is None: + weights = np.ones(n_hists) / n_hists + else: + assert(len(weights) == A.shape[1]) + + if log: + log = {'err': []} + + K = np.exp(- M / reg) + + fi = alpha / (alpha + reg) + + v = np.ones((p, n_hists)) / p + u = np.ones((p, 1)) / p + + cpt = 0 + err = 1. + + while (err > stopThr and cpt < numItermax): + uprev = u + vprev = v + + Kv = K.dot(v) + u = (A / Kv) ** fi + Ktu = K.T.dot(u) + q = ((Ktu ** (1 - fi)).dot(weights)) + q = q ** (1 / (1 - fi)) + Q = q[:, None] + v = (Q / Ktu) ** fi + + if (np.any(Ktu == 0.) + or np.any(np.isnan(u)) or np.any(np.isnan(v)) + or np.any(np.isinf(u)) or np.any(np.isinf(v))): + # we have reached the machine precision + # come back to previous solution and quit loop + warnings.warn('Numerical errors at iteration', cpt) + u = uprev + v = vprev + break + if cpt % 10 == 0: + # we can speed up the process by checking for the error only all + # the 10th iterations + err = np.sum((u - uprev) ** 2) / np.sum((u) ** 2) + \ + np.sum((v - vprev) ** 2) / np.sum((v) ** 2) + if log: + log['err'].append(err) + if verbose: + if cpt % 50 == 0: + print( + '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(cpt, err)) + + cpt += 1 + if log: + log['niter'] = cpt + log['u'] = u + log['v'] = v + return q, log + else: + return q diff --git a/test/test_unbalanced.py b/test/test_unbalanced.py index b4fa355..b39e457 100644 --- a/test/test_unbalanced.py +++ b/test/test_unbalanced.py @@ -39,3 +39,33 @@ def test_unbalanced_convergence(method): u_final, log["u"], atol=1e-05) np.testing.assert_allclose( v_final, log["v"], atol=1e-05) + + +def test_unbalanced_barycenter(): + # test generalized sinkhorn for unbalanced OT barycenter + n = 100 + rng = np.random.RandomState(42) + + x = rng.randn(n, 2) + A = rng.rand(n, 2) + + # make dists unbalanced + A = A * np.array([1, 2])[None, :] + M = ot.dist(x, x) + epsilon = 1. + alpha = 1. + K = np.exp(- M / epsilon) + + q, log = ot.unbalanced.barycenter_unbalanced(A, M, reg=epsilon, alpha=alpha, + stopThr=1e-10, + log=True) + + # check fixed point equations + fi = alpha / (alpha + epsilon) + v_final = (q[:, None] / K.T.dot(log["u"])) ** fi + u_final = (A / K.dot(log["v"])) ** fi + + np.testing.assert_allclose( + u_final, log["u"], atol=1e-05) + np.testing.assert_allclose( + v_final, log["v"], atol=1e-05) -- cgit v1.2.3 From 897982718a5fd81a9a591d80a7d50839399fc088 Mon Sep 17 00:00:00 2001 From: Hicham Janati Date: Tue, 18 Jun 2019 16:40:06 +0200 Subject: fix func names + add more tests --- ot/__init__.py | 2 +- ot/bregman.py | 2 +- ot/unbalanced.py | 79 ++++++++++++++++++++++++++++++------------------- test/test_unbalanced.py | 79 +++++++++++++++++++++++++++++++++++++++++++++++-- 4 files changed, 127 insertions(+), 35 deletions(-) (limited to 'test') diff --git a/ot/__init__.py b/ot/__init__.py index 361be02..acb05e6 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -25,7 +25,7 @@ from . import unbalanced # OT functions from .lp import emd, emd2 from .bregman import sinkhorn, sinkhorn2, barycenter -from .unbalanced import sinkhorn_unbalanced +from .unbalanced import sinkhorn_unbalanced, barycenter_unbalanced from .da import sinkhorn_lpl1_mm # utils functions diff --git a/ot/bregman.py b/ot/bregman.py index 321712b..09716e6 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -241,7 +241,7 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, b = np.asarray(b, dtype=np.float64) if len(b.shape) < 2: - b = b.reshape((-1, 1)) + b = b[:, None] return sink() diff --git a/ot/unbalanced.py b/ot/unbalanced.py index a30fc18..97e2576 100644 --- a/ot/unbalanced.py +++ b/ot/unbalanced.py @@ -73,8 +73,9 @@ def sinkhorn_unbalanced(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000, >>> a=[.5, .5] >>> b=[.5, .5] >>> M=[[0., 1.], [1., 0.]] - >>> ot.sinkhorn2(a, b, M, 1, 1) - array([0.26894142]) + >>> ot.sinkhorn_unbalanced(a, b, M, 1, 1) + array([[0.51122823, 0.18807035], + [0.18807035, 0.51122823]]) References @@ -91,28 +92,36 @@ def sinkhorn_unbalanced(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000, See Also -------- - ot.unbalanced.sinkhorn_knopp : Unbalanced Classic Sinkhorn [10] - ot.unbalanced.sinkhorn_stabilized: Unbalanced Stabilized sinkhorn [9][10] - ot.unbalanced.sinkhorn_epsilon_scaling: Unbalanced Sinkhorn with epslilon scaling [9][10] + ot.unbalanced.sinkhorn_knopp_unbalanced : Unbalanced Classic Sinkhorn [10] + ot.unbalanced.sinkhorn_stabilized_unbalanced: Unbalanced Stabilized sinkhorn [9][10] + ot.unbalanced.sinkhorn_epsilon_scaling_unbalanced: Unbalanced Sinkhorn with epslilon scaling [9][10] """ if method.lower() == 'sinkhorn': def sink(): - return sinkhorn_knopp(a, b, M, reg, alpha, numItermax=numItermax, - stopThr=stopThr, verbose=verbose, log=log, **kwargs) - else: - warnings.warn('Unknown method. Falling back to classic Sinkhorn Knopp') + return sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, + numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, **kwargs) + + elif method.lower() in ['sinkhorn_stabilized', 'sinkhorn_epsilon_scaling']: + warnings.warn('Method not implemented yet. Using classic Sinkhorn Knopp') def sink(): - return sinkhorn_knopp(a, b, M, reg, alpha, numItermax=numItermax, - stopThr=stopThr, verbose=verbose, log=log, **kwargs) + return sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, + numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, **kwargs) + else: + raise ValueError('Unknown method. Using classic Sinkhorn Knopp') return sink() -def sinkhorn2(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000, - stopThr=1e-9, verbose=False, log=False, **kwargs): +def sinkhorn_unbalanced2(a, b, M, reg, alpha, method='sinkhorn', + numItermax=1000, stopThr=1e-9, verbose=False, + log=False, **kwargs): u""" Solve the entropic regularization unbalanced optimal transport problem and return the loss @@ -173,8 +182,8 @@ def sinkhorn2(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000, >>> a=[.5, .10] >>> b=[.5, .5] >>> M=[[0., 1.],[1., 0.]] - >>> ot.sinkhorn2(a, b, M, 1., 1.) - array([ 0.26894142]) + >>> ot.unbalanced.sinkhorn_unbalanced2(a, b, M, 1., 1.) + array([0.31912866]) @@ -199,23 +208,31 @@ def sinkhorn2(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000, if method.lower() == 'sinkhorn': def sink(): - return sinkhorn_knopp(a, b, M, reg, alpha, numItermax=numItermax, - stopThr=stopThr, verbose=verbose, log=log, **kwargs) - else: - warnings.warn('Unknown method using classic Sinkhorn Knopp') + return sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, + numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, **kwargs) + + elif method.lower() in ['sinkhorn_stabilized', 'sinkhorn_epsilon_scaling']: + warnings.warn('Method not implemented yet. Using classic Sinkhorn Knopp') def sink(): - return sinkhorn_knopp(a, b, M, reg, alpha, **kwargs) + return sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, + numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, **kwargs) + else: + raise ValueError('Unknown method. Using classic Sinkhorn Knopp') b = np.asarray(b, dtype=np.float64) if len(b.shape) < 2: - b = b[None, :] + b = b[:, None] return sink() -def sinkhorn_knopp(a, b, M, reg, alpha, numItermax=1000, - stopThr=1e-9, verbose=False, log=False, **kwargs): +def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000, + stopThr=1e-9, verbose=False, log=False, **kwargs): """ Solve the entropic regularization unbalanced optimal transport problem and return the loss @@ -273,10 +290,9 @@ def sinkhorn_knopp(a, b, M, reg, alpha, numItermax=1000, >>> a=[.5, .15] >>> b=[.5, .5] >>> M=[[0., 1.],[1., 0.]] - >>> ot.sinkhorn(a, b, M, 1., 1.) - array([[ 0.36552929, 0.13447071], - [ 0.13447071, 0.36552929]]) - + >>> ot.sinkhorn_knopp_unbalanced(a, b, M, 1., 1.) + array([[0.52761554, 0.22392482], + [0.10286295, 0.32257641]]) References ---------- @@ -303,8 +319,7 @@ def sinkhorn_knopp(a, b, M, reg, alpha, numItermax=1000, if len(b) == 0: b = np.ones(n_b, dtype=np.float64) / n_b - assert n_a == len(a) and n_b == len(b) - if b.ndim > 1: + if len(b.shape) > 1: n_hists = b.shape[1] else: n_hists = 0 @@ -315,8 +330,9 @@ def sinkhorn_knopp(a, b, M, reg, alpha, numItermax=1000, # we assume that no distances are null except those of the diagonal of # distances if n_hists: - u = np.ones((n_a, n_hists)) / n_a + u = np.ones((n_a, 1)) / n_a v = np.ones((n_b, n_hists)) / n_b + a = a.reshape(n_a, 1) else: u = np.ones(n_a) / n_a v = np.ones(n_b) / n_b @@ -332,6 +348,7 @@ def sinkhorn_knopp(a, b, M, reg, alpha, numItermax=1000, cpt = 0 err = 1. + while (err > stopThr and cpt < numItermax): uprev = u vprev = v @@ -473,7 +490,7 @@ def barycenter_unbalanced(A, M, reg, alpha, weights=None, numItermax=1000, or np.any(np.isinf(u)) or np.any(np.isinf(v))): # we have reached the machine precision # come back to previous solution and quit loop - warnings.warn('Numerical errors at iteration', cpt) + warnings.warn('Numerical errors at iteration %s' % cpt) u = uprev v = vprev break diff --git a/test/test_unbalanced.py b/test/test_unbalanced.py index b39e457..1395fe1 100644 --- a/test/test_unbalanced.py +++ b/test/test_unbalanced.py @@ -29,7 +29,8 @@ def test_unbalanced_convergence(method): G, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon, alpha=alpha, stopThr=1e-10, method=method, log=True) - + loss = ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, alpha, + method=method) # check fixed point equations fi = alpha / (alpha + epsilon) v_final = (b / K.T.dot(log["u"])) ** fi @@ -40,6 +41,44 @@ def test_unbalanced_convergence(method): np.testing.assert_allclose( v_final, log["v"], atol=1e-05) + # check if sinkhorn_unbalanced2 returns the correct loss + np.testing.assert_allclose((G * M).sum(), loss, atol=1e-5) + + +@pytest.mark.parametrize("method", ["sinkhorn"]) +def test_unbalanced_multiple_inputs(method): + # test generalized sinkhorn for unbalanced OT + n = 100 + rng = np.random.RandomState(42) + + x = rng.randn(n, 2) + a = ot.utils.unif(n) + + # make dists unbalanced + b = rng.rand(n, 2) + + M = ot.dist(x, x) + epsilon = 1. + alpha = 1. + K = np.exp(- M / epsilon) + + loss, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon, + alpha=alpha, + stopThr=1e-10, method=method, + log=True) + # check fixed point equations + fi = alpha / (alpha + epsilon) + v_final = (b / K.T.dot(log["u"])) ** fi + + u_final = (a[:, None] / K.dot(log["v"])) ** fi + + np.testing.assert_allclose( + u_final, log["u"], atol=1e-05) + np.testing.assert_allclose( + v_final, log["v"], atol=1e-05) + + assert len(loss) == b.shape[1] + def test_unbalanced_barycenter(): # test generalized sinkhorn for unbalanced OT barycenter @@ -59,7 +98,6 @@ def test_unbalanced_barycenter(): q, log = ot.unbalanced.barycenter_unbalanced(A, M, reg=epsilon, alpha=alpha, stopThr=1e-10, log=True) - # check fixed point equations fi = alpha / (alpha + epsilon) v_final = (q[:, None] / K.T.dot(log["u"])) ** fi @@ -69,3 +107,40 @@ def test_unbalanced_barycenter(): u_final, log["u"], atol=1e-05) np.testing.assert_allclose( v_final, log["v"], atol=1e-05) + + +def test_implemented_methods(): + IMPLEMENTED_METHODS = ['sinkhorn'] + TO_BE_IMPLEMENTED_METHODS = ['sinkhorn_stabilized', + 'sinkhorn_epsilon_scaling'] + NOT_VALID_TOKENS = ['foo'] + # test generalized sinkhorn for unbalanced OT barycenter + n = 3 + rng = np.random.RandomState(42) + + x = rng.randn(n, 2) + a = ot.utils.unif(n) + + # make dists unbalanced + b = ot.utils.unif(n) * 1.5 + + M = ot.dist(x, x) + epsilon = 1. + alpha = 1. + for method in IMPLEMENTED_METHODS: + ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, alpha, + method=method) + ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, alpha, + method=method) + with pytest.warns(UserWarning, match='not implemented'): + for method in set(TO_BE_IMPLEMENTED_METHODS): + ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, alpha, + method=method) + ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, alpha, + method=method) + with pytest.raises(ValueError): + for method in set(NOT_VALID_TOKENS): + ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, alpha, + method=method) + ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, alpha, + method=method) -- cgit v1.2.3