summaryrefslogtreecommitdiff
path: root/ot/bregman.py
diff options
context:
space:
mode:
authorHuy Tran <huytran82125@gmail.com>2021-06-14 13:06:40 +0200
committerGitHub <noreply@github.com>2021-06-14 13:06:40 +0200
commit2dbeeda9308029a8e8db56bed07d48f4d5718efb (patch)
tree60fab2738fa0ffc6fc3c2762171d6d8d133e28c0 /ot/bregman.py
parent982510eb5085a0edd7a00fb96a308854957d32bf (diff)
[MRG] Batch/Lazy Log Sinkhorn Knopp on samples (#259)
* Add batch implementation of Sinkhorn * Reformat to pep8 and modify parameter * Fix error in batch size * Code review and add test * Fix accidental typo in test_empirical_sinkhorn * Remove whitespace * Edit config.yml
Diffstat (limited to 'ot/bregman.py')
-rw-r--r--ot/bregman.py134
1 files changed, 113 insertions, 21 deletions
diff --git a/ot/bregman.py b/ot/bregman.py
index b10effd..105b38b 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -11,6 +11,7 @@ Bregman projections solvers for entropic regularized OT
# Mokhtar Z. Alaya <mokhtarzahdi.alaya@gmail.com>
# Alexander Tong <alexander.tong@yale.edu>
# Ievgen Redko <ievgen.redko@univ-st-etienne.fr>
+# Quang Huy Tran <quang-huy.tran@univ-ubs.fr>
#
# License: MIT License
@@ -18,6 +19,7 @@ import warnings
import numpy as np
from scipy.optimize import fmin_l_bfgs_b
+from scipy.special import logsumexp
from ot.utils import unif, dist, list_to_array
from .backend import get_backend
@@ -1684,7 +1686,7 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100,
def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
- numIterMax=10000, stopThr=1e-9, verbose=False,
+ numIterMax=10000, stopThr=1e-9, isLazy=False, batchSize=100, verbose=False,
log=False, **kwargs):
r'''
Solve the entropic regularization optimal transport problem and return the
@@ -1723,6 +1725,12 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
Max number of iterations
stopThr : float, optional
Stop threshol on error (>0)
+ isLazy: boolean, optional
+ If True, then only calculate the cost matrix by block and return the dual potentials only (to save memory)
+ If False, calculate full cost matrix and return outputs of sinkhorn function.
+ batchSize: int or tuple of 2 int, optional
+ Size of the batcheses used to compute the sinkhorn update without memory overhead.
+ When a tuple is provided it sets the size of the left/right batches.
verbose : bool, optional
Print information along iterations
log : bool, optional
@@ -1758,24 +1766,78 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
.. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
'''
-
+ ns, nt = X_s.shape[0], X_t.shape[0]
if a is None:
- a = unif(np.shape(X_s)[0])
+ a = unif(ns)
if b is None:
- b = unif(np.shape(X_t)[0])
+ b = unif(nt)
+
+ if isLazy:
+ if log:
+ dict_log = {"err": []}
- M = dist(X_s, X_t, metric=metric)
+ log_a, log_b = np.log(a), np.log(b)
+ f, g = np.zeros(ns), np.zeros(nt)
+
+ if isinstance(batchSize, int):
+ bs, bt = batchSize, batchSize
+ elif isinstance(batchSize, tuple) and len(batchSize) == 2:
+ bs, bt = batchSize[0], batchSize[1]
+ else:
+ raise ValueError("Batch size must be in integer or a tuple of two integers")
+
+ range_s, range_t = range(0, ns, bs), range(0, nt, bt)
+
+ lse_f = np.zeros(ns)
+ lse_g = np.zeros(nt)
+
+ for i_ot in range(numIterMax):
+
+ for i in range_s:
+ M = dist(X_s[i:i + bs, :], X_t, metric=metric)
+ lse_f[i:i + bs] = logsumexp(g[None, :] - M / reg, axis=1)
+ f = log_a - lse_f
+
+ for j in range_t:
+ M = dist(X_s, X_t[j:j + bt, :], metric=metric)
+ lse_g[j:j + bt] = logsumexp(f[:, None] - M / reg, axis=0)
+ g = log_b - lse_g
+
+ if (i_ot + 1) % 10 == 0:
+ m1 = np.zeros_like(a)
+ for i in range_s:
+ M = dist(X_s[i:i + bs, :], X_t, metric=metric)
+ m1[i:i + bs] = np.exp(f[i:i + bs, None] + g[None, :] - M / reg).sum(1)
+ err = np.abs(m1 - a).sum()
+ if log:
+ dict_log["err"].append(err)
+
+ if verbose and (i_ot + 1) % 100 == 0:
+ print("Error in marginal at iteration {} = {}".format(i_ot + 1, err))
+
+ if err <= stopThr:
+ break
+
+ if log:
+ dict_log["u"] = f
+ dict_log["v"] = g
+ return (f, g, dict_log)
+ else:
+ return (f, g)
- if log:
- pi, log = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=True, **kwargs)
- return pi, log
else:
- pi = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=False, **kwargs)
- return pi
+ M = dist(X_s, X_t, metric=metric)
+
+ if log:
+ pi, log = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=True, **kwargs)
+ return pi, log
+ else:
+ pi = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=False, **kwargs)
+ return pi
def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9,
- verbose=False, log=False, **kwargs):
+ isLazy=False, batchSize=100, verbose=False, log=False, **kwargs):
r'''
Solve the entropic regularization optimal transport problem from empirical
data and return the OT loss
@@ -1814,6 +1876,12 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num
Max number of iterations
stopThr : float, optional
Stop threshol on error (>0)
+ isLazy: boolean, optional
+ If True, then only calculate the cost matrix by block and return the dual potentials only (to save memory)
+ If False, calculate full cost matrix and return outputs of sinkhorn function.
+ batchSize: int or tuple of 2 int, optional
+ Size of the batcheses used to compute the sinkhorn update without memory overhead.
+ When a tuple is provided it sets the size of the left/right batches.
verbose : bool, optional
Print information along iterations
log : bool, optional
@@ -1850,21 +1918,45 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num
.. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
'''
+ ns, nt = X_s.shape[0], X_t.shape[0]
if a is None:
- a = unif(np.shape(X_s)[0])
+ a = unif(ns)
if b is None:
- b = unif(np.shape(X_t)[0])
+ b = unif(nt)
- M = dist(X_s, X_t, metric=metric)
+ if isLazy:
+ if log:
+ f, g, dict_log = empirical_sinkhorn(X_s, X_t, reg, a, b, metric, numIterMax=numIterMax, stopThr=stopThr,
+ isLazy=isLazy, batchSize=batchSize, verbose=verbose, log=log)
+ else:
+ f, g = empirical_sinkhorn(X_s, X_t, reg, a, b, metric, numIterMax=numIterMax, stopThr=stopThr,
+ isLazy=isLazy, batchSize=batchSize, verbose=verbose, log=log)
+
+ bs = batchSize if isinstance(batchSize, int) else batchSize[0]
+ range_s = range(0, ns, bs)
+
+ loss = 0
+ for i in range_s:
+ M_block = dist(X_s[i:i + bs, :], X_t, metric=metric)
+ pi_block = np.exp(f[i:i + bs, None] + g[None, :] - M_block / reg)
+ loss += np.sum(M_block * pi_block)
+
+ if log:
+ return loss, dict_log
+ else:
+ return loss
- if log:
- sinkhorn_loss, log = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log,
- **kwargs)
- return sinkhorn_loss, log
else:
- sinkhorn_loss = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log,
- **kwargs)
- return sinkhorn_loss
+ M = dist(X_s, X_t, metric=metric)
+
+ if log:
+ sinkhorn_loss, log = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log,
+ **kwargs)
+ return sinkhorn_loss, log
+ else:
+ sinkhorn_loss = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log,
+ **kwargs)
+ return sinkhorn_loss
def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9,