From 2a3f2241951ea9cc044b4fba8a382b6ae9630513 Mon Sep 17 00:00:00 2001 From: AdrienCorenflos Date: Mon, 19 Apr 2021 14:57:51 +0300 Subject: BUG/DOC FIX - Sinkhorn divergence used the wrong weights, and sinkhorn2 didn't support epsilon_scaling method. (#235) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * FIX: 1. Documentation of loss specific functions 2. Sinkhorn divergence weights handling 3. Sinkhorn2 does not support epsilon scaling, so I removed it (it *should* arguably support it, but this would require a refactoring of the sinkhorn iterates pretty much everywhere, maybe should be done in torch first?) * Had some PEP8 issues Co-authored-by: RĂ©mi Flamary --- ot/bregman.py | 53 +++++++++++++++++++++++++--------------------------- test/test_bregman.py | 13 +++++++------ 2 files changed, 32 insertions(+), 34 deletions(-) diff --git a/ot/bregman.py b/ot/bregman.py index dcd35e1..559db14 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -14,11 +14,13 @@ Bregman projections solvers for entropic regularized OT # # License: MIT License -import numpy as np import warnings -from .utils import unif, dist + +import numpy as np from scipy.optimize import fmin_l_bfgs_b +from ot.utils import unif, dist + def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, stopThr=1e-9, verbose=False, log=False, **kwargs): @@ -179,8 +181,7 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, reg : float Regularization term >0 method : str - method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or - 'sinkhorn_epsilon_scaling', see those function for specific parameters + method used for the solver either 'sinkhorn', 'sinkhorn_stabilized', see those function for specific parameters numItermax : int, optional Max number of iterations stopThr : float, optional @@ -207,7 +208,7 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, Returns ------- - W : (n_hists) ndarray or float + W : (n_hists) ndarray Optimal transportation loss for the given parameters log : dict log dictionary return only if log==True in parameters @@ -244,12 +245,12 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, ot.bregman.sinkhorn_knopp : Classic Sinkhorn [2] ot.bregman.greenkhorn : Greenkhorn [21] ot.bregman.sinkhorn_stabilized: Stabilized sinkhorn [9][10] - ot.bregman.sinkhorn_epsilon_scaling: Sinkhorn with epslilon scaling [9][10] """ b = np.asarray(b, dtype=np.float64) if len(b.shape) < 2: b = b[:, None] + if method.lower() == 'sinkhorn': return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, @@ -258,10 +259,6 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) - elif method.lower() == 'sinkhorn_epsilon_scaling': - return sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=numItermax, - stopThr=stopThr, verbose=verbose, - log=log, **kwargs) else: raise ValueError("Unknown method '%s'." % method) @@ -745,8 +742,7 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, # remove numerical problems and store them in K if np.abs(u).max() > tau or np.abs(v).max() > tau: if n_hists: - alpha, beta = alpha + reg * \ - np.max(np.log(u), 1), beta + reg * np.max(np.log(v)) + alpha, beta = alpha + reg * np.max(np.log(u), 1), beta + reg * np.max(np.log(v)) else: alpha, beta = alpha + reg * np.log(u), beta + reg * np.log(v) if n_hists: @@ -1747,7 +1743,7 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', >>> reg = 0.1 >>> X_s = np.reshape(np.arange(n_samples_a), (n_samples_a, 1)) >>> X_t = np.reshape(np.arange(0, n_samples_b), (n_samples_b, 1)) - >>> empirical_sinkhorn(X_s, X_t, reg, verbose=False) # doctest: +NORMALIZE_WHITESPACE + >>> empirical_sinkhorn(X_s, X_t, reg=reg, verbose=False) # doctest: +NORMALIZE_WHITESPACE array([[4.99977301e-01, 2.26989344e-05], [2.26989344e-05, 4.99977301e-01]]) @@ -1825,8 +1821,8 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num Returns ------- - gamma : ndarray, shape (n_samples_a, n_samples_b) - Regularized optimal transportation matrix for the given parameters + W : (n_hists) ndarray or float + Optimal transportation loss for the given parameters log : dict log dictionary return only if log==True in parameters @@ -1838,8 +1834,9 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num >>> reg = 0.1 >>> X_s = np.reshape(np.arange(n_samples_a), (n_samples_a, 1)) >>> X_t = np.reshape(np.arange(0, n_samples_b), (n_samples_b, 1)) - >>> empirical_sinkhorn2(X_s, X_t, reg, verbose=False) - array([4.53978687e-05]) + >>> b = np.full((n_samples_b, 3), 1/n_samples_b) + >>> empirical_sinkhorn2(X_s, X_t, b=b, reg=reg, verbose=False) + array([4.53978687e-05, 4.53978687e-05, 4.53978687e-05]) References @@ -1935,8 +1932,8 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli Returns ------- - gamma : ndarray, shape (n_samples_a, n_samples_b) - Regularized optimal transportation matrix for the given parameters + W : (1,) ndarray + Optimal transportation symmetrized loss for the given parameters log : dict log dictionary return only if log==True in parameters @@ -1959,13 +1956,13 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli sinkhorn_loss_ab, log_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs) - sinkhorn_loss_a, log_a = empirical_sinkhorn2(X_s, X_s, reg, a, b, metric=metric, numIterMax=numIterMax, + sinkhorn_loss_a, log_a = empirical_sinkhorn2(X_s, X_s, reg, a, a, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs) - sinkhorn_loss_b, log_b = empirical_sinkhorn2(X_t, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, + sinkhorn_loss_b, log_b = empirical_sinkhorn2(X_t, X_t, reg, b, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs) - sinkhorn_div = sinkhorn_loss_ab - 1 / 2 * (sinkhorn_loss_a + sinkhorn_loss_b) + sinkhorn_div = sinkhorn_loss_ab - 0.5 * (sinkhorn_loss_a + sinkhorn_loss_b) log = {} log['sinkhorn_loss_ab'] = sinkhorn_loss_ab @@ -1981,13 +1978,13 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli sinkhorn_loss_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs) - sinkhorn_loss_a = empirical_sinkhorn2(X_s, X_s, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, + sinkhorn_loss_a = empirical_sinkhorn2(X_s, X_s, reg, a, a, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs) - sinkhorn_loss_b = empirical_sinkhorn2(X_t, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, + sinkhorn_loss_b = empirical_sinkhorn2(X_t, X_t, reg, b, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs) - sinkhorn_div = sinkhorn_loss_ab - 1 / 2 * (sinkhorn_loss_a + sinkhorn_loss_b) + sinkhorn_div = sinkhorn_loss_ab - 0.5 * (sinkhorn_loss_a + sinkhorn_loss_b) return max(0, sinkhorn_div) @@ -2212,11 +2209,11 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res # box constraints in L-BFGS-B (see Proposition 1 in [26]) bounds_u = [(max(a_I_min / ((nt - nt_budget) * epsilon + nt_budget * (b_J_max / ( - ns * epsilon * kappa * K_min))), epsilon / kappa), a_I_max / (nt * epsilon * K_min))] * ns_budget + ns * epsilon * kappa * K_min))), epsilon / kappa), a_I_max / (nt * epsilon * K_min))] * ns_budget bounds_v = [( - max(b_J_min / ((ns - ns_budget) * epsilon + ns_budget * (kappa * a_I_max / (nt * epsilon * K_min))), - epsilon * kappa), b_J_max / (ns * epsilon * K_min))] * nt_budget + max(b_J_min / ((ns - ns_budget) * epsilon + ns_budget * (kappa * a_I_max / (nt * epsilon * K_min))), + epsilon * kappa), b_J_max / (ns * epsilon * K_min))] * nt_budget # pre-calculated constants for the objective vec_eps_IJc = epsilon * kappa * (K_IJc * np.ones(nt - nt_budget).reshape((1, -1))).sum(axis=1) diff --git a/test/test_bregman.py b/test/test_bregman.py index 6aa4e08..331acd3 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -6,9 +6,10 @@ # License: MIT License import numpy as np -import ot import pytest +import ot + def test_sinkhorn(): # test sinkhorn @@ -257,7 +258,8 @@ def test_empirical_sinkhorn(): def test_empirical_sinkhorn_divergence(): # Test sinkhorn divergence n = 10 - a = ot.unif(n) + a = np.linspace(1, n, n) + a /= a.sum() b = ot.unif(n) X_s = np.reshape(np.arange(n), (n, 1)) X_t = np.reshape(np.arange(0, n * 2, 2), (n, 1)) @@ -265,16 +267,15 @@ def test_empirical_sinkhorn_divergence(): M_s = ot.dist(X_s, X_s) M_t = ot.dist(X_t, X_t) - emp_sinkhorn_div = ot.bregman.empirical_sinkhorn_divergence(X_s, X_t, 1) + emp_sinkhorn_div = ot.bregman.empirical_sinkhorn_divergence(X_s, X_t, 1, a=a, b=b) sinkhorn_div = (ot.sinkhorn2(a, b, M, 1) - 1 / 2 * ot.sinkhorn2(a, a, M_s, 1) - 1 / 2 * ot.sinkhorn2(b, b, M_t, 1)) - emp_sinkhorn_div_log, log_es = ot.bregman.empirical_sinkhorn_divergence(X_s, X_t, 1, log=True) + emp_sinkhorn_div_log, log_es = ot.bregman.empirical_sinkhorn_divergence(X_s, X_t, 1, a=a, b=b, log=True) sink_div_log_ab, log_s_ab = ot.sinkhorn2(a, b, M, 1, log=True) sink_div_log_a, log_s_a = ot.sinkhorn2(a, a, M_s, 1, log=True) sink_div_log_b, log_s_b = ot.sinkhorn2(b, b, M_t, 1, log=True) sink_div_log = sink_div_log_ab - 1 / 2 * (sink_div_log_a + sink_div_log_b) - - # check constratints + # check constraints np.testing.assert_allclose( emp_sinkhorn_div, sinkhorn_div, atol=1e-05) # cf conv emp sinkhorn np.testing.assert_allclose( -- cgit v1.2.3