From 0d23718409b1f0ac41b9302d98ca3d1ab9577855 Mon Sep 17 00:00:00 2001 From: Hicham Janati Date: Fri, 19 Jul 2019 17:04:14 +0200 Subject: remove square in convergence check --- ot/unbalanced.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/ot/unbalanced.py b/ot/unbalanced.py index 50ec03c..f6c2d5f 100644 --- a/ot/unbalanced.py +++ b/ot/unbalanced.py @@ -371,8 +371,9 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000, if cpt % 10 == 0: # we can speed up the process by checking for the error only all # the 10th iterations - err = np.sum((u - uprev)**2) / np.sum((u)**2) + \ - np.sum((v - vprev)**2) / np.sum((v)**2) + err_u = abs(u - uprev).max() / max(abs(u), abs(uprev), 1.) + err_v = abs(v - vprev).max() / max(abs(v), abs(vprev), 1.) + err = 0.5 * (err_u + err_v) if log: log['err'].append(err) if verbose: @@ -498,8 +499,9 @@ def barycenter_unbalanced(A, M, reg, alpha, weights=None, numItermax=1000, if cpt % 10 == 0: # we can speed up the process by checking for the error only all # the 10th iterations - err = np.sum((u - uprev) ** 2) / np.sum((u) ** 2) + \ - np.sum((v - vprev) ** 2) / np.sum((v) ** 2) + err_u = abs(u - uprev).max() / max(abs(u), abs(uprev), 1.) + err_v = abs(v - vprev).max() / max(abs(v), abs(vprev), 1.) + err = 0.5 * (err_u + err_v) if log: log['err'].append(err) if verbose: -- cgit v1.2.3 From 10accb13c2f22c946b65b249d7aae6e4f6af7579 Mon Sep 17 00:00:00 2001 From: Hicham Janati Date: Mon, 22 Jul 2019 14:53:45 +0200 Subject: add unbalanced with stabilization --- ot/unbalanced.py | 279 ++++++++++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 245 insertions(+), 34 deletions(-) diff --git a/ot/unbalanced.py b/ot/unbalanced.py index f6c2d5f..ca24e8b 100644 --- a/ot/unbalanced.py +++ b/ot/unbalanced.py @@ -9,10 +9,12 @@ Regularized Unbalanced OT from __future__ import division import warnings import numpy as np +from scipy.misc import logsumexp + # from .utils import unif, dist -def sinkhorn_unbalanced(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000, +def sinkhorn_unbalanced(a, b, M, reg, mu, method='sinkhorn', numItermax=1000, stopThr=1e-9, verbose=False, log=False, **kwargs): r""" Solve the unbalanced entropic regularization optimal transport problem and return the loss @@ -20,7 +22,7 @@ def sinkhorn_unbalanced(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000, The function solves the following optimization problem: .. math:: - W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + \\alpha KL(\gamma 1, a) + \\alpha KL(\gamma^T 1, b) + W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + \\mu KL(\gamma 1, a) + \\mu KL(\gamma^T 1, b) s.t. \gamma\geq 0 @@ -45,11 +47,11 @@ def sinkhorn_unbalanced(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000, loss matrix reg : float Entropy regularization term > 0 - alpha : float + mu : float Marginal relaxation term > 0 method : str method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or - 'sinkhorn_epsilon_scaling', see those function for specific parameters + 'sinkhorn_reg_scaling', see those function for specific parameters numItermax : int, optional Max number of iterations stopThr : float, optional @@ -95,22 +97,29 @@ def sinkhorn_unbalanced(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000, -------- ot.unbalanced.sinkhorn_knopp_unbalanced : Unbalanced Classic Sinkhorn [10] ot.unbalanced.sinkhorn_stabilized_unbalanced: Unbalanced Stabilized sinkhorn [9][10] - ot.unbalanced.sinkhorn_epsilon_scaling_unbalanced: Unbalanced Sinkhorn with epslilon scaling [9][10] + ot.unbalanced.sinkhorn_reg_scaling_unbalanced: Unbalanced Sinkhorn with epslilon scaling [9][10] """ if method.lower() == 'sinkhorn': def sink(): - return sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, + return sinkhorn_knopp_unbalanced(a, b, M, reg, mu, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) - elif method.lower() in ['sinkhorn_stabilized', 'sinkhorn_epsilon_scaling']: + elif method.lower() == 'sinkhorn_stabilized': + def sink(): + return sinkhorn_stabilized_unbalanced(a, b, M, reg, mu, + numItermax=numItermax, + stopThr=stopThr, + verbose=verbose, + log=log, **kwargs) + elif method.lower() in ['sinkhorn_reg_scaling']: warnings.warn('Method not implemented yet. Using classic Sinkhorn Knopp') def sink(): - return sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, + return sinkhorn_knopp_unbalanced(a, b, M, reg, mu, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) @@ -120,7 +129,7 @@ def sinkhorn_unbalanced(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000, return sink() -def sinkhorn_unbalanced2(a, b, M, reg, alpha, method='sinkhorn', +def sinkhorn_unbalanced2(a, b, M, reg, mu, method='sinkhorn', numItermax=1000, stopThr=1e-9, verbose=False, log=False, **kwargs): r""" @@ -129,7 +138,7 @@ def sinkhorn_unbalanced2(a, b, M, reg, alpha, method='sinkhorn', The function solves the following optimization problem: .. math:: - W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + \\alpha KL(\gamma 1, a) + \\alpha KL(\gamma^T 1, b) + W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + \\mu KL(\gamma 1, a) + \\mu KL(\gamma^T 1, b) s.t. \gamma\geq 0 @@ -154,11 +163,11 @@ def sinkhorn_unbalanced2(a, b, M, reg, alpha, method='sinkhorn', loss matrix reg : float Entropy regularization term > 0 - alpha : float + mu : float Marginal relaxation term > 0 method : str method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or - 'sinkhorn_epsilon_scaling', see those function for specific parameters + 'sinkhorn_reg_scaling', see those function for specific parameters numItermax : int, optional Max number of iterations stopThr : float, optional @@ -203,22 +212,29 @@ def sinkhorn_unbalanced2(a, b, M, reg, alpha, method='sinkhorn', -------- ot.unbalanced.sinkhorn_knopp : Unbalanced Classic Sinkhorn [10] ot.unbalanced.sinkhorn_stabilized: Unbalanced Stabilized sinkhorn [9][10] - ot.unbalanced.sinkhorn_epsilon_scaling: Unbalanced Sinkhorn with epslilon scaling [9][10] + ot.unbalanced.sinkhorn_reg_scaling: Unbalanced Sinkhorn with epslilon scaling [9][10] """ if method.lower() == 'sinkhorn': def sink(): - return sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, + return sinkhorn_knopp_unbalanced(a, b, M, reg, mu, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) - elif method.lower() in ['sinkhorn_stabilized', 'sinkhorn_epsilon_scaling']: + elif method.lower() == 'sinkhorn_stabilized': + def sink(): + return sinkhorn_stabilized_unbalanced(a, b, M, reg, mu, + numItermax=numItermax, + stopThr=stopThr, + verbose=verbose, + log=log, **kwargs) + elif method.lower() in ['sinkhorn_reg_scaling']: warnings.warn('Method not implemented yet. Using classic Sinkhorn Knopp') def sink(): - return sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, + return sinkhorn_knopp_unbalanced(a, b, M, reg, mu, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) @@ -232,7 +248,7 @@ def sinkhorn_unbalanced2(a, b, M, reg, alpha, method='sinkhorn', return sink() -def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000, +def sinkhorn_knopp_unbalanced(a, b, M, reg, mu, numItermax=1000, stopThr=1e-9, verbose=False, log=False, **kwargs): r""" Solve the entropic regularization unbalanced optimal transport problem and return the loss @@ -240,7 +256,7 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000, The function solves the following optimization problem: .. math:: - W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + \\alpha KL(\gamma 1, a) + \\alpha KL(\gamma^T 1, b) + W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + \\mu KL(\gamma 1, a) + \\mu KL(\gamma^T 1, b) s.t. \gamma\geq 0 @@ -265,7 +281,7 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000, loss matrix reg : float Entropy regularization term > 0 - alpha : float + mu : float Marginal relaxation term > 0 numItermax : int, optional Max number of iterations @@ -338,14 +354,12 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000, u = np.ones(n_a) / n_a v = np.ones(n_b) / n_b - # print(reg) # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute K = np.empty(M.shape, dtype=M.dtype) np.divide(M, -reg, out=K) np.exp(K, out=K) - # print(np.min(K)) - fi = alpha / (alpha + reg) + fi = mu / (mu + reg) cpt = 0 err = 1. @@ -371,8 +385,8 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000, if cpt % 10 == 0: # we can speed up the process by checking for the error only all # the 10th iterations - err_u = abs(u - uprev).max() / max(abs(u), abs(uprev), 1.) - err_v = abs(v - vprev).max() / max(abs(v), abs(vprev), 1.) + err_u = abs(u - uprev).max() / max(abs(u).max(), abs(uprev).max(), 1.) + err_v = abs(v - vprev).max() / max(abs(v).max(), abs(vprev).max(), 1.) err = 0.5 * (err_u + err_v) if log: log['err'].append(err) @@ -383,8 +397,8 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000, print('{:5d}|{:8e}|'.format(cpt, err)) cpt = cpt + 1 if log: - log['u'] = u - log['v'] = v + log['logu'] = np.log(u + 1e-16) + log['logv'] = np.log(v + 1e-16) if n_hists: # return only loss res = np.einsum('ik,ij,jk,ij->k', u, K, v, M) @@ -401,7 +415,204 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000, return u[:, None] * K * v[None, :] -def barycenter_unbalanced(A, M, reg, alpha, weights=None, numItermax=1000, +def sinkhorn_stabilized_unbalanced(a, b, M, reg, mu, tau=1e5, numItermax=1000, + stopThr=1e-9, verbose=False, log=False, + **kwargs): + r""" + Solve the entropic regularization unbalanced optimal transport problem and return the loss + + The function solves the following optimization problem using log-domain + stabilization as proposed in [10]: + + .. math:: + W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + \\mu KL(\gamma 1, a) + \\mu KL(\gamma^T 1, b) + + s.t. + \gamma\geq 0 + where : + + - M is the (ns, nt) metric cost matrix + - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - a and b are source and target weights + - KL is the Kullback-Leibler divergence + + The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in [10, 23]_ + + + Parameters + ---------- + a : np.ndarray (ns,) + samples weights in the source domain + b : np.ndarray (nt,) or np.ndarray (nt, n_hists) + samples in the target domain, compute sinkhorn with multiple targets + and fixed M if b is a matrix (return OT loss + dual variables in log) + M : np.ndarray (ns,nt) + loss matrix + reg : float + Entropy regularization term > 0 + mu : float + Marginal relaxation term > 0 + tau : float + thershold for max value in u or v for log scaling + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshol on error (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + + Returns + ------- + gamma : (ns x nt) ndarray + Optimal transportation matrix for the given parameters + log : dict + log dictionary return only if log==True in parameters + + Examples + -------- + + >>> import ot + >>> a=[.5, .5] + >>> b=[.5, .5] + >>> M=[[0., 1.],[1., 0.]] + >>> ot.unbalanced.sinkhorn_stabilized_unbalanced(a, b, M, 1., 1.) + array([[0.51122823, 0.18807035], + [0.18807035, 0.51122823]]) + + References + ---------- + + .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. + + .. [25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. : Learning with a Wasserstein Loss, Advances in Neural Information Processing Systems (NIPS) 2015 + + See Also + -------- + ot.lp.emd : Unregularized OT + ot.optim.cg : General regularized OT + + """ + + a = np.asarray(a, dtype=np.float64) + b = np.asarray(b, dtype=np.float64) + M = np.asarray(M, dtype=np.float64) + + n_a, n_b = M.shape + + if len(a) == 0: + a = np.ones(n_a, dtype=np.float64) / n_a + if len(b) == 0: + b = np.ones(n_b, dtype=np.float64) / n_b + + if len(b.shape) > 1: + n_hists = b.shape[1] + else: + n_hists = 0 + + if log: + log = {'err': []} + + # we assume that no distances are null except those of the diagonal of + # distances + if n_hists: + u = np.ones((n_a, n_hists)) / n_a + v = np.ones((n_b, n_hists)) / n_b + a = a.reshape(n_a, 1) + else: + u = np.ones(n_a) / n_a + v = np.ones(n_b) / n_b + + # print(reg) + # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute + K = np.empty(M.shape, dtype=M.dtype) + np.divide(M, -reg, out=K) + np.exp(K, out=K) + + fi = mu / (mu + reg) + + cpt = 0 + err = 1. + alpha = np.zeros(n_a) + beta = np.zeros(n_b) + while (err > stopThr and cpt < numItermax): + uprev = u + vprev = v + + Kv = K.dot(v) + f_alpha = np.exp(- alpha / (reg + mu)) + f_beta = np.exp(- beta / (reg + mu)) + + if n_hists: + f_alpha = f_alpha[:, None] + f_beta = f_beta[:, None] + u = ((a / (Kv + 1e-16)) ** fi) * f_alpha + Ktu = K.T.dot(u) + v = ((b / (Ktu + 1e-16)) ** fi) * f_beta + if (u > tau).any() or (v > tau).any(): + if n_hists: + alpha = alpha + reg * np.log(np.max(u, 1)) + beta = beta + reg * np.log(np.max(v, 1)) + else: + alpha = alpha + reg * np.log(np.max(u)) + beta = beta + reg * np.log(np.max(v)) + K = np.exp((alpha[:, None] + beta[None, :] - + M) / reg) + v = np.ones_like(v) + Kv = K.dot(v) + + if (np.any(Ktu == 0.) + or np.any(np.isnan(u)) or np.any(np.isnan(v)) + or np.any(np.isinf(u)) or np.any(np.isinf(v))): + # we have reached the machine precision + # come back to previous solution and quit loop + warnings.warn('Numerical errors at iteration %d' % cpt) + u = uprev + v = vprev + break + if cpt % 10 == 0: + # we can speed up the process by checking for the error only all + # the 10th iterations + err = abs(u - uprev).max() / max(abs(u).max(), abs(uprev).max(), + 1.) + if log: + log['err'].append(err) + if verbose: + if cpt % 200 == 0: + print( + '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(cpt, err)) + cpt = cpt + 1 + + if n_hists: + logu = alpha[:, None] / reg + np.log(u) + logv = beta[:, None] / reg + np.log(v) + else: + logu = alpha / reg + np.log(u) + logv = beta / reg + np.log(v) + if log: + log['logu'] = logu + log['logv'] = logv + if n_hists: # return only loss + res = logsumexp(np.log(M + 1e-100)[:, :, None] + logu[:, None, :] + + logv[None, :, :] - M[:, :, None] / reg, axis=(0, 1)) + res = np.exp(res) + if log: + return res, log + else: + return res + + else: # return OT matrix + ot_matrix = np.exp(logu[:, None] + logv[None, :] - M / reg) + if log: + return ot_matrix, log + else: + return ot_matrix + + +def barycenter_unbalanced(A, M, reg, mu, weights=None, numItermax=1000, stopThr=1e-4, verbose=False, log=False): r"""Compute the entropic regularized unbalanced wasserstein barycenter of distributions A @@ -415,7 +626,7 @@ def barycenter_unbalanced(A, M, reg, alpha, weights=None, numItermax=1000, - :math:`Wu_{reg}(\cdot,\cdot)` is the unbalanced entropic regularized Wasserstein distance (see ot.unbalanced.sinkhorn_unbalanced) - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}` - reg and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT - - alpha is the marginal relaxation hyperparameter + - mu is the marginal relaxation hyperparameter The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in [10]_ Parameters @@ -426,7 +637,7 @@ def barycenter_unbalanced(A, M, reg, alpha, weights=None, numItermax=1000, loss matrix for OT reg : float Entropy regularization term > 0 - alpha : float + mu : float Marginal relaxation term > 0 weights : np.ndarray (n,) Weights of each histogram a_i on the simplex (barycentric coodinates) @@ -467,7 +678,7 @@ def barycenter_unbalanced(A, M, reg, alpha, weights=None, numItermax=1000, K = np.exp(- M / reg) - fi = alpha / (alpha + reg) + fi = mu / (mu + reg) v = np.ones((p, n_hists)) / p u = np.ones((p, 1)) / p @@ -499,8 +710,8 @@ def barycenter_unbalanced(A, M, reg, alpha, weights=None, numItermax=1000, if cpt % 10 == 0: # we can speed up the process by checking for the error only all # the 10th iterations - err_u = abs(u - uprev).max() / max(abs(u), abs(uprev), 1.) - err_v = abs(v - vprev).max() / max(abs(v), abs(vprev), 1.) + err_u = abs(u - uprev).max() / max(abs(u).max(), abs(uprev).max(), 1.) + err_v = abs(v - vprev).max() / max(abs(v).max(), abs(vprev).max(), 1.) err = 0.5 * (err_u + err_v) if log: log['err'].append(err) @@ -513,8 +724,8 @@ def barycenter_unbalanced(A, M, reg, alpha, weights=None, numItermax=1000, cpt += 1 if log: log['niter'] = cpt - log['u'] = u - log['v'] = v + log['logu'] = np.log(u + 1e-16) + log['logv'] = np.log(v + 1e-16) return q, log else: return q -- cgit v1.2.3 From 5c0ed104b2890c609bdadfe0fcb0e836ba7a6ef1 Mon Sep 17 00:00:00 2001 From: Hicham Janati Date: Mon, 22 Jul 2019 14:54:01 +0200 Subject: add unbalanced tests with stabilization --- test/test_unbalanced.py | 116 ++++++++++++++++++++++++++++++++---------------- 1 file changed, 77 insertions(+), 39 deletions(-) diff --git a/test/test_unbalanced.py b/test/test_unbalanced.py index 1395fe1..fc7aa5e 100644 --- a/test/test_unbalanced.py +++ b/test/test_unbalanced.py @@ -8,8 +8,10 @@ import numpy as np import ot import pytest +from scipy.misc import logsumexp -@pytest.mark.parametrize("method", ["sinkhorn"]) + +@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"]) def test_unbalanced_convergence(method): # test generalized sinkhorn for unbalanced OT n = 100 @@ -23,29 +25,34 @@ def test_unbalanced_convergence(method): M = ot.dist(x, x) epsilon = 1. - alpha = 1. - K = np.exp(- M / epsilon) + mu = 1. - G, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon, alpha=alpha, + G, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon, mu=mu, stopThr=1e-10, method=method, log=True) - loss = ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, alpha, + loss = ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, mu, method=method) # check fixed point equations - fi = alpha / (alpha + epsilon) - v_final = (b / K.T.dot(log["u"])) ** fi - u_final = (a / K.dot(log["v"])) ** fi + # in log-domain + fi = mu / (mu + epsilon) + logb = np.log(b + 1e-16) + loga = np.log(a + 1e-16) + logKtu = logsumexp(log["logu"][None, :] - M.T / epsilon, axis=1) + logKv = logsumexp(log["logv"][None, :] - M / epsilon, axis=1) + + v_final = fi * (logb - logKtu) + u_final = fi * (loga - logKv) np.testing.assert_allclose( - u_final, log["u"], atol=1e-05) + u_final, log["logu"], atol=1e-05) np.testing.assert_allclose( - v_final, log["v"], atol=1e-05) + v_final, log["logv"], atol=1e-05) # check if sinkhorn_unbalanced2 returns the correct loss np.testing.assert_allclose((G * M).sum(), loss, atol=1e-5) -@pytest.mark.parametrize("method", ["sinkhorn"]) +@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"]) def test_unbalanced_multiple_inputs(method): # test generalized sinkhorn for unbalanced OT n = 100 @@ -59,27 +66,55 @@ def test_unbalanced_multiple_inputs(method): M = ot.dist(x, x) epsilon = 1. - alpha = 1. - K = np.exp(- M / epsilon) + mu = 1. - loss, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon, - alpha=alpha, + loss, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon, mu=mu, stopThr=1e-10, method=method, log=True) # check fixed point equations - fi = alpha / (alpha + epsilon) - v_final = (b / K.T.dot(log["u"])) ** fi - - u_final = (a[:, None] / K.dot(log["v"])) ** fi + # in log-domain + fi = mu / (mu + epsilon) + logb = np.log(b + 1e-16) + loga = np.log(a + 1e-16)[:, None] + logKtu = logsumexp(log["logu"][:, None, :] - M[:, :, None] / epsilon, + axis=0) + logKv = logsumexp(log["logv"][None, :] - M[:, :, None] / epsilon, axis=1) + v_final = fi * (logb - logKtu) + u_final = fi * (loga - logKv) np.testing.assert_allclose( - u_final, log["u"], atol=1e-05) + u_final, log["logu"], atol=1e-05) np.testing.assert_allclose( - v_final, log["v"], atol=1e-05) + v_final, log["logv"], atol=1e-05) assert len(loss) == b.shape[1] +def test_stabilized_vs_sinkhorn(): + # test if stable version matches sinkhorn + n = 100 + + # Gaussian distributions + a = ot.datasets.make_1D_gauss(n, m=20, s=5) # m= mean, s= std + b1 = ot.datasets.make_1D_gauss(n, m=60, s=8) + b2 = ot.datasets.make_1D_gauss(n, m=30, s=4) + + # creating matrix A containing all distributions + b = np.vstack((b1, b2)).T + + M = ot.utils.dist0(n) + M /= np.median(M) + epsilon = 0.1 + mu = 1. + G, log = ot.unbalanced.sinkhorn_stabilized_unbalanced(a, b, M, reg=epsilon, + mu=mu, + log=True) + G2, log2 = ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, mu, + method="sinkhorn", log=True) + + np.testing.assert_allclose(G, G2) + + def test_unbalanced_barycenter(): # test generalized sinkhorn for unbalanced OT barycenter n = 100 @@ -92,27 +127,30 @@ def test_unbalanced_barycenter(): A = A * np.array([1, 2])[None, :] M = ot.dist(x, x) epsilon = 1. - alpha = 1. - K = np.exp(- M / epsilon) + mu = 1. - q, log = ot.unbalanced.barycenter_unbalanced(A, M, reg=epsilon, alpha=alpha, + q, log = ot.unbalanced.barycenter_unbalanced(A, M, reg=epsilon, mu=mu, stopThr=1e-10, log=True) # check fixed point equations - fi = alpha / (alpha + epsilon) - v_final = (q[:, None] / K.T.dot(log["u"])) ** fi - u_final = (A / K.dot(log["v"])) ** fi + fi = mu / (mu + epsilon) + logA = np.log(A + 1e-16) + logq = np.log(q + 1e-16)[:, None] + logKtu = logsumexp(log["logu"][:, None, :] - M[:, :, None] / epsilon, + axis=0) + logKv = logsumexp(log["logv"][None, :] - M[:, :, None] / epsilon, axis=1) + v_final = fi * (logq - logKtu) + u_final = fi * (logA - logKv) np.testing.assert_allclose( - u_final, log["u"], atol=1e-05) + u_final, log["logu"], atol=1e-05) np.testing.assert_allclose( - v_final, log["v"], atol=1e-05) + v_final, log["logv"], atol=1e-05) def test_implemented_methods(): - IMPLEMENTED_METHODS = ['sinkhorn'] - TO_BE_IMPLEMENTED_METHODS = ['sinkhorn_stabilized', - 'sinkhorn_epsilon_scaling'] + IMPLEMENTED_METHODS = ['sinkhorn', 'sinkhorn_stabilized'] + TO_BE_IMPLEMENTED_METHODS = ['sinkhorn_reg_scaling'] NOT_VALID_TOKENS = ['foo'] # test generalized sinkhorn for unbalanced OT barycenter n = 3 @@ -126,21 +164,21 @@ def test_implemented_methods(): M = ot.dist(x, x) epsilon = 1. - alpha = 1. + mu = 1. for method in IMPLEMENTED_METHODS: - ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, alpha, + ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, mu, method=method) - ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, alpha, + ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, mu, method=method) with pytest.warns(UserWarning, match='not implemented'): for method in set(TO_BE_IMPLEMENTED_METHODS): - ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, alpha, + ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, mu, method=method) - ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, alpha, + ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, mu, method=method) with pytest.raises(ValueError): for method in set(NOT_VALID_TOKENS): - ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, alpha, + ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, mu, method=method) - ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, alpha, + ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, mu, method=method) -- cgit v1.2.3 From 50a5a4111ada5e8c208da1acf731608930d0a278 Mon Sep 17 00:00:00 2001 From: Hicham Janati Date: Mon, 22 Jul 2019 15:28:59 +0200 Subject: fix doctest examples --- ot/unbalanced.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/ot/unbalanced.py b/ot/unbalanced.py index ca24e8b..1453b31 100644 --- a/ot/unbalanced.py +++ b/ot/unbalanced.py @@ -77,8 +77,8 @@ def sinkhorn_unbalanced(a, b, M, reg, mu, method='sinkhorn', numItermax=1000, >>> b=[.5, .5] >>> M=[[0., 1.], [1., 0.]] >>> ot.sinkhorn_unbalanced(a, b, M, 1, 1) - array([[0.51122823, 0.18807035], - [0.18807035, 0.51122823]]) + array([[0.51122818, 0.18807034], + [0.18807034, 0.51122818]]) References @@ -193,7 +193,7 @@ def sinkhorn_unbalanced2(a, b, M, reg, mu, method='sinkhorn', >>> b=[.5, .5] >>> M=[[0., 1.],[1., 0.]] >>> ot.unbalanced.sinkhorn_unbalanced2(a, b, M, 1., 1.) - array([0.31912866]) + array([0.31912862]) @@ -308,8 +308,8 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, mu, numItermax=1000, >>> b=[.5, .5] >>> M=[[0., 1.],[1., 0.]] >>> ot.unbalanced.sinkhorn_knopp_unbalanced(a, b, M, 1., 1.) - array([[0.51122823, 0.18807035], - [0.18807035, 0.51122823]]) + array([[0.51122818, 0.18807034], + [0.18807034, 0.51122818]]) References ---------- @@ -479,8 +479,8 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, mu, tau=1e5, numItermax=1000, >>> b=[.5, .5] >>> M=[[0., 1.],[1., 0.]] >>> ot.unbalanced.sinkhorn_stabilized_unbalanced(a, b, M, 1., 1.) - array([[0.51122823, 0.18807035], - [0.18807035, 0.51122823]]) + array([[0.51122818, 0.18807034], + [0.18807034, 0.51122818]]) References ---------- -- cgit v1.2.3 From 09f3f640fc46ba4905d5508b704f2e5a90dda295 Mon Sep 17 00:00:00 2001 From: Hicham Janati Date: Tue, 23 Jul 2019 21:28:30 +0200 Subject: fix issue 94 + add test --- ot/bregman.py | 10 +++++++--- test/test_bregman.py | 25 +++++++++++++++++++++++++ 2 files changed, 32 insertions(+), 3 deletions(-) diff --git a/ot/bregman.py b/ot/bregman.py index f39145d..70e4208 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -765,10 +765,14 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, cpt = cpt + 1 - # print('err=',err,' cpt=',cpt) if log: - log['logu'] = alpha / reg + np.log(u) - log['logv'] = beta / reg + np.log(v) + if nbb: + alpha = alpha[:, None] + beta = beta[:, None] + logu = alpha / reg + np.log(u) + logv = beta / reg + np.log(v) + log['logu'] = logu + log['logv'] = logv log['alpha'] = alpha + reg * np.log(u) log['beta'] = beta + reg * np.log(v) log['warmstart'] = (log['alpha'], log['beta']) diff --git a/test/test_bregman.py b/test/test_bregman.py index 7f4972c..83ebba8 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -254,3 +254,28 @@ def test_empirical_sinkhorn_divergence(): emp_sinkhorn_div, sinkhorn_div, atol=1e-05) # cf conv emp sinkhorn np.testing.assert_allclose( emp_sinkhorn_div_log, sink_div_log, atol=1e-05) # cf conv emp sinkhorn + + +def test_stabilized_vs_sinkhorn_multidim(): + # test if stable version matches sinkhorn + # for multidimensional inputs + n = 100 + + # Gaussian distributions + a = ot.datasets.make_1D_gauss(n, m=20, s=5) # m= mean, s= std + b1 = ot.datasets.make_1D_gauss(n, m=60, s=8) + b2 = ot.datasets.make_1D_gauss(n, m=30, s=4) + + # creating matrix A containing all distributions + b = np.vstack((b1, b2)).T + + M = ot.utils.dist0(n) + M /= np.median(M) + epsilon = 0.1 + G, log = ot.bregman.sinkhorn(a, b, M, reg=epsilon, + method="sinkhorn_stabilized", + log=True) + G2, log2 = ot.bregman.sinkhorn(a, b, M, epsilon, + method="sinkhorn", log=True) + + np.testing.assert_allclose(G, G2) -- cgit v1.2.3 From a725f1dc0ac63ac919461ab8f2a23b111a410c00 Mon Sep 17 00:00:00 2001 From: Hicham Janati Date: Tue, 23 Jul 2019 21:51:10 +0200 Subject: rebase unbalanced --- ot/unbalanced.py | 291 ++++++++----------------------------------------------- 1 file changed, 39 insertions(+), 252 deletions(-) diff --git a/ot/unbalanced.py b/ot/unbalanced.py index 14e9e36..467fda2 100644 --- a/ot/unbalanced.py +++ b/ot/unbalanced.py @@ -9,12 +9,10 @@ Regularized Unbalanced OT from __future__ import division import warnings import numpy as np -from scipy.misc import logsumexp - # from .utils import unif, dist -def sinkhorn_unbalanced(a, b, M, reg, mu, method='sinkhorn', numItermax=1000, +def sinkhorn_unbalanced(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000, stopThr=1e-9, verbose=False, log=False, **kwargs): r""" Solve the unbalanced entropic regularization optimal transport problem and return the loss @@ -22,7 +20,7 @@ def sinkhorn_unbalanced(a, b, M, reg, mu, method='sinkhorn', numItermax=1000, The function solves the following optimization problem: .. math:: - W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + \\mu KL(\gamma 1, a) + \\mu KL(\gamma^T 1, b) + W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + \\alpha KL(\gamma 1, a) + \\alpha KL(\gamma^T 1, b) s.t. \gamma\geq 0 @@ -47,11 +45,11 @@ def sinkhorn_unbalanced(a, b, M, reg, mu, method='sinkhorn', numItermax=1000, loss matrix reg : float Entropy regularization term > 0 - mu : float + alpha : float Marginal relaxation term > 0 method : str method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or - 'sinkhorn_reg_scaling', see those function for specific parameters + 'sinkhorn_epsilon_scaling', see those function for specific parameters numItermax : int, optional Max number of iterations stopThr : float, optional @@ -77,8 +75,8 @@ def sinkhorn_unbalanced(a, b, M, reg, mu, method='sinkhorn', numItermax=1000, >>> b=[.5, .5] >>> M=[[0., 1.], [1., 0.]] >>> ot.sinkhorn_unbalanced(a, b, M, 1, 1) - array([[0.51122818, 0.18807034], - [0.18807034, 0.51122818]]) + array([[0.51122823, 0.18807035], + [0.18807035, 0.51122823]]) References @@ -97,29 +95,22 @@ def sinkhorn_unbalanced(a, b, M, reg, mu, method='sinkhorn', numItermax=1000, -------- ot.unbalanced.sinkhorn_knopp_unbalanced : Unbalanced Classic Sinkhorn [10] ot.unbalanced.sinkhorn_stabilized_unbalanced: Unbalanced Stabilized sinkhorn [9][10] - ot.unbalanced.sinkhorn_reg_scaling_unbalanced: Unbalanced Sinkhorn with epslilon scaling [9][10] + ot.unbalanced.sinkhorn_epsilon_scaling_unbalanced: Unbalanced Sinkhorn with epslilon scaling [9][10] """ if method.lower() == 'sinkhorn': def sink(): - return sinkhorn_knopp_unbalanced(a, b, M, reg, mu, + return sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) - elif method.lower() == 'sinkhorn_stabilized': - def sink(): - return sinkhorn_stabilized_unbalanced(a, b, M, reg, mu, - numItermax=numItermax, - stopThr=stopThr, - verbose=verbose, - log=log, **kwargs) - elif method.lower() in ['sinkhorn_reg_scaling']: + elif method.lower() in ['sinkhorn_stabilized', 'sinkhorn_epsilon_scaling']: warnings.warn('Method not implemented yet. Using classic Sinkhorn Knopp') def sink(): - return sinkhorn_knopp_unbalanced(a, b, M, reg, mu, + return sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) @@ -129,7 +120,7 @@ def sinkhorn_unbalanced(a, b, M, reg, mu, method='sinkhorn', numItermax=1000, return sink() -def sinkhorn_unbalanced2(a, b, M, reg, mu, method='sinkhorn', +def sinkhorn_unbalanced2(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000, stopThr=1e-9, verbose=False, log=False, **kwargs): r""" @@ -138,7 +129,7 @@ def sinkhorn_unbalanced2(a, b, M, reg, mu, method='sinkhorn', The function solves the following optimization problem: .. math:: - W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + \\mu KL(\gamma 1, a) + \\mu KL(\gamma^T 1, b) + W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + \\alpha KL(\gamma 1, a) + \\alpha KL(\gamma^T 1, b) s.t. \gamma\geq 0 @@ -163,11 +154,11 @@ def sinkhorn_unbalanced2(a, b, M, reg, mu, method='sinkhorn', loss matrix reg : float Entropy regularization term > 0 - mu : float + alpha : float Marginal relaxation term > 0 method : str method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or - 'sinkhorn_reg_scaling', see those function for specific parameters + 'sinkhorn_epsilon_scaling', see those function for specific parameters numItermax : int, optional Max number of iterations stopThr : float, optional @@ -193,7 +184,7 @@ def sinkhorn_unbalanced2(a, b, M, reg, mu, method='sinkhorn', >>> b=[.5, .5] >>> M=[[0., 1.],[1., 0.]] >>> ot.unbalanced.sinkhorn_unbalanced2(a, b, M, 1., 1.) - array([0.31912862]) + array([0.31912866]) @@ -212,29 +203,22 @@ def sinkhorn_unbalanced2(a, b, M, reg, mu, method='sinkhorn', -------- ot.unbalanced.sinkhorn_knopp : Unbalanced Classic Sinkhorn [10] ot.unbalanced.sinkhorn_stabilized: Unbalanced Stabilized sinkhorn [9][10] - ot.unbalanced.sinkhorn_reg_scaling: Unbalanced Sinkhorn with epslilon scaling [9][10] + ot.unbalanced.sinkhorn_epsilon_scaling: Unbalanced Sinkhorn with epslilon scaling [9][10] """ if method.lower() == 'sinkhorn': def sink(): - return sinkhorn_knopp_unbalanced(a, b, M, reg, mu, + return sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) - elif method.lower() == 'sinkhorn_stabilized': - def sink(): - return sinkhorn_stabilized_unbalanced(a, b, M, reg, mu, - numItermax=numItermax, - stopThr=stopThr, - verbose=verbose, - log=log, **kwargs) - elif method.lower() in ['sinkhorn_reg_scaling']: + elif method.lower() in ['sinkhorn_stabilized', 'sinkhorn_epsilon_scaling']: warnings.warn('Method not implemented yet. Using classic Sinkhorn Knopp') def sink(): - return sinkhorn_knopp_unbalanced(a, b, M, reg, mu, + return sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) @@ -248,7 +232,7 @@ def sinkhorn_unbalanced2(a, b, M, reg, mu, method='sinkhorn', return sink() -def sinkhorn_knopp_unbalanced(a, b, M, reg, mu, numItermax=1000, +def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000, stopThr=1e-9, verbose=False, log=False, **kwargs): r""" Solve the entropic regularization unbalanced optimal transport problem and return the loss @@ -256,7 +240,7 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, mu, numItermax=1000, The function solves the following optimization problem: .. math:: - W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + \\mu KL(\gamma 1, a) + \\mu KL(\gamma^T 1, b) + W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + \\alpha KL(\gamma 1, a) + \\alpha KL(\gamma^T 1, b) s.t. \gamma\geq 0 @@ -281,7 +265,7 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, mu, numItermax=1000, loss matrix reg : float Entropy regularization term > 0 - mu : float + alpha : float Marginal relaxation term > 0 numItermax : int, optional Max number of iterations @@ -308,8 +292,8 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, mu, numItermax=1000, >>> b=[.5, .5] >>> M=[[0., 1.],[1., 0.]] >>> ot.unbalanced.sinkhorn_knopp_unbalanced(a, b, M, 1., 1.) - array([[0.51122818, 0.18807034], - [0.18807034, 0.51122818]]) + array([[0.51122823, 0.18807035], + [0.18807035, 0.51122823]]) References ---------- @@ -354,12 +338,14 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, mu, numItermax=1000, u = np.ones(n_a) / n_a v = np.ones(n_b) / n_b + # print(reg) # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute K = np.empty(M.shape, dtype=M.dtype) np.divide(M, -reg, out=K) np.exp(K, out=K) - fi = mu / (mu + reg) + # print(np.min(K)) + fi = alpha / (alpha + reg) cpt = 0 err = 1. @@ -385,9 +371,8 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, mu, numItermax=1000, if cpt % 10 == 0: # we can speed up the process by checking for the error only all # the 10th iterations - err_u = abs(u - uprev).max() / max(abs(u).max(), abs(uprev).max(), 1.) - err_v = abs(v - vprev).max() / max(abs(v).max(), abs(vprev).max(), 1.) - err = 0.5 * (err_u + err_v) + err = np.sum((u - uprev)**2) / np.sum((u)**2) + \ + np.sum((v - vprev)**2) / np.sum((v)**2) if log: log['err'].append(err) if verbose: @@ -398,8 +383,8 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, mu, numItermax=1000, cpt += 1 if log: - log['logu'] = np.log(u + 1e-16) - log['logv'] = np.log(v + 1e-16) + log['u'] = u + log['v'] = v if n_hists: # return only loss res = np.einsum('ik,ij,jk,ij->k', u, K, v, M) @@ -416,204 +401,7 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, mu, numItermax=1000, return u[:, None] * K * v[None, :] -def sinkhorn_stabilized_unbalanced(a, b, M, reg, mu, tau=1e5, numItermax=1000, - stopThr=1e-9, verbose=False, log=False, - **kwargs): - r""" - Solve the entropic regularization unbalanced optimal transport problem and return the loss - - The function solves the following optimization problem using log-domain - stabilization as proposed in [10]: - - .. math:: - W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + \\mu KL(\gamma 1, a) + \\mu KL(\gamma^T 1, b) - - s.t. - \gamma\geq 0 - where : - - - M is the (ns, nt) metric cost matrix - - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - a and b are source and target weights - - KL is the Kullback-Leibler divergence - - The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in [10, 23]_ - - - Parameters - ---------- - a : np.ndarray (ns,) - samples weights in the source domain - b : np.ndarray (nt,) or np.ndarray (nt, n_hists) - samples in the target domain, compute sinkhorn with multiple targets - and fixed M if b is a matrix (return OT loss + dual variables in log) - M : np.ndarray (ns,nt) - loss matrix - reg : float - Entropy regularization term > 0 - mu : float - Marginal relaxation term > 0 - tau : float - thershold for max value in u or v for log scaling - numItermax : int, optional - Max number of iterations - stopThr : float, optional - Stop threshol on error (>0) - verbose : bool, optional - Print information along iterations - log : bool, optional - record log if True - - - Returns - ------- - gamma : (ns x nt) ndarray - Optimal transportation matrix for the given parameters - log : dict - log dictionary return only if log==True in parameters - - Examples - -------- - - >>> import ot - >>> a=[.5, .5] - >>> b=[.5, .5] - >>> M=[[0., 1.],[1., 0.]] - >>> ot.unbalanced.sinkhorn_stabilized_unbalanced(a, b, M, 1., 1.) - array([[0.51122818, 0.18807034], - [0.18807034, 0.51122818]]) - - References - ---------- - - .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. - - .. [25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. : Learning with a Wasserstein Loss, Advances in Neural Information Processing Systems (NIPS) 2015 - - See Also - -------- - ot.lp.emd : Unregularized OT - ot.optim.cg : General regularized OT - - """ - - a = np.asarray(a, dtype=np.float64) - b = np.asarray(b, dtype=np.float64) - M = np.asarray(M, dtype=np.float64) - - n_a, n_b = M.shape - - if len(a) == 0: - a = np.ones(n_a, dtype=np.float64) / n_a - if len(b) == 0: - b = np.ones(n_b, dtype=np.float64) / n_b - - if len(b.shape) > 1: - n_hists = b.shape[1] - else: - n_hists = 0 - - if log: - log = {'err': []} - - # we assume that no distances are null except those of the diagonal of - # distances - if n_hists: - u = np.ones((n_a, n_hists)) / n_a - v = np.ones((n_b, n_hists)) / n_b - a = a.reshape(n_a, 1) - else: - u = np.ones(n_a) / n_a - v = np.ones(n_b) / n_b - - # print(reg) - # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute - K = np.empty(M.shape, dtype=M.dtype) - np.divide(M, -reg, out=K) - np.exp(K, out=K) - - fi = mu / (mu + reg) - - cpt = 0 - err = 1. - alpha = np.zeros(n_a) - beta = np.zeros(n_b) - while (err > stopThr and cpt < numItermax): - uprev = u - vprev = v - - Kv = K.dot(v) - f_alpha = np.exp(- alpha / (reg + mu)) - f_beta = np.exp(- beta / (reg + mu)) - - if n_hists: - f_alpha = f_alpha[:, None] - f_beta = f_beta[:, None] - u = ((a / (Kv + 1e-16)) ** fi) * f_alpha - Ktu = K.T.dot(u) - v = ((b / (Ktu + 1e-16)) ** fi) * f_beta - if (u > tau).any() or (v > tau).any(): - if n_hists: - alpha = alpha + reg * np.log(np.max(u, 1)) - beta = beta + reg * np.log(np.max(v, 1)) - else: - alpha = alpha + reg * np.log(np.max(u)) - beta = beta + reg * np.log(np.max(v)) - K = np.exp((alpha[:, None] + beta[None, :] - - M) / reg) - v = np.ones_like(v) - Kv = K.dot(v) - - if (np.any(Ktu == 0.) - or np.any(np.isnan(u)) or np.any(np.isnan(v)) - or np.any(np.isinf(u)) or np.any(np.isinf(v))): - # we have reached the machine precision - # come back to previous solution and quit loop - warnings.warn('Numerical errors at iteration %d' % cpt) - u = uprev - v = vprev - break - if cpt % 10 == 0: - # we can speed up the process by checking for the error only all - # the 10th iterations - err = abs(u - uprev).max() / max(abs(u).max(), abs(uprev).max(), - 1.) - if log: - log['err'].append(err) - if verbose: - if cpt % 200 == 0: - print( - '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(cpt, err)) - cpt = cpt + 1 - - if n_hists: - logu = alpha[:, None] / reg + np.log(u) - logv = beta[:, None] / reg + np.log(v) - else: - logu = alpha / reg + np.log(u) - logv = beta / reg + np.log(v) - if log: - log['logu'] = logu - log['logv'] = logv - if n_hists: # return only loss - res = logsumexp(np.log(M + 1e-100)[:, :, None] + logu[:, None, :] + - logv[None, :, :] - M[:, :, None] / reg, axis=(0, 1)) - res = np.exp(res) - if log: - return res, log - else: - return res - - else: # return OT matrix - ot_matrix = np.exp(logu[:, None] + logv[None, :] - M / reg) - if log: - return ot_matrix, log - else: - return ot_matrix - - -def barycenter_unbalanced(A, M, reg, mu, weights=None, numItermax=1000, +def barycenter_unbalanced(A, M, reg, alpha, weights=None, numItermax=1000, stopThr=1e-4, verbose=False, log=False): r"""Compute the entropic regularized unbalanced wasserstein barycenter of distributions A @@ -627,7 +415,7 @@ def barycenter_unbalanced(A, M, reg, mu, weights=None, numItermax=1000, - :math:`Wu_{reg}(\cdot,\cdot)` is the unbalanced entropic regularized Wasserstein distance (see ot.unbalanced.sinkhorn_unbalanced) - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}` - reg and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT - - mu is the marginal relaxation hyperparameter + - alpha is the marginal relaxation hyperparameter The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in [10]_ Parameters @@ -638,7 +426,7 @@ def barycenter_unbalanced(A, M, reg, mu, weights=None, numItermax=1000, loss matrix for OT reg : float Entropy regularization term > 0 - mu : float + alpha : float Marginal relaxation term > 0 weights : np.ndarray (n,) Weights of each histogram a_i on the simplex (barycentric coodinates) @@ -679,7 +467,7 @@ def barycenter_unbalanced(A, M, reg, mu, weights=None, numItermax=1000, K = np.exp(- M / reg) - fi = mu / (mu + reg) + fi = alpha / (alpha + reg) v = np.ones((p, n_hists)) / p u = np.ones((p, 1)) / p @@ -711,9 +499,8 @@ def barycenter_unbalanced(A, M, reg, mu, weights=None, numItermax=1000, if cpt % 10 == 0: # we can speed up the process by checking for the error only all # the 10th iterations - err_u = abs(u - uprev).max() / max(abs(u).max(), abs(uprev).max(), 1.) - err_v = abs(v - vprev).max() / max(abs(v).max(), abs(vprev).max(), 1.) - err = 0.5 * (err_u + err_v) + err = np.sum((u - uprev) ** 2) / np.sum((u) ** 2) + \ + np.sum((v - vprev) ** 2) / np.sum((v) ** 2) if log: log['err'].append(err) if verbose: @@ -725,8 +512,8 @@ def barycenter_unbalanced(A, M, reg, mu, weights=None, numItermax=1000, cpt += 1 if log: log['niter'] = cpt - log['logu'] = np.log(u + 1e-16) - log['logv'] = np.log(v + 1e-16) + log['u'] = u + log['v'] = v return q, log else: return q -- cgit v1.2.3 From a507556b1901e16351c211e69b38d8d74ac2bc3d Mon Sep 17 00:00:00 2001 From: Hicham Janati Date: Tue, 23 Jul 2019 21:51:53 +0200 Subject: rebase unbalanced --- test/test_unbalanced.py | 116 ++++++++++++++++-------------------------------- 1 file changed, 39 insertions(+), 77 deletions(-) diff --git a/test/test_unbalanced.py b/test/test_unbalanced.py index fc7aa5e..1395fe1 100644 --- a/test/test_unbalanced.py +++ b/test/test_unbalanced.py @@ -8,10 +8,8 @@ import numpy as np import ot import pytest -from scipy.misc import logsumexp - -@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"]) +@pytest.mark.parametrize("method", ["sinkhorn"]) def test_unbalanced_convergence(method): # test generalized sinkhorn for unbalanced OT n = 100 @@ -25,34 +23,29 @@ def test_unbalanced_convergence(method): M = ot.dist(x, x) epsilon = 1. - mu = 1. + alpha = 1. + K = np.exp(- M / epsilon) - G, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon, mu=mu, + G, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon, alpha=alpha, stopThr=1e-10, method=method, log=True) - loss = ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, mu, + loss = ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, alpha, method=method) # check fixed point equations - # in log-domain - fi = mu / (mu + epsilon) - logb = np.log(b + 1e-16) - loga = np.log(a + 1e-16) - logKtu = logsumexp(log["logu"][None, :] - M.T / epsilon, axis=1) - logKv = logsumexp(log["logv"][None, :] - M / epsilon, axis=1) - - v_final = fi * (logb - logKtu) - u_final = fi * (loga - logKv) + fi = alpha / (alpha + epsilon) + v_final = (b / K.T.dot(log["u"])) ** fi + u_final = (a / K.dot(log["v"])) ** fi np.testing.assert_allclose( - u_final, log["logu"], atol=1e-05) + u_final, log["u"], atol=1e-05) np.testing.assert_allclose( - v_final, log["logv"], atol=1e-05) + v_final, log["v"], atol=1e-05) # check if sinkhorn_unbalanced2 returns the correct loss np.testing.assert_allclose((G * M).sum(), loss, atol=1e-5) -@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"]) +@pytest.mark.parametrize("method", ["sinkhorn"]) def test_unbalanced_multiple_inputs(method): # test generalized sinkhorn for unbalanced OT n = 100 @@ -66,55 +59,27 @@ def test_unbalanced_multiple_inputs(method): M = ot.dist(x, x) epsilon = 1. - mu = 1. + alpha = 1. + K = np.exp(- M / epsilon) - loss, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon, mu=mu, + loss, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon, + alpha=alpha, stopThr=1e-10, method=method, log=True) # check fixed point equations - # in log-domain - fi = mu / (mu + epsilon) - logb = np.log(b + 1e-16) - loga = np.log(a + 1e-16)[:, None] - logKtu = logsumexp(log["logu"][:, None, :] - M[:, :, None] / epsilon, - axis=0) - logKv = logsumexp(log["logv"][None, :] - M[:, :, None] / epsilon, axis=1) - v_final = fi * (logb - logKtu) - u_final = fi * (loga - logKv) + fi = alpha / (alpha + epsilon) + v_final = (b / K.T.dot(log["u"])) ** fi + + u_final = (a[:, None] / K.dot(log["v"])) ** fi np.testing.assert_allclose( - u_final, log["logu"], atol=1e-05) + u_final, log["u"], atol=1e-05) np.testing.assert_allclose( - v_final, log["logv"], atol=1e-05) + v_final, log["v"], atol=1e-05) assert len(loss) == b.shape[1] -def test_stabilized_vs_sinkhorn(): - # test if stable version matches sinkhorn - n = 100 - - # Gaussian distributions - a = ot.datasets.make_1D_gauss(n, m=20, s=5) # m= mean, s= std - b1 = ot.datasets.make_1D_gauss(n, m=60, s=8) - b2 = ot.datasets.make_1D_gauss(n, m=30, s=4) - - # creating matrix A containing all distributions - b = np.vstack((b1, b2)).T - - M = ot.utils.dist0(n) - M /= np.median(M) - epsilon = 0.1 - mu = 1. - G, log = ot.unbalanced.sinkhorn_stabilized_unbalanced(a, b, M, reg=epsilon, - mu=mu, - log=True) - G2, log2 = ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, mu, - method="sinkhorn", log=True) - - np.testing.assert_allclose(G, G2) - - def test_unbalanced_barycenter(): # test generalized sinkhorn for unbalanced OT barycenter n = 100 @@ -127,30 +92,27 @@ def test_unbalanced_barycenter(): A = A * np.array([1, 2])[None, :] M = ot.dist(x, x) epsilon = 1. - mu = 1. + alpha = 1. + K = np.exp(- M / epsilon) - q, log = ot.unbalanced.barycenter_unbalanced(A, M, reg=epsilon, mu=mu, + q, log = ot.unbalanced.barycenter_unbalanced(A, M, reg=epsilon, alpha=alpha, stopThr=1e-10, log=True) # check fixed point equations - fi = mu / (mu + epsilon) - logA = np.log(A + 1e-16) - logq = np.log(q + 1e-16)[:, None] - logKtu = logsumexp(log["logu"][:, None, :] - M[:, :, None] / epsilon, - axis=0) - logKv = logsumexp(log["logv"][None, :] - M[:, :, None] / epsilon, axis=1) - v_final = fi * (logq - logKtu) - u_final = fi * (logA - logKv) + fi = alpha / (alpha + epsilon) + v_final = (q[:, None] / K.T.dot(log["u"])) ** fi + u_final = (A / K.dot(log["v"])) ** fi np.testing.assert_allclose( - u_final, log["logu"], atol=1e-05) + u_final, log["u"], atol=1e-05) np.testing.assert_allclose( - v_final, log["logv"], atol=1e-05) + v_final, log["v"], atol=1e-05) def test_implemented_methods(): - IMPLEMENTED_METHODS = ['sinkhorn', 'sinkhorn_stabilized'] - TO_BE_IMPLEMENTED_METHODS = ['sinkhorn_reg_scaling'] + IMPLEMENTED_METHODS = ['sinkhorn'] + TO_BE_IMPLEMENTED_METHODS = ['sinkhorn_stabilized', + 'sinkhorn_epsilon_scaling'] NOT_VALID_TOKENS = ['foo'] # test generalized sinkhorn for unbalanced OT barycenter n = 3 @@ -164,21 +126,21 @@ def test_implemented_methods(): M = ot.dist(x, x) epsilon = 1. - mu = 1. + alpha = 1. for method in IMPLEMENTED_METHODS: - ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, mu, + ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, alpha, method=method) - ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, mu, + ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, alpha, method=method) with pytest.warns(UserWarning, match='not implemented'): for method in set(TO_BE_IMPLEMENTED_METHODS): - ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, mu, + ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, alpha, method=method) - ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, mu, + ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, alpha, method=method) with pytest.raises(ValueError): for method in set(NOT_VALID_TOKENS): - ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, mu, + ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, alpha, method=method) - ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, mu, + ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, alpha, method=method) -- cgit v1.2.3