summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHuy Tran <huytran82125@gmail.com>2021-06-14 13:06:40 +0200
committerGitHub <noreply@github.com>2021-06-14 13:06:40 +0200
commit2dbeeda9308029a8e8db56bed07d48f4d5718efb (patch)
tree60fab2738fa0ffc6fc3c2762171d6d8d133e28c0
parent982510eb5085a0edd7a00fb96a308854957d32bf (diff)
[MRG] Batch/Lazy Log Sinkhorn Knopp on samples (#259)
* Add batch implementation of Sinkhorn * Reformat to pep8 and modify parameter * Fix error in batch size * Code review and add test * Fix accidental typo in test_empirical_sinkhorn * Remove whitespace * Edit config.yml
-rw-r--r--.circleci/config.yml1
-rw-r--r--ot/bregman.py134
-rw-r--r--test/test_bregman.py44
3 files changed, 158 insertions, 21 deletions
diff --git a/.circleci/config.yml b/.circleci/config.yml
index 29c9a07..e4c71dd 100644
--- a/.circleci/config.yml
+++ b/.circleci/config.yml
@@ -73,6 +73,7 @@ jobs:
command: |
cd docs;
make html;
+ no_output_timeout: 30m
# Save the outputs
- store_artifacts:
diff --git a/ot/bregman.py b/ot/bregman.py
index b10effd..105b38b 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -11,6 +11,7 @@ Bregman projections solvers for entropic regularized OT
# Mokhtar Z. Alaya <mokhtarzahdi.alaya@gmail.com>
# Alexander Tong <alexander.tong@yale.edu>
# Ievgen Redko <ievgen.redko@univ-st-etienne.fr>
+# Quang Huy Tran <quang-huy.tran@univ-ubs.fr>
#
# License: MIT License
@@ -18,6 +19,7 @@ import warnings
import numpy as np
from scipy.optimize import fmin_l_bfgs_b
+from scipy.special import logsumexp
from ot.utils import unif, dist, list_to_array
from .backend import get_backend
@@ -1684,7 +1686,7 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100,
def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
- numIterMax=10000, stopThr=1e-9, verbose=False,
+ numIterMax=10000, stopThr=1e-9, isLazy=False, batchSize=100, verbose=False,
log=False, **kwargs):
r'''
Solve the entropic regularization optimal transport problem and return the
@@ -1723,6 +1725,12 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
Max number of iterations
stopThr : float, optional
Stop threshol on error (>0)
+ isLazy: boolean, optional
+ If True, then only calculate the cost matrix by block and return the dual potentials only (to save memory)
+ If False, calculate full cost matrix and return outputs of sinkhorn function.
+ batchSize: int or tuple of 2 int, optional
+ Size of the batcheses used to compute the sinkhorn update without memory overhead.
+ When a tuple is provided it sets the size of the left/right batches.
verbose : bool, optional
Print information along iterations
log : bool, optional
@@ -1758,24 +1766,78 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
.. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
'''
-
+ ns, nt = X_s.shape[0], X_t.shape[0]
if a is None:
- a = unif(np.shape(X_s)[0])
+ a = unif(ns)
if b is None:
- b = unif(np.shape(X_t)[0])
+ b = unif(nt)
+
+ if isLazy:
+ if log:
+ dict_log = {"err": []}
- M = dist(X_s, X_t, metric=metric)
+ log_a, log_b = np.log(a), np.log(b)
+ f, g = np.zeros(ns), np.zeros(nt)
+
+ if isinstance(batchSize, int):
+ bs, bt = batchSize, batchSize
+ elif isinstance(batchSize, tuple) and len(batchSize) == 2:
+ bs, bt = batchSize[0], batchSize[1]
+ else:
+ raise ValueError("Batch size must be in integer or a tuple of two integers")
+
+ range_s, range_t = range(0, ns, bs), range(0, nt, bt)
+
+ lse_f = np.zeros(ns)
+ lse_g = np.zeros(nt)
+
+ for i_ot in range(numIterMax):
+
+ for i in range_s:
+ M = dist(X_s[i:i + bs, :], X_t, metric=metric)
+ lse_f[i:i + bs] = logsumexp(g[None, :] - M / reg, axis=1)
+ f = log_a - lse_f
+
+ for j in range_t:
+ M = dist(X_s, X_t[j:j + bt, :], metric=metric)
+ lse_g[j:j + bt] = logsumexp(f[:, None] - M / reg, axis=0)
+ g = log_b - lse_g
+
+ if (i_ot + 1) % 10 == 0:
+ m1 = np.zeros_like(a)
+ for i in range_s:
+ M = dist(X_s[i:i + bs, :], X_t, metric=metric)
+ m1[i:i + bs] = np.exp(f[i:i + bs, None] + g[None, :] - M / reg).sum(1)
+ err = np.abs(m1 - a).sum()
+ if log:
+ dict_log["err"].append(err)
+
+ if verbose and (i_ot + 1) % 100 == 0:
+ print("Error in marginal at iteration {} = {}".format(i_ot + 1, err))
+
+ if err <= stopThr:
+ break
+
+ if log:
+ dict_log["u"] = f
+ dict_log["v"] = g
+ return (f, g, dict_log)
+ else:
+ return (f, g)
- if log:
- pi, log = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=True, **kwargs)
- return pi, log
else:
- pi = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=False, **kwargs)
- return pi
+ M = dist(X_s, X_t, metric=metric)
+
+ if log:
+ pi, log = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=True, **kwargs)
+ return pi, log
+ else:
+ pi = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=False, **kwargs)
+ return pi
def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9,
- verbose=False, log=False, **kwargs):
+ isLazy=False, batchSize=100, verbose=False, log=False, **kwargs):
r'''
Solve the entropic regularization optimal transport problem from empirical
data and return the OT loss
@@ -1814,6 +1876,12 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num
Max number of iterations
stopThr : float, optional
Stop threshol on error (>0)
+ isLazy: boolean, optional
+ If True, then only calculate the cost matrix by block and return the dual potentials only (to save memory)
+ If False, calculate full cost matrix and return outputs of sinkhorn function.
+ batchSize: int or tuple of 2 int, optional
+ Size of the batcheses used to compute the sinkhorn update without memory overhead.
+ When a tuple is provided it sets the size of the left/right batches.
verbose : bool, optional
Print information along iterations
log : bool, optional
@@ -1850,21 +1918,45 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num
.. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
'''
+ ns, nt = X_s.shape[0], X_t.shape[0]
if a is None:
- a = unif(np.shape(X_s)[0])
+ a = unif(ns)
if b is None:
- b = unif(np.shape(X_t)[0])
+ b = unif(nt)
- M = dist(X_s, X_t, metric=metric)
+ if isLazy:
+ if log:
+ f, g, dict_log = empirical_sinkhorn(X_s, X_t, reg, a, b, metric, numIterMax=numIterMax, stopThr=stopThr,
+ isLazy=isLazy, batchSize=batchSize, verbose=verbose, log=log)
+ else:
+ f, g = empirical_sinkhorn(X_s, X_t, reg, a, b, metric, numIterMax=numIterMax, stopThr=stopThr,
+ isLazy=isLazy, batchSize=batchSize, verbose=verbose, log=log)
+
+ bs = batchSize if isinstance(batchSize, int) else batchSize[0]
+ range_s = range(0, ns, bs)
+
+ loss = 0
+ for i in range_s:
+ M_block = dist(X_s[i:i + bs, :], X_t, metric=metric)
+ pi_block = np.exp(f[i:i + bs, None] + g[None, :] - M_block / reg)
+ loss += np.sum(M_block * pi_block)
+
+ if log:
+ return loss, dict_log
+ else:
+ return loss
- if log:
- sinkhorn_loss, log = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log,
- **kwargs)
- return sinkhorn_loss, log
else:
- sinkhorn_loss = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log,
- **kwargs)
- return sinkhorn_loss
+ M = dist(X_s, X_t, metric=metric)
+
+ if log:
+ sinkhorn_loss, log = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log,
+ **kwargs)
+ return sinkhorn_loss, log
+ else:
+ sinkhorn_loss = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log,
+ **kwargs)
+ return sinkhorn_loss
def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9,
diff --git a/test/test_bregman.py b/test/test_bregman.py
index 7c5162a..9665229 100644
--- a/test/test_bregman.py
+++ b/test/test_bregman.py
@@ -2,6 +2,7 @@
# Author: Remi Flamary <remi.flamary@unice.fr>
# Kilian Fatras <kilian.fatras@irisa.fr>
+# Quang Huy Tran <quang-huy.tran@univ-ubs.fr>
#
# License: MIT License
@@ -329,6 +330,49 @@ def test_empirical_sinkhorn():
np.testing.assert_allclose(loss_emp_sinkhorn, loss_sinkhorn, atol=1e-05)
+def test_lazy_empirical_sinkhorn():
+ # test sinkhorn
+ n = 100
+ a = ot.unif(n)
+ b = ot.unif(n)
+ numIterMax = 1000
+
+ X_s = np.reshape(np.arange(n), (n, 1))
+ X_t = np.reshape(np.arange(0, n), (n, 1))
+ M = ot.dist(X_s, X_t)
+ M_m = ot.dist(X_s, X_t, metric='minkowski')
+
+ f, g = ot.bregman.empirical_sinkhorn(X_s, X_t, 1, numIterMax=numIterMax, isLazy=True, batchSize=(1, 1), verbose=True)
+ G_sqe = np.exp(f[:, None] + g[None, :] - M / 1)
+ sinkhorn_sqe = ot.sinkhorn(a, b, M, 1)
+
+ f, g, log_es = ot.bregman.empirical_sinkhorn(X_s, X_t, 0.1, numIterMax=numIterMax, isLazy=True, batchSize=1, log=True)
+ G_log = np.exp(f[:, None] + g[None, :] - M / 0.1)
+ sinkhorn_log, log_s = ot.sinkhorn(a, b, M, 0.1, log=True)
+
+ f, g = ot.bregman.empirical_sinkhorn(X_s, X_t, 1, metric='minkowski', numIterMax=numIterMax, isLazy=True, batchSize=1)
+ G_m = np.exp(f[:, None] + g[None, :] - M_m / 1)
+ sinkhorn_m = ot.sinkhorn(a, b, M_m, 1)
+
+ loss_emp_sinkhorn, log = ot.bregman.empirical_sinkhorn2(X_s, X_t, 1, numIterMax=numIterMax, isLazy=True, batchSize=1, log=True)
+ loss_sinkhorn = ot.sinkhorn2(a, b, M, 1)
+
+ # check constratints
+ np.testing.assert_allclose(
+ sinkhorn_sqe.sum(1), G_sqe.sum(1), atol=1e-05) # metric sqeuclidian
+ np.testing.assert_allclose(
+ sinkhorn_sqe.sum(0), G_sqe.sum(0), atol=1e-05) # metric sqeuclidian
+ np.testing.assert_allclose(
+ sinkhorn_log.sum(1), G_log.sum(1), atol=1e-05) # log
+ np.testing.assert_allclose(
+ sinkhorn_log.sum(0), G_log.sum(0), atol=1e-05) # log
+ np.testing.assert_allclose(
+ sinkhorn_m.sum(1), G_m.sum(1), atol=1e-05) # metric euclidian
+ np.testing.assert_allclose(
+ sinkhorn_m.sum(0), G_m.sum(0), atol=1e-05) # metric euclidian
+ np.testing.assert_allclose(loss_emp_sinkhorn, loss_sinkhorn, atol=1e-05)
+
+
def test_empirical_sinkhorn_divergence():
# Test sinkhorn divergence
n = 10