diff options
author | Huy Tran <huytran82125@gmail.com> | 2021-06-14 13:06:40 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-06-14 13:06:40 +0200 |
commit | 2dbeeda9308029a8e8db56bed07d48f4d5718efb (patch) | |
tree | 60fab2738fa0ffc6fc3c2762171d6d8d133e28c0 /ot | |
parent | 982510eb5085a0edd7a00fb96a308854957d32bf (diff) |
[MRG] Batch/Lazy Log Sinkhorn Knopp on samples (#259)
* Add batch implementation of Sinkhorn
* Reformat to pep8 and modify parameter
* Fix error in batch size
* Code review and add test
* Fix accidental typo in test_empirical_sinkhorn
* Remove whitespace
* Edit config.yml
Diffstat (limited to 'ot')
-rw-r--r-- | ot/bregman.py | 134 |
1 files changed, 113 insertions, 21 deletions
diff --git a/ot/bregman.py b/ot/bregman.py index b10effd..105b38b 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -11,6 +11,7 @@ Bregman projections solvers for entropic regularized OT # Mokhtar Z. Alaya <mokhtarzahdi.alaya@gmail.com> # Alexander Tong <alexander.tong@yale.edu> # Ievgen Redko <ievgen.redko@univ-st-etienne.fr> +# Quang Huy Tran <quang-huy.tran@univ-ubs.fr> # # License: MIT License @@ -18,6 +19,7 @@ import warnings import numpy as np from scipy.optimize import fmin_l_bfgs_b +from scipy.special import logsumexp from ot.utils import unif, dist, list_to_array from .backend import get_backend @@ -1684,7 +1686,7 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', - numIterMax=10000, stopThr=1e-9, verbose=False, + numIterMax=10000, stopThr=1e-9, isLazy=False, batchSize=100, verbose=False, log=False, **kwargs): r''' Solve the entropic regularization optimal transport problem and return the @@ -1723,6 +1725,12 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', Max number of iterations stopThr : float, optional Stop threshol on error (>0) + isLazy: boolean, optional + If True, then only calculate the cost matrix by block and return the dual potentials only (to save memory) + If False, calculate full cost matrix and return outputs of sinkhorn function. + batchSize: int or tuple of 2 int, optional + Size of the batcheses used to compute the sinkhorn update without memory overhead. + When a tuple is provided it sets the size of the left/right batches. verbose : bool, optional Print information along iterations log : bool, optional @@ -1758,24 +1766,78 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. ''' - + ns, nt = X_s.shape[0], X_t.shape[0] if a is None: - a = unif(np.shape(X_s)[0]) + a = unif(ns) if b is None: - b = unif(np.shape(X_t)[0]) + b = unif(nt) + + if isLazy: + if log: + dict_log = {"err": []} - M = dist(X_s, X_t, metric=metric) + log_a, log_b = np.log(a), np.log(b) + f, g = np.zeros(ns), np.zeros(nt) + + if isinstance(batchSize, int): + bs, bt = batchSize, batchSize + elif isinstance(batchSize, tuple) and len(batchSize) == 2: + bs, bt = batchSize[0], batchSize[1] + else: + raise ValueError("Batch size must be in integer or a tuple of two integers") + + range_s, range_t = range(0, ns, bs), range(0, nt, bt) + + lse_f = np.zeros(ns) + lse_g = np.zeros(nt) + + for i_ot in range(numIterMax): + + for i in range_s: + M = dist(X_s[i:i + bs, :], X_t, metric=metric) + lse_f[i:i + bs] = logsumexp(g[None, :] - M / reg, axis=1) + f = log_a - lse_f + + for j in range_t: + M = dist(X_s, X_t[j:j + bt, :], metric=metric) + lse_g[j:j + bt] = logsumexp(f[:, None] - M / reg, axis=0) + g = log_b - lse_g + + if (i_ot + 1) % 10 == 0: + m1 = np.zeros_like(a) + for i in range_s: + M = dist(X_s[i:i + bs, :], X_t, metric=metric) + m1[i:i + bs] = np.exp(f[i:i + bs, None] + g[None, :] - M / reg).sum(1) + err = np.abs(m1 - a).sum() + if log: + dict_log["err"].append(err) + + if verbose and (i_ot + 1) % 100 == 0: + print("Error in marginal at iteration {} = {}".format(i_ot + 1, err)) + + if err <= stopThr: + break + + if log: + dict_log["u"] = f + dict_log["v"] = g + return (f, g, dict_log) + else: + return (f, g) - if log: - pi, log = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=True, **kwargs) - return pi, log else: - pi = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=False, **kwargs) - return pi + M = dist(X_s, X_t, metric=metric) + + if log: + pi, log = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=True, **kwargs) + return pi, log + else: + pi = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=False, **kwargs) + return pi def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9, - verbose=False, log=False, **kwargs): + isLazy=False, batchSize=100, verbose=False, log=False, **kwargs): r''' Solve the entropic regularization optimal transport problem from empirical data and return the OT loss @@ -1814,6 +1876,12 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num Max number of iterations stopThr : float, optional Stop threshol on error (>0) + isLazy: boolean, optional + If True, then only calculate the cost matrix by block and return the dual potentials only (to save memory) + If False, calculate full cost matrix and return outputs of sinkhorn function. + batchSize: int or tuple of 2 int, optional + Size of the batcheses used to compute the sinkhorn update without memory overhead. + When a tuple is provided it sets the size of the left/right batches. verbose : bool, optional Print information along iterations log : bool, optional @@ -1850,21 +1918,45 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. ''' + ns, nt = X_s.shape[0], X_t.shape[0] if a is None: - a = unif(np.shape(X_s)[0]) + a = unif(ns) if b is None: - b = unif(np.shape(X_t)[0]) + b = unif(nt) - M = dist(X_s, X_t, metric=metric) + if isLazy: + if log: + f, g, dict_log = empirical_sinkhorn(X_s, X_t, reg, a, b, metric, numIterMax=numIterMax, stopThr=stopThr, + isLazy=isLazy, batchSize=batchSize, verbose=verbose, log=log) + else: + f, g = empirical_sinkhorn(X_s, X_t, reg, a, b, metric, numIterMax=numIterMax, stopThr=stopThr, + isLazy=isLazy, batchSize=batchSize, verbose=verbose, log=log) + + bs = batchSize if isinstance(batchSize, int) else batchSize[0] + range_s = range(0, ns, bs) + + loss = 0 + for i in range_s: + M_block = dist(X_s[i:i + bs, :], X_t, metric=metric) + pi_block = np.exp(f[i:i + bs, None] + g[None, :] - M_block / reg) + loss += np.sum(M_block * pi_block) + + if log: + return loss, dict_log + else: + return loss - if log: - sinkhorn_loss, log = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log, - **kwargs) - return sinkhorn_loss, log else: - sinkhorn_loss = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log, - **kwargs) - return sinkhorn_loss + M = dist(X_s, X_t, metric=metric) + + if log: + sinkhorn_loss, log = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log, + **kwargs) + return sinkhorn_loss, log + else: + sinkhorn_loss = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log, + **kwargs) + return sinkhorn_loss def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9, |