summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHicham Janati <hicham.janati@inria.fr>2019-06-12 17:52:02 +0200
committerHicham Janati <hicham.janati@inria.fr>2019-06-12 17:52:02 +0200
commit50bc90058940645a13e2f3e41129bdc97161dc63 (patch)
tree24031123549ee349344c83875903d5d313e26292
parent12ed1581225f70c7c8777b6ce31710453fda7f51 (diff)
add unbalanced barycenters
-rw-r--r--examples/plot_UOT_barycenter_1D.py164
-rw-r--r--ot/unbalanced.py118
-rw-r--r--test/test_unbalanced.py30
3 files changed, 312 insertions, 0 deletions
diff --git a/examples/plot_UOT_barycenter_1D.py b/examples/plot_UOT_barycenter_1D.py
new file mode 100644
index 0000000..8dfb84f
--- /dev/null
+++ b/examples/plot_UOT_barycenter_1D.py
@@ -0,0 +1,164 @@
+# -*- coding: utf-8 -*-
+"""
+===========================================================
+1D Wasserstein barycenter demo for Unbalanced distributions
+===========================================================
+
+This example illustrates the computation of regularized Wassersyein Barycenter
+as proposed in [10] for Unbalanced inputs.
+
+
+[10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
+
+"""
+
+# Author: Hicham Janati <hicham.janati@inria.fr>
+#
+# License: MIT License
+
+import numpy as np
+import matplotlib.pylab as pl
+import ot
+# necessary for 3d plot even if not used
+from mpl_toolkits.mplot3d import Axes3D # noqa
+from matplotlib.collections import PolyCollection
+
+##############################################################################
+# Generate data
+# -------------
+
+#%% parameters
+
+n = 100 # nb bins
+
+# bin positions
+x = np.arange(n, dtype=np.float64)
+
+# Gaussian distributions
+a1 = ot.datasets.make_1D_gauss(n, m=20, s=5) # m= mean, s= std
+a2 = ot.datasets.make_1D_gauss(n, m=60, s=8)
+
+# make unbalanced dists
+a2 *= 3.
+
+# creating matrix A containing all distributions
+A = np.vstack((a1, a2)).T
+n_distributions = A.shape[1]
+
+# loss matrix + normalization
+M = ot.utils.dist0(n)
+M /= M.max()
+
+##############################################################################
+# Plot data
+# ---------
+
+#%% plot the distributions
+
+pl.figure(1, figsize=(6.4, 3))
+for i in range(n_distributions):
+ pl.plot(x, A[:, i])
+pl.title('Distributions')
+pl.tight_layout()
+
+##############################################################################
+# Barycenter computation
+# ----------------------
+
+#%% non weighted barycenter computation
+
+weight = 0.5 # 0<=weight<=1
+weights = np.array([1 - weight, weight])
+
+# l2bary
+bary_l2 = A.dot(weights)
+
+# wasserstein
+reg = 1e-3
+alpha = 1.
+
+bary_wass = ot.unbalanced.barycenter_unbalanced(A, M, reg, alpha, weights)
+
+pl.figure(2)
+pl.clf()
+pl.subplot(2, 1, 1)
+for i in range(n_distributions):
+ pl.plot(x, A[:, i])
+pl.title('Distributions')
+
+pl.subplot(2, 1, 2)
+pl.plot(x, bary_l2, 'r', label='l2')
+pl.plot(x, bary_wass, 'g', label='Wasserstein')
+pl.legend()
+pl.title('Barycenters')
+pl.tight_layout()
+
+##############################################################################
+# Barycentric interpolation
+# -------------------------
+
+#%% barycenter interpolation
+
+n_weight = 11
+weight_list = np.linspace(0, 1, n_weight)
+
+
+B_l2 = np.zeros((n, n_weight))
+
+B_wass = np.copy(B_l2)
+
+for i in range(0, n_weight):
+ weight = weight_list[i]
+ weights = np.array([1 - weight, weight])
+ B_l2[:, i] = A.dot(weights)
+ B_wass[:, i] = ot.unbalanced.barycenter_unbalanced(A, M, reg, alpha, weights)
+
+
+#%% plot interpolation
+
+pl.figure(3)
+
+cmap = pl.cm.get_cmap('viridis')
+verts = []
+zs = weight_list
+for i, z in enumerate(zs):
+ ys = B_l2[:, i]
+ verts.append(list(zip(x, ys)))
+
+ax = pl.gcf().gca(projection='3d')
+
+poly = PolyCollection(verts, facecolors=[cmap(a) for a in weight_list])
+poly.set_alpha(0.7)
+ax.add_collection3d(poly, zs=zs, zdir='y')
+ax.set_xlabel('x')
+ax.set_xlim3d(0, n)
+ax.set_ylabel(r'$\alpha$')
+ax.set_ylim3d(0, 1)
+ax.set_zlabel('')
+ax.set_zlim3d(0, B_l2.max() * 1.01)
+pl.title('Barycenter interpolation with l2')
+pl.tight_layout()
+
+pl.figure(4)
+cmap = pl.cm.get_cmap('viridis')
+verts = []
+zs = weight_list
+for i, z in enumerate(zs):
+ ys = B_wass[:, i]
+ verts.append(list(zip(x, ys)))
+
+ax = pl.gcf().gca(projection='3d')
+
+poly = PolyCollection(verts, facecolors=[cmap(a) for a in weight_list])
+poly.set_alpha(0.7)
+ax.add_collection3d(poly, zs=zs, zdir='y')
+ax.set_xlabel('x')
+ax.set_xlim3d(0, n)
+ax.set_ylabel(r'$\alpha$')
+ax.set_ylim3d(0, 1)
+ax.set_zlabel('')
+ax.set_zlim3d(0, B_l2.max() * 1.01)
+pl.title('Barycenter interpolation with Wasserstein')
+pl.tight_layout()
+
+pl.show()
diff --git a/ot/unbalanced.py b/ot/unbalanced.py
index f4208b5..a30fc18 100644
--- a/ot/unbalanced.py
+++ b/ot/unbalanced.py
@@ -380,3 +380,121 @@ def sinkhorn_knopp(a, b, M, reg, alpha, numItermax=1000,
return u[:, None] * K * v[None, :], log
else:
return u[:, None] * K * v[None, :]
+
+
+def barycenter_unbalanced(A, M, reg, alpha, weights=None, numItermax=1000,
+ stopThr=1e-4, verbose=False, log=False):
+ """Compute the entropic regularized unbalanced wasserstein barycenter of distributions A
+
+ The function solves the following optimization problem:
+
+ .. math::
+ \mathbf{a} = arg\min_\mathbf{a} \sum_i Wu_{reg}(\mathbf{a},\mathbf{a}_i)
+
+ where :
+
+ - :math:`W_{reg}(\cdot,\cdot)` is the unbalanced entropic regularized Wasserstein distance (see ot.unbalanced.sinkhorn_unbalanced)
+ - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}`
+ - reg and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT
+ - alpha is the marginal relaxation hyperparameter
+ The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in [10]_
+
+ Parameters
+ ----------
+ A : np.ndarray (d,n)
+ n training distributions a_i of size d
+ M : np.ndarray (d,d)
+ loss matrix for OT
+ reg : float
+ Regularization term > 0
+ alpha : float
+ Regularization term > 0
+ weights : np.ndarray (n,)
+ Weights of each histogram a_i on the simplex (barycentric coodinates)
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshol on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+
+ Returns
+ -------
+ a : (d,) ndarray
+ Unbalanced Wasserstein barycenter
+ log : dict
+ log dictionary return only if log==True in parameters
+
+
+ References
+ ----------
+
+ .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). Iterative Bregman projections for regularized transportation problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138.
+ .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
+
+
+ """
+ p, n_hists = A.shape
+ if weights is None:
+ weights = np.ones(n_hists) / n_hists
+ else:
+ assert(len(weights) == A.shape[1])
+
+ if log:
+ log = {'err': []}
+
+ K = np.exp(- M / reg)
+
+ fi = alpha / (alpha + reg)
+
+ v = np.ones((p, n_hists)) / p
+ u = np.ones((p, 1)) / p
+
+ cpt = 0
+ err = 1.
+
+ while (err > stopThr and cpt < numItermax):
+ uprev = u
+ vprev = v
+
+ Kv = K.dot(v)
+ u = (A / Kv) ** fi
+ Ktu = K.T.dot(u)
+ q = ((Ktu ** (1 - fi)).dot(weights))
+ q = q ** (1 / (1 - fi))
+ Q = q[:, None]
+ v = (Q / Ktu) ** fi
+
+ if (np.any(Ktu == 0.)
+ or np.any(np.isnan(u)) or np.any(np.isnan(v))
+ or np.any(np.isinf(u)) or np.any(np.isinf(v))):
+ # we have reached the machine precision
+ # come back to previous solution and quit loop
+ warnings.warn('Numerical errors at iteration', cpt)
+ u = uprev
+ v = vprev
+ break
+ if cpt % 10 == 0:
+ # we can speed up the process by checking for the error only all
+ # the 10th iterations
+ err = np.sum((u - uprev) ** 2) / np.sum((u) ** 2) + \
+ np.sum((v - vprev) ** 2) / np.sum((v) ** 2)
+ if log:
+ log['err'].append(err)
+ if verbose:
+ if cpt % 50 == 0:
+ print(
+ '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(cpt, err))
+
+ cpt += 1
+ if log:
+ log['niter'] = cpt
+ log['u'] = u
+ log['v'] = v
+ return q, log
+ else:
+ return q
diff --git a/test/test_unbalanced.py b/test/test_unbalanced.py
index b4fa355..b39e457 100644
--- a/test/test_unbalanced.py
+++ b/test/test_unbalanced.py
@@ -39,3 +39,33 @@ def test_unbalanced_convergence(method):
u_final, log["u"], atol=1e-05)
np.testing.assert_allclose(
v_final, log["v"], atol=1e-05)
+
+
+def test_unbalanced_barycenter():
+ # test generalized sinkhorn for unbalanced OT barycenter
+ n = 100
+ rng = np.random.RandomState(42)
+
+ x = rng.randn(n, 2)
+ A = rng.rand(n, 2)
+
+ # make dists unbalanced
+ A = A * np.array([1, 2])[None, :]
+ M = ot.dist(x, x)
+ epsilon = 1.
+ alpha = 1.
+ K = np.exp(- M / epsilon)
+
+ q, log = ot.unbalanced.barycenter_unbalanced(A, M, reg=epsilon, alpha=alpha,
+ stopThr=1e-10,
+ log=True)
+
+ # check fixed point equations
+ fi = alpha / (alpha + epsilon)
+ v_final = (q[:, None] / K.T.dot(log["u"])) ** fi
+ u_final = (A / K.dot(log["v"])) ** fi
+
+ np.testing.assert_allclose(
+ u_final, log["u"], atol=1e-05)
+ np.testing.assert_allclose(
+ v_final, log["v"], atol=1e-05)