diff options
-rw-r--r-- | README.md | 6 | ||||
-rw-r--r-- | examples/plot_screenkhorn_1D.py | 68 | ||||
-rw-r--r-- | ot/bregman.py | 339 | ||||
-rw-r--r-- | requirements.txt | 5 | ||||
-rw-r--r-- | test/test_bregman.py | 26 |
5 files changed, 435 insertions, 9 deletions
@@ -28,6 +28,7 @@ It provides the following solvers: * Stochastic Optimization for Large-scale Optimal Transport (semi-dual problem [18] and dual problem [19]) * Non regularized free support Wasserstein barycenters [20]. * Unbalanced OT with KL relaxation distance and barycenter [10, 25]. +* Screening Sinkhorn Algorithm for OT [26]. Some demonstrations (both in Python and Jupyter Notebook format) are available in the examples folder. @@ -180,6 +181,7 @@ The contributors to this library are * [Vayer Titouan](https://tvayer.github.io/) * [Hicham Janati](https://hichamjanati.github.io/) (Unbalanced OT) * [Romain Tavenard](https://rtavenar.github.io/) (1d Wasserstein) +* [Mokhtar Z. Alaya](http://mzalaya.github.io/) (Screenkhorn) This toolbox benefit a lot from open source research and we would like to thank the following persons for providing some code (in various languages): @@ -252,4 +254,6 @@ You can also post bug reports and feature requests in Github issues. Make sure t [24] Vayer, T., Chapel, L., Flamary, R., Tavenard, R. and Courty, N. (2019). [Optimal Transport for structured data with application on graphs](http://proceedings.mlr.press/v97/titouan19a.html) Proceedings of the 36th International Conference on Machine Learning (ICML). -[25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. (2019). [Learning with a Wasserstein Loss](http://cbcl.mit.edu/wasserstein/) Advances in Neural Information Processing Systems (NIPS). +[25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. (2015). [Learning with a Wasserstein Loss](http://cbcl.mit.edu/wasserstein/) Advances in Neural Information Processing Systems (NIPS). + +[26] Alaya M. Z., Bérar M., Gasso G., Rakotomamonjy A. (2019). [Screening Sinkhorn Algorithm for Regularized Optimal Transport](https://papers.nips.cc/paper/9386-screening-sinkhorn-algorithm-for-regularized-optimal-transport), Advances in Neural Information Processing Systems 33 (NeurIPS). diff --git a/examples/plot_screenkhorn_1D.py b/examples/plot_screenkhorn_1D.py new file mode 100644 index 0000000..840ead8 --- /dev/null +++ b/examples/plot_screenkhorn_1D.py @@ -0,0 +1,68 @@ +# -*- coding: utf-8 -*- +""" +=============================== +1D Screened optimal transport +=============================== + +This example illustrates the computation of Screenkhorn: +Screening Sinkhorn Algorithm for Optimal transport. +""" + +# Author: Mokhtar Z. Alaya <mokhtarzahdi.alaya@gmail.com> +# +# License: MIT License + +import numpy as np +import matplotlib.pylab as pl +import ot.plot +from ot.datasets import make_1D_gauss as gauss +from ot.bregman import screenkhorn + +############################################################################## +# Generate data +# ------------- + +#%% parameters + +n = 100 # nb bins + +# bin positions +x = np.arange(n, dtype=np.float64) + +# Gaussian distributions +a = gauss(n, m=20, s=5) # m= mean, s= std +b = gauss(n, m=60, s=10) + +# loss matrix +M = ot.dist(x.reshape((n, 1)), x.reshape((n, 1))) +M /= M.max() + +############################################################################## +# Plot distributions and loss matrix +# ---------------------------------- + +#%% plot the distributions + +pl.figure(1, figsize=(6.4, 3)) +pl.plot(x, a, 'b', label='Source distribution') +pl.plot(x, b, 'r', label='Target distribution') +pl.legend() + +# plot distributions and loss matrix + +pl.figure(2, figsize=(5, 5)) +ot.plot.plot1D_mat(a, b, M, 'Cost matrix M') + +############################################################################## +# Solve Screenkhorn +# ----------------------- + +# Screenkhorn +lambd = 2e-03 # entropy parameter +ns_budget = 30 # budget number of points to be keeped in the source distribution +nt_budget = 30 # budget number of points to be keeped in the target distribution + +G_screen = screenkhorn(a, b, M, lambd, ns_budget, nt_budget, uniform=False, restricted=True, verbose=True) +pl.figure(4, figsize=(5, 5)) +ot.plot.plot1D_mat(a, b, G_screen, 'OT matrix Screenkhorn') +pl.show() diff --git a/ot/bregman.py b/ot/bregman.py index ba5c7ba..2707b7c 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -8,12 +8,14 @@ Bregman projections for regularized OT # Kilian Fatras <kilian.fatras@irisa.fr> # Titouan Vayer <titouan.vayer@irisa.fr> # Hicham Janati <hicham.janati@inria.fr> +# Mokhtar Z. Alaya <mokhtarzahdi.alaya@gmail.com> # # License: MIT License import numpy as np import warnings from .utils import unif, dist +from scipy.optimize import fmin_l_bfgs_b def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, @@ -1787,3 +1789,340 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli sinkhorn_div = sinkhorn_loss_ab - 1 / 2 * (sinkhorn_loss_a + sinkhorn_loss_b) return max(0, sinkhorn_div) + + +def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, restricted=True, + maxiter=10000, maxfun=10000, pgtol=1e-09, verbose=False, log=False): + r"""" + Screening Sinkhorn Algorithm for Regularized Optimal Transport + + The function solves an approximate dual of Sinkhorn divergence [2] which is written as the following optimization problem: + + ..math:: + (u, v) = \argmin_{u, v} 1_{ns}^T B(u,v) 1_{nt} - <\kappa u, a> - <v/\kappa, b> + + where B(u,v) = \diag(e^u) K \diag(e^v), with K = e^{-M/reg} and + + s.t. e^{u_i} \geq \epsilon / \kappa, for all i \in {1, ..., ns} + + e^{v_j} \geq \epsilon \kappa, for all j \in {1, ..., nt} + + The parameters \kappa and \epsilon are determined w.r.t the couple number budget of points (ns_budget, nt_budget), see Equation (5) in [26] + + + Parameters + ---------- + a : `numpy.ndarray`, shape=(ns,) + samples weights in the source domain + + b : `numpy.ndarray`, shape=(nt,) + samples weights in the target domain + + M : `numpy.ndarray`, shape=(ns, nt) + Cost matrix + + reg : `float` + Level of the entropy regularisation + + ns_budget : `int`, deafult=None + Number budget of points to be keeped in the source domain + If it is None then 50% of the source sample points will be keeped + + nt_budget : `int`, deafult=None + Number budget of points to be keeped in the target domain + If it is None then 50% of the target sample points will be keeped + + uniform : `bool`, default=False + If `True`, the source and target distribution are supposed to be uniform, i.e., a_i = 1 / ns and b_j = 1 / nt + + restricted : `bool`, default=True + If `True`, a warm-start initialization for the L-BFGS-B solver + using a restricted Sinkhorn algorithm with at most 5 iterations + + maxiter : `int`, default=10000 + Maximum number of iterations in LBFGS solver + + maxfun : `int`, default=10000 + Maximum number of function evaluations in LBFGS solver + + pgtol : `float`, default=1e-09 + Final objective function accuracy in LBFGS solver + + verbose : `bool`, default=False + If `True`, dispaly informations about the cardinals of the active sets and the paramerters kappa + and epsilon + + Dependency + ---------- + To gain more efficiency, screenkhorn needs to call the "Bottleneck" package (https://pypi.org/project/Bottleneck/) + in the screening pre-processing step. If Bottleneck isn't installed, the following error message appears: + "Bottleneck module doesn't exist. Install it from https://pypi.org/project/Bottleneck/" + + + Returns + ------- + gamma : `numpy.ndarray`, shape=(ns, nt) + Screened optimal transportation matrix for the given parameters + + log : `dict`, default=False + Log dictionary return only if log==True in parameters + + + References + ----------- + .. [26] Alaya M. Z., Bérar M., Gasso G., Rakotomamonjy A. (2019). Screening Sinkhorn Algorithm for Regularized Optimal Transport (NIPS) 33, 2019 + + """ + # check if bottleneck module exists + try: + import bottleneck + except ImportError: + warnings.warn("Bottleneck module is not installed. Install it from https://pypi.org/project/Bottleneck/ for better performance.") + bottleneck = np + + a = np.asarray(a, dtype=np.float64) + b = np.asarray(b, dtype=np.float64) + M = np.asarray(M, dtype=np.float64) + ns, nt = M.shape + + # by default, we keep only 50% of the sample data points + if ns_budget is None: + ns_budget = int(np.floor(0.5 * ns)) + if nt_budget is None: + nt_budget = int(np.floor(0.5 * nt)) + + # calculate the Gibbs kernel + K = np.empty_like(M) + np.divide(M, -reg, out=K) + np.exp(K, out=K) + + def projection(u, epsilon): + u[u <= epsilon] = epsilon + return u + + # ----------------------------------------------------------------------------------------------------------------# + # Step 1: Screening pre-processing # + # ----------------------------------------------------------------------------------------------------------------# + + if ns_budget == ns and nt_budget == nt: + # full number of budget points (ns, nt) = (ns_budget, nt_budget) + Isel = np.ones(ns, dtype=bool) + Jsel = np.ones(nt, dtype=bool) + epsilon = 0.0 + kappa = 1.0 + + cst_u = 0. + cst_v = 0. + + bounds_u = [(0.0, np.inf)] * ns + bounds_v = [(0.0, np.inf)] * nt + + a_I = a + b_J = b + K_IJ = K + K_IJc = [] + K_IcJ = [] + + vec_eps_IJc = np.zeros(nt) + vec_eps_IcJ = np.zeros(ns) + + else: + # sum of rows and columns of K + K_sum_cols = K.sum(axis=1) + K_sum_rows = K.sum(axis=0) + + if uniform: + if ns / ns_budget < 4: + aK_sort = np.sort(K_sum_cols) + epsilon_u_square = a[0] / aK_sort[ns_budget - 1] + else: + aK_sort = bottleneck.partition(K_sum_cols, ns_budget - 1)[ns_budget - 1] + epsilon_u_square = a[0] / aK_sort + + if nt / nt_budget < 4: + bK_sort = np.sort(K_sum_rows) + epsilon_v_square = b[0] / bK_sort[nt_budget - 1] + else: + bK_sort = bottleneck.partition(K_sum_rows, nt_budget - 1)[nt_budget - 1] + epsilon_v_square = b[0] / bK_sort + else: + aK = a / K_sum_cols + bK = b / K_sum_rows + + aK_sort = np.sort(aK)[::-1] + epsilon_u_square = aK_sort[ns_budget - 1] + + bK_sort = np.sort(bK)[::-1] + epsilon_v_square = bK_sort[nt_budget - 1] + + # active sets I and J (see Lemma 1 in [26]) + Isel = a >= epsilon_u_square * K_sum_cols + Jsel = b >= epsilon_v_square * K_sum_rows + + if sum(Isel) != ns_budget: + if uniform: + aK = a / K_sum_cols + aK_sort = np.sort(aK)[::-1] + epsilon_u_square = aK_sort[ns_budget - 1:ns_budget + 1].mean() + Isel = a >= epsilon_u_square * K_sum_cols + ns_budget = sum(Isel) + + if sum(Jsel) != nt_budget: + if uniform: + bK = b / K_sum_rows + bK_sort = np.sort(bK)[::-1] + epsilon_v_square = bK_sort[nt_budget - 1:nt_budget + 1].mean() + Jsel = b >= epsilon_v_square * K_sum_rows + nt_budget = sum(Jsel) + + epsilon = (epsilon_u_square * epsilon_v_square) ** (1 / 4) + kappa = (epsilon_v_square / epsilon_u_square) ** (1 / 2) + + if verbose: + print("epsilon = %s\n" % epsilon) + print("kappa = %s\n" % kappa) + print('Cardinality of selected points: |Isel| = %s \t |Jsel| = %s \n' % (sum(Isel), sum(Jsel))) + + # Ic, Jc: complementary of the active sets I and J + Ic = ~Isel + Jc = ~Jsel + + K_IJ = K[np.ix_(Isel, Jsel)] + K_IcJ = K[np.ix_(Ic, Jsel)] + K_IJc = K[np.ix_(Isel, Jc)] + + K_min = K_IJ.min() + if K_min == 0: + K_min = np.finfo(float).tiny + + # a_I, b_J, a_Ic, b_Jc + a_I = a[Isel] + b_J = b[Jsel] + if not uniform: + a_I_min = a_I.min() + a_I_max = a_I.max() + b_J_max = b_J.max() + b_J_min = b_J.min() + else: + a_I_min = a_I[0] + a_I_max = a_I[0] + b_J_max = b_J[0] + b_J_min = b_J[0] + + # box constraints in L-BFGS-B (see Proposition 1 in [26]) + bounds_u = [(max(a_I_min / ((nt - nt_budget) * epsilon + nt_budget * (b_J_max / ( + ns * epsilon * kappa * K_min))), epsilon / kappa), a_I_max / (nt * epsilon * K_min))] * ns_budget + + bounds_v = [(max(b_J_min / ((ns - ns_budget) * epsilon + ns_budget * (kappa * a_I_max / (nt * epsilon * K_min))), + epsilon * kappa), b_J_max / (ns * epsilon * K_min))] * nt_budget + + # pre-calculated constants for the objective + vec_eps_IJc = epsilon * kappa * (K_IJc * np.ones(nt - nt_budget).reshape((1, -1))).sum(axis=1) + vec_eps_IcJ = (epsilon / kappa) * (np.ones(ns - ns_budget).reshape((-1, 1)) * K_IcJ).sum(axis=0) + + # initialisation + u0 = np.full(ns_budget, (1. / ns_budget) + epsilon / kappa) + v0 = np.full(nt_budget, (1. / nt_budget) + epsilon * kappa) + + # pre-calculed constants for Restricted Sinkhorn (see Algorithm 1 in supplementary of [26]) + if restricted: + if ns_budget != ns or nt_budget != nt: + cst_u = kappa * epsilon * K_IJc.sum(axis=1) + cst_v = epsilon * K_IcJ.sum(axis=0) / kappa + + cpt = 1 + while cpt < 5: # 5 iterations + K_IJ_v = np.dot(K_IJ.T, u0) + cst_v + v0 = b_J / (kappa * K_IJ_v) + KIJ_u = np.dot(K_IJ, v0) + cst_u + u0 = (kappa * a_I) / KIJ_u + cpt += 1 + + u0 = projection(u0, epsilon / kappa) + v0 = projection(v0, epsilon * kappa) + + else: + u0 = u0 + v0 = v0 + + def restricted_sinkhorn(usc, vsc, max_iter=5): + """ + Restricted Sinkhorn Algorithm as a warm-start initialized point for L-BFGS-B (see Algorithm 1 in supplementary of [26]) + """ + cpt = 1 + while cpt < max_iter: + K_IJ_v = np.dot(K_IJ.T, usc) + cst_v + vsc = b_J / (kappa * K_IJ_v) + KIJ_u = np.dot(K_IJ, vsc) + cst_u + usc = (kappa * a_I) / KIJ_u + cpt += 1 + + usc = projection(usc, epsilon / kappa) + vsc = projection(vsc, epsilon * kappa) + + return usc, vsc + + def screened_obj(usc, vsc): + part_IJ = np.dot(np.dot(usc, K_IJ), vsc) - kappa * np.dot(a_I, np.log(usc)) - (1. / kappa) * np.dot(b_J, np.log(vsc)) + part_IJc = np.dot(usc, vec_eps_IJc) + part_IcJ = np.dot(vec_eps_IcJ, vsc) + psi_epsilon = part_IJ + part_IJc + part_IcJ + return psi_epsilon + + def screened_grad(usc, vsc): + # gradients of Psi_(kappa,epsilon) w.r.t u and v + grad_u = np.dot(K_IJ, vsc) + vec_eps_IJc - kappa * a_I / usc + grad_v = np.dot(K_IJ.T, usc) + vec_eps_IcJ - (1. / kappa) * b_J / vsc + return grad_u, grad_v + + def bfgspost(theta): + u = theta[:ns_budget] + v = theta[ns_budget:] + # objective + f = screened_obj(u, v) + # gradient + g_u, g_v = screened_grad(u, v) + g = np.hstack([g_u, g_v]) + return f, g + + #----------------------------------------------------------------------------------------------------------------# + # Step 2: L-BFGS-B solver # + #----------------------------------------------------------------------------------------------------------------# + + u0, v0 = restricted_sinkhorn(u0, v0) + theta0 = np.hstack([u0, v0]) + + bounds = bounds_u + bounds_v # constraint bounds + + def obj(theta): + return bfgspost(theta) + + theta, _, _ = fmin_l_bfgs_b(func=obj, + x0=theta0, + bounds=bounds, + maxfun=maxfun, + pgtol=pgtol, + maxiter=maxiter) + + usc = theta[:ns_budget] + vsc = theta[ns_budget:] + + usc_full = np.full(ns, epsilon / kappa) + vsc_full = np.full(nt, epsilon * kappa) + usc_full[Isel] = usc + vsc_full[Jsel] = vsc + + if log: + log = {} + log['u'] = usc_full + log['v'] = vsc_full + log['Isel'] = Isel + log['Jsel'] = Jsel + + gamma = usc_full[:, None] * K * vsc_full[None, :] + gamma = gamma / gamma.sum() + + if log: + return gamma, log + else: + return gamma diff --git a/requirements.txt b/requirements.txt index 5a3432b..c08822e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,6 +4,7 @@ cython matplotlib sphinx-gallery autograd -pymanopt +pymanopt==0.2.4; python_version <'3' +pymanopt; python_version >= '3' cvxopt -pytest +pytest
\ No newline at end of file diff --git a/test/test_bregman.py b/test/test_bregman.py index f70df10..f54ba9f 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -106,7 +106,6 @@ def test_sinkhorn_variants_log(): @pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"]) def test_barycenter(method): - n_bins = 100 # nb bins # Gaussian distributions @@ -133,7 +132,6 @@ def test_barycenter(method): def test_barycenter_stabilization(): - n_bins = 100 # nb bins # Gaussian distributions @@ -161,7 +159,6 @@ def test_barycenter_stabilization(): def test_wasserstein_bary_2d(): - size = 100 # size of a square image a1 = np.random.randn(size, size) a1 += a1.min() @@ -185,7 +182,6 @@ def test_wasserstein_bary_2d(): def test_unmix(): - n_bins = 50 # nb bins # Gaussian distributions @@ -207,7 +203,7 @@ def test_unmix(): # wasserstein reg = 1e-3 - um = ot.bregman.unmix(a, D, M, M0, h0, reg, 1, alpha=0.01,) + um = ot.bregman.unmix(a, D, M, M0, h0, reg, 1, alpha=0.01, ) np.testing.assert_allclose(1, np.sum(um), rtol=1e-03, atol=1e-03) np.testing.assert_allclose([0.5, 0.5], um, rtol=1e-03, atol=1e-03) @@ -256,7 +252,7 @@ def test_empirical_sinkhorn(): def test_empirical_sinkhorn_divergence(): - #Test sinkhorn divergence + # Test sinkhorn divergence n = 10 a = ot.unif(n) b = ot.unif(n) @@ -337,3 +333,21 @@ def test_implemented_methods(): ot.bregman.sinkhorn(a, b, M, epsilon, method=method) with pytest.raises(ValueError): ot.bregman.sinkhorn2(a, b, M, epsilon, method=method) + + +def test_screenkhorn(): + # test screenkhorn + rng = np.random.RandomState(0) + n = 100 + a = ot.unif(n) + b = ot.unif(n) + + x = rng.randn(n, 2) + M = ot.dist(x, x) + # sinkhorn + G_sink = ot.sinkhorn(a, b, M, 1e-03) + # screenkhorn + G_screen = ot.bregman.screenkhorn(a, b, M, 1e-03, uniform=True, verbose=True) + # check marginals + np.testing.assert_allclose(G_sink.sum(0), G_screen.sum(0), atol=1e-02) + np.testing.assert_allclose(G_sink.sum(1), G_screen.sum(1), atol=1e-02) |