summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--README.md6
-rw-r--r--examples/plot_screenkhorn_1D.py68
-rw-r--r--ot/bregman.py339
-rw-r--r--requirements.txt5
-rw-r--r--test/test_bregman.py26
5 files changed, 435 insertions, 9 deletions
diff --git a/README.md b/README.md
index d8bb051..c115776 100644
--- a/README.md
+++ b/README.md
@@ -28,6 +28,7 @@ It provides the following solvers:
* Stochastic Optimization for Large-scale Optimal Transport (semi-dual problem [18] and dual problem [19])
* Non regularized free support Wasserstein barycenters [20].
* Unbalanced OT with KL relaxation distance and barycenter [10, 25].
+* Screening Sinkhorn Algorithm for OT [26].
Some demonstrations (both in Python and Jupyter Notebook format) are available in the examples folder.
@@ -180,6 +181,7 @@ The contributors to this library are
* [Vayer Titouan](https://tvayer.github.io/)
* [Hicham Janati](https://hichamjanati.github.io/) (Unbalanced OT)
* [Romain Tavenard](https://rtavenar.github.io/) (1d Wasserstein)
+* [Mokhtar Z. Alaya](http://mzalaya.github.io/) (Screenkhorn)
This toolbox benefit a lot from open source research and we would like to thank the following persons for providing some code (in various languages):
@@ -252,4 +254,6 @@ You can also post bug reports and feature requests in Github issues. Make sure t
[24] Vayer, T., Chapel, L., Flamary, R., Tavenard, R. and Courty, N. (2019). [Optimal Transport for structured data with application on graphs](http://proceedings.mlr.press/v97/titouan19a.html) Proceedings of the 36th International Conference on Machine Learning (ICML).
-[25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. (2019). [Learning with a Wasserstein Loss](http://cbcl.mit.edu/wasserstein/) Advances in Neural Information Processing Systems (NIPS).
+[25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. (2015). [Learning with a Wasserstein Loss](http://cbcl.mit.edu/wasserstein/) Advances in Neural Information Processing Systems (NIPS).
+
+[26] Alaya M. Z., Bérar M., Gasso G., Rakotomamonjy A. (2019). [Screening Sinkhorn Algorithm for Regularized Optimal Transport](https://papers.nips.cc/paper/9386-screening-sinkhorn-algorithm-for-regularized-optimal-transport), Advances in Neural Information Processing Systems 33 (NeurIPS).
diff --git a/examples/plot_screenkhorn_1D.py b/examples/plot_screenkhorn_1D.py
new file mode 100644
index 0000000..840ead8
--- /dev/null
+++ b/examples/plot_screenkhorn_1D.py
@@ -0,0 +1,68 @@
+# -*- coding: utf-8 -*-
+"""
+===============================
+1D Screened optimal transport
+===============================
+
+This example illustrates the computation of Screenkhorn:
+Screening Sinkhorn Algorithm for Optimal transport.
+"""
+
+# Author: Mokhtar Z. Alaya <mokhtarzahdi.alaya@gmail.com>
+#
+# License: MIT License
+
+import numpy as np
+import matplotlib.pylab as pl
+import ot.plot
+from ot.datasets import make_1D_gauss as gauss
+from ot.bregman import screenkhorn
+
+##############################################################################
+# Generate data
+# -------------
+
+#%% parameters
+
+n = 100 # nb bins
+
+# bin positions
+x = np.arange(n, dtype=np.float64)
+
+# Gaussian distributions
+a = gauss(n, m=20, s=5) # m= mean, s= std
+b = gauss(n, m=60, s=10)
+
+# loss matrix
+M = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)))
+M /= M.max()
+
+##############################################################################
+# Plot distributions and loss matrix
+# ----------------------------------
+
+#%% plot the distributions
+
+pl.figure(1, figsize=(6.4, 3))
+pl.plot(x, a, 'b', label='Source distribution')
+pl.plot(x, b, 'r', label='Target distribution')
+pl.legend()
+
+# plot distributions and loss matrix
+
+pl.figure(2, figsize=(5, 5))
+ot.plot.plot1D_mat(a, b, M, 'Cost matrix M')
+
+##############################################################################
+# Solve Screenkhorn
+# -----------------------
+
+# Screenkhorn
+lambd = 2e-03 # entropy parameter
+ns_budget = 30 # budget number of points to be keeped in the source distribution
+nt_budget = 30 # budget number of points to be keeped in the target distribution
+
+G_screen = screenkhorn(a, b, M, lambd, ns_budget, nt_budget, uniform=False, restricted=True, verbose=True)
+pl.figure(4, figsize=(5, 5))
+ot.plot.plot1D_mat(a, b, G_screen, 'OT matrix Screenkhorn')
+pl.show()
diff --git a/ot/bregman.py b/ot/bregman.py
index ba5c7ba..2707b7c 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -8,12 +8,14 @@ Bregman projections for regularized OT
# Kilian Fatras <kilian.fatras@irisa.fr>
# Titouan Vayer <titouan.vayer@irisa.fr>
# Hicham Janati <hicham.janati@inria.fr>
+# Mokhtar Z. Alaya <mokhtarzahdi.alaya@gmail.com>
#
# License: MIT License
import numpy as np
import warnings
from .utils import unif, dist
+from scipy.optimize import fmin_l_bfgs_b
def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
@@ -1787,3 +1789,340 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
sinkhorn_div = sinkhorn_loss_ab - 1 / 2 * (sinkhorn_loss_a + sinkhorn_loss_b)
return max(0, sinkhorn_div)
+
+
+def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, restricted=True,
+ maxiter=10000, maxfun=10000, pgtol=1e-09, verbose=False, log=False):
+ r""""
+ Screening Sinkhorn Algorithm for Regularized Optimal Transport
+
+ The function solves an approximate dual of Sinkhorn divergence [2] which is written as the following optimization problem:
+
+ ..math::
+ (u, v) = \argmin_{u, v} 1_{ns}^T B(u,v) 1_{nt} - <\kappa u, a> - <v/\kappa, b>
+
+ where B(u,v) = \diag(e^u) K \diag(e^v), with K = e^{-M/reg} and
+
+ s.t. e^{u_i} \geq \epsilon / \kappa, for all i \in {1, ..., ns}
+
+ e^{v_j} \geq \epsilon \kappa, for all j \in {1, ..., nt}
+
+ The parameters \kappa and \epsilon are determined w.r.t the couple number budget of points (ns_budget, nt_budget), see Equation (5) in [26]
+
+
+ Parameters
+ ----------
+ a : `numpy.ndarray`, shape=(ns,)
+ samples weights in the source domain
+
+ b : `numpy.ndarray`, shape=(nt,)
+ samples weights in the target domain
+
+ M : `numpy.ndarray`, shape=(ns, nt)
+ Cost matrix
+
+ reg : `float`
+ Level of the entropy regularisation
+
+ ns_budget : `int`, deafult=None
+ Number budget of points to be keeped in the source domain
+ If it is None then 50% of the source sample points will be keeped
+
+ nt_budget : `int`, deafult=None
+ Number budget of points to be keeped in the target domain
+ If it is None then 50% of the target sample points will be keeped
+
+ uniform : `bool`, default=False
+ If `True`, the source and target distribution are supposed to be uniform, i.e., a_i = 1 / ns and b_j = 1 / nt
+
+ restricted : `bool`, default=True
+ If `True`, a warm-start initialization for the L-BFGS-B solver
+ using a restricted Sinkhorn algorithm with at most 5 iterations
+
+ maxiter : `int`, default=10000
+ Maximum number of iterations in LBFGS solver
+
+ maxfun : `int`, default=10000
+ Maximum number of function evaluations in LBFGS solver
+
+ pgtol : `float`, default=1e-09
+ Final objective function accuracy in LBFGS solver
+
+ verbose : `bool`, default=False
+ If `True`, dispaly informations about the cardinals of the active sets and the paramerters kappa
+ and epsilon
+
+ Dependency
+ ----------
+ To gain more efficiency, screenkhorn needs to call the "Bottleneck" package (https://pypi.org/project/Bottleneck/)
+ in the screening pre-processing step. If Bottleneck isn't installed, the following error message appears:
+ "Bottleneck module doesn't exist. Install it from https://pypi.org/project/Bottleneck/"
+
+
+ Returns
+ -------
+ gamma : `numpy.ndarray`, shape=(ns, nt)
+ Screened optimal transportation matrix for the given parameters
+
+ log : `dict`, default=False
+ Log dictionary return only if log==True in parameters
+
+
+ References
+ -----------
+ .. [26] Alaya M. Z., Bérar M., Gasso G., Rakotomamonjy A. (2019). Screening Sinkhorn Algorithm for Regularized Optimal Transport (NIPS) 33, 2019
+
+ """
+ # check if bottleneck module exists
+ try:
+ import bottleneck
+ except ImportError:
+ warnings.warn("Bottleneck module is not installed. Install it from https://pypi.org/project/Bottleneck/ for better performance.")
+ bottleneck = np
+
+ a = np.asarray(a, dtype=np.float64)
+ b = np.asarray(b, dtype=np.float64)
+ M = np.asarray(M, dtype=np.float64)
+ ns, nt = M.shape
+
+ # by default, we keep only 50% of the sample data points
+ if ns_budget is None:
+ ns_budget = int(np.floor(0.5 * ns))
+ if nt_budget is None:
+ nt_budget = int(np.floor(0.5 * nt))
+
+ # calculate the Gibbs kernel
+ K = np.empty_like(M)
+ np.divide(M, -reg, out=K)
+ np.exp(K, out=K)
+
+ def projection(u, epsilon):
+ u[u <= epsilon] = epsilon
+ return u
+
+ # ----------------------------------------------------------------------------------------------------------------#
+ # Step 1: Screening pre-processing #
+ # ----------------------------------------------------------------------------------------------------------------#
+
+ if ns_budget == ns and nt_budget == nt:
+ # full number of budget points (ns, nt) = (ns_budget, nt_budget)
+ Isel = np.ones(ns, dtype=bool)
+ Jsel = np.ones(nt, dtype=bool)
+ epsilon = 0.0
+ kappa = 1.0
+
+ cst_u = 0.
+ cst_v = 0.
+
+ bounds_u = [(0.0, np.inf)] * ns
+ bounds_v = [(0.0, np.inf)] * nt
+
+ a_I = a
+ b_J = b
+ K_IJ = K
+ K_IJc = []
+ K_IcJ = []
+
+ vec_eps_IJc = np.zeros(nt)
+ vec_eps_IcJ = np.zeros(ns)
+
+ else:
+ # sum of rows and columns of K
+ K_sum_cols = K.sum(axis=1)
+ K_sum_rows = K.sum(axis=0)
+
+ if uniform:
+ if ns / ns_budget < 4:
+ aK_sort = np.sort(K_sum_cols)
+ epsilon_u_square = a[0] / aK_sort[ns_budget - 1]
+ else:
+ aK_sort = bottleneck.partition(K_sum_cols, ns_budget - 1)[ns_budget - 1]
+ epsilon_u_square = a[0] / aK_sort
+
+ if nt / nt_budget < 4:
+ bK_sort = np.sort(K_sum_rows)
+ epsilon_v_square = b[0] / bK_sort[nt_budget - 1]
+ else:
+ bK_sort = bottleneck.partition(K_sum_rows, nt_budget - 1)[nt_budget - 1]
+ epsilon_v_square = b[0] / bK_sort
+ else:
+ aK = a / K_sum_cols
+ bK = b / K_sum_rows
+
+ aK_sort = np.sort(aK)[::-1]
+ epsilon_u_square = aK_sort[ns_budget - 1]
+
+ bK_sort = np.sort(bK)[::-1]
+ epsilon_v_square = bK_sort[nt_budget - 1]
+
+ # active sets I and J (see Lemma 1 in [26])
+ Isel = a >= epsilon_u_square * K_sum_cols
+ Jsel = b >= epsilon_v_square * K_sum_rows
+
+ if sum(Isel) != ns_budget:
+ if uniform:
+ aK = a / K_sum_cols
+ aK_sort = np.sort(aK)[::-1]
+ epsilon_u_square = aK_sort[ns_budget - 1:ns_budget + 1].mean()
+ Isel = a >= epsilon_u_square * K_sum_cols
+ ns_budget = sum(Isel)
+
+ if sum(Jsel) != nt_budget:
+ if uniform:
+ bK = b / K_sum_rows
+ bK_sort = np.sort(bK)[::-1]
+ epsilon_v_square = bK_sort[nt_budget - 1:nt_budget + 1].mean()
+ Jsel = b >= epsilon_v_square * K_sum_rows
+ nt_budget = sum(Jsel)
+
+ epsilon = (epsilon_u_square * epsilon_v_square) ** (1 / 4)
+ kappa = (epsilon_v_square / epsilon_u_square) ** (1 / 2)
+
+ if verbose:
+ print("epsilon = %s\n" % epsilon)
+ print("kappa = %s\n" % kappa)
+ print('Cardinality of selected points: |Isel| = %s \t |Jsel| = %s \n' % (sum(Isel), sum(Jsel)))
+
+ # Ic, Jc: complementary of the active sets I and J
+ Ic = ~Isel
+ Jc = ~Jsel
+
+ K_IJ = K[np.ix_(Isel, Jsel)]
+ K_IcJ = K[np.ix_(Ic, Jsel)]
+ K_IJc = K[np.ix_(Isel, Jc)]
+
+ K_min = K_IJ.min()
+ if K_min == 0:
+ K_min = np.finfo(float).tiny
+
+ # a_I, b_J, a_Ic, b_Jc
+ a_I = a[Isel]
+ b_J = b[Jsel]
+ if not uniform:
+ a_I_min = a_I.min()
+ a_I_max = a_I.max()
+ b_J_max = b_J.max()
+ b_J_min = b_J.min()
+ else:
+ a_I_min = a_I[0]
+ a_I_max = a_I[0]
+ b_J_max = b_J[0]
+ b_J_min = b_J[0]
+
+ # box constraints in L-BFGS-B (see Proposition 1 in [26])
+ bounds_u = [(max(a_I_min / ((nt - nt_budget) * epsilon + nt_budget * (b_J_max / (
+ ns * epsilon * kappa * K_min))), epsilon / kappa), a_I_max / (nt * epsilon * K_min))] * ns_budget
+
+ bounds_v = [(max(b_J_min / ((ns - ns_budget) * epsilon + ns_budget * (kappa * a_I_max / (nt * epsilon * K_min))),
+ epsilon * kappa), b_J_max / (ns * epsilon * K_min))] * nt_budget
+
+ # pre-calculated constants for the objective
+ vec_eps_IJc = epsilon * kappa * (K_IJc * np.ones(nt - nt_budget).reshape((1, -1))).sum(axis=1)
+ vec_eps_IcJ = (epsilon / kappa) * (np.ones(ns - ns_budget).reshape((-1, 1)) * K_IcJ).sum(axis=0)
+
+ # initialisation
+ u0 = np.full(ns_budget, (1. / ns_budget) + epsilon / kappa)
+ v0 = np.full(nt_budget, (1. / nt_budget) + epsilon * kappa)
+
+ # pre-calculed constants for Restricted Sinkhorn (see Algorithm 1 in supplementary of [26])
+ if restricted:
+ if ns_budget != ns or nt_budget != nt:
+ cst_u = kappa * epsilon * K_IJc.sum(axis=1)
+ cst_v = epsilon * K_IcJ.sum(axis=0) / kappa
+
+ cpt = 1
+ while cpt < 5: # 5 iterations
+ K_IJ_v = np.dot(K_IJ.T, u0) + cst_v
+ v0 = b_J / (kappa * K_IJ_v)
+ KIJ_u = np.dot(K_IJ, v0) + cst_u
+ u0 = (kappa * a_I) / KIJ_u
+ cpt += 1
+
+ u0 = projection(u0, epsilon / kappa)
+ v0 = projection(v0, epsilon * kappa)
+
+ else:
+ u0 = u0
+ v0 = v0
+
+ def restricted_sinkhorn(usc, vsc, max_iter=5):
+ """
+ Restricted Sinkhorn Algorithm as a warm-start initialized point for L-BFGS-B (see Algorithm 1 in supplementary of [26])
+ """
+ cpt = 1
+ while cpt < max_iter:
+ K_IJ_v = np.dot(K_IJ.T, usc) + cst_v
+ vsc = b_J / (kappa * K_IJ_v)
+ KIJ_u = np.dot(K_IJ, vsc) + cst_u
+ usc = (kappa * a_I) / KIJ_u
+ cpt += 1
+
+ usc = projection(usc, epsilon / kappa)
+ vsc = projection(vsc, epsilon * kappa)
+
+ return usc, vsc
+
+ def screened_obj(usc, vsc):
+ part_IJ = np.dot(np.dot(usc, K_IJ), vsc) - kappa * np.dot(a_I, np.log(usc)) - (1. / kappa) * np.dot(b_J, np.log(vsc))
+ part_IJc = np.dot(usc, vec_eps_IJc)
+ part_IcJ = np.dot(vec_eps_IcJ, vsc)
+ psi_epsilon = part_IJ + part_IJc + part_IcJ
+ return psi_epsilon
+
+ def screened_grad(usc, vsc):
+ # gradients of Psi_(kappa,epsilon) w.r.t u and v
+ grad_u = np.dot(K_IJ, vsc) + vec_eps_IJc - kappa * a_I / usc
+ grad_v = np.dot(K_IJ.T, usc) + vec_eps_IcJ - (1. / kappa) * b_J / vsc
+ return grad_u, grad_v
+
+ def bfgspost(theta):
+ u = theta[:ns_budget]
+ v = theta[ns_budget:]
+ # objective
+ f = screened_obj(u, v)
+ # gradient
+ g_u, g_v = screened_grad(u, v)
+ g = np.hstack([g_u, g_v])
+ return f, g
+
+ #----------------------------------------------------------------------------------------------------------------#
+ # Step 2: L-BFGS-B solver #
+ #----------------------------------------------------------------------------------------------------------------#
+
+ u0, v0 = restricted_sinkhorn(u0, v0)
+ theta0 = np.hstack([u0, v0])
+
+ bounds = bounds_u + bounds_v # constraint bounds
+
+ def obj(theta):
+ return bfgspost(theta)
+
+ theta, _, _ = fmin_l_bfgs_b(func=obj,
+ x0=theta0,
+ bounds=bounds,
+ maxfun=maxfun,
+ pgtol=pgtol,
+ maxiter=maxiter)
+
+ usc = theta[:ns_budget]
+ vsc = theta[ns_budget:]
+
+ usc_full = np.full(ns, epsilon / kappa)
+ vsc_full = np.full(nt, epsilon * kappa)
+ usc_full[Isel] = usc
+ vsc_full[Jsel] = vsc
+
+ if log:
+ log = {}
+ log['u'] = usc_full
+ log['v'] = vsc_full
+ log['Isel'] = Isel
+ log['Jsel'] = Jsel
+
+ gamma = usc_full[:, None] * K * vsc_full[None, :]
+ gamma = gamma / gamma.sum()
+
+ if log:
+ return gamma, log
+ else:
+ return gamma
diff --git a/requirements.txt b/requirements.txt
index 5a3432b..c08822e 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -4,6 +4,7 @@ cython
matplotlib
sphinx-gallery
autograd
-pymanopt
+pymanopt==0.2.4; python_version <'3'
+pymanopt; python_version >= '3'
cvxopt
-pytest
+pytest \ No newline at end of file
diff --git a/test/test_bregman.py b/test/test_bregman.py
index f70df10..f54ba9f 100644
--- a/test/test_bregman.py
+++ b/test/test_bregman.py
@@ -106,7 +106,6 @@ def test_sinkhorn_variants_log():
@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"])
def test_barycenter(method):
-
n_bins = 100 # nb bins
# Gaussian distributions
@@ -133,7 +132,6 @@ def test_barycenter(method):
def test_barycenter_stabilization():
-
n_bins = 100 # nb bins
# Gaussian distributions
@@ -161,7 +159,6 @@ def test_barycenter_stabilization():
def test_wasserstein_bary_2d():
-
size = 100 # size of a square image
a1 = np.random.randn(size, size)
a1 += a1.min()
@@ -185,7 +182,6 @@ def test_wasserstein_bary_2d():
def test_unmix():
-
n_bins = 50 # nb bins
# Gaussian distributions
@@ -207,7 +203,7 @@ def test_unmix():
# wasserstein
reg = 1e-3
- um = ot.bregman.unmix(a, D, M, M0, h0, reg, 1, alpha=0.01,)
+ um = ot.bregman.unmix(a, D, M, M0, h0, reg, 1, alpha=0.01, )
np.testing.assert_allclose(1, np.sum(um), rtol=1e-03, atol=1e-03)
np.testing.assert_allclose([0.5, 0.5], um, rtol=1e-03, atol=1e-03)
@@ -256,7 +252,7 @@ def test_empirical_sinkhorn():
def test_empirical_sinkhorn_divergence():
- #Test sinkhorn divergence
+ # Test sinkhorn divergence
n = 10
a = ot.unif(n)
b = ot.unif(n)
@@ -337,3 +333,21 @@ def test_implemented_methods():
ot.bregman.sinkhorn(a, b, M, epsilon, method=method)
with pytest.raises(ValueError):
ot.bregman.sinkhorn2(a, b, M, epsilon, method=method)
+
+
+def test_screenkhorn():
+ # test screenkhorn
+ rng = np.random.RandomState(0)
+ n = 100
+ a = ot.unif(n)
+ b = ot.unif(n)
+
+ x = rng.randn(n, 2)
+ M = ot.dist(x, x)
+ # sinkhorn
+ G_sink = ot.sinkhorn(a, b, M, 1e-03)
+ # screenkhorn
+ G_screen = ot.bregman.screenkhorn(a, b, M, 1e-03, uniform=True, verbose=True)
+ # check marginals
+ np.testing.assert_allclose(G_sink.sum(0), G_screen.sum(0), atol=1e-02)
+ np.testing.assert_allclose(G_sink.sum(1), G_screen.sum(1), atol=1e-02)