From d7554331fc409fea48ee758fd630909dd9dc4827 Mon Sep 17 00:00:00 2001 From: Rémi Flamary Date: Wed, 27 Oct 2021 08:41:08 +0200 Subject: [WIP] Sinkhorn in log space (#290) * adda sinkhorn log and working sinkhorn2 function * more tests pass * more tests pass * it works but not by default yet * remove warningd * update circleci doc * update circleci doc * new sinkhorn implemeted but not by default * better * doctest pass * test doctest * new test utils * remove pep8 errors * remove pep8 errors * doc new implementtaion with log * test sinkhorn 2 * doc for log implementation --- .circleci/config.yml | 14 +-- README.md | 4 +- docs/source/quickstart.rst | 10 +- ot/bregman.py | 272 +++++++++++++++++++++++++++++++++++++++++---- ot/dr.py | 4 +- ot/gromov.py | 4 +- ot/optim.py | 4 +- ot/utils.py | 4 +- test/test_bregman.py | 120 ++++++++++++++++++-- test/test_gromov.py | 10 +- test/test_helpers.py | 4 +- test/test_utils.py | 15 +++ 12 files changed, 403 insertions(+), 62 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index e4c71dd..379394a 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -4,7 +4,7 @@ version: 2 jobs: build_docs: docker: - - image: circleci/python:3.7-stretch + - image: cimg/python:3.9 steps: - checkout - run: @@ -34,18 +34,6 @@ jobs: - data-cache-0 - pip-cache - - run: - name: Spin up Xvfb - command: | - /sbin/start-stop-daemon --start --quiet --pidfile /tmp/custom_xvfb_99.pid --make-pidfile --background --exec /usr/bin/Xvfb -- :99 -screen 0 1400x900x24 -ac +extension GLX +render -noreset; - - # https://github.com/ContinuumIO/anaconda-issues/issues/9190#issuecomment-386508136 - # https://github.com/golemfactory/golem/issues/1019 - - run: - name: Fix libgcc_s.so.1 pthread_cancel bug - command: | - sudo apt-get install qt5-default - - run: name: Get Python running command: | diff --git a/README.md b/README.md index 266d847..ffad0bd 100644 --- a/README.md +++ b/README.md @@ -20,7 +20,7 @@ POT provides the following generic OT solvers (links to examples): * [OT Network Simplex solver](https://pythonot.github.io/auto_examples/plot_OT_1D.html) for the linear program/ Earth Movers Distance [1] . * [Conditional gradient](https://pythonot.github.io/auto_examples/plot_optim_OTreg.html) [6] and [Generalized conditional gradient](https://pythonot.github.io/auto_examples/plot_optim_OTreg.html) for regularized OT [7]. -* Entropic regularization OT solver with [Sinkhorn Knopp Algorithm](https://pythonot.github.io/auto_examples/plot_OT_1D.html) [2] , stabilized version [9] [10], greedy Sinkhorn [22] and [Screening Sinkhorn [26] ](https://pythonot.github.io/auto_examples/plot_screenkhorn_1D.html). +* Entropic regularization OT solver with [Sinkhorn Knopp Algorithm](https://pythonot.github.io/auto_examples/plot_OT_1D.html) [2] , stabilized version [9] [10] [34], greedy Sinkhorn [22] and [Screening Sinkhorn [26] ](https://pythonot.github.io/auto_examples/plot_screenkhorn_1D.html). * Bregman projections for [Wasserstein barycenter](https://pythonot.github.io/auto_examples/barycenters/plot_barycenter_lp_vs_entropic.html) [3], [convolutional barycenter](https://pythonot.github.io/auto_examples/barycenters/plot_convolutional_barycenter.html) [21] and unmixing [4]. * Sinkhorn divergence [23] and entropic regularization OT from empirical data. * [Smooth optimal transport solvers](https://pythonot.github.io/auto_examples/plot_OT_1D_smooth.html) (dual and semi-dual) for KL and squared L2 regularizations [17]. @@ -290,3 +290,5 @@ You can also post bug reports and feature requests in Github issues. Make sure t [32] Huang, M., Ma S., Lai, L. (2021). [A Riemannian Block Coordinate Descent Method for Computing the Projection Robust Wasserstein Distance](http://proceedings.mlr.press/v139/huang21e.html), Proceedings of the 38th International Conference on Machine Learning (ICML). [33] Kerdoncuff T., Emonet R., Marc S. [Sampled Gromov Wasserstein](https://hal.archives-ouvertes.fr/hal-03232509/document), Machine Learning Journal (MJL), 2021 + +[34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., Trouvé, A., & Peyré, G. (2019, April). Interpolating between optimal transport and MMD using Sinkhorn divergences. In The 22nd International Conference on Artificial Intelligence and Statistics (pp. 2681-2690). PMLR. diff --git a/docs/source/quickstart.rst b/docs/source/quickstart.rst index fd046a1..232df7b 100644 --- a/docs/source/quickstart.rst +++ b/docs/source/quickstart.rst @@ -358,6 +358,11 @@ More details about the algorithms used are given in the following note. + :code:`method='sinkhorn'` calls :any:`ot.bregman.sinkhorn_knopp` the classic algorithm [2]_. + + :code:`method='sinkhorn_log'` calls :any:`ot.bregman.sinkhorn_log` the + sinkhorn algorithm in log space [2]_ that is more stable but can be + slower in numpy since `logsumexp` is not implmemented in parallel. + It is the recommended solver for applications that requires + differentiability with a small number of iterations. + :code:`method='sinkhorn_stabilized'` calls :any:`ot.bregman.sinkhorn_stabilized` the log stabilized version of the algorithm [9]_. + :code:`method='sinkhorn_epsilon_scaling'` calls @@ -389,7 +394,10 @@ More details about the algorithms used are given in the following note. solutions. Note that the greedy version of the Sinkhorn :any:`ot.bregman.greenkhorn` can also lead to a speedup and the screening version of the Sinkhorn :any:`ot.bregman.screenkhorn` aim a providing a - fast approximation of the Sinkhorn problem. + fast approximation of the Sinkhorn problem. For use of GPU and gradient + computation with small number of iterations we strongly recommend the + :any:`ot.bregman.sinkhorn_log` solver that will no need to check for + numerical problems. diff --git a/ot/bregman.py b/ot/bregman.py index b59ee1b..2aa76ff 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -64,7 +64,10 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, solutions. Note that the greedy version of the sinkhorn :py:func:`ot.bregman.greenkhorn` can also lead to a speedup and the screening version of the sinkhorn :py:func:`ot.bregman.screenkhorn` aim at providing a - fast approximation of the Sinkhorn problem. + fast approximation of the Sinkhorn problem. For use of GPU and gradient + computation with small number of iterations we strongly recommend the + :any:`ot.bregman.sinkhorn_log` solver that will no need to check for + numerical problems. Parameters @@ -79,8 +82,9 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, reg : float Regularization term >0 method : str - method used for the solver either 'sinkhorn', 'greenkhorn', 'sinkhorn_stabilized' or - 'sinkhorn_epsilon_scaling', see those function for specific parameters + method used for the solver either 'sinkhorn','sinkhorn_log', + 'greenkhorn', 'sinkhorn_stabilized' or 'sinkhorn_epsilon_scaling', see + those function for specific parameters numItermax : int, optional Max number of iterations stopThr : float, optional @@ -118,6 +122,7 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. + .. [34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., Trouvé, A., & Peyré, G. (2019, April). Interpolating between optimal transport and MMD using Sinkhorn divergences. In The 22nd International Conference on Artificial Intelligence and Statistics (pp. 2681-2690). PMLR. See Also @@ -134,6 +139,10 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) + elif method.lower() == 'sinkhorn_log': + return sinkhorn_log(a, b, M, reg, numItermax=numItermax, + stopThr=stopThr, verbose=verbose, log=log, + **kwargs) elif method.lower() == 'greenkhorn': return greenkhorn(a, b, M, reg, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log) @@ -182,7 +191,7 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, By default and when using a regularization parameter that is not too small the default sinkhorn solver should be enough. If you need to use a small regularization to get sharper OT matrices, you should use the - :any:`ot.bregman.sinkhorn_stabilized` solver that will avoid numerical + :any:`ot.bregman.sinkhorn_log` solver that will avoid numerical errors. This last solver can be very slow in practice and might not even converge to a reasonable OT matrix in a finite time. This is why :any:`ot.bregman.sinkhorn_epsilon_scaling` that relies on iterating the value @@ -190,7 +199,10 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, solutions. Note that the greedy version of the sinkhorn :any:`ot.bregman.greenkhorn` can also lead to a speedup and the screening version of the sinkhorn :any:`ot.bregman.screenkhorn` aim a providing a - fast approximation of the Sinkhorn problem. + fast approximation of the Sinkhorn problem. For use of GPU and gradient + computation with small number of iterations we strongly recommend the + :any:`ot.bregman.sinkhorn_log` solver that will no need to check for + numerical problems. Parameters ---------- @@ -204,7 +216,8 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, reg : float Regularization term >0 method : str - method used for the solver either 'sinkhorn', 'sinkhorn_stabilized', see those function for specific parameters + method used for the solver either 'sinkhorn','sinkhorn_log', + 'sinkhorn_stabilized', see those function for specific parameters numItermax : int, optional Max number of iterations stopThr : float, optional @@ -230,7 +243,7 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, >>> b=[.5, .5] >>> M=[[0., 1.], [1., 0.]] >>> ot.sinkhorn2(a, b, M, 1) - array([0.26894142]) + 0.26894142136999516 .. _references-sinkhorn2: @@ -243,7 +256,11 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. - .. [21] Altschuler J., Weed J., Rigollet P. : Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration, Advances in Neural Information Processing Systems (NIPS) 31, 2017 + .. [21] Altschuler J., Weed J., Rigollet P. : Near-linear time approximation + algorithms for optimal transport via Sinkhorn iteration, Advances in Neural + Information Processing Systems (NIPS) 31, 2017 + + .. [34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., Trouvé, A., & Peyré, G. (2019, April). Interpolating between optimal transport and MMD using Sinkhorn divergences. In The 22nd International Conference on Artificial Intelligence and Statistics (pp. 2681-2690). PMLR. @@ -257,20 +274,45 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, """ - b = list_to_array(b) + M, a, b = list_to_array(M, a, b) + nx = get_backend(M, a, b) + if len(b.shape) < 2: - b = b[:, None] + if method.lower() == 'sinkhorn': + res = sinkhorn_knopp(a, b, M, reg, numItermax=numItermax, + stopThr=stopThr, verbose=verbose, log=log, + **kwargs) + elif method.lower() == 'sinkhorn_log': + res = sinkhorn_log(a, b, M, reg, numItermax=numItermax, + stopThr=stopThr, verbose=verbose, log=log, + **kwargs) + elif method.lower() == 'sinkhorn_stabilized': + res = sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax, + stopThr=stopThr, verbose=verbose, log=log, + **kwargs) + else: + raise ValueError("Unknown method '%s'." % method) + if log: + return nx.sum(M * res[0]), res[1] + else: + return nx.sum(M * res) - if method.lower() == 'sinkhorn': - return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax, - stopThr=stopThr, verbose=verbose, log=log, - **kwargs) - elif method.lower() == 'sinkhorn_stabilized': - return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax, - stopThr=stopThr, verbose=verbose, log=log, - **kwargs) else: - raise ValueError("Unknown method '%s'." % method) + + if method.lower() == 'sinkhorn': + return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax, + stopThr=stopThr, verbose=verbose, log=log, + **kwargs) + elif method.lower() == 'sinkhorn_log': + return sinkhorn_log(a, b, M, reg, numItermax=numItermax, + stopThr=stopThr, verbose=verbose, log=log, + **kwargs) + elif method.lower() == 'sinkhorn_stabilized': + return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax, + stopThr=stopThr, verbose=verbose, log=log, + **kwargs) + else: + raise ValueError("Unknown method '%s'." % method) def sinkhorn_knopp(a, b, M, reg, numItermax=1000, @@ -361,7 +403,7 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, # init data dim_a = len(a) - dim_b = len(b) + dim_b = b.shape[0] if len(b.shape) > 1: n_hists = b.shape[1] @@ -438,6 +480,191 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, return u.reshape((-1, 1)) * K * v.reshape((1, -1)) +def sinkhorn_log(a, b, M, reg, numItermax=1000, + stopThr=1e-9, verbose=False, log=False, **kwargs): + r""" + Solve the entropic regularization optimal transport problem in log space + and return the OT matrix + + The function solves the following optimization problem: + + .. math:: + \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + + s.t. \gamma 1 = a + + \gamma^T 1= b + + \gamma\geq 0 + where : + + - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix + - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (histograms, both sum to 1) + + The algorithm used for solving the problem is the Sinkhorn-Knopp matrix + scaling algorithm :ref:`[2] ` with the + implementation from :ref:`[34] ` + + + Parameters + ---------- + a : array-like, shape (dim_a,) + samples weights in the source domain + b : array-like, shape (dim_b,) or array-like, shape (dim_b, n_hists) + samples in the target domain, compute sinkhorn with multiple targets + and fixed :math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix (return OT loss + dual variables in log) + M : array-like, shape (dim_a, dim_b) + loss matrix + reg : float + Regularization term >0 + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshold on error (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + Returns + ------- + gamma : array-like, shape (dim_a, dim_b) + Optimal transportation matrix for the given parameters + log : dict + log dictionary return only if log==True in parameters + + Examples + -------- + + >>> import ot + >>> a=[.5, .5] + >>> b=[.5, .5] + >>> M=[[0., 1.], [1., 0.]] + >>> ot.sinkhorn(a, b, M, 1) + array([[0.36552929, 0.13447071], + [0.13447071, 0.36552929]]) + + + .. _references-sinkhorn-log: + References + ---------- + + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal + Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 + + .. [34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., Trouvé, A., & Peyré, G. (2019, April). Interpolating between optimal transport and MMD using Sinkhorn divergences. In The 22nd International Conference on Artificial Intelligence and Statistics (pp. 2681-2690). PMLR. + + + See Also + -------- + ot.lp.emd : Unregularized OT + ot.optim.cg : General regularized OT + + """ + + a, b, M = list_to_array(a, b, M) + + nx = get_backend(M, a, b) + + if len(a) == 0: + a = nx.full((M.shape[0],), 1.0 / M.shape[0], type_as=M) + if len(b) == 0: + b = nx.full((M.shape[1],), 1.0 / M.shape[1], type_as=M) + + # init data + dim_a = len(a) + dim_b = b.shape[0] + + if len(b.shape) > 1: + n_hists = b.shape[1] + else: + n_hists = 0 + + if n_hists: # we do not want to use tensors sor we do a loop + + lst_loss = [] + lst_u = [] + lst_v = [] + + for k in range(n_hists): + res = sinkhorn_log(a, b[:, k], M, reg, numItermax=numItermax, + stopThr=stopThr, verbose=verbose, log=log, **kwargs) + + if log: + lst_loss.append(nx.sum(M * res[0])) + lst_u.append(res[1]['log_u']) + lst_v.append(res[1]['log_v']) + else: + lst_loss.append(nx.sum(M * res)) + res = nx.stack(lst_loss) + if log: + log = {'log_u': nx.stack(lst_u, 1), + 'log_v': nx.stack(lst_v, 1), } + log['u'] = nx.exp(log['log_u']) + log['v'] = nx.exp(log['log_v']) + return res, log + else: + return res + + else: + + if log: + log = {'err': []} + + Mr = M / (-reg) + + # we assume that no distances are null except those of the diagonal of + # distances + + u = nx.zeros(dim_a, type_as=M) + v = nx.zeros(dim_b, type_as=M) + + def get_logT(u, v): + if n_hists: + return Mr[:, :, None] + u + v + else: + return Mr + u[:, None] + v[None, :] + + loga = nx.log(a) + logb = nx.log(b) + + cpt = 0 + err = 1 + while (err > stopThr and cpt < numItermax): + + v = logb - nx.logsumexp(Mr + u[:, None], 0) + u = loga - nx.logsumexp(Mr + v[None, :], 1) + + if cpt % 10 == 0: + # we can speed up the process by checking for the error only all + # the 10th iterations + + # compute right marginal tmp2= (diag(u)Kdiag(v))^T1 + tmp2 = nx.sum(nx.exp(get_logT(u, v)), 0) + err = nx.norm(tmp2 - b) # violation of marginal + if log: + log['err'].append(err) + + if verbose: + if cpt % 200 == 0: + print( + '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(cpt, err)) + cpt = cpt + 1 + + if log: + log['log_u'] = u + log['log_v'] = v + log['u'] = nx.exp(u) + log['v'] = nx.exp(v) + + return nx.exp(get_logT(u, v)), log + + else: + return nx.exp(get_logT(u, v)) + + def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log=False): r""" @@ -1881,8 +2108,7 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', return (f, g) else: - M = dist(nx.to_numpy(X_s), nx.to_numpy(X_t), metric=metric) - M = nx.from_numpy(M, type_as=a) + M = dist(X_s, X_t, metric=metric) if log: pi, log = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=True, **kwargs) return pi, log @@ -2102,7 +2328,7 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli >>> X_s = np.reshape(np.arange(n_samples_a), (n_samples_a, 1)) >>> X_t = np.reshape(np.arange(0, n_samples_b), (n_samples_b, 1)) >>> empirical_sinkhorn_divergence(X_s, X_t, reg) # doctest: +ELLIPSIS - array([1.499...]) + 1.499887176049052 References diff --git a/ot/dr.py b/ot/dr.py index 64588cf..de39662 100644 --- a/ot/dr.py +++ b/ot/dr.py @@ -209,11 +209,11 @@ def projection_robust_wasserstein(X, Y, a, b, tau, U0=None, reg=0.1, k=2, stopTh .. math:: \max_{U \in St(d, k)} \min_{\pi \in \Pi(\mu,\nu)} \sum_{i,j} \pi_{i,j} \|U^T(x_i - y_j)\|^2 - reg * H(\pi) - + - :math:`U` is a linear projection operator in the Stiefel(d, k) manifold - :math:`H(\pi)` is entropy regularizer - :math:`x_i`, :math:`y_j` are samples of measures \mu and \nu respectively - + Parameters ---------- X : ndarray, shape (n, d) diff --git a/ot/gromov.py b/ot/gromov.py index 85b1549..33b4453 100644 --- a/ot/gromov.py +++ b/ot/gromov.py @@ -1030,7 +1030,7 @@ def entropic_gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, # compute the gradient tens = gwggrad(constC, hC1, hC2, T) - T = sinkhorn(p, q, tens, epsilon) + T = sinkhorn(p, q, tens, epsilon, method='sinkhorn') if cpt % 10 == 0: # we can speed up the process by checking for the error only all @@ -1204,7 +1204,7 @@ def entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, Cprev = C T = [entropic_gromov_wasserstein(Cs[s], C, ps[s], p, loss_fun, epsilon, - max_iter, 1e-5, verbose, log) for s in range(S)] + max_iter, 1e-4, verbose, log) for s in range(S)] if loss_fun == 'square_loss': C = update_square_loss(p, lambdas, T, Cs) diff --git a/ot/optim.py b/ot/optim.py index 6822e4e..34cbb17 100644 --- a/ot/optim.py +++ b/ot/optim.py @@ -20,7 +20,7 @@ from .backend import get_backend def line_search_armijo(f, xk, pk, gfk, old_fval, args=(), c1=1e-4, alpha0=0.99): - """ + r""" Armijo linesearch function that works with matrices Find an approximate minimum of :math:`f(x_k + \\alpha \cdot p_k)` that satisfies the @@ -447,7 +447,7 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10, def solve_1d_linesearch_quad(a, b, c): - """ + r""" For any convex or non-convex 1d quadratic function `f`, solve the following problem: .. math:: diff --git a/ot/utils.py b/ot/utils.py index 6a782e6..0608aee 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -183,7 +183,7 @@ def euclidean_distances(X, Y, squared=False): return c -def dist(x1, x2=None, metric='sqeuclidean'): +def dist(x1, x2=None, metric='sqeuclidean', p=2): """Compute distance between samples in x1 and x2 .. note:: This function is backend-compatible and will work on arrays @@ -222,7 +222,7 @@ def dist(x1, x2=None, metric='sqeuclidean'): if not get_backend(x1, x2).__name__ == 'numpy': raise NotImplementedError() else: - return cdist(x1, x2, metric=metric) + return cdist(x1, x2, metric=metric, p=p) def dist0(n, method='lin_square'): diff --git a/test/test_bregman.py b/test/test_bregman.py index 942cb6d..c1120ba 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -32,6 +32,27 @@ def test_sinkhorn(): u, G.sum(0), atol=1e-05) # cf convergence sinkhorn +def test_sinkhorn_multi_b(): + # test sinkhorn + n = 10 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + u = ot.utils.unif(n) + + b = rng.rand(n, 3) + b = b / np.sum(b, 0, keepdims=True) + + M = ot.dist(x, x) + + loss0, log = ot.sinkhorn(u, b, M, .1, stopThr=1e-10, log=True) + + loss = [ot.sinkhorn2(u, b[:, k], M, .1, stopThr=1e-10) for k in range(3)] + # check constraints + np.testing.assert_allclose( + loss0, loss, atol=1e-06) # cf convergence sinkhorn + + def test_sinkhorn_backends(nx): n_samples = 100 n_features = 2 @@ -147,6 +168,7 @@ def test_sinkhorn_variants(nx): Mb = nx.from_numpy(M) G = ot.sinkhorn(u, u, M, 1, method='sinkhorn', stopThr=1e-10) + Gl = nx.to_numpy(ot.sinkhorn(ub, ub, Mb, 1, method='sinkhorn_log', stopThr=1e-10)) G0 = nx.to_numpy(ot.sinkhorn(ub, ub, Mb, 1, method='sinkhorn', stopThr=1e-10)) Gs = nx.to_numpy(ot.sinkhorn(ub, ub, Mb, 1, method='sinkhorn_stabilized', stopThr=1e-10)) Ges = nx.to_numpy(ot.sinkhorn( @@ -155,15 +177,73 @@ def test_sinkhorn_variants(nx): # check values np.testing.assert_allclose(G, G0, atol=1e-05) + np.testing.assert_allclose(G, Gl, atol=1e-05) np.testing.assert_allclose(G0, Gs, atol=1e-05) np.testing.assert_allclose(G0, Ges, atol=1e-05) np.testing.assert_allclose(G0, G_green, atol=1e-5) - print(G0, G_green) + + +@pytest.skip_backend("jax") +def test_sinkhorn_variants_multi_b(nx): + # test sinkhorn + n = 50 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + u = ot.utils.unif(n) + + b = rng.rand(n, 3) + b = b / np.sum(b, 0, keepdims=True) + + M = ot.dist(x, x) + + ub = nx.from_numpy(u) + bb = nx.from_numpy(b) + Mb = nx.from_numpy(M) + + G = ot.sinkhorn(u, b, M, 1, method='sinkhorn', stopThr=1e-10) + Gl = nx.to_numpy(ot.sinkhorn(ub, bb, Mb, 1, method='sinkhorn_log', stopThr=1e-10)) + G0 = nx.to_numpy(ot.sinkhorn(ub, bb, Mb, 1, method='sinkhorn', stopThr=1e-10)) + Gs = nx.to_numpy(ot.sinkhorn(ub, bb, Mb, 1, method='sinkhorn_stabilized', stopThr=1e-10)) + + # check values + np.testing.assert_allclose(G, G0, atol=1e-05) + np.testing.assert_allclose(G, Gl, atol=1e-05) + np.testing.assert_allclose(G0, Gs, atol=1e-05) + + +@pytest.skip_backend("jax") +def test_sinkhorn2_variants_multi_b(nx): + # test sinkhorn + n = 50 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + u = ot.utils.unif(n) + + b = rng.rand(n, 3) + b = b / np.sum(b, 0, keepdims=True) + + M = ot.dist(x, x) + + ub = nx.from_numpy(u) + bb = nx.from_numpy(b) + Mb = nx.from_numpy(M) + + G = ot.sinkhorn2(u, b, M, 1, method='sinkhorn', stopThr=1e-10) + Gl = nx.to_numpy(ot.sinkhorn2(ub, bb, Mb, 1, method='sinkhorn_log', stopThr=1e-10)) + G0 = nx.to_numpy(ot.sinkhorn2(ub, bb, Mb, 1, method='sinkhorn', stopThr=1e-10)) + Gs = nx.to_numpy(ot.sinkhorn2(ub, bb, Mb, 1, method='sinkhorn_stabilized', stopThr=1e-10)) + + # check values + np.testing.assert_allclose(G, G0, atol=1e-05) + np.testing.assert_allclose(G, Gl, atol=1e-05) + np.testing.assert_allclose(G0, Gs, atol=1e-05) def test_sinkhorn_variants_log(): # test sinkhorn - n = 100 + n = 50 rng = np.random.RandomState(0) x = rng.randn(n, 2) @@ -172,6 +252,7 @@ def test_sinkhorn_variants_log(): M = ot.dist(x, x) G0, log0 = ot.sinkhorn(u, u, M, 1, method='sinkhorn', stopThr=1e-10, log=True) + Gl, logl = ot.sinkhorn(u, u, M, 1, method='sinkhorn_log', stopThr=1e-10, log=True) Gs, logs = ot.sinkhorn(u, u, M, 1, method='sinkhorn_stabilized', stopThr=1e-10, log=True) Ges, loges = ot.sinkhorn( u, u, M, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10, log=True) @@ -179,9 +260,30 @@ def test_sinkhorn_variants_log(): # check values np.testing.assert_allclose(G0, Gs, atol=1e-05) + np.testing.assert_allclose(G0, Gl, atol=1e-05) np.testing.assert_allclose(G0, Ges, atol=1e-05) np.testing.assert_allclose(G0, G_green, atol=1e-5) - print(G0, G_green) + + +def test_sinkhorn_variants_log_multib(): + # test sinkhorn + n = 50 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + u = ot.utils.unif(n) + b = rng.rand(n, 3) + b = b / np.sum(b, 0, keepdims=True) + + M = ot.dist(x, x) + + G0, log0 = ot.sinkhorn(u, b, M, 1, method='sinkhorn', stopThr=1e-10, log=True) + Gl, logl = ot.sinkhorn(u, b, M, 1, method='sinkhorn_log', stopThr=1e-10, log=True) + Gs, logs = ot.sinkhorn(u, b, M, 1, method='sinkhorn_stabilized', stopThr=1e-10, log=True) + + # check values + np.testing.assert_allclose(G0, Gs, atol=1e-05) + np.testing.assert_allclose(G0, Gl, atol=1e-05) @pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"]) @@ -326,10 +428,10 @@ def test_empirical_sinkhorn(nx): a = ot.unif(n) b = ot.unif(n) - X_s = np.reshape(np.arange(n), (n, 1)) - X_t = np.reshape(np.arange(0, n), (n, 1)) + X_s = np.reshape(1.0 * np.arange(n), (n, 1)) + X_t = np.reshape(1.0 * np.arange(0, n), (n, 1)) M = ot.dist(X_s, X_t) - M_m = ot.dist(X_s, X_t, metric='minkowski') + M_m = ot.dist(X_s, X_t, metric='euclidean') ab = nx.from_numpy(a) bb = nx.from_numpy(b) @@ -346,7 +448,7 @@ def test_empirical_sinkhorn(nx): sinkhorn_log, log_s = ot.sinkhorn(ab, bb, Mb, 0.1, log=True) sinkhorn_log = nx.to_numpy(sinkhorn_log) - G_m = nx.to_numpy(ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1, metric='minkowski')) + G_m = nx.to_numpy(ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1, metric='euclidean')) sinkhorn_m = nx.to_numpy(ot.sinkhorn(ab, bb, M_mb, 1)) loss_emp_sinkhorn = nx.to_numpy(ot.bregman.empirical_sinkhorn2(X_sb, X_tb, 1)) @@ -378,7 +480,7 @@ def test_lazy_empirical_sinkhorn(nx): X_s = np.reshape(np.arange(n), (n, 1)) X_t = np.reshape(np.arange(0, n), (n, 1)) M = ot.dist(X_s, X_t) - M_m = ot.dist(X_s, X_t, metric='minkowski') + M_m = ot.dist(X_s, X_t, metric='euclidean') ab = nx.from_numpy(a) bb = nx.from_numpy(b) @@ -398,7 +500,7 @@ def test_lazy_empirical_sinkhorn(nx): sinkhorn_log, log_s = ot.sinkhorn(ab, bb, Mb, 0.1, log=True) sinkhorn_log = nx.to_numpy(sinkhorn_log) - f, g = ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1, metric='minkowski', numIterMax=numIterMax, isLazy=True, batchSize=1) + f, g = ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1, metric='euclidean', numIterMax=numIterMax, isLazy=True, batchSize=1) f, g = nx.to_numpy(f), nx.to_numpy(g) G_m = np.exp(f[:, None] + g[None, :] - M_m / 1) sinkhorn_m = nx.to_numpy(ot.sinkhorn(ab, bb, M_mb, 1)) diff --git a/test/test_gromov.py b/test/test_gromov.py index 19d61b1..0242d72 100644 --- a/test/test_gromov.py +++ b/test/test_gromov.py @@ -180,8 +180,8 @@ def test_sampled_gromov(): def test_gromov_barycenter(): - ns = 50 - nt = 60 + ns = 10 + nt = 20 Xs, ys = ot.datasets.make_data_classif('3gauss', ns, random_state=42) Xt, yt = ot.datasets.make_data_classif('3gauss2', nt, random_state=42) @@ -208,8 +208,8 @@ def test_gromov_barycenter(): @pytest.mark.filterwarnings("ignore:divide") def test_gromov_entropic_barycenter(): - ns = 20 - nt = 30 + ns = 10 + nt = 20 Xs, ys = ot.datasets.make_data_classif('3gauss', ns, random_state=42) Xt, yt = ot.datasets.make_data_classif('3gauss2', nt, random_state=42) @@ -222,7 +222,7 @@ def test_gromov_entropic_barycenter(): [ot.unif(ns), ot.unif(nt) ], ot.unif(n_samples), [.5, .5], 'square_loss', 1e-3, - max_iter=50, tol=1e-5, + max_iter=50, tol=1e-3, verbose=True) np.testing.assert_allclose(Cb.shape, (n_samples, n_samples)) diff --git a/test/test_helpers.py b/test/test_helpers.py index 8bd0015..cc4c90e 100644 --- a/test/test_helpers.py +++ b/test/test_helpers.py @@ -9,8 +9,8 @@ import sys sys.path.append(os.path.join("ot", "helpers")) -from openmp_helpers import get_openmp_flag, check_openmp_support # noqa -from pre_build_helpers import _get_compiler, compile_test_program # noqa +from openmp_helpers import get_openmp_flag, check_openmp_support # noqa +from pre_build_helpers import _get_compiler, compile_test_program # noqa def test_helpers(): diff --git a/test/test_utils.py b/test/test_utils.py index 60ad5d3..0650ce2 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -7,6 +7,7 @@ import ot import numpy as np import sys +import pytest def test_proj_simplex(nx): @@ -108,6 +109,10 @@ def test_dist(): D2 = ot.dist(x, x) D3 = ot.dist(x) + D4 = ot.dist(x, x, metric='minkowski', p=0.5) + + assert D4[0, 1] == D4[1, 0] + # dist shoul return squared euclidean np.testing.assert_allclose(D, D2, atol=1e-14) np.testing.assert_allclose(D, D3, atol=1e-14) @@ -220,6 +225,13 @@ def test_deprecated_func(): class Class(): pass + with pytest.warns(DeprecationWarning): + fun() + + with pytest.warns(DeprecationWarning): + cl = Class() + print(cl) + if sys.version_info < (3, 5): print('Not tested') else: @@ -250,4 +262,7 @@ def test_BaseEstimator(): params['first'] = 'spam again' cl.set_params(**params) + with pytest.raises(ValueError): + cl.set_params(bibi=10) + assert cl.first == 'spam again' -- cgit v1.2.3