diff options
author | AdrienCorenflos <adrien.corenflos@gmail.com> | 2021-04-19 14:57:51 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-04-19 13:57:51 +0200 |
commit | 2a3f2241951ea9cc044b4fba8a382b6ae9630513 (patch) | |
tree | c4a07fda0e2ac6495d673df8aba277588bb47783 /ot | |
parent | 3a2ec71ae7d11aa650a7d3222357885010a9b2c3 (diff) |
BUG/DOC FIX - Sinkhorn divergence used the wrong weights, and sinkhorn2 didn't support epsilon_scaling method. (#235)
* FIX:
1. Documentation of loss specific functions
2. Sinkhorn divergence weights handling
3. Sinkhorn2 does not support epsilon scaling, so I removed it (it *should* arguably support it, but this would require a refactoring of the sinkhorn iterates pretty much everywhere, maybe should be done in torch first?)
* Had some PEP8 issues
Co-authored-by: RĂ©mi Flamary <remi.flamary@gmail.com>
Diffstat (limited to 'ot')
-rw-r--r-- | ot/bregman.py | 53 |
1 files changed, 25 insertions, 28 deletions
diff --git a/ot/bregman.py b/ot/bregman.py index dcd35e1..559db14 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -14,11 +14,13 @@ Bregman projections solvers for entropic regularized OT # # License: MIT License -import numpy as np import warnings -from .utils import unif, dist + +import numpy as np from scipy.optimize import fmin_l_bfgs_b +from ot.utils import unif, dist + def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, stopThr=1e-9, verbose=False, log=False, **kwargs): @@ -179,8 +181,7 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, reg : float Regularization term >0 method : str - method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or - 'sinkhorn_epsilon_scaling', see those function for specific parameters + method used for the solver either 'sinkhorn', 'sinkhorn_stabilized', see those function for specific parameters numItermax : int, optional Max number of iterations stopThr : float, optional @@ -207,7 +208,7 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, Returns ------- - W : (n_hists) ndarray or float + W : (n_hists) ndarray Optimal transportation loss for the given parameters log : dict log dictionary return only if log==True in parameters @@ -244,12 +245,12 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, ot.bregman.sinkhorn_knopp : Classic Sinkhorn [2] ot.bregman.greenkhorn : Greenkhorn [21] ot.bregman.sinkhorn_stabilized: Stabilized sinkhorn [9][10] - ot.bregman.sinkhorn_epsilon_scaling: Sinkhorn with epslilon scaling [9][10] """ b = np.asarray(b, dtype=np.float64) if len(b.shape) < 2: b = b[:, None] + if method.lower() == 'sinkhorn': return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, @@ -258,10 +259,6 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) - elif method.lower() == 'sinkhorn_epsilon_scaling': - return sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=numItermax, - stopThr=stopThr, verbose=verbose, - log=log, **kwargs) else: raise ValueError("Unknown method '%s'." % method) @@ -745,8 +742,7 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, # remove numerical problems and store them in K if np.abs(u).max() > tau or np.abs(v).max() > tau: if n_hists: - alpha, beta = alpha + reg * \ - np.max(np.log(u), 1), beta + reg * np.max(np.log(v)) + alpha, beta = alpha + reg * np.max(np.log(u), 1), beta + reg * np.max(np.log(v)) else: alpha, beta = alpha + reg * np.log(u), beta + reg * np.log(v) if n_hists: @@ -1747,7 +1743,7 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', >>> reg = 0.1 >>> X_s = np.reshape(np.arange(n_samples_a), (n_samples_a, 1)) >>> X_t = np.reshape(np.arange(0, n_samples_b), (n_samples_b, 1)) - >>> empirical_sinkhorn(X_s, X_t, reg, verbose=False) # doctest: +NORMALIZE_WHITESPACE + >>> empirical_sinkhorn(X_s, X_t, reg=reg, verbose=False) # doctest: +NORMALIZE_WHITESPACE array([[4.99977301e-01, 2.26989344e-05], [2.26989344e-05, 4.99977301e-01]]) @@ -1825,8 +1821,8 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num Returns ------- - gamma : ndarray, shape (n_samples_a, n_samples_b) - Regularized optimal transportation matrix for the given parameters + W : (n_hists) ndarray or float + Optimal transportation loss for the given parameters log : dict log dictionary return only if log==True in parameters @@ -1838,8 +1834,9 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num >>> reg = 0.1 >>> X_s = np.reshape(np.arange(n_samples_a), (n_samples_a, 1)) >>> X_t = np.reshape(np.arange(0, n_samples_b), (n_samples_b, 1)) - >>> empirical_sinkhorn2(X_s, X_t, reg, verbose=False) - array([4.53978687e-05]) + >>> b = np.full((n_samples_b, 3), 1/n_samples_b) + >>> empirical_sinkhorn2(X_s, X_t, b=b, reg=reg, verbose=False) + array([4.53978687e-05, 4.53978687e-05, 4.53978687e-05]) References @@ -1935,8 +1932,8 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli Returns ------- - gamma : ndarray, shape (n_samples_a, n_samples_b) - Regularized optimal transportation matrix for the given parameters + W : (1,) ndarray + Optimal transportation symmetrized loss for the given parameters log : dict log dictionary return only if log==True in parameters @@ -1959,13 +1956,13 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli sinkhorn_loss_ab, log_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs) - sinkhorn_loss_a, log_a = empirical_sinkhorn2(X_s, X_s, reg, a, b, metric=metric, numIterMax=numIterMax, + sinkhorn_loss_a, log_a = empirical_sinkhorn2(X_s, X_s, reg, a, a, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs) - sinkhorn_loss_b, log_b = empirical_sinkhorn2(X_t, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, + sinkhorn_loss_b, log_b = empirical_sinkhorn2(X_t, X_t, reg, b, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs) - sinkhorn_div = sinkhorn_loss_ab - 1 / 2 * (sinkhorn_loss_a + sinkhorn_loss_b) + sinkhorn_div = sinkhorn_loss_ab - 0.5 * (sinkhorn_loss_a + sinkhorn_loss_b) log = {} log['sinkhorn_loss_ab'] = sinkhorn_loss_ab @@ -1981,13 +1978,13 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli sinkhorn_loss_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs) - sinkhorn_loss_a = empirical_sinkhorn2(X_s, X_s, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, + sinkhorn_loss_a = empirical_sinkhorn2(X_s, X_s, reg, a, a, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs) - sinkhorn_loss_b = empirical_sinkhorn2(X_t, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, + sinkhorn_loss_b = empirical_sinkhorn2(X_t, X_t, reg, b, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs) - sinkhorn_div = sinkhorn_loss_ab - 1 / 2 * (sinkhorn_loss_a + sinkhorn_loss_b) + sinkhorn_div = sinkhorn_loss_ab - 0.5 * (sinkhorn_loss_a + sinkhorn_loss_b) return max(0, sinkhorn_div) @@ -2212,11 +2209,11 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res # box constraints in L-BFGS-B (see Proposition 1 in [26]) bounds_u = [(max(a_I_min / ((nt - nt_budget) * epsilon + nt_budget * (b_J_max / ( - ns * epsilon * kappa * K_min))), epsilon / kappa), a_I_max / (nt * epsilon * K_min))] * ns_budget + ns * epsilon * kappa * K_min))), epsilon / kappa), a_I_max / (nt * epsilon * K_min))] * ns_budget bounds_v = [( - max(b_J_min / ((ns - ns_budget) * epsilon + ns_budget * (kappa * a_I_max / (nt * epsilon * K_min))), - epsilon * kappa), b_J_max / (ns * epsilon * K_min))] * nt_budget + max(b_J_min / ((ns - ns_budget) * epsilon + ns_budget * (kappa * a_I_max / (nt * epsilon * K_min))), + epsilon * kappa), b_J_max / (ns * epsilon * K_min))] * nt_budget # pre-calculated constants for the objective vec_eps_IJc = epsilon * kappa * (K_IJc * np.ones(nt - nt_budget).reshape((1, -1))).sum(axis=1) |