summaryrefslogtreecommitdiff
path: root/ot/bregman.py
diff options
context:
space:
mode:
authorAdrienCorenflos <adrien.corenflos@gmail.com>2021-04-19 14:57:51 +0300
committerGitHub <noreply@github.com>2021-04-19 13:57:51 +0200
commit2a3f2241951ea9cc044b4fba8a382b6ae9630513 (patch)
treec4a07fda0e2ac6495d673df8aba277588bb47783 /ot/bregman.py
parent3a2ec71ae7d11aa650a7d3222357885010a9b2c3 (diff)
BUG/DOC FIX - Sinkhorn divergence used the wrong weights, and sinkhorn2 didn't support epsilon_scaling method. (#235)
* FIX: 1. Documentation of loss specific functions 2. Sinkhorn divergence weights handling 3. Sinkhorn2 does not support epsilon scaling, so I removed it (it *should* arguably support it, but this would require a refactoring of the sinkhorn iterates pretty much everywhere, maybe should be done in torch first?) * Had some PEP8 issues Co-authored-by: RĂ©mi Flamary <remi.flamary@gmail.com>
Diffstat (limited to 'ot/bregman.py')
-rw-r--r--ot/bregman.py53
1 files changed, 25 insertions, 28 deletions
diff --git a/ot/bregman.py b/ot/bregman.py
index dcd35e1..559db14 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -14,11 +14,13 @@ Bregman projections solvers for entropic regularized OT
#
# License: MIT License
-import numpy as np
import warnings
-from .utils import unif, dist
+
+import numpy as np
from scipy.optimize import fmin_l_bfgs_b
+from ot.utils import unif, dist
+
def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
stopThr=1e-9, verbose=False, log=False, **kwargs):
@@ -179,8 +181,7 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
reg : float
Regularization term >0
method : str
- method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or
- 'sinkhorn_epsilon_scaling', see those function for specific parameters
+ method used for the solver either 'sinkhorn', 'sinkhorn_stabilized', see those function for specific parameters
numItermax : int, optional
Max number of iterations
stopThr : float, optional
@@ -207,7 +208,7 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
Returns
-------
- W : (n_hists) ndarray or float
+ W : (n_hists) ndarray
Optimal transportation loss for the given parameters
log : dict
log dictionary return only if log==True in parameters
@@ -244,12 +245,12 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
ot.bregman.sinkhorn_knopp : Classic Sinkhorn [2]
ot.bregman.greenkhorn : Greenkhorn [21]
ot.bregman.sinkhorn_stabilized: Stabilized sinkhorn [9][10]
- ot.bregman.sinkhorn_epsilon_scaling: Sinkhorn with epslilon scaling [9][10]
"""
b = np.asarray(b, dtype=np.float64)
if len(b.shape) < 2:
b = b[:, None]
+
if method.lower() == 'sinkhorn':
return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax,
stopThr=stopThr, verbose=verbose, log=log,
@@ -258,10 +259,6 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax,
stopThr=stopThr, verbose=verbose, log=log,
**kwargs)
- elif method.lower() == 'sinkhorn_epsilon_scaling':
- return sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=numItermax,
- stopThr=stopThr, verbose=verbose,
- log=log, **kwargs)
else:
raise ValueError("Unknown method '%s'." % method)
@@ -745,8 +742,7 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
# remove numerical problems and store them in K
if np.abs(u).max() > tau or np.abs(v).max() > tau:
if n_hists:
- alpha, beta = alpha + reg * \
- np.max(np.log(u), 1), beta + reg * np.max(np.log(v))
+ alpha, beta = alpha + reg * np.max(np.log(u), 1), beta + reg * np.max(np.log(v))
else:
alpha, beta = alpha + reg * np.log(u), beta + reg * np.log(v)
if n_hists:
@@ -1747,7 +1743,7 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
>>> reg = 0.1
>>> X_s = np.reshape(np.arange(n_samples_a), (n_samples_a, 1))
>>> X_t = np.reshape(np.arange(0, n_samples_b), (n_samples_b, 1))
- >>> empirical_sinkhorn(X_s, X_t, reg, verbose=False) # doctest: +NORMALIZE_WHITESPACE
+ >>> empirical_sinkhorn(X_s, X_t, reg=reg, verbose=False) # doctest: +NORMALIZE_WHITESPACE
array([[4.99977301e-01, 2.26989344e-05],
[2.26989344e-05, 4.99977301e-01]])
@@ -1825,8 +1821,8 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num
Returns
-------
- gamma : ndarray, shape (n_samples_a, n_samples_b)
- Regularized optimal transportation matrix for the given parameters
+ W : (n_hists) ndarray or float
+ Optimal transportation loss for the given parameters
log : dict
log dictionary return only if log==True in parameters
@@ -1838,8 +1834,9 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num
>>> reg = 0.1
>>> X_s = np.reshape(np.arange(n_samples_a), (n_samples_a, 1))
>>> X_t = np.reshape(np.arange(0, n_samples_b), (n_samples_b, 1))
- >>> empirical_sinkhorn2(X_s, X_t, reg, verbose=False)
- array([4.53978687e-05])
+ >>> b = np.full((n_samples_b, 3), 1/n_samples_b)
+ >>> empirical_sinkhorn2(X_s, X_t, b=b, reg=reg, verbose=False)
+ array([4.53978687e-05, 4.53978687e-05, 4.53978687e-05])
References
@@ -1935,8 +1932,8 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
Returns
-------
- gamma : ndarray, shape (n_samples_a, n_samples_b)
- Regularized optimal transportation matrix for the given parameters
+ W : (1,) ndarray
+ Optimal transportation symmetrized loss for the given parameters
log : dict
log dictionary return only if log==True in parameters
@@ -1959,13 +1956,13 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
sinkhorn_loss_ab, log_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, numIterMax=numIterMax,
stopThr=1e-9, verbose=verbose, log=log, **kwargs)
- sinkhorn_loss_a, log_a = empirical_sinkhorn2(X_s, X_s, reg, a, b, metric=metric, numIterMax=numIterMax,
+ sinkhorn_loss_a, log_a = empirical_sinkhorn2(X_s, X_s, reg, a, a, metric=metric, numIterMax=numIterMax,
stopThr=1e-9, verbose=verbose, log=log, **kwargs)
- sinkhorn_loss_b, log_b = empirical_sinkhorn2(X_t, X_t, reg, a, b, metric=metric, numIterMax=numIterMax,
+ sinkhorn_loss_b, log_b = empirical_sinkhorn2(X_t, X_t, reg, b, b, metric=metric, numIterMax=numIterMax,
stopThr=1e-9, verbose=verbose, log=log, **kwargs)
- sinkhorn_div = sinkhorn_loss_ab - 1 / 2 * (sinkhorn_loss_a + sinkhorn_loss_b)
+ sinkhorn_div = sinkhorn_loss_ab - 0.5 * (sinkhorn_loss_a + sinkhorn_loss_b)
log = {}
log['sinkhorn_loss_ab'] = sinkhorn_loss_ab
@@ -1981,13 +1978,13 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
sinkhorn_loss_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9,
verbose=verbose, log=log, **kwargs)
- sinkhorn_loss_a = empirical_sinkhorn2(X_s, X_s, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9,
+ sinkhorn_loss_a = empirical_sinkhorn2(X_s, X_s, reg, a, a, metric=metric, numIterMax=numIterMax, stopThr=1e-9,
verbose=verbose, log=log, **kwargs)
- sinkhorn_loss_b = empirical_sinkhorn2(X_t, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9,
+ sinkhorn_loss_b = empirical_sinkhorn2(X_t, X_t, reg, b, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9,
verbose=verbose, log=log, **kwargs)
- sinkhorn_div = sinkhorn_loss_ab - 1 / 2 * (sinkhorn_loss_a + sinkhorn_loss_b)
+ sinkhorn_div = sinkhorn_loss_ab - 0.5 * (sinkhorn_loss_a + sinkhorn_loss_b)
return max(0, sinkhorn_div)
@@ -2212,11 +2209,11 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res
# box constraints in L-BFGS-B (see Proposition 1 in [26])
bounds_u = [(max(a_I_min / ((nt - nt_budget) * epsilon + nt_budget * (b_J_max / (
- ns * epsilon * kappa * K_min))), epsilon / kappa), a_I_max / (nt * epsilon * K_min))] * ns_budget
+ ns * epsilon * kappa * K_min))), epsilon / kappa), a_I_max / (nt * epsilon * K_min))] * ns_budget
bounds_v = [(
- max(b_J_min / ((ns - ns_budget) * epsilon + ns_budget * (kappa * a_I_max / (nt * epsilon * K_min))),
- epsilon * kappa), b_J_max / (ns * epsilon * K_min))] * nt_budget
+ max(b_J_min / ((ns - ns_budget) * epsilon + ns_budget * (kappa * a_I_max / (nt * epsilon * K_min))),
+ epsilon * kappa), b_J_max / (ns * epsilon * K_min))] * nt_budget
# pre-calculated constants for the objective
vec_eps_IJc = epsilon * kappa * (K_IJc * np.ones(nt - nt_budget).reshape((1, -1))).sum(axis=1)