From daec9fe15f9728080b54a7ddbfdb67075e78c6bd Mon Sep 17 00:00:00 2001 From: AdrienCorenflos Date: Tue, 5 May 2020 13:14:35 +0100 Subject: break before exceeding array size --- ot/lp/emd_wrap.pyx | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) (limited to 'ot/lp') diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx index c167964..e9e8fba 100644 --- a/ot/lp/emd_wrap.pyx +++ b/ot/lp/emd_wrap.pyx @@ -172,7 +172,7 @@ def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights, cdef np.ndarray[long, ndim=2, mode="c"] indices = np.zeros((n + m - 1, 2), dtype=np.int) cdef int cur_idx = 0 - while i < n and j < m: + while True: if metric == 'sqeuclidean': m_ij = (u[i] - v[j]) * (u[i] - v[j]) elif metric == 'cityblock' or metric == 'euclidean': @@ -188,6 +188,8 @@ def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights, indices[cur_idx, 0] = i indices[cur_idx, 1] = j i += 1 + if i == n: + break w_j -= w_i w_i = u_weights[i] else: @@ -196,6 +198,8 @@ def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights, indices[cur_idx, 0] = i indices[cur_idx, 1] = j j += 1 + if j == m: + break w_i -= w_j w_j = v_weights[j] cur_idx += 1 -- cgit v1.2.3 From ea2890aa3cfbf09a32f8ef3063b6a413f485526b Mon Sep 17 00:00:00 2001 From: AdrienCorenflos Date: Tue, 5 May 2020 13:19:13 +0100 Subject: Some improvements for platform compatibility --- ot/lp/emd_wrap.pyx | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) (limited to 'ot/lp') diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx index e9e8fba..10bc5cf 100644 --- a/ot/lp/emd_wrap.pyx +++ b/ot/lp/emd_wrap.pyx @@ -157,12 +157,12 @@ def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights, cost associated to the optimal transportation """ cdef double cost = 0. - cdef int n = u_weights.shape[0] - cdef int m = v_weights.shape[0] + cdef Py_ssize_t n = u_weights.shape[0] + cdef Py_ssize_t m = v_weights.shape[0] - cdef int i = 0 + cdef Py_ssize_t i = 0 cdef double w_i = u_weights[0] - cdef int j = 0 + cdef Py_ssize_t j = 0 cdef double w_j = v_weights[0] cdef double m_ij = 0. @@ -171,7 +171,7 @@ def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights, dtype=np.float64) cdef np.ndarray[long, ndim=2, mode="c"] indices = np.zeros((n + m - 1, 2), dtype=np.int) - cdef int cur_idx = 0 + cdef Py_ssize_t cur_idx = 0 while True: if metric == 'sqeuclidean': m_ij = (u[i] - v[j]) * (u[i] - v[j]) -- cgit v1.2.3 From ea6642c4873b557b4d284f6f3717d8990e23ad51 Mon Sep 17 00:00:00 2001 From: AdrienCorenflos Date: Tue, 5 May 2020 13:37:08 +0100 Subject: fix failing test - cur_idx needs to be incremented by 1 after the loop --- ot/lp/emd_wrap.pyx | 1 + 1 file changed, 1 insertion(+) (limited to 'ot/lp') diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx index 10bc5cf..d79d0ca 100644 --- a/ot/lp/emd_wrap.pyx +++ b/ot/lp/emd_wrap.pyx @@ -203,4 +203,5 @@ def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights, w_i -= w_j w_j = v_weights[j] cur_idx += 1 + cur_idx += 1 return G[:cur_idx], indices[:cur_idx], cost -- cgit v1.2.3 From 23db72c49465a1eeb2897d4c6dd9c189aec9cd6e Mon Sep 17 00:00:00 2001 From: AdrienCorenflos Date: Mon, 20 Jul 2020 14:59:13 +0300 Subject: Correct documentation for support barycenter (#201) * example for log treatment in bregman.py * Improve doc * Revert "example for log treatment in bregman.py" This reverts commit 9f51c14e * Add comments by Flamary * Delete repetitive description * Added raw string to avoid pbs with backslashes --- ot/lp/__init__.py | 26 ++++++++++++++++++-------- 1 file changed, 18 insertions(+), 8 deletions(-) (limited to 'ot/lp') diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index 514a607..2a1b082 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -272,7 +272,7 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True): if np.any(~asel) or np.any(~bsel): u, v = estimate_dual_null_weights(u, v, a, b, M) - + result_code_string = check_result(result_code) if log: log = {} @@ -389,7 +389,7 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), if log or return_matrix: def f(b): bsel = b != 0 - + G, cost, u, v, result_code = emd_c(a, b, M, numItermax) if center_dual: @@ -435,26 +435,36 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), def free_support_barycenter(measures_locations, measures_weights, X_init, b=None, weights=None, numItermax=100, stopThr=1e-7, verbose=False, log=None): - """ - Solves the free support (locations of the barycenters are optimized, not the weights) Wasserstein barycenter problem (i.e. the weighted Frechet mean for the 2-Wasserstein distance) + r""" + Solves the free support (locations of the barycenters are optimized, not the weights) Wasserstein barycenter problem (i.e. the weighted Frechet mean for the 2-Wasserstein distance), formally: + + .. math:: + \min_X \sum_{i=1}^N w_i W_2^2(b, X, a_i, X_i) + + where : + + - :math:`w \in \mathbb{(0, 1)}^{N}`'s are the barycenter weights and sum to one + - the :math:`a_i \in \mathbb{R}^{k_i}` are the empirical measures weights and sum to one for each :math:`i` + - the :math:`X_i \in \mathbb{R}^{k_i, d}` are the empirical measures atoms locations + - :math:`b \in \mathbb{R}^{k}` is the desired weights vector of the barycenter - The function solves the Wasserstein barycenter problem when the barycenter measure is constrained to be supported on k atoms. This problem is considered in [1] (Algorithm 2). There are two differences with the following codes: + - we do not optimize over the weights - we do not do line search for the locations updates, we use i.e. theta = 1 in [1] (Algorithm 2). This can be seen as a discrete implementation of the fixed-point algorithm of [2] proposed in the continuous setting. Parameters ---------- - measures_locations : list of (k_i,d) numpy.ndarray + measures_locations : list of N (k_i,d) numpy.ndarray The discrete support of a measure supported on k_i locations of a d-dimensional space (k_i can be different for each element of the list) - measures_weights : list of (k_i,) numpy.ndarray + measures_weights : list of N (k_i,) numpy.ndarray Numpy arrays where each numpy array has k_i non-negatives values summing to one representing the weights of each discrete input measure X_init : (k,d) np.ndarray Initialization of the support locations (on k atoms) of the barycenter b : (k,) np.ndarray Initialization of the weights of the barycenter (non-negatives, sum to 1) - weights : (k,) np.ndarray + weights : (N,) np.ndarray Initialization of the coefficients of the barycenter (non-negatives, sum to 1) numItermax : int, optional -- cgit v1.2.3 From 7adc1b1aa73c55dc07983ff08dcb23fd71e9e8b6 Mon Sep 17 00:00:00 2001 From: Rémi Flamary Date: Thu, 22 Oct 2020 10:16:40 +0200 Subject: [MRG] Cleanup minimal build and add separate build for pep8 (#210) * cleanup requiorement minimal * add pep8 build * cleanup sklearn * skip test if no sklearn * debug build yaml * comment error out in test (test sklearn) * maybe small stuff for better robustness : copy the sub-array * bump verison minimal build * update version strict requireent * update version strict requirement last change --- .github/requirements_strict.txt | 9 +++------ .github/workflows/build_tests.yml | 37 +++++++++++++++++++++++++------------ .gitignore | 3 +++ Makefile | 4 ++-- ot/lp/__init__.py | 2 +- test/test_da.py | 8 ++++++++ 6 files changed, 42 insertions(+), 21 deletions(-) (limited to 'ot/lp') diff --git a/.github/requirements_strict.txt b/.github/requirements_strict.txt index d7539c5..9a1ada4 100644 --- a/.github/requirements_strict.txt +++ b/.github/requirements_strict.txt @@ -1,7 +1,4 @@ -numpy==1.16.* -scipy==1.0.* -cython==0.23.* -matplotlib -cvxopt -scikit-learn +numpy +scipy>=1.3 +cython pytest diff --git a/.github/workflows/build_tests.yml b/.github/workflows/build_tests.yml index 41b08b3..fa814ba 100644 --- a/.github/workflows/build_tests.yml +++ b/.github/workflows/build_tests.yml @@ -30,14 +30,7 @@ jobs: run: | python -m pip install --upgrade pip pip install -r requirements.txt - pip install flake8 pytest "pytest-cov<2.6" codecov - pip install -U "sklearn" - - name: Lint with flake8 - run: | - # stop the build if there are Python syntax errors or undefined names - flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics - # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide - flake8 examples/ ot/ test/ --count --max-line-length=127 --statistics + pip install pytest "pytest-cov<2.6" codecov - name: Install POT run: | pip install -e . @@ -48,6 +41,29 @@ jobs: run: | codecov + pep8: + runs-on: ubuntu-latest + strategy: + max-parallel: 4 + matrix: + python-version: [3.8] + + steps: + - uses: actions/checkout@v1 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v1 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install flake8 + - name: Lint with flake8 + run: | + # stop the build if there are Python syntax errors or undefined names + flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics + # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide + flake8 examples/ ot/ test/ --count --max-line-length=127 --statistics linux-minimal-deps: @@ -55,7 +71,7 @@ jobs: strategy: max-parallel: 4 matrix: - python-version: [3.6] + python-version: [3.8] steps: - uses: actions/checkout@v1 @@ -68,7 +84,6 @@ jobs: python -m pip install --upgrade pip pip install -r .github/requirements_strict.txt pip install pytest - pip install -U "sklearn" - name: Install POT run: | pip install -e . @@ -95,7 +110,6 @@ jobs: python -m pip install --upgrade pip pip install -r requirements.txt pip install pytest "pytest-cov<2.6" - pip install -U "sklearn" - name: Install POT run: | pip install -e . @@ -122,7 +136,6 @@ jobs: python -m pip install --upgrade pip pip install -r requirements.txt pip install pytest "pytest-cov<2.6" - pip install -U "sklearn" - name: Install POT run: | pip install -e . diff --git a/.gitignore b/.gitignore index a2ace7c..b44ea43 100644 --- a/.gitignore +++ b/.gitignore @@ -40,6 +40,9 @@ var/ *.manifest *.spec +# env +pythonenv3.8/ + # Installer logs pip-log.txt pip-delete-this-directory.txt diff --git a/Makefile b/Makefile index 70cdbdd..32332b4 100644 --- a/Makefile +++ b/Makefile @@ -45,10 +45,10 @@ pep8 : flake8 examples/ ot/ test/ test : FORCE pep8 - $(PYTHON) -m pytest -v test/ --doctest-modules --ignore ot/gpu/ --cov=ot --cov-report html:cov_html + $(PYTHON) -m pytest -v test/ --doctest-modules --ignore ot/gpu/ pytest : FORCE - $(PYTHON) -m pytest -v test/ --doctest-modules --ignore ot/gpu/ --cov=ot + $(PYTHON) -m pytest -v test/ --doctest-modules --ignore ot/gpu/ release : twine upload dist/* diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index 2a1b082..f08e020 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -426,7 +426,7 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), nb = b.shape[1] if processes > 1: - res = parmap(f, [b[:, i] for i in range(nb)], processes) + res = parmap(f, [b[:, i].copy() for i in range(nb)], processes) else: res = list(map(f, [b[:, i].copy() for i in range(nb)])) diff --git a/test/test_da.py b/test/test_da.py index 3b28119..52c6a48 100644 --- a/test/test_da.py +++ b/test/test_da.py @@ -6,11 +6,18 @@ import numpy as np from numpy.testing import assert_allclose, assert_equal +import pytest import ot from ot.datasets import make_data_classif from ot.utils import unif +try: # test if cudamat installed + import sklearn # noqa: F401 + nosklearn = False +except ImportError: + nosklearn = True + def test_sinkhorn_lpl1_transport_class(): """test_sinkhorn_transport @@ -691,6 +698,7 @@ def test_jcpot_barycenter(): np.testing.assert_allclose(prop, [1 - pt, pt], rtol=1e-3, atol=1e-3) +@pytest.mark.skipif(nosklearn, reason="No sklearn available") def test_emd_laplace_class(): """test_emd_laplace_transport """ -- cgit v1.2.3 From cb3e24aea8a2492ccb7e7664533ea3543b14c8ac Mon Sep 17 00:00:00 2001 From: Nicolas Courty Date: Thu, 12 Nov 2020 16:04:16 +0100 Subject: change precision EPSILON in C code (#217) * change precision EPSILON in C code * change precision EPSILON in C code V2 * change precision EPSILON in C code V3 (add comment and remove unnecessary lines --- ot/lp/network_simplex_simple.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'ot/lp') diff --git a/ot/lp/network_simplex_simple.h b/ot/lp/network_simplex_simple.h index 5d93040..630b595 100644 --- a/ot/lp/network_simplex_simple.h +++ b/ot/lp/network_simplex_simple.h @@ -1507,7 +1507,7 @@ namespace lemon { if( retVal == OPTIMAL){ for (int e = _search_arc_num; e != _all_arc_num; ++e) { if (_flow[e] != 0){ - if (abs(_flow[e]) > EPSILON) + if (fabs(_flow[e]) > _EPSILON) // change of the original code following issue #126 return INFEASIBLE; else _flow[e]=0; -- cgit v1.2.3 From cd3ce6140d7a2dbe2bcf05927a8dd8289f4ce9e2 Mon Sep 17 00:00:00 2001 From: Rémi Flamary Date: Mon, 19 Apr 2021 15:03:57 +0200 Subject: [MRG] Cleanup test warnings (#242) * remove warnings in tests from docstrings * working tets for bregman implemneted methods * pep8 --- ot/da.py | 12 ++++++------ ot/dr.py | 2 +- ot/gpu/bregman.py | 2 +- ot/gromov.py | 20 ++++++++++---------- ot/lp/cvx.py | 3 +-- ot/optim.py | 4 ++-- test/test_bregman.py | 3 ++- 7 files changed, 23 insertions(+), 23 deletions(-) (limited to 'ot/lp') diff --git a/ot/da.py b/ot/da.py index f1e4769..cdc747c 100644 --- a/ot/da.py +++ b/ot/da.py @@ -26,7 +26,7 @@ from .optim import gcg def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10, numInnerItermax=200, stopInnerThr=1e-9, verbose=False, log=False): - """ + r""" Solve the entropic regularization optimal transport problem with nonconvex group lasso regularization @@ -137,7 +137,7 @@ def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10, def sinkhorn_l1l2_gl(a, labels_a, b, M, reg, eta=0.1, numItermax=10, numInnerItermax=200, stopInnerThr=1e-9, verbose=False, log=False): - """ + r""" Solve the entropic regularization optimal transport problem with group lasso regularization @@ -245,7 +245,7 @@ def joint_OT_mapping_linear(xs, xt, mu=1, eta=0.001, bias=False, verbose=False, verbose2=False, numItermax=100, numInnerItermax=10, stopInnerThr=1e-6, stopThr=1e-5, log=False, **kwargs): - """Joint OT and linear mapping estimation as proposed in [8] + r"""Joint OT and linear mapping estimation as proposed in [8] The function solves the following optimization problem: @@ -434,7 +434,7 @@ def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian', numItermax=100, numInnerItermax=10, stopInnerThr=1e-6, stopThr=1e-5, log=False, **kwargs): - """Joint OT and nonlinear mapping estimation with kernels as proposed in [8] + r"""Joint OT and nonlinear mapping estimation with kernels as proposed in [8] The function solves the following optimization problem: @@ -645,7 +645,7 @@ def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian', def OT_mapping_linear(xs, xt, reg=1e-6, ws=None, wt=None, bias=True, log=False): - """ return OT linear operator between samples + r""" return OT linear operator between samples The function estimates the optimal linear operator that aligns the two empirical distributions. This is equivalent to estimating the closed @@ -1228,7 +1228,7 @@ class BaseTransport(BaseEstimator): class LinearTransport(BaseTransport): - """ OT linear operator between empirical distributions + r""" OT linear operator between empirical distributions The function estimates the optimal linear operator that aligns the two empirical distributions. This is equivalent to estimating the closed diff --git a/ot/dr.py b/ot/dr.py index 11d2e10..b7a1af0 100644 --- a/ot/dr.py +++ b/ot/dr.py @@ -109,7 +109,7 @@ def fda(X, y, p=2, reg=1e-16): def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None): - """ + r""" Wasserstein Discriminant Analysis [11]_ The function solves the following optimization problem: diff --git a/ot/gpu/bregman.py b/ot/gpu/bregman.py index 2e2df83..82f34f3 100644 --- a/ot/gpu/bregman.py +++ b/ot/gpu/bregman.py @@ -15,7 +15,7 @@ from . import utils def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9, verbose=False, log=False, to_numpy=True, **kwargs): - """ + r""" Solve the entropic regularization optimal transport on GPU If the input matrix are in numpy format, they will be uploaded to the diff --git a/ot/gromov.py b/ot/gromov.py index 4427a96..8f457e9 100644 --- a/ot/gromov.py +++ b/ot/gromov.py @@ -19,7 +19,7 @@ from .optim import cg def init_matrix(C1, C2, p, q, loss_fun='square_loss'): - """Return loss matrices and tensors for Gromov-Wasserstein fast computation + r"""Return loss matrices and tensors for Gromov-Wasserstein fast computation Returns the value of \mathcal{L}(C1,C2) \otimes T with the selected loss function as the loss function of Gromow-Wasserstein discrepancy. @@ -109,7 +109,7 @@ def init_matrix(C1, C2, p, q, loss_fun='square_loss'): def tensor_product(constC, hC1, hC2, T): - """Return the tensor for Gromov-Wasserstein fast computation + r"""Return the tensor for Gromov-Wasserstein fast computation The tensor is computed as described in Proposition 1 Eq. (6) in [12]. @@ -262,7 +262,7 @@ def update_kl_loss(p, lambdas, T, Cs): def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs): - """ + r""" Returns the gromov-wasserstein transport between (C1,p) and (C2,q) The function solves the following optimization problem: @@ -343,7 +343,7 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs): - """ + r""" Returns the gromov-wasserstein discrepancy between (C1,p) and (C2,q) The function solves the following optimization problem: @@ -420,7 +420,7 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwarg def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, **kwargs): - """ + r""" Computes the FGW transport between two graphs see [24] .. math:: @@ -496,7 +496,7 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, **kwargs): - """ + r""" Computes the FGW distance between two graphs see [24] .. math:: @@ -574,7 +574,7 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5 def entropic_gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, tol=1e-9, verbose=False, log=False): - """ + r""" Returns the gromov-wasserstein transport between (C1,p) and (C2,q) (C1,p) and (C2,q) @@ -681,7 +681,7 @@ def entropic_gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, def entropic_gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, tol=1e-9, verbose=False, log=False): - """ + r""" Returns the entropic gromov-wasserstein discrepancy between the two measured similarity matrices (C1,p) and (C2,q) @@ -747,7 +747,7 @@ def entropic_gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon, def entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, max_iter=1000, tol=1e-9, verbose=False, log=False, init_C=None): - """ + r""" Returns the gromov-wasserstein barycenters of S measured similarity matrices (Cs)_{s=1}^{s=S} @@ -857,7 +857,7 @@ def entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, max_iter=1000, tol=1e-9, verbose=False, log=False, init_C=None): - """ + r""" Returns the gromov-wasserstein barycenters of S measured similarity matrices (Cs)_{s=1}^{s=S} diff --git a/ot/lp/cvx.py b/ot/lp/cvx.py index 8e763be..869d450 100644 --- a/ot/lp/cvx.py +++ b/ot/lp/cvx.py @@ -27,7 +27,7 @@ def scipy_sparse_to_spmatrix(A): def barycenter(A, M, weights=None, verbose=False, log=False, solver='interior-point'): - """Compute the Wasserstein barycenter of distributions A + r"""Compute the Wasserstein barycenter of distributions A The function solves the following optimization problem [16]: @@ -76,7 +76,6 @@ def barycenter(A, M, weights=None, verbose=False, log=False, solver='interior-po .. [16] Agueh, M., & Carlier, G. (2011). Barycenters in the Wasserstein space. SIAM Journal on Mathematical Analysis, 43(2), 904-924. - """ if weights is None: diff --git a/ot/optim.py b/ot/optim.py index 1902907..abe9e6a 100644 --- a/ot/optim.py +++ b/ot/optim.py @@ -139,7 +139,7 @@ def solve_linesearch(cost, G, deltaG, Mi, f_val, def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000, stopThr=1e-9, stopThr2=1e-9, verbose=False, log=False, **kwargs): - """ + r""" Solve the general regularized OT problem with conditional gradient The function solves the following optimization problem: @@ -278,7 +278,7 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000, def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10, numInnerItermax=200, stopThr=1e-9, stopThr2=1e-9, verbose=False, log=False): - """ + r""" Solve the general regularized OT problem with the generalized conditional gradient The function solves the following optimization problem: diff --git a/test/test_bregman.py b/test/test_bregman.py index 331acd3..1ebd21f 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -321,8 +321,9 @@ def test_implemented_methods(): # make dists unbalanced b = ot.utils.unif(n) A = rng.rand(n, 2) + A /= A.sum(0, keepdims=True) M = ot.dist(x, x) - epsilon = 1. + epsilon = 1.0 for method in IMPLEMENTED_METHODS: ot.bregman.sinkhorn(a, b, M, epsilon, method=method) -- cgit v1.2.3 From 0d995011b19b243bc980588cd98786b7c41a0509 Mon Sep 17 00:00:00 2001 From: Rémi Flamary Date: Wed, 21 Apr 2021 17:12:29 +0200 Subject: [MRG] Fixes issue #239 (deprecated numpy types) (#244) * remove warning numpy int? * use long long * stoupid mistake * cleanup double test run in PR from local branch --- .github/workflows/build_tests.yml | 4 +++- ot/lp/emd_wrap.pyx | 6 ++---- 2 files changed, 5 insertions(+), 5 deletions(-) (limited to 'ot/lp') diff --git a/.github/workflows/build_tests.yml b/.github/workflows/build_tests.yml index f4d55d1..2fc6770 100644 --- a/.github/workflows/build_tests.yml +++ b/.github/workflows/build_tests.yml @@ -3,9 +3,11 @@ name: Tests on: workflow_dispatch: pull_request: + branches: + - 'master' push: branches: - - '**' + - 'master' create: branches: - 'master' diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx index d79d0ca..de9a700 100644 --- a/ot/lp/emd_wrap.pyx +++ b/ot/lp/emd_wrap.pyx @@ -97,8 +97,6 @@ def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mod cdef np.ndarray[double, ndim=2, mode="c"] G=np.zeros([0, 0]) cdef np.ndarray[double, ndim=1, mode="c"] Gv=np.zeros(0) - cdef np.ndarray[long, ndim=1, mode="c"] iG=np.zeros(0,dtype=np.int) - cdef np.ndarray[long, ndim=1, mode="c"] jG=np.zeros(0,dtype=np.int) if not len(a): a=np.ones((n1,))/n1 @@ -169,8 +167,8 @@ def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights, cdef np.ndarray[double, ndim=1, mode="c"] G = np.zeros((n + m - 1, ), dtype=np.float64) - cdef np.ndarray[long, ndim=2, mode="c"] indices = np.zeros((n + m - 1, 2), - dtype=np.int) + cdef np.ndarray[long long, ndim=2, mode="c"] indices = np.zeros((n + m - 1, 2), + dtype=np.int64) cdef Py_ssize_t cur_idx = 0 while True: if metric == 'sqeuclidean': -- cgit v1.2.3 From 178c281fc91e014f5e148b7017430928d715de8c Mon Sep 17 00:00:00 2001 From: ncassereau-idris <84033440+ncassereau-idris@users.noreply.github.com> Date: Wed, 26 May 2021 16:56:00 +0200 Subject: Docs correction (#252) --- ot/lp/__init__.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) (limited to 'ot/lp') diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index f08e020..d5c3a5e 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -323,7 +323,7 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), The maximum number of iterations before stopping the optimization algorithm if it has not converged. log: boolean, optional (default=False) - If True, returns a dictionary containing the cost and dual + If True, returns a dictionary containing dual variables. Otherwise returns only the optimal transportation cost. return_matrix: boolean, optional (default=False) If True, returns the optimal transportation matrix in the log. @@ -333,10 +333,10 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), Returns ------- - gamma: (ns x nt) ndarray - Optimal transportation matrix for the given parameters + W: float + Optimal transportation loss for the given parameters log: dictnp - If input log is true, a dictionary containing the cost and dual + If input log is true, a dictionary containing dual variables and exit status -- cgit v1.2.3 From 184f8f4f7ac78f1dd7f653496d2753211a4e3426 Mon Sep 17 00:00:00 2001 From: Rémi Flamary Date: Tue, 1 Jun 2021 10:10:54 +0200 Subject: [MRG] POT numpy/torch/jax backends (#249) * add numpy and torch backends * stat sets on functions * proper import * install recent torch on windows * install recent torch on windows * now testing all functions in backedn * add jax backedn * clenaup windowds * proper convert for jax backedn * pep8 * try again windows tests * test jax conversion * try proper widows tests * emd fuction ses backedn * better test partial OT * proper tests to_numpy and teplate Backend * pep8 * pep8 x2 * feaking sinkhorn works with torch * sinkhorn2 compatible * working ot.emd2 * important detach * it should work * jax autodiff emd * pep8 * no tast same for jax * new independat tests per backedn * freaking pep8 * add tests for gradients * deprecate ot.gpu * worging dist function * working dist * dist done in backedn * not in * remove indexing * change accuacy for jax * first pull backend * projection simplex * projection simplex * projection simplex * projection simplex no ci * projection simplex no ci * projection simplex no ci * pep8 * add backedn discusion to quickstart guide * projection simplex no ci * projection simplex no ci * projection simplex no ci * pep8 + better doc * proper links * corect doctest * big debug documentation * doctest again * doctest again bis * doctest again ter (last one or i kill myself) * backend test + doc proj simplex * correction test_utils * correction test_utils * correction cumsum * correction flip * correction flip v2 * more debug * more debug * more debug + pep8 * pep8 * argh * proj_simplex * backedn works for sort * proj simplex * jax sucks * update doc * Update test/test_utils.py Co-authored-by: Alexandre Gramfort * Update docs/source/quickstart.rst Co-authored-by: Alexandre Gramfort * Update docs/source/quickstart.rst Co-authored-by: Alexandre Gramfort * Update docs/source/quickstart.rst Co-authored-by: Alexandre Gramfort * Update docs/source/readme.rst Co-authored-by: Alexandre Gramfort * Update test/test_utils.py Co-authored-by: Alexandre Gramfort * Update ot/utils.py Co-authored-by: Alexandre Gramfort * Update docs/source/readme.rst Co-authored-by: Alexandre Gramfort * Update ot/lp/__init__.py Co-authored-by: Alexandre Gramfort * begin comment alex * comment alex part 2 * optimize test gromov * proj_simplex on vectors * add awesome gradient decsnt example on the weights * pep98 of course * proof read example by alex * pep8 again * encoding oos in translation * correct legend Co-authored-by: Nicolas Courty Co-authored-by: Alexandre Gramfort --- .github/requirements_test_windows.txt | 10 + .github/workflows/build_tests.yml | 9 +- README.md | 8 +- docs/source/quickstart.rst | 68 +++- docs/source/readme.rst | 70 ++-- examples/README.txt | 2 +- examples/backends/README.txt | 4 + examples/backends/plot_unmix_optim_torch.py | 161 +++++++++ ot/__init__.py | 1 + ot/backend.py | 536 ++++++++++++++++++++++++++++ ot/bregman.py | 141 ++++---- ot/gpu/__init__.py | 4 +- ot/lp/__init__.py | 137 ++++--- ot/utils.py | 128 +++++-- requirements.txt | 3 + test/test_backend.py | 364 +++++++++++++++++++ test/test_bregman.py | 74 ++++ test/test_gromov.py | 10 +- test/test_ot.py | 91 ++++- test/test_partial.py | 4 +- test/test_utils.py | 76 +++- 21 files changed, 1692 insertions(+), 209 deletions(-) create mode 100644 .github/requirements_test_windows.txt create mode 100644 examples/backends/README.txt create mode 100644 examples/backends/plot_unmix_optim_torch.py create mode 100644 ot/backend.py create mode 100644 test/test_backend.py (limited to 'ot/lp') diff --git a/.github/requirements_test_windows.txt b/.github/requirements_test_windows.txt new file mode 100644 index 0000000..331dd57 --- /dev/null +++ b/.github/requirements_test_windows.txt @@ -0,0 +1,10 @@ +numpy +scipy>=1.3 +cython +matplotlib +autograd +pymanopt==0.2.4; python_version <'3' +pymanopt; python_version >= '3' +cvxopt +scikit-learn +pytest \ No newline at end of file diff --git a/.github/workflows/build_tests.yml b/.github/workflows/build_tests.yml index 2fc6770..92a07b5 100644 --- a/.github/workflows/build_tests.yml +++ b/.github/workflows/build_tests.yml @@ -40,7 +40,7 @@ jobs: pip install -e . - name: Run tests run: | - python -m pytest -v test/ ot/ --doctest-modules --ignore ot/gpu/ --cov=ot + python -m pytest -v test/ ot/ --doctest-modules --ignore ot/gpu/ --cov=ot --color=yes - name: Upload codecov run: | codecov @@ -142,11 +142,12 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install -r requirements.txt - pip install pytest "pytest-cov<2.6" + python -m pip install -r .github/requirements_test_windows.txt + python -m pip install torch==1.8.1+cpu -f https://download.pytorch.org/whl/torch_stable.html + python -m pip install pytest "pytest-cov<2.6" - name: Install POT run: | - pip install -e . + python -m pip install -e . - name: Run tests run: | python -m pytest -v test/ ot/ --doctest-modules --ignore ot/gpu/ --cov=ot diff --git a/README.md b/README.md index f5d18c1..e5e16e0 100644 --- a/README.md +++ b/README.md @@ -20,7 +20,7 @@ POT provides the following generic OT solvers (links to examples): * [OT Network Simplex solver](https://pythonot.github.io/auto_examples/plot_OT_1D.html) for the linear program/ Earth Movers Distance [1] . * [Conditional gradient](https://pythonot.github.io/auto_examples/plot_optim_OTreg.html) [6] and [Generalized conditional gradient](https://pythonot.github.io/auto_examples/plot_optim_OTreg.html) for regularized OT [7]. -* Entropic regularization OT solver with [Sinkhorn Knopp Algorithm](https://pythonot.github.io/auto_examples/plot_OT_1D.html) [2] , stabilized version [9] [10], greedy Sinkhorn [22] and [Screening Sinkhorn [26] ](https://pythonot.github.io/auto_examples/plot_screenkhorn_1D.html) with optional GPU implementation (requires cupy). +* Entropic regularization OT solver with [Sinkhorn Knopp Algorithm](https://pythonot.github.io/auto_examples/plot_OT_1D.html) [2] , stabilized version [9] [10], greedy Sinkhorn [22] and [Screening Sinkhorn [26] ](https://pythonot.github.io/auto_examples/plot_screenkhorn_1D.html). * Bregman projections for [Wasserstein barycenter](https://pythonot.github.io/auto_examples/barycenters/plot_barycenter_lp_vs_entropic.html) [3], [convolutional barycenter](https://pythonot.github.io/auto_examples/barycenters/plot_convolutional_barycenter.html) [21] and unmixing [4]. * Sinkhorn divergence [23] and entropic regularization OT from empirical data. * [Smooth optimal transport solvers](https://pythonot.github.io/auto_examples/plot_OT_1D_smooth.html) (dual and semi-dual) for KL and squared L2 regularizations [17]. @@ -33,6 +33,7 @@ POT provides the following generic OT solvers (links to examples): * [Partial Wasserstein and Gromov-Wasserstein](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_partial_wass_and_gromov.html) (exact [29] and entropic [3] formulations). * [Sliced Wasserstein](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_variance.html) [31, 32]. +* [Several backends](https://pythonot.github.io/quickstart.html#solving-ot-with-multiple-backends) for easy use of POT with [Pytorch](https://pytorch.org/)/[jax](https://github.com/google/jax)/[Numpy](https://numpy.org/) arrays. POT provides the following Machine Learning related solvers: @@ -77,8 +78,7 @@ The library has been tested on Linux, MacOSX and Windows. It requires a C++ comp - Numpy (>=1.16) - Scipy (>=1.0) -- Cython (>=0.23) -- Matplotlib (>=1.5) +- Cython (>=0.23) (build only, not necessary when installing wheels from pip or conda) #### Pip installation @@ -129,7 +129,7 @@ Some sub-modules require additional dependences which are discussed below pip install pymanopt autograd ``` -* **ot.gpu** (GPU accelerated OT) depends on cupy that have to be installed following instructions on [this page](https://docs-cupy.chainer.org/en/stable/install.html). Obviously you will need CUDA installed and a compatible GPU. +* **ot.gpu** (GPU accelerated OT) depends on cupy that have to be installed following instructions on [this page](https://docs-cupy.chainer.org/en/stable/install.html). Obviously you will need CUDA installed and a compatible GPU. Note that this module is deprecated since version 0.8 and will be deleted in the future. GPU is now handled automatically through the backends and several solver already can run on GPU using the Pytorch backend. ## Examples diff --git a/docs/source/quickstart.rst b/docs/source/quickstart.rst index cf5d6aa..fd046a1 100644 --- a/docs/source/quickstart.rst +++ b/docs/source/quickstart.rst @@ -15,6 +15,12 @@ are also available as notebooks on the POT Github. in ML applications we refer the reader to the following `OTML tutorial `_. +.. note:: + + Since version 0.8, POT provides a backend to automatically solve some OT + problems independently from the toolbox used by the user (numpy/torch/jax). + We provide a discussion about which functions are compatible in section + `Backend section <#solving-ot-with-multiple-backends>`_ . Why Optimal Transport ? @@ -158,7 +164,6 @@ Wasserstein but has better computational and `statistical properties `_. - Optimal transport and Wasserstein distance ------------------------------------------ @@ -922,6 +927,13 @@ The implementations of FGW and FGW barycenter is provided in functions GPU acceleration ^^^^^^^^^^^^^^^^ +.. warning:: + + The :any:`ot.gpu` has been deprecated since the release 0.8 of POT and + should not be used. The GPU implementation (in Pytorch for instance) can be + used with the novel backends using the compatible functions from POT. + + We provide several implementation of our OT solvers in :any:`ot.gpu`. Those implementations use the :code:`cupy` toolbox that obviously need to be installed. @@ -950,6 +962,60 @@ explicitly. use it you have to specifically import it with :code:`import ot.gpu` . +Solving OT with Multiple backends +--------------------------------- + +.. _backends_section: + +Since version 0.8, POT provides a backend that allows to code solvers +independently from the type of the input arrays. The idea is to provide the user +with a package that works seamlessly and returns a solution for instance as a +Pytorch tensors when the function has Pytorch tensors as input. + + +How it works +^^^^^^^^^^^^ + +The aim of the backend is to use the same function independently of the type of +the input arrays. + +For instance when executing the following code + +.. code:: python + + # a and b are 1D histograms (sum to 1 and positive) + # M is the ground cost matrix + T = ot.emd(a, b, M) # exact linear program + w = ot.emd2(a, b, M) # Wasserstein computation + +the functions :any:`ot.emd` and :any:`ot.emd2` can take inputs of the type +:any:`numpy.array`, :any:`torch.tensor` or :any:`jax.numpy.array`. The output of +the function will be the same type as the inputs and on the same device. When +possible all computations are done on the same device and also when possible the +output will be differentiable with respect to the input of the function. + + + +List of compatible Backends +^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +- `Numpy `_ (all functions and solvers) +- `Pytorch `_ (all outputs differentiable w.r.t. inputs) +- `Jax `_ (Some functions are differentiable some require a wrapper) + +List of compatible functions +^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +This list will get longer for new releases and will hopefully disappear when POT +become fully implemented with the backend. + +- :any:`ot.emd` +- :any:`ot.emd2` +- :any:`ot.sinkhorn` +- :any:`ot.sinkhorn2` +- :any:`ot.dist` + + FAQ --- diff --git a/docs/source/readme.rst b/docs/source/readme.rst index 3b594c2..82d3e6c 100644 --- a/docs/source/readme.rst +++ b/docs/source/readme.rst @@ -26,8 +26,7 @@ POT provides the following generic OT solvers (links to examples): Algorithm `__ [2] , stabilized version [9] [10], greedy Sinkhorn [22] and `Screening Sinkhorn - [26] `__ - with optional GPU implementation (requires cupy). + [26] `__. - Bregman projections for `Wasserstein barycenter `__ [3], `convolutional @@ -69,6 +68,11 @@ POT provides the following generic OT solvers (links to examples): - `Sliced Wasserstein `__ [31, 32]. +- `Several + backends `__ + for easy use of POT with + `Pytorch `__/`jax `__/`Numpy `__ + arrays. POT provides the following Machine Learning related solvers: @@ -104,12 +108,14 @@ paper `__: :: - Rémi Flamary, Nicolas Courty, Alexandre Gramfort, Mokhtar Z. Alaya, Aurélie Boisbunon, Stanislas Chambon, Laetitia Chapel, Adrien Corenflos, Kilian Fatras, Nemo Fournier, Léo Gautheron, Nathalie T.H. Gayraud, Hicham Janati, Alain Rakotomamonjy, Ievgen Redko, Antoine Rolet, Antony Schutz, Vivien Seguy, Danica J. Sutherland, Romain Tavenard, Alexander Tong, Titouan Vayer;, POT Python Optimal Transport library, Journal of Machine Learning Research, 22(78):1−8, 2021. + Rémi Flamary, Nicolas Courty, Alexandre Gramfort, Mokhtar Z. Alaya, Aurélie Boisbunon, Stanislas Chambon, Laetitia Chapel, Adrien Corenflos, Kilian Fatras, Nemo Fournier, Léo Gautheron, Nathalie T.H. Gayraud, Hicham Janati, Alain Rakotomamonjy, Ievgen Redko, Antoine Rolet, Antony Schutz, Vivien Seguy, Danica J. Sutherland, Romain Tavenard, Alexander Tong, Titouan Vayer, + POT Python Optimal Transport library, + Journal of Machine Learning Research, 22(78):1−8, 2021. Website: https://pythonot.github.io/ In Bibtex format: -:: +.. code:: bibtex @article{flamary2021pot, author = {R{\'e}mi Flamary and Nicolas Courty and Alexandre Gramfort and Mokhtar Z. Alaya and Aur{\'e}lie Boisbunon and Stanislas Chambon and Laetitia Chapel and Adrien Corenflos and Kilian Fatras and Nemo Fournier and L{\'e}o Gautheron and Nathalie T.H. Gayraud and Hicham Janati and Alain Rakotomamonjy and Ievgen Redko and Antoine Rolet and Antony Schutz and Vivien Seguy and Danica J. Sutherland and Romain Tavenard and Alexander Tong and Titouan Vayer}, @@ -131,8 +137,8 @@ following Python modules: - Numpy (>=1.16) - Scipy (>=1.0) -- Cython (>=0.23) -- Matplotlib (>=1.5) +- Cython (>=0.23) (build only, not necessary when installing wheels + from pip or conda) Pip installation ^^^^^^^^^^^^^^^^ @@ -140,19 +146,19 @@ Pip installation Note that due to a limitation of pip, ``cython`` and ``numpy`` need to be installed prior to installing POT. This can be done easily with -:: +.. code:: console pip install numpy cython You can install the toolbox through PyPI with: -:: +.. code:: console pip install POT or get the very latest version by running: -:: +.. code:: console pip install -U https://github.com/PythonOT/POT/archive/master.zip # with --user for user install (no root) @@ -163,7 +169,7 @@ If you use the Anaconda python distribution, POT is available in `conda-forge `__. To install it and the required dependencies: -:: +.. code:: console conda install -c conda-forge pot @@ -188,15 +194,17 @@ below - **ot.dr** (Wasserstein dimensionality reduction) depends on autograd and pymanopt that can be installed with: - :: +.. code:: shell - pip install pymanopt autograd + pip install pymanopt autograd - **ot.gpu** (GPU accelerated OT) depends on cupy that have to be installed following instructions on `this page `__. - -obviously you need CUDA installed and a compatible GPU. + Obviously you will need CUDA installed and a compatible GPU. Note + that this module is deprecated since version 0.8 and will be deleted + in the future. GPU is now handled automatically through the backends + and several solver already can run on GPU using the Pytorch backend. Examples -------- @@ -206,36 +214,36 @@ Short examples - Import the toolbox - .. code:: python +.. code:: python - import ot + import ot - Compute Wasserstein distances - .. code:: python +.. code:: python - # a,b are 1D histograms (sum to 1 and positive) - # M is the ground cost matrix - Wd=ot.emd2(a,b,M) # exact linear program - Wd_reg=ot.sinkhorn2(a,b,M,reg) # entropic regularized OT - # if b is a matrix compute all distances to a and return a vector + # a and b are 1D histograms (sum to 1 and positive) + # M is the ground cost matrix + Wd = ot.emd2(a, b, M) # exact linear program + Wd_reg = ot.sinkhorn2(a, b, M, reg) # entropic regularized OT + # if b is a matrix compute all distances to a and return a vector - Compute OT matrix - .. code:: python +.. code:: python - # a,b are 1D histograms (sum to 1 and positive) - # M is the ground cost matrix - T=ot.emd(a,b,M) # exact linear program - T_reg=ot.sinkhorn(a,b,M,reg) # entropic regularized OT + # a and b are 1D histograms (sum to 1 and positive) + # M is the ground cost matrix + T = ot.emd(a, b, M) # exact linear program + T_reg = ot.sinkhorn(a, b, M, reg) # entropic regularized OT - Compute Wasserstein barycenter - .. code:: python +.. code:: python - # A is a n*d matrix containing d 1D histograms - # M is the ground cost matrix - ba=ot.barycenter(A,M,reg) # reg is regularization parameter + # A is a n*d matrix containing d 1D histograms + # M is the ground cost matrix + ba = ot.barycenter(A, M, reg) # reg is regularization parameter Examples and Notebooks ~~~~~~~~~~~~~~~~~~~~~~ diff --git a/examples/README.txt b/examples/README.txt index 69a9f84..b48487f 100644 --- a/examples/README.txt +++ b/examples/README.txt @@ -1,7 +1,7 @@ Examples gallery ================ -This is a gallery of all the POT example files. +This is a gallery of all the POT example files. OT and regularized OT diff --git a/examples/backends/README.txt b/examples/backends/README.txt new file mode 100644 index 0000000..3ee0e27 --- /dev/null +++ b/examples/backends/README.txt @@ -0,0 +1,4 @@ + + +POT backend examples +-------------------- \ No newline at end of file diff --git a/examples/backends/plot_unmix_optim_torch.py b/examples/backends/plot_unmix_optim_torch.py new file mode 100644 index 0000000..9ae66e9 --- /dev/null +++ b/examples/backends/plot_unmix_optim_torch.py @@ -0,0 +1,161 @@ +# -*- coding: utf-8 -*- +r""" +================================= +Wasserstein unmixing with PyTorch +================================= + +In this example we estimate mixing parameters from distributions that minimize +the Wasserstein distance. In other words we suppose that a target +distribution :math:`\mu^t` can be expressed as a weighted sum of source +distributions :math:`\mu^s_k` with the following model: + +.. math:: + \mu^t = \sum_{k=1}^K w_k\mu^s_k + +where :math:`\mathbf{w}` is a vector of size :math:`K` and belongs in the +distribution simplex :math:`\Delta_K`. + +In order to estimate this weight vector we propose to optimize the Wasserstein +distance between the model and the observed :math:`\mu^t` with respect to +the vector. This leads to the following optimization problem: + +.. math:: + \min_{\mathbf{w}\in\Delta_K} \quad W \left(\mu^t,\sum_{k=1}^K w_k\mu^s_k\right) + +This minimization is done in this example with a simple projected gradient +descent in PyTorch. We use the automatic backend of POT that allows us to +compute the Wasserstein distance with :any:`ot.emd2` with +differentiable losses. + +""" + +# Author: Remi Flamary +# +# License: MIT License + +# sphinx_gallery_thumbnail_number = 2 + +import numpy as np +import matplotlib.pylab as pl +import ot +import torch + + +############################################################################## +# Generate data +# ------------- + +#%% Data + +nt = 100 +nt1 = 10 # + +ns1 = 50 +ns = 2 * ns1 + +rng = np.random.RandomState(2) + +xt = rng.randn(nt, 2) * 0.2 +xt[:nt1, 0] += 1 +xt[nt1:, 1] += 1 + + +xs1 = rng.randn(ns1, 2) * 0.2 +xs1[:, 0] += 1 +xs2 = rng.randn(ns1, 2) * 0.2 +xs2[:, 1] += 1 + +xs = np.concatenate((xs1, xs2)) + +# Sample reweighting matrix H +H = np.zeros((ns, 2)) +H[:ns1, 0] = 1 / ns1 +H[ns1:, 1] = 1 / ns1 +# each columns sums to 1 and has weights only for samples form the +# corresponding source distribution + +M = ot.dist(xs, xt) + +############################################################################## +# Plot data +# --------- + +#%% plot the distributions + +pl.figure(1) +pl.scatter(xt[:, 0], xt[:, 1], label='Target $\mu^t$', alpha=0.5) +pl.scatter(xs1[:, 0], xs1[:, 1], label='Source $\mu^s_1$', alpha=0.5) +pl.scatter(xs2[:, 0], xs2[:, 1], label='Source $\mu^s_2$', alpha=0.5) +pl.title('Sources and Target distributions') +pl.legend() + + +############################################################################## +# Optimization of the model wrt the Wasserstein distance +# ------------------------------------------------------ + + +#%% Weights optimization with gradient descent + +# convert numpy arrays to torch tensors +H2 = torch.tensor(H) +M2 = torch.tensor(M) + +# weights for the source distributions +w = torch.tensor(ot.unif(2), requires_grad=True) + +# uniform weights for target +b = torch.tensor(ot.unif(nt)) + +lr = 2e-3 # learning rate +niter = 500 # number of iterations +losses = [] # loss along the iterations + +# loss for the minimal Wasserstein estimator + + +def get_loss(w): + a = torch.mv(H2, w) # distribution reweighting + return ot.emd2(a, b, M2) # squared Wasserstein 2 + + +for i in range(niter): + + loss = get_loss(w) + losses.append(float(loss)) + + loss.backward() + + with torch.no_grad(): + w -= lr * w.grad # gradient step + w[:] = ot.utils.proj_simplex(w) # projection on the simplex + + w.grad.zero_() + + +############################################################################## +# Estimated weights and convergence of the objective +# --------------------------------------------------- + +we = w.detach().numpy() +print('Estimated mixture:', we) + +pl.figure(2) +pl.semilogy(losses) +pl.grid() +pl.title('Wasserstein distance') +pl.xlabel("Iterations") + +############################################################################## +# Ploting the reweighted source distribution +# ------------------------------------------ + +pl.figure(3) + +# compute source weights +ws = H.dot(we) + +pl.scatter(xt[:, 0], xt[:, 1], label='Target $\mu^t$', alpha=0.5) +pl.scatter(xs[:, 0], xs[:, 1], color='C3', s=ws * 20 * ns, label='Weighted sources $\sum_{k} w_k\mu^s_k$', alpha=0.5) +pl.title('Target and reweighted source distributions') +pl.legend() diff --git a/ot/__init__.py b/ot/__init__.py index 5a8a415..3b072c6 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -33,6 +33,7 @@ from . import smooth from . import stochastic from . import unbalanced from . import partial +from . import backend # OT functions from .lp import emd, emd2, emd_1d, emd2_1d, wasserstein_1d diff --git a/ot/backend.py b/ot/backend.py new file mode 100644 index 0000000..d68f5cf --- /dev/null +++ b/ot/backend.py @@ -0,0 +1,536 @@ +# -*- coding: utf-8 -*- +""" +Multi-lib backend for POT +""" + +# Author: Remi Flamary +# Nicolas Courty +# +# License: MIT License + +import numpy as np + +try: + import torch + torch_type = torch.Tensor +except ImportError: + torch = False + torch_type = float + +try: + import jax + import jax.numpy as jnp + jax_type = jax.numpy.ndarray +except ImportError: + jax = False + jax_type = float + +str_type_error = "All array should be from the same type/backend. Current types are : {}" + + +def get_backend_list(): + """ returns the list of available backends)""" + lst = [NumpyBackend(), ] + + if torch: + lst.append(TorchBackend()) + + if jax: + lst.append(JaxBackend()) + + return lst + + +def get_backend(*args): + """returns the proper backend for a list of input arrays + + Also raises TypeError if all arrays are not from the same backend + """ + # check that some arrays given + if not len(args) > 0: + raise ValueError(" The function takes at least one parameter") + # check all same type + + if isinstance(args[0], np.ndarray): + if not len(set(type(a) for a in args)) == 1: + raise ValueError(str_type_error.format([type(a) for a in args])) + return NumpyBackend() + elif torch and isinstance(args[0], torch_type): + if not len(set(type(a) for a in args)) == 1: + raise ValueError(str_type_error.format([type(a) for a in args])) + return TorchBackend() + elif isinstance(args[0], jax_type): + return JaxBackend() + else: + raise ValueError("Unknown type of non implemented backend.") + + +def to_numpy(*args): + """returns numpy arrays from any compatible backend""" + + if len(args) == 1: + return get_backend(args[0]).to_numpy(args[0]) + else: + return [get_backend(a).to_numpy(a) for a in args] + + +class Backend(): + + __name__ = None + __type__ = None + + def __str__(self): + return self.__name__ + + # convert to numpy + def to_numpy(self, a): + raise NotImplementedError() + + # convert from numpy + def from_numpy(self, a, type_as=None): + raise NotImplementedError() + + def set_gradients(self, val, inputs, grads): + """ define the gradients for the value val wrt the inputs """ + raise NotImplementedError() + + def zeros(self, shape, type_as=None): + raise NotImplementedError() + + def ones(self, shape, type_as=None): + raise NotImplementedError() + + def arange(self, stop, start=0, step=1, type_as=None): + raise NotImplementedError() + + def full(self, shape, fill_value, type_as=None): + raise NotImplementedError() + + def eye(self, N, M=None, type_as=None): + raise NotImplementedError() + + def sum(self, a, axis=None, keepdims=False): + raise NotImplementedError() + + def cumsum(self, a, axis=None): + raise NotImplementedError() + + def max(self, a, axis=None, keepdims=False): + raise NotImplementedError() + + def min(self, a, axis=None, keepdims=False): + raise NotImplementedError() + + def maximum(self, a, b): + raise NotImplementedError() + + def minimum(self, a, b): + raise NotImplementedError() + + def dot(self, a, b): + raise NotImplementedError() + + def abs(self, a): + raise NotImplementedError() + + def exp(self, a): + raise NotImplementedError() + + def log(self, a): + raise NotImplementedError() + + def sqrt(self, a): + raise NotImplementedError() + + def norm(self, a): + raise NotImplementedError() + + def any(self, a): + raise NotImplementedError() + + def isnan(self, a): + raise NotImplementedError() + + def isinf(self, a): + raise NotImplementedError() + + def einsum(self, subscripts, *operands): + raise NotImplementedError() + + def sort(self, a, axis=-1): + raise NotImplementedError() + + def argsort(self, a, axis=None): + raise NotImplementedError() + + def flip(self, a, axis=None): + raise NotImplementedError() + + +class NumpyBackend(Backend): + + __name__ = 'numpy' + __type__ = np.ndarray + + def to_numpy(self, a): + return a + + def from_numpy(self, a, type_as=None): + if type_as is None: + return a + elif isinstance(a, float): + return a + else: + return a.astype(type_as.dtype) + + def set_gradients(self, val, inputs, grads): + # no gradients for numpy + return val + + def zeros(self, shape, type_as=None): + if type_as is None: + return np.zeros(shape) + else: + return np.zeros(shape, dtype=type_as.dtype) + + def ones(self, shape, type_as=None): + if type_as is None: + return np.ones(shape) + else: + return np.ones(shape, dtype=type_as.dtype) + + def arange(self, stop, start=0, step=1, type_as=None): + return np.arange(start, stop, step) + + def full(self, shape, fill_value, type_as=None): + if type_as is None: + return np.full(shape, fill_value) + else: + return np.full(shape, fill_value, dtype=type_as.dtype) + + def eye(self, N, M=None, type_as=None): + if type_as is None: + return np.eye(N, M) + else: + return np.eye(N, M, dtype=type_as.dtype) + + def sum(self, a, axis=None, keepdims=False): + return np.sum(a, axis, keepdims=keepdims) + + def cumsum(self, a, axis=None): + return np.cumsum(a, axis) + + def max(self, a, axis=None, keepdims=False): + return np.max(a, axis, keepdims=keepdims) + + def min(self, a, axis=None, keepdims=False): + return np.min(a, axis, keepdims=keepdims) + + def maximum(self, a, b): + return np.maximum(a, b) + + def minimum(self, a, b): + return np.minimum(a, b) + + def dot(self, a, b): + return np.dot(a, b) + + def abs(self, a): + return np.abs(a) + + def exp(self, a): + return np.exp(a) + + def log(self, a): + return np.log(a) + + def sqrt(self, a): + return np.sqrt(a) + + def norm(self, a): + return np.sqrt(np.sum(np.square(a))) + + def any(self, a): + return np.any(a) + + def isnan(self, a): + return np.isnan(a) + + def isinf(self, a): + return np.isinf(a) + + def einsum(self, subscripts, *operands): + return np.einsum(subscripts, *operands) + + def sort(self, a, axis=-1): + return np.sort(a, axis) + + def argsort(self, a, axis=-1): + return np.argsort(a, axis) + + def flip(self, a, axis=None): + return np.flip(a, axis) + + +class JaxBackend(Backend): + + __name__ = 'jax' + __type__ = jax_type + + def to_numpy(self, a): + return np.array(a) + + def from_numpy(self, a, type_as=None): + if type_as is None: + return jnp.array(a) + else: + return jnp.array(a).astype(type_as.dtype) + + def set_gradients(self, val, inputs, grads): + # no gradients for jax because it is functional + + # does not work + # from jax import custom_jvp + # @custom_jvp + # def f(*inputs): + # return val + # f.defjvps(*grads) + # return f(*inputs) + + return val + + def zeros(self, shape, type_as=None): + if type_as is None: + return jnp.zeros(shape) + else: + return jnp.zeros(shape, dtype=type_as.dtype) + + def ones(self, shape, type_as=None): + if type_as is None: + return jnp.ones(shape) + else: + return jnp.ones(shape, dtype=type_as.dtype) + + def arange(self, stop, start=0, step=1, type_as=None): + return jnp.arange(start, stop, step) + + def full(self, shape, fill_value, type_as=None): + if type_as is None: + return jnp.full(shape, fill_value) + else: + return jnp.full(shape, fill_value, dtype=type_as.dtype) + + def eye(self, N, M=None, type_as=None): + if type_as is None: + return jnp.eye(N, M) + else: + return jnp.eye(N, M, dtype=type_as.dtype) + + def sum(self, a, axis=None, keepdims=False): + return jnp.sum(a, axis, keepdims=keepdims) + + def cumsum(self, a, axis=None): + return jnp.cumsum(a, axis) + + def max(self, a, axis=None, keepdims=False): + return jnp.max(a, axis, keepdims=keepdims) + + def min(self, a, axis=None, keepdims=False): + return jnp.min(a, axis, keepdims=keepdims) + + def maximum(self, a, b): + return jnp.maximum(a, b) + + def minimum(self, a, b): + return jnp.minimum(a, b) + + def dot(self, a, b): + return jnp.dot(a, b) + + def abs(self, a): + return jnp.abs(a) + + def exp(self, a): + return jnp.exp(a) + + def log(self, a): + return jnp.log(a) + + def sqrt(self, a): + return jnp.sqrt(a) + + def norm(self, a): + return jnp.sqrt(jnp.sum(jnp.square(a))) + + def any(self, a): + return jnp.any(a) + + def isnan(self, a): + return jnp.isnan(a) + + def isinf(self, a): + return jnp.isinf(a) + + def einsum(self, subscripts, *operands): + return jnp.einsum(subscripts, *operands) + + def sort(self, a, axis=-1): + return jnp.sort(a, axis) + + def argsort(self, a, axis=-1): + return jnp.argsort(a, axis) + + def flip(self, a, axis=None): + return jnp.flip(a, axis) + + +class TorchBackend(Backend): + + __name__ = 'torch' + __type__ = torch_type + + def to_numpy(self, a): + return a.cpu().detach().numpy() + + def from_numpy(self, a, type_as=None): + if type_as is None: + return torch.from_numpy(a) + else: + return torch.as_tensor(a, dtype=type_as.dtype, device=type_as.device) + + def set_gradients(self, val, inputs, grads): + from torch.autograd import Function + + # define a function that takes inputs and return val + class ValFunction(Function): + @staticmethod + def forward(ctx, *inputs): + return val + + @staticmethod + def backward(ctx, grad_output): + # the gradients are grad + return grads + + return ValFunction.apply(*inputs) + + def zeros(self, shape, type_as=None): + if type_as is None: + return torch.zeros(shape) + else: + return torch.zeros(shape, dtype=type_as.dtype, device=type_as.device) + + def ones(self, shape, type_as=None): + if type_as is None: + return torch.ones(shape) + else: + return torch.ones(shape, dtype=type_as.dtype, device=type_as.device) + + def arange(self, stop, start=0, step=1, type_as=None): + if type_as is None: + return torch.arange(start, stop, step) + else: + return torch.arange(start, stop, step, device=type_as.device) + + def full(self, shape, fill_value, type_as=None): + if type_as is None: + return torch.full(shape, fill_value) + else: + return torch.full(shape, fill_value, dtype=type_as.dtype, device=type_as.device) + + def eye(self, N, M=None, type_as=None): + if M is None: + M = N + if type_as is None: + return torch.eye(N, m=M) + else: + return torch.eye(N, m=M, dtype=type_as.dtype, device=type_as.device) + + def sum(self, a, axis=None, keepdims=False): + if axis is None: + return torch.sum(a) + else: + return torch.sum(a, axis, keepdim=keepdims) + + def cumsum(self, a, axis=None): + if axis is None: + return torch.cumsum(a.flatten(), 0) + else: + return torch.cumsum(a, axis) + + def max(self, a, axis=None, keepdims=False): + if axis is None: + return torch.max(a) + else: + return torch.max(a, axis, keepdim=keepdims)[0] + + def min(self, a, axis=None, keepdims=False): + if axis is None: + return torch.min(a) + else: + return torch.min(a, axis, keepdim=keepdims)[0] + + def maximum(self, a, b): + if isinstance(a, int) or isinstance(a, float): + a = torch.tensor([float(a)], dtype=b.dtype, device=b.device) + if isinstance(b, int) or isinstance(b, float): + b = torch.tensor([float(b)], dtype=a.dtype, device=a.device) + return torch.maximum(a, b) + + def minimum(self, a, b): + if isinstance(a, int) or isinstance(a, float): + a = torch.tensor([float(a)], dtype=b.dtype, device=b.device) + if isinstance(b, int) or isinstance(b, float): + b = torch.tensor([float(b)], dtype=a.dtype, device=a.device) + return torch.minimum(a, b) + + def dot(self, a, b): + if len(a.shape) == len(b.shape) == 1: + return torch.dot(a, b) + elif len(a.shape) == 2 and len(b.shape) == 1: + return torch.mv(a, b) + else: + return torch.mm(a, b) + + def abs(self, a): + return torch.abs(a) + + def exp(self, a): + return torch.exp(a) + + def log(self, a): + return torch.log(a) + + def sqrt(self, a): + return torch.sqrt(a) + + def norm(self, a): + return torch.sqrt(torch.sum(torch.square(a))) + + def any(self, a): + return torch.any(a) + + def isnan(self, a): + return torch.isnan(a) + + def isinf(self, a): + return torch.isinf(a) + + def einsum(self, subscripts, *operands): + return torch.einsum(subscripts, *operands) + + def sort(self, a, axis=-1): + sorted0, indices = torch.sort(a, dim=axis) + return sorted0 + + def argsort(self, a, axis=-1): + sorted, indices = torch.sort(a, dim=axis) + return indices + + def flip(self, a, axis=None): + if axis is None: + return torch.flip(a, tuple(i for i in range(len(a.shape)))) + if isinstance(axis, int): + return torch.flip(a, (axis,)) + else: + return torch.flip(a, dims=axis) diff --git a/ot/bregman.py b/ot/bregman.py index 559db14..b10effd 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -19,7 +19,8 @@ import warnings import numpy as np from scipy.optimize import fmin_l_bfgs_b -from ot.utils import unif, dist +from ot.utils import unif, dist, list_to_array +from .backend import get_backend def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, @@ -43,17 +44,36 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - a and b are source and target weights (histograms, both sum to 1) - The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [2]_ + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. + + The algorithm used for solving the problem is the Sinkhorn-Knopp matrix + scaling algorithm as proposed in [2]_ + + **Choosing a Sinkhorn solver** + + By default and when using a regularization parameter that is not too small + the default sinkhorn solver should be enough. If you need to use a small + regularization to get sharper OT matrices, you should use the + :any:`ot.bregman.sinkhorn_stabilized` solver that will avoid numerical + errors. This last solver can be very slow in practice and might not even + converge to a reasonable OT matrix in a finite time. This is why + :any:`ot.bregman.sinkhorn_epsilon_scaling` that relies on iterating the value + of the regularization (and using warm start) sometimes leads to better + solutions. Note that the greedy version of the sinkhorn + :any:`ot.bregman.greenkhorn` can also lead to a speedup and the screening + version of the sinkhorn :any:`ot.bregman.screenkhorn` aim a providing a + fast approximation of the Sinkhorn problem. Parameters ---------- - a : ndarray, shape (dim_a,) + a : array-like, shape (dim_a,) samples weights in the source domain - b : ndarray, shape (dim_b,) or ndarray, shape (dim_b, n_hists) + b : array-like, shape (dim_b,) or ndarray, shape (dim_b, n_hists) samples in the target domain, compute sinkhorn with multiple targets and fixed M if b is a matrix (return OT loss + dual variables in log) - M : ndarray, shape (dim_a, dim_b) + M : array-like, shape (dim_a, dim_b) loss matrix reg : float Regularization term >0 @@ -69,25 +89,9 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, log : bool, optional record log if True - **Choosing a Sinkhorn solver** - - By default and when using a regularization parameter that is not too small - the default sinkhorn solver should be enough. If you need to use a small - regularization to get sharper OT matrices, you should use the - :any:`ot.bregman.sinkhorn_stabilized` solver that will avoid numerical - errors. This last solver can be very slow in practice and might not even - converge to a reasonable OT matrix in a finite time. This is why - :any:`ot.bregman.sinkhorn_epsilon_scaling` that relie on iterating the value - of the regularization (and using warm start) sometimes leads to better - solutions. Note that the greedy version of the sinkhorn - :any:`ot.bregman.greenkhorn` can also lead to a speedup and the screening - version of the sinkhorn :any:`ot.bregman.screenkhorn` aim a providing a - fast approximation of the Sinkhorn problem. - - Returns ------- - gamma : ndarray, shape (dim_a, dim_b) + gamma : array-like, shape (dim_a, dim_b) Optimal transportation matrix for the given parameters log : dict log dictionary return only if log==True in parameters @@ -166,17 +170,35 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - a and b are source and target weights (histograms, both sum to 1) + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. + The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [2]_ + **Choosing a Sinkhorn solver** + + By default and when using a regularization parameter that is not too small + the default sinkhorn solver should be enough. If you need to use a small + regularization to get sharper OT matrices, you should use the + :any:`ot.bregman.sinkhorn_stabilized` solver that will avoid numerical + errors. This last solver can be very slow in practice and might not even + converge to a reasonable OT matrix in a finite time. This is why + :any:`ot.bregman.sinkhorn_epsilon_scaling` that relies on iterating the value + of the regularization (and using warm start) sometimes leads to better + solutions. Note that the greedy version of the sinkhorn + :any:`ot.bregman.greenkhorn` can also lead to a speedup and the screening + version of the sinkhorn :any:`ot.bregman.screenkhorn` aim a providing a + fast approximation of the Sinkhorn problem. + Parameters ---------- - a : ndarray, shape (dim_a,) + a : array-like, shape (dim_a,) samples weights in the source domain - b : ndarray, shape (dim_b,) or ndarray, shape (dim_b, n_hists) + b : array-like, shape (dim_b,) or ndarray, shape (dim_b, n_hists) samples in the target domain, compute sinkhorn with multiple targets and fixed M if b is a matrix (return OT loss + dual variables in log) - M : ndarray, shape (dim_a, dim_b) + M : array-like, shape (dim_a, dim_b) loss matrix reg : float Regularization term >0 @@ -191,28 +213,14 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, log : bool, optional record log if True - **Choosing a Sinkhorn solver** - - By default and when using a regularization parameter that is not too small - the default sinkhorn solver should be enough. If you need to use a small - regularization to get sharper OT matrices, you should use the - :any:`ot.bregman.sinkhorn_stabilized` solver that will avoid numerical - errors. This last solver can be very slow in practice and might not even - converge to a reasonable OT matrix in a finite time. This is why - :any:`ot.bregman.sinkhorn_epsilon_scaling` that relie on iterating the value - of the regularization (and using warm start) sometimes leads to better - solutions. Note that the greedy version of the sinkhorn - :any:`ot.bregman.greenkhorn` can also lead to a speedup and the screening - version of the sinkhorn :any:`ot.bregman.screenkhorn` aim a providing a - fast approximation of the Sinkhorn problem. - Returns ------- - W : (n_hists) ndarray + W : (n_hists) float/array-like Optimal transportation loss for the given parameters log : dict log dictionary return only if log==True in parameters + Examples -------- @@ -247,7 +255,8 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, ot.bregman.sinkhorn_stabilized: Stabilized sinkhorn [9][10] """ - b = np.asarray(b, dtype=np.float64) + + b = list_to_array(b) if len(b.shape) < 2: b = b[:, None] @@ -339,14 +348,14 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, """ - a = np.asarray(a, dtype=np.float64) - b = np.asarray(b, dtype=np.float64) - M = np.asarray(M, dtype=np.float64) + a, b, M = list_to_array(a, b, M) + + nx = get_backend(M, a, b) if len(a) == 0: - a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0] + a = nx.full((M.shape[0],), 1.0 / M.shape[0], type_as=M) if len(b) == 0: - b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1] + b = nx.full((M.shape[1],), 1.0 / M.shape[1], type_as=M) # init data dim_a = len(a) @@ -363,21 +372,13 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, # we assume that no distances are null except those of the diagonal of # distances if n_hists: - u = np.ones((dim_a, n_hists)) / dim_a - v = np.ones((dim_b, n_hists)) / dim_b + u = nx.ones((dim_a, n_hists), type_as=M) / dim_a + v = nx.ones((dim_b, n_hists), type_as=M) / dim_b else: - u = np.ones(dim_a) / dim_a - v = np.ones(dim_b) / dim_b + u = nx.ones(dim_a, type_as=M) / dim_a + v = nx.ones(dim_b, type_as=M) / dim_b - # print(reg) - - # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute - K = np.empty(M.shape, dtype=M.dtype) - np.divide(M, -reg, out=K) - np.exp(K, out=K) - - # print(np.min(K)) - tmp2 = np.empty(b.shape, dtype=M.dtype) + K = nx.exp(M / (-reg)) Kp = (1 / a).reshape(-1, 1) * K cpt = 0 @@ -386,13 +387,13 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, uprev = u vprev = v - KtransposeU = np.dot(K.T, u) - v = np.divide(b, KtransposeU) - u = 1. / np.dot(Kp, v) + KtransposeU = nx.dot(K.T, u) + v = b / KtransposeU + u = 1. / nx.dot(Kp, v) - if (np.any(KtransposeU == 0) - or np.any(np.isnan(u)) or np.any(np.isnan(v)) - or np.any(np.isinf(u)) or np.any(np.isinf(v))): + if (nx.any(KtransposeU == 0) + or nx.any(nx.isnan(u)) or nx.any(nx.isnan(v)) + or nx.any(nx.isinf(u)) or nx.any(nx.isinf(v))): # we have reached the machine precision # come back to previous solution and quit loop print('Warning: numerical errors at iteration', cpt) @@ -403,11 +404,11 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, # we can speed up the process by checking for the error only all # the 10th iterations if n_hists: - np.einsum('ik,ij,jk->jk', u, K, v, out=tmp2) + tmp2 = nx.einsum('ik,ij,jk->jk', u, K, v) else: # compute right marginal tmp2= (diag(u)Kdiag(v))^T1 - np.einsum('i,ij,j->j', u, K, v, out=tmp2) - err = np.linalg.norm(tmp2 - b) # violation of marginal + tmp2 = nx.einsum('i,ij,j->j', u, K, v) + err = nx.norm(tmp2 - b) # violation of marginal if log: log['err'].append(err) @@ -422,7 +423,7 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, log['v'] = v if n_hists: # return only loss - res = np.einsum('ik,ij,jk,ij->k', u, K, v, M) + res = nx.einsum('ik,ij,jk,ij->k', u, K, v, M) if log: return res, log else: diff --git a/ot/gpu/__init__.py b/ot/gpu/__init__.py index 7478fb9..e939610 100644 --- a/ot/gpu/__init__.py +++ b/ot/gpu/__init__.py @@ -25,6 +25,8 @@ result of the function with parameter ``to_numpy=False``. # # License: MIT License +import warnings + from . import bregman from . import da from .bregman import sinkhorn @@ -34,7 +36,7 @@ from . import utils from .utils import dist, to_gpu, to_np - +warnings.warn('This module will be deprecated in the next minor release of POT', category=DeprecationWarning) __all__ = ["utils", "dist", "sinkhorn", diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index d5c3a5e..c8c9da6 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -18,8 +18,9 @@ from . import cvx from .cvx import barycenter # import compiled emd from .emd_wrap import emd_c, check_result, emd_1d_sorted -from ..utils import dist +from ..utils import dist, list_to_array from ..utils import parmap +from ..backend import get_backend __all__ = ['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx', 'emd_1d', 'emd2_1d', 'wasserstein_1d'] @@ -176,8 +177,7 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True): r"""Solves the Earth Movers distance problem and returns the OT matrix - .. math:: - \gamma = arg\min_\gamma <\gamma,M>_F + .. math:: \gamma = arg\min_\gamma <\gamma,M>_F s.t. \gamma 1 = a @@ -189,37 +189,41 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True): - M is the metric cost matrix - a and b are the sample weights - .. warning:: - Note that the M matrix needs to be a C-order numpy.array in float64 - format. + .. warning:: Note that the M matrix in numpy needs to be a C-order + numpy.array in float64 format. It will be converted if not in this + format + + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. Uses the algorithm proposed in [1]_ Parameters ---------- - a : (ns,) numpy.ndarray, float64 + a : (ns,) array-like, float Source histogram (uniform weight if empty list) - b : (nt,) numpy.ndarray, float64 - Target histogram (uniform weight if empty list) - M : (ns,nt) numpy.ndarray, float64 - Loss matrix (c-order array with type float64) - numItermax : int, optional (default=100000) + b : (nt,) array-like, float + Target histogram (uniform weight if empty list) + M : (ns,nt) array-like, float + Loss matrix (c-order array in numpy with type float64) + numItermax : int, optional (default=100000) The maximum number of iterations before stopping the optimization - algorithm if it has not converged. - log: bool, optional (default=False) - If True, returns a dictionary containing the cost and dual - variables. Otherwise returns only the optimal transportation matrix. + algorithm if it has not converged. + log: bool, optional (default=False) + If True, returns a dictionary containing the cost and dual variables. + Otherwise returns only the optimal transportation matrix. center_dual: boolean, optional (default=True) - If True, centers the dual potential using function + If True, centers the dual potential using function :ref:`center_ot_dual`. Returns ------- - gamma: (ns x nt) numpy.ndarray - Optimal transportation matrix for the given parameters - log: dict - If input log is true, a dictionary containing the cost and dual - variables and exit status + gamma: array-like, shape (ns, nt) + Optimal transportation matrix for the given + parameters + log: dict, optional + If input log is true, a dictionary containing the + cost and dual variables and exit status Examples @@ -232,26 +236,37 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True): >>> a=[.5,.5] >>> b=[.5,.5] >>> M=[[0.,1.],[1.,0.]] - >>> ot.emd(a,b,M) + >>> ot.emd(a, b, M) array([[0.5, 0. ], [0. , 0.5]]) References ---------- - .. [1] Bonneel, N., Van De Panne, M., Paris, S., & Heidrich, W. - (2011, December). Displacement interpolation using Lagrangian mass - transport. In ACM Transactions on Graphics (TOG) (Vol. 30, No. 6, p. - 158). ACM. + .. [1] Bonneel, N., Van De Panne, M., Paris, S., & Heidrich, W. (2011, + December). Displacement interpolation using Lagrangian mass transport. + In ACM Transactions on Graphics (TOG) (Vol. 30, No. 6, p. 158). ACM. See Also -------- - ot.bregman.sinkhorn : Entropic regularized OT - ot.optim.cg : General regularized OT""" - + ot.bregman.sinkhorn : Entropic regularized OT ot.optim.cg : General + regularized OT""" + + # convert to numpy if list + a, b, M = list_to_array(a, b, M) + + a0, b0, M0 = a, b, M + nx = get_backend(M0, a0, b0) + + # convert to numpy + M = nx.to_numpy(M) + a = nx.to_numpy(a) + b = nx.to_numpy(b) + + # ensure float64 a = np.asarray(a, dtype=np.float64) b = np.asarray(b, dtype=np.float64) - M = np.asarray(M, dtype=np.float64) + M = np.asarray(M, dtype=np.float64, order='C') # if empty array given then use uniform distributions if len(a) == 0: @@ -262,6 +277,11 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True): assert (a.shape[0] == M.shape[0] and b.shape[0] == M.shape[1]), \ "Dimension mismatch, check dimensions of M with a and b" + # ensure that same mass + np.testing.assert_almost_equal(a.sum(0), + b.sum(0), err_msg='a and b vector must have the same sum') + b=b*a.sum()/b.sum() + asel = a != 0 bsel = b != 0 @@ -277,12 +297,12 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True): if log: log = {} log['cost'] = cost - log['u'] = u - log['v'] = v + log['u'] = nx.from_numpy(u, type_as=a0) + log['v'] = nx.from_numpy(v, type_as=b0) log['warning'] = result_code_string log['result_code'] = result_code - return G, log - return G + return nx.from_numpy(G, type_as=M0), log + return nx.from_numpy(G, type_as=M0) def emd2(a, b, M, processes=multiprocessing.cpu_count(), @@ -303,20 +323,19 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), - M is the metric cost matrix - a and b are the sample weights - .. warning:: - Note that the M matrix needs to be a C-order numpy.array in float64 - format. + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. Uses the algorithm proposed in [1]_ Parameters ---------- - a : (ns,) numpy.ndarray, float64 + a : (ns,) array-like, float64 Source histogram (uniform weight if empty list) - b : (nt,) numpy.ndarray, float64 + b : (nt,) array-like, float64 Target histogram (uniform weight if empty list) - M : (ns,nt) numpy.ndarray, float64 - Loss matrix (c-order array with type float64) + M : (ns,nt) array-like, float64 + Loss matrix (for numpy c-order array with type float64) processes : int, optional (default=nb cpu) Nb of processes used for multiple emd computation (not used on windows) numItermax : int, optional (default=100000) @@ -333,9 +352,9 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), Returns ------- - W: float + W: float, array-like Optimal transportation loss for the given parameters - log: dictnp + log: dict If input log is true, a dictionary containing dual variables and exit status @@ -367,12 +386,22 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), ot.bregman.sinkhorn : Entropic regularized OT ot.optim.cg : General regularized OT""" + a, b, M = list_to_array(a, b, M) + + a0, b0, M0 = a, b, M + nx = get_backend(M0, a0, b0) + + # convert to numpy + M = nx.to_numpy(M) + a = nx.to_numpy(a) + b = nx.to_numpy(b) + a = np.asarray(a, dtype=np.float64) b = np.asarray(b, dtype=np.float64) - M = np.asarray(M, dtype=np.float64) + M = np.asarray(M, dtype=np.float64, order= 'C') # problem with pikling Forks - if sys.platform.endswith('win32'): + if sys.platform.endswith('win32') or not nx.__name__ == 'numpy': processes = 1 # if empty array given then use uniform distributions @@ -400,12 +429,15 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), result_code_string = check_result(result_code) log = {} + G = nx.from_numpy(G, type_as=M0) if return_matrix: log['G'] = G - log['u'] = u - log['v'] = v + log['u'] = nx.from_numpy(u, type_as=a0) + log['v'] = nx.from_numpy(v, type_as=b0) log['warning'] = result_code_string log['result_code'] = result_code + cost = nx.set_gradients(nx.from_numpy(cost, type_as=M0), + (a0,b0, M0), (log['u'], log['v'], G)) return [cost, log] else: def f(b): @@ -418,6 +450,11 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), if np.any(~asel) or np.any(~bsel): u, v = estimate_dual_null_weights(u, v, a, b, M) + G = nx.from_numpy(G, type_as=M0) + cost = nx.set_gradients(nx.from_numpy(cost, type_as=M0), + (a0,b0, M0), (nx.from_numpy(u, type_as=a0), + nx.from_numpy(v, type_as=b0),G)) + check_result(result_code) return cost @@ -637,6 +674,10 @@ def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True, if b.ndim == 0 or len(b) == 0: b = np.ones((x_b.shape[0],), dtype=np.float64) / x_b.shape[0] + # ensure that same mass + np.testing.assert_almost_equal(a.sum(0),b.sum(0),err_msg='a and b vector must have the same sum') + b=b*a.sum()/b.sum() + x_a_1d = x_a.reshape((-1,)) x_b_1d = x_b.reshape((-1,)) perm_a = np.argsort(x_a_1d) diff --git a/ot/utils.py b/ot/utils.py index 544c569..4dac0c5 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -16,6 +16,7 @@ from scipy.spatial.distance import cdist import sys import warnings from inspect import signature +from .backend import get_backend __time_tic_toc = time.time() @@ -41,8 +42,11 @@ def toq(): def kernel(x1, x2, method='gaussian', sigma=1, **kwargs): """Compute kernel matrix""" + + nx = get_backend(x1, x2) + if method.lower() in ['gaussian', 'gauss', 'rbf']: - K = np.exp(-dist(x1, x2) / (2 * sigma**2)) + K = nx.exp(-dist(x1, x2) / (2 * sigma**2)) return K @@ -52,6 +56,66 @@ def laplacian(x): return L +def list_to_array(*lst): + """ Convert a list if in numpy format """ + if len(lst) > 1: + return [np.array(a) if isinstance(a, list) else a for a in lst] + else: + return np.array(lst[0]) if isinstance(lst[0], list) else lst[0] + + +def proj_simplex(v, z=1): + r""" compute the closest point (orthogonal projection) on the + generalized (n-1)-simplex of a vector v wrt. to the Euclidean + distance, thus solving: + .. math:: + \mathcal{P}(w) \in arg\min_\gamma || \gamma - v ||_2 + + s.t. \gamma^T 1= z + + \gamma\geq 0 + + If v is a 2d array, compute all the projections wrt. axis 0 + + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. + + Parameters + ---------- + v : {array-like}, shape (n, d) + z : int, optional + 'size' of the simplex (each vectors sum to z, 1 by default) + + Returns + ------- + h : ndarray, shape (n,d) + Array of projections on the simplex + """ + nx = get_backend(v) + n = v.shape[0] + if v.ndim == 1: + d1 = 1 + v = v[:, None] + else: + d1 = 0 + d = v.shape[1] + + # sort u in ascending order + u = nx.sort(v, axis=0) + # take the descending order + u = nx.flip(u, 0) + cssv = nx.cumsum(u, axis=0) - z + ind = nx.arange(n, type_as=v)[:, None] + 1 + cond = u - cssv / ind > 0 + rho = nx.sum(cond, 0) + theta = cssv[rho - 1, nx.arange(d)] / rho + w = nx.maximum(v - theta[None, :], nx.zeros(v.shape, type_as=v)) + if d1: + return w[:, 0] + else: + return w + + def unif(n): """ return a uniform histogram of length n (simplex) @@ -84,52 +148,68 @@ def euclidean_distances(X, Y, squared=False): """ Considering the rows of X (and Y=X) as vectors, compute the distance matrix between each pair of vectors. + + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. + Parameters ---------- X : {array-like}, shape (n_samples_1, n_features) Y : {array-like}, shape (n_samples_2, n_features) squared : boolean, optional Return squared Euclidean distances. + Returns ------- distances : {array}, shape (n_samples_1, n_samples_2) """ - XX = np.einsum('ij,ij->i', X, X)[:, np.newaxis] - YY = np.einsum('ij,ij->i', Y, Y)[np.newaxis, :] - distances = np.dot(X, Y.T) - distances *= -2 - distances += XX - distances += YY - np.maximum(distances, 0, out=distances) + + nx = get_backend(X, Y) + + a2 = nx.einsum('ij,ij->i', X, X) + b2 = nx.einsum('ij,ij->i', Y, Y) + + c = -2 * nx.dot(X, Y.T) + c += a2[:, None] + c += b2[None, :] + + c = nx.maximum(c, 0) + + if not squared: + c = nx.sqrt(c) + if X is Y: - # Ensure that distances between vectors and themselves are set to 0.0. - # This may not be the case due to floating point rounding errors. - distances.flat[::distances.shape[0] + 1] = 0.0 - return distances if squared else np.sqrt(distances, out=distances) + c = c * (1 - nx.eye(X.shape[0], type_as=c)) + + return c def dist(x1, x2=None, metric='sqeuclidean'): - """Compute distance between samples in x1 and x2 using function scipy.spatial.distance.cdist + """Compute distance between samples in x1 and x2 + + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. Parameters ---------- - x1 : ndarray, shape (n1,d) + x1 : array-like, shape (n1,d) matrix with n1 samples of size d - x2 : array, shape (n2,d), optional + x2 : array-like, shape (n2,d), optional matrix with n2 samples of size d (if None then x2=x1) metric : str | callable, optional - Name of the metric to be computed (full list in the doc of scipy), If a string, - the distance function can be 'braycurtis', 'canberra', 'chebyshev', 'cityblock', - 'correlation', 'cosine', 'dice', 'euclidean', 'hamming', 'jaccard', 'kulsinski', - 'mahalanobis', 'matching', 'minkowski', 'rogerstanimoto', 'russellrao', 'seuclidean', + 'sqeuclidean' or 'euclidean' on all backends. On numpy the function also + accepts from the scipy.spatial.distance.cdist function : 'braycurtis', + 'canberra', 'chebyshev', 'cityblock', 'correlation', 'cosine', 'dice', + 'euclidean', 'hamming', 'jaccard', 'kulsinski', 'mahalanobis', + 'matching', 'minkowski', 'rogerstanimoto', 'russellrao', 'seuclidean', 'sokalmichener', 'sokalsneath', 'sqeuclidean', 'wminkowski', 'yule'. Returns ------- - M : np.array (n1,n2) + M : array-like, shape (n1, n2) distance matrix computed with given metric """ @@ -137,7 +217,13 @@ def dist(x1, x2=None, metric='sqeuclidean'): x2 = x1 if metric == "sqeuclidean": return euclidean_distances(x1, x2, squared=True) - return cdist(x1, x2, metric=metric) + elif metric == "euclidean": + return euclidean_distances(x1, x2, squared=False) + else: + if not get_backend(x1, x2).__name__ == 'numpy': + raise NotImplementedError() + else: + return cdist(x1, x2, metric=metric) def dist0(n, method='lin_square'): diff --git a/requirements.txt b/requirements.txt index 331dd57..4353247 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,4 +7,7 @@ pymanopt==0.2.4; python_version <'3' pymanopt; python_version >= '3' cvxopt scikit-learn +torch +jax +jaxlib pytest \ No newline at end of file diff --git a/test/test_backend.py b/test/test_backend.py new file mode 100644 index 0000000..bc5b00c --- /dev/null +++ b/test/test_backend.py @@ -0,0 +1,364 @@ +"""Tests for backend module """ + +# Author: Remi Flamary +# +# License: MIT License + +import ot +import ot.backend +from ot.backend import torch, jax + +import pytest + +import numpy as np +from numpy.testing import assert_array_almost_equal_nulp + +from ot.backend import get_backend, get_backend_list, to_numpy + + +backend_list = get_backend_list() + + +def test_get_backend_list(): + + lst = get_backend_list() + + assert len(lst) > 0 + assert isinstance(lst[0], ot.backend.NumpyBackend) + + +@pytest.mark.parametrize('nx', backend_list) +def test_to_numpy(nx): + + v = nx.zeros(10) + M = nx.ones((10, 10)) + + v2 = to_numpy(v) + assert isinstance(v2, np.ndarray) + + v2, M2 = to_numpy(v, M) + assert isinstance(M2, np.ndarray) + + +def test_get_backend(): + + A = np.zeros((3, 2)) + B = np.zeros((3, 1)) + + nx = get_backend(A) + assert nx.__name__ == 'numpy' + + nx = get_backend(A, B) + assert nx.__name__ == 'numpy' + + # error if no parameters + with pytest.raises(ValueError): + get_backend() + + # error if unknown types + with pytest.raises(ValueError): + get_backend(1, 2.0) + + # test torch + if torch: + + A2 = torch.from_numpy(A) + B2 = torch.from_numpy(B) + + nx = get_backend(A2) + assert nx.__name__ == 'torch' + + nx = get_backend(A2, B2) + assert nx.__name__ == 'torch' + + # test not unique types in input + with pytest.raises(ValueError): + get_backend(A, B2) + + if jax: + + A2 = jax.numpy.array(A) + B2 = jax.numpy.array(B) + + nx = get_backend(A2) + assert nx.__name__ == 'jax' + + nx = get_backend(A2, B2) + assert nx.__name__ == 'jax' + + # test not unique types in input + with pytest.raises(ValueError): + get_backend(A, B2) + + +@pytest.mark.parametrize('nx', backend_list) +def test_convert_between_backends(nx): + + A = np.zeros((3, 2)) + B = np.zeros((3, 1)) + + A2 = nx.from_numpy(A) + B2 = nx.from_numpy(B) + + assert isinstance(A2, nx.__type__) + assert isinstance(B2, nx.__type__) + + nx2 = get_backend(A2, B2) + + assert nx2.__name__ == nx.__name__ + + assert_array_almost_equal_nulp(nx.to_numpy(A2), A) + assert_array_almost_equal_nulp(nx.to_numpy(B2), B) + + +def test_empty_backend(): + + rnd = np.random.RandomState(0) + M = rnd.randn(10, 3) + v = rnd.randn(3) + + nx = ot.backend.Backend() + + with pytest.raises(NotImplementedError): + nx.from_numpy(M) + with pytest.raises(NotImplementedError): + nx.to_numpy(M) + with pytest.raises(NotImplementedError): + nx.set_gradients(0, 0, 0) + with pytest.raises(NotImplementedError): + nx.zeros((10, 3)) + with pytest.raises(NotImplementedError): + nx.ones((10, 3)) + with pytest.raises(NotImplementedError): + nx.arange(10, 1, 2) + with pytest.raises(NotImplementedError): + nx.full((10, 3), 3.14) + with pytest.raises(NotImplementedError): + nx.eye((10, 3)) + with pytest.raises(NotImplementedError): + nx.sum(M) + with pytest.raises(NotImplementedError): + nx.cumsum(M) + with pytest.raises(NotImplementedError): + nx.max(M) + with pytest.raises(NotImplementedError): + nx.min(M) + with pytest.raises(NotImplementedError): + nx.maximum(v, v) + with pytest.raises(NotImplementedError): + nx.minimum(v, v) + with pytest.raises(NotImplementedError): + nx.abs(M) + with pytest.raises(NotImplementedError): + nx.log(M) + with pytest.raises(NotImplementedError): + nx.exp(M) + with pytest.raises(NotImplementedError): + nx.sqrt(M) + with pytest.raises(NotImplementedError): + nx.dot(v, v) + with pytest.raises(NotImplementedError): + nx.norm(M) + with pytest.raises(NotImplementedError): + nx.exp(M) + with pytest.raises(NotImplementedError): + nx.any(M) + with pytest.raises(NotImplementedError): + nx.isnan(M) + with pytest.raises(NotImplementedError): + nx.isinf(M) + with pytest.raises(NotImplementedError): + nx.einsum('ij->i', M) + with pytest.raises(NotImplementedError): + nx.sort(M) + with pytest.raises(NotImplementedError): + nx.argsort(M) + with pytest.raises(NotImplementedError): + nx.flip(M) + + +@pytest.mark.parametrize('backend', backend_list) +def test_func_backends(backend): + + rnd = np.random.RandomState(0) + M = rnd.randn(10, 3) + v = rnd.randn(3) + val = np.array([1.0]) + + lst_tot = [] + + for nx in [ot.backend.NumpyBackend(), backend]: + + print('Backend: ', nx.__name__) + + lst_b = [] + lst_name = [] + + Mb = nx.from_numpy(M) + vb = nx.from_numpy(v) + val = nx.from_numpy(val) + + A = nx.set_gradients(val, v, v) + lst_b.append(nx.to_numpy(A)) + lst_name.append('set_gradients') + + A = nx.zeros((10, 3)) + A = nx.zeros((10, 3), type_as=Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('zeros') + + A = nx.ones((10, 3)) + A = nx.ones((10, 3), type_as=Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('ones') + + A = nx.arange(10, 1, 2) + lst_b.append(nx.to_numpy(A)) + lst_name.append('arange') + + A = nx.full((10, 3), 3.14) + A = nx.full((10, 3), 3.14, type_as=Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('full') + + A = nx.eye(10, 3) + A = nx.eye(10, 3, type_as=Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('eye') + + A = nx.sum(Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('sum') + + A = nx.sum(Mb, axis=1, keepdims=True) + lst_b.append(nx.to_numpy(A)) + lst_name.append('sum(axis)') + + A = nx.cumsum(Mb, 0) + lst_b.append(nx.to_numpy(A)) + lst_name.append('cumsum(axis)') + + A = nx.max(Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('max') + + A = nx.max(Mb, axis=1, keepdims=True) + lst_b.append(nx.to_numpy(A)) + lst_name.append('max(axis)') + + A = nx.min(Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('min') + + A = nx.min(Mb, axis=1, keepdims=True) + lst_b.append(nx.to_numpy(A)) + lst_name.append('min(axis)') + + A = nx.maximum(vb, 0) + lst_b.append(nx.to_numpy(A)) + lst_name.append('maximum') + + A = nx.minimum(vb, 0) + lst_b.append(nx.to_numpy(A)) + lst_name.append('minimum') + + A = nx.abs(Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('abs') + + A = nx.log(A) + lst_b.append(nx.to_numpy(A)) + lst_name.append('log') + + A = nx.exp(Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('exp') + + A = nx.sqrt(nx.abs(Mb)) + lst_b.append(nx.to_numpy(A)) + lst_name.append('sqrt') + + A = nx.dot(vb, vb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('dot(v,v)') + + A = nx.dot(Mb, vb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('dot(M,v)') + + A = nx.dot(Mb, Mb.T) + lst_b.append(nx.to_numpy(A)) + lst_name.append('dot(M,M)') + + A = nx.norm(vb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('norm') + + A = nx.any(vb > 0) + lst_b.append(nx.to_numpy(A)) + lst_name.append('any') + + A = nx.isnan(vb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('isnan') + + A = nx.isinf(vb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('isinf') + + A = nx.einsum('ij->i', Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('einsum(ij->i)') + + A = nx.einsum('ij,j->i', Mb, vb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('nx.einsum(ij,j->i)') + + A = nx.einsum('ij->i', Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('nx.einsum(ij->i)') + + A = nx.sort(Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('sort') + + A = nx.argsort(Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('argsort') + + A = nx.flip(Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('flip') + + lst_tot.append(lst_b) + + lst_np = lst_tot[0] + lst_b = lst_tot[1] + + for a1, a2, name in zip(lst_np, lst_b, lst_name): + if not np.allclose(a1, a2): + print('Assert fail on: ', name) + assert np.allclose(a1, a2, atol=1e-7) + + +def test_gradients_backends(): + + rnd = np.random.RandomState(0) + v = rnd.randn(10) + c = rnd.randn(1) + + if torch: + + nx = ot.backend.TorchBackend() + + v2 = torch.tensor(v, requires_grad=True) + c2 = torch.tensor(c, requires_grad=True) + + val = c2 * torch.sum(v2 * v2) + + val2 = nx.set_gradients(val, (v2, c2), (v2, c2)) + + val2.backward() + + assert torch.equal(v2.grad, v2) + assert torch.equal(c2.grad, c2) diff --git a/test/test_bregman.py b/test/test_bregman.py index 1ebd21f..7c5162a 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -9,6 +9,10 @@ import numpy as np import pytest import ot +from ot.backend import get_backend_list +from ot.backend import torch + +backend_list = get_backend_list() def test_sinkhorn(): @@ -30,6 +34,76 @@ def test_sinkhorn(): u, G.sum(0), atol=1e-05) # cf convergence sinkhorn +@pytest.mark.parametrize('nx', backend_list) +def test_sinkhorn_backends(nx): + n_samples = 100 + n_features = 2 + rng = np.random.RandomState(0) + + x = rng.randn(n_samples, n_features) + y = rng.randn(n_samples, n_features) + a = ot.utils.unif(n_samples) + + M = ot.dist(x, y) + + G = ot.sinkhorn(a, a, M, 1) + + ab = nx.from_numpy(a) + Mb = nx.from_numpy(M) + + Gb = ot.sinkhorn(ab, ab, Mb, 1) + + np.allclose(G, nx.to_numpy(Gb)) + + +@pytest.mark.parametrize('nx', backend_list) +def test_sinkhorn2_backends(nx): + n_samples = 100 + n_features = 2 + rng = np.random.RandomState(0) + + x = rng.randn(n_samples, n_features) + y = rng.randn(n_samples, n_features) + a = ot.utils.unif(n_samples) + + M = ot.dist(x, y) + + G = ot.sinkhorn(a, a, M, 1) + + ab = nx.from_numpy(a) + Mb = nx.from_numpy(M) + + Gb = ot.sinkhorn2(ab, ab, Mb, 1) + + np.allclose(G, nx.to_numpy(Gb)) + + +def test_sinkhorn2_gradients(): + n_samples = 100 + n_features = 2 + rng = np.random.RandomState(0) + + x = rng.randn(n_samples, n_features) + y = rng.randn(n_samples, n_features) + a = ot.utils.unif(n_samples) + + M = ot.dist(x, y) + + if torch: + + a1 = torch.tensor(a, requires_grad=True) + b1 = torch.tensor(a, requires_grad=True) + M1 = torch.tensor(M, requires_grad=True) + + val = ot.sinkhorn2(a1, b1, M1, 1) + + val.backward() + + assert a1.shape == a1.grad.shape + assert b1.shape == b1.grad.shape + assert M1.shape == M1.grad.shape + + def test_sinkhorn_empty(): # test sinkhorn n = 100 diff --git a/test/test_gromov.py b/test/test_gromov.py index 43da9fc..81138ca 100644 --- a/test/test_gromov.py +++ b/test/test_gromov.py @@ -181,7 +181,7 @@ def test_fgw(): M = ot.dist(ys, yt) M /= M.max() - G = ot.gromov.fused_gromov_wasserstein(M, C1, C2, p, q, 'square_loss', alpha=0.5) + G, log = ot.gromov.fused_gromov_wasserstein(M, C1, C2, p, q, 'square_loss', alpha=0.5, log=True) # check constratints np.testing.assert_allclose( @@ -242,9 +242,9 @@ def test_fgw_barycenter(): init_X = np.random.randn(n_samples, ys.shape[1]) - X, C = ot.gromov.fgw_barycenters(n_samples, [ys, yt], [C1, C2], [ot.unif(ns), ot.unif(nt)], [.5, .5], 0.5, - fixed_structure=False, fixed_features=True, init_X=init_X, - p=ot.unif(n_samples), loss_fun='square_loss', - max_iter=100, tol=1e-3) + X, C, log = ot.gromov.fgw_barycenters(n_samples, [ys, yt], [C1, C2], [ot.unif(ns), ot.unif(nt)], [.5, .5], 0.5, + fixed_structure=False, fixed_features=True, init_X=init_X, + p=ot.unif(n_samples), loss_fun='square_loss', + max_iter=100, tol=1e-3, log=True) np.testing.assert_allclose(C.shape, (n_samples, n_samples)) np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1])) diff --git a/test/test_ot.py b/test/test_ot.py index f45e4c9..3e953dc 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -12,9 +12,12 @@ from scipy.stats import wasserstein_distance import ot from ot.datasets import make_1D_gauss as gauss +from ot.backend import get_backend_list, torch +backend_list = get_backend_list() -def test_emd_dimension_mismatch(): + +def test_emd_dimension_and_mass_mismatch(): # test emd and emd2 for dimension mismatch n_samples = 100 n_features = 2 @@ -29,6 +32,80 @@ def test_emd_dimension_mismatch(): np.testing.assert_raises(AssertionError, ot.emd2, a, a, M) + b = a.copy() + a[0] = 100 + np.testing.assert_raises(AssertionError, ot.emd, a, b, M) + + +@pytest.mark.parametrize('nx', backend_list) +def test_emd_backends(nx): + n_samples = 100 + n_features = 2 + rng = np.random.RandomState(0) + + x = rng.randn(n_samples, n_features) + y = rng.randn(n_samples, n_features) + a = ot.utils.unif(n_samples) + + M = ot.dist(x, y) + + G = ot.emd(a, a, M) + + ab = nx.from_numpy(a) + Mb = nx.from_numpy(M) + + Gb = ot.emd(ab, ab, Mb) + + np.allclose(G, nx.to_numpy(Gb)) + + +@pytest.mark.parametrize('nx', backend_list) +def test_emd2_backends(nx): + n_samples = 100 + n_features = 2 + rng = np.random.RandomState(0) + + x = rng.randn(n_samples, n_features) + y = rng.randn(n_samples, n_features) + a = ot.utils.unif(n_samples) + + M = ot.dist(x, y) + + val = ot.emd2(a, a, M) + + ab = nx.from_numpy(a) + Mb = nx.from_numpy(M) + + valb = ot.emd2(ab, ab, Mb) + + np.allclose(val, nx.to_numpy(valb)) + + +def test_emd2_gradients(): + n_samples = 100 + n_features = 2 + rng = np.random.RandomState(0) + + x = rng.randn(n_samples, n_features) + y = rng.randn(n_samples, n_features) + a = ot.utils.unif(n_samples) + + M = ot.dist(x, y) + + if torch: + + a1 = torch.tensor(a, requires_grad=True) + b1 = torch.tensor(a, requires_grad=True) + M1 = torch.tensor(M, requires_grad=True) + + val = ot.emd2(a1, b1, M1) + + val.backward() + + assert a1.shape == a1.grad.shape + assert b1.shape == b1.grad.shape + assert M1.shape == M1.grad.shape + def test_emd_emd2(): # test emd and emd2 for simple identity @@ -83,7 +160,7 @@ def test_emd_1d_emd2_1d(): np.testing.assert_allclose(np.ones((m,)) / m, G.sum(0)) # check G is similar - np.testing.assert_allclose(G, G_1d) + np.testing.assert_allclose(G, G_1d, atol=1e-15) # check AssertionError is raised if called on non 1d arrays u = np.random.randn(n, 2) @@ -292,16 +369,6 @@ def test_warnings(): ot.emd(a, b, M, numItermax=1) assert "numItermax" in str(w[-1].message) #assert len(w) == 1 - a[0] = 100 - print('Computing {} EMD '.format(2)) - ot.emd(a, b, M) - assert "infeasible" in str(w[-1].message) - #assert len(w) == 2 - a[0] = -1 - print('Computing {} EMD '.format(2)) - ot.emd(a, b, M) - assert "infeasible" in str(w[-1].message) - #assert len(w) == 3 def test_dual_variables(): diff --git a/test/test_partial.py b/test/test_partial.py index 121f345..3571e2a 100755 --- a/test/test_partial.py +++ b/test/test_partial.py @@ -129,9 +129,9 @@ def test_partial_wasserstein(): # check constratints np.testing.assert_equal( - G.sum(1) <= p, [True] * len(p)) # cf convergence wasserstein + G.sum(1) - p <= 1e-5, [True] * len(p)) # cf convergence wasserstein np.testing.assert_equal( - G.sum(0) <= q, [True] * len(q)) # cf convergence wasserstein + G.sum(0) - q <= 1e-5, [True] * len(q)) # cf convergence wasserstein np.testing.assert_allclose( np.sum(G), m, atol=1e-04) diff --git a/test/test_utils.py b/test/test_utils.py index db9cda6..76b1faa 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -4,11 +4,47 @@ # # License: MIT License - +import pytest import ot import numpy as np import sys +from ot.backend import get_backend_list + +backend_list = get_backend_list() + + +@pytest.mark.parametrize('nx', backend_list) +def test_proj_simplex(nx): + n = 10 + rng = np.random.RandomState(0) + + # test on matrix when projection is done on axis 0 + x = rng.randn(n, 2) + x1 = nx.from_numpy(x) + + # all projections should sum to 1 + proj = ot.utils.proj_simplex(x1) + l1 = np.sum(nx.to_numpy(proj), axis=0) + l2 = np.ones(2) + np.testing.assert_allclose(l1, l2, atol=1e-5) + + # all projections should sum to 3 + proj = ot.utils.proj_simplex(x1, 3) + l1 = np.sum(nx.to_numpy(proj), axis=0) + l2 = 3 * np.ones(2) + np.testing.assert_allclose(l1, l2, atol=1e-5) + + # tets on vector + x = rng.randn(n) + x1 = nx.from_numpy(x) + + # all projections should sum to 1 + proj = ot.utils.proj_simplex(x1) + l1 = np.sum(nx.to_numpy(proj), axis=0) + l2 = np.ones(2) + np.testing.assert_allclose(l1, l2, atol=1e-5) + def test_parmap(): @@ -45,8 +81,8 @@ def test_tic_toc(): def test_kernel(): n = 100 - - x = np.random.randn(n, 2) + rng = np.random.RandomState(0) + x = rng.randn(n, 2) K = ot.utils.kernel(x, x) @@ -67,7 +103,8 @@ def test_dist(): n = 100 - x = np.random.randn(n, 2) + rng = np.random.RandomState(0) + x = rng.randn(n, 2) D = np.zeros((n, n)) for i in range(n): @@ -78,8 +115,27 @@ def test_dist(): D3 = ot.dist(x) # dist shoul return squared euclidean - np.testing.assert_allclose(D, D2) - np.testing.assert_allclose(D, D3) + np.testing.assert_allclose(D, D2, atol=1e-14) + np.testing.assert_allclose(D, D3, atol=1e-14) + + +@ pytest.mark.parametrize('nx', backend_list) +def test_dist_backends(nx): + + n = 100 + rng = np.random.RandomState(0) + x = rng.randn(n, 2) + x1 = nx.from_numpy(x) + + lst_metric = ['euclidean', 'sqeuclidean'] + + for metric in lst_metric: + + D = ot.dist(x, x, metric=metric) + D1 = ot.dist(x1, x1, metric=metric) + + # low atol because jax forces float32 + np.testing.assert_allclose(D, nx.to_numpy(D1), atol=1e-5) def test_dist0(): @@ -95,9 +151,11 @@ def test_dots(): n1, n2, n3, n4 = 100, 50, 200, 100 - A = np.random.randn(n1, n2) - B = np.random.randn(n2, n3) - C = np.random.randn(n3, n4) + rng = np.random.RandomState(0) + + A = rng.randn(n1, n2) + B = rng.randn(n2, n3) + C = rng.randn(n3, n4) X1 = ot.utils.dots(A, B, C) -- cgit v1.2.3 From 1c7e7ce2da8bb362c184fb6eae71fe7e36356494 Mon Sep 17 00:00:00 2001 From: kguerda-idris <84066930+kguerda-idris@users.noreply.github.com> Date: Wed, 29 Sep 2021 15:29:31 +0200 Subject: [MRG] OpenMP support (#260) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Added : OpenMP support Restored : Epsilon and Debug mode Replaced : parmap => multiprocessing is now replace by multithreading * Commit clean up * Number of CPUs correctly calculated on SLURM clusters * Corrected number of processes for cluster slurm * Mistake corrected * parmap is now deprecated * Now a different solver is used depending on the requested number of threads * Tiny mistake corrected * Folders are now in the ot library instead of at the root * Helpers is now correctly placed * Attempt to make compilation work smoothly * OS compatible path * NumThreads now defaults to 1 * Better flags * Mistake corrected in case of OpenMP unavailability * Revert OpenMP flags modification, which do not compile on Windows * Test helper functions * Helpers comments * Documentation update * File title corrected * Warning no longer using print * Last attempt for macos compilation * pls work * atempt * solving a type error * TypeError OpenMP * Compilation finally working on Windows * Bug solve, number of threads now correctly selected * 64 bits solver to avoid overflows for bigger problems * 64 bits EMD corrected Co-authored-by: kguerda-idris Co-authored-by: ncassereau-idris <84033440+ncassereau-idris@users.noreply.github.com> Co-authored-by: ncassereau Co-authored-by: Rémi Flamary --- ot/helpers/openmp_helpers.py | 85 ++ ot/helpers/pre_build_helpers.py | 87 ++ ot/lp/EMD.h | 5 +- ot/lp/EMD_wrapper.cpp | 124 ++- ot/lp/__init__.py | 71 +- ot/lp/emd_wrap.pyx | 9 +- ot/lp/full_bipartitegraph.h | 27 +- ot/lp/full_bipartitegraph_omp.h | 234 +++++ ot/lp/network_simplex_simple.h | 210 ++--- ot/lp/network_simplex_simple_omp.h | 1699 ++++++++++++++++++++++++++++++++++++ ot/utils.py | 38 +- setup.py | 12 +- test/test_helpers.py | 26 + 13 files changed, 2442 insertions(+), 185 deletions(-) create mode 100644 ot/helpers/openmp_helpers.py create mode 100644 ot/helpers/pre_build_helpers.py create mode 100644 ot/lp/full_bipartitegraph_omp.h create mode 100644 ot/lp/network_simplex_simple_omp.h create mode 100644 test/test_helpers.py (limited to 'ot/lp') diff --git a/ot/helpers/openmp_helpers.py b/ot/helpers/openmp_helpers.py new file mode 100644 index 0000000..a6ad38b --- /dev/null +++ b/ot/helpers/openmp_helpers.py @@ -0,0 +1,85 @@ +"""Helpers for OpenMP support during the build.""" + +# This code is adapted for a large part from the astropy openmp helpers, which +# can be found at: https://github.com/astropy/extension-helpers/blob/master/extension_helpers/_openmp_helpers.py # noqa + + +import os +import sys +import textwrap +import subprocess + +from distutils.errors import CompileError, LinkError + +from pre_build_helpers import compile_test_program + + +def get_openmp_flag(compiler): + """Get openmp flags for a given compiler""" + + if hasattr(compiler, 'compiler'): + compiler = compiler.compiler[0] + else: + compiler = compiler.__class__.__name__ + + if sys.platform == "win32" and ('icc' in compiler or 'icl' in compiler): + omp_flag = ['/Qopenmp'] + elif sys.platform == "win32": + omp_flag = ['/openmp'] + elif sys.platform in ("darwin", "linux") and "icc" in compiler: + omp_flag = ['-qopenmp'] + elif sys.platform == "darwin" and 'openmp' in os.getenv('CPPFLAGS', ''): + omp_flag = [] + else: + # Default flag for GCC and clang: + omp_flag = ['-fopenmp'] + if sys.platform.startswith("darwin"): + omp_flag += ["-Xpreprocessor", "-lomp"] + return omp_flag + + +def check_openmp_support(): + """Check whether OpenMP test code can be compiled and run""" + + code = textwrap.dedent( + """\ + #include + #include + int main(void) { + #pragma omp parallel + printf("nthreads=%d\\n", omp_get_num_threads()); + return 0; + } + """) + + extra_preargs = os.getenv('LDFLAGS', None) + if extra_preargs is not None: + extra_preargs = extra_preargs.strip().split(" ") + extra_preargs = [ + flag for flag in extra_preargs + if flag.startswith(('-L', '-Wl,-rpath', '-l'))] + + extra_postargs = get_openmp_flag + + try: + output, compile_flags = compile_test_program( + code, + extra_preargs=extra_preargs, + extra_postargs=extra_postargs + ) + + if output and 'nthreads=' in output[0]: + nthreads = int(output[0].strip().split('=')[1]) + openmp_supported = len(output) == nthreads + elif "PYTHON_CROSSENV" in os.environ: + # Since we can't run the test program when cross-compiling + # assume that openmp is supported if the program can be + # compiled. + openmp_supported = True + else: + openmp_supported = False + + except (CompileError, LinkError, subprocess.CalledProcessError): + openmp_supported = False + compile_flags = [] + return openmp_supported, compile_flags diff --git a/ot/helpers/pre_build_helpers.py b/ot/helpers/pre_build_helpers.py new file mode 100644 index 0000000..93ecd6a --- /dev/null +++ b/ot/helpers/pre_build_helpers.py @@ -0,0 +1,87 @@ +"""Helpers to check build environment before actual build of POT""" + +import os +import sys +import glob +import tempfile +import setuptools # noqa +import subprocess + +from distutils.dist import Distribution +from distutils.sysconfig import customize_compiler +from numpy.distutils.ccompiler import new_compiler +from numpy.distutils.command.config_compiler import config_cc + + +def _get_compiler(): + """Get a compiler equivalent to the one that will be used to build POT + Handles compiler specified as follows: + - python setup.py build_ext --compiler= + - CC= python setup.py build_ext + """ + dist = Distribution({'script_name': os.path.basename(sys.argv[0]), + 'script_args': sys.argv[1:], + 'cmdclass': {'config_cc': config_cc}}) + + cmd_opts = dist.command_options.get('build_ext') + if cmd_opts is not None and 'compiler' in cmd_opts: + compiler = cmd_opts['compiler'][1] + else: + compiler = None + + ccompiler = new_compiler(compiler=compiler) + customize_compiler(ccompiler) + + return ccompiler + + +def compile_test_program(code, extra_preargs=[], extra_postargs=[]): + """Check that some C code can be compiled and run""" + ccompiler = _get_compiler() + + # extra_(pre/post)args can be a callable to make it possible to get its + # value from the compiler + if callable(extra_preargs): + extra_preargs = extra_preargs(ccompiler) + if callable(extra_postargs): + extra_postargs = extra_postargs(ccompiler) + + start_dir = os.path.abspath('.') + + with tempfile.TemporaryDirectory() as tmp_dir: + try: + os.chdir(tmp_dir) + + # Write test program + with open('test_program.c', 'w') as f: + f.write(code) + + os.mkdir('objects') + + # Compile, test program + ccompiler.compile(['test_program.c'], output_dir='objects', + extra_postargs=extra_postargs) + + # Link test program + objects = glob.glob( + os.path.join('objects', '*' + ccompiler.obj_extension)) + ccompiler.link_executable(objects, 'test_program', + extra_preargs=extra_preargs, + extra_postargs=extra_postargs) + + if "PYTHON_CROSSENV" not in os.environ: + # Run test program if not cross compiling + # will raise a CalledProcessError if return code was non-zero + output = subprocess.check_output('./test_program') + output = output.decode( + sys.stdout.encoding or 'utf-8').splitlines() + else: + # Return an empty output if we are cross compiling + # as we cannot run the test_program + output = [] + except Exception: + raise + finally: + os.chdir(start_dir) + + return output, extra_postargs diff --git a/ot/lp/EMD.h b/ot/lp/EMD.h index c0fe7a3..8a1f9ac 100644 --- a/ot/lp/EMD.h +++ b/ot/lp/EMD.h @@ -18,19 +18,18 @@ #include #include -#include "network_simplex_simple.h" -using namespace lemon; typedef unsigned int node_id_type; enum ProblemType { INFEASIBLE, OPTIMAL, UNBOUNDED, - MAX_ITER_REACHED + MAX_ITER_REACHED }; int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int maxIter); +int EMD_wrap_omp(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int maxIter, int numThreads); diff --git a/ot/lp/EMD_wrapper.cpp b/ot/lp/EMD_wrapper.cpp index bc873ed..2bdc172 100644 --- a/ot/lp/EMD_wrapper.cpp +++ b/ot/lp/EMD_wrapper.cpp @@ -12,16 +12,22 @@ * */ + +#include "network_simplex_simple.h" +#include "network_simplex_simple_omp.h" #include "EMD.h" +#include int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G, double* alpha, double* beta, double *cost, int maxIter) { - // beware M and C anre strored in row major C style!!! - int n, m, i, cur; + // beware M and C are stored in row major C style!!! + + using namespace lemon; + int n, m, cur; typedef FullBipartiteDigraph Digraph; - DIGRAPH_TYPEDEFS(FullBipartiteDigraph); + DIGRAPH_TYPEDEFS(Digraph); // Get the number of non zero coordinates for r and c n=0; @@ -48,7 +54,7 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G, std::vector indI(n), indJ(m); std::vector weights1(n), weights2(m); Digraph di(n, m); - NetworkSimplexSimple net(di, true, n+m, n*m, maxIter); + NetworkSimplexSimple net(di, true, n+m, ((int64_t)n)*((int64_t)m), maxIter); // Set supply and demand, don't account for 0 values (faster) @@ -76,10 +82,12 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G, net.supplyMap(&weights1[0], n, &weights2[0], m); // Set the cost of each edge + int64_t idarc = 0; for (int i=0; i0) { + n++; + }else if(val<0){ + return INFEASIBLE; + } + } + m=0; + for (int i=0; i0) { + m++; + }else if(val<0){ + return INFEASIBLE; + } + } + + // Define the graph + + std::vector indI(n), indJ(m); + std::vector weights1(n), weights2(m); + Digraph di(n, m); + NetworkSimplexSimple net(di, true, n+m, ((int64_t)n)*((int64_t)m), maxIter, numThreads); + + // Set supply and demand, don't account for 0 values (faster) + + cur=0; + for (int i=0; i0) { + weights1[ cur ] = val; + indI[cur++]=i; + } + } + + // Demand is actually negative supply... + + cur=0; + for (int i=0; i0) { + weights2[ cur ] = -val; + indJ[cur++]=i; + } + } + + + net.supplyMap(&weights1[0], n, &weights2[0], m); + + // Set the cost of each edge + int64_t idarc = 0; + for (int i=0; i 1: - res = parmap(f, [b[:, i].copy() for i in range(nb)], processes) - else: - res = list(map(f, [b[:, i].copy() for i in range(nb)])) + warnings.warn( + "The 'processes' parameter has been deprecated. " + "Multiprocessing should be done outside of POT." + ) + res = list(map(f, [b[:, i].copy() for i in range(nb)])) return res def free_support_barycenter(measures_locations, measures_weights, X_init, b=None, weights=None, numItermax=100, - stopThr=1e-7, verbose=False, log=None): + stopThr=1e-7, verbose=False, log=None, numThreads=1): r""" Solves the free support (locations of the barycenters are optimized, not the weights) Wasserstein barycenter problem (i.e. the weighted Frechet mean for the 2-Wasserstein distance), formally: @@ -512,6 +541,10 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None Print information along iterations log : bool, optional record log if True + numThreads: int or "max", optional (default=1, i.e. OpenMP is not used) + If compiled with OpenMP, chooses the number of threads to parallelize. + "max" selects the highest number possible. + Returns ------- @@ -551,7 +584,7 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None for (measure_locations_i, measure_weights_i, weight_i) in zip(measures_locations, measures_weights, weights.tolist()): M_i = dist(X, measure_locations_i) - T_i = emd(b, measure_weights_i, M_i) + T_i = emd(b, measure_weights_i, M_i, numThreads=numThreads) T_sum = T_sum + weight_i * np.reshape(1. / b, (-1, 1)) * np.matmul(T_i, measure_locations_i) displacement_square_norm = np.sum(np.square(T_sum - X)) diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx index de9a700..42e08f4 100644 --- a/ot/lp/emd_wrap.pyx +++ b/ot/lp/emd_wrap.pyx @@ -20,6 +20,7 @@ import warnings cdef extern from "EMD.h": int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int maxIter) nogil + int EMD_wrap_omp(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int maxIter, int numThreads) nogil cdef enum ProblemType: INFEASIBLE, OPTIMAL, UNBOUNDED, MAX_ITER_REACHED @@ -38,7 +39,7 @@ def check_result(result_code): @cython.boundscheck(False) @cython.wraparound(False) -def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mode="c"] b, np.ndarray[double, ndim=2, mode="c"] M, int max_iter): +def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mode="c"] b, np.ndarray[double, ndim=2, mode="c"] M, int max_iter, int numThreads): """ Solves the Earth Movers distance problem and returns the optimal transport matrix @@ -109,8 +110,10 @@ def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mod # calling the function with nogil: - result_code = EMD_wrap(n1, n2, a.data, b.data, M.data, G.data, alpha.data, beta.data, &cost, max_iter) - + if numThreads == 1: + result_code = EMD_wrap(n1, n2, a.data, b.data, M.data, G.data, alpha.data, beta.data, &cost, max_iter) + else: + result_code = EMD_wrap_omp(n1, n2, a.data, b.data, M.data, G.data, alpha.data, beta.data, &cost, max_iter, numThreads) return G, cost, alpha, beta, result_code diff --git a/ot/lp/full_bipartitegraph.h b/ot/lp/full_bipartitegraph.h index 87a1bec..713ccb5 100644 --- a/ot/lp/full_bipartitegraph.h +++ b/ot/lp/full_bipartitegraph.h @@ -23,10 +23,10 @@ * */ -#ifndef LEMON_FULL_BIPARTITE_GRAPH_H -#define LEMON_FULL_BIPARTITE_GRAPH_H +#pragma once #include "core.h" +#include ///\ingroup graphs ///\file @@ -44,16 +44,16 @@ namespace lemon { //class Node; typedef int Node; //class Arc; - typedef long long Arc; + typedef int64_t Arc; protected: int _node_num; - long long _arc_num; + int64_t _arc_num; FullBipartiteDigraphBase() {} - void construct(int n1, int n2) { _node_num = n1+n2; _arc_num = n1 * n2; _n1=n1; _n2=n2;} + void construct(int n1, int n2) { _node_num = n1+n2; _arc_num = (int64_t)n1 * (int64_t)n2; _n1=n1; _n2=n2;} public: @@ -65,25 +65,25 @@ namespace lemon { Arc arc(const Node& s, const Node& t) const { if (s<_n1 && t>=_n1) - return Arc(s * _n2 + (t-_n1) ); + return Arc((int64_t)s * (int64_t)_n2 + (int64_t)(t-_n1) ); else return Arc(-1); } int nodeNum() const { return _node_num; } - long long arcNum() const { return _arc_num; } + int64_t arcNum() const { return _arc_num; } int maxNodeId() const { return _node_num - 1; } - long long maxArcId() const { return _arc_num - 1; } + int64_t maxArcId() const { return _arc_num - 1; } Node source(Arc arc) const { return arc / _n2; } Node target(Arc arc) const { return (arc % _n2) + _n1; } static int id(Node node) { return node; } - static long long id(Arc arc) { return arc; } + static int64_t id(Arc arc) { return arc; } static Node nodeFromId(int id) { return Node(id);} - static Arc arcFromId(int id) { return Arc(id);} + static Arc arcFromId(int64_t id) { return Arc(id);} Arc findArc(Node s, Node t, Arc prev = -1) const { @@ -136,7 +136,7 @@ namespace lemon { /// /// \brief A directed full graph class. /// - /// FullBipartiteDigraph is a simple and fast implmenetation of directed full + /// FullBipartiteDigraph is a simple and fast implementation of directed full /// (complete) graphs. It contains an arc from each node to each node /// (including a loop for each node), therefore the number of arcs /// is the square of the number of nodes. @@ -203,13 +203,10 @@ namespace lemon { /// \brief Number of nodes. int nodeNum() const { return Parent::nodeNum(); } /// \brief Number of arcs. - long long arcNum() const { return Parent::arcNum(); } + int64_t arcNum() const { return Parent::arcNum(); } }; } //namespace lemon - - -#endif //LEMON_FULL_GRAPH_H diff --git a/ot/lp/full_bipartitegraph_omp.h b/ot/lp/full_bipartitegraph_omp.h new file mode 100644 index 0000000..8cbed0b --- /dev/null +++ b/ot/lp/full_bipartitegraph_omp.h @@ -0,0 +1,234 @@ +/* -*- mode: C++; indent-tabs-mode: nil; -*- + * + * This file has been adapted by Nicolas Bonneel (2013), + * from full_graph.h from LEMON, a generic C++ optimization library, + * to implement a lightweight fully connected bipartite graph. A previous + * version of this file is used as part of the Displacement Interpolation + * project, + * Web: http://www.cs.ubc.ca/labs/imager/tr/2011/DisplacementInterpolation/ + * + * + **** Original file Copyright Notice : + * Copyright (C) 2003-2010 + * Egervary Jeno Kombinatorikus Optimalizalasi Kutatocsoport + * (Egervary Research Group on Combinatorial Optimization, EGRES). + * + * Permission to use, modify and distribute this software is granted + * provided that this copyright notice appears in all copies. For + * precise terms see the accompanying LICENSE file. + * + * This software is provided "AS IS" with no warranty of any kind, + * express or implied, and with no claim as to its suitability for any + * purpose. + * + */ + +#pragma once + +#include + +///\ingroup graphs +///\file +///\brief FullBipartiteDigraph and FullBipartiteGraph classes. + + +namespace lemon_omp { + + ///This \c \#define creates convenient type definitions for the following + ///types of \c Digraph: \c Node, \c NodeIt, \c Arc, \c ArcIt, \c InArcIt, + ///\c OutArcIt, \c BoolNodeMap, \c IntNodeMap, \c DoubleNodeMap, + ///\c BoolArcMap, \c IntArcMap, \c DoubleArcMap. + /// + ///\note If the graph type is a dependent type, ie. the graph type depend + ///on a template parameter, then use \c TEMPLATE_DIGRAPH_TYPEDEFS() + ///macro. +#define DIGRAPH_TYPEDEFS(Digraph) \ + typedef Digraph::Node Node; \ + typedef Digraph::Arc Arc; \ + + + ///Create convenience typedefs for the digraph types and iterators + + ///\see DIGRAPH_TYPEDEFS + /// + ///\note Use this macro, if the graph type is a dependent type, + ///ie. the graph type depend on a template parameter. +#define TEMPLATE_DIGRAPH_TYPEDEFS(Digraph) \ + typedef typename Digraph::Node Node; \ + typedef typename Digraph::Arc Arc; \ + + + class FullBipartiteDigraphBase { + public: + + typedef FullBipartiteDigraphBase Digraph; + + //class Node; + typedef int Node; + //class Arc; + typedef int64_t Arc; + + protected: + + int _node_num; + int64_t _arc_num; + + FullBipartiteDigraphBase() {} + + void construct(int n1, int n2) { _node_num = n1+n2; _arc_num = (int64_t)n1 * (int64_t)n2; _n1=n1; _n2=n2;} + + public: + + int _n1, _n2; + + + Node operator()(int ix) const { return Node(ix); } + static int index(const Node& node) { return node; } + + Arc arc(const Node& s, const Node& t) const { + if (s<_n1 && t>=_n1) + return Arc((int64_t)s * (int64_t)_n2 + (int64_t)(t-_n1) ); + else + return Arc(-1); + } + + int nodeNum() const { return _node_num; } + int64_t arcNum() const { return _arc_num; } + + int maxNodeId() const { return _node_num - 1; } + int64_t maxArcId() const { return _arc_num - 1; } + + Node source(Arc arc) const { return arc / _n2; } + Node target(Arc arc) const { return (arc % _n2) + _n1; } + + static int id(Node node) { return node; } + static int64_t id(Arc arc) { return arc; } + + static Node nodeFromId(int id) { return Node(id);} + static Arc arcFromId(int64_t id) { return Arc(id);} + + + Arc findArc(Node s, Node t, Arc prev = -1) const { + return prev == -1 ? arc(s, t) : -1; + } + + void first(Node& node) const { + node = _node_num - 1; + } + + static void next(Node& node) { + --node; + } + + void first(Arc& arc) const { + arc = _arc_num - 1; + } + + static void next(Arc& arc) { + --arc; + } + + void firstOut(Arc& arc, const Node& node) const { + if (node>=_n1) + arc = -1; + else + arc = (node + 1) * _n2 - 1; + } + + void nextOut(Arc& arc) const { + if (arc % _n2 == 0) arc = 0; + --arc; + } + + void firstIn(Arc& arc, const Node& node) const { + if (node<_n1) + arc = -1; + else + arc = _arc_num + node - _node_num; + } + + void nextIn(Arc& arc) const { + arc -= _n2; + if (arc < 0) arc = -1; + } + + }; + + /// \ingroup graphs + /// + /// \brief A directed full graph class. + /// + /// FullBipartiteDigraph is a simple and fast implmenetation of directed full + /// (complete) graphs. It contains an arc from each node to each node + /// (including a loop for each node), therefore the number of arcs + /// is the square of the number of nodes. + /// This class is completely static and it needs constant memory space. + /// Thus you can neither add nor delete nodes or arcs, however + /// the structure can be resized using resize(). + /// + /// This type fully conforms to the \ref concepts::Digraph "Digraph concept". + /// Most of its member functions and nested classes are documented + /// only in the concept class. + /// + /// This class provides constant time counting for nodes and arcs. + /// + /// \note FullBipartiteDigraph and FullBipartiteGraph classes are very similar, + /// but there are two differences. While this class conforms only + /// to the \ref concepts::Digraph "Digraph" concept, FullBipartiteGraph + /// conforms to the \ref concepts::Graph "Graph" concept, + /// moreover FullBipartiteGraph does not contain a loop for each + /// node as this class does. + /// + /// \sa FullBipartiteGraph + class FullBipartiteDigraph : public FullBipartiteDigraphBase { + typedef FullBipartiteDigraphBase Parent; + + public: + + /// \brief Default constructor. + /// + /// Default constructor. The number of nodes and arcs will be zero. + FullBipartiteDigraph() { construct(0,0); } + + /// \brief Constructor + /// + /// Constructor. + /// \param n The number of the nodes. + FullBipartiteDigraph(int n1, int n2) { construct(n1, n2); } + + + /// \brief Returns the node with the given index. + /// + /// Returns the node with the given index. Since this structure is + /// completely static, the nodes can be indexed with integers from + /// the range [0..nodeNum()-1]. + /// The index of a node is the same as its ID. + /// \sa index() + Node operator()(int ix) const { return Parent::operator()(ix); } + + /// \brief Returns the index of the given node. + /// + /// Returns the index of the given node. Since this structure is + /// completely static, the nodes can be indexed with integers from + /// the range [0..nodeNum()-1]. + /// The index of a node is the same as its ID. + /// \sa operator()() + static int index(const Node& node) { return Parent::index(node); } + + /// \brief Returns the arc connecting the given nodes. + /// + /// Returns the arc connecting the given nodes. + /*Arc arc(Node u, Node v) const { + return Parent::arc(u, v); + }*/ + + /// \brief Number of nodes. + int nodeNum() const { return Parent::nodeNum(); } + /// \brief Number of arcs. + int64_t arcNum() const { return Parent::arcNum(); } + }; + + + + +} //namespace lemon_omp diff --git a/ot/lp/network_simplex_simple.h b/ot/lp/network_simplex_simple.h index 630b595..3b46b9b 100644 --- a/ot/lp/network_simplex_simple.h +++ b/ot/lp/network_simplex_simple.h @@ -25,15 +25,17 @@ * */ -#ifndef LEMON_NETWORK_SIMPLEX_SIMPLE_H -#define LEMON_NETWORK_SIMPLEX_SIMPLE_H +#pragma once +#undef DEBUG_LVL #define DEBUG_LVL 0 #if DEBUG_LVL>0 #include #endif - +#undef EPSILON +#undef _EPSILON +#undef MAX_DEBUG_ITER #define EPSILON 2.2204460492503131e-15 #define _EPSILON 1e-8 #define MAX_DEBUG_ITER 100000 @@ -50,6 +52,7 @@ #include #include #include +#include #include #ifdef HASHMAP #include @@ -63,6 +66,8 @@ //#include "sparse_array_n.h" #include "full_bipartitegraph.h" +#undef INVALIDNODE +#undef INVALID #define INVALIDNODE -1 #define INVALID (-1) @@ -76,16 +81,16 @@ namespace lemon { class SparseValueVector { public: - SparseValueVector(int n=0) + SparseValueVector(size_t n=0) { } - void resize(int n=0){}; - T operator[](const int id) const + void resize(size_t n=0){}; + T operator[](const size_t id) const { #ifdef HASHMAP - typename stdext::hash_map::const_iterator it = data.find(id); + typename stdext::hash_map::const_iterator it = data.find(id); #else - typename std::map::const_iterator it = data.find(id); + typename std::map::const_iterator it = data.find(id); #endif if (it==data.end()) return 0; @@ -93,16 +98,16 @@ namespace lemon { return it->second; } - ProxyObject operator[](const int id) + ProxyObject operator[](const size_t id) { return ProxyObject( this, id ); } //private: #ifdef HASHMAP - stdext::hash_map data; + stdext::hash_map data; #else - std::map data; + std::map data; #endif }; @@ -110,7 +115,7 @@ namespace lemon { template class ProxyObject { public: - ProxyObject( SparseValueVector *v, int idx ){_v=v; _idx=idx;}; + ProxyObject( SparseValueVector *v, size_t idx ){_v=v; _idx=idx;}; ProxyObject & operator=( const T &v ) { // If we get here, we know that operator[] was called to perform a write access, // so we can insert an item in the vector if needed @@ -123,9 +128,9 @@ namespace lemon { // If we get here, we know that operator[] was called to perform a read access, // so we can simply return the existing object #ifdef HASHMAP - typename stdext::hash_map::iterator it = _v->data.find(_idx); + typename stdext::hash_map::iterator it = _v->data.find(_idx); #else - typename std::map::iterator it = _v->data.find(_idx); + typename std::map::iterator it = _v->data.find(_idx); #endif if (it==_v->data.end()) return 0; @@ -137,9 +142,9 @@ namespace lemon { { if (val==0) return; #ifdef HASHMAP - typename stdext::hash_map::iterator it = _v->data.find(_idx); + typename stdext::hash_map::iterator it = _v->data.find(_idx); #else - typename std::map::iterator it = _v->data.find(_idx); + typename std::map::iterator it = _v->data.find(_idx); #endif if (it==_v->data.end()) _v->data[_idx] = val; @@ -156,9 +161,9 @@ namespace lemon { { if (val==0) return; #ifdef HASHMAP - typename stdext::hash_map::iterator it = _v->data.find(_idx); + typename stdext::hash_map::iterator it = _v->data.find(_idx); #else - typename std::map::iterator it = _v->data.find(_idx); + typename std::map::iterator it = _v->data.find(_idx); #endif if (it==_v->data.end()) _v->data[_idx] = -val; @@ -173,7 +178,7 @@ namespace lemon { } SparseValueVector *_v; - int _idx; + size_t _idx; }; @@ -204,7 +209,7 @@ namespace lemon { /// /// \tparam GR The digraph type the algorithm runs on. /// \tparam V The number type used for flow amounts, capacity bounds - /// and supply values in the algorithm. By default, it is \c int. + /// and supply values in the algorithm. By default, it is \c int64_t. /// \tparam C The number type used for costs and potentials in the /// algorithm. By default, it is the same as \c V. /// @@ -214,7 +219,7 @@ namespace lemon { /// \note %NetworkSimplexSimple provides five different pivot rule /// implementations, from which the most efficient one is used /// by default. For more information, see \ref PivotRule. - template + template class NetworkSimplexSimple { public: @@ -228,7 +233,7 @@ namespace lemon { /// mixed order in the internal data structure. /// In special cases, it could lead to better overall performance, /// but it is usually slower. Therefore it is disabled by default. - NetworkSimplexSimple(const GR& graph, bool arc_mixing, int nbnodes, long long nb_arcs,int maxiters) : + NetworkSimplexSimple(const GR& graph, bool arc_mixing, int nbnodes, ArcsType nb_arcs, size_t maxiters) : _graph(graph), //_arc_id(graph), _arc_mixing(arc_mixing), _init_nb_nodes(nbnodes), _init_nb_arcs(nb_arcs), MAX(std::numeric_limits::max()), @@ -288,11 +293,11 @@ namespace lemon { private: - int max_iter; + size_t max_iter; TEMPLATE_DIGRAPH_TYPEDEFS(GR); typedef std::vector IntVector; - typedef std::vector UHalfIntVector; + typedef std::vector ArcVector; typedef std::vector ValueVector; typedef std::vector CostVector; // typedef SparseValueVector CostVector; @@ -315,9 +320,9 @@ namespace lemon { // Data related to the underlying digraph const GR &_graph; int _node_num; - int _arc_num; - int _all_arc_num; - int _search_arc_num; + ArcsType _arc_num; + ArcsType _all_arc_num; + ArcsType _search_arc_num; // Parameters of the problem SupplyType _stype; @@ -325,9 +330,9 @@ namespace lemon { inline int _node_id(int n) const {return _node_num-n-1;} ; - //IntArcMap _arc_id; - UHalfIntVector _source; - UHalfIntVector _target; +// IntArcMap _arc_id; + IntVector _source; // keep nodes as integers + IntVector _target; bool _arc_mixing; public: // Node and arc data @@ -341,7 +346,7 @@ namespace lemon { private: // Data for storing the spanning tree structure IntVector _parent; - IntVector _pred; + ArcVector _pred; IntVector _thread; IntVector _rev_thread; IntVector _succ_num; @@ -349,17 +354,17 @@ namespace lemon { IntVector _dirty_revs; BoolVector _forward; StateVector _state; - int _root; + ArcsType _root; // Temporary data used in the current pivot iteration - int in_arc, join, u_in, v_in, u_out, v_out; - int first, second, right, last; - int stem, par_stem, new_stem; + ArcsType in_arc, join, u_in, v_in, u_out, v_out; + ArcsType first, second, right, last; + ArcsType stem, par_stem, new_stem; Value delta; const Value MAX; - int mixingCoeff; + ArcsType mixingCoeff; public: @@ -373,27 +378,27 @@ namespace lemon { private: // thank you to DVK and MizardX from StackOverflow for this function! - inline int sequence(int k) const { - int smallv = (k > num_total_big_subsequence_numbers) & 1; + inline ArcsType sequence(ArcsType k) const { + ArcsType smallv = (k > num_total_big_subsequence_numbers) & 1; k -= num_total_big_subsequence_numbers * smallv; - int subsequence_length2 = subsequence_length- smallv; - int subsequence_num = (k / subsequence_length2) + num_big_subseqiences * smallv; - int subsequence_offset = (k % subsequence_length2) * mixingCoeff; + ArcsType subsequence_length2 = subsequence_length- smallv; + ArcsType subsequence_num = (k / subsequence_length2) + num_big_subseqiences * smallv; + ArcsType subsequence_offset = (k % subsequence_length2) * mixingCoeff; return subsequence_offset + subsequence_num; } - int subsequence_length; - int num_big_subseqiences; - int num_total_big_subsequence_numbers; + ArcsType subsequence_length; + ArcsType num_big_subseqiences; + ArcsType num_total_big_subsequence_numbers; - inline int getArcID(const Arc &arc) const + inline ArcsType getArcID(const Arc &arc) const { //int n = _arc_num-arc._id-1; - int n = _arc_num-GR::id(arc)-1; + ArcsType n = _arc_num-GR::id(arc)-1; - //int a = mixingCoeff*(n%mixingCoeff) + n/mixingCoeff; - //int b = _arc_id[arc]; + //ArcsType a = mixingCoeff*(n%mixingCoeff) + n/mixingCoeff; + //ArcsType b = _arc_id[arc]; if (_arc_mixing) return sequence(n); else @@ -401,16 +406,16 @@ namespace lemon { } // finally unused because too slow - inline int getSource(const int arc) const + inline ArcsType getSource(const ArcsType arc) const { - //int a = _source[arc]; + //ArcsType a = _source[arc]; //return a; - int n = _arc_num-arc-1; + ArcsType n = _arc_num-arc-1; if (_arc_mixing) n = mixingCoeff*(n%mixingCoeff) + n/mixingCoeff; - int b; + ArcsType b; if (n>=0) b = _node_id(_graph.source(GR::arcFromId( n ) )); else @@ -436,17 +441,17 @@ namespace lemon { private: // References to the NetworkSimplexSimple class - const UHalfIntVector &_source; - const UHalfIntVector &_target; + const IntVector &_source; + const IntVector &_target; const CostVector &_cost; const StateVector &_state; const CostVector &_pi; - int &_in_arc; - int _search_arc_num; + ArcsType &_in_arc; + ArcsType _search_arc_num; // Pivot rule data - int _block_size; - int _next_arc; + ArcsType _block_size; + ArcsType _next_arc; NetworkSimplexSimple &_ns; public: @@ -460,17 +465,16 @@ namespace lemon { { // The main parameters of the pivot rule const double BLOCK_SIZE_FACTOR = 1.0; - const int MIN_BLOCK_SIZE = 10; + const ArcsType MIN_BLOCK_SIZE = 10; - _block_size = std::max( int(BLOCK_SIZE_FACTOR * - std::sqrt(double(_search_arc_num))), - MIN_BLOCK_SIZE ); + _block_size = std::max(ArcsType(BLOCK_SIZE_FACTOR * std::sqrt(double(_search_arc_num))), MIN_BLOCK_SIZE); } + // Find next entering arc bool findEnteringArc() { Cost c, min = 0; - int e; - int cnt = _block_size; + ArcsType e; + ArcsType cnt = _block_size; double a; for (e = _next_arc; e != _search_arc_num; ++e) { c = _state[e] * (_cost[e] + _pi[_source[e]] - _pi[_target[e]]); @@ -516,7 +520,7 @@ namespace lemon { int _init_nb_nodes; - long long _init_nb_arcs; + ArcsType _init_nb_arcs; /// \name Parameters /// The parameters of the algorithm can be specified using these @@ -736,7 +740,7 @@ namespace lemon { for (int i = 0; i != _node_num; ++i) { _supply[i] = 0; } - for (int i = 0; i != _arc_num; ++i) { + for (ArcsType i = 0; i != _arc_num; ++i) { _cost[i] = 1; } _stype = GEQ; @@ -745,7 +749,7 @@ namespace lemon { - int divid (int x, int y) + int64_t divid (int64_t x, int64_t y) { return (x-x%y)/y; } @@ -775,7 +779,7 @@ namespace lemon { _node_num = _init_nb_nodes; _arc_num = _init_nb_arcs; int all_node_num = _node_num + 1; - int max_arc_num = _arc_num + 2 * _node_num; + ArcsType max_arc_num = _arc_num + 2 * _node_num; _source.resize(max_arc_num); _target.resize(max_arc_num); @@ -798,13 +802,13 @@ namespace lemon { //_arc_mixing=false; if (_arc_mixing) { // Store the arcs in a mixed order - int k = std::max(int(std::sqrt(double(_arc_num))), 10); + const ArcsType k = std::max(ArcsType(std::sqrt(double(_arc_num))), ArcsType(10)); mixingCoeff = k; subsequence_length = _arc_num / mixingCoeff + 1; num_big_subseqiences = _arc_num % mixingCoeff; num_total_big_subsequence_numbers = subsequence_length * num_big_subseqiences; - int i = 0, j = 0; + ArcsType i = 0, j = 0; Arc a; _graph.first(a); for (; a != INVALID; _graph.next(a)) { _source[i] = _node_id(_graph.source(a)); @@ -814,7 +818,7 @@ namespace lemon { } } else { // Store the arcs in the original order - int i = 0; + ArcsType i = 0; Arc a; _graph.first(a); for (; a != INVALID; _graph.next(a), ++i) { _source[i] = _node_id(_graph.source(a)); @@ -856,7 +860,7 @@ namespace lemon { Number totalCost() const { Number c = 0; for (ArcIt a(_graph); a != INVALID; ++a) { - int i = getArcID(a); + int64_t i = getArcID(a); c += Number(_flow[i]) * Number(_cost[i]); } return c; @@ -867,15 +871,15 @@ namespace lemon { Number c = 0; /*#ifdef HASHMAP - typename stdext::hash_map::const_iterator it; + typename stdext::hash_map::const_iterator it; #else - typename std::map::const_iterator it; + typename std::map::const_iterator it; #endif for (it = _flow.data.begin(); it!=_flow.data.end(); ++it) c += Number(it->second) * Number(_cost[it->first]); return c;*/ - for (unsigned long i=0; i<_flow.size(); i++) + for (ArcsType i=0; i<_flow.size(); i++) c += _flow[i] * Number(_cost[i]); return c; @@ -944,14 +948,14 @@ namespace lemon { // Initialize internal data structures bool init() { if (_node_num == 0) return false; - + // Check the sum of supply values _sum_supply = 0; for (int i = 0; i != _node_num; ++i) { _sum_supply += _supply[i]; } if ( fabs(_sum_supply) > _EPSILON ) return false; - + _sum_supply = 0; // Initialize artifical cost @@ -960,14 +964,14 @@ namespace lemon { ART_COST = std::numeric_limits::max() / 2 + 1; } else { ART_COST = 0; - for (int i = 0; i != _arc_num; ++i) { + for (ArcsType i = 0; i != _arc_num; ++i) { if (_cost[i] > ART_COST) ART_COST = _cost[i]; } ART_COST = (ART_COST + 1) * _node_num; } // Initialize arc maps - for (int i = 0; i != _arc_num; ++i) { + for (ArcsType i = 0; i != _arc_num; ++i) { //_flow[i] = 0; //by default, the sparse matrix is empty _state[i] = STATE_LOWER; } @@ -988,7 +992,7 @@ namespace lemon { // EQ supply constraints _search_arc_num = _arc_num; _all_arc_num = _arc_num + _node_num; - for (int u = 0, e = _arc_num; u != _node_num; ++u, ++e) { + for (ArcsType u = 0, e = _arc_num; u != _node_num; ++u, ++e) { _parent[u] = _root; _pred[u] = e; _thread[u] = u + 1; @@ -1016,8 +1020,8 @@ namespace lemon { else if (_sum_supply > 0) { // LEQ supply constraints _search_arc_num = _arc_num + _node_num; - int f = _arc_num + _node_num; - for (int u = 0, e = _arc_num; u != _node_num; ++u, ++e) { + ArcsType f = _arc_num + _node_num; + for (ArcsType u = 0, e = _arc_num; u != _node_num; ++u, ++e) { _parent[u] = _root; _thread[u] = u + 1; _rev_thread[u + 1] = u; @@ -1054,8 +1058,8 @@ namespace lemon { else { // GEQ supply constraints _search_arc_num = _arc_num + _node_num; - int f = _arc_num + _node_num; - for (int u = 0, e = _arc_num; u != _node_num; ++u, ++e) { + ArcsType f = _arc_num + _node_num; + for (ArcsType u = 0, e = _arc_num; u != _node_num; ++u, ++e) { _parent[u] = _root; _thread[u] = u + 1; _rev_thread[u + 1] = u; @@ -1120,9 +1124,9 @@ namespace lemon { second = _source[in_arc]; } delta = INF; - int result = 0; + char result = 0; Value d; - int e; + ArcsType e; // Search the cycle along the path form the first node to the root for (int u = first; u != join; u = _parent[u]) { @@ -1239,7 +1243,7 @@ namespace lemon { // Update _rev_thread using the new _thread values for (int i = 0; i != int(_dirty_revs.size()); ++i) { - u = _dirty_revs[i]; + int u = _dirty_revs[i]; _rev_thread[_thread[u]] = u; } @@ -1257,7 +1261,7 @@ namespace lemon { u = w; } _pred[u_in] = in_arc; - _forward[u_in] = ((unsigned int)u_in == _source[in_arc]); + _forward[u_in] = (u_in == _source[in_arc]); _succ_num[u_in] = old_succ_num; // Set limits for updating _last_succ form v_in and v_out @@ -1328,7 +1332,7 @@ namespace lemon { if (_sum_supply > 0) total -= _sum_supply; if (total <= 0) return true; - IntVector arc_vector; + ArcVector arc_vector; if (_sum_supply >= 0) { if (supply_nodes.size() == 1 && demand_nodes.size() == 1) { // Perform a reverse graph search from the sink to the source @@ -1345,7 +1349,7 @@ namespace lemon { Arc a; _graph.firstIn(a, v); for (; a != INVALID; _graph.nextIn(a)) { if (reached[u = _graph.source(a)]) continue; - int j = getArcID(a); + ArcsType j = getArcID(a); if (INF >= total) { arc_vector.push_back(j); reached[u] = true; @@ -1355,7 +1359,7 @@ namespace lemon { } } else { // Find the min. cost incomming arc for each demand node - for (int i = 0; i != int(demand_nodes.size()); ++i) { + for (int i = 0; i != demand_nodes.size(); ++i) { Node v = demand_nodes[i]; Cost c, min_cost = std::numeric_limits::max(); Arc min_arc = INVALID; @@ -1393,7 +1397,7 @@ namespace lemon { } // Perform heuristic initial pivots - for (int i = 0; i != int(arc_vector.size()); ++i) { + for (ArcsType i = 0; i != arc_vector.size(); ++i) { in_arc = arc_vector[i]; // l'erreur est probablement ici... if (_state[in_arc] * (_cost[in_arc] + _pi[_source[in_arc]] - @@ -1423,7 +1427,7 @@ namespace lemon { // Perform heuristic initial pivots if (!initialPivots()) return UNBOUNDED; - int iter_number=0; + size_t iter_number=0; //pivot.setDantzig(true); // Execute the Network Simplex algorithm while (pivot.findEnteringArc()) { @@ -1443,7 +1447,7 @@ namespace lemon { double a; a= (fabs(_pi[_source[in_arc]])>=fabs(_pi[_target[in_arc]])) ? fabs(_pi[_source[in_arc]]) : fabs(_pi[_target[in_arc]]); a=a>=fabs(_cost[in_arc])?a:fabs(_cost[in_arc]); - for (int i=0; i<_flow.size(); i++) { + for (int64_t i=0; i<_flow.size(); i++) { sumFlow+=_state[i]*_flow[i]; } std::cout << "Sum of the flow " << std::setprecision(20) << sumFlow << "\n" << iter_number << " iterations, current cost=" << curCost << "\nReduced cost=" << _state[in_arc] * (_cost[in_arc] + _pi[_source[in_arc]] -_pi[_target[in_arc]]) << "\nPrecision = "<< -EPSILON*(a) << "\n"; @@ -1482,12 +1486,12 @@ namespace lemon { double a; a= (fabs(_pi[_source[in_arc]])>=fabs(_pi[_target[in_arc]])) ? fabs(_pi[_source[in_arc]]) : fabs(_pi[_target[in_arc]]); a=a>=fabs(_cost[in_arc])?a:fabs(_cost[in_arc]); - for (int i=0; i<_flow.size(); i++) { + for (int64_t i=0; i<_flow.size(); i++) { sumFlow+=_state[i]*_flow[i]; } - + std::cout << "Sum of the flow " << std::setprecision(20) << sumFlow << "\n" << niter << " iterations, current cost=" << curCost << "\nReduced cost=" << _state[in_arc] * (_cost[in_arc] + _pi[_source[in_arc]] -_pi[_target[in_arc]]) << "\nPrecision = "<< -EPSILON*(a) << "\n"; - + std::cout << "Arc in = (" << _node_id(_source[in_arc]) << ", " << _node_id(_target[in_arc]) <<")\n"; std::cout << "Supplies = (" << _supply[_source[in_arc]] << ", " << _supply[_target[in_arc]] << ")\n"; @@ -1505,7 +1509,7 @@ namespace lemon { #endif // Check feasibility if( retVal == OPTIMAL){ - for (int e = _search_arc_num; e != _all_arc_num; ++e) { + for (ArcsType e = _search_arc_num; e != _all_arc_num; ++e) { if (_flow[e] != 0){ if (fabs(_flow[e]) > _EPSILON) // change of the original code following issue #126 return INFEASIBLE; @@ -1521,20 +1525,20 @@ namespace lemon { if (_sum_supply == 0) { if (_stype == GEQ) { Cost max_pot = -std::numeric_limits::max(); - for (int i = 0; i != _node_num; ++i) { + for (ArcsType i = 0; i != _node_num; ++i) { if (_pi[i] > max_pot) max_pot = _pi[i]; } if (max_pot > 0) { - for (int i = 0; i != _node_num; ++i) + for (ArcsType i = 0; i != _node_num; ++i) _pi[i] -= max_pot; } } else { Cost min_pot = std::numeric_limits::max(); - for (int i = 0; i != _node_num; ++i) { + for (ArcsType i = 0; i != _node_num; ++i) { if (_pi[i] < min_pot) min_pot = _pi[i]; } if (min_pot < 0) { - for (int i = 0; i != _node_num; ++i) + for (ArcsType i = 0; i != _node_num; ++i) _pi[i] -= min_pot; } } @@ -1548,5 +1552,3 @@ namespace lemon { ///@} } //namespace lemon - -#endif //LEMON_NETWORK_SIMPLEX_H diff --git a/ot/lp/network_simplex_simple_omp.h b/ot/lp/network_simplex_simple_omp.h new file mode 100644 index 0000000..87e4c05 --- /dev/null +++ b/ot/lp/network_simplex_simple_omp.h @@ -0,0 +1,1699 @@ +/* -*- mode: C++; indent-tabs-mode: nil; -*- +* +* +* This file has been adapted by Nicolas Bonneel (2013), +* from network_simplex.h from LEMON, a generic C++ optimization library, +* to implement a lightweight network simplex for mass transport, more +* memory efficient than the original file. A previous version of this file +* is used as part of the Displacement Interpolation project, +* Web: http://www.cs.ubc.ca/labs/imager/tr/2011/DisplacementInterpolation/ +* +* Revisions: +* March 2015: added OpenMP parallelization +* March 2017: included Antoine Rolet's trick to make it more robust +* April 2018: IMPORTANT bug fix + uses 64bit integers (slightly slower but less risks of overflows), updated to a newer version of the algo by LEMON, sparse flow by default + minor edits. +* +* +**** Original file Copyright Notice : +* +* Copyright (C) 2003-2010 +* Egervary Jeno Kombinatorikus Optimalizalasi Kutatocsoport +* (Egervary Research Group on Combinatorial Optimization, EGRES). +* +* Permission to use, modify and distribute this software is granted +* provided that this copyright notice appears in all copies. For +* precise terms see the accompanying LICENSE file. +* +* This software is provided "AS IS" with no warranty of any kind, +* express or implied, and with no claim as to its suitability for any +* purpose. +* +*/ + +#pragma once +#undef DEBUG_LVL +#define DEBUG_LVL 0 + +#if DEBUG_LVL>0 +#include +#endif + +#undef EPSILON +#undef _EPSILON +#undef MAX_DEBUG_ITER +#define EPSILON std::numeric_limits::epsilon()*10 +#define _EPSILON 1e-8 +#define MAX_DEBUG_ITER 100000 + +/// \ingroup min_cost_flow_algs +/// +/// \file +/// \brief Network Simplex algorithm for finding a minimum cost flow. + +// if your compiler has troubles with unorderedmaps, just comment the following line to use a slower std::map instead +#define HASHMAP // now handled with unorderedmaps instead of stdext::hash_map. Should be better supported. + +#define SPARSE_FLOW // a sparse flow vector will be 10-15% slower for small problems but uses less memory and becomes faster for large problems (40k total nodes) + +#include +#include +#include +#include +#ifdef HASHMAP +#include +#else +#include +#endif +//#include "core.h" +//#include "lmath.h" + +#ifdef OMP +#include +#endif +#include + + +//#include "sparse_array_n.h" +#include "full_bipartitegraph_omp.h" + +#undef INVALIDNODE +#undef INVALID +#define INVALIDNODE -1 +#define INVALID (-1) + +namespace lemon_omp { + + int64_t max_threads = -1; + + template + class ProxyObject; + + template + class SparseValueVector + { + public: + SparseValueVector(size_t n = 0) // parameter n for compatibility with standard vectors + { + } + void resize(size_t n = 0) {}; + T operator[](const size_t id) const + { +#ifdef HASHMAP + typename std::unordered_map::const_iterator it = data.find(id); +#else + typename std::map::const_iterator it = data.find(id); +#endif + if (it == data.end()) + return 0; + else + return it->second; + } + + ProxyObject operator[](const size_t id) + { + return ProxyObject(this, id); + } + + //private: +#ifdef HASHMAP + std::unordered_map data; +#else + std::map data; +#endif + + }; + + template + class ProxyObject { + public: + ProxyObject(SparseValueVector *v, size_t idx) { _v = v; _idx = idx; }; + ProxyObject & operator=(const T &v) { + // If we get here, we know that operator[] was called to perform a write access, + // so we can insert an item in the vector if needed + if (v != 0) + _v->data[_idx] = v; + return *this; + } + + operator T() { + // If we get here, we know that operator[] was called to perform a read access, + // so we can simply return the existing object +#ifdef HASHMAP + typename std::unordered_map::iterator it = _v->data.find(_idx); +#else + typename std::map::iterator it = _v->data.find(_idx); +#endif + if (it == _v->data.end()) + return 0; + else + return it->second; + } + + void operator+=(T val) + { + if (val == 0) return; +#ifdef HASHMAP + typename std::unordered_map::iterator it = _v->data.find(_idx); +#else + typename std::map::iterator it = _v->data.find(_idx); +#endif + if (it == _v->data.end()) + _v->data[_idx] = val; + else + { + T sum = it->second + val; + if (sum == 0) + _v->data.erase(it); + else + it->second = sum; + } + } + void operator-=(T val) + { + if (val == 0) return; +#ifdef HASHMAP + typename std::unordered_map::iterator it = _v->data.find(_idx); +#else + typename std::map::iterator it = _v->data.find(_idx); +#endif + if (it == _v->data.end()) + _v->data[_idx] = -val; + else + { + T sum = it->second - val; + if (sum == 0) + _v->data.erase(it); + else + it->second = sum; + } + } + + SparseValueVector *_v; + size_t _idx; + }; + + + + /// \addtogroup min_cost_flow_algs + /// @{ + + /// \brief Implementation of the primal Network Simplex algorithm + /// for finding a \ref min_cost_flow "minimum cost flow". + /// + /// \ref NetworkSimplexSimple implements the primal Network Simplex algorithm + /// for finding a \ref min_cost_flow "minimum cost flow" + /// \ref amo93networkflows, \ref dantzig63linearprog, + /// \ref kellyoneill91netsimplex. + /// This algorithm is a highly efficient specialized version of the + /// linear programming simplex method directly for the minimum cost + /// flow problem. + /// + /// In general, %NetworkSimplexSimple is the fastest implementation available + /// in LEMON for this problem. + /// Moreover, it supports both directions of the supply/demand inequality + /// constraints. For more information, see \ref SupplyType. + /// + /// Most of the parameters of the problem (except for the digraph) + /// can be given using separate functions, and the algorithm can be + /// executed using the \ref run() function. If some parameters are not + /// specified, then default values will be used. + /// + /// \tparam GR The digraph type the algorithm runs on. + /// \tparam V The number type used for flow amounts, capacity bounds + /// and supply values in the algorithm. By default, it is \c int. + /// \tparam C The number type used for costs and potentials in the + /// algorithm. By default, it is the same as \c V. + /// + /// \warning Both number types must be signed and all input data must + /// be integer. + /// + /// \note %NetworkSimplexSimple provides five different pivot rule + /// implementations, from which the most efficient one is used + /// by default. For more information, see \ref PivotRule. + template + class NetworkSimplexSimple + { + public: + + /// \brief Constructor. + /// + /// The constructor of the class. + /// + /// \param graph The digraph the algorithm runs on. + /// \param arc_mixing Indicate if the arcs have to be stored in a + /// mixed order in the internal data structure. + /// In special cases, it could lead to better overall performance, + /// but it is usually slower. Therefore it is disabled by default. + NetworkSimplexSimple(const GR& graph, bool arc_mixing, int nbnodes, ArcsType nb_arcs, size_t maxiters = 0, int numThreads=-1) : + _graph(graph), //_arc_id(graph), + _arc_mixing(arc_mixing), _init_nb_nodes(nbnodes), _init_nb_arcs(nb_arcs), + MAX(std::numeric_limits::max()), + INF(std::numeric_limits::has_infinity ? + std::numeric_limits::infinity() : MAX) + { + // Reset data structures + reset(); + max_iter = maxiters; +#ifdef OMP + if (max_threads < 0) { + max_threads = omp_get_max_threads(); + } + if (numThreads > 0 && numThreads<=max_threads){ + num_threads = numThreads; + } else if (numThreads == -1 || numThreads>max_threads) { + num_threads = max_threads; + } else { + num_threads = 1; + } + omp_set_num_threads(num_threads); +#else + num_threads = 1; +#endif + } + + /// The type of the flow amounts, capacity bounds and supply values + typedef V Value; + /// The type of the arc costs + typedef C Cost; + + public: + /// \brief Problem type constants for the \c run() function. + /// + /// Enum type containing the problem type constants that can be + /// returned by the \ref run() function of the algorithm. + enum ProblemType { + /// The problem has no feasible solution (flow). + INFEASIBLE, + /// The problem has optimal solution (i.e. it is feasible and + /// bounded), and the algorithm has found optimal flow and node + /// potentials (primal and dual solutions). + OPTIMAL, + /// The objective function of the problem is unbounded, i.e. + /// there is a directed cycle having negative total cost and + /// infinite upper bound. + UNBOUNDED, + // The maximum number of iteration has been reached + MAX_ITER_REACHED + }; + + /// \brief Constants for selecting the type of the supply constraints. + /// + /// Enum type containing constants for selecting the supply type, + /// i.e. the direction of the inequalities in the supply/demand + /// constraints of the \ref min_cost_flow "minimum cost flow problem". + /// + /// The default supply type is \c GEQ, the \c LEQ type can be + /// selected using \ref supplyType(). + /// The equality form is a special case of both supply types. + enum SupplyType { + /// This option means that there are "greater or equal" + /// supply/demand constraints in the definition of the problem. + GEQ, + /// This option means that there are "less or equal" + /// supply/demand constraints in the definition of the problem. + LEQ + }; + + + + private: + size_t max_iter; + int num_threads; + TEMPLATE_DIGRAPH_TYPEDEFS(GR); + + typedef std::vector IntVector; + typedef std::vector ArcVector; + typedef std::vector ValueVector; + typedef std::vector CostVector; + // typedef SparseValueVector CostVector; + typedef std::vector BoolVector; + // Note: vector is used instead of vector for efficiency reasons + + // State constants for arcs + enum ArcState { + STATE_UPPER = -1, + STATE_TREE = 0, + STATE_LOWER = 1 + }; + + typedef std::vector StateVector; + // Note: vector is used instead of vector for + // efficiency reasons + + private: + + // Data related to the underlying digraph + const GR &_graph; + int _node_num; + ArcsType _arc_num; + ArcsType _all_arc_num; + ArcsType _search_arc_num; + + // Parameters of the problem + SupplyType _stype; + Value _sum_supply; + + inline int _node_id(int n) const { return _node_num - n - 1; }; + + //IntArcMap _arc_id; + IntVector _source; // keep nodes as integers + IntVector _target; + bool _arc_mixing; + + // Node and arc data + CostVector _cost; + ValueVector _supply; +#ifdef SPARSE_FLOW + SparseValueVector _flow; +#else + ValueVector _flow; +#endif + + CostVector _pi; + + // Data for storing the spanning tree structure + IntVector _parent; + ArcVector _pred; + IntVector _thread; + IntVector _rev_thread; + IntVector _succ_num; + IntVector _last_succ; + IntVector _dirty_revs; + BoolVector _forward; + StateVector _state; + ArcsType _root; + + // Temporary data used in the current pivot iteration + ArcsType in_arc, join, u_in, v_in, u_out, v_out; + ArcsType first, second, right, last; + ArcsType stem, par_stem, new_stem; + Value delta; + + const Value MAX; + + ArcsType mixingCoeff; + + public: + + /// \brief Constant for infinite upper bounds (capacities). + /// + /// Constant for infinite upper bounds (capacities). + /// It is \c std::numeric_limits::infinity() if available, + /// \c std::numeric_limits::max() otherwise. + const Value INF; + + private: + + // thank you to DVK and MizardX from StackOverflow for this function! + inline ArcsType sequence(ArcsType k) const { + ArcsType smallv = (k > num_total_big_subsequence_numbers) & 1; + + k -= num_total_big_subsequence_numbers * smallv; + ArcsType subsequence_length2 = subsequence_length - smallv; + ArcsType subsequence_num = (k / subsequence_length2) + num_big_subsequences * smallv; + ArcsType subsequence_offset = (k % subsequence_length2) * mixingCoeff; + + return subsequence_offset + subsequence_num; + } + ArcsType subsequence_length; + ArcsType num_big_subsequences; + ArcsType num_total_big_subsequence_numbers; + + inline ArcsType getArcID(const Arc &arc) const + { + //int n = _arc_num-arc._id-1; + ArcsType n = _arc_num - GR::id(arc) - 1; + + //ArcsType a = mixingCoeff*(n%mixingCoeff) + n/mixingCoeff; + //ArcsType b = _arc_id[arc]; + if (_arc_mixing) + return sequence(n); + else + return n; + } + + // finally unused because too slow + inline ArcsType getSource(const ArcsType arc) const + { + //ArcsType a = _source[arc]; + //return a; + + ArcsType n = _arc_num - arc - 1; + if (_arc_mixing) + n = mixingCoeff*(n%mixingCoeff) + n / mixingCoeff; + + ArcsType b; + if (n >= 0) + b = _node_id(_graph.source(GR::arcFromId(n))); + else + { + n = arc + 1 - _arc_num; + if (n <= _node_num) + b = _node_num; + else + if (n >= _graph._n1) + b = _graph._n1; + else + b = _graph._n1 - n; + } + + return b; + } + + + + // Implementation of the Block Search pivot rule + class BlockSearchPivotRule + { + private: + + // References to the NetworkSimplexSimple class + const IntVector &_source; + const IntVector &_target; + const CostVector &_cost; + const StateVector &_state; + const CostVector &_pi; + ArcsType &_in_arc; + ArcsType _search_arc_num; + + // Pivot rule data + ArcsType _block_size; + ArcsType _next_arc; + NetworkSimplexSimple &_ns; + + public: + + // Constructor + BlockSearchPivotRule(NetworkSimplexSimple &ns) : + _source(ns._source), _target(ns._target), + _cost(ns._cost), _state(ns._state), _pi(ns._pi), + _in_arc(ns.in_arc), _search_arc_num(ns._search_arc_num), + _next_arc(0), _ns(ns) + { + // The main parameters of the pivot rule + const double BLOCK_SIZE_FACTOR = 1; + const ArcsType MIN_BLOCK_SIZE = 10; + + _block_size = std::max(ArcsType(BLOCK_SIZE_FACTOR * std::sqrt(double(_search_arc_num))), MIN_BLOCK_SIZE); + } + + // Find next entering arc + bool findEnteringArc() { + Cost min_val = 0; + + ArcsType N = _ns.num_threads; + + std::vector minArray(N, 0); + std::vector arcId(N); + ArcsType bs = (ArcsType)ceil(_block_size / (double)N); + + for (ArcsType i = 0; i < _search_arc_num; i += _block_size) { + + ArcsType e; + int j; +#pragma omp parallel + { +#ifdef OMP + int t = omp_get_thread_num(); +#else + int t = 0; +#endif + +#pragma omp for schedule(static, bs) lastprivate(e) + for (j = 0; j < std::min(i + _block_size, _search_arc_num) - i; j++) { + e = (_next_arc + i + j); if (e >= _search_arc_num) e -= _search_arc_num; + Cost c = _state[e] * (_cost[e] + _pi[_source[e]] - _pi[_target[e]]); + if (c < minArray[t]) { + minArray[t] = c; + arcId[t] = e; + } + } + } + for (int j = 0; j < N; j++) { + if (minArray[j] < min_val) { + min_val = minArray[j]; + _in_arc = arcId[j]; + } + } + Cost a = std::abs(_pi[_source[_in_arc]]) > std::abs(_pi[_target[_in_arc]]) ? std::abs(_pi[_source[_in_arc]]) : std::abs(_pi[_target[_in_arc]]); + a = a > std::abs(_cost[_in_arc]) ? a : std::abs(_cost[_in_arc]); + if (min_val < -EPSILON*a) { + _next_arc = e; + return true; + } + } + + Cost a = fabs(_pi[_source[_in_arc]]) > fabs(_pi[_target[_in_arc]]) ? fabs(_pi[_source[_in_arc]]) : fabs(_pi[_target[_in_arc]]); + a = a > fabs(_cost[_in_arc]) ? a : fabs(_cost[_in_arc]); + if (min_val >= -EPSILON*a) return false; + + return true; + } + + + // Find next entering arc + /*bool findEnteringArc() { + Cost min_val = 0; + int N = omp_get_max_threads(); + std::vector minArray(N); + std::vector arcId(N); + + ArcsType bs = (ArcsType)ceil(_block_size / (double)N); + for (ArcsType i = 0; i < _search_arc_num; i += _block_size) { + + ArcsType maxJ = std::min(i + _block_size, _search_arc_num) - i; + ArcsType j; +#pragma omp parallel + { + int t = omp_get_thread_num(); + Cost minV = 0; + ArcsType arcStart = _next_arc + i; + ArcsType arc = -1; +#pragma omp for schedule(static, bs) + for (j = 0; j < maxJ; j++) { + ArcsType e = arcStart + j; if (e >= _search_arc_num) e -= _search_arc_num; + Cost c = _state[e] * (_cost[e] + _pi[_source[e]] - _pi[_target[e]]); + if (c < minV) { + minV = c; + arc = e; + } + } + + minArray[t] = minV; + arcId[t] = arc; + } + for (int j = 0; j < N; j++) { + if (minArray[j] < min_val) { + min_val = minArray[j]; + _in_arc = arcId[j]; + } + } + + //FIX by Antoine Rolet to avoid precision issues + Cost a = std::max(std::abs(_cost[_in_arc]), std::max(std::abs(_pi[_source[_in_arc]]), std::abs(_pi[_target[_in_arc]]))); + if (min_val <-std::numeric_limits::epsilon()*a) { + _next_arc = _next_arc + i + maxJ - 1; + if (_next_arc >= _search_arc_num) _next_arc -= _search_arc_num; + return true; + } + } + + if (min_val >= 0) { + return false; + } + + return true; + }*/ + + + /*bool findEnteringArc() { + Cost c, min = 0; + int cnt = _block_size; + int e, min_arc = _next_arc; + for (e = _next_arc; e < _search_arc_num; ++e) { + c = _state[e] * (_cost[e] + _pi[_source[e]] - _pi[_target[e]]); + if (c < min) { + min = c; + min_arc = e; + + } + if (--cnt == 0) { + if (min < 0) break; + cnt = _block_size; + + } + + } + if (min == 0 || cnt > 0) { + for (e = 0; e < _next_arc; ++e) { + c = _state[e] * (_cost[e] + _pi[_source[e]] - _pi[_target[e]]); + if (c < min) { + min = c; + min_arc = e; + + } + if (--cnt == 0) { + if (min < 0) break; + cnt = _block_size; + + } + + } + + } + if (min >= 0) return false; + _in_arc = min_arc; + _next_arc = e; + return true; + }*/ + + + + }; //class BlockSearchPivotRule + + + + public: + + + + int _init_nb_nodes; + ArcsType _init_nb_arcs; + + /// \name Parameters + /// The parameters of the algorithm can be specified using these + /// functions. + + /// @{ + + + /// \brief Set the costs of the arcs. + /// + /// This function sets the costs of the arcs. + /// If it is not used before calling \ref run(), the costs + /// will be set to \c 1 on all arcs. + /// + /// \param map An arc map storing the costs. + /// Its \c Value type must be convertible to the \c Cost type + /// of the algorithm. + /// + /// \return (*this) + template + NetworkSimplexSimple& costMap(const CostMap& map) { + Arc a; _graph.first(a); + for (; a != INVALID; _graph.next(a)) { + _cost[getArcID(a)] = map[a]; + } + return *this; + } + + + /// \brief Set the costs of one arc. + /// + /// This function sets the costs of one arcs. + /// Done for memory reasons + /// + /// \param arc An arc. + /// \param arc A cost + /// + /// \return (*this) + template + NetworkSimplexSimple& setCost(const Arc& arc, const Value cost) { + _cost[getArcID(arc)] = cost; + return *this; + } + + + /// \brief Set the supply values of the nodes. + /// + /// This function sets the supply values of the nodes. + /// If neither this function nor \ref stSupply() is used before + /// calling \ref run(), the supply of each node will be set to zero. + /// + /// \param map A node map storing the supply values. + /// Its \c Value type must be convertible to the \c Value type + /// of the algorithm. + /// + /// \return (*this) + template + NetworkSimplexSimple& supplyMap(const SupplyMap& map) { + Node n; _graph.first(n); + for (; n != INVALIDNODE; _graph.next(n)) { + _supply[_node_id(n)] = map[n]; + } + return *this; + } + template + NetworkSimplexSimple& supplyMap(const SupplyMap* map1, int n1, const SupplyMap* map2, int n2) { + Node n; _graph.first(n); + for (; n != INVALIDNODE; _graph.next(n)) { + if (n + NetworkSimplexSimple& supplyMapAll(SupplyMap val1, int n1, SupplyMap val2, int n2) { + Node n; _graph.first(n); + for (; n != INVALIDNODE; _graph.next(n)) { + if (n(*this) + NetworkSimplexSimple& stSupply(const Node& s, const Node& t, Value k) { + for (int i = 0; i != _node_num; ++i) { + _supply[i] = 0; + } + _supply[_node_id(s)] = k; + _supply[_node_id(t)] = -k; + return *this; + } + + /// \brief Set the type of the supply constraints. + /// + /// This function sets the type of the supply/demand constraints. + /// If it is not used before calling \ref run(), the \ref GEQ supply + /// type will be used. + /// + /// For more information, see \ref SupplyType. + /// + /// \return (*this) + NetworkSimplexSimple& supplyType(SupplyType supply_type) { + _stype = supply_type; + return *this; + } + + /// @} + + /// \name Execution Control + /// The algorithm can be executed using \ref run(). + + /// @{ + + /// \brief Run the algorithm. + /// + /// This function runs the algorithm. + /// The paramters can be specified using functions \ref lowerMap(), + /// \ref upperMap(), \ref costMap(), \ref supplyMap(), \ref stSupply(), + /// \ref supplyType(). + /// For example, + /// \code + /// NetworkSimplexSimple ns(graph); + /// ns.lowerMap(lower).upperMap(upper).costMap(cost) + /// .supplyMap(sup).run(); + /// \endcode + /// + /// This function can be called more than once. All the given parameters + /// are kept for the next call, unless \ref resetParams() or \ref reset() + /// is used, thus only the modified parameters have to be set again. + /// If the underlying digraph was also modified after the construction + /// of the class (or the last \ref reset() call), then the \ref reset() + /// function must be called. + /// + /// \param pivot_rule The pivot rule that will be used during the + /// algorithm. For more information, see \ref PivotRule. + /// + /// \return \c INFEASIBLE if no feasible flow exists, + /// \n \c OPTIMAL if the problem has optimal solution + /// (i.e. it is feasible and bounded), and the algorithm has found + /// optimal flow and node potentials (primal and dual solutions), + /// \n \c UNBOUNDED if the objective function of the problem is + /// unbounded, i.e. there is a directed cycle having negative total + /// cost and infinite upper bound. + /// + /// \see ProblemType, PivotRule + /// \see resetParams(), reset() + ProblemType run() { +#if DEBUG_LVL>0 + std::cout << "OPTIMAL = " << OPTIMAL << "\nINFEASIBLE = " << INFEASIBLE << "\nUNBOUNDED = " << UNBOUNDED << "\nMAX_ITER_REACHED" << MAX_ITER_REACHED << "\n" ; +#endif + if (!init()) return INFEASIBLE; +#if DEBUG_LVL>0 + std::cout << "Init done, starting iterations\n"; +#endif + + return start(); + } + + /// \brief Reset all the parameters that have been given before. + /// + /// This function resets all the paramaters that have been given + /// before using functions \ref lowerMap(), \ref upperMap(), + /// \ref costMap(), \ref supplyMap(), \ref stSupply(), \ref supplyType(). + /// + /// It is useful for multiple \ref run() calls. Basically, all the given + /// parameters are kept for the next \ref run() call, unless + /// \ref resetParams() or \ref reset() is used. + /// If the underlying digraph was also modified after the construction + /// of the class or the last \ref reset() call, then the \ref reset() + /// function must be used, otherwise \ref resetParams() is sufficient. + /// + /// For example, + /// \code + /// NetworkSimplexSimple ns(graph); + /// + /// // First run + /// ns.lowerMap(lower).upperMap(upper).costMap(cost) + /// .supplyMap(sup).run(); + /// + /// // Run again with modified cost map (resetParams() is not called, + /// // so only the cost map have to be set again) + /// cost[e] += 100; + /// ns.costMap(cost).run(); + /// + /// // Run again from scratch using resetParams() + /// // (the lower bounds will be set to zero on all arcs) + /// ns.resetParams(); + /// ns.upperMap(capacity).costMap(cost) + /// .supplyMap(sup).run(); + /// \endcode + /// + /// \return (*this) + /// + /// \see reset(), run() + NetworkSimplexSimple& resetParams() { + for (int i = 0; i != _node_num; ++i) { + _supply[i] = 0; + } + for (ArcsType i = 0; i != _arc_num; ++i) { + _cost[i] = 1; + } + _stype = GEQ; + return *this; + } + + + /// \brief Reset the internal data structures and all the parameters + /// that have been given before. + /// + /// This function resets the internal data structures and all the + /// paramaters that have been given before using functions \ref lowerMap(), + /// \ref upperMap(), \ref costMap(), \ref supplyMap(), \ref stSupply(), + /// \ref supplyType(). + /// + /// It is useful for multiple \ref run() calls. Basically, all the given + /// parameters are kept for the next \ref run() call, unless + /// \ref resetParams() or \ref reset() is used. + /// If the underlying digraph was also modified after the construction + /// of the class or the last \ref reset() call, then the \ref reset() + /// function must be used, otherwise \ref resetParams() is sufficient. + /// + /// See \ref resetParams() for examples. + /// + /// \return (*this) + /// + /// \see resetParams(), run() + NetworkSimplexSimple& reset() { + // Resize vectors + _node_num = _init_nb_nodes; + _arc_num = _init_nb_arcs; + int all_node_num = _node_num + 1; + ArcsType max_arc_num = _arc_num + 2 * _node_num; + + _source.resize(max_arc_num); + _target.resize(max_arc_num); + + _cost.resize(max_arc_num); + _supply.resize(all_node_num); + _flow.resize(max_arc_num); + _pi.resize(all_node_num); + + _parent.resize(all_node_num); + _pred.resize(all_node_num); + _forward.resize(all_node_num); + _thread.resize(all_node_num); + _rev_thread.resize(all_node_num); + _succ_num.resize(all_node_num); + _last_succ.resize(all_node_num); + _state.resize(max_arc_num); + + + //_arc_mixing=false; + if (_arc_mixing && _node_num > 1) { + // Store the arcs in a mixed order + //ArcsType k = std::max(ArcsType(std::sqrt(double(_arc_num))), ArcsType(10)); + const ArcsType k = std::max(ArcsType(_arc_num / _node_num), ArcsType(3)); + mixingCoeff = k; + subsequence_length = _arc_num / mixingCoeff + 1; + num_big_subsequences = _arc_num % mixingCoeff; + num_total_big_subsequence_numbers = subsequence_length * num_big_subsequences; + +#pragma omp parallel for schedule(static) + for (Arc a = 0; a <= _graph.maxArcId(); a++) { // --a <=> _graph.next(a) , -1 == INVALID + ArcsType i = sequence(_graph.maxArcId()-a); + _source[i] = _node_id(_graph.source(a)); + _target[i] = _node_id(_graph.target(a)); + } + } else { + // Store the arcs in the original order + ArcsType i = 0; + Arc a; _graph.first(a); + for (; a != INVALID; _graph.next(a), ++i) { + _source[i] = _node_id(_graph.source(a)); + _target[i] = _node_id(_graph.target(a)); + //_arc_id[a] = i; + } + } + + // Reset parameters + resetParams(); + return *this; + } + + /// @} + + /// \name Query Functions + /// The results of the algorithm can be obtained using these + /// functions.\n + /// The \ref run() function must be called before using them. + + /// @{ + + /// \brief Return the total cost of the found flow. + /// + /// This function returns the total cost of the found flow. + /// Its complexity is O(e). + /// + /// \note The return type of the function can be specified as a + /// template parameter. For example, + /// \code + /// ns.totalCost(); + /// \endcode + /// It is useful if the total cost cannot be stored in the \c Cost + /// type of the algorithm, which is the default return type of the + /// function. + /// + /// \pre \ref run() must be called before using this function. + /*template + Number totalCost() const { + Number c = 0; + for (ArcIt a(_graph); a != INVALID; ++a) { + int i = getArcID(a); + c += Number(_flow[i]) * Number(_cost[i]); + } + return c; + }*/ + + template + Number totalCost() const { + Number c = 0; + +#ifdef SPARSE_FLOW + #ifdef HASHMAP + typename std::unordered_map::const_iterator it; + #else + typename std::map::const_iterator it; + #endif + for (it = _flow.data.begin(); it!=_flow.data.end(); ++it) + c += Number(it->second) * Number(_cost[it->first]); + return c; +#else + for (ArcsType i = 0; i<_flow.size(); i++) + c += _flow[i] * Number(_cost[i]); + return c; +#endif + } + +#ifndef DOXYGEN + Cost totalCost() const { + return totalCost(); + } +#endif + + /// \brief Return the flow on the given arc. + /// + /// This function returns the flow on the given arc. + /// + /// \pre \ref run() must be called before using this function. + Value flow(const Arc& a) const { + return _flow[getArcID(a)]; + } + + /// \brief Return the flow map (the primal solution). + /// + /// This function copies the flow value on each arc into the given + /// map. The \c Value type of the algorithm must be convertible to + /// the \c Value type of the map. + /// + /// \pre \ref run() must be called before using this function. + template + void flowMap(FlowMap &map) const { + Arc a; _graph.first(a); + for (; a != INVALID; _graph.next(a)) { + map.set(a, _flow[getArcID(a)]); + } + } + + /// \brief Return the potential (dual value) of the given node. + /// + /// This function returns the potential (dual value) of the + /// given node. + /// + /// \pre \ref run() must be called before using this function. + Cost potential(const Node& n) const { + return _pi[_node_id(n)]; + } + + /// \brief Return the potential map (the dual solution). + /// + /// This function copies the potential (dual value) of each node + /// into the given map. + /// The \c Cost type of the algorithm must be convertible to the + /// \c Value type of the map. + /// + /// \pre \ref run() must be called before using this function. + template + void potentialMap(PotentialMap &map) const { + Node n; _graph.first(n); + for (; n != INVALID; _graph.next(n)) { + map.set(n, _pi[_node_id(n)]); + } + } + + /// @} + + private: + + // Initialize internal data structures + bool init() { + if (_node_num == 0) return false; + + // Check the sum of supply values + _sum_supply = 0; + for (int i = 0; i != _node_num; ++i) { + _sum_supply += _supply[i]; + } + /*if (!((_stype == GEQ && _sum_supply <= 0) || + (_stype == LEQ && _sum_supply >= 0))) return false;*/ + + + // Initialize artifical cost + Cost ART_COST; + if (std::numeric_limits::is_exact) { + ART_COST = std::numeric_limits::max() / 2 + 1; + } else { + ART_COST = 0; + for (ArcsType i = 0; i != _arc_num; ++i) { + if (_cost[i] > ART_COST) ART_COST = _cost[i]; + } + ART_COST = (ART_COST + 1) * _node_num; + } + + // Initialize arc maps + for (ArcsType i = 0; i != _arc_num; ++i) { +#ifndef SPARSE_FLOW + _flow[i] = 0; //by default, the sparse matrix is empty +#endif + _state[i] = STATE_LOWER; + } +#ifdef SPARSE_FLOW + _flow = SparseValueVector(); +#endif + + // Set data for the artificial root node + _root = _node_num; + _parent[_root] = -1; + _pred[_root] = -1; + _thread[_root] = 0; + _rev_thread[0] = _root; + _succ_num[_root] = _node_num + 1; + _last_succ[_root] = _root - 1; + _supply[_root] = -_sum_supply; + _pi[_root] = 0; + + // Add artificial arcs and initialize the spanning tree data structure + if (_sum_supply == 0) { + // EQ supply constraints + _search_arc_num = _arc_num; + _all_arc_num = _arc_num + _node_num; + for (ArcsType u = 0, e = _arc_num; u != _node_num; ++u, ++e) { + _parent[u] = _root; + _pred[u] = e; + _thread[u] = u + 1; + _rev_thread[u + 1] = u; + _succ_num[u] = 1; + _last_succ[u] = u; + _state[e] = STATE_TREE; + if (_supply[u] >= 0) { + _forward[u] = true; + _pi[u] = 0; + _source[e] = u; + _target[e] = _root; + _flow[e] = _supply[u]; + _cost[e] = 0; + } else { + _forward[u] = false; + _pi[u] = ART_COST; + _source[e] = _root; + _target[e] = u; + _flow[e] = -_supply[u]; + _cost[e] = ART_COST; + } + } + } else if (_sum_supply > 0) { + // LEQ supply constraints + _search_arc_num = _arc_num + _node_num; + ArcsType f = _arc_num + _node_num; + for (ArcsType u = 0, e = _arc_num; u != _node_num; ++u, ++e) { + _parent[u] = _root; + _thread[u] = u + 1; + _rev_thread[u + 1] = u; + _succ_num[u] = 1; + _last_succ[u] = u; + if (_supply[u] >= 0) { + _forward[u] = true; + _pi[u] = 0; + _pred[u] = e; + _source[e] = u; + _target[e] = _root; + _flow[e] = _supply[u]; + _cost[e] = 0; + _state[e] = STATE_TREE; + } else { + _forward[u] = false; + _pi[u] = ART_COST; + _pred[u] = f; + _source[f] = _root; + _target[f] = u; + _flow[f] = -_supply[u]; + _cost[f] = ART_COST; + _state[f] = STATE_TREE; + _source[e] = u; + _target[e] = _root; + //_flow[e] = 0; //by default, the sparse matrix is empty + _cost[e] = 0; + _state[e] = STATE_LOWER; + ++f; + } + } + _all_arc_num = f; + } else { + // GEQ supply constraints + _search_arc_num = _arc_num + _node_num; + ArcsType f = _arc_num + _node_num; + for (ArcsType u = 0, e = _arc_num; u != _node_num; ++u, ++e) { + _parent[u] = _root; + _thread[u] = u + 1; + _rev_thread[u + 1] = u; + _succ_num[u] = 1; + _last_succ[u] = u; + if (_supply[u] <= 0) { + _forward[u] = false; + _pi[u] = 0; + _pred[u] = e; + _source[e] = _root; + _target[e] = u; + _flow[e] = -_supply[u]; + _cost[e] = 0; + _state[e] = STATE_TREE; + } else { + _forward[u] = true; + _pi[u] = -ART_COST; + _pred[u] = f; + _source[f] = u; + _target[f] = _root; + _flow[f] = _supply[u]; + _state[f] = STATE_TREE; + _cost[f] = ART_COST; + _source[e] = _root; + _target[e] = u; + //_flow[e] = 0; //by default, the sparse matrix is empty + _cost[e] = 0; + _state[e] = STATE_LOWER; + ++f; + } + } + _all_arc_num = f; + } + + return true; + } + + // Find the join node + void findJoinNode() { + int u = _source[in_arc]; + int v = _target[in_arc]; + while (u != v) { + if (_succ_num[u] < _succ_num[v]) { + u = _parent[u]; + } else { + v = _parent[v]; + } + } + join = u; + } + + // Find the leaving arc of the cycle and returns true if the + // leaving arc is not the same as the entering arc + bool findLeavingArc() { + // Initialize first and second nodes according to the direction + // of the cycle + if (_state[in_arc] == STATE_LOWER) { + first = _source[in_arc]; + second = _target[in_arc]; + } else { + first = _target[in_arc]; + second = _source[in_arc]; + } + delta = INF; + char result = 0; + Value d; + ArcsType e; + + // Search the cycle along the path form the first node to the root + for (int u = first; u != join; u = _parent[u]) { + e = _pred[u]; + d = _forward[u] ? _flow[e] : INF; + if (d < delta) { + delta = d; + u_out = u; + result = 1; + } + } + // Search the cycle along the path form the second node to the root + for (int u = second; u != join; u = _parent[u]) { + e = _pred[u]; + d = _forward[u] ? INF : _flow[e]; + if (d <= delta) { + delta = d; + u_out = u; + result = 2; + } + } + + if (result == 1) { + u_in = first; + v_in = second; + } else { + u_in = second; + v_in = first; + } + return result != 0; + } + + // Change _flow and _state vectors + void changeFlow(bool change) { + // Augment along the cycle + if (delta > 0) { + Value val = _state[in_arc] * delta; + _flow[in_arc] += val; + for (int u = _source[in_arc]; u != join; u = _parent[u]) { + _flow[_pred[u]] += _forward[u] ? -val : val; + } + for (int u = _target[in_arc]; u != join; u = _parent[u]) { + _flow[_pred[u]] += _forward[u] ? val : -val; + } + } + // Update the state of the entering and leaving arcs + if (change) { + _state[in_arc] = STATE_TREE; + _state[_pred[u_out]] = + (_flow[_pred[u_out]] == 0) ? STATE_LOWER : STATE_UPPER; + } else { + _state[in_arc] = -_state[in_arc]; + } + } + + // Update the tree structure + void updateTreeStructure() { + int old_rev_thread = _rev_thread[u_out]; + int old_succ_num = _succ_num[u_out]; + int old_last_succ = _last_succ[u_out]; + v_out = _parent[u_out]; + + // Check if u_in and u_out coincide + if (u_in == u_out) { + // Update _parent, _pred, _pred_dir + _parent[u_in] = v_in; + _pred[u_in] = in_arc; + _forward[u_in] = (u_in == _source[in_arc]); + + // Update _thread and _rev_thread + if (_thread[v_in] != u_out) { + ArcsType after = _thread[old_last_succ]; + _thread[old_rev_thread] = after; + _rev_thread[after] = old_rev_thread; + after = _thread[v_in]; + _thread[v_in] = u_out; + _rev_thread[u_out] = v_in; + _thread[old_last_succ] = after; + _rev_thread[after] = old_last_succ; + } + } else { + // Handle the case when old_rev_thread equals to v_in + // (it also means that join and v_out coincide) + int thread_continue = old_rev_thread == v_in ? + _thread[old_last_succ] : _thread[v_in]; + + // Update _thread and _parent along the stem nodes (i.e. the nodes + // between u_in and u_out, whose parent have to be changed) + int stem = u_in; // the current stem node + int par_stem = v_in; // the new parent of stem + int next_stem; // the next stem node + int last = _last_succ[u_in]; // the last successor of stem + int before, after = _thread[last]; + _thread[v_in] = u_in; + _dirty_revs.clear(); + _dirty_revs.push_back(v_in); + while (stem != u_out) { + // Insert the next stem node into the thread list + next_stem = _parent[stem]; + _thread[last] = next_stem; + _dirty_revs.push_back(last); + + // Remove the subtree of stem from the thread list + before = _rev_thread[stem]; + _thread[before] = after; + _rev_thread[after] = before; + + // Change the parent node and shift stem nodes + _parent[stem] = par_stem; + par_stem = stem; + stem = next_stem; + + // Update last and after + last = _last_succ[stem] == _last_succ[par_stem] ? + _rev_thread[par_stem] : _last_succ[stem]; + after = _thread[last]; + } + _parent[u_out] = par_stem; + _thread[last] = thread_continue; + _rev_thread[thread_continue] = last; + _last_succ[u_out] = last; + + // Remove the subtree of u_out from the thread list except for + // the case when old_rev_thread equals to v_in + if (old_rev_thread != v_in) { + _thread[old_rev_thread] = after; + _rev_thread[after] = old_rev_thread; + } + + // Update _rev_thread using the new _thread values + for (int i = 0; i != int(_dirty_revs.size()); ++i) { + int u = _dirty_revs[i]; + _rev_thread[_thread[u]] = u; + } + + // Update _pred, _pred_dir, _last_succ and _succ_num for the + // stem nodes from u_out to u_in + int tmp_sc = 0, tmp_ls = _last_succ[u_out]; + for (int u = u_out, p = _parent[u]; u != u_in; u = p, p = _parent[u]) { + _pred[u] = _pred[p]; + _forward[u] = !_forward[p]; + tmp_sc += _succ_num[u] - _succ_num[p]; + _succ_num[u] = tmp_sc; + _last_succ[p] = tmp_ls; + } + _pred[u_in] = in_arc; + _forward[u_in] = (u_in == _source[in_arc]); + _succ_num[u_in] = old_succ_num; + } + + // Update _last_succ from v_in towards the root + int up_limit_out = _last_succ[join] == v_in ? join : -1; + int last_succ_out = _last_succ[u_out]; + for (int u = v_in; u != -1 && _last_succ[u] == v_in; u = _parent[u]) { + _last_succ[u] = last_succ_out; + } + + // Update _last_succ from v_out towards the root + if (join != old_rev_thread && v_in != old_rev_thread) { + for (int u = v_out; u != up_limit_out && _last_succ[u] == old_last_succ; + u = _parent[u]) { + _last_succ[u] = old_rev_thread; + } + } else if (last_succ_out != old_last_succ) { + for (int u = v_out; u != up_limit_out && _last_succ[u] == old_last_succ; + u = _parent[u]) { + _last_succ[u] = last_succ_out; + } + } + + // Update _succ_num from v_in to join + for (int u = v_in; u != join; u = _parent[u]) { + _succ_num[u] += old_succ_num; + } + // Update _succ_num from v_out to join + for (int u = v_out; u != join; u = _parent[u]) { + _succ_num[u] -= old_succ_num; + } + } + + void updatePotential() { + Cost sigma = _pi[v_in] - _pi[u_in] - + ((_forward[u_in])?_cost[in_arc]:(-_cost[in_arc])); + int end = _thread[_last_succ[u_in]]; + for (int u = u_in; u != end; u = _thread[u]) { + _pi[u] += sigma; + } + } + + + // Heuristic initial pivots + bool initialPivots() { + Value curr, total = 0; + std::vector supply_nodes, demand_nodes; + Node u; _graph.first(u); + for (; u != INVALIDNODE; _graph.next(u)) { + curr = _supply[_node_id(u)]; + if (curr > 0) { + total += curr; + supply_nodes.push_back(u); + } else if (curr < 0) { + demand_nodes.push_back(u); + } + } + if (_sum_supply > 0) total -= _sum_supply; + if (total <= 0) return true; + + ArcVector arc_vector; + if (_sum_supply >= 0) { + if (supply_nodes.size() == 1 && demand_nodes.size() == 1) { + // Perform a reverse graph search from the sink to the source + //typename GR::template NodeMap reached(_graph, false); + BoolVector reached(_node_num, false); + Node s = supply_nodes[0], t = demand_nodes[0]; + std::vector stack; + reached[t] = true; + stack.push_back(t); + while (!stack.empty()) { + Node u, v = stack.back(); + stack.pop_back(); + if (v == s) break; + Arc a; _graph.firstIn(a, v); + for (; a != INVALID; _graph.nextIn(a)) { + if (reached[u = _graph.source(a)]) continue; + ArcsType j = getArcID(a); + arc_vector.push_back(j); + reached[u] = true; + stack.push_back(u); + } + } + } else { + arc_vector.resize(demand_nodes.size()); + // Find the min. cost incomming arc for each demand node +#pragma omp parallel for + for (int i = 0; i < demand_nodes.size(); ++i) { + Node v = demand_nodes[i]; + Cost min_cost = std::numeric_limits::max(); + Arc min_arc = INVALID; + Arc a; _graph.firstIn(a, v); + for (; a != INVALID; _graph.nextIn(a)) { + Cost c = _cost[getArcID(a)]; + if (c < min_cost) { + min_cost = c; + min_arc = a; + } + } + arc_vector[i] = getArcID(min_arc); + } + arc_vector.erase(std::remove(arc_vector.begin(), arc_vector.end(), INVALID), arc_vector.end()); + } + } else { + arc_vector.resize(supply_nodes.size()); + // Find the min. cost outgoing arc for each supply node +#pragma omp parallel for + for (int i = 0; i < int(supply_nodes.size()); ++i) { + Node u = supply_nodes[i]; + Cost min_cost = std::numeric_limits::max(); + Arc min_arc = INVALID; + Arc a; _graph.firstOut(a, u); + for (; a != INVALID; _graph.nextOut(a)) { + Cost c = _cost[getArcID(a)]; + if (c < min_cost) { + min_cost = c; + min_arc = a; + } + } + arc_vector[i] = getArcID(min_arc); + } + arc_vector.erase(std::remove(arc_vector.begin(), arc_vector.end(), INVALID), arc_vector.end()); + } + + // Perform heuristic initial pivots + for (ArcsType i = 0; i != ArcsType(arc_vector.size()); ++i) { + in_arc = arc_vector[i]; + if (_state[in_arc] * (_cost[in_arc] + _pi[_source[in_arc]] - + _pi[_target[in_arc]]) >= 0) continue; + findJoinNode(); + bool change = findLeavingArc(); + if (delta >= MAX) return false; + changeFlow(change); + if (change) { + updateTreeStructure(); + updatePotential(); + } + } + return true; + } + + // Execute the algorithm + ProblemType start() { + return start(); + } + + template + ProblemType start() { + PivotRuleImpl pivot(*this); + ProblemType retVal = OPTIMAL; + + // Perform heuristic initial pivots + if (!initialPivots()) return UNBOUNDED; + + size_t iter_number = 0; + // Execute the Network Simplex algorithm + while (pivot.findEnteringArc()) { + if ((++iter_number <= max_iter&&max_iter > 0) || max_iter<=0) { +#if DEBUG_LVL>0 + if(iter_number>MAX_DEBUG_ITER) + break; + if(iter_number%1000==0||iter_number%1000==1){ + Cost curCost=totalCost(); + Value sumFlow=0; + Cost a; + a= (fabs(_pi[_source[in_arc]])>=fabs(_pi[_target[in_arc]])) ? fabs(_pi[_source[in_arc]]) : fabs(_pi[_target[in_arc]]); + a=a>=fabs(_cost[in_arc])?a:fabs(_cost[in_arc]); + for (int i=0; i<_flow.size(); i++) { + sumFlow+=_state[i]*_flow[i]; + } + std::cout << "Sum of the flow " << std::setprecision(20) << sumFlow << "\n" << iter_number << " iterations, current cost=" << curCost << "\nReduced cost=" << _state[in_arc] * (_cost[in_arc] + _pi[_source[in_arc]] -_pi[_target[in_arc]]) << "\nPrecision = "<< -EPSILON*(a) << "\n"; + std::cout << "Arc in = (" << _node_id(_source[in_arc]) << ", " << _node_id(_target[in_arc]) <<")\n"; + std::cout << "Supplies = (" << _supply[_source[in_arc]] << ", " << _supply[_target[in_arc]] << ")\n"; + std::cout << _cost[in_arc] << "\n"; + std::cout << _pi[_source[in_arc]] << "\n"; + std::cout << _pi[_target[in_arc]] << "\n"; + std::cout << a << "\n"; + } +#endif + + findJoinNode(); + bool change = findLeavingArc(); + if (delta >= MAX) return UNBOUNDED; + changeFlow(change); + if (change) { + updateTreeStructure(); + updatePotential(); + } + +#if DEBUG_LVL>0 + else{ + std::cout << "No change\n"; + } +#endif + +#if DEBUG_LVL>1 + std::cout << "Arc in = (" << _source[in_arc] << ", " << _target[in_arc] << ")\n"; +#endif + + + } else { + char errMess[1000]; + sprintf( errMess, "RESULT MIGHT BE INACURATE\nMax number of iteration reached, currently \%d. Sometimes iterations go on in cycle even though the solution has been reached, to check if it's the case here have a look at the minimal reduced cost. If it is very close to machine precision, you might actually have the correct solution, if not try setting the maximum number of iterations a bit higher\n",iter_number ); + std::cerr << errMess; + retVal = MAX_ITER_REACHED; + break; + } + + } + + + +#if DEBUG_LVL>0 + Cost curCost=totalCost(); + Value sumFlow=0; + Cost a; + a= (fabs(_pi[_source[in_arc]])>=fabs(_pi[_target[in_arc]])) ? fabs(_pi[_source[in_arc]]) : fabs(_pi[_target[in_arc]]); + a=a>=fabs(_cost[in_arc])?a:fabs(_cost[in_arc]); + for (int i=0; i<_flow.size(); i++) { + sumFlow+=_state[i]*_flow[i]; + } + + std::cout << "Sum of the flow " << std::setprecision(20) << sumFlow << "\n" << niter << " iterations, current cost=" << curCost << "\nReduced cost=" << _state[in_arc] * (_cost[in_arc] + _pi[_source[in_arc]] -_pi[_target[in_arc]]) << "\nPrecision = "<< -EPSILON*(a) << "\n"; + + std::cout << "Arc in = (" << _node_id(_source[in_arc]) << ", " << _node_id(_target[in_arc]) <<")\n"; + std::cout << "Supplies = (" << _supply[_source[in_arc]] << ", " << _supply[_target[in_arc]] << ")\n"; + +#endif + + + +#if DEBUG_LVL>1 + sumFlow=0; + for (int i=0; i<_flow.size(); i++) { + sumFlow+=_state[i]*_flow[i]; + if (_state[i]==STATE_TREE) { + std::cout << "Non zero value at (" << _node_num+1-_source[i] << ", " << _node_num+1-_target[i] << ")\n"; + } + } + std::cout << "Sum of the flow " << sumFlow << "\n"<< niter <<" iterations, current cost=" << totalCost() << "\n"; +#endif + + + + //Check feasibility + if(retVal == OPTIMAL){ + for (ArcsType e = _search_arc_num; e != _all_arc_num; ++e) { + if (_flow[e] != 0){ + if (fabs(_flow[e]) > _EPSILON) // change of the original code following issue #126 + return INFEASIBLE; + else + _flow[e]=0; + } + } + } + + // Shift potentials to meet the requirements of the GEQ/LEQ type + // optimality conditions + if (_sum_supply == 0) { + if (_stype == GEQ) { + Cost max_pot = -std::numeric_limits::max(); + for (ArcsType i = 0; i != _node_num; ++i) { + if (_pi[i] > max_pot) max_pot = _pi[i]; + } + if (max_pot > 0) { + for (ArcsType i = 0; i != _node_num; ++i) + _pi[i] -= max_pot; + } + } else { + Cost min_pot = std::numeric_limits::max(); + for (ArcsType i = 0; i != _node_num; ++i) { + if (_pi[i] < min_pot) min_pot = _pi[i]; + } + if (min_pot < 0) { + for (ArcsType i = 0; i != _node_num; ++i) + _pi[i] -= min_pot; + } + } + } + + return retVal; + } + + }; //class NetworkSimplexSimple + + ///@} + +} //namespace lemon_omp diff --git a/ot/utils.py b/ot/utils.py index 4dac0c5..6a782e6 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -7,7 +7,6 @@ Various useful functions # # License: MIT License -import multiprocessing from functools import reduce import time @@ -311,38 +310,11 @@ def label_normalization(y, start=0): return y -def fun(f, q_in, q_out): - """ Utility function for parmap with no serializing problems """ - while True: - i, x = q_in.get() - if i is None: - break - q_out.put((i, f(x))) - - -def parmap(f, X, nprocs=multiprocessing.cpu_count()): - """ paralell map for multiprocessing (only map on windows)""" - - if not sys.platform.endswith('win32') and not sys.platform.endswith('darwin'): - - q_in = multiprocessing.Queue(1) - q_out = multiprocessing.Queue() - - proc = [multiprocessing.Process(target=fun, args=(f, q_in, q_out)) - for _ in range(nprocs)] - for p in proc: - p.daemon = True - p.start() - - sent = [q_in.put((i, x)) for i, x in enumerate(X)] - [q_in.put((None, None)) for _ in range(nprocs)] - res = [q_out.get() for _ in range(len(sent))] - - [p.join() for p in proc] - - return [x for i, x in sorted(res)] - else: - return list(map(f, X)) +def parmap(f, X, nprocs="default"): + """ paralell map for multiprocessing. + The function has been deprecated and only performs a regular map. + """ + return list(map(f, X)) def check_params(**kwargs): diff --git a/setup.py b/setup.py index 37a5824..ef95eaf 100644 --- a/setup.py +++ b/setup.py @@ -11,6 +11,8 @@ from setuptools.extension import Extension import numpy from Cython.Build import cythonize +sys.path.append(os.path.join("ot", "helpers")) +from openmp_helpers import check_openmp_support # dirty but working __version__ = re.search( @@ -30,7 +32,14 @@ if 'clean' in sys.argv[1:]: os.remove('ot/lp/emd_wrap.cpp') # add platform dependant optional compilation argument -compile_args = ["-O3"] +openmp_supported, flags = check_openmp_support() +compile_args = ["/O2" if sys.platform == "win32" else "-O3"] +link_args = [] + +if openmp_supported: + compile_args += flags + ["/DOMP" if sys.platform == 'win32' else "-DOMP"] + link_args += flags + if sys.platform.startswith('darwin'): compile_args.append("-stdlib=libc++") sdk_path = subprocess.check_output(['xcrun', '--show-sdk-path']) @@ -52,6 +61,7 @@ setup( language="c++", include_dirs=[numpy.get_include(), os.path.join(ROOT, 'ot/lp')], extra_compile_args=compile_args, + extra_link_args=link_args )), platforms=['linux', 'macosx', 'windows'], download_url='https://github.com/PythonOT/POT/archive/{}.tar.gz'.format(__version__), diff --git a/test/test_helpers.py b/test/test_helpers.py new file mode 100644 index 0000000..8bd0015 --- /dev/null +++ b/test/test_helpers.py @@ -0,0 +1,26 @@ +"""Tests for helpers functions """ + +# Author: Remi Flamary +# +# License: MIT License + +import os +import sys + +sys.path.append(os.path.join("ot", "helpers")) + +from openmp_helpers import get_openmp_flag, check_openmp_support # noqa +from pre_build_helpers import _get_compiler, compile_test_program # noqa + + +def test_helpers(): + + compiler = _get_compiler() + + get_openmp_flag(compiler) + + s = '#include \n#include \n\nint main(void) {\n\tprintf("Hello world!\\n");\n\treturn 0;\n}' + output, _ = compile_test_program(s) + assert len(output) == 1 and output[0] == "Hello world!" + + check_openmp_support() -- cgit v1.2.3 From 76450dddf8dd62b9714b72e99ae075516246d433 Mon Sep 17 00:00:00 2001 From: ncassereau-idris <84033440+ncassereau-idris@users.noreply.github.com> Date: Mon, 25 Oct 2021 17:35:36 +0200 Subject: [MRG] Backend for optim (#282) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Backend for optim * Bug solve * Doc update * backend tests now with fixture * Unused imports removed * Docs * Docs * Docs * Outer product backend docs * Prettier docs * Pep8 * Mistakes corrected Co-authored-by: Rémi Flamary --- ot/backend.py | 118 ++++++++++++++++++++++++++------------- ot/lp/__init__.py | 4 +- ot/optim.py | 155 ++++++++++++++++++++++++++++++--------------------- test/test_backend.py | 22 +++++--- test/test_optim.py | 78 ++++++++++++++++++++------ test/test_ot.py | 6 +- test/test_utils.py | 7 --- 7 files changed, 250 insertions(+), 140 deletions(-) (limited to 'ot/lp') diff --git a/ot/backend.py b/ot/backend.py index a4a4757..876b96a 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -123,7 +123,7 @@ class Backend(): r""" Creates a tensor full of zeros. - This function follow the api from :any:`numpy.zeros` + This function follows the api from :any:`numpy.zeros` See: https://numpy.org/doc/stable/reference/generated/numpy.zeros.html """ @@ -133,7 +133,7 @@ class Backend(): r""" Creates a tensor full of ones. - This function follow the api from :any:`numpy.ones` + This function follows the api from :any:`numpy.ones` See: https://numpy.org/doc/stable/reference/generated/numpy.ones.html """ @@ -143,7 +143,7 @@ class Backend(): r""" Returns evenly spaced values within a given interval. - This function follow the api from :any:`numpy.arange` + This function follows the api from :any:`numpy.arange` See: https://numpy.org/doc/stable/reference/generated/numpy.arange.html """ @@ -153,7 +153,7 @@ class Backend(): r""" Creates a tensor with given shape, filled with given value. - This function follow the api from :any:`numpy.full` + This function follows the api from :any:`numpy.full` See: https://numpy.org/doc/stable/reference/generated/numpy.full.html """ @@ -163,7 +163,7 @@ class Backend(): r""" Creates the identity matrix of given size. - This function follow the api from :any:`numpy.eye` + This function follows the api from :any:`numpy.eye` See: https://numpy.org/doc/stable/reference/generated/numpy.eye.html """ @@ -173,7 +173,7 @@ class Backend(): r""" Sums tensor elements over given dimensions. - This function follow the api from :any:`numpy.sum` + This function follows the api from :any:`numpy.sum` See: https://numpy.org/doc/stable/reference/generated/numpy.sum.html """ @@ -183,7 +183,7 @@ class Backend(): r""" Returns the cumulative sum of tensor elements over given dimensions. - This function follow the api from :any:`numpy.cumsum` + This function follows the api from :any:`numpy.cumsum` See: https://numpy.org/doc/stable/reference/generated/numpy.cumsum.html """ @@ -193,7 +193,7 @@ class Backend(): r""" Returns the maximum of an array or maximum along given dimensions. - This function follow the api from :any:`numpy.amax` + This function follows the api from :any:`numpy.amax` See: https://numpy.org/doc/stable/reference/generated/numpy.amax.html """ @@ -203,7 +203,7 @@ class Backend(): r""" Returns the maximum of an array or maximum along given dimensions. - This function follow the api from :any:`numpy.amin` + This function follows the api from :any:`numpy.amin` See: https://numpy.org/doc/stable/reference/generated/numpy.amin.html """ @@ -213,7 +213,7 @@ class Backend(): r""" Returns element-wise maximum of array elements. - This function follow the api from :any:`numpy.maximum` + This function follows the api from :any:`numpy.maximum` See: https://numpy.org/doc/stable/reference/generated/numpy.maximum.html """ @@ -223,7 +223,7 @@ class Backend(): r""" Returns element-wise minimum of array elements. - This function follow the api from :any:`numpy.minimum` + This function follows the api from :any:`numpy.minimum` See: https://numpy.org/doc/stable/reference/generated/numpy.minimum.html """ @@ -233,7 +233,7 @@ class Backend(): r""" Returns the dot product of two tensors. - This function follow the api from :any:`numpy.dot` + This function follows the api from :any:`numpy.dot` See: https://numpy.org/doc/stable/reference/generated/numpy.dot.html """ @@ -243,7 +243,7 @@ class Backend(): r""" Computes the absolute value element-wise. - This function follow the api from :any:`numpy.absolute` + This function follows the api from :any:`numpy.absolute` See: https://numpy.org/doc/stable/reference/generated/numpy.absolute.html """ @@ -253,7 +253,7 @@ class Backend(): r""" Computes the exponential value element-wise. - This function follow the api from :any:`numpy.exp` + This function follows the api from :any:`numpy.exp` See: https://numpy.org/doc/stable/reference/generated/numpy.exp.html """ @@ -263,7 +263,7 @@ class Backend(): r""" Computes the natural logarithm, element-wise. - This function follow the api from :any:`numpy.log` + This function follows the api from :any:`numpy.log` See: https://numpy.org/doc/stable/reference/generated/numpy.log.html """ @@ -273,7 +273,7 @@ class Backend(): r""" Returns the non-ngeative square root of a tensor, element-wise. - This function follow the api from :any:`numpy.sqrt` + This function follows the api from :any:`numpy.sqrt` See: https://numpy.org/doc/stable/reference/generated/numpy.sqrt.html """ @@ -283,7 +283,7 @@ class Backend(): r""" First tensor elements raised to powers from second tensor, element-wise. - This function follow the api from :any:`numpy.power` + This function follows the api from :any:`numpy.power` See: https://numpy.org/doc/stable/reference/generated/numpy.power.html """ @@ -293,7 +293,7 @@ class Backend(): r""" Computes the matrix frobenius norm. - This function follow the api from :any:`numpy.linalg.norm` + This function follows the api from :any:`numpy.linalg.norm` See: https://numpy.org/doc/stable/reference/generated/numpy.linalg.norm.html """ @@ -303,7 +303,7 @@ class Backend(): r""" Tests whether any tensor element along given dimensions evaluates to True. - This function follow the api from :any:`numpy.any` + This function follows the api from :any:`numpy.any` See: https://numpy.org/doc/stable/reference/generated/numpy.any.html """ @@ -313,7 +313,7 @@ class Backend(): r""" Tests element-wise for NaN and returns result as a boolean tensor. - This function follow the api from :any:`numpy.isnan` + This function follows the api from :any:`numpy.isnan` See: https://numpy.org/doc/stable/reference/generated/numpy.isnan.html """ @@ -323,7 +323,7 @@ class Backend(): r""" Tests element-wise for positive or negative infinity and returns result as a boolean tensor. - This function follow the api from :any:`numpy.isinf` + This function follows the api from :any:`numpy.isinf` See: https://numpy.org/doc/stable/reference/generated/numpy.isinf.html """ @@ -333,7 +333,7 @@ class Backend(): r""" Evaluates the Einstein summation convention on the operands. - This function follow the api from :any:`numpy.einsum` + This function follows the api from :any:`numpy.einsum` See: https://numpy.org/doc/stable/reference/generated/numpy.einsum.html """ @@ -343,7 +343,7 @@ class Backend(): r""" Returns a sorted copy of a tensor. - This function follow the api from :any:`numpy.sort` + This function follows the api from :any:`numpy.sort` See: https://numpy.org/doc/stable/reference/generated/numpy.sort.html """ @@ -353,7 +353,7 @@ class Backend(): r""" Returns the indices that would sort a tensor. - This function follow the api from :any:`numpy.argsort` + This function follows the api from :any:`numpy.argsort` See: https://numpy.org/doc/stable/reference/generated/numpy.argsort.html """ @@ -363,7 +363,7 @@ class Backend(): r""" Finds indices where elements should be inserted to maintain order in given tensor. - This function follow the api from :any:`numpy.searchsorted` + This function follows the api from :any:`numpy.searchsorted` See: https://numpy.org/doc/stable/reference/generated/numpy.searchsorted.html """ @@ -373,7 +373,7 @@ class Backend(): r""" Reverses the order of elements in a tensor along given dimensions. - This function follow the api from :any:`numpy.flip` + This function follows the api from :any:`numpy.flip` See: https://numpy.org/doc/stable/reference/generated/numpy.flip.html """ @@ -383,7 +383,7 @@ class Backend(): """ Limits the values in a tensor. - This function follow the api from :any:`numpy.clip` + This function follows the api from :any:`numpy.clip` See: https://numpy.org/doc/stable/reference/generated/numpy.clip.html """ @@ -393,7 +393,7 @@ class Backend(): r""" Repeats elements of a tensor. - This function follow the api from :any:`numpy.repeat` + This function follows the api from :any:`numpy.repeat` See: https://numpy.org/doc/stable/reference/generated/numpy.repeat.html """ @@ -403,7 +403,7 @@ class Backend(): r""" Gathers elements of a tensor along given dimensions. - This function follow the api from :any:`numpy.take_along_axis` + This function follows the api from :any:`numpy.take_along_axis` See: https://numpy.org/doc/stable/reference/generated/numpy.take_along_axis.html """ @@ -413,7 +413,7 @@ class Backend(): r""" Joins a sequence of tensors along an existing dimension. - This function follow the api from :any:`numpy.concatenate` + This function follows the api from :any:`numpy.concatenate` See: https://numpy.org/doc/stable/reference/generated/numpy.concatenate.html """ @@ -423,7 +423,7 @@ class Backend(): r""" Pads a tensor. - This function follow the api from :any:`numpy.pad` + This function follows the api from :any:`numpy.pad` See: https://numpy.org/doc/stable/reference/generated/numpy.pad.html """ @@ -433,7 +433,7 @@ class Backend(): r""" Returns the indices of the maximum values of a tensor along given dimensions. - This function follow the api from :any:`numpy.argmax` + This function follows the api from :any:`numpy.argmax` See: https://numpy.org/doc/stable/reference/generated/numpy.argmax.html """ @@ -443,7 +443,7 @@ class Backend(): r""" Computes the arithmetic mean of a tensor along given dimensions. - This function follow the api from :any:`numpy.mean` + This function follows the api from :any:`numpy.mean` See: https://numpy.org/doc/stable/reference/generated/numpy.mean.html """ @@ -453,7 +453,7 @@ class Backend(): r""" Computes the standard deviation of a tensor along given dimensions. - This function follow the api from :any:`numpy.std` + This function follows the api from :any:`numpy.std` See: https://numpy.org/doc/stable/reference/generated/numpy.std.html """ @@ -463,7 +463,7 @@ class Backend(): r""" Returns a specified number of evenly spaced values over a given interval. - This function follow the api from :any:`numpy.linspace` + This function follows the api from :any:`numpy.linspace` See: https://numpy.org/doc/stable/reference/generated/numpy.linspace.html """ @@ -473,7 +473,7 @@ class Backend(): r""" Returns coordinate matrices from coordinate vectors (Numpy convention). - This function follow the api from :any:`numpy.meshgrid` + This function follows the api from :any:`numpy.meshgrid` See: https://numpy.org/doc/stable/reference/generated/numpy.meshgrid.html """ @@ -483,7 +483,7 @@ class Backend(): r""" Extracts or constructs a diagonal tensor. - This function follow the api from :any:`numpy.diag` + This function follows the api from :any:`numpy.diag` See: https://numpy.org/doc/stable/reference/generated/numpy.diag.html """ @@ -493,7 +493,7 @@ class Backend(): r""" Finds unique elements of given tensor. - This function follow the api from :any:`numpy.unique` + This function follows the api from :any:`numpy.unique` See: https://numpy.org/doc/stable/reference/generated/numpy.unique.html """ @@ -503,7 +503,7 @@ class Backend(): r""" Computes the log of the sum of exponentials of input elements. - This function follow the api from :any:`scipy.special.logsumexp` + This function follows the api from :any:`scipy.special.logsumexp` See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.special.logsumexp.html """ @@ -513,12 +513,32 @@ class Backend(): r""" Joins a sequence of tensors along a new dimension. - This function follow the api from :any:`numpy.stack` + This function follows the api from :any:`numpy.stack` See: https://numpy.org/doc/stable/reference/generated/numpy.stack.html """ raise NotImplementedError() + def outer(self, a, b): + r""" + Computes the outer product between two vectors. + + This function follows the api from :any:`numpy.outer` + + See: https://numpy.org/doc/stable/reference/generated/numpy.outer.html + """ + raise NotImplementedError() + + def reshape(self, a, shape): + r""" + Gives a new shape to a tensor without changing its data. + + This function follows the api from :any:`numpy.reshape` + + See: https://numpy.org/doc/stable/reference/generated/numpy.reshape.html + """ + raise NotImplementedError() + class NumpyBackend(Backend): """ @@ -644,6 +664,9 @@ class NumpyBackend(Backend): def flip(self, a, axis=None): return np.flip(a, axis) + def outer(self, a, b): + return np.outer(a, b) + def clip(self, a, a_min, a_max): return np.clip(a, a_min, a_max) @@ -686,6 +709,9 @@ class NumpyBackend(Backend): def stack(self, arrays, axis=0): return np.stack(arrays, axis) + def reshape(self, a, shape): + return np.reshape(a, shape) + class JaxBackend(Backend): """ @@ -815,6 +841,9 @@ class JaxBackend(Backend): def flip(self, a, axis=None): return jnp.flip(a, axis) + def outer(self, a, b): + return jnp.outer(a, b) + def clip(self, a, a_min, a_max): return jnp.clip(a, a_min, a_max) @@ -857,6 +886,9 @@ class JaxBackend(Backend): def stack(self, arrays, axis=0): return jnp.stack(arrays, axis) + def reshape(self, a, shape): + return jnp.reshape(a, shape) + class TorchBackend(Backend): """ @@ -1035,6 +1067,9 @@ class TorchBackend(Backend): else: return torch.flip(a, dims=axis) + def outer(self, a, b): + return torch.outer(a, b) + def clip(self, a, a_min, a_max): return torch.clamp(a, a_min, a_max) @@ -1091,3 +1126,6 @@ class TorchBackend(Backend): def stack(self, arrays, axis=0): return torch.stack(arrays, dim=axis) + + def reshape(self, a, shape): + return torch.reshape(a, shape) diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index b907b10..c6757d1 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -281,12 +281,12 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1): a0, b0, M0 = a, b, M nx = get_backend(M0, a0, b0) - + # convert to numpy M = nx.to_numpy(M) a = nx.to_numpy(a) b = nx.to_numpy(b) - + # ensure float64 a = np.asarray(a, dtype=np.float64) b = np.asarray(b, dtype=np.float64) diff --git a/ot/optim.py b/ot/optim.py index 0359343..6822e4e 100644 --- a/ot/optim.py +++ b/ot/optim.py @@ -12,6 +12,8 @@ import numpy as np from scipy.optimize.linesearch import scalar_search_armijo from .lp import emd from .bregman import sinkhorn +from ot.utils import list_to_array +from .backend import get_backend # The corresponding scipy function does not work for matrices @@ -21,25 +23,25 @@ def line_search_armijo(f, xk, pk, gfk, old_fval, """ Armijo linesearch function that works with matrices - find an approximate minimum of f(xk+alpha*pk) that satifies the + Find an approximate minimum of :math:`f(x_k + \\alpha \cdot p_k)` that satisfies the armijo conditions. Parameters ---------- f : callable loss function - xk : ndarray + xk : array-like initial position - pk : ndarray + pk : array-like descent direction - gfk : ndarray - gradient of f at xk + gfk : array-like + gradient of `f` at :math:`x_k` old_fval : float - loss value at xk + loss value at :math:`x_k` args : tuple, optional - arguments given to f + arguments given to `f` c1 : float, optional - c1 const in armijo rule (>0) + :math:`c_1` const in armijo rule (>0) alpha0 : float, optional initial step (>0) @@ -53,7 +55,13 @@ def line_search_armijo(f, xk, pk, gfk, old_fval, loss value at step alpha """ - xk = np.atleast_1d(xk) + + xk, pk, gfk = list_to_array(xk, pk, gfk) + nx = get_backend(xk, pk) + + if len(xk.shape) == 0: + xk = nx.reshape(xk, (-1,)) + fc = [0] def phi(alpha1): @@ -65,7 +73,7 @@ def line_search_armijo(f, xk, pk, gfk, old_fval, else: phi0 = old_fval - derphi0 = np.sum(pk * gfk) # Quickfix for matrices + derphi0 = nx.sum(pk * gfk) # Quickfix for matrices alpha, phi1 = scalar_search_armijo( phi, phi0, derphi0, c1=c1, alpha0=alpha0) @@ -79,55 +87,64 @@ def solve_linesearch(cost, G, deltaG, Mi, f_val, armijo=True, C1=None, C2=None, reg=None, Gc=None, constC=None, M=None): """ Solve the linesearch in the FW iterations + Parameters ---------- cost : method Cost in the FW for the linesearch - G : ndarray, shape(ns,nt) + G : array-like, shape(ns,nt) The transport map at a given iteration of the FW - deltaG : ndarray (ns,nt) + deltaG : array-like (ns,nt) Difference between the optimal map found by linearization in the FW algorithm and the value at a given iteration - Mi : ndarray (ns,nt) + Mi : array-like (ns,nt) Cost matrix of the linearized transport problem. Corresponds to the gradient of the cost - f_val : float - Value of the cost at G + f_val : float + Value of the cost at `G` armijo : bool, optional - If True the steps of the line-search is found via an armijo research. Else closed form is used. - If there is convergence issues use False. - C1 : ndarray (ns,ns), optional + If True the steps of the line-search is found via an armijo research. Else closed form is used. + If there is convergence issues use False. + C1 : array-like (ns,ns), optional Structure matrix in the source domain. Only used and necessary when armijo=False - C2 : ndarray (nt,nt), optional + C2 : array-like (nt,nt), optional Structure matrix in the target domain. Only used and necessary when armijo=False reg : float, optional - Regularization parameter. Only used and necessary when armijo=False - Gc : ndarray (ns,nt) + Regularization parameter. Only used and necessary when armijo=False + Gc : array-like (ns,nt) Optimal map found by linearization in the FW algorithm. Only used and necessary when armijo=False - constC : ndarray (ns,nt) - Constant for the gromov cost. See [24]. Only used and necessary when armijo=False - M : ndarray (ns,nt), optional + constC : array-like (ns,nt) + Constant for the gromov cost. See :ref:`[24] `. Only used and necessary when armijo=False + M : array-like (ns,nt), optional Cost matrix between the features. Only used and necessary when armijo=False + Returns ------- alpha : float - The optimal step size of the FW + The optimal step size of the FW fc : int - nb of function call. Useless here - f_val : float - The value of the cost for the next iteration + nb of function call. Useless here + f_val : float + The value of the cost for the next iteration + + + .. _references-solve-linesearch: References ---------- - .. [24] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain - and Courty Nicolas + .. [24] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain and Courty Nicolas "Optimal Transport for structured data with application on graphs" International Conference on Machine Learning (ICML). 2019. """ if armijo: alpha, fc, f_val = line_search_armijo(cost, G, deltaG, Mi, f_val) else: # requires symetric matrices - dot1 = np.dot(C1, deltaG) - dot12 = dot1.dot(C2) - a = -2 * reg * np.sum(dot12 * deltaG) - b = np.sum((M + reg * constC) * deltaG) - 2 * reg * (np.sum(dot12 * G) + np.sum(np.dot(C1, G).dot(C2) * deltaG)) + G, deltaG, C1, C2, constC, M = list_to_array(G, deltaG, C1, C2, constC, M) + if isinstance(M, int) or isinstance(M, float): + nx = get_backend(G, deltaG, C1, C2, constC) + else: + nx = get_backend(G, deltaG, C1, C2, constC, M) + + dot = nx.dot(nx.dot(C1, deltaG), C2) + a = -2 * reg * nx.sum(dot * deltaG) + b = nx.sum((M + reg * constC) * deltaG) - 2 * reg * (nx.sum(dot * G) + nx.sum(nx.dot(nx.dot(C1, G), C2) * deltaG)) c = cost(G) alpha = solve_1d_linesearch_quad(a, b, c) @@ -145,33 +162,33 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000, The function solves the following optimization problem: .. math:: - \gamma = arg\min_\gamma <\gamma,M>_F + reg*f(\gamma) + \gamma = arg\min_\gamma <\gamma,M>_F + \mathrm{reg} \cdot f(\gamma) - s.t. \gamma 1 = a + s.t. \ \gamma 1 = a \gamma^T 1= b \gamma\geq 0 where : - - M is the (ns,nt) metric cost matrix - - :math:`f` is the regularization term ( and df is its gradient) - - a and b are source and target weights (sum to 1) + - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix + - :math:`f` is the regularization term (and `df` is its gradient) + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1) - The algorithm used for solving the problem is conditional gradient as discussed in [1]_ + The algorithm used for solving the problem is conditional gradient as discussed in :ref:`[1] ` Parameters ---------- - a : ndarray, shape (ns,) + a : array-like, shape (ns,) samples weights in the source domain - b : ndarray, shape (nt,) + b : array-like, shape (nt,) samples in the target domain - M : ndarray, shape (ns, nt) + M : array-like, shape (ns, nt) loss matrix reg : float Regularization term >0 - G0 : ndarray, shape (ns,nt), optional + G0 : array-like, shape (ns,nt), optional initial guess (default is indep joint density) numItermax : int, optional Max number of iterations @@ -196,6 +213,7 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000, log dictionary return only if log==True in parameters + .. _references-cg: References ---------- @@ -207,6 +225,11 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000, ot.bregman.sinkhorn : Entropic regularized optimal transport """ + a, b, M, G0 = list_to_array(a, b, M, G0) + if isinstance(M, int) or isinstance(M, float): + nx = get_backend(a, b) + else: + nx = get_backend(a, b, M) loop = 1 @@ -214,12 +237,12 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000, log = {'loss': []} if G0 is None: - G = np.outer(a, b) + G = nx.outer(a, b) else: G = G0 def cost(G): - return np.sum(M * G) + reg * f(G) + return nx.sum(M * G) + reg * f(G) f_val = cost(G) if log: @@ -240,7 +263,7 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000, # problem linearization Mi = M + reg * df(G) # set M positive - Mi += Mi.min() + Mi += nx.min(Mi) # solve linear program Gc = emd(a, b, Mi, numItermax=numItermaxEmd) @@ -286,36 +309,36 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10, The function solves the following optimization problem: .. math:: - \gamma = arg\min_\gamma <\gamma,M>_F + reg1\cdot\Omega(\gamma) + reg2\cdot f(\gamma) + \gamma = arg\min_\gamma <\gamma,M>_F + \mathrm{reg_1}\cdot\Omega(\gamma) + \mathrm{reg_2}\cdot f(\gamma) - s.t. \gamma 1 = a + s.t. \ \gamma 1 = a \gamma^T 1= b \gamma\geq 0 where : - - M is the (ns,nt) metric cost matrix + - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - :math:`f` is the regularization term ( and df is its gradient) - - a and b are source and target weights (sum to 1) + - :math:`f` is the regularization term (and `df` is its gradient) + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1) - The algorithm used for solving the problem is the generalized conditional gradient as discussed in [5,7]_ + The algorithm used for solving the problem is the generalized conditional gradient as discussed in :ref:`[5, 7] ` Parameters ---------- - a : ndarray, shape (ns,) + a : array-like, shape (ns,) samples weights in the source domain - b : ndarrayv (nt,) + b : array-like, (nt,) samples in the target domain - M : ndarray, shape (ns, nt) + M : array-like, shape (ns, nt) loss matrix reg1 : float Entropic Regularization term >0 reg2 : float Second Regularization term >0 - G0 : ndarray, shape (ns, nt), optional + G0 : array-like, shape (ns, nt), optional initial guess (default is indep joint density) numItermax : int, optional Max number of iterations @@ -337,9 +360,13 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10, log : dict log dictionary return only if log==True in parameters + + .. _references-gcg: References ---------- + .. [5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, "Optimal Transport for Domain Adaptation," in IEEE Transactions on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1 + .. [7] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015). Generalized conditional gradient: analysis of convergence and applications. arXiv preprint arXiv:1510.06567. See Also @@ -347,6 +374,8 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10, ot.optim.cg : conditional gradient """ + a, b, M, G0 = list_to_array(a, b, M, G0) + nx = get_backend(a, b, M) loop = 1 @@ -354,12 +383,12 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10, log = {'loss': []} if G0 is None: - G = np.outer(a, b) + G = nx.outer(a, b) else: G = G0 def cost(G): - return np.sum(M * G) + reg1 * np.sum(G * np.log(G)) + reg2 * f(G) + return nx.sum(M * G) + reg1 * nx.sum(G * nx.log(G)) + reg2 * f(G) f_val = cost(G) if log: @@ -387,7 +416,7 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10, deltaG = Gc - G # line search - dcost = Mi + reg1 * (1 + np.log(G)) # ?? + dcost = Mi + reg1 * (1 + nx.log(G)) # ?? alpha, fc, f_val = line_search_armijo(cost, G, deltaG, dcost, f_val) G = G + alpha * deltaG @@ -419,9 +448,11 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10, def solve_1d_linesearch_quad(a, b, c): """ - For any convex or non-convex 1d quadratic function f, solve on [0,1] the following problem: + For any convex or non-convex 1d quadratic function `f`, solve the following problem: + .. math:: - \argmin f(x)=a*x^{2}+b*x+c + + arg\min_{0 \leq x \leq 1} f(x) = ax^{2} + bx + c Parameters ---------- diff --git a/test/test_backend.py b/test/test_backend.py index 859da5a..5853282 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -17,9 +17,6 @@ from numpy.testing import assert_array_almost_equal_nulp from ot.backend import get_backend, get_backend_list, to_numpy -backend_list = get_backend_list() - - def test_get_backend_list(): lst = get_backend_list() @@ -28,7 +25,6 @@ def test_get_backend_list(): assert isinstance(lst[0], ot.backend.NumpyBackend) -@pytest.mark.parametrize('nx', backend_list) def test_to_numpy(nx): v = nx.zeros(10) @@ -92,7 +88,6 @@ def test_get_backend(): get_backend(A, B2) -@pytest.mark.parametrize('nx', backend_list) def test_convert_between_backends(nx): A = np.zeros((3, 2)) @@ -180,6 +175,8 @@ def test_empty_backend(): nx.searchsorted(v, v) with pytest.raises(NotImplementedError): nx.flip(M) + with pytest.raises(NotImplementedError): + nx.outer(v, v) with pytest.raises(NotImplementedError): nx.clip(M, -1, 1) with pytest.raises(NotImplementedError): @@ -208,10 +205,11 @@ def test_empty_backend(): nx.logsumexp(M) with pytest.raises(NotImplementedError): nx.stack([M, M]) + with pytest.raises(NotImplementedError): + nx.reshape(M, (5, 3, 2)) -@pytest.mark.parametrize('backend', backend_list) -def test_func_backends(backend): +def test_func_backends(nx): rnd = np.random.RandomState(0) M = rnd.randn(10, 3) @@ -220,7 +218,7 @@ def test_func_backends(backend): lst_tot = [] - for nx in [ot.backend.NumpyBackend(), backend]: + for nx in [ot.backend.NumpyBackend(), nx]: print('Backend: ', nx.__name__) @@ -371,6 +369,10 @@ def test_func_backends(backend): lst_b.append(nx.to_numpy(A)) lst_name.append('flip') + A = nx.outer(vb, vb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('outer') + A = nx.clip(vb, 0, 1) lst_b.append(nx.to_numpy(A)) lst_name.append('clip') @@ -432,6 +434,10 @@ def test_func_backends(backend): lst_b.append(nx.to_numpy(A)) lst_name.append('stack') + A = nx.reshape(Mb, (5, 3, 2)) + lst_b.append(nx.to_numpy(A)) + lst_name.append('reshape') + lst_tot.append(lst_b) lst_np = lst_tot[0] diff --git a/test/test_optim.py b/test/test_optim.py index 94995d5..4efd9b1 100644 --- a/test/test_optim.py +++ b/test/test_optim.py @@ -8,7 +8,7 @@ import numpy as np import ot -def test_conditional_gradient(): +def test_conditional_gradient(nx): n_bins = 100 # nb bins np.random.seed(0) @@ -29,15 +29,25 @@ def test_conditional_gradient(): def df(G): return G + def fb(G): + return 0.5 * nx.sum(G ** 2) + + ab = nx.from_numpy(a) + bb = nx.from_numpy(b) + Mb = nx.from_numpy(M, type_as=ab) + reg = 1e-1 G, log = ot.optim.cg(a, b, M, reg, f, df, verbose=True, log=True) + Gb, log = ot.optim.cg(ab, bb, Mb, reg, fb, df, verbose=True, log=True) + Gb = nx.to_numpy(Gb) - np.testing.assert_allclose(a, G.sum(1)) - np.testing.assert_allclose(b, G.sum(0)) + np.testing.assert_allclose(Gb, G) + np.testing.assert_allclose(a, Gb.sum(1)) + np.testing.assert_allclose(b, Gb.sum(0)) -def test_conditional_gradient_itermax(): +def test_conditional_gradient_itermax(nx): n = 100 # nb samples mu_s = np.array([0, 0]) @@ -61,16 +71,27 @@ def test_conditional_gradient_itermax(): def df(G): return G + def fb(G): + return 0.5 * nx.sum(G ** 2) + + ab = nx.from_numpy(a) + bb = nx.from_numpy(b) + Mb = nx.from_numpy(M, type_as=ab) + reg = 1e-1 G, log = ot.optim.cg(a, b, M, reg, f, df, numItermaxEmd=10000, verbose=True, log=True) + Gb, log = ot.optim.cg(ab, bb, Mb, reg, fb, df, numItermaxEmd=10000, + verbose=True, log=True) + Gb = nx.to_numpy(Gb) - np.testing.assert_allclose(a, G.sum(1)) - np.testing.assert_allclose(b, G.sum(0)) + np.testing.assert_allclose(Gb, G) + np.testing.assert_allclose(a, Gb.sum(1)) + np.testing.assert_allclose(b, Gb.sum(0)) -def test_generalized_conditional_gradient(): +def test_generalized_conditional_gradient(nx): n_bins = 100 # nb bins np.random.seed(0) @@ -91,13 +112,23 @@ def test_generalized_conditional_gradient(): def df(G): return G + def fb(G): + return 0.5 * nx.sum(G ** 2) + reg1 = 1e-3 reg2 = 1e-1 + ab = nx.from_numpy(a) + bb = nx.from_numpy(b) + Mb = nx.from_numpy(M, type_as=ab) + G, log = ot.optim.gcg(a, b, M, reg1, reg2, f, df, verbose=True, log=True) + Gb, log = ot.optim.gcg(ab, bb, Mb, reg1, reg2, fb, df, verbose=True, log=True) + Gb = nx.to_numpy(Gb) - np.testing.assert_allclose(a, G.sum(1), atol=1e-05) - np.testing.assert_allclose(b, G.sum(0), atol=1e-05) + np.testing.assert_allclose(Gb, G) + np.testing.assert_allclose(a, Gb.sum(1), atol=1e-05) + np.testing.assert_allclose(b, Gb.sum(0), atol=1e-05) def test_solve_1d_linesearch_quad_funct(): @@ -106,24 +137,31 @@ def test_solve_1d_linesearch_quad_funct(): np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad(-1, 0.5, 0), 1) -def test_line_search_armijo(): +def test_line_search_armijo(nx): xk = np.array([[0.25, 0.25], [0.25, 0.25]]) pk = np.array([[-0.25, 0.25], [0.25, -0.25]]) gfk = np.array([[23.04273441, 23.0449082], [23.04273441, 23.0449082]]) old_fval = -123 # Should not throw an exception and return None for alpha - alpha, _, _ = ot.optim.line_search_armijo(lambda x: 1, xk, pk, gfk, old_fval) + alpha, a, b = ot.optim.line_search_armijo( + lambda x: 1, nx.from_numpy(xk), nx.from_numpy(pk), nx.from_numpy(gfk), old_fval + ) + alpha_np, anp, bnp = ot.optim.line_search_armijo( + lambda x: 1, xk, pk, gfk, old_fval + ) + assert a == anp + assert b == bnp assert alpha is None # check line search armijo def f(x): - return np.sum((x - 5.0) ** 2) + return nx.sum((x - 5.0) ** 2) def grad(x): return 2 * (x - 5.0) - xk = np.array([[[-5.0, -5.0]]]) - pk = np.array([[[100.0, 100.0]]]) + xk = nx.from_numpy(np.array([[[-5.0, -5.0]]])) + pk = nx.from_numpy(np.array([[[100.0, 100.0]]])) gfk = grad(xk) old_fval = f(xk) @@ -132,10 +170,18 @@ def test_line_search_armijo(): np.testing.assert_allclose(alpha, 0.1) # check the case where the direction is not far enough - pk = np.array([[[3.0, 3.0]]]) + pk = nx.from_numpy(np.array([[[3.0, 3.0]]])) alpha, _, _ = ot.optim.line_search_armijo(f, xk, pk, gfk, old_fval, alpha0=1.0) np.testing.assert_allclose(alpha, 1.0) - # check the case where the checking the wrong direction + # check the case where checking the wrong direction alpha, _, _ = ot.optim.line_search_armijo(f, xk, -pk, gfk, old_fval) assert alpha <= 0 + + # check the case where the point is not a vector + xk = nx.from_numpy(np.array(-5.0)) + pk = nx.from_numpy(np.array(100.0)) + gfk = grad(xk) + old_fval = f(xk) + alpha, _, _ = ot.optim.line_search_armijo(f, xk, pk, gfk, old_fval) + np.testing.assert_allclose(alpha, 0.1) diff --git a/test/test_ot.py b/test/test_ot.py index 3e953dc..4dfc510 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -12,9 +12,7 @@ from scipy.stats import wasserstein_distance import ot from ot.datasets import make_1D_gauss as gauss -from ot.backend import get_backend_list, torch - -backend_list = get_backend_list() +from ot.backend import torch def test_emd_dimension_and_mass_mismatch(): @@ -37,7 +35,6 @@ def test_emd_dimension_and_mass_mismatch(): np.testing.assert_raises(AssertionError, ot.emd, a, b, M) -@pytest.mark.parametrize('nx', backend_list) def test_emd_backends(nx): n_samples = 100 n_features = 2 @@ -59,7 +56,6 @@ def test_emd_backends(nx): np.allclose(G, nx.to_numpy(Gb)) -@pytest.mark.parametrize('nx', backend_list) def test_emd2_backends(nx): n_samples = 100 n_features = 2 diff --git a/test/test_utils.py b/test/test_utils.py index 76b1faa..60ad5d3 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -4,17 +4,11 @@ # # License: MIT License -import pytest import ot import numpy as np import sys -from ot.backend import get_backend_list -backend_list = get_backend_list() - - -@pytest.mark.parametrize('nx', backend_list) def test_proj_simplex(nx): n = 10 rng = np.random.RandomState(0) @@ -119,7 +113,6 @@ def test_dist(): np.testing.assert_allclose(D, D3, atol=1e-14) -@ pytest.mark.parametrize('nx', backend_list) def test_dist_backends(nx): n = 100 -- cgit v1.2.3 From a335324d008e8982be61d7ace937815a2bfa98f9 Mon Sep 17 00:00:00 2001 From: ncassereau-idris <84033440+ncassereau-idris@users.noreply.github.com> Date: Tue, 2 Nov 2021 13:42:02 +0100 Subject: [MRG] Backend for gromov (#294) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * bregman: small correction * gromov backend first draft * Removing decorators * Reworked casting method * Bug solve * Removing casting * Bug solve * toarray renamed todense ; expand_dims removed * Warning (jax not supporting sparse matrix) moved * Mistake corrected * test backend * Sparsity test for older versions of pytorch * Trying pytorch/1.10 * Attempt to correct torch sparse bug * Backend version of gromov tests * Random state introduced for remaining gromov functions * review changes * code coverage * Docs (first draft, to be continued) * Gromov docs * Prettified docs * mistake corrected in the docs * little change Co-authored-by: Rémi Flamary --- ot/backend.py | 214 ++++++++- ot/bregman.py | 184 ++++---- ot/gromov.py | 1220 +++++++++++++++++++++++++++----------------------- ot/lp/__init__.py | 58 ++- ot/optim.py | 22 +- test/test_backend.py | 56 +++ test/test_bregman.py | 4 +- test/test_gromov.py | 297 ++++++++---- 8 files changed, 1289 insertions(+), 766 deletions(-) (limited to 'ot/lp') diff --git a/ot/backend.py b/ot/backend.py index 876b96a..358297c 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -26,6 +26,7 @@ Examples import numpy as np import scipy.special as scipy +from scipy.sparse import issparse, coo_matrix, csr_matrix try: import torch @@ -539,6 +540,86 @@ class Backend(): """ raise NotImplementedError() + def coo_matrix(self, data, rows, cols, shape=None, type_as=None): + r""" + Creates a sparse tensor in COOrdinate format. + + This function follows the api from :any:`scipy.sparse.coo_matrix` + + See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.coo_matrix.html + """ + raise NotImplementedError() + + def issparse(self, a): + r""" + Checks whether or not the input tensor is a sparse tensor. + + This function follows the api from :any:`scipy.sparse.issparse` + + See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.issparse.html + """ + raise NotImplementedError() + + def tocsr(self, a): + r""" + Converts this matrix to Compressed Sparse Row format. + + This function follows the api from :any:`scipy.sparse.coo_matrix.tocsr` + + See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.coo_matrix.tocsr.html + """ + raise NotImplementedError() + + def eliminate_zeros(self, a, threshold=0.): + r""" + Removes entries smaller than the given threshold from the sparse tensor. + + This function follows the api from :any:`scipy.sparse.csr_matrix.eliminate_zeros` + + See: https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.sparse.csr_matrix.eliminate_zeros.html + """ + raise NotImplementedError() + + def todense(self, a): + r""" + Converts a sparse tensor to a dense tensor. + + This function follows the api from :any:`scipy.sparse.csr_matrix.toarray` + + See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.csr_matrix.toarray.html + """ + raise NotImplementedError() + + def where(self, condition, x, y): + r""" + Returns elements chosen from x or y depending on condition. + + This function follows the api from :any:`numpy.where` + + See: https://numpy.org/doc/stable/reference/generated/numpy.where.html + """ + raise NotImplementedError() + + def copy(self, a): + r""" + Returns a copy of the given tensor. + + This function follows the api from :any:`numpy.copy` + + See: https://numpy.org/doc/stable/reference/generated/numpy.copy.html + """ + raise NotImplementedError() + + def allclose(self, a, b, rtol=1e-05, atol=1e-08, equal_nan=False): + r""" + Returns True if two arrays are element-wise equal within a tolerance. + + This function follows the api from :any:`numpy.allclose` + + See: https://numpy.org/doc/stable/reference/generated/numpy.allclose.html + """ + raise NotImplementedError() + class NumpyBackend(Backend): """ @@ -712,6 +793,46 @@ class NumpyBackend(Backend): def reshape(self, a, shape): return np.reshape(a, shape) + def coo_matrix(self, data, rows, cols, shape=None, type_as=None): + if type_as is None: + return coo_matrix((data, (rows, cols)), shape=shape) + else: + return coo_matrix((data, (rows, cols)), shape=shape, dtype=type_as.dtype) + + def issparse(self, a): + return issparse(a) + + def tocsr(self, a): + if self.issparse(a): + return a.tocsr() + else: + return csr_matrix(a) + + def eliminate_zeros(self, a, threshold=0.): + if threshold > 0: + if self.issparse(a): + a.data[self.abs(a.data) <= threshold] = 0 + else: + a[self.abs(a) <= threshold] = 0 + if self.issparse(a): + a.eliminate_zeros() + return a + + def todense(self, a): + if self.issparse(a): + return a.toarray() + else: + return a + + def where(self, condition, x, y): + return np.where(condition, x, y) + + def copy(self, a): + return a.copy() + + def allclose(self, a, b, rtol=1e-05, atol=1e-08, equal_nan=False): + return np.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan) + class JaxBackend(Backend): """ @@ -889,6 +1010,48 @@ class JaxBackend(Backend): def reshape(self, a, shape): return jnp.reshape(a, shape) + def coo_matrix(self, data, rows, cols, shape=None, type_as=None): + # Currently, JAX does not support sparse matrices + data = self.to_numpy(data) + rows = self.to_numpy(rows) + cols = self.to_numpy(cols) + nx = NumpyBackend() + coo_matrix = nx.coo_matrix(data, rows, cols, shape=shape, type_as=type_as) + matrix = nx.todense(coo_matrix) + return self.from_numpy(matrix) + + def issparse(self, a): + # Currently, JAX does not support sparse matrices + return False + + def tocsr(self, a): + # Currently, JAX does not support sparse matrices + return a + + def eliminate_zeros(self, a, threshold=0.): + # Currently, JAX does not support sparse matrices + if threshold > 0: + return self.where( + self.abs(a) <= threshold, + self.zeros((1,), type_as=a), + a + ) + return a + + def todense(self, a): + # Currently, JAX does not support sparse matrices + return a + + def where(self, condition, x, y): + return jnp.where(condition, x, y) + + def copy(self, a): + # No need to copy, JAX arrays are immutable + return a + + def allclose(self, a, b, rtol=1e-05, atol=1e-08, equal_nan=False): + return jnp.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan) + class TorchBackend(Backend): """ @@ -999,7 +1162,7 @@ class TorchBackend(Backend): a = torch.tensor([float(a)], dtype=b.dtype, device=b.device) if isinstance(b, int) or isinstance(b, float): b = torch.tensor([float(b)], dtype=a.dtype, device=a.device) - if torch.__version__ >= '1.7.0': + if hasattr(torch, "maximum"): return torch.maximum(a, b) else: return torch.max(torch.stack(torch.broadcast_tensors(a, b)), axis=0)[0] @@ -1009,7 +1172,7 @@ class TorchBackend(Backend): a = torch.tensor([float(a)], dtype=b.dtype, device=b.device) if isinstance(b, int) or isinstance(b, float): b = torch.tensor([float(b)], dtype=a.dtype, device=a.device) - if torch.__version__ >= '1.7.0': + if hasattr(torch, "minimum"): return torch.minimum(a, b) else: return torch.min(torch.stack(torch.broadcast_tensors(a, b)), axis=0)[0] @@ -1129,3 +1292,50 @@ class TorchBackend(Backend): def reshape(self, a, shape): return torch.reshape(a, shape) + + def coo_matrix(self, data, rows, cols, shape=None, type_as=None): + if type_as is None: + return torch.sparse_coo_tensor(torch.stack([rows, cols]), data, size=shape) + else: + return torch.sparse_coo_tensor( + torch.stack([rows, cols]), data, size=shape, + dtype=type_as.dtype, device=type_as.device + ) + + def issparse(self, a): + return getattr(a, "is_sparse", False) or getattr(a, "is_sparse_csr", False) + + def tocsr(self, a): + # Versions older than 1.9 do not support CSR tensors. PyTorch 1.9 and 1.10 offer a very limited support + return self.todense(a) + + def eliminate_zeros(self, a, threshold=0.): + if self.issparse(a): + if threshold > 0: + mask = self.abs(a) <= threshold + mask = ~mask + mask = mask.nonzero() + else: + mask = a._values().nonzero() + nv = a._values().index_select(0, mask.view(-1)) + ni = a._indices().index_select(1, mask.view(-1)) + return self.coo_matrix(nv, ni[0], ni[1], shape=a.shape, type_as=a) + else: + if threshold > 0: + a[self.abs(a) <= threshold] = 0 + return a + + def todense(self, a): + if self.issparse(a): + return a.to_dense() + else: + return a + + def where(self, condition, x, y): + return torch.where(condition, x, y) + + def copy(self, a): + return torch.clone(a) + + def allclose(self, a, b, rtol=1e-05, atol=1e-08, equal_nan=False): + return torch.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan) diff --git a/ot/bregman.py b/ot/bregman.py index 2aa76ff..0499b8e 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -32,13 +32,14 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, The function solves the following optimization problem: .. math:: - \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + \gamma = \mathop{\arg \min}_\gamma <\gamma, \mathbf{M}>_F + \mathrm{reg}\cdot\Omega(\gamma) - s.t. \ \gamma 1 = a + s.t. \ \gamma \mathbf{1} &= \mathbf{a} - \gamma^T 1= b + \gamma^T \mathbf{1} &= \mathbf{b} + + \gamma &\geq 0 - \gamma\geq 0 where : - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix @@ -167,13 +168,14 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, The function solves the following optimization problem: .. math:: - W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + W = \min_\gamma <\gamma, \mathbf{M}>_F + \mathrm{reg}\cdot\Omega(\gamma) + + s.t. \ \gamma \mathbf{1} &= \mathbf{a} - s.t. \ \gamma 1 = a + \gamma^T \mathbf{1} &= \mathbf{b} - \gamma^T 1= b + \gamma &\geq 0 - \gamma\geq 0 where : - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix @@ -323,13 +325,13 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, The function solves the following optimization problem: .. math:: - \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + \gamma = \mathop{\arg \min}_\gamma <\gamma, \mathbf{M}>_F + \mathrm{reg}\cdot\Omega(\gamma) - s.t. \gamma 1 = a + s.t. \ \gamma \mathbf{1} &= \mathbf{a} - \gamma^T 1= b + \gamma^T \mathbf{1} &= \mathbf{b} - \gamma\geq 0 + \gamma &\geq 0 where : - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix @@ -489,13 +491,13 @@ def sinkhorn_log(a, b, M, reg, numItermax=1000, The function solves the following optimization problem: .. math:: - \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + \gamma = \mathop{\arg \min}_\gamma <\gamma, \mathbf{M}>_F + \mathrm{reg}\cdot\Omega(\gamma) - s.t. \gamma 1 = a + s.t. \ \gamma \mathbf{1} &= \mathbf{a} - \gamma^T 1= b + \gamma^T \mathbf{1} &= \mathbf{b} - \gamma\geq 0 + \gamma &\geq 0 where : - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix @@ -550,8 +552,7 @@ def sinkhorn_log(a, b, M, reg, numItermax=1000, References ---------- - .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal - Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 .. [34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., Trouvé, A., & Peyré, G. (2019, April). Interpolating between optimal transport and MMD using Sinkhorn divergences. In The 22nd International Conference on Artificial Intelligence and Statistics (pp. 2681-2690). PMLR. @@ -675,13 +676,13 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, The function solves the following optimization problem: .. math:: - \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + \gamma = \mathop{\arg \min}_\gamma <\gamma, \mathbf{M}>_F + \mathrm{reg}\cdot\Omega(\gamma) - s.t. \ \gamma 1 = a + s.t. \ \gamma \mathbf{1} &= \mathbf{a} - \gamma^T 1= b + \gamma^T \mathbf{1} &= \mathbf{b} - \gamma\geq 0 + \gamma &\geq 0 where : - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix @@ -820,13 +821,13 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, The function solves the following optimization problem: .. math:: - \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + \gamma = \mathop{\arg \min}_\gamma <\gamma, \mathbf{M}>_F + \mathrm{reg}\cdot\Omega(\gamma) - s.t. \ \gamma 1 = a + s.t. \ \gamma \mathbf{1} &= \mathbf{a} - \gamma^T 1= b + \gamma^T \mathbf{1} &= \mathbf{b} - \gamma\geq 0 + \gamma &\geq 0 where : - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix @@ -965,7 +966,7 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, # remove numerical problems and store them in K if nx.max(nx.abs(u)) > tau or nx.max(nx.abs(v)) > tau: if n_hists: - alpha, beta = alpha + reg * nx.max(nx.log(u), 1), beta + reg * nx.max(np.log(v)) + alpha, beta = alpha + reg * nx.max(nx.log(u), 1), beta + reg * nx.max(nx.log(v)) else: alpha, beta = alpha + reg * nx.log(u), beta + reg * nx.log(v) if n_hists: @@ -1055,13 +1056,13 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, The function solves the following optimization problem: .. math:: - \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + \gamma = \mathop{\arg \min}_\gamma <\gamma, \mathbf{M}>_F + \mathrm{reg}\cdot\Omega(\gamma) - s.t. \ \gamma 1 = a + s.t. \ \gamma \mathbf{1} &= \mathbf{a} - \gamma^T 1= b + \gamma^T \mathbf{1} &= \mathbf{b} - \gamma\geq 0 + \gamma &\geq 0 where : - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix @@ -1245,12 +1246,12 @@ def projC(gamma, q): def barycenter(A, M, reg, weights=None, method="sinkhorn", numItermax=10000, stopThr=1e-4, verbose=False, log=False, **kwargs): - r"""Compute the entropic regularized wasserstein barycenter of distributions A + r"""Compute the entropic regularized wasserstein barycenter of distributions :math:`\mathbf{A}` The function solves the following optimization problem: .. math:: - \mathbf{a} = arg\min_\mathbf{a} \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i) + \mathbf{a} = \mathop{\arg \min}_\mathbf{a} \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i) where : @@ -1263,7 +1264,7 @@ def barycenter(A, M, reg, weights=None, method="sinkhorn", numItermax=10000, Parameters ---------- A : array-like, shape (dim, n_hists) - `n_hists` training distributions :math:`a_i` of size `dim` + `n_hists` training distributions :math:`\mathbf{a}_i` of size `dim` M : array-like, shape (dim, dim) loss matrix for OT reg : float @@ -1271,7 +1272,7 @@ def barycenter(A, M, reg, weights=None, method="sinkhorn", numItermax=10000, method : str (optional) method used for the solver either 'sinkhorn' or 'sinkhorn_stabilized' weights : array-like, shape (n_hists,) - Weights of each histogram :math:`a_i` on the simplex (barycentric coodinates) + Weights of each histogram :math:`\mathbf{a}_i` on the simplex (barycentric coodinates) numItermax : int, optional Max number of iterations stopThr : float, optional @@ -1314,12 +1315,12 @@ def barycenter(A, M, reg, weights=None, method="sinkhorn", numItermax=10000, def barycenter_sinkhorn(A, M, reg, weights=None, numItermax=1000, stopThr=1e-4, verbose=False, log=False): - r"""Compute the entropic regularized wasserstein barycenter of distributions A + r"""Compute the entropic regularized wasserstein barycenter of distributions :math:`\mathbf{A}` The function solves the following optimization problem: .. math:: - \mathbf{a} = arg\min_\mathbf{a} \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i) + \mathbf{a} = \mathop{\arg \min}_\mathbf{a} \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i) where : @@ -1332,13 +1333,13 @@ def barycenter_sinkhorn(A, M, reg, weights=None, numItermax=1000, Parameters ---------- A : array-like, shape (dim, n_hists) - `n_hists` training distributions :math:`a_i` of size `dim` + `n_hists` training distributions :math:`\mathbf{a}_i` of size `dim` M : array-like, shape (dim, dim) loss matrix for OT reg : float Regularization term > 0 weights : array-like, shape (n_hists,) - Weights of each histogram :math:`a_i` on the simplex (barycentric coodinates) + Weights of each histogram :math:`\mathbf{a}_i` on the simplex (barycentric coodinates) numItermax : int, optional Max number of iterations stopThr : float, optional @@ -1414,12 +1415,12 @@ def barycenter_sinkhorn(A, M, reg, weights=None, numItermax=1000, def barycenter_stabilized(A, M, reg, tau=1e10, weights=None, numItermax=1000, stopThr=1e-4, verbose=False, log=False): - r"""Compute the entropic regularized wasserstein barycenter of distributions A with stabilization. + r"""Compute the entropic regularized wasserstein barycenter of distributions :math:`\mathbf{A}` with stabilization. The function solves the following optimization problem: .. math:: - \mathbf{a} = arg\min_\mathbf{a} \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i) + \mathbf{a} = \mathop{\arg \min}_\mathbf{a} \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i) where : @@ -1432,7 +1433,7 @@ def barycenter_stabilized(A, M, reg, tau=1e10, weights=None, numItermax=1000, Parameters ---------- A : array-like, shape (dim, n_hists) - `n_hists` training distributions :math:`a_i` of size `dim` + `n_hists` training distributions :math:`\mathbf{a}_i` of size `dim` M : array-like, shape (dim, dim) loss matrix for OT reg : float @@ -1440,7 +1441,7 @@ def barycenter_stabilized(A, M, reg, tau=1e10, weights=None, numItermax=1000, tau : float threshold for max value in :math:`\mathbf{u}` or :math:`\mathbf{v}` for log scaling weights : array-like, shape (n_hists,) - Weights of each histogram :math:`a_i` on the simplex (barycentric coodinates) + Weights of each histogram :math:`\mathbf{a}_i` on the simplex (barycentric coodinates) numItermax : int, optional Max number of iterations stopThr : float, optional @@ -1533,8 +1534,8 @@ def barycenter_stabilized(A, M, reg, tau=1e10, weights=None, numItermax=1000, "Or a larger absorption threshold `tau`.") if log: log['niter'] = cpt - log['logu'] = np.log(u + 1e-16) - log['logv'] = np.log(v + 1e-16) + log['logu'] = nx.log(u + 1e-16) + log['logv'] = nx.log(v + 1e-16) return q, log else: return q @@ -1543,13 +1544,13 @@ def barycenter_stabilized(A, M, reg, tau=1e10, weights=None, numItermax=1000, def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, stopThr=1e-9, stabThr=1e-30, verbose=False, log=False): - r"""Compute the entropic regularized wasserstein barycenter of distributions A - where A is a collection of 2D images. + r"""Compute the entropic regularized wasserstein barycenter of distributions :math:`\mathbf{A}` + where :math:`\mathbf{A}` is a collection of 2D images. The function solves the following optimization problem: .. math:: - \mathbf{a} = arg\min_\mathbf{a} \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i) + \mathbf{a} = \mathop{\arg \min}_\mathbf{a} \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i) where : @@ -1673,12 +1674,12 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000, .. math:: - \mathbf{h} = arg\min_\mathbf{h} (1- \alpha) W_{M,reg}(\mathbf{a},\mathbf{Dh})+\alpha W_{M_0,reg_0}(\mathbf{h}_0,\mathbf{h}) + \mathbf{h} = \mathop{\arg \min}_\mathbf{h} (1- \alpha) W_{\mathbf{M}, \mathrm{reg}}(\mathbf{a},\mathbf{Dh})+\alpha W_{\mathbf{M_0},\mathrm{reg}_0}(\mathbf{h}_0,\mathbf{h}) where : - - :math:`W_{M,reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance with M loss matrix (see :py:func:`ot.bregman.sinkhorn`) + - :math:`W_{M,reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance with :math:`\mathbf{M}` loss matrix (see :py:func:`ot.bregman.sinkhorn`) - :math:`\mathbf{D}` is a dictionary of `n_atoms` atoms of dimension `dim_a`, its expected shape is `(dim_a, n_atoms)` - :math:`\mathbf{h}` is the estimated unmixing of dimension `n_atoms` - :math:`\mathbf{a}` is an observed distribution of dimension `dim_a` @@ -1790,7 +1791,7 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, .. math:: - \mathbf{h} = arg\min_{\mathbf{h}}\quad \sum_{k=1}^{K} \lambda_k + \mathbf{h} = \mathop{\arg \min}_{\mathbf{h}} \sum_{k=1}^{K} \lambda_k W_{reg}((\mathbf{D}_2^{(k)} \mathbf{h})^T, \mathbf{a}) s.t. \ \forall k, \mathbf{D}_1^{(k)} \gamma_k \mathbf{1}_n= \mathbf{h} @@ -1898,7 +1899,7 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, K.append(Ktmp) # uniform target distribution - a = nx.from_numpy(unif(np.shape(Xt)[0])) + a = nx.from_numpy(unif(Xt.shape[0]), type_as=Xs[0]) cpt = 0 # iterations count err = 1 @@ -1956,13 +1957,13 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', The function solves the following optimization problem: .. math:: - \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + \gamma = \mathop{\arg \min}_\gamma <\gamma, \mathbf{M}>_F + \mathrm{reg} \cdot\Omega(\gamma) - s.t. \ \gamma 1 = a + s.t. \ \gamma \mathbf{1} &= a - \gamma^T 1= b + \gamma^T \mathbf{1} &= b - \gamma\geq 0 + \gamma &\geq 0 where : - :math:`\mathbf{M}` is the (`n_samples_a`, `n_samples_b`) metric cost matrix @@ -2010,8 +2011,8 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', >>> n_samples_a = 2 >>> n_samples_b = 2 >>> reg = 0.1 - >>> X_s = np.reshape(np.arange(n_samples_a), (n_samples_a, 1)) - >>> X_t = np.reshape(np.arange(0, n_samples_b), (n_samples_b, 1)) + >>> X_s = np.reshape(np.arange(n_samples_a, dtype=np.float64), (n_samples_a, 1)) + >>> X_t = np.reshape(np.arange(0, n_samples_b, dtype=np.float64), (n_samples_b, 1)) >>> empirical_sinkhorn(X_s, X_t, reg=reg, verbose=False) # doctest: +NORMALIZE_WHITESPACE array([[4.99977301e-01, 2.26989344e-05], [2.26989344e-05, 4.99977301e-01]]) @@ -2033,9 +2034,9 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', ns, nt = X_s.shape[0], X_t.shape[0] if a is None: - a = nx.from_numpy(unif(ns)) + a = nx.from_numpy(unif(ns), type_as=X_s) if b is None: - b = nx.from_numpy(unif(nt)) + b = nx.from_numpy(unif(nt), type_as=X_s) if isLazy: if log: @@ -2127,13 +2128,13 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num The function solves the following optimization problem: .. math:: - W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + W = \min_\gamma <\gamma, \mathbf{M}>_F + \mathrm{reg} \cdot\Omega(\gamma) - s.t. \ \gamma 1 = a + s.t. \ \gamma \mathbf{1} &= a - \gamma^T 1= b + \gamma^T \mathbf{1} &= b - \gamma\geq 0 + \gamma &\geq 0 where : - :math:`\mathbf{M}` is the (`n_samples_a`, `n_samples_b`) metric cost matrix @@ -2181,8 +2182,8 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num >>> n_samples_a = 2 >>> n_samples_b = 2 >>> reg = 0.1 - >>> X_s = np.reshape(np.arange(n_samples_a), (n_samples_a, 1)) - >>> X_t = np.reshape(np.arange(0, n_samples_b), (n_samples_b, 1)) + >>> X_s = np.reshape(np.arange(n_samples_a, dtype=np.float64), (n_samples_a, 1)) + >>> X_t = np.reshape(np.arange(0, n_samples_b, dtype=np.float64), (n_samples_b, 1)) >>> b = np.full((n_samples_b, 3), 1/n_samples_b) >>> empirical_sinkhorn2(X_s, X_t, b=b, reg=reg, verbose=False) array([4.53978687e-05, 4.53978687e-05, 4.53978687e-05]) @@ -2204,9 +2205,9 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num ns, nt = X_s.shape[0], X_t.shape[0] if a is None: - a = nx.from_numpy(unif(ns)) + a = nx.from_numpy(unif(ns), type_as=X_s) if b is None: - b = nx.from_numpy(unif(nt)) + b = nx.from_numpy(unif(nt), type_as=X_s) if isLazy: if log: @@ -2259,32 +2260,32 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli .. math:: - W &= \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + W &= \min_\gamma <\gamma, \mathbf{M}>_F + \mathrm{reg} \cdot\Omega(\gamma) - W_a &= \min_{\gamma_a} <\gamma_a,M_a>_F + reg\cdot\Omega(\gamma_a) + W_a &= \min_{\gamma_a} <\gamma_a, \mathbf{M_a}>_F + \mathrm{reg} \cdot\Omega(\gamma_a) - W_b &= \min_{\gamma_b} <\gamma_b,M_b>_F + reg\cdot\Omega(\gamma_b) + W_b &= \min_{\gamma_b} <\gamma_b, \mathbf{M_b}>_F + \mathrm{reg} \cdot\Omega(\gamma_b) S &= W - \frac{W_a + W_b}{2} .. math:: - s.t. \ \gamma 1 = a + s.t. \ \gamma \mathbf{1} &= a - \gamma^T 1= b + \gamma^T \mathbf{1} &= b - \gamma\geq 0 + \gamma &\geq 0 - \gamma_a 1 = a + \gamma_a \mathbf{1} &= \mathbf{a} - \gamma_a^T 1= a + \gamma_a^T \mathbf{1} &= \mathbf{a} - \gamma_a\geq 0 + \gamma_a &\geq 0 - \gamma_b 1 = b + \gamma_b \mathbf{1} &= \mathbf{b} - \gamma_b^T 1= b + \gamma_b^T \mathbf{1} &= \mathbf{b} - \gamma_b\geq 0 + \gamma_b &\geq 0 where : - :math:`\mathbf{M}` (resp. :math:`\mathbf{M_a}`, :math:`\mathbf{M_b}`) is the (`n_samples_a`, `n_samples_b`) metric cost matrix (resp (`n_samples_a, n_samples_a`) and (`n_samples_b`, `n_samples_b`)) @@ -2325,8 +2326,8 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli >>> n_samples_a = 2 >>> n_samples_b = 4 >>> reg = 0.1 - >>> X_s = np.reshape(np.arange(n_samples_a), (n_samples_a, 1)) - >>> X_t = np.reshape(np.arange(0, n_samples_b), (n_samples_b, 1)) + >>> X_s = np.reshape(np.arange(n_samples_a, dtype=np.float64), (n_samples_a, 1)) + >>> X_t = np.reshape(np.arange(0, n_samples_b, dtype=np.float64), (n_samples_b, 1)) >>> empirical_sinkhorn_divergence(X_s, X_t, reg) # doctest: +ELLIPSIS 1.499887176049052 @@ -2380,19 +2381,19 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res .. math:: - (u, v) = arg\min_{u, v} 1_{ns}^T B(u,v) 1_{nt} - <\kappa u, a> - + (\mathbf{u}, \mathbf{v}) = \mathop{\arg \min}_{\mathbf{u}, \mathbf{v}} \ \mathbf{1}_{ns}^T \mathbf{B}(\mathbf{u}, \mathbf{v}) \mathbf{1}_{nt} - <\kappa \mathbf{u}, \mathbf{a}> - <\mathbf{v} / \kappa, \mathbf{b}> where: .. math:: - B(u,v) = \mathrm{diag}(e^u) K \mathrm{diag}(e^v) \text{, with } K = e^{-M/reg} \text{ and} + \mathbf{B}(\mathbf{u}, \mathbf{v}) = \mathrm{diag}(e^\mathbf{u}) \mathbf{K} \mathrm{diag}(e^\mathbf{v}) \text{, with } \mathbf{K} = e^{-\mathbf{M} / \mathrm{reg}} \text{ and} .. math:: - s.t. \ e^{u_i} \geq \epsilon / \kappa, \forall i \in \{1, \ldots, ns\} + s.t. \ e^{u_i} &\geq \epsilon / \kappa, \forall i \in \{1, \ldots, ns\} - e^{v_j} \geq \epsilon \kappa, \forall j \in \{1, \ldots, nt\} + e^{v_j} &\geq \epsilon \kappa, \forall j \in \{1, \ldots, nt\} The parameters `kappa` and `epsilon` are determined w.r.t the couple number budget of points (`ns_budget`, `nt_budget`), see Equation (5) in :ref:`[26] ` @@ -2531,7 +2532,8 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res epsilon_u_square = a[0] / aK_sort[ns_budget - 1] else: aK_sort = nx.from_numpy( - bottleneck.partition(nx.to_numpy(K_sum_cols), ns_budget - 1)[ns_budget - 1] + bottleneck.partition(nx.to_numpy(K_sum_cols), ns_budget - 1)[ns_budget - 1], + type_as=M ) epsilon_u_square = a[0] / aK_sort @@ -2540,7 +2542,8 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res epsilon_v_square = b[0] / bK_sort[nt_budget - 1] else: bK_sort = nx.from_numpy( - bottleneck.partition(nx.to_numpy(K_sum_rows), nt_budget - 1)[nt_budget - 1] + bottleneck.partition(nx.to_numpy(K_sum_rows), nt_budget - 1)[nt_budget - 1], + type_as=M ) epsilon_v_square = b[0] / bK_sort else: @@ -2589,10 +2592,9 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res K_IcJ = K[np.ix_(Ic, Jsel)] K_IJc = K[np.ix_(Isel, Jc)] - #K_min = K_IJ.min() K_min = nx.min(K_IJ) if K_min == 0: - K_min = np.finfo(float).tiny + K_min = float(np.finfo(float).tiny) # a_I, b_J, a_Ic, b_Jc a_I = a[Isel] @@ -2713,7 +2715,7 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res maxfun=maxfun, pgtol=pgtol, maxiter=maxiter) - theta = nx.from_numpy(theta) + theta = nx.from_numpy(theta, type_as=M) usc = theta[:ns_budget] vsc = theta[ns_budget:] diff --git a/ot/gromov.py b/ot/gromov.py index 33b4453..a0fbf48 100644 --- a/ot/gromov.py +++ b/ot/gromov.py @@ -14,67 +14,85 @@ import numpy as np from .bregman import sinkhorn -from .utils import dist, UndefinedParameter +from .utils import dist, UndefinedParameter, list_to_array from .optim import cg from .lp import emd_1d, emd from .utils import check_random_state - -from scipy.sparse import issparse +from .backend import get_backend def init_matrix(C1, C2, p, q, loss_fun='square_loss'): r"""Return loss matrices and tensors for Gromov-Wasserstein fast computation - Returns the value of \mathcal{L}(C1,C2) \otimes T with the selected loss - function as the loss function of Gromow-Wasserstein discrepancy. + Returns the value of :math:`\mathcal{L}(\mathbf{C_1}, \mathbf{C_2}) \otimes \mathbf{T}` with the + selected loss function as the loss function of Gromow-Wasserstein discrepancy. - The matrices are computed as described in Proposition 1 in [12] + The matrices are computed as described in Proposition 1 in :ref:`[12] ` Where : - * C1 : Metric cost matrix in the source space - * C2 : Metric cost matrix in the target space - * T : A coupling between those two spaces - - The square-loss function L(a,b)=|a-b|^2 is read as : - L(a,b) = f1(a)+f2(b)-h1(a)*h2(b) with : - * f1(a)=(a^2) - * f2(b)=(b^2) - * h1(a)=a - * h2(b)=2*b - - The kl-loss function L(a,b)=a*log(a/b)-a+b is read as : - L(a,b) = f1(a)+f2(b)-h1(a)*h2(b) with : - * f1(a)=a*log(a)-a - * f2(b)=b - * h1(a)=a - * h2(b)=log(b) + + - :math:`\mathbf{C_1}`: Metric cost matrix in the source space + - :math:`\mathbf{C_2}`: Metric cost matrix in the target space + - :math:`\mathbf{T}`: A coupling between those two spaces + + The square-loss function :math:`L(a, b) = |a - b|^2` is read as : + + .. math:: + + L(a, b) = f_1(a) + f_2(b) - h_1(a) h_2(b) + + \mathrm{with} \ f_1(a) &= a^2 + + f_2(b) &= b^2 + + h_1(a) &= a + + h_2(b) &= 2b + + The kl-loss function :math:`L(a, b) = a \log\left(\frac{a}{b}\right) - a + b` is read as : + + .. math:: + + L(a, b) = f_1(a) + f_2(b) - h_1(a) h_2(b) + + \mathrm{with} \ f_1(a) &= a \log(a) - a + + f_2(b) &= b + + h_1(a) &= a + + h_2(b) &= \log(b) Parameters ---------- - C1 : ndarray, shape (ns, ns) + C1 : array-like, shape (ns, ns) Metric cost matrix in the source space - C2 : ndarray, shape (nt, nt) - Metric costfr matrix in the target space - T : ndarray, shape (ns, nt) + C2 : array-like, shape (nt, nt) + Metric cost matrix in the target space + T : array-like, shape (ns, nt) Coupling between source and target spaces - p : ndarray, shape (ns,) + p : array-like, shape (ns,) Returns ------- - constC : ndarray, shape (ns, nt) - Constant C matrix in Eq. (6) - hC1 : ndarray, shape (ns, ns) - h1(C1) matrix in Eq. (6) - hC2 : ndarray, shape (nt, nt) - h2(C) matrix in Eq. (6) + constC : array-like, shape (ns, nt) + Constant :math:`\mathbf{C}` matrix in Eq. (6) + hC1 : array-like, shape (ns, ns) + :math:`\mathbf{h1}(\mathbf{C1})` matrix in Eq. (6) + hC2 : array-like, shape (nt, nt) + :math:`\mathbf{h2}(\mathbf{C2})` matrix in Eq. (6) + + .. _references-init-matrix: References ---------- - .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon, + .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon, "Gromov-Wasserstein averaging of kernel and distance matrices." International Conference on Machine Learning (ICML). 2016. """ + C1, C2, p, q = list_to_array(C1, C2, p, q) + nx = get_backend(C1, C2, p, q) if loss_fun == 'square_loss': def f1(a): @@ -90,7 +108,7 @@ def init_matrix(C1, C2, p, q, loss_fun='square_loss'): return 2 * b elif loss_fun == 'kl_loss': def f1(a): - return a * np.log(a + 1e-15) - a + return a * nx.log(a + 1e-15) - a def f2(b): return b @@ -99,12 +117,16 @@ def init_matrix(C1, C2, p, q, loss_fun='square_loss'): return a def h2(b): - return np.log(b + 1e-15) - - constC1 = np.dot(np.dot(f1(C1), p.reshape(-1, 1)), - np.ones(len(q)).reshape(1, -1)) - constC2 = np.dot(np.ones(len(p)).reshape(-1, 1), - np.dot(q.reshape(1, -1), f2(C2).T)) + return nx.log(b + 1e-15) + + constC1 = nx.dot( + nx.dot(f1(C1), nx.reshape(p, (-1, 1))), + nx.ones((1, len(q)), type_as=q) + ) + constC2 = nx.dot( + nx.ones((len(p), 1), type_as=p), + nx.dot(nx.reshape(q, (1, -1)), f2(C2).T) + ) constC = constC1 + constC2 hC1 = h1(C1) hC2 = h2(C2) @@ -115,30 +137,37 @@ def init_matrix(C1, C2, p, q, loss_fun='square_loss'): def tensor_product(constC, hC1, hC2, T): r"""Return the tensor for Gromov-Wasserstein fast computation - The tensor is computed as described in Proposition 1 Eq. (6) in [12]. + The tensor is computed as described in Proposition 1 Eq. (6) in :ref:`[12] ` Parameters ---------- - constC : ndarray, shape (ns, nt) - Constant C matrix in Eq. (6) - hC1 : ndarray, shape (ns, ns) - h1(C1) matrix in Eq. (6) - hC2 : ndarray, shape (nt, nt) - h2(C) matrix in Eq. (6) + constC : array-like, shape (ns, nt) + Constant :math:`\mathbf{C}` matrix in Eq. (6) + hC1 : array-like, shape (ns, ns) + :math:`\mathbf{h1}(\mathbf{C1})` matrix in Eq. (6) + hC2 : array-like, shape (nt, nt) + :math:`\mathbf{h2}(\mathbf{C2})` matrix in Eq. (6) Returns ------- - tens : ndarray, shape (ns, nt) - \mathcal{L}(C1,C2) \otimes T tensor-matrix multiplication result + tens : array-like, shape (`ns`, `nt`) + :math:`\mathcal{L}(\mathbf{C_1}, \mathbf{C_2}) \otimes \mathbf{T}` tensor-matrix multiplication result + + .. _references-tensor-product: References ---------- - .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon, + .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon, "Gromov-Wasserstein averaging of kernel and distance matrices." International Conference on Machine Learning (ICML). 2016. """ - A = -np.dot(hC1, T).dot(hC2.T) + constC, hC1, hC2, T = list_to_array(constC, hC1, hC2, T) + nx = get_backend(constC, hC1, hC2, T) + + A = - nx.dot( + nx.dot(hC1, T), hC2.T + ) tens = constC + A # tens -= tens.min() return tens @@ -147,27 +176,29 @@ def tensor_product(constC, hC1, hC2, T): def gwloss(constC, hC1, hC2, T): """Return the Loss for Gromov-Wasserstein - The loss is computed as described in Proposition 1 Eq. (6) in [12]. + The loss is computed as described in Proposition 1 Eq. (6) in :ref:`[12] ` Parameters ---------- - constC : ndarray, shape (ns, nt) - Constant C matrix in Eq. (6) - hC1 : ndarray, shape (ns, ns) - h1(C1) matrix in Eq. (6) - hC2 : ndarray, shape (nt, nt) - h2(C) matrix in Eq. (6) - T : ndarray, shape (ns, nt) - Current value of transport matrix T + constC : array-like, shape (ns, nt) + Constant :math:`\mathbf{C}` matrix in Eq. (6) + hC1 : array-like, shape (ns, ns) + :math:`\mathbf{h1}(\mathbf{C1})` matrix in Eq. (6) + hC2 : array-like, shape (nt, nt) + :math:`\mathbf{h2}(\mathbf{C2})` matrix in Eq. (6) + T : array-like, shape (ns, nt) + Current value of transport matrix :math:`\mathbf{T}` Returns ------- loss : float Gromov Wasserstein loss + + .. _references-gwloss: References ---------- - .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon, + .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon, "Gromov-Wasserstein averaging of kernel and distance matrices." International Conference on Machine Learning (ICML). 2016. @@ -175,33 +206,38 @@ def gwloss(constC, hC1, hC2, T): tens = tensor_product(constC, hC1, hC2, T) - return np.sum(tens * T) + tens, T = list_to_array(tens, T) + nx = get_backend(tens, T) + + return nx.sum(tens * T) def gwggrad(constC, hC1, hC2, T): """Return the gradient for Gromov-Wasserstein - The gradient is computed as described in Proposition 2 in [12]. + The gradient is computed as described in Proposition 2 in :ref:`[12] ` Parameters ---------- - constC : ndarray, shape (ns, nt) - Constant C matrix in Eq. (6) - hC1 : ndarray, shape (ns, ns) - h1(C1) matrix in Eq. (6) - hC2 : ndarray, shape (nt, nt) - h2(C) matrix in Eq. (6) - T : ndarray, shape (ns, nt) - Current value of transport matrix T + constC : array-like, shape (ns, nt) + Constant :math:`\mathbf{C}` matrix in Eq. (6) + hC1 : array-like, shape (ns, ns) + :math:`\mathbf{h1}(\mathbf{C1})` matrix in Eq. (6) + hC2 : array-like, shape (nt, nt) + :math:`\mathbf{h2}(\mathbf{C2})` matrix in Eq. (6) + T : array-like, shape (ns, nt) + Current value of transport matrix :math:`\mathbf{T}` Returns ------- - grad : ndarray, shape (ns, nt) + grad : array-like, shape (`ns`, `nt`) Gromov Wasserstein gradient + + .. _references-gwggrad: References ---------- - .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon, + .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon, "Gromov-Wasserstein averaging of kernel and distance matrices." International Conference on Machine Learning (ICML). 2016. @@ -212,88 +248,107 @@ def gwggrad(constC, hC1, hC2, T): def update_square_loss(p, lambdas, T, Cs): """ - Updates C according to the L2 Loss kernel with the S Ts couplings - calculated at each iteration + Updates :math:`\mathbf{C}` according to the L2 Loss kernel with the `S` :math:`\mathbf{T}_s` + couplings calculated at each iteration Parameters ---------- - p : ndarray, shape (N,) + p : array-like, shape (N,) Masses in the targeted barycenter. lambdas : list of float - List of the S spaces' weights. - T : list of S np.ndarray of shape (ns,N) - The S Ts couplings calculated at each iteration. - Cs : list of S ndarray, shape(ns,ns) + List of the `S` spaces' weights. + T : list of S array-like of shape (ns,N) + The `S` :math:`\mathbf{T}_s` couplings calculated at each iteration. + Cs : list of S array-like, shape(ns,ns) Metric cost matrices. Returns ---------- - C : ndarray, shape (nt, nt) - Updated C matrix. + C : array-like, shape (`nt`, `nt`) + Updated :math:`\mathbf{C}` matrix. """ - tmpsum = sum([lambdas[s] * np.dot(T[s].T, Cs[s]).dot(T[s]) - for s in range(len(T))]) - ppt = np.outer(p, p) + T = list_to_array(*T) + Cs = list_to_array(*Cs) + p = list_to_array(p) + nx = get_backend(p, *T, *Cs) + + tmpsum = sum([ + lambdas[s] * nx.dot( + nx.dot(T[s].T, Cs[s]), + T[s] + ) for s in range(len(T)) + ]) + ppt = nx.outer(p, p) - return np.divide(tmpsum, ppt) + return tmpsum / ppt def update_kl_loss(p, lambdas, T, Cs): """ - Updates C according to the KL Loss kernel with the S Ts couplings calculated at each iteration + Updates :math:`\mathbf{C}` according to the KL Loss kernel with the `S` :math:`\mathbf{T}_s` couplings calculated at each iteration Parameters ---------- - p : ndarray, shape (N,) + p : array-like, shape (N,) Weights in the targeted barycenter. - lambdas : list of the S spaces' weights - T : list of S np.ndarray of shape (ns,N) - The S Ts couplings calculated at each iteration. - Cs : list of S ndarray, shape(ns,ns) + lambdas : list of float + List of the `S` spaces' weights + T : list of S array-like of shape (ns,N) + The `S` :math:`\mathbf{T}_s` couplings calculated at each iteration. + Cs : list of S array-like, shape(ns,ns) Metric cost matrices. Returns ---------- - C : ndarray, shape (ns,ns) - updated C matrix + C : array-like, shape (`ns`, `ns`) + updated :math:`\mathbf{C}` matrix """ - tmpsum = sum([lambdas[s] * np.dot(T[s].T, Cs[s]).dot(T[s]) - for s in range(len(T))]) - ppt = np.outer(p, p) + Cs = list_to_array(*Cs) + T = list_to_array(*T) + p = list_to_array(p) + nx = get_backend(p, *T, *Cs) - return np.exp(np.divide(tmpsum, ppt)) + tmpsum = sum([ + lambdas[s] * nx.dot( + nx.dot(T[s].T, Cs[s]), + T[s] + ) for s in range(len(T)) + ]) + ppt = nx.outer(p, p) + + return nx.exp(tmpsum / ppt) def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs): r""" - Returns the gromov-wasserstein transport between (C1,p) and (C2,q) + Returns the gromov-wasserstein transport between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` The function solves the following optimization problem: .. math:: - GW = \min_T \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l} + \mathbf{GW} = \mathop{\arg \min}_\mathbf{T} \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} Where : - - C1 : Metric cost matrix in the source space - - C2 : Metric cost matrix in the target space - - p : distribution in the source space - - q : distribution in the target space - - L : loss function to account for the misfit between the similarity matrices + + - :math:`\mathbf{C_1}`: Metric cost matrix in the source space + - :math:`\mathbf{C_2}`: Metric cost matrix in the target space + - :math:`\mathbf{p}`: distribution in the source space + - :math:`\mathbf{q}`: distribution in the target space + - `L`: loss function to account for the misfit between the similarity matrices Parameters ---------- - C1 : ndarray, shape (ns, ns) + C1 : array-like, shape (ns, ns) Metric cost matrix in the source space - C2 : ndarray, shape (nt, nt) - Metric costfr matrix in the target space - p : ndarray, shape (ns,) + C2 : array-like, shape (nt, nt) + Metric cost matrix in the target space + p : array-like, shape (ns,) Distribution in the source space - q : ndarray, shape (nt,) + q : array-like, shape (nt,) Distribution in the target space loss_fun : str loss function used for the solver either 'square_loss' or 'kl_loss' - max_iter : int, optional Max number of iterations tol : float, optional @@ -303,22 +358,23 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs log : bool, optional record log if True armijo : bool, optional - If True the steps of the line-search is found via an armijo research. Else closed form is used. - If there is convergence issues use False. + If True the step of the line-search is found via an armijo research. Else closed form is used. + If there are convergence issues use False. **kwargs : dict parameters can be directly passed to the ot.optim.cg solver Returns ------- - T : ndarray, shape (ns, nt) - Doupling between the two spaces that minimizes: - \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l} + T : array-like, shape (`ns`, `nt`) + Coupling between the two spaces that minimizes: + + :math:`\sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l}` log : dict Convergence information and loss. References ---------- - .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon, + .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon, "Gromov-Wasserstein averaging of kernel and distance matrices." International Conference on Machine Learning (ICML). 2016. @@ -327,6 +383,7 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs mathematics 11.4 (2011): 417-487. """ + p, q = list_to_array(p, q) constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun) @@ -348,29 +405,30 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs): r""" - Returns the gromov-wasserstein discrepancy between (C1,p) and (C2,q) + Returns the gromov-wasserstein discrepancy between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` The function solves the following optimization problem: .. math:: - GW = \min_T \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l} + GW = \min_\mathbf{T} \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} Where : - - C1 : Metric cost matrix in the source space - - C2 : Metric cost matrix in the target space - - p : distribution in the source space - - q : distribution in the target space - - L : loss function to account for the misfit between the similarity matrices + + - :math:`\mathbf{C_1}`: Metric cost matrix in the source space + - :math:`\mathbf{C_2}`: Metric cost matrix in the target space + - :math:`\mathbf{p}`: distribution in the source space + - :math:`\mathbf{q}`: distribution in the target space + - `L`: loss function to account for the misfit between the similarity matrices Parameters ---------- - C1 : ndarray, shape (ns, ns) + C1 : array-like, shape (ns, ns) Metric cost matrix in the source space - C2 : ndarray, shape (nt, nt) + C2 : array-like, shape (nt, nt) Metric cost matrix in the target space - p : ndarray, shape (ns,) + p : array-like, shape (ns,) Distribution in the source space. - q : ndarray, shape (nt,) + q : array-like, shape (nt,) Distribution in the target space. loss_fun : str loss function used for the solver either 'square_loss' or 'kl_loss' @@ -383,8 +441,8 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwarg log : bool, optional record log if True armijo : bool, optional - If True the steps of the line-search is found via an armijo research. Else closed form is used. - If there is convergence issues use False. + If True the step of the line-search is found via an armijo research. Else closed form is used. + If there are convergence issues use False. Returns ------- @@ -395,7 +453,7 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwarg References ---------- - .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon, + .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon, "Gromov-Wasserstein averaging of kernel and distance matrices." International Conference on Machine Learning (ICML). 2016. @@ -404,6 +462,7 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwarg mathematics 11.4 (2011): 417-487. """ + p, q = list_to_array(p, q) constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun) @@ -425,42 +484,45 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwarg def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, **kwargs): r""" - Computes the FGW transport between two graphs see [24] + Computes the FGW transport between two graphs (see :ref:`[24] `) .. math:: - \gamma = arg\min_\gamma (1-\\alpha)*<\gamma,M>_F + \\alpha* \sum_{i,j,k,l} - L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l} + \gamma = \mathop{\arg \min}_\gamma (1 - \alpha) <\gamma, \mathbf{M}>_F + \alpha \sum_{i,j,k,l} + L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} + + s.t. \ \mathbf{\gamma} \mathbf{1} &= \mathbf{p} - s.t. \gamma 1 = p - \gamma^T 1= q - \gamma\geq 0 + \mathbf{\gamma}^T \mathbf{1} &= \mathbf{q} + + \mathbf{\gamma} &\geq 0 where : - - M is the (ns,nt) metric cost matrix - - p and q are source and target weights (sum to 1) - - L is a loss function to account for the misfit between the similarity matrices - The algorithm used for solving the problem is conditional gradient as discussed in [24]_ + - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix + - :math:`\mathbf{p}` and :math:`\mathbf{q}` are source and target weights (sum to 1) + - `L` is a loss function to account for the misfit between the similarity matrices + + The algorithm used for solving the problem is conditional gradient as discussed in :ref:`[24] ` Parameters ---------- - M : ndarray, shape (ns, nt) + M : array-like, shape (ns, nt) Metric cost matrix between features across domains - C1 : ndarray, shape (ns, ns) + C1 : array-like, shape (ns, ns) Metric cost matrix representative of the structure in the source space - C2 : ndarray, shape (nt, nt) + C2 : array-like, shape (nt, nt) Metric cost matrix representative of the structure in the target space - p : ndarray, shape (ns,) + p : array-like, shape (ns,) Distribution in the source space - q : ndarray, shape (nt,) + q : array-like, shape (nt,) Distribution in the target space loss_fun : str, optional Loss function used for the solver alpha : float, optional Trade-off parameter (0 < alpha < 1) armijo : bool, optional - If True the steps of the line-search is found via an armijo research. Else closed form is used. - If there is convergence issues use False. + If True the step of the line-search is found via an armijo research. Else closed form is used. + If there are convergence issues use False. log : bool, optional record log if True **kwargs : dict @@ -468,18 +530,21 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, Returns ------- - gamma : ndarray, shape (ns, nt) + gamma : array-like, shape (`ns`, `nt`) Optimal transportation matrix for the given parameters. log : dict Log dictionary return only if log==True in parameters. + + .. _references-fused-gromov-wasserstein: References ---------- - .. [24] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain + .. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain and Courty Nicolas "Optimal Transport for structured data with application on graphs", International Conference on Machine Learning (ICML). 2019. """ + p, q = list_to_array(p, q) constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun) @@ -501,61 +566,67 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, **kwargs): r""" - Computes the FGW distance between two graphs see [24] + Computes the FGW distance between two graphs see (see :ref:`[24] `) .. math:: - \min_\gamma (1-\\alpha)*<\gamma,M>_F + \\alpha* \sum_{i,j,k,l} - L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l} + \min_\gamma (1 - \alpha) <\gamma, \mathbf{M}>_F + \alpha \sum_{i,j,k,l} + L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} + s.t. \ \mathbf{\gamma} \mathbf{1} &= \mathbf{p} - s.t. \gamma 1 = p - \gamma^T 1= q - \gamma\geq 0 + \mathbf{\gamma}^T \mathbf{1} &= \mathbf{q} + + \mathbf{\gamma} &\geq 0 where : - - M is the (ns,nt) metric cost matrix - - p and q are source and target weights (sum to 1) - - L is a loss function to account for the misfit between the similarity matrices - The algorithm used for solving the problem is conditional gradient as discussed in [1]_ + + - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix + - :math:`\mathbf{p}` and :math:`\mathbf{q}` are source and target weights (sum to 1) + - `L` is a loss function to account for the misfit between the similarity matrices + + The algorithm used for solving the problem is conditional gradient as discussed in :ref:`[24] ` Parameters ---------- - M : ndarray, shape (ns, nt) + M : array-like, shape (ns, nt) Metric cost matrix between features across domains - C1 : ndarray, shape (ns, ns) + C1 : array-like, shape (ns, ns) Metric cost matrix respresentative of the structure in the source space. - C2 : ndarray, shape (nt, nt) + C2 : array-like, shape (nt, nt) Metric cost matrix espresentative of the structure in the target space. - p : ndarray, shape (ns,) + p : array-like, shape (ns,) Distribution in the source space. - q : ndarray, shape (nt,) + q : array-like, shape (nt,) Distribution in the target space. loss_fun : str, optional Loss function used for the solver. alpha : float, optional Trade-off parameter (0 < alpha < 1) armijo : bool, optional - If True the steps of the line-search is found via an armijo research. - Else closed form is used. If there is convergence issues use False. + If True the step of the line-search is found via an armijo research. + Else closed form is used. If there are convergence issues use False. log : bool, optional Record log if True. **kwargs : dict - Parameters can be directly pased to the ot.optim.cg solver. + Parameters can be directly passed to the ot.optim.cg solver. Returns ------- - gamma : ndarray, shape (ns, nt) + gamma : array-like, shape (ns, nt) Optimal transportation matrix for the given parameters. log : dict Log dictionary return only if log==True in parameters. + + .. _references-fused-gromov-wasserstein2: References ---------- - .. [24] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain + .. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain and Courty Nicolas "Optimal Transport for structured data with application on graphs" International Conference on Machine Learning (ICML). 2019. """ + p, q = list_to_array(p, q) constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun) @@ -579,60 +650,64 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5 def GW_distance_estimation(C1, C2, p, q, loss_fun, T, nb_samples_p=None, nb_samples_q=None, std=True, random_state=None): r""" - Returns an approximation of the gromov-wasserstein cost between (C1,p) and (C2,q) - with a fixed transport plan T. - - The function gives an unbiased approximation of the following equation: - - .. math:: - GW = \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l} - - Where : - - - C1 : Metric cost matrix in the source space - - C2 : Metric cost matrix in the target space - - L : Loss function to account for the misfit between the similarity matrices - - T : Matrix with marginal p and q - - Parameters - ---------- - C1 : ndarray, shape (ns, ns) - Metric cost matrix in the source space - C2 : ndarray, shape (nt, nt) - Metric costfr matrix in the target space - p : ndarray, shape (ns,) - Distribution in the source space - q : ndarray, shape (nt,) - Distribution in the target space - loss_fun : function: \mathcal{R} \times \mathcal{R} \shortarrow \mathcal{R} - Loss function used for the distance, the transport plan does not depend on the loss function - T : csr or ndarray, shape (ns, nt) - Transport plan matrix, either a sparse csr matrix or - nb_samples_p : int, optional - nb_samples_p is the number of samples (without replacement) along the first dimension of T. - nb_samples_q : int, optional - nb_samples_q is the number of samples along the second dimension of T, for each sample along the first. - std : bool, optional - Standard deviation associated with the prediction of the gromov-wasserstein cost. - random_state : int or RandomState instance, optional - Fix the seed for to allow reproducibility - - Returns - ------- - : float - Gromov-wasserstein cost - - References - ---------- - .. [14] Kerdoncuff, Tanguy, Emonet, Rémi, Sebban, Marc - "Sampled Gromov Wasserstein." - Machine Learning Journal (MLJ). 2021. - - """ + Returns an approximation of the gromov-wasserstein cost between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` + with a fixed transport plan :math:`\mathbf{T}`. + + The function gives an unbiased approximation of the following equation: + + .. math:: + + GW = \sum_{i,j,k,l} L(\mathbf{C_{1}}_{i,k}, \mathbf{C_{2}}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} + + Where : + + - :math:`\mathbf{C_1}`: Metric cost matrix in the source space + - :math:`\mathbf{C_2}`: Metric cost matrix in the target space + - `L` : Loss function to account for the misfit between the similarity matrices + - :math:`\mathbf{T}`: Matrix with marginal :math:`\mathbf{p}` and :math:`\mathbf{q}` + + Parameters + ---------- + C1 : array-like, shape (ns, ns) + Metric cost matrix in the source space + C2 : array-like, shape (nt, nt) + Metric cost matrix in the target space + p : array-like, shape (ns,) + Distribution in the source space + q : array-like, shape (nt,) + Distribution in the target space + loss_fun : function: :math:`\mathbb{R} \times \mathbb{R} \mapsto \mathbb{R}` + Loss function used for the distance, the transport plan does not depend on the loss function + T : csr or array-like, shape (ns, nt) + Transport plan matrix, either a sparse csr or a dense matrix + nb_samples_p : int, optional + `nb_samples_p` is the number of samples (without replacement) along the first dimension of :math:`\mathbf{T}` + nb_samples_q : int, optional + `nb_samples_q` is the number of samples along the second dimension of :math:`\mathbf{T}`, for each sample along the first + std : bool, optional + Standard deviation associated with the prediction of the gromov-wasserstein cost + random_state : int or RandomState instance, optional + Fix the seed for reproducibility + + Returns + ------- + : float + Gromov-wasserstein cost + + References + ---------- + .. [14] Kerdoncuff, Tanguy, Emonet, Rémi, Sebban, Marc + "Sampled Gromov Wasserstein." + Machine Learning Journal (MLJ). 2021. + + """ + C1, C2, p, q = list_to_array(C1, C2, p, q) + nx = get_backend(C1, C2, p, q) + generator = check_random_state(random_state) - len_p = len(p) - len_q = len(q) + len_p = p.shape[0] + len_q = q.shape[0] # It is always better to sample from the biggest distribution first. if len_p < len_q: @@ -642,7 +717,7 @@ def GW_distance_estimation(C1, C2, p, q, loss_fun, T, T = T.T if nb_samples_p is None: - if issparse(T): + if nx.issparse(T): # If T is sparse, it probably mean that PoGroW was used, thus the number of sample is reduced nb_samples_p = min(int(5 * (len_p * np.log(len_p)) ** 0.5), len_p) else: @@ -657,100 +732,112 @@ def GW_distance_estimation(C1, C2, p, q, loss_fun, T, index_k = np.zeros((nb_samples_p, nb_samples_q), dtype=int) index_l = np.zeros((nb_samples_p, nb_samples_q), dtype=int) - list_value_sample = np.zeros((nb_samples_p, nb_samples_p, nb_samples_q)) index_i = generator.choice(len_p, size=nb_samples_p, p=p, replace=False) index_j = generator.choice(len_p, size=nb_samples_p, p=p, replace=False) for i in range(nb_samples_p): - if issparse(T): - T_indexi = T[index_i[i], :].toarray()[0] - T_indexj = T[index_j[i], :].toarray()[0] + if nx.issparse(T): + T_indexi = nx.reshape(nx.todense(T[index_i[i], :]), (-1,)) + T_indexj = nx.reshape(nx.todense(T[index_j[i], :]), (-1,)) else: T_indexi = T[index_i[i], :] T_indexj = T[index_j[i], :] # For each of the row sampled, the column is sampled. - index_k[i] = generator.choice(len_q, size=nb_samples_q, p=T_indexi / T_indexi.sum(), replace=True) - index_l[i] = generator.choice(len_q, size=nb_samples_q, p=T_indexj / T_indexj.sum(), replace=True) - - for n in range(nb_samples_q): - list_value_sample[:, :, n] = loss_fun(C1[np.ix_(index_i, index_j)], C2[np.ix_(index_k[:, n], index_l[:, n])]) + index_k[i] = generator.choice( + len_q, + size=nb_samples_q, + p=T_indexi / nx.sum(T_indexi), + replace=True + ) + index_l[i] = generator.choice( + len_q, + size=nb_samples_q, + p=T_indexj / nx.sum(T_indexj), + replace=True + ) + + list_value_sample = nx.stack([ + loss_fun( + C1[np.ix_(index_i, index_j)], + C2[np.ix_(index_k[:, n], index_l[:, n])] + ) for n in range(nb_samples_q) + ], axis=2) if std: - std_value = np.sum(np.std(list_value_sample, axis=2) ** 2) ** 0.5 - return np.mean(list_value_sample), std_value / (nb_samples_p * nb_samples_p) + std_value = nx.sum(nx.std(list_value_sample, axis=2) ** 2) ** 0.5 + return nx.mean(list_value_sample), std_value / (nb_samples_p * nb_samples_p) else: - return np.mean(list_value_sample) + return nx.mean(list_value_sample) def pointwise_gromov_wasserstein(C1, C2, p, q, loss_fun, alpha=1, max_iter=100, threshold_plan=0, log=False, verbose=False, random_state=None): r""" - Returns the gromov-wasserstein transport between (C1,p) and (C2,q) using a stochastic Frank-Wolfe. - This method as a O(max_iter \times PN^2) time complexity with P the number of Sinkhorn iterations. - - The function solves the following optimization problem: - - .. math:: - GW = arg\min_T \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l} - - s.t. T 1 = p - - T^T 1= q - - T\geq 0 - - Where : - - - C1 : Metric cost matrix in the source space - - C2 : Metric cost matrix in the target space - - p : distribution in the source space - - q : distribution in the target space - - L : loss function to account for the misfit between the similarity matrices - - Parameters - ---------- - C1 : ndarray, shape (ns, ns) - Metric cost matrix in the source space - C2 : ndarray, shape (nt, nt) - Metric costfr matrix in the target space - p : ndarray, shape (ns,) - Distribution in the source space - q : ndarray, shape (nt,) - Distribution in the target space - loss_fun : function: \mathcal{R} \times \mathcal{R} \shortarrow \mathcal{R} - Loss function used for the distance, the transport plan does not depend on the loss function - alpha : float - Step of the Frank-Wolfe algorithm, should be between 0 and 1 - max_iter : int, optional - Max number of iterations - threshold_plan : float, optional - Deleting very small value in the transport plan. If above zero, it violate the marginal constraints. - verbose : bool, optional - Print information along iterations - log : bool, optional - Gives the distance estimated and the standard deviation - random_state : int or RandomState instance, optional - Fix the seed for to allow reproducibility - - Returns - ------- - T : ndarray, shape (ns, nt) - Optimal coupling between the two spaces - - References - ---------- - .. [14] Kerdoncuff, Tanguy, Emonet, Rémi, Sebban, Marc - "Sampled Gromov Wasserstein." - Machine Learning Journal (MLJ). 2021. - - """ - C1 = np.asarray(C1, dtype=np.float64) - C2 = np.asarray(C2, dtype=np.float64) - p = np.asarray(p, dtype=np.float64) - q = np.asarray(q, dtype=np.float64) - len_p = len(p) - len_q = len(q) + Returns the gromov-wasserstein transport between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` using a stochastic Frank-Wolfe. + This method has a :math:`\mathcal{O}(\mathrm{max\_iter} \times PN^2)` time complexity with `P` the number of Sinkhorn iterations. + + The function solves the following optimization problem: + + .. math:: + \mathbf{GW} = \mathop{\arg \min}_\mathbf{T} \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} + + s.t. \ \mathbf{T} \mathbf{1} &= \mathbf{p} + + \mathbf{T}^T \mathbf{1} &= \mathbf{q} + + \mathbf{T} &\geq 0 + + Where : + + - :math:`\mathbf{C_1}`: Metric cost matrix in the source space + - :math:`\mathbf{C_2}`: Metric cost matrix in the target space + - :math:`\mathbf{p}`: distribution in the source space + - :math:`\mathbf{q}`: distribution in the target space + - `L`: loss function to account for the misfit between the similarity matrices + + Parameters + ---------- + C1 : array-like, shape (ns, ns) + Metric cost matrix in the source space + C2 : array-like, shape (nt, nt) + Metric cost matrix in the target space + p : array-like, shape (ns,) + Distribution in the source space + q : array-like, shape (nt,) + Distribution in the target space + loss_fun : function: :math:`\mathbb{R} \times \mathbb{R} \mapsto \mathbb{R}` + Loss function used for the distance, the transport plan does not depend on the loss function + alpha : float + Step of the Frank-Wolfe algorithm, should be between 0 and 1 + max_iter : int, optional + Max number of iterations + threshold_plan : float, optional + Deleting very small values in the transport plan. If above zero, it violates the marginal constraints. + verbose : bool, optional + Print information along iterations + log : bool, optional + Gives the distance estimated and the standard deviation + random_state : int or RandomState instance, optional + Fix the seed for reproducibility + + Returns + ------- + T : array-like, shape (`ns`, `nt`) + Optimal coupling between the two spaces + + References + ---------- + .. [14] Kerdoncuff, Tanguy, Emonet, Rémi, Sebban, Marc + "Sampled Gromov Wasserstein." + Machine Learning Journal (MLJ). 2021. + + """ + C1, C2, p, q = list_to_array(C1, C2, p, q) + nx = get_backend(C1, C2, p, q) + + len_p = p.shape[0] + len_q = q.shape[0] generator = check_random_state(random_state) @@ -759,30 +846,35 @@ def pointwise_gromov_wasserstein(C1, C2, p, q, loss_fun, # Initialize with default marginal index[0] = generator.choice(len_p, size=1, p=p) index[1] = generator.choice(len_q, size=1, p=q) - T = emd_1d(C1[index[0]], C2[index[1]], a=p, b=q, dense=False).tocsr() + T = nx.tocsr(emd_1d(C1[index[0]], C2[index[1]], a=p, b=q, dense=False)) best_gw_dist_estimated = np.inf for cpt in range(max_iter): index[0] = generator.choice(len_p, size=1, p=p) - T_index0 = T[index[0], :].toarray()[0] + T_index0 = nx.reshape(nx.todense(T[index[0], :]), (-1,)) index[1] = generator.choice(len_q, size=1, p=T_index0 / T_index0.sum()) if alpha == 1: - T = emd_1d(C1[index[0]], C2[index[1]], a=p, b=q, dense=False).tocsr() + T = nx.tocsr( + emd_1d(C1[index[0]], C2[index[1]], a=p, b=q, dense=False) + ) else: - new_T = emd_1d(C1[index[0]], C2[index[1]], a=p, b=q, dense=False).tocsr() + new_T = nx.tocsr( + emd_1d(C1[index[0]], C2[index[1]], a=p, b=q, dense=False) + ) T = (1 - alpha) * T + alpha * new_T - # To limit the number of non 0, the values bellow the threshold are set to 0. - T.data[T.data < threshold_plan] = 0 - T.eliminate_zeros() + # To limit the number of non 0, the values below the threshold are set to 0. + T = nx.eliminate_zeros(T, threshold=threshold_plan) if cpt % 10 == 0 or cpt == (max_iter - 1): - gw_dist_estimated = GW_distance_estimation(C1=C1, C2=C2, loss_fun=loss_fun, - p=p, q=q, T=T, std=False, random_state=generator) + gw_dist_estimated = GW_distance_estimation( + C1=C1, C2=C2, loss_fun=loss_fun, + p=p, q=q, T=T, std=False, random_state=generator + ) if gw_dist_estimated < best_gw_dist_estimated: best_gw_dist_estimated = gw_dist_estimated - best_T = T.copy() + best_T = nx.copy(T) if verbose: if cpt % 200 == 0: @@ -791,9 +883,10 @@ def pointwise_gromov_wasserstein(C1, C2, p, q, loss_fun, if log: log = {} - log["gw_dist_estimated"], log["gw_dist_std"] = GW_distance_estimation(C1=C1, C2=C2, loss_fun=loss_fun, - p=p, q=q, T=best_T, - random_state=generator) + log["gw_dist_estimated"], log["gw_dist_std"] = GW_distance_estimation( + C1=C1, C2=C2, loss_fun=loss_fun, + p=p, q=q, T=best_T, random_state=generator + ) return best_T, log return best_T @@ -802,71 +895,70 @@ def sampled_gromov_wasserstein(C1, C2, p, q, loss_fun, nb_samples_grad=100, epsilon=1, max_iter=500, log=False, verbose=False, random_state=None): r""" - Returns the gromov-wasserstein transport between (C1,p) and (C2,q) using a 1-stochastic Frank-Wolfe. - This method as a O(max_iter \times Nlog(N)) time complexity by relying on the 1D Optimal Transport solver. - - The function solves the following optimization problem: - - .. math:: - GW = arg\min_T \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l} - - s.t. T 1 = p - - T^T 1= q - - T\geq 0 - - Where : - - - C1 : Metric cost matrix in the source space - - C2 : Metric cost matrix in the target space - - p : distribution in the source space - - q : distribution in the target space - - L : loss function to account for the misfit between the similarity matrices - - Parameters - ---------- - C1 : ndarray, shape (ns, ns) - Metric cost matrix in the source space - C2 : ndarray, shape (nt, nt) - Metric costfr matrix in the target space - p : ndarray, shape (ns,) - Distribution in the source space - q : ndarray, shape (nt,) - Distribution in the target space - loss_fun : function: \mathcal{R} \times \mathcal{R} \shortarrow \mathcal{R} - Loss function used for the distance, the transport plan does not depend on the loss function - nb_samples_grad : int - Number of samples to approximate the gradient - epsilon : float - Weight of the Kullback-Leiber regularization - max_iter : int, optional - Max number of iterations - verbose : bool, optional - Print information along iterations - log : bool, optional - Gives the distance estimated and the standard deviation - random_state : int or RandomState instance, optional - Fix the seed for to allow reproducibility - - Returns - ------- - T : ndarray, shape (ns, nt) - Optimal coupling between the two spaces - - References - ---------- - .. [14] Kerdoncuff, Tanguy, Emonet, Rémi, Sebban, Marc - "Sampled Gromov Wasserstein." - Machine Learning Journal (MLJ). 2021. - - """ - C1 = np.asarray(C1, dtype=np.float64) - C2 = np.asarray(C2, dtype=np.float64) - p = np.asarray(p, dtype=np.float64) - q = np.asarray(q, dtype=np.float64) - len_p = len(p) - len_q = len(q) + Returns the gromov-wasserstein transport between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` using a 1-stochastic Frank-Wolfe. + This method has a :math:`\mathcal{O}(\mathrm{max\_iter} \times N \log(N))` time complexity by relying on the 1D Optimal Transport solver. + + The function solves the following optimization problem: + + .. math:: + \mathbf{GW} = \mathop{\arg \min}_\mathbf{T} \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} + + s.t. \ \mathbf{T} \mathbf{1} &= \mathbf{p} + + \mathbf{T}^T \mathbf{1} &= \mathbf{q} + + \mathbf{T} &\geq 0 + + Where : + + - :math:`\mathbf{C_1}`: Metric cost matrix in the source space + - :math:`\mathbf{C_2}`: Metric cost matrix in the target space + - :math:`\mathbf{p}`: distribution in the source space + - :math:`\mathbf{q}`: distribution in the target space + - `L`: loss function to account for the misfit between the similarity matrices + + Parameters + ---------- + C1 : array-like, shape (ns, ns) + Metric cost matrix in the source space + C2 : array-like, shape (nt, nt) + Metric cost matrix in the target space + p : array-like, shape (ns,) + Distribution in the source space + q : array-like, shape (nt,) + Distribution in the target space + loss_fun : function: :math:`\mathbb{R} \times \mathbb{R} \mapsto \mathbb{R}` + Loss function used for the distance, the transport plan does not depend on the loss function + nb_samples_grad : int + Number of samples to approximate the gradient + epsilon : float + Weight of the Kullback-Leibler regularization + max_iter : int, optional + Max number of iterations + verbose : bool, optional + Print information along iterations + log : bool, optional + Gives the distance estimated and the standard deviation + random_state : int or RandomState instance, optional + Fix the seed for reproducibility + + Returns + ------- + T : array-like, shape (`ns`, `nt`) + Optimal coupling between the two spaces + + References + ---------- + .. [14] Kerdoncuff, Tanguy, Emonet, Rémi, Sebban, Marc + "Sampled Gromov Wasserstein." + Machine Learning Journal (MLJ). 2021. + + """ + C1, C2, p, q = list_to_array(C1, C2, p, q) + nx = get_backend(C1, C2, p, q) + + len_p = p.shape[0] + len_q = q.shape[0] generator = check_random_state(random_state) @@ -880,12 +972,12 @@ def sampled_gromov_wasserstein(C1, C2, p, q, loss_fun, nb_samples_grad_p, nb_samples_grad_q = nb_samples_grad, 1 else: nb_samples_grad_p, nb_samples_grad_q = nb_samples_grad - T = np.outer(p, q) + T = nx.outer(p, q) # continue_loop allows to stop the loop if there is several successive small modification of T. continue_loop = 0 # The gradient of GW is more complex if the two matrices are not symmetric. - C_are_symmetric = np.allclose(C1, C1.T, rtol=1e-10, atol=1e-10) and np.allclose(C2, C2.T, rtol=1e-10, atol=1e-10) + C_are_symmetric = nx.allclose(C1, C1.T, rtol=1e-10, atol=1e-10) and nx.allclose(C2, C2.T, rtol=1e-10, atol=1e-10) for cpt in range(max_iter): index0 = generator.choice(len_p, size=nb_samples_grad_p, p=p, replace=False) @@ -893,28 +985,30 @@ def sampled_gromov_wasserstein(C1, C2, p, q, loss_fun, for i, index0_i in enumerate(index0): index1 = generator.choice(len_q, size=nb_samples_grad_q, - p=T[index0_i, :] / T[index0_i, :].sum(), + p=T[index0_i, :] / nx.sum(T[index0_i, :]), replace=False) # If the matrices C are not symmetric, the gradient has 2 terms, thus the term is chosen randomly. if (not C_are_symmetric) and generator.rand(1) > 0.5: - Lik += np.mean(loss_fun(np.expand_dims(C1[:, np.repeat(index0[i], nb_samples_grad_q)], 1), - np.expand_dims(C2[:, index1], 0)), - axis=2) + Lik += nx.mean(loss_fun( + C1[:, [index0[i]] * nb_samples_grad_q][:, None, :], + C2[:, index1][None, :, :] + ), axis=2) else: - Lik += np.mean(loss_fun(np.expand_dims(C1[np.repeat(index0[i], nb_samples_grad_q), :], 2), - np.expand_dims(C2[index1, :], 1)), - axis=0) + Lik += nx.mean(loss_fun( + C1[[index0[i]] * nb_samples_grad_q, :][:, :, None], + C2[index1, :][:, None, :] + ), axis=0) - max_Lik = np.max(Lik) + max_Lik = nx.max(Lik) if max_Lik == 0: continue # This division by the max is here to facilitate the choice of epsilon. Lik /= max_Lik if epsilon > 0: - # Set to infinity all the numbers bellow exp(-200) to avoid log of 0. - log_T = np.log(np.clip(T, np.exp(-200), 1)) - log_T[log_T == -200] = -np.inf + # Set to infinity all the numbers below exp(-200) to avoid log of 0. + log_T = nx.log(nx.clip(T, np.exp(-200), 1)) + log_T = nx.where(log_T == -200, -np.inf, log_T) Lik = Lik - epsilon * log_T try: @@ -925,11 +1019,11 @@ def sampled_gromov_wasserstein(C1, C2, p, q, loss_fun, else: new_T = emd(a=p, b=q, M=Lik) - change_T = ((T - new_T) ** 2).mean() + change_T = nx.mean((T - new_T) ** 2) if change_T <= 10e-20: continue_loop += 1 if continue_loop > 100: # Number max of low modifications of T - T = new_T.copy() + T = nx.copy(new_T) break else: continue_loop = 0 @@ -938,12 +1032,14 @@ def sampled_gromov_wasserstein(C1, C2, p, q, loss_fun, if cpt % 200 == 0: print('{:5s}|{:12s}'.format('It.', '||T_n - T_{n+1}||') + '\n' + '-' * 19) print('{:5d}|{:8e}|'.format(cpt, change_T)) - T = new_T.copy() + T = nx.copy(new_T) if log: log = {} - log["gw_dist_estimated"], log["gw_dist_std"] = GW_distance_estimation(C1=C1, C2=C2, loss_fun=loss_fun, - p=p, q=q, T=T, random_state=generator) + log["gw_dist_estimated"], log["gw_dist_std"] = GW_distance_estimation( + C1=C1, C2=C2, loss_fun=loss_fun, + p=p, q=q, T=T, random_state=generator + ) return T, log return T @@ -951,38 +1047,37 @@ def sampled_gromov_wasserstein(C1, C2, p, q, loss_fun, def entropic_gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, tol=1e-9, verbose=False, log=False): r""" - Returns the gromov-wasserstein transport between (C1,p) and (C2,q) - - (C1,p) and (C2,q) + Returns the gromov-wasserstein transport between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` The function solves the following optimization problem: .. math:: - GW = arg\min_T \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l}-\epsilon(H(T)) + \mathbf{GW} = \mathop{\arg\min}_\mathbf{T} \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} - \epsilon(H(\mathbf{T})) - s.t. T 1 = p + s.t. \ \mathbf{T} \mathbf{1} &= \mathbf{p} - T^T 1= q + \mathbf{T}^T \mathbf{1} &= \mathbf{q} - T\geq 0 + \mathbf{T} &\geq 0 Where : - - C1 : Metric cost matrix in the source space - - C2 : Metric cost matrix in the target space - - p : distribution in the source space - - q : distribution in the target space - - L : loss function to account for the misfit between the similarity matrices - - H : entropy + + - :math:`\mathbf{C_1}`: Metric cost matrix in the source space + - :math:`\mathbf{C_2}`: Metric cost matrix in the target space + - :math:`\mathbf{p}`: distribution in the source space + - :math:`\mathbf{q}`: distribution in the target space + - `L`: loss function to account for the misfit between the similarity matrices + - `H`: entropy Parameters ---------- - C1 : ndarray, shape (ns, ns) + C1 : array-like, shape (ns, ns) Metric cost matrix in the source space - C2 : ndarray, shape (nt, nt) - Metric costfr matrix in the target space - p : ndarray, shape (ns,) + C2 : array-like, shape (nt, nt) + Metric cost matrix in the target space + p : array-like, shape (ns,) Distribution in the source space - q : ndarray, shape (nt,) + q : array-like, shape (nt,) Distribution in the target space loss_fun : string Loss function used for the solver either 'square_loss' or 'kl_loss' @@ -999,21 +1094,20 @@ def entropic_gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, Returns ------- - T : ndarray, shape (ns, nt) + T : array-like, shape (`ns`, `nt`) Optimal coupling between the two spaces References ---------- - .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon, + .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon, "Gromov-Wasserstein averaging of kernel and distance matrices." International Conference on Machine Learning (ICML). 2016. """ + C1, C2, p, q = list_to_array(C1, C2, p, q) + nx = get_backend(C1, C2, p, q) - C1 = np.asarray(C1, dtype=np.float64) - C2 = np.asarray(C2, dtype=np.float64) - - T = np.outer(p, q) # Initialization + T = nx.outer(p, q) constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun) @@ -1035,7 +1129,7 @@ def entropic_gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, if cpt % 10 == 0: # we can speed up the process by checking for the error only all # the 10th iterations - err = np.linalg.norm(T - Tprev) + err = nx.norm(T - Tprev) if log: log['err'].append(err) @@ -1058,32 +1152,31 @@ def entropic_gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, def entropic_gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, tol=1e-9, verbose=False, log=False): r""" - Returns the entropic gromov-wasserstein discrepancy between the two measured similarity matrices - - (C1,p) and (C2,q) + Returns the entropic gromov-wasserstein discrepancy between the two measured similarity matrices :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` The function solves the following optimization problem: .. math:: - GW = \min_T \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l}-\epsilon(H(T)) + GW = \min_\mathbf{T} \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} - \epsilon(H(\mathbf{T})) Where : - - C1 : Metric cost matrix in the source space - - C2 : Metric cost matrix in the target space - - p : distribution in the source space - - q : distribution in the target space - - L : loss function to account for the misfit between the similarity matrices - - H : entropy + + - :math:`\mathbf{C_1}`: Metric cost matrix in the source space + - :math:`\mathbf{C_2}`: Metric cost matrix in the target space + - :math:`\mathbf{p}`: distribution in the source space + - :math:`\mathbf{q}`: distribution in the target space + - `L`: loss function to account for the misfit between the similarity matrices + - `H`: entropy Parameters ---------- - C1 : ndarray, shape (ns, ns) + C1 : array-like, shape (ns, ns) Metric cost matrix in the source space - C2 : ndarray, shape (nt, nt) - Metric costfr matrix in the target space - p : ndarray, shape (ns,) + C2 : array-like, shape (nt, nt) + Metric cost matrix in the target space + p : array-like, shape (ns,) Distribution in the source space - q : ndarray, shape (nt,) + q : array-like, shape (nt,) Distribution in the target space loss_fun : str Loss function used for the solver either 'square_loss' or 'kl_loss' @@ -1105,7 +1198,7 @@ def entropic_gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon, References ---------- - .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon, + .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon, "Gromov-Wasserstein averaging of kernel and distance matrices." International Conference on Machine Learning (ICML). 2016. @@ -1122,40 +1215,39 @@ def entropic_gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon, def entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, - max_iter=1000, tol=1e-9, verbose=False, log=False, init_C=None): + max_iter=1000, tol=1e-9, verbose=False, log=False, init_C=None, random_state=None): r""" - Returns the gromov-wasserstein barycenters of S measured similarity matrices - - (Cs)_{s=1}^{s=S} + Returns the gromov-wasserstein barycenters of `S` measured similarity matrices :math:`(\mathbf{C}_s)_{1 \leq s \leq S}` The function solves the following optimization problem: .. math:: - C = argmin_{C\in R^{NxN}} \sum_s \lambda_s GW(C,C_s,p,p_s) + \mathbf{C} = \mathop{\arg \min}_{\mathbf{C}\in \mathbb{R}^{N \times N}} \sum_s \lambda_s \mathrm{GW}(\mathbf{C}, \mathbf{C}_s, \mathbf{p}, \mathbf{p}_s) Where : - - :math:`C_s` : metric cost matrix - - :math:`p_s` : distribution + - :math:`\mathbf{C}_s`: metric cost matrix + - :math:`\mathbf{p}_s`: distribution Parameters ---------- N : int Size of the targeted barycenter - Cs : list of S np.ndarray of shape (ns,ns) + Cs : list of S array-like of shape (ns,ns) Metric cost matrices - ps : list of S np.ndarray of shape (ns,) - Sample weights in the S spaces - p : ndarray, shape(N,) + ps : list of S array-like of shape (ns,) + Sample weights in the `S` spaces + p : array-like, shape(N,) Weights in the targeted barycenter lambdas : list of float - List of the S spaces' weights. + List of the `S` spaces' weights. loss_fun : callable Tensor-matrix multiplication function based on specific loss function. update : callable - function(p,lambdas,T,Cs) that updates C according to a specific Kernel - with the S Ts couplings calculated at each iteration + function(:math:`\mathbf{p}`, lambdas, :math:`\mathbf{T}`, :math:`\mathbf{Cs}`) that updates + :math:`\mathbf{C}` according to a specific Kernel with the `S` :math:`\mathbf{T}_s` couplings + calculated at each iteration epsilon : float Regularization term >0 max_iter : int, optional @@ -1166,32 +1258,36 @@ def entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, Print information along iterations. log : bool, optional Record log if True. - init_C : bool | ndarray, shape (N, N) - Random initial value for the C matrix provided by user. + init_C : bool | array-like, shape (N, N) + Random initial value for the :math:`\mathbf{C}` matrix provided by user. + random_state : int or RandomState instance, optional + Fix the seed for reproducibility Returns ------- - C : ndarray, shape (N, N) + C : array-like, shape (`N`, `N`) Similarity matrix in the barycenter space (permutated arbitrarily) References ---------- - .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon, + .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon, "Gromov-Wasserstein averaging of kernel and distance matrices." International Conference on Machine Learning (ICML). 2016. """ + Cs = list_to_array(*Cs) + ps = list_to_array(*ps) + p = list_to_array(p) + nx = get_backend(*Cs, *ps, p) S = len(Cs) - Cs = [np.asarray(Cs[s], dtype=np.float64) for s in range(S)] - lambdas = np.asarray(lambdas, dtype=np.float64) - # Initialization of C : random SPD matrix (if not provided by user) if init_C is None: - # XXX use random state - xalea = np.random.randn(N, 2) + generator = check_random_state(random_state) + xalea = generator.randn(N, 2) C = dist(xalea, xalea) C /= C.max() + C = nx.from_numpy(C, type_as=p) else: C = init_C @@ -1214,7 +1310,7 @@ def entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, if cpt % 10 == 0: # we can speed up the process by checking for the error only all # the 10th iterations - err = np.linalg.norm(C - Cprev) + err = nx.norm(C - Cprev) error.append(err) if log: @@ -1232,38 +1328,39 @@ def entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, - max_iter=1000, tol=1e-9, verbose=False, log=False, init_C=None): + max_iter=1000, tol=1e-9, verbose=False, log=False, init_C=None, random_state=None): r""" - Returns the gromov-wasserstein barycenters of S measured similarity matrices - - (Cs)_{s=1}^{s=S} + Returns the gromov-wasserstein barycenters of `S` measured similarity matrices :math:`(\mathbf{C}_s)_{1 \leq s \leq S}` - The function solves the following optimization problem with block - coordinate descent: + The function solves the following optimization problem with block coordinate descent: .. math:: - C = argmin_C\in R^NxN \sum_s \lambda_s GW(C,Cs,p,ps) + + \mathbf{C} = \mathop{\arg \min}_{\mathbf{C}\in \mathbb{R}^{N \times N}} \sum_s \lambda_s \mathrm{GW}(\mathbf{C}, \mathbf{C}_s, \mathbf{p}, \mathbf{p}_s) Where : - - Cs : metric cost matrix - - ps : distribution + - :math:`\mathbf{C}_s`: metric cost matrix + - :math:`\mathbf{p}_s`: distribution Parameters ---------- N : int Size of the targeted barycenter - Cs : list of S np.ndarray of shape (ns, ns) + Cs : list of S array-like of shape (ns, ns) Metric cost matrices - ps : list of S np.ndarray of shape (ns,) - Sample weights in the S spaces - p : ndarray, shape (N,) + ps : list of S array-like of shape (ns,) + Sample weights in the `S` spaces + p : array-like, shape (N,) Weights in the targeted barycenter lambdas : list of float - List of the S spaces' weights - loss_fun : tensor-matrix multiplication function based on specific loss function - update : function(p,lambdas,T,Cs) that updates C according to a specific Kernel - with the S Ts couplings calculated at each iteration + List of the `S` spaces' weights + loss_fun : callable + tensor-matrix multiplication function based on specific loss function + update : callable + function(:math:`\mathbf{p}`, lambdas, :math:`\mathbf{T}`, :math:`\mathbf{Cs}`) that updates + :math:`\mathbf{C}` according to a specific Kernel with the `S` :math:`\mathbf{T}_s` couplings + calculated at each iteration max_iter : int, optional Max number of iterations tol : float, optional @@ -1272,32 +1369,37 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, Print information along iterations. log : bool, optional Record log if True. - init_C : bool | ndarray, shape(N,N) - Random initial value for the C matrix provided by user. + init_C : bool | array-like, shape(N,N) + Random initial value for the :math:`\mathbf{C}` matrix provided by user. + random_state : int or RandomState instance, optional + Fix the seed for reproducibility Returns ------- - C : ndarray, shape (N, N) + C : array-like, shape (`N`, `N`) Similarity matrix in the barycenter space (permutated arbitrarily) References ---------- - .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon, + .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon, "Gromov-Wasserstein averaging of kernel and distance matrices." International Conference on Machine Learning (ICML). 2016. """ - S = len(Cs) + Cs = list_to_array(*Cs) + ps = list_to_array(*ps) + p = list_to_array(p) + nx = get_backend(*Cs, *ps, p) - Cs = [np.asarray(Cs[s], dtype=np.float64) for s in range(S)] - lambdas = np.asarray(lambdas, dtype=np.float64) + S = len(Cs) # Initialization of C : random SPD matrix (if not provided by user) if init_C is None: - # XXX : should use a random state and not use the global seed - xalea = np.random.randn(N, 2) + generator = check_random_state(random_state) + xalea = generator.randn(N, 2) C = dist(xalea, xalea) C /= C.max() + C = nx.from_numpy(C, type_as=p) else: C = init_C @@ -1320,7 +1422,7 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, if cpt % 10 == 0: # we can speed up the process by checking for the error only all # the 10th iterations - err = np.linalg.norm(C - Cprev) + err = nx.norm(C - Cprev) error.append(err) if log: @@ -1339,21 +1441,21 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_features=False, p=None, loss_fun='square_loss', max_iter=100, tol=1e-9, - verbose=False, log=False, init_C=None, init_X=None): - """Compute the fgw barycenter as presented eq (5) in [24]. + verbose=False, log=False, init_C=None, init_X=None, random_state=None): + """Compute the fgw barycenter as presented eq (5) in :ref:`[24] ` Parameters ---------- - N : integer + N : int Desired number of samples of the target barycenter - Ys: list of ndarray, each element has shape (ns,d) + Ys: list of array-like, each element has shape (ns,d) Features of all samples - Cs : list of ndarray, each element has shape (ns,ns) + Cs : list of array-like, each element has shape (ns,ns) Structure matrices of all samples - ps : list of ndarray, each element has shape (ns,) + ps : list of array-like, each element has shape (ns,) Masses of all samples. lambdas : list of float - List of the S spaces' weights + List of the `S` spaces' weights alpha : float Alpha parameter for the fgw distance fixed_structure : bool @@ -1370,41 +1472,46 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_ Print information along iterations. log : bool, optional Record log if True. - init_C : ndarray, shape (N,N), optional + init_C : array-like, shape (N,N), optional Initialization for the barycenters' structure matrix. If not set a random init is used. - init_X : ndarray, shape (N,d), optional + init_X : array-like, shape (N,d), optional Initialization for the barycenters' features. If not set a random init is used. + random_state : int or RandomState instance, optional + Fix the seed for reproducibility Returns ------- - X : ndarray, shape (N, d) + X : array-like, shape (`N`, `d`) Barycenters' features - C : ndarray, shape (N, N) + C : array-like, shape (`N`, `N`) Barycenters' structure matrix - log_: dict + log : dict Only returned when log=True. It contains the keys: - T : list of (N,ns) transport matrices - Ms : all distance matrices between the feature of the barycenter and the - other features dist(X,Ys) shape (N,ns) + - :math:`\mathbf{T}`: list of (`N`, `ns`) transport matrices + - :math:`(\mathbf{M}_s)_s`: all distance matrices between the feature of the barycenter and the other features :math:`(dist(\mathbf{X}, \mathbf{Y}_s))_s` shape (`N`, `ns`) + + + .. _references-fgw-barycenters: References ---------- - .. [24] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain + .. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain and Courty Nicolas "Optimal Transport for structured data with application on graphs" International Conference on Machine Learning (ICML). 2019. """ + Cs = list_to_array(*Cs) + ps = list_to_array(*ps) + Ys = list_to_array(*Ys) + p = list_to_array(p) + nx = get_backend(*Cs, *Ys, *ps) + S = len(Cs) d = Ys[0].shape[1] # dimension on the node features if p is None: - p = np.ones(N) / N - - Cs = [np.asarray(Cs[s], dtype=np.float64) for s in range(S)] - Ys = [np.asarray(Ys[s], dtype=np.float64) for s in range(S)] - - lambdas = np.asarray(lambdas, dtype=np.float64) + p = nx.ones(N, type_as=Cs[0]) / N if fixed_structure: if init_C is None: @@ -1413,8 +1520,10 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_ C = init_C else: if init_C is None: - xalea = np.random.randn(N, 2) + generator = check_random_state(random_state) + xalea = generator.randn(N, 2) C = dist(xalea, xalea) + C = nx.from_numpy(C, type_as=ps[0]) else: C = init_C @@ -1425,13 +1534,13 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_ X = init_X else: if init_X is None: - X = np.zeros((N, d)) + X = nx.zeros((N, d), type_as=ps[0]) else: X = init_X - T = [np.outer(p, q) for q in ps] + T = [nx.outer(p, q) for q in ps] - Ms = [np.asarray(dist(X, Ys[s]), dtype=np.float64) for s in range(len(Ys))] # Ms is N,ns + Ms = [dist(X, Ys[s]) for s in range(len(Ys))] cpt = 0 err_feature = 1 @@ -1451,20 +1560,19 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_ Ys_temp = [y.T for y in Ys] X = update_feature_matrix(lambdas, Ys_temp, T, p).T - Ms = [np.asarray(dist(X, Ys[s]), dtype=np.float64) for s in range(len(Ys))] + Ms = [dist(X, Ys[s]) for s in range(len(Ys))] if not fixed_structure: if loss_fun == 'square_loss': T_temp = [t.T for t in T] - C = update_sructure_matrix(p, lambdas, T_temp, Cs) + C = update_structure_matrix(p, lambdas, T_temp, Cs) T = [fused_gromov_wasserstein(Ms[s], C, Cs[s], p, ps[s], loss_fun, alpha, numItermax=max_iter, stopThr=1e-5, verbose=verbose) for s in range(S)] # T is N,ns - err_feature = np.linalg.norm(X - Xprev.reshape(N, d)) - err_structure = np.linalg.norm(C - Cprev) - + err_feature = nx.norm(X - nx.reshape(Xprev, (N, d))) + err_structure = nx.norm(C - Cprev) if log: log_['err_feature'].append(err_feature) log_['err_structure'].append(err_structure) @@ -1490,64 +1598,80 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_ return X, C -def update_sructure_matrix(p, lambdas, T, Cs): - """Updates C according to the L2 Loss kernel with the S Ts couplings. +def update_structure_matrix(p, lambdas, T, Cs): + """Updates :math:`\mathbf{C}` according to the L2 Loss kernel with the `S` :math:`\mathbf{T}_s` couplings. It is calculated at each iteration Parameters ---------- - p : ndarray, shape (N,) + p : array-like, shape (N,) Masses in the targeted barycenter. lambdas : list of float - List of the S spaces' weights. - T : list of S ndarray of shape (ns, N) - The S Ts couplings calculated at each iteration. - Cs : list of S ndarray, shape (ns, ns) - Metric cost matrices. + List of the `S` spaces' weights. + T : list of S array-like of shape (ns, N) + The `S` :math:`\mathbf{T}_s` couplings calculated at each iteration. + Cs : list of S array-like, shape (ns, ns) + Metric cost matrices. Returns ------- - C : ndarray, shape (nt, nt) - Updated C matrix. + C : array-like, shape (`nt`, `nt`) + Updated :math:`\mathbf{C}` matrix. """ - tmpsum = sum([lambdas[s] * np.dot(T[s].T, Cs[s]).dot(T[s]) for s in range(len(T))]) - ppt = np.outer(p, p) + p = list_to_array(p) + T = list_to_array(*T) + Cs = list_to_array(*Cs) + nx = get_backend(*Cs, *T, p) - return np.divide(tmpsum, ppt) + tmpsum = sum([ + lambdas[s] * nx.dot( + nx.dot(T[s].T, Cs[s]), + T[s] + ) for s in range(len(T)) + ]) + ppt = nx.outer(p, p) + return tmpsum / ppt def update_feature_matrix(lambdas, Ys, Ts, p): - """Updates the feature with respect to the S Ts couplings. + """Updates the feature with respect to the `S` :math:`\mathbf{T}_s` couplings. See "Solving the barycenter problem with Block Coordinate Descent (BCD)" - in [24] calculated at each iteration + in :ref:`[24] ` calculated at each iteration Parameters ---------- - p : ndarray, shape (N,) + p : array-like, shape (N,) masses in the targeted barycenter lambdas : list of float - List of the S spaces' weights - Ts : list of S np.ndarray(ns,N) - the S Ts couplings calculated at each iteration - Ys : list of S ndarray, shape(d,ns) + List of the `S` spaces' weights + Ts : list of S array-like, shape (ns,N) + The `S` :math:`\mathbf{T}_s` couplings calculated at each iteration + Ys : list of S array-like, shape (d,ns) The features. Returns ------- - X : ndarray, shape (d, N) + X : array-like, shape (`d`, `N`) + + .. _references-update-feature-matrix: References ---------- - .. [24] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain - and Courty Nicolas + .. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain and Courty Nicolas "Optimal Transport for structured data with application on graphs" International Conference on Machine Learning (ICML). 2019. """ - p = np.array(1. / p).reshape(-1,) - - tmpsum = sum([lambdas[s] * np.dot(Ys[s], Ts[s].T) * p[None, :] for s in range(len(Ts))]) - + p = list_to_array(p) + Ts = list_to_array(*Ts) + Ys = list_to_array(*Ys) + nx = get_backend(*Ys, *Ts, p) + + p = 1. / p + tmpsum = sum([ + lambdas[s] * nx.dot(Ys[s], Ts[s].T) * p[None, :] + for s in range(len(Ts)) + ]) return tmpsum diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index c6757d1..4e95ccf 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -691,10 +691,8 @@ def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True, ot.lp.emd2_1d : EMD for 1d distributions (returns cost instead of the transportation matrix) """ - a = np.asarray(a, dtype=np.float64) - b = np.asarray(b, dtype=np.float64) - x_a = np.asarray(x_a, dtype=np.float64) - x_b = np.asarray(x_b, dtype=np.float64) + a, b, x_a, x_b = list_to_array(a, b, x_a, x_b) + nx = get_backend(x_a, x_b) assert (x_a.ndim == 1 or x_a.ndim == 2 and x_a.shape[1] == 1), \ "emd_1d should only be used with monodimensional data" @@ -702,27 +700,43 @@ def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True, "emd_1d should only be used with monodimensional data" # if empty array given then use uniform distributions - if a.ndim == 0 or len(a) == 0: - a = np.ones((x_a.shape[0],), dtype=np.float64) / x_a.shape[0] - if b.ndim == 0 or len(b) == 0: - b = np.ones((x_b.shape[0],), dtype=np.float64) / x_b.shape[0] + if a is None or a.ndim == 0 or len(a) == 0: + a = nx.ones((x_a.shape[0],), type_as=x_a) / x_a.shape[0] + if b is None or b.ndim == 0 or len(b) == 0: + b = nx.ones((x_b.shape[0],), type_as=x_b) / x_b.shape[0] # ensure that same mass - np.testing.assert_almost_equal(a.sum(0),b.sum(0),err_msg='a and b vector must have the same sum') - b=b*a.sum()/b.sum() - - x_a_1d = x_a.reshape((-1,)) - x_b_1d = x_b.reshape((-1,)) - perm_a = np.argsort(x_a_1d) - perm_b = np.argsort(x_b_1d) - - G_sorted, indices, cost = emd_1d_sorted(a[perm_a], b[perm_b], - x_a_1d[perm_a], x_b_1d[perm_b], - metric=metric, p=p) - G = coo_matrix((G_sorted, (perm_a[indices[:, 0]], perm_b[indices[:, 1]])), - shape=(a.shape[0], b.shape[0])) + np.testing.assert_almost_equal( + nx.sum(a, axis=0), + nx.sum(b, axis=0), + err_msg='a and b vector must have the same sum' + ) + b = b * nx.sum(a) / nx.sum(b) + + x_a_1d = nx.reshape(x_a, (-1,)) + x_b_1d = nx.reshape(x_b, (-1,)) + perm_a = nx.argsort(x_a_1d) + perm_b = nx.argsort(x_b_1d) + + G_sorted, indices, cost = emd_1d_sorted( + nx.to_numpy(a[perm_a]), + nx.to_numpy(b[perm_b]), + nx.to_numpy(x_a_1d[perm_a]), + nx.to_numpy(x_b_1d[perm_b]), + metric=metric, p=p + ) + + G = nx.coo_matrix( + G_sorted, + perm_a[indices[:, 0]], + perm_b[indices[:, 1]], + shape=(a.shape[0], b.shape[0]), + type_as=x_a + ) if dense: - G = G.toarray() + G = nx.todense(G) + elif str(nx) == "jax": + warnings.warn("JAX does not support sparse matrices, converting to dense") if log: log = {'cost': cost} return G, log diff --git a/ot/optim.py b/ot/optim.py index 34cbb17..6456c03 100644 --- a/ot/optim.py +++ b/ot/optim.py @@ -23,7 +23,7 @@ def line_search_armijo(f, xk, pk, gfk, old_fval, r""" Armijo linesearch function that works with matrices - Find an approximate minimum of :math:`f(x_k + \\alpha \cdot p_k)` that satisfies the + Find an approximate minimum of :math:`f(x_k + \alpha \cdot p_k)` that satisfies the armijo conditions. Parameters @@ -129,7 +129,7 @@ def solve_linesearch(cost, G, deltaG, Mi, f_val, .. _references-solve-linesearch: References ---------- - .. [24] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain and Courty Nicolas + .. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain and Courty Nicolas "Optimal Transport for structured data with application on graphs" International Conference on Machine Learning (ICML). 2019. """ @@ -162,13 +162,13 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000, The function solves the following optimization problem: .. math:: - \gamma = arg\min_\gamma <\gamma,M>_F + \mathrm{reg} \cdot f(\gamma) + \gamma = \mathop{\arg \min}_\gamma <\gamma, \mathbf{M}>_F + \mathrm{reg} \cdot f(\gamma) - s.t. \ \gamma 1 = a + s.t. \ \gamma \mathbf{1} &= \mathbf{a} - \gamma^T 1= b + \gamma^T \mathbf{1} &= \mathbf{b} - \gamma\geq 0 + \gamma &\geq 0 where : - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix @@ -309,13 +309,13 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10, The function solves the following optimization problem: .. math:: - \gamma = arg\min_\gamma <\gamma,M>_F + \mathrm{reg_1}\cdot\Omega(\gamma) + \mathrm{reg_2}\cdot f(\gamma) + \gamma = \mathop{\arg \min}_\gamma <\gamma, \mathbf{M}>_F + \mathrm{reg_1}\cdot\Omega(\gamma) + \mathrm{reg_2}\cdot f(\gamma) - s.t. \ \gamma 1 = a + s.t. \ \gamma \mathbf{1} &= \mathbf{a} - \gamma^T 1= b + \gamma^T \mathbf{1} &= \mathbf{b} - \gamma\geq 0 + \gamma &\geq 0 where : - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix @@ -452,7 +452,7 @@ def solve_1d_linesearch_quad(a, b, c): .. math:: - arg\min_{0 \leq x \leq 1} f(x) = ax^{2} + bx + c + \mathop{\arg \min}_{0 \leq x \leq 1} f(x) = ax^{2} + bx + c Parameters ---------- diff --git a/test/test_backend.py b/test/test_backend.py index 5853282..0f11ace 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -207,6 +207,22 @@ def test_empty_backend(): nx.stack([M, M]) with pytest.raises(NotImplementedError): nx.reshape(M, (5, 3, 2)) + with pytest.raises(NotImplementedError): + nx.coo_matrix(M, M, M) + with pytest.raises(NotImplementedError): + nx.issparse(M) + with pytest.raises(NotImplementedError): + nx.tocsr(M) + with pytest.raises(NotImplementedError): + nx.eliminate_zeros(M) + with pytest.raises(NotImplementedError): + nx.todense(M) + with pytest.raises(NotImplementedError): + nx.where(M, M, M) + with pytest.raises(NotImplementedError): + nx.copy(M) + with pytest.raises(NotImplementedError): + nx.allclose(M, M) def test_func_backends(nx): @@ -216,6 +232,11 @@ def test_func_backends(nx): v = rnd.randn(3) val = np.array([1.0]) + # Sparse tensors test + sp_row = np.array([0, 3, 1, 0, 3]) + sp_col = np.array([0, 3, 1, 2, 2]) + sp_data = np.array([4, 5, 7, 9, 0]) + lst_tot = [] for nx in [ot.backend.NumpyBackend(), nx]: @@ -229,6 +250,10 @@ def test_func_backends(nx): vb = nx.from_numpy(v) val = nx.from_numpy(val) + sp_rowb = nx.from_numpy(sp_row) + sp_colb = nx.from_numpy(sp_col) + sp_datab = nx.from_numpy(sp_data) + A = nx.set_gradients(val, v, v) lst_b.append(nx.to_numpy(A)) lst_name.append('set_gradients') @@ -438,6 +463,37 @@ def test_func_backends(nx): lst_b.append(nx.to_numpy(A)) lst_name.append('reshape') + sp_Mb = nx.coo_matrix(sp_datab, sp_rowb, sp_colb, shape=(4, 4)) + nx.todense(Mb) + lst_b.append(nx.to_numpy(nx.todense(sp_Mb))) + lst_name.append('coo_matrix') + + assert not nx.issparse(Mb), 'Assert fail on: issparse (expected False)' + assert nx.issparse(sp_Mb) or nx.__name__ == "jax", 'Assert fail on: issparse (expected True)' + + A = nx.tocsr(sp_Mb) + lst_b.append(nx.to_numpy(nx.todense(A))) + lst_name.append('tocsr') + + A = nx.eliminate_zeros(nx.copy(sp_datab), threshold=5.) + lst_b.append(nx.to_numpy(A)) + lst_name.append('eliminate_zeros (dense)') + + A = nx.eliminate_zeros(sp_Mb) + lst_b.append(nx.to_numpy(nx.todense(A))) + lst_name.append('eliminate_zeros (sparse)') + + A = nx.where(Mb >= nx.stack([nx.linspace(0, 1, 10)] * 3, axis=1), Mb, 0.0) + lst_b.append(nx.to_numpy(A)) + lst_name.append('where') + + A = nx.copy(Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('copy') + + assert nx.allclose(Mb, Mb), 'Assert fail on: allclose (expected True)' + assert not nx.allclose(2 * Mb, Mb), 'Assert fail on: allclose (expected False)' + lst_tot.append(lst_b) lst_np = lst_tot[0] diff --git a/test/test_bregman.py b/test/test_bregman.py index c1120ba..6923d31 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -477,8 +477,8 @@ def test_lazy_empirical_sinkhorn(nx): b = ot.unif(n) numIterMax = 1000 - X_s = np.reshape(np.arange(n), (n, 1)) - X_t = np.reshape(np.arange(0, n), (n, 1)) + X_s = np.reshape(np.arange(n, dtype=np.float64), (n, 1)) + X_t = np.reshape(np.arange(0, n, dtype=np.float64), (n, 1)) M = ot.dist(X_s, X_t) M_m = ot.dist(X_s, X_t, metric='euclidean') diff --git a/test/test_gromov.py b/test/test_gromov.py index 0242d72..509c54d 100644 --- a/test/test_gromov.py +++ b/test/test_gromov.py @@ -8,11 +8,12 @@ import numpy as np import ot +from ot.backend import NumpyBackend import pytest -def test_gromov(): +def test_gromov(nx): n_samples = 50 # nb samples mu_s = np.array([0, 0]) @@ -31,37 +32,50 @@ def test_gromov(): C1 /= C1.max() C2 /= C2.max() + C1b = nx.from_numpy(C1) + C2b = nx.from_numpy(C2) + pb = nx.from_numpy(p) + qb = nx.from_numpy(q) + G = ot.gromov.gromov_wasserstein(C1, C2, p, q, 'square_loss', verbose=True) + Gb = nx.to_numpy(ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', verbose=True)) # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) np.testing.assert_allclose( - p, G.sum(1), atol=1e-04) # cf convergence gromov + p, Gb.sum(1), atol=1e-04) # cf convergence gromov np.testing.assert_allclose( - q, G.sum(0), atol=1e-04) # cf convergence gromov + q, Gb.sum(0), atol=1e-04) # cf convergence gromov Id = (1 / (1.0 * n_samples)) * np.eye(n_samples, n_samples) - np.testing.assert_allclose( - G, np.flipud(Id), atol=1e-04) + np.testing.assert_allclose(Gb, np.flipud(Id), atol=1e-04) gw, log = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', log=True) + gwb, logb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', log=True) gw_val = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', log=False) + gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', log=False) G = log['T'] + Gb = nx.to_numpy(logb['T']) - np.testing.assert_allclose(gw, 0, atol=1e-1, rtol=1e-1) + np.testing.assert_allclose(gw, gwb, atol=1e-06) + np.testing.assert_allclose(gwb, 0, atol=1e-1, rtol=1e-1) - np.testing.assert_allclose(gw, gw_val, atol=1e-1, rtol=1e-1) # cf log=False + np.testing.assert_allclose(gw_val, gw_valb, atol=1e-06) + np.testing.assert_allclose(gwb, gw_valb, atol=1e-1, rtol=1e-1) # cf log=False # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) np.testing.assert_allclose( - p, G.sum(1), atol=1e-04) # cf convergence gromov + p, Gb.sum(1), atol=1e-04) # cf convergence gromov np.testing.assert_allclose( - q, G.sum(0), atol=1e-04) # cf convergence gromov + q, Gb.sum(0), atol=1e-04) # cf convergence gromov -def test_entropic_gromov(): +@pytest.skip_backend("jax", reason="test very slow with jax backend") +def test_entropic_gromov(nx): n_samples = 50 # nb samples mu_s = np.array([0, 0]) @@ -80,30 +94,44 @@ def test_entropic_gromov(): C1 /= C1.max() C2 /= C2.max() + C1b = nx.from_numpy(C1) + C2b = nx.from_numpy(C2) + pb = nx.from_numpy(p) + qb = nx.from_numpy(q) + G = ot.gromov.entropic_gromov_wasserstein( C1, C2, p, q, 'square_loss', epsilon=5e-4, verbose=True) + Gb = nx.to_numpy(ot.gromov.entropic_gromov_wasserstein( + C1b, C2b, pb, qb, 'square_loss', epsilon=5e-4, verbose=True + )) # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) np.testing.assert_allclose( - p, G.sum(1), atol=1e-04) # cf convergence gromov + p, Gb.sum(1), atol=1e-04) # cf convergence gromov np.testing.assert_allclose( - q, G.sum(0), atol=1e-04) # cf convergence gromov + q, Gb.sum(0), atol=1e-04) # cf convergence gromov gw, log = ot.gromov.entropic_gromov_wasserstein2( C1, C2, p, q, 'kl_loss', epsilon=1e-2, log=True) + gwb, logb = ot.gromov.entropic_gromov_wasserstein2( + C1b, C2b, pb, qb, 'kl_loss', epsilon=1e-2, log=True) G = log['T'] + Gb = nx.to_numpy(logb['T']) + np.testing.assert_allclose(gw, gwb, atol=1e-06) np.testing.assert_allclose(gw, 0, atol=1e-1, rtol=1e-1) # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) np.testing.assert_allclose( - p, G.sum(1), atol=1e-04) # cf convergence gromov + p, Gb.sum(1), atol=1e-04) # cf convergence gromov np.testing.assert_allclose( - q, G.sum(0), atol=1e-04) # cf convergence gromov + q, Gb.sum(0), atol=1e-04) # cf convergence gromov -def test_pointwise_gromov(): +def test_pointwise_gromov(nx): n_samples = 50 # nb samples mu_s = np.array([0, 0]) @@ -122,33 +150,52 @@ def test_pointwise_gromov(): C1 /= C1.max() C2 /= C2.max() + C1b = nx.from_numpy(C1) + C2b = nx.from_numpy(C2) + pb = nx.from_numpy(p) + qb = nx.from_numpy(q) + def loss(x, y): return np.abs(x - y) + def lossb(x, y): + return nx.abs(x - y) + G, log = ot.gromov.pointwise_gromov_wasserstein( C1, C2, p, q, loss, max_iter=100, log=True, verbose=True, random_state=42) + G = NumpyBackend().todense(G) + Gb, logb = ot.gromov.pointwise_gromov_wasserstein( + C1b, C2b, pb, qb, lossb, max_iter=100, log=True, verbose=True, random_state=42) + Gb = nx.to_numpy(nx.todense(Gb)) # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) np.testing.assert_allclose( - p[:, np.newaxis], G.sum(1), atol=1e-04) # cf convergence gromov + p, Gb.sum(1), atol=1e-04) # cf convergence gromov np.testing.assert_allclose( - q[np.newaxis, :], G.sum(0), atol=1e-04) # cf convergence gromov + q, Gb.sum(0), atol=1e-04) # cf convergence gromov - assert log['gw_dist_estimated'] == 0.0 - assert log['gw_dist_std'] == 0.0 + np.testing.assert_allclose(logb['gw_dist_estimated'], 0.0, atol=1e-08) + np.testing.assert_allclose(logb['gw_dist_std'], 0.0, atol=1e-08) G, log = ot.gromov.pointwise_gromov_wasserstein( C1, C2, p, q, loss, max_iter=100, alpha=0.1, log=True, verbose=True, random_state=42) + G = NumpyBackend().todense(G) + Gb, logb = ot.gromov.pointwise_gromov_wasserstein( + C1b, C2b, pb, qb, lossb, max_iter=100, alpha=0.1, log=True, verbose=True, random_state=42) + Gb = nx.to_numpy(nx.todense(Gb)) - assert log['gw_dist_estimated'] == 0.10342276348494964 - assert log['gw_dist_std'] == 0.0015952535464736394 + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose(logb['gw_dist_estimated'], 0.10342276348494964, atol=1e-8) + np.testing.assert_allclose(logb['gw_dist_std'], 0.0015952535464736394, atol=1e-8) -def test_sampled_gromov(): +@pytest.skip_backend("jax", reason="test very slow with jax backend") +def test_sampled_gromov(nx): n_samples = 50 # nb samples - mu_s = np.array([0, 0]) - cov_s = np.array([[1, 0], [0, 1]]) + mu_s = np.array([0, 0], dtype=np.float64) + cov_s = np.array([[1, 0], [0, 1]], dtype=np.float64) xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=42) @@ -163,23 +210,35 @@ def test_sampled_gromov(): C1 /= C1.max() C2 /= C2.max() + C1b = nx.from_numpy(C1) + C2b = nx.from_numpy(C2) + pb = nx.from_numpy(p) + qb = nx.from_numpy(q) + def loss(x, y): return np.abs(x - y) + def lossb(x, y): + return nx.abs(x - y) + G, log = ot.gromov.sampled_gromov_wasserstein( C1, C2, p, q, loss, max_iter=100, epsilon=1, log=True, verbose=True, random_state=42) + Gb, logb = ot.gromov.sampled_gromov_wasserstein( + C1b, C2b, pb, qb, lossb, max_iter=100, epsilon=1, log=True, verbose=True, random_state=42) + Gb = nx.to_numpy(Gb) # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) np.testing.assert_allclose( - p, G.sum(1), atol=1e-04) # cf convergence gromov + p, Gb.sum(1), atol=1e-04) # cf convergence gromov np.testing.assert_allclose( - q, G.sum(0), atol=1e-04) # cf convergence gromov + q, Gb.sum(0), atol=1e-04) # cf convergence gromov - assert log['gw_dist_estimated'] == 0.05679474884977278 - assert log['gw_dist_std'] == 0.0005986592106971995 + np.testing.assert_allclose(logb['gw_dist_estimated'], 0.05679474884977278, atol=1e-08) + np.testing.assert_allclose(logb['gw_dist_std'], 0.0005986592106971995, atol=1e-08) -def test_gromov_barycenter(): +def test_gromov_barycenter(nx): ns = 10 nt = 20 @@ -188,26 +247,42 @@ def test_gromov_barycenter(): C1 = ot.dist(Xs) C2 = ot.dist(Xt) - + p1 = ot.unif(ns) + p2 = ot.unif(nt) n_samples = 3 - Cb = ot.gromov.gromov_barycenters(n_samples, [C1, C2], - [ot.unif(ns), ot.unif(nt) - ], ot.unif(n_samples), [.5, .5], - 'square_loss', # 5e-4, - max_iter=100, tol=1e-3, - verbose=True) - np.testing.assert_allclose(Cb.shape, (n_samples, n_samples)) + p = ot.unif(n_samples) - Cb2 = ot.gromov.gromov_barycenters(n_samples, [C1, C2], - [ot.unif(ns), ot.unif(nt) - ], ot.unif(n_samples), [.5, .5], - 'kl_loss', # 5e-4, - max_iter=100, tol=1e-3) - np.testing.assert_allclose(Cb2.shape, (n_samples, n_samples)) + C1b = nx.from_numpy(C1) + C2b = nx.from_numpy(C2) + p1b = nx.from_numpy(p1) + p2b = nx.from_numpy(p2) + pb = nx.from_numpy(p) + + Cb = ot.gromov.gromov_barycenters( + n_samples, [C1, C2], [p1, p2], p, [.5, .5], + 'square_loss', max_iter=100, tol=1e-3, verbose=True, random_state=42 + ) + Cbb = nx.to_numpy(ot.gromov.gromov_barycenters( + n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5], + 'square_loss', max_iter=100, tol=1e-3, verbose=True, random_state=42 + )) + np.testing.assert_allclose(Cb, Cbb, atol=1e-06) + np.testing.assert_allclose(Cbb.shape, (n_samples, n_samples)) + + Cb2 = ot.gromov.gromov_barycenters( + n_samples, [C1, C2], [p1, p2], p, [.5, .5], + 'kl_loss', max_iter=100, tol=1e-3, random_state=42 + ) + Cb2b = nx.to_numpy(ot.gromov.gromov_barycenters( + n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5], + 'kl_loss', max_iter=100, tol=1e-3, random_state=42 + )) + np.testing.assert_allclose(Cb2, Cb2b, atol=1e-06) + np.testing.assert_allclose(Cb2b.shape, (n_samples, n_samples)) @pytest.mark.filterwarnings("ignore:divide") -def test_gromov_entropic_barycenter(): +def test_gromov_entropic_barycenter(nx): ns = 10 nt = 20 @@ -216,26 +291,41 @@ def test_gromov_entropic_barycenter(): C1 = ot.dist(Xs) C2 = ot.dist(Xt) - + p1 = ot.unif(ns) + p2 = ot.unif(nt) n_samples = 2 - Cb = ot.gromov.entropic_gromov_barycenters(n_samples, [C1, C2], - [ot.unif(ns), ot.unif(nt) - ], ot.unif(n_samples), [.5, .5], - 'square_loss', 1e-3, - max_iter=50, tol=1e-3, - verbose=True) - np.testing.assert_allclose(Cb.shape, (n_samples, n_samples)) - - Cb2 = ot.gromov.entropic_gromov_barycenters(n_samples, [C1, C2], - [ot.unif(ns), ot.unif(nt) - ], ot.unif(n_samples), [.5, .5], - 'kl_loss', 1e-3, - max_iter=100, tol=1e-3) - np.testing.assert_allclose(Cb2.shape, (n_samples, n_samples)) - - -def test_fgw(): + p = ot.unif(n_samples) + C1b = nx.from_numpy(C1) + C2b = nx.from_numpy(C2) + p1b = nx.from_numpy(p1) + p2b = nx.from_numpy(p2) + pb = nx.from_numpy(p) + + Cb = ot.gromov.entropic_gromov_barycenters( + n_samples, [C1, C2], [p1, p2], p, [.5, .5], + 'square_loss', 1e-3, max_iter=50, tol=1e-3, verbose=True, random_state=42 + ) + Cbb = nx.to_numpy(ot.gromov.entropic_gromov_barycenters( + n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5], + 'square_loss', 1e-3, max_iter=50, tol=1e-3, verbose=True, random_state=42 + )) + np.testing.assert_allclose(Cb, Cbb, atol=1e-06) + np.testing.assert_allclose(Cbb.shape, (n_samples, n_samples)) + + Cb2 = ot.gromov.entropic_gromov_barycenters( + n_samples, [C1, C2], [p1, p2], p, [.5, .5], + 'kl_loss', 1e-3, max_iter=100, tol=1e-3, random_state=42 + ) + Cb2b = nx.to_numpy(ot.gromov.entropic_gromov_barycenters( + n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5], + 'kl_loss', 1e-3, max_iter=100, tol=1e-3, random_state=42 + )) + np.testing.assert_allclose(Cb2, Cb2b, atol=1e-06) + np.testing.assert_allclose(Cb2b.shape, (n_samples, n_samples)) + + +def test_fgw(nx): n_samples = 50 # nb samples mu_s = np.array([0, 0]) @@ -260,33 +350,46 @@ def test_fgw(): M = ot.dist(ys, yt) M /= M.max() + Mb = nx.from_numpy(M) + C1b = nx.from_numpy(C1) + C2b = nx.from_numpy(C2) + pb = nx.from_numpy(p) + qb = nx.from_numpy(q) + G, log = ot.gromov.fused_gromov_wasserstein(M, C1, C2, p, q, 'square_loss', alpha=0.5, log=True) + Gb, logb = ot.gromov.fused_gromov_wasserstein(Mb, C1b, C2b, pb, qb, 'square_loss', alpha=0.5, log=True) + Gb = nx.to_numpy(Gb) # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) np.testing.assert_allclose( - p, G.sum(1), atol=1e-04) # cf convergence fgw + p, Gb.sum(1), atol=1e-04) # cf convergence fgw np.testing.assert_allclose( - q, G.sum(0), atol=1e-04) # cf convergence fgw + q, Gb.sum(0), atol=1e-04) # cf convergence fgw Id = (1 / (1.0 * n_samples)) * np.eye(n_samples, n_samples) np.testing.assert_allclose( - G, np.flipud(Id), atol=1e-04) # cf convergence gromov + Gb, np.flipud(Id), atol=1e-04) # cf convergence gromov fgw, log = ot.gromov.fused_gromov_wasserstein2(M, C1, C2, p, q, 'square_loss', alpha=0.5, log=True) + fgwb, logb = ot.gromov.fused_gromov_wasserstein2(Mb, C1b, C2b, pb, qb, 'square_loss', alpha=0.5, log=True) G = log['T'] + Gb = nx.to_numpy(logb['T']) - np.testing.assert_allclose(fgw, 0, atol=1e-1, rtol=1e-1) + np.testing.assert_allclose(fgw, fgwb, atol=1e-08) + np.testing.assert_allclose(fgwb, 0, atol=1e-1, rtol=1e-1) # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) np.testing.assert_allclose( - p, G.sum(1), atol=1e-04) # cf convergence gromov + p, Gb.sum(1), atol=1e-04) # cf convergence gromov np.testing.assert_allclose( - q, G.sum(0), atol=1e-04) # cf convergence gromov + q, Gb.sum(0), atol=1e-04) # cf convergence gromov -def test_fgw_barycenter(): +def test_fgw_barycenter(nx): np.random.seed(42) ns = 50 @@ -300,30 +403,44 @@ def test_fgw_barycenter(): C1 = ot.dist(Xs) C2 = ot.dist(Xt) - + p1, p2 = ot.unif(ns), ot.unif(nt) n_samples = 3 - X, C = ot.gromov.fgw_barycenters(n_samples, [ys, yt], [C1, C2], [ot.unif(ns), ot.unif(nt)], [.5, .5], 0.5, - fixed_structure=False, fixed_features=False, - p=ot.unif(n_samples), loss_fun='square_loss', - max_iter=100, tol=1e-3) - np.testing.assert_allclose(C.shape, (n_samples, n_samples)) - np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1])) + p = ot.unif(n_samples) + + ysb = nx.from_numpy(ys) + ytb = nx.from_numpy(yt) + C1b = nx.from_numpy(C1) + C2b = nx.from_numpy(C2) + p1b = nx.from_numpy(p1) + p2b = nx.from_numpy(p2) + pb = nx.from_numpy(p) + + Xb, Cb = ot.gromov.fgw_barycenters( + n_samples, [ysb, ytb], [C1b, C2b], [p1b, p2b], [.5, .5], 0.5, fixed_structure=False, + fixed_features=False, p=pb, loss_fun='square_loss', max_iter=100, tol=1e-3, random_state=12345 + ) xalea = np.random.randn(n_samples, 2) init_C = ot.dist(xalea, xalea) - - X, C = ot.gromov.fgw_barycenters(n_samples, [ys, yt], [C1, C2], ps=[ot.unif(ns), ot.unif(nt)], lambdas=[.5, .5], alpha=0.5, - fixed_structure=True, init_C=init_C, fixed_features=False, - p=ot.unif(n_samples), loss_fun='square_loss', - max_iter=100, tol=1e-3) - np.testing.assert_allclose(C.shape, (n_samples, n_samples)) - np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1])) + init_Cb = nx.from_numpy(init_C) + + Xb, Cb = ot.gromov.fgw_barycenters( + n_samples, [ysb, ytb], [C1b, C2b], ps=[p1b, p2b], lambdas=[.5, .5], + alpha=0.5, fixed_structure=True, init_C=init_Cb, fixed_features=False, + p=pb, loss_fun='square_loss', max_iter=100, tol=1e-3 + ) + Xb, Cb = nx.to_numpy(Xb), nx.to_numpy(Cb) + np.testing.assert_allclose(Cb.shape, (n_samples, n_samples)) + np.testing.assert_allclose(Xb.shape, (n_samples, ys.shape[1])) init_X = np.random.randn(n_samples, ys.shape[1]) - - X, C, log = ot.gromov.fgw_barycenters(n_samples, [ys, yt], [C1, C2], [ot.unif(ns), ot.unif(nt)], [.5, .5], 0.5, - fixed_structure=False, fixed_features=True, init_X=init_X, - p=ot.unif(n_samples), loss_fun='square_loss', - max_iter=100, tol=1e-3, log=True) - np.testing.assert_allclose(C.shape, (n_samples, n_samples)) - np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1])) + init_Xb = nx.from_numpy(init_X) + + Xb, Cb, logb = ot.gromov.fgw_barycenters( + n_samples, [ysb, ytb], [C1b, C2b], [p1b, p2b], [.5, .5], 0.5, + fixed_structure=False, fixed_features=True, init_X=init_Xb, + p=pb, loss_fun='square_loss', max_iter=100, tol=1e-3, log=True, random_state=98765 + ) + Xb, Cb = nx.to_numpy(Xb), nx.to_numpy(Cb) + np.testing.assert_allclose(Cb.shape, (n_samples, n_samples)) + np.testing.assert_allclose(Xb.shape, (n_samples, ys.shape[1])) -- cgit v1.2.3 From 6775a527f9d3c801f8cdd805d8f205b6a75551b9 Mon Sep 17 00:00:00 2001 From: Nicolas Courty Date: Tue, 2 Nov 2021 14:19:57 +0100 Subject: [MRG] Sliced and 1D Wasserstein distances : backend versions (#256) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add numpy and torch backends * stat sets on functions * proper import * install recent torch on windows * install recent torch on windows * now testing all functions in backedn * add jax backedn * clenaup windowds * proper convert for jax backedn * pep8 * try again windows tests * test jax conversion * try proper widows tests * emd fuction ses backedn * better test partial OT * proper tests to_numpy and teplate Backend * pep8 * pep8 x2 * feaking sinkhorn works with torch * sinkhorn2 compatible * working ot.emd2 * important detach * it should work * jax autodiff emd * pep8 * no tast same for jax * new independat tests per backedn * freaking pep8 * add tests for gradients * deprecate ot.gpu * worging dist function * working dist * dist done in backedn * not in * remove indexing * change accuacy for jax * first pull backend * projection simplex * projection simplex * projection simplex * projection simplex no ci * projection simplex no ci * projection simplex no ci * pep8 * add backedn discusion to quickstart guide * projection simplex no ci * projection simplex no ci * projection simplex no ci * pep8 + better doc * proper links * corect doctest * big debug documentation * doctest again * doctest again bis * doctest again ter (last one or i kill myself) * backend test + doc proj simplex * correction test_utils * correction test_utils * correction cumsum * correction flip * correction flip v2 * more debug * more debug * more debug + pep8 * pep8 * argh * proj_simplex * backedn works for sort * proj simplex * jax sucks * update doc * Update test/test_utils.py Co-authored-by: Alexandre Gramfort * Update docs/source/quickstart.rst Co-authored-by: Alexandre Gramfort * Update docs/source/quickstart.rst Co-authored-by: Alexandre Gramfort * Update docs/source/quickstart.rst Co-authored-by: Alexandre Gramfort * Update docs/source/readme.rst Co-authored-by: Alexandre Gramfort * Update test/test_utils.py Co-authored-by: Alexandre Gramfort * Update ot/utils.py Co-authored-by: Alexandre Gramfort * Update docs/source/readme.rst Co-authored-by: Alexandre Gramfort * Update ot/lp/__init__.py Co-authored-by: Alexandre Gramfort * begin comment alex * comment alex part 2 * optimize test gromov * proj_simplex on vectors * add awesome gradient decsnt example on the weights * pep98 of course * proof read example by alex * pep8 again * encoding oos in translation * correct legend * new backend functions for sliced * small indent pb * Optimized backendversion of sliced W * error in sliced W * after master merge * error sliced * error sliced * pep8 * test_sliced pep8 * doctest + precision for sliced * doctest * type win test_backend gather * type win test_backend gather * Update sliced.py change argument of padding pad_width * Update backend.py update redefinition * Update backend.py pep8 * Update backend.py pep 8 again.... * pep8 * build docs * emd2_1D example * refectoring emd_1d and variants * remove unused previous wasserstein_1d * pep8 * upate example * move stuff * tesys should work + implemù random backend * test random generayor functions * correction * better random generation * update sliced * update sliced * proper tests sliced * max sliced * chae file nam * add stuff * example sliced flow and barycenter * correct typo + update readme * exemple sliced flow done * pep8 * solver1d works * pep8 Co-authored-by: Rémi Flamary Co-authored-by: Alexandre Gramfort --- README.md | 11 +- docs/source/readme.rst | 51 ++- .../backends/plot_sliced_wass_grad_flow_pytorch.py | 185 +++++++++++ examples/backends/plot_wass1d_torch.py | 152 +++++++++ examples/sliced-wasserstein/plot_variance.py | 2 +- ot/__init__.py | 5 +- ot/backend.py | 98 ++++++ ot/lp/__init__.py | 367 ++------------------- ot/lp/solver_1d.py | 367 +++++++++++++++++++++ ot/sliced.py | 181 ++++++++-- test/test_1d_solver.py | 85 +++++ test/test_backend.py | 36 ++ test/test_ot.py | 57 +--- test/test_sliced.py | 90 ++++- test/test_utils.py | 2 +- 15 files changed, 1244 insertions(+), 445 deletions(-) create mode 100644 examples/backends/plot_sliced_wass_grad_flow_pytorch.py create mode 100644 examples/backends/plot_wass1d_torch.py create mode 100644 ot/lp/solver_1d.py create mode 100644 test/test_1d_solver.py (limited to 'ot/lp') diff --git a/README.md b/README.md index f0e5227..cfb9744 100644 --- a/README.md +++ b/README.md @@ -33,7 +33,7 @@ POT provides the following generic OT solvers (links to examples): * [Unbalanced OT](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT_1D.html) with KL relaxation and [barycenter](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT_barycenter_1D.html) [10, 25]. * [Partial Wasserstein and Gromov-Wasserstein](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_partial_wass_and_gromov.html) (exact [29] and entropic [3] formulations). -* [Sliced Wasserstein](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_variance.html) [31, 32]. +* [Sliced Wasserstein](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_variance.html) [31, 32] and Max-sliced Wasserstein [35] that can be used for gradient flows [36]. * [Several backends](https://pythonot.github.io/quickstart.html#solving-ot-with-multiple-backends) for easy use of POT with [Pytorch](https://pytorch.org/)/[jax](https://github.com/google/jax)/[Numpy](https://numpy.org/) arrays. POT provides the following Machine Learning related solvers: @@ -285,4 +285,11 @@ You can also post bug reports and feature requests in Github issues. Make sure t [33] Kerdoncuff T., Emonet R., Marc S. [Sampled Gromov Wasserstein](https://hal.archives-ouvertes.fr/hal-03232509/document), Machine Learning Journal (MJL), 2021 -[34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., Trouvé, A., & Peyré, G. (2019, April). Interpolating between optimal transport and MMD using Sinkhorn divergences. In The 22nd International Conference on Artificial Intelligence and Statistics (pp. 2681-2690). PMLR. +[34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., Trouvé, A., & Peyré, G. (2019, April). [Interpolating between optimal transport and MMD using Sinkhorn divergences](http://proceedings.mlr.press/v89/feydy19a/feydy19a.pdf). In The 22nd International Conference on Artificial Intelligence and Statistics (pp. 2681-2690). PMLR. + +[35] Deshpande, I., Hu, Y. T., Sun, R., Pyrros, A., Siddiqui, N., Koyejo, S., ... & Schwing, A. G. (2019). [Max-sliced wasserstein distance and its use for gans](https://openaccess.thecvf.com/content_CVPR_2019/papers/Deshpande_Max-Sliced_Wasserstein_Distance_and_Its_Use_for_GANs_CVPR_2019_paper.pdf). In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (pp. 10648-10656). + +[36] Liutkus, A., Simsekli, U., Majewski, S., Durmus, A., & Stöter, F. R. +(2019, May). [Sliced-Wasserstein flows: Nonparametric generative modeling +via optimal transport and diffusions](http://proceedings.mlr.press/v97/liutkus19a/liutkus19a.pdf). In International Conference on +Machine Learning (pp. 4104-4113). PMLR. diff --git a/docs/source/readme.rst b/docs/source/readme.rst index 82d3e6c..ee32e2b 100644 --- a/docs/source/readme.rst +++ b/docs/source/readme.rst @@ -24,7 +24,7 @@ POT provides the following generic OT solvers (links to examples): for regularized OT [7]. - Entropic regularization OT solver with `Sinkhorn Knopp Algorithm `__ - [2] , stabilized version [9] [10], greedy Sinkhorn [22] and + [2] , stabilized version [9] [10] [34], greedy Sinkhorn [22] and `Screening Sinkhorn [26] `__. - Bregman projections for `Wasserstein @@ -54,6 +54,9 @@ POT provides the following generic OT solvers (links to examples): solver `__ for Large-scale Optimal Transport (semi-dual problem [18] and dual problem [19]) +- `Stochastic solver of Gromov + Wasserstein `__ + for large-scale problem with any loss functions [33] - Non regularized `free support Wasserstein barycenters `__ [20]. @@ -137,19 +140,12 @@ following Python modules: - Numpy (>=1.16) - Scipy (>=1.0) -- Cython (>=0.23) (build only, not necessary when installing wheels - from pip or conda) +- Cython (>=0.23) (build only, not necessary when installing from pip + or conda) Pip installation ^^^^^^^^^^^^^^^^ -Note that due to a limitation of pip, ``cython`` and ``numpy`` need to -be installed prior to installing POT. This can be done easily with - -.. code:: console - - pip install numpy cython - You can install the toolbox through PyPI with: .. code:: console @@ -183,7 +179,8 @@ without errors: import ot -Note that for easier access the module is name ot instead of pot. +Note that for easier access the module is named ``ot`` instead of +``pot``. Dependencies ~~~~~~~~~~~~ @@ -222,7 +219,7 @@ Short examples .. code:: python - # a and b are 1D histograms (sum to 1 and positive) + # a,b are 1D histograms (sum to 1 and positive) # M is the ground cost matrix Wd = ot.emd2(a, b, M) # exact linear program Wd_reg = ot.sinkhorn2(a, b, M, reg) # entropic regularized OT @@ -232,7 +229,7 @@ Short examples .. code:: python - # a and b are 1D histograms (sum to 1 and positive) + # a,b are 1D histograms (sum to 1 and positive) # M is the ground cost matrix T = ot.emd(a, b, M) # exact linear program T_reg = ot.sinkhorn(a, b, M, reg) # entropic regularized OT @@ -287,6 +284,10 @@ The contributors to this library are - `Ievgen Redko `__ (Laplacian DA, JCPOT) - `Adrien Corenflos `__ (Sliced Wasserstein Distance) +- `Tanguy Kerdoncuff `__ (Sampled Gromov + Wasserstein) +- `Minhui Huang `__ (Projection Robust + Wasserstein Distance) This toolbox benefit a lot from open source research and we would like to thank the following persons for providing some code (in various @@ -476,6 +477,30 @@ of measures `__, Journal of Mathematical Imaging and Vision 51.1 (2015): 22-45 +[32] Huang, M., Ma S., Lai, L. (2021). `A Riemannian Block Coordinate +Descent Method for Computing the Projection Robust Wasserstein +Distance `__, +Proceedings of the 38th International Conference on Machine Learning +(ICML). + +[33] Kerdoncuff T., Emonet R., Marc S. `Sampled Gromov +Wasserstein `__, +Machine Learning Journal (MJL), 2021 + +[34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., Trouvé, A., +& Peyré, G. (2019, April). `Interpolating between optimal transport and +MMD using Sinkhorn +divergences `__. +In The 22nd International Conference on Artificial Intelligence and +Statistics (pp. 2681-2690). PMLR. + +[35] Deshpande, I., Hu, Y. T., Sun, R., Pyrros, A., Siddiqui, N., +Koyejo, S., ... & Schwing, A. G. (2019). `Max-sliced wasserstein +distance and its use for +gans `__. +In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern +Recognition (pp. 10648-10656). + .. |PyPI version| image:: https://badge.fury.io/py/POT.svg :target: https://badge.fury.io/py/POT .. |Anaconda Cloud| image:: https://anaconda.org/conda-forge/pot/badges/version.svg diff --git a/examples/backends/plot_sliced_wass_grad_flow_pytorch.py b/examples/backends/plot_sliced_wass_grad_flow_pytorch.py new file mode 100644 index 0000000..05b9952 --- /dev/null +++ b/examples/backends/plot_sliced_wass_grad_flow_pytorch.py @@ -0,0 +1,185 @@ +r""" +================================= +Sliced Wasserstein barycenter and gradient flow with PyTorch +================================= + +In this exemple we use the pytorch backend to optimize the sliced Wasserstein +loss between two empirical distributions [31]. + +In the first example one we perform a +gradient flow on the support of a distribution that minimize the sliced +Wassersein distance as poposed in [36]. + +In the second exemple we optimize with a gradient descent the sliced +Wasserstein barycenter between two distributions as in [31]. + +[31] Bonneel, Nicolas, et al. "Sliced and radon wasserstein barycenters of +measures." Journal of Mathematical Imaging and Vision 51.1 (2015): 22-45 + +[36] Liutkus, A., Simsekli, U., Majewski, S., Durmus, A., & Stöter, F. R. +(2019, May). Sliced-Wasserstein flows: Nonparametric generative modeling +via optimal transport and diffusions. In International Conference on +Machine Learning (pp. 4104-4113). PMLR. + + +""" +# Author: Rémi Flamary +# +# License: MIT License + + +# %% +# Loading the data + + +import numpy as np +import matplotlib.pylab as pl +import torch +import ot +import matplotlib.animation as animation + +I1 = pl.imread('../../data/redcross.png').astype(np.float64)[::4, ::4, 2] +I2 = pl.imread('../../data/tooth.png').astype(np.float64)[::4, ::4, 2] + +sz = I2.shape[0] +XX, YY = np.meshgrid(np.arange(sz), np.arange(sz)) + +x1 = np.stack((XX[I1 == 0], YY[I1 == 0]), 1) * 1.0 +x2 = np.stack((XX[I2 == 0] + 60, -YY[I2 == 0] + 32), 1) * 1.0 +x3 = np.stack((XX[I2 == 0], -YY[I2 == 0] + 32), 1) * 1.0 + +pl.figure(1, (8, 4)) +pl.scatter(x1[:, 0], x1[:, 1], alpha=0.5) +pl.scatter(x2[:, 0], x2[:, 1], alpha=0.5) + +# %% +# Sliced Wasserstein gradient flow with Pytorch +# --------------------------------------------- + + +device = "cuda" if torch.cuda.is_available() else "cpu" + +# use pyTorch for our data +x1_torch = torch.tensor(x1).to(device=device).requires_grad_(True) +x2_torch = torch.tensor(x2).to(device=device) + + +lr = 1e3 +nb_iter_max = 100 + +x_all = np.zeros((nb_iter_max, x1.shape[0], 2)) + +loss_iter = [] + +# generator for random permutations +gen = torch.Generator() +gen.manual_seed(42) + +for i in range(nb_iter_max): + + loss = ot.sliced_wasserstein_distance(x1_torch, x2_torch, n_projections=20, seed=gen) + + loss_iter.append(loss.clone().detach().cpu().numpy()) + loss.backward() + + # performs a step of projected gradient descent + with torch.no_grad(): + grad = x1_torch.grad + x1_torch -= grad * lr / (1 + i / 5e1) # step + x1_torch.grad.zero_() + x_all[i, :, :] = x1_torch.clone().detach().cpu().numpy() + +xb = x1_torch.clone().detach().cpu().numpy() + +pl.figure(2, (8, 4)) +pl.scatter(x1[:, 0], x1[:, 1], alpha=0.5, label='$\mu^{(0)}$') +pl.scatter(x2[:, 0], x2[:, 1], alpha=0.5, label=r'$\nu$') +pl.scatter(xb[:, 0], xb[:, 1], alpha=0.5, label='$\mu^{(100)}$') +pl.title('Sliced Wasserstein gradient flow') +pl.legend() +ax = pl.axis() + +# %% +# Animate trajectories of the gradient flow along iteration +# ------------------------------------------------------- + +pl.figure(3, (8, 4)) + + +def _update_plot(i): + pl.clf() + pl.scatter(x1[:, 0], x1[:, 1], alpha=0.5, label='$\mu^{(0)}$') + pl.scatter(x2[:, 0], x2[:, 1], alpha=0.5, label=r'$\nu$') + pl.scatter(x_all[i, :, 0], x_all[i, :, 1], alpha=0.5, label='$\mu^{(100)}$') + pl.title('Sliced Wasserstein gradient flow Iter. {}'.format(i)) + pl.axis(ax) + return 1 + + +ani = animation.FuncAnimation(pl.gcf(), _update_plot, nb_iter_max, interval=100, repeat_delay=2000) + +# %% +# Compute the Sliced Wasserstein Barycenter +# +x1_torch = torch.tensor(x1).to(device=device) +x3_torch = torch.tensor(x3).to(device=device) +xbinit = np.random.randn(500, 2) * 10 + 16 +xbary_torch = torch.tensor(xbinit).to(device=device).requires_grad_(True) + +lr = 1e3 +nb_iter_max = 100 + +x_all = np.zeros((nb_iter_max, xbary_torch.shape[0], 2)) + +loss_iter = [] + +# generator for random permutations +gen = torch.Generator() +gen.manual_seed(42) + +alpha = 0.5 + +for i in range(nb_iter_max): + + loss = alpha * ot.sliced_wasserstein_distance(xbary_torch, x3_torch, n_projections=50, seed=gen) \ + + (1 - alpha) * ot.sliced_wasserstein_distance(xbary_torch, x1_torch, n_projections=50, seed=gen) + + loss_iter.append(loss.clone().detach().cpu().numpy()) + loss.backward() + + # performs a step of projected gradient descent + with torch.no_grad(): + grad = xbary_torch.grad + xbary_torch -= grad * lr # / (1 + i / 5e1) # step + xbary_torch.grad.zero_() + x_all[i, :, :] = xbary_torch.clone().detach().cpu().numpy() + +xb = xbary_torch.clone().detach().cpu().numpy() + +pl.figure(4, (8, 4)) +pl.scatter(x1[:, 0], x1[:, 1], alpha=0.5, label='$\mu$') +pl.scatter(x2[:, 0], x2[:, 1], alpha=0.5, label=r'$\nu$') +pl.scatter(xb[:, 0] + 30, xb[:, 1], alpha=0.5, label='Barycenter') +pl.title('Sliced Wasserstein barycenter') +pl.legend() +ax = pl.axis() + + +# %% +# Animate trajectories of the barycenter along gradient descent +# ------------------------------------------------------- + +pl.figure(5, (8, 4)) + + +def _update_plot(i): + pl.clf() + pl.scatter(x1[:, 0], x1[:, 1], alpha=0.5, label='$\mu^{(0)}$') + pl.scatter(x2[:, 0], x2[:, 1], alpha=0.5, label=r'$\nu$') + pl.scatter(x_all[i, :, 0] + 30, x_all[i, :, 1], alpha=0.5, label='$\mu^{(100)}$') + pl.title('Sliced Wasserstein barycenter Iter. {}'.format(i)) + pl.axis(ax) + return 1 + + +ani = animation.FuncAnimation(pl.gcf(), _update_plot, nb_iter_max, interval=100, repeat_delay=2000) diff --git a/examples/backends/plot_wass1d_torch.py b/examples/backends/plot_wass1d_torch.py new file mode 100644 index 0000000..0abdd6d --- /dev/null +++ b/examples/backends/plot_wass1d_torch.py @@ -0,0 +1,152 @@ +r""" +================================= +Wasserstein 1D with PyTorch +================================= + +In this small example, we consider the following minization problem: + +.. math:: + \mu^* = \min_\mu W(\mu,\nu) + +where :math:`\nu` is a reference 1D measure. The problem is handled +by a projected gradient descent method, where the gradient is computed +by pyTorch automatic differentiation. The projection on the simplex +ensures that the iterate will remain on the probability simplex. + +This example illustrates both `wasserstein_1d` function and backend use within +the POT framework. +""" +# Author: Nicolas Courty +# Rémi Flamary +# +# License: MIT License + +import numpy as np +import matplotlib.pylab as pl +import matplotlib as mpl +import torch + +from ot.lp import wasserstein_1d +from ot.datasets import make_1D_gauss as gauss +from ot.utils import proj_simplex + +red = np.array(mpl.colors.to_rgb('red')) +blue = np.array(mpl.colors.to_rgb('blue')) + + +n = 100 # nb bins + +# bin positions +x = np.arange(n, dtype=np.float64) + +# Gaussian distributions +a = gauss(n, m=20, s=5) # m= mean, s= std +b = gauss(n, m=60, s=10) + +# enforce sum to one on the support +a = a / a.sum() +b = b / b.sum() + + +device = "cuda" if torch.cuda.is_available() else "cpu" + +# use pyTorch for our data +x_torch = torch.tensor(x).to(device=device) +a_torch = torch.tensor(a).to(device=device).requires_grad_(True) +b_torch = torch.tensor(b).to(device=device) + +lr = 1e-6 +nb_iter_max = 800 + +loss_iter = [] + +pl.figure(1, figsize=(8, 4)) +pl.plot(x, a, 'b', label='Source distribution') +pl.plot(x, b, 'r', label='Target distribution') + +for i in range(nb_iter_max): + # Compute the Wasserstein 1D with torch backend + loss = wasserstein_1d(x_torch, x_torch, a_torch, b_torch, p=2) + # record the corresponding loss value + loss_iter.append(loss.clone().detach().cpu().numpy()) + loss.backward() + + # performs a step of projected gradient descent + with torch.no_grad(): + grad = a_torch.grad + a_torch -= a_torch.grad * lr # step + a_torch.grad.zero_() + a_torch.data = proj_simplex(a_torch) # projection onto the simplex + + # plot one curve every 10 iterations + if i % 10 == 0: + mix = float(i) / nb_iter_max + pl.plot(x, a_torch.clone().detach().cpu().numpy(), c=(1 - mix) * blue + mix * red) + +pl.legend() +pl.title('Distribution along the iterations of the projected gradient descent') +pl.show() + +pl.figure(2) +pl.plot(range(nb_iter_max), loss_iter, lw=3) +pl.title('Evolution of the loss along iterations', fontsize=16) +pl.show() + +# %% +# Wasserstein barycenter +# --------- +# In this example, we consider the following Wasserstein barycenter problem +# $$ \\eta^* = \\min_\\eta\;\;\; (1-t)W(\\mu,\\eta) + tW(\\eta,\\nu)$$ +# where :math:`\\mu` and :math:`\\nu` are reference 1D measures, and :math:`t` +# is a parameter :math:`\in [0,1]`. The problem is handled by a project gradient +# descent method, where the gradient is computed by pyTorch automatic differentiation. +# The projection on the simplex ensures that the iterate will remain on the +# probability simplex. +# +# This example illustrates both `wasserstein_1d` function and backend use within the +# POT framework. + + +device = "cuda" if torch.cuda.is_available() else "cpu" + +# use pyTorch for our data +x_torch = torch.tensor(x).to(device=device) +a_torch = torch.tensor(a).to(device=device) +b_torch = torch.tensor(b).to(device=device) +bary_torch = torch.tensor((a + b).copy() / 2).to(device=device).requires_grad_(True) + + +lr = 1e-6 +nb_iter_max = 2000 + +loss_iter = [] + +# instant of the interpolation +t = 0.5 + +for i in range(nb_iter_max): + # Compute the Wasserstein 1D with torch backend + loss = (1 - t) * wasserstein_1d(x_torch, x_torch, a_torch.detach(), bary_torch, p=2) + t * wasserstein_1d(x_torch, x_torch, b_torch, bary_torch, p=2) + # record the corresponding loss value + loss_iter.append(loss.clone().detach().cpu().numpy()) + loss.backward() + + # performs a step of projected gradient descent + with torch.no_grad(): + grad = bary_torch.grad + bary_torch -= bary_torch.grad * lr # step + bary_torch.grad.zero_() + bary_torch.data = proj_simplex(bary_torch) # projection onto the simplex + +pl.figure(3, figsize=(8, 4)) +pl.plot(x, a, 'b', label='Source distribution') +pl.plot(x, b, 'r', label='Target distribution') +pl.plot(x, bary_torch.clone().detach().cpu().numpy(), c='green', label='W barycenter') +pl.legend() +pl.title('Wasserstein barycenter computed by gradient descent') +pl.show() + +pl.figure(4) +pl.plot(range(nb_iter_max), loss_iter, lw=3) +pl.title('Evolution of the loss along iterations', fontsize=16) +pl.show() diff --git a/examples/sliced-wasserstein/plot_variance.py b/examples/sliced-wasserstein/plot_variance.py index 27df656..7d73907 100644 --- a/examples/sliced-wasserstein/plot_variance.py +++ b/examples/sliced-wasserstein/plot_variance.py @@ -63,7 +63,7 @@ res = np.empty((n_seed, 25)) # %% Compute statistics for seed in range(n_seed): for i, n_projections in enumerate(n_projections_arr): - res[seed, i] = ot.sliced_wasserstein_distance(xs, xt, a, b, n_projections, seed) + res[seed, i] = ot.sliced_wasserstein_distance(xs, xt, a, b, n_projections, seed=seed) res_mean = np.mean(res, axis=0) res_std = np.std(res, axis=0) diff --git a/ot/__init__.py b/ot/__init__.py index 5bd4bab..f20332c 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -42,7 +42,7 @@ from .bregman import sinkhorn, sinkhorn2, barycenter from .unbalanced import (sinkhorn_unbalanced, barycenter_unbalanced, sinkhorn_unbalanced2) from .da import sinkhorn_lpl1_mm -from .sliced import sliced_wasserstein_distance +from .sliced import sliced_wasserstein_distance, max_sliced_wasserstein_distance # utils functions from .utils import dist, unif, tic, toc, toq @@ -51,8 +51,9 @@ __version__ = "0.8.0dev" __all__ = ['emd', 'emd2', 'emd_1d', 'sinkhorn', 'sinkhorn2', 'utils', 'datasets', 'bregman', 'lp', 'tic', 'toc', 'toq', 'gromov', - 'emd_1d', 'emd2_1d', 'wasserstein_1d', + 'emd2_1d', 'wasserstein_1d', 'backend', 'dist', 'unif', 'barycenter', 'sinkhorn_lpl1_mm', 'da', 'optim', 'sinkhorn_unbalanced', 'barycenter_unbalanced', 'sinkhorn_unbalanced2', 'sliced_wasserstein_distance', + 'max_sliced_wasserstein_distance', 'smooth', 'stochastic', 'unbalanced', 'partial', 'regpath'] diff --git a/ot/backend.py b/ot/backend.py index 358297c..d3df44c 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -103,6 +103,8 @@ class Backend(): __name__ = None __type__ = None + rng_ = None + def __str__(self): return self.__name__ @@ -540,6 +542,36 @@ class Backend(): """ raise NotImplementedError() + def seed(self, seed=None): + r""" + Sets the seed for the random generator. + + This function follows the api from :any:`numpy.random.seed` + + See: https://numpy.org/doc/stable/reference/generated/numpy.random.seed.html + """ + raise NotImplementedError() + + def rand(self, *size, type_as=None): + r""" + Generate uniform random numbers. + + This function follows the api from :any:`numpy.random.rand` + + See: https://numpy.org/doc/stable/reference/generated/numpy.random.rand.html + """ + raise NotImplementedError() + + def randn(self, *size, type_as=None): + r""" + Generate normal Gaussian random numbers. + + This function follows the api from :any:`numpy.random.rand` + + See: https://numpy.org/doc/stable/reference/generated/numpy.random.rand.html + """ + raise NotImplementedError() + def coo_matrix(self, data, rows, cols, shape=None, type_as=None): r""" Creates a sparse tensor in COOrdinate format. @@ -632,6 +664,8 @@ class NumpyBackend(Backend): __name__ = 'numpy' __type__ = np.ndarray + rng_ = np.random.RandomState() + def to_numpy(self, a): return a @@ -793,6 +827,16 @@ class NumpyBackend(Backend): def reshape(self, a, shape): return np.reshape(a, shape) + def seed(self, seed=None): + if seed is not None: + self.rng_.seed(seed) + + def rand(self, *size, type_as=None): + return self.rng_.rand(*size) + + def randn(self, *size, type_as=None): + return self.rng_.randn(*size) + def coo_matrix(self, data, rows, cols, shape=None, type_as=None): if type_as is None: return coo_matrix((data, (rows, cols)), shape=shape) @@ -845,6 +889,11 @@ class JaxBackend(Backend): __name__ = 'jax' __type__ = jax_type + rng_ = None + + def __init__(self): + self.rng_ = jax.random.PRNGKey(42) + def to_numpy(self, a): return np.array(a) @@ -1010,6 +1059,24 @@ class JaxBackend(Backend): def reshape(self, a, shape): return jnp.reshape(a, shape) + def seed(self, seed=None): + if seed is not None: + self.rng_ = jax.random.PRNGKey(seed) + + def rand(self, *size, type_as=None): + self.rng_, subkey = jax.random.split(self.rng_) + if type_as is not None: + return jax.random.uniform(subkey, shape=size, dtype=type_as.dtype) + else: + return jax.random.uniform(subkey, shape=size) + + def randn(self, *size, type_as=None): + self.rng_, subkey = jax.random.split(self.rng_) + if type_as is not None: + return jax.random.normal(subkey, shape=size, dtype=type_as.dtype) + else: + return jax.random.normal(subkey, shape=size) + def coo_matrix(self, data, rows, cols, shape=None, type_as=None): # Currently, JAX does not support sparse matrices data = self.to_numpy(data) @@ -1064,8 +1131,13 @@ class TorchBackend(Backend): __name__ = 'torch' __type__ = torch_type + rng_ = None + def __init__(self): + self.rng_ = torch.Generator() + self.rng_.seed() + from torch.autograd import Function # define a function that takes inputs val and grads @@ -1102,12 +1174,16 @@ class TorchBackend(Backend): return res def zeros(self, shape, type_as=None): + if isinstance(shape, int): + shape = (shape,) if type_as is None: return torch.zeros(shape) else: return torch.zeros(shape, dtype=type_as.dtype, device=type_as.device) def ones(self, shape, type_as=None): + if isinstance(shape, int): + shape = (shape,) if type_as is None: return torch.ones(shape) else: @@ -1120,6 +1196,8 @@ class TorchBackend(Backend): return torch.arange(start, stop, step, device=type_as.device) def full(self, shape, fill_value, type_as=None): + if isinstance(shape, int): + shape = (shape,) if type_as is None: return torch.full(shape, fill_value) else: @@ -1293,6 +1371,26 @@ class TorchBackend(Backend): def reshape(self, a, shape): return torch.reshape(a, shape) + def seed(self, seed=None): + if isinstance(seed, int): + self.rng_.manual_seed(seed) + elif isinstance(seed, torch.Generator): + self.rng_ = seed + else: + raise ValueError("Non compatible seed : {}".format(seed)) + + def rand(self, *size, type_as=None): + if type_as is not None: + return torch.rand(size=size, generator=self.rng_, dtype=type_as.dtype, device=type_as.device) + else: + return torch.rand(size=size, generator=self.rng_) + + def randn(self, *size, type_as=None): + if type_as is not None: + return torch.randn(size=size, dtype=type_as.dtype, generator=self.rng_, device=type_as.device) + else: + return torch.randn(size=size, generator=self.rng_) + def coo_matrix(self, data, rows, cols, shape=None, type_as=None): if type_as is None: return torch.sparse_coo_tensor(torch.stack([rows, cols]), data, size=shape) diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index 4e95ccf..2c18a88 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -13,20 +13,23 @@ import multiprocessing import sys import numpy as np -from scipy.sparse import coo_matrix import warnings from . import cvx from .cvx import barycenter + # import compiled emd from .emd_wrap import emd_c, check_result, emd_1d_sorted +from .solver_1d import emd_1d, emd2_1d, wasserstein_1d + from ..utils import dist, list_to_array from ..utils import parmap from ..backend import get_backend -__all__ = ['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx', +__all__ = ['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx', ' emd_1d_sorted', 'emd_1d', 'emd2_1d', 'wasserstein_1d'] + def check_number_threads(numThreads): """Checks whether or not the requested number of threads has a valid value. @@ -115,10 +118,10 @@ def estimate_dual_null_weights(alpha0, beta0, a, b, M): .. warning:: This function is necessary because the C++ solver in emd_c - discards all samples in the distributions with - zeros weights. This means that while the primal variable (transport + discards all samples in the distributions with + zeros weights. This means that while the primal variable (transport matrix) is exact, the solver only returns feasible dual potentials - on the samples with weights different from zero. + on the samples with weights different from zero. First we compute the constraints violations: @@ -215,26 +218,26 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1): format .. note:: This function is backend-compatible and will work on arrays - from all compatible backends. + from all compatible backends. Uses the algorithm proposed in [1]_ Parameters ---------- - a : (ns,) array-like, float + a : (ns,) array-like, float Source histogram (uniform weight if empty list) - b : (nt,) array-like, float - Target histogram (uniform weight if empty list) - M : (ns,nt) array-like, float - Loss matrix (c-order array in numpy with type float64) - numItermax : int, optional (default=100000) + b : (nt,) array-like, float + Target histogram (uniform weight if empty list) + M : (ns,nt) array-like, float + Loss matrix (c-order array in numpy with type float64) + numItermax : int, optional (default=100000) The maximum number of iterations before stopping the optimization - algorithm if it has not converged. - log: bool, optional (default=False) - If True, returns a dictionary containing the cost and dual variables. - Otherwise returns only the optimal transportation matrix. + algorithm if it has not converged. + log: bool, optional (default=False) + If True, returns a dictionary containing the cost and dual variables. + Otherwise returns only the optimal transportation matrix. center_dual: boolean, optional (default=True) - If True, centers the dual potential using function + If True, centers the dual potential using function :ref:`center_ot_dual`. numThreads: int or "max", optional (default=1, i.e. OpenMP is not used) If compiled with OpenMP, chooses the number of threads to parallelize. @@ -242,9 +245,9 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1): Returns ------- - gamma: array-like, shape (ns, nt) + gamma: array-like, shape (ns, nt) Optimal transportation matrix for the given - parameters + parameters log: dict, optional If input log is true, a dictionary containing the cost and dual variables and exit status @@ -277,10 +280,10 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1): regularized OT""" # convert to numpy if list - a, b, M = list_to_array(a, b, M) + a, b, M = list_to_array(a, b, M) a0, b0, M0 = a, b, M - nx = get_backend(M0, a0, b0) + nx = get_backend(M0, a0, b0) # convert to numpy M = nx.to_numpy(M) @@ -302,9 +305,9 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1): "Dimension mismatch, check dimensions of M with a and b" # ensure that same mass - np.testing.assert_almost_equal(a.sum(0), - b.sum(0), err_msg='a and b vector must have the same sum') - b=b*a.sum()/b.sum() + np.testing.assert_almost_equal(a.sum(0), + b.sum(0), err_msg='a and b vector must have the same sum') + b = b * a.sum() / b.sum() asel = a != 0 bsel = b != 0 @@ -415,10 +418,10 @@ def emd2(a, b, M, processes=1, ot.bregman.sinkhorn : Entropic regularized OT ot.optim.cg : General regularized OT""" - a, b, M = list_to_array(a, b, M) + a, b, M = list_to_array(a, b, M) a0, b0, M0 = a, b, M - nx = get_backend(M0, a0, b0) + nx = get_backend(M0, a0, b0) # convert to numpy M = nx.to_numpy(M) @@ -427,7 +430,7 @@ def emd2(a, b, M, processes=1, a = np.asarray(a, dtype=np.float64) b = np.asarray(b, dtype=np.float64) - M = np.asarray(M, dtype=np.float64, order= 'C') + M = np.asarray(M, dtype=np.float64, order='C') # if empty array given then use uniform distributions if len(a) == 0: @@ -463,8 +466,8 @@ def emd2(a, b, M, processes=1, log['v'] = nx.from_numpy(v, type_as=b0) log['warning'] = result_code_string log['result_code'] = result_code - cost = nx.set_gradients(nx.from_numpy(cost, type_as=M0), - (a0,b0, M0), (log['u'], log['v'], G)) + cost = nx.set_gradients(nx.from_numpy(cost, type_as=M0), + (a0, b0, M0), (log['u'], log['v'], G)) return [cost, log] else: def f(b): @@ -479,8 +482,8 @@ def emd2(a, b, M, processes=1, G = nx.from_numpy(G, type_as=M0) cost = nx.set_gradients(nx.from_numpy(cost, type_as=M0), - (a0,b0, M0), (nx.from_numpy(u, type_as=a0), - nx.from_numpy(v, type_as=b0),G)) + (a0, b0, M0), (nx.from_numpy(u, type_as=a0), + nx.from_numpy(v, type_as=b0), G)) check_result(result_code) return cost @@ -603,305 +606,3 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None return X, log_dict else: return X - - -def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True, - log=False): - r"""Solves the Earth Movers distance problem between 1d measures and returns - the OT matrix - - - .. math:: - \gamma = arg\min_\gamma \sum_i \sum_j \gamma_{ij} d(x_a[i], x_b[j]) - - s.t. \gamma 1 = a, - \gamma^T 1= b, - \gamma\geq 0 - where : - - - d is the metric - - x_a and x_b are the samples - - a and b are the sample weights - - When 'minkowski' is used as a metric, :math:`d(x, y) = |x - y|^p`. - - Uses the algorithm detailed in [1]_ - - Parameters - ---------- - x_a : (ns,) or (ns, 1) ndarray, float64 - Source dirac locations (on the real line) - x_b : (nt,) or (ns, 1) ndarray, float64 - Target dirac locations (on the real line) - a : (ns,) ndarray, float64, optional - Source histogram (default is uniform weight) - b : (nt,) ndarray, float64, optional - Target histogram (default is uniform weight) - metric: str, optional (default='sqeuclidean') - Metric to be used. Only strings listed in :func:`ot.dist` are accepted. - Due to implementation details, this function runs faster when - `'sqeuclidean'`, `'cityblock'`, or `'euclidean'` metrics are used. - p: float, optional (default=1.0) - The p-norm to apply for if metric='minkowski' - dense: boolean, optional (default=True) - If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt). - Otherwise returns a sparse representation using scipy's `coo_matrix` - format. Due to implementation details, this function runs faster when - `'sqeuclidean'`, `'minkowski'`, `'cityblock'`, or `'euclidean'` metrics - are used. - log: boolean, optional (default=False) - If True, returns a dictionary containing the cost. - Otherwise returns only the optimal transportation matrix. - - Returns - ------- - gamma: (ns, nt) ndarray - Optimal transportation matrix for the given parameters - log: dict - If input log is True, a dictionary containing the cost - - - Examples - -------- - - Simple example with obvious solution. The function emd_1d accepts lists and - performs automatic conversion to numpy arrays - - >>> import ot - >>> a=[.5, .5] - >>> b=[.5, .5] - >>> x_a = [2., 0.] - >>> x_b = [0., 3.] - >>> ot.emd_1d(x_a, x_b, a, b) - array([[0. , 0.5], - [0.5, 0. ]]) - >>> ot.emd_1d(x_a, x_b) - array([[0. , 0.5], - [0.5, 0. ]]) - - References - ---------- - - .. [1] Peyré, G., & Cuturi, M. (2017). "Computational Optimal - Transport", 2018. - - See Also - -------- - ot.lp.emd : EMD for multidimensional distributions - ot.lp.emd2_1d : EMD for 1d distributions (returns cost instead of the - transportation matrix) - """ - a, b, x_a, x_b = list_to_array(a, b, x_a, x_b) - nx = get_backend(x_a, x_b) - - assert (x_a.ndim == 1 or x_a.ndim == 2 and x_a.shape[1] == 1), \ - "emd_1d should only be used with monodimensional data" - assert (x_b.ndim == 1 or x_b.ndim == 2 and x_b.shape[1] == 1), \ - "emd_1d should only be used with monodimensional data" - - # if empty array given then use uniform distributions - if a is None or a.ndim == 0 or len(a) == 0: - a = nx.ones((x_a.shape[0],), type_as=x_a) / x_a.shape[0] - if b is None or b.ndim == 0 or len(b) == 0: - b = nx.ones((x_b.shape[0],), type_as=x_b) / x_b.shape[0] - - # ensure that same mass - np.testing.assert_almost_equal( - nx.sum(a, axis=0), - nx.sum(b, axis=0), - err_msg='a and b vector must have the same sum' - ) - b = b * nx.sum(a) / nx.sum(b) - - x_a_1d = nx.reshape(x_a, (-1,)) - x_b_1d = nx.reshape(x_b, (-1,)) - perm_a = nx.argsort(x_a_1d) - perm_b = nx.argsort(x_b_1d) - - G_sorted, indices, cost = emd_1d_sorted( - nx.to_numpy(a[perm_a]), - nx.to_numpy(b[perm_b]), - nx.to_numpy(x_a_1d[perm_a]), - nx.to_numpy(x_b_1d[perm_b]), - metric=metric, p=p - ) - - G = nx.coo_matrix( - G_sorted, - perm_a[indices[:, 0]], - perm_b[indices[:, 1]], - shape=(a.shape[0], b.shape[0]), - type_as=x_a - ) - if dense: - G = nx.todense(G) - elif str(nx) == "jax": - warnings.warn("JAX does not support sparse matrices, converting to dense") - if log: - log = {'cost': cost} - return G, log - return G - - -def emd2_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True, - log=False): - r"""Solves the Earth Movers distance problem between 1d measures and returns - the loss - - - .. math:: - \gamma = arg\min_\gamma \sum_i \sum_j \gamma_{ij} d(x_a[i], x_b[j]) - - s.t. \gamma 1 = a, - \gamma^T 1= b, - \gamma\geq 0 - where : - - - d is the metric - - x_a and x_b are the samples - - a and b are the sample weights - - When 'minkowski' is used as a metric, :math:`d(x, y) = |x - y|^p`. - - Uses the algorithm detailed in [1]_ - - Parameters - ---------- - x_a : (ns,) or (ns, 1) ndarray, float64 - Source dirac locations (on the real line) - x_b : (nt,) or (ns, 1) ndarray, float64 - Target dirac locations (on the real line) - a : (ns,) ndarray, float64, optional - Source histogram (default is uniform weight) - b : (nt,) ndarray, float64, optional - Target histogram (default is uniform weight) - metric: str, optional (default='sqeuclidean') - Metric to be used. Only strings listed in :func:`ot.dist` are accepted. - Due to implementation details, this function runs faster when - `'sqeuclidean'`, `'minkowski'`, `'cityblock'`, or `'euclidean'` metrics - are used. - p: float, optional (default=1.0) - The p-norm to apply for if metric='minkowski' - dense: boolean, optional (default=True) - If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt). - Otherwise returns a sparse representation using scipy's `coo_matrix` - format. Only used if log is set to True. Due to implementation details, - this function runs faster when dense is set to False. - log: boolean, optional (default=False) - If True, returns a dictionary containing the transportation matrix. - Otherwise returns only the loss. - - Returns - ------- - loss: float - Cost associated to the optimal transportation - log: dict - If input log is True, a dictionary containing the Optimal transportation - matrix for the given parameters - - - Examples - -------- - - Simple example with obvious solution. The function emd2_1d accepts lists and - performs automatic conversion to numpy arrays - - >>> import ot - >>> a=[.5, .5] - >>> b=[.5, .5] - >>> x_a = [2., 0.] - >>> x_b = [0., 3.] - >>> ot.emd2_1d(x_a, x_b, a, b) - 0.5 - >>> ot.emd2_1d(x_a, x_b) - 0.5 - - References - ---------- - - .. [1] Peyré, G., & Cuturi, M. (2017). "Computational Optimal - Transport", 2018. - - See Also - -------- - ot.lp.emd2 : EMD for multidimensional distributions - ot.lp.emd_1d : EMD for 1d distributions (returns the transportation matrix - instead of the cost) - """ - # If we do not return G (log==False), then we should not to cast it to dense - # (useless overhead) - G, log_emd = emd_1d(x_a=x_a, x_b=x_b, a=a, b=b, metric=metric, p=p, - dense=dense and log, log=True) - cost = log_emd['cost'] - if log: - log_emd = {'G': G} - return cost, log_emd - return cost - - -def wasserstein_1d(x_a, x_b, a=None, b=None, p=1.): - r"""Solves the p-Wasserstein distance problem between 1d measures and returns - the distance - - .. math:: - \min_\gamma \left( \sum_i \sum_j \gamma_{ij} \|x_a[i] - x_b[j]\|^p \right)^{1/p} - - s.t. \gamma 1 = a, - \gamma^T 1= b, - \gamma\geq 0 - - where : - - - x_a and x_b are the samples - - a and b are the sample weights - - Uses the algorithm detailed in [1]_ - - Parameters - ---------- - x_a : (ns,) or (ns, 1) ndarray, float64 - Source dirac locations (on the real line) - x_b : (nt,) or (ns, 1) ndarray, float64 - Target dirac locations (on the real line) - a : (ns,) ndarray, float64, optional - Source histogram (default is uniform weight) - b : (nt,) ndarray, float64, optional - Target histogram (default is uniform weight) - p: float, optional (default=1.0) - The order of the p-Wasserstein distance to be computed - - Returns - ------- - dist: float - p-Wasserstein distance - - - Examples - -------- - - Simple example with obvious solution. The function wasserstein_1d accepts - lists and performs automatic conversion to numpy arrays - - >>> import ot - >>> a=[.5, .5] - >>> b=[.5, .5] - >>> x_a = [2., 0.] - >>> x_b = [0., 3.] - >>> ot.wasserstein_1d(x_a, x_b, a, b) - 0.5 - >>> ot.wasserstein_1d(x_a, x_b) - 0.5 - - References - ---------- - - .. [1] Peyré, G., & Cuturi, M. (2017). "Computational Optimal - Transport", 2018. - - See Also - -------- - ot.lp.emd_1d : EMD for 1d distributions - """ - cost_emd = emd2_1d(x_a=x_a, x_b=x_b, a=a, b=b, metric='minkowski', p=p, - dense=False, log=False) - return np.power(cost_emd, 1. / p) diff --git a/ot/lp/solver_1d.py b/ot/lp/solver_1d.py new file mode 100644 index 0000000..42554aa --- /dev/null +++ b/ot/lp/solver_1d.py @@ -0,0 +1,367 @@ +# -*- coding: utf-8 -*- +""" +Exact solvers for the 1D Wasserstein distance using cvxopt +""" + +# Author: Remi Flamary +# Author: Nicolas Courty +# +# License: MIT License + +import numpy as np +import warnings + +from .emd_wrap import emd_1d_sorted +from ..backend import get_backend +from ..utils import list_to_array + + +def quantile_function(qs, cws, xs): + r""" Computes the quantile function of an empirical distribution + + Parameters + ---------- + qs: array-like, shape (n,) + Quantiles at which the quantile function is evaluated + cws: array-like, shape (m, ...) + cumulative weights of the 1D empirical distribution, if batched, must be similar to xs + xs: array-like, shape (n, ...) + locations of the 1D empirical distribution, batched against the `xs.ndim - 1` first dimensions + + Returns + ------- + q: array-like, shape (..., n) + The quantiles of the distribution + """ + nx = get_backend(qs, cws) + n = xs.shape[0] + if nx.__name__ == 'torch': + # this is to ensure the best performance for torch searchsorted + # and avoid a warninng related to non-contiguous arrays + cws = cws.T.contiguous() + qs = qs.T.contiguous() + else: + cws = cws.T + qs = qs.T + idx = nx.searchsorted(cws, qs).T + return nx.take_along_axis(xs, nx.clip(idx, 0, n - 1), axis=0) + + +def wasserstein_1d(u_values, v_values, u_weights=None, v_weights=None, p=1, require_sort=True): + r""" + Computes the 1 dimensional OT loss [15] between two (batched) empirical + distributions + + .. math: + OT_{loss} = \int_0^1 |cdf_u^{-1}(q) cdf_v^{-1}(q)|^p dq + + It is formally the p-Wasserstein distance raised to the power p. + We do so in a vectorized way by first building the individual quantile functions then integrating them. + + This function should be preferred to `emd_1d` whenever the backend is + different to numpy, and when gradients over + either sample positions or weights are required. + + Parameters + ---------- + u_values: array-like, shape (n, ...) + locations of the first empirical distribution + v_values: array-like, shape (m, ...) + locations of the second empirical distribution + u_weights: array-like, shape (n, ...), optional + weights of the first empirical distribution, if None then uniform weights are used + v_weights: array-like, shape (m, ...), optional + weights of the second empirical distribution, if None then uniform weights are used + p: int, optional + order of the ground metric used, should be at least 1 (see [2, Chap. 2], default is 1 + require_sort: bool, optional + sort the distributions atoms locations, if False we will consider they have been sorted prior to being passed to + the function, default is True + + Returns + ------- + cost: float/array-like, shape (...) + the batched EMD + + References + ---------- + .. [15] Peyré, G., & Cuturi, M. (2018). Computational Optimal Transport. + + """ + + assert p >= 1, "The OT loss is only valid for p>=1, {p} was given".format(p=p) + + if u_weights is not None and v_weights is not None: + nx = get_backend(u_values, v_values, u_weights, v_weights) + else: + nx = get_backend(u_values, v_values) + + n = u_values.shape[0] + m = v_values.shape[0] + + if u_weights is None: + u_weights = nx.full(u_values.shape, 1. / n) + elif u_weights.ndim != u_values.ndim: + u_weights = nx.repeat(u_weights[..., None], u_values.shape[-1], -1) + if v_weights is None: + v_weights = nx.full(v_values.shape, 1. / m) + elif v_weights.ndim != v_values.ndim: + v_weights = nx.repeat(v_weights[..., None], v_values.shape[-1], -1) + + if require_sort: + u_sorter = nx.argsort(u_values, 0) + u_values = nx.take_along_axis(u_values, u_sorter, 0) + + v_sorter = nx.argsort(v_values, 0) + v_values = nx.take_along_axis(v_values, v_sorter, 0) + + u_weights = nx.take_along_axis(u_weights, u_sorter, 0) + v_weights = nx.take_along_axis(v_weights, v_sorter, 0) + + u_cumweights = nx.cumsum(u_weights, 0) + v_cumweights = nx.cumsum(v_weights, 0) + + qs = nx.sort(nx.concatenate((u_cumweights, v_cumweights), 0), 0) + u_quantiles = quantile_function(qs, u_cumweights, u_values) + v_quantiles = quantile_function(qs, v_cumweights, v_values) + qs = nx.zero_pad(qs, pad_width=[(1, 0)] + (qs.ndim - 1) * [(0, 0)]) + delta = qs[1:, ...] - qs[:-1, ...] + diff_quantiles = nx.abs(u_quantiles - v_quantiles) + + if p == 1: + return nx.sum(delta * nx.abs(diff_quantiles), axis=0) + return nx.sum(delta * nx.power(diff_quantiles, p), axis=0) + + +def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True, + log=False): + r"""Solves the Earth Movers distance problem between 1d measures and returns + the OT matrix + + + .. math:: + \gamma = arg\min_\gamma \sum_i \sum_j \gamma_{ij} d(x_a[i], x_b[j]) + + s.t. \gamma 1 = a, + \gamma^T 1= b, + \gamma\geq 0 + where : + + - d is the metric + - x_a and x_b are the samples + - a and b are the sample weights + + When 'minkowski' is used as a metric, :math:`d(x, y) = |x - y|^p`. + + Uses the algorithm detailed in [1]_ + + Parameters + ---------- + x_a : (ns,) or (ns, 1) ndarray, float64 + Source dirac locations (on the real line) + x_b : (nt,) or (ns, 1) ndarray, float64 + Target dirac locations (on the real line) + a : (ns,) ndarray, float64, optional + Source histogram (default is uniform weight) + b : (nt,) ndarray, float64, optional + Target histogram (default is uniform weight) + metric: str, optional (default='sqeuclidean') + Metric to be used. Only strings listed in :func:`ot.dist` are accepted. + Due to implementation details, this function runs faster when + `'sqeuclidean'`, `'cityblock'`, or `'euclidean'` metrics are used. + p: float, optional (default=1.0) + The p-norm to apply for if metric='minkowski' + dense: boolean, optional (default=True) + If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt). + Otherwise returns a sparse representation using scipy's `coo_matrix` + format. Due to implementation details, this function runs faster when + `'sqeuclidean'`, `'minkowski'`, `'cityblock'`, or `'euclidean'` metrics + are used. + log: boolean, optional (default=False) + If True, returns a dictionary containing the cost. + Otherwise returns only the optimal transportation matrix. + + Returns + ------- + gamma: (ns, nt) ndarray + Optimal transportation matrix for the given parameters + log: dict + If input log is True, a dictionary containing the cost + + + Examples + -------- + + Simple example with obvious solution. The function emd_1d accepts lists and + performs automatic conversion to numpy arrays + + >>> import ot + >>> a=[.5, .5] + >>> b=[.5, .5] + >>> x_a = [2., 0.] + >>> x_b = [0., 3.] + >>> ot.emd_1d(x_a, x_b, a, b) + array([[0. , 0.5], + [0.5, 0. ]]) + >>> ot.emd_1d(x_a, x_b) + array([[0. , 0.5], + [0.5, 0. ]]) + + References + ---------- + + .. [1] Peyré, G., & Cuturi, M. (2017). "Computational Optimal + Transport", 2018. + + See Also + -------- + ot.lp.emd : EMD for multidimensional distributions + ot.lp.emd2_1d : EMD for 1d distributions (returns cost instead of the + transportation matrix) + """ + a, b, x_a, x_b = list_to_array(a, b, x_a, x_b) + nx = get_backend(x_a, x_b) + + assert (x_a.ndim == 1 or x_a.ndim == 2 and x_a.shape[1] == 1), \ + "emd_1d should only be used with monodimensional data" + assert (x_b.ndim == 1 or x_b.ndim == 2 and x_b.shape[1] == 1), \ + "emd_1d should only be used with monodimensional data" + + # if empty array given then use uniform distributions + if a is None or a.ndim == 0 or len(a) == 0: + a = nx.ones((x_a.shape[0],), type_as=x_a) / x_a.shape[0] + if b is None or b.ndim == 0 or len(b) == 0: + b = nx.ones((x_b.shape[0],), type_as=x_b) / x_b.shape[0] + + # ensure that same mass + np.testing.assert_almost_equal( + nx.sum(a, axis=0), + nx.sum(b, axis=0), + err_msg='a and b vector must have the same sum' + ) + b = b * nx.sum(a) / nx.sum(b) + + x_a_1d = nx.reshape(x_a, (-1,)) + x_b_1d = nx.reshape(x_b, (-1,)) + perm_a = nx.argsort(x_a_1d) + perm_b = nx.argsort(x_b_1d) + + G_sorted, indices, cost = emd_1d_sorted( + nx.to_numpy(a[perm_a]), + nx.to_numpy(b[perm_b]), + nx.to_numpy(x_a_1d[perm_a]), + nx.to_numpy(x_b_1d[perm_b]), + metric=metric, p=p + ) + + G = nx.coo_matrix( + G_sorted, + perm_a[indices[:, 0]], + perm_b[indices[:, 1]], + shape=(a.shape[0], b.shape[0]), + type_as=x_a + ) + if dense: + G = nx.todense(G) + elif str(nx) == "jax": + warnings.warn("JAX does not support sparse matrices, converting to dense") + if log: + log = {'cost': cost} + return G, log + return G + + +def emd2_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True, + log=False): + r"""Solves the Earth Movers distance problem between 1d measures and returns + the loss + + + .. math:: + \gamma = arg\min_\gamma \sum_i \sum_j \gamma_{ij} d(x_a[i], x_b[j]) + + s.t. \gamma 1 = a, + \gamma^T 1= b, + \gamma\geq 0 + where : + + - d is the metric + - x_a and x_b are the samples + - a and b are the sample weights + + When 'minkowski' is used as a metric, :math:`d(x, y) = |x - y|^p`. + + Uses the algorithm detailed in [1]_ + + Parameters + ---------- + x_a : (ns,) or (ns, 1) ndarray, float64 + Source dirac locations (on the real line) + x_b : (nt,) or (ns, 1) ndarray, float64 + Target dirac locations (on the real line) + a : (ns,) ndarray, float64, optional + Source histogram (default is uniform weight) + b : (nt,) ndarray, float64, optional + Target histogram (default is uniform weight) + metric: str, optional (default='sqeuclidean') + Metric to be used. Only strings listed in :func:`ot.dist` are accepted. + Due to implementation details, this function runs faster when + `'sqeuclidean'`, `'minkowski'`, `'cityblock'`, or `'euclidean'` metrics + are used. + p: float, optional (default=1.0) + The p-norm to apply for if metric='minkowski' + dense: boolean, optional (default=True) + If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt). + Otherwise returns a sparse representation using scipy's `coo_matrix` + format. Only used if log is set to True. Due to implementation details, + this function runs faster when dense is set to False. + log: boolean, optional (default=False) + If True, returns a dictionary containing the transportation matrix. + Otherwise returns only the loss. + + Returns + ------- + loss: float + Cost associated to the optimal transportation + log: dict + If input log is True, a dictionary containing the Optimal transportation + matrix for the given parameters + + + Examples + -------- + + Simple example with obvious solution. The function emd2_1d accepts lists and + performs automatic conversion to numpy arrays + + >>> import ot + >>> a=[.5, .5] + >>> b=[.5, .5] + >>> x_a = [2., 0.] + >>> x_b = [0., 3.] + >>> ot.emd2_1d(x_a, x_b, a, b) + 0.5 + >>> ot.emd2_1d(x_a, x_b) + 0.5 + + References + ---------- + + .. [1] Peyré, G., & Cuturi, M. (2017). "Computational Optimal + Transport", 2018. + + See Also + -------- + ot.lp.emd2 : EMD for multidimensional distributions + ot.lp.emd_1d : EMD for 1d distributions (returns the transportation matrix + instead of the cost) + """ + # If we do not return G (log==False), then we should not to cast it to dense + # (useless overhead) + G, log_emd = emd_1d(x_a=x_a, x_b=x_b, a=a, b=b, metric=metric, p=p, + dense=dense and log, log=True) + cost = log_emd['cost'] + if log: + log_emd = {'G': G} + return cost, log_emd + return cost diff --git a/ot/sliced.py b/ot/sliced.py index 4792576..d3dc3f2 100644 --- a/ot/sliced.py +++ b/ot/sliced.py @@ -1,61 +1,73 @@ """ -Sliced Wasserstein Distance. +Sliced OT Distances """ # Author: Adrien Corenflos +# Nicolas Courty +# Rémi Flamary # # License: MIT License import numpy as np +from .backend import get_backend, NumpyBackend +from .utils import list_to_array -def get_random_projections(n_projections, d, seed=None): +def get_random_projections(d, n_projections, seed=None, backend=None, type_as=None): r""" Generates n_projections samples from the uniform on the unit sphere of dimension d-1: :math:`\mathcal{U}(\mathcal{S}^{d-1})` Parameters ---------- - n_projections : int - number of samples requested d : int dimension of the space + n_projections : int + number of samples requested seed: int or RandomState, optional Seed used for numpy random number generator + backend: + Backend to ue for random generation Returns ------- - out: ndarray, shape (n_projections, d) + out: ndarray, shape (d, n_projections) The uniform unit vectors on the sphere Examples -------- >>> n_projections = 100 >>> d = 5 - >>> projs = get_random_projections(n_projections, d) - >>> np.allclose(np.sum(np.square(projs), 1), 1.) # doctest: +NORMALIZE_WHITESPACE + >>> projs = get_random_projections(d, n_projections) + >>> np.allclose(np.sum(np.square(projs), 0), 1.) # doctest: +NORMALIZE_WHITESPACE True """ - if not isinstance(seed, np.random.RandomState): - random_state = np.random.RandomState(seed) + if backend is None: + nx = NumpyBackend() + else: + nx = backend + + if isinstance(seed, np.random.RandomState) and str(nx) == 'numpy': + projections = seed.randn(d, n_projections) else: - random_state = seed + if seed is not None: + nx.seed(seed) + projections = nx.randn(d, n_projections, type_as=type_as) - projections = random_state.normal(0., 1., [n_projections, d]) - norm = np.linalg.norm(projections, ord=2, axis=1, keepdims=True) - projections = projections / norm + projections = projections / nx.sqrt(nx.sum(projections**2, 0, keepdims=True)) return projections -def sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, seed=None, log=False): +def sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, p=2, + projections=None, seed=None, log=False): r""" - Computes a Monte-Carlo approximation of the 2-Sliced Wasserstein distance + Computes a Monte-Carlo approximation of the p-Sliced Wasserstein distance .. math:: - \mathcal{SWD}_2(\mu, \nu) = \underset{\theta \sim \mathcal{U}(\mathbb{S}^{d-1})}{\mathbb{E}}[\mathcal{W}_2^2(\theta_\# \mu, \theta_\# \nu)]^{\frac{1}{2}} + \mathcal{SWD}_p(\mu, \nu) = \underset{\theta \sim \mathcal{U}(\mathbb{S}^{d-1})}{\mathbb{E}}[\mathcal{W}_p^p(\theta_\# \mu, \theta_\# \nu)]^{\frac{1}{p}} where : @@ -74,8 +86,12 @@ def sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, seed samples weights in the target domain n_projections : int, optional Number of projections used for the Monte-Carlo approximation + p: float, optional = + Power p used for computing the sliced Wasserstein + projections: shape (dim, n_projections), optional + Projection matrix (n_projections and seed are not used in this case) seed: int or RandomState or None, optional - Seed used for numpy random number generator + Seed used for random number generator log: bool, optional if True, sliced_wasserstein_distance returns the projections used and their associated EMD. @@ -100,10 +116,18 @@ def sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, seed .. [31] Bonneel, Nicolas, et al. "Sliced and radon wasserstein barycenters of measures." Journal of Mathematical Imaging and Vision 51.1 (2015): 22-45 """ - from .lp import emd2_1d + from .lp import wasserstein_1d - X_s = np.asanyarray(X_s) - X_t = np.asanyarray(X_t) + X_s, X_t = list_to_array(X_s, X_t) + + if a is not None and b is not None and projections is None: + nx = get_backend(X_s, X_t, a, b) + elif a is not None and b is not None and projections is not None: + nx = get_backend(X_s, X_t, a, b, projections) + elif a is None and b is None and projections is not None: + nx = get_backend(X_s, X_t, projections) + else: + nx = get_backend(X_s, X_t) n = X_s.shape[0] m = X_t.shape[0] @@ -114,31 +138,120 @@ def sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, seed X_t.shape[1])) if a is None: - a = np.full(n, 1 / n) + a = nx.full(n, 1 / n) if b is None: - b = np.full(m, 1 / m) + b = nx.full(m, 1 / m) d = X_s.shape[1] - projections = get_random_projections(n_projections, d, seed) + if projections is None: + projections = get_random_projections(d, n_projections, seed, backend=nx, type_as=X_s) + + X_s_projections = nx.dot(X_s, projections) + X_t_projections = nx.dot(X_t, projections) - X_s_projections = np.dot(projections, X_s.T) - X_t_projections = np.dot(projections, X_t.T) + projected_emd = wasserstein_1d(X_s_projections, X_t_projections, a, b, p=p) + res = (nx.sum(projected_emd) / n_projections) ** (1.0 / p) if log: - projected_emd = np.empty(n_projections) + return res, {"projections": projections, "projected_emds": projected_emd} + return res + + +def max_sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, p=2, + projections=None, seed=None, log=False): + r""" + Computes a Monte-Carlo approximation of the max p-Sliced Wasserstein distance + + .. math:: + \mathcal{Max-SWD}_p(\mu, \nu) = \underset{\theta _in + \mathcal{U}(\mathbb{S}^{d-1})}{\max} [\mathcal{W}_p^p(\theta_\# + \mu, \theta_\# \nu)]^{\frac{1}{p}} + + where : + + - :math:`\theta_\# \mu` stands for the pushforwars of the projection :math:`\mathbb{R}^d \ni X \mapsto \langle \theta, X \rangle` + + + Parameters + ---------- + X_s : ndarray, shape (n_samples_a, dim) + samples in the source domain + X_t : ndarray, shape (n_samples_b, dim) + samples in the target domain + a : ndarray, shape (n_samples_a,), optional + samples weights in the source domain + b : ndarray, shape (n_samples_b,), optional + samples weights in the target domain + n_projections : int, optional + Number of projections used for the Monte-Carlo approximation + p: float, optional = + Power p used for computing the sliced Wasserstein + projections: shape (dim, n_projections), optional + Projection matrix (n_projections and seed are not used in this case) + seed: int or RandomState or None, optional + Seed used for random number generator + log: bool, optional + if True, sliced_wasserstein_distance returns the projections used and their associated EMD. + + Returns + ------- + cost: float + Sliced Wasserstein Cost + log : dict, optional + log dictionary return only if log==True in parameters + + Examples + -------- + + >>> n_samples_a = 20 + >>> reg = 0.1 + >>> X = np.random.normal(0., 1., (n_samples_a, 5)) + >>> sliced_wasserstein_distance(X, X, seed=0) # doctest: +NORMALIZE_WHITESPACE + 0.0 + + References + ---------- + + .. [35] Deshpande, I., Hu, Y. T., Sun, R., Pyrros, A., Siddiqui, N., Koyejo, S., ... & Schwing, A. G. (2019). Max-sliced wasserstein distance and its use for gans. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (pp. 10648-10656). + """ + from .lp import wasserstein_1d + + X_s, X_t = list_to_array(X_s, X_t) + + if a is not None and b is not None and projections is None: + nx = get_backend(X_s, X_t, a, b) + elif a is not None and b is not None and projections is not None: + nx = get_backend(X_s, X_t, a, b, projections) + elif a is None and b is None and projections is not None: + nx = get_backend(X_s, X_t, projections) else: - projected_emd = None + nx = get_backend(X_s, X_t) + + n = X_s.shape[0] + m = X_t.shape[0] + + if X_s.shape[1] != X_t.shape[1]: + raise ValueError( + "X_s and X_t must have the same number of dimensions {} and {} respectively given".format(X_s.shape[1], + X_t.shape[1])) + + if a is None: + a = nx.full(n, 1 / n) + if b is None: + b = nx.full(m, 1 / m) + + d = X_s.shape[1] + + if projections is None: + projections = get_random_projections(d, n_projections, seed, backend=nx, type_as=X_s) - res = 0. + X_s_projections = nx.dot(X_s, projections) + X_t_projections = nx.dot(X_t, projections) - for i, (X_s_proj, X_t_proj) in enumerate(zip(X_s_projections, X_t_projections)): - emd = emd2_1d(X_s_proj, X_t_proj, a, b, log=False, dense=False) - if projected_emd is not None: - projected_emd[i] = emd - res += emd + projected_emd = wasserstein_1d(X_s_projections, X_t_projections, a, b, p=p) - res = (res / n_projections) ** 0.5 + res = nx.max(projected_emd) ** (1.0 / p) if log: return res, {"projections": projections, "projected_emds": projected_emd} return res diff --git a/test/test_1d_solver.py b/test/test_1d_solver.py new file mode 100644 index 0000000..2c470c2 --- /dev/null +++ b/test/test_1d_solver.py @@ -0,0 +1,85 @@ +"""Tests for module 1d Wasserstein solver""" + +# Author: Adrien Corenflos +# Nicolas Courty +# +# License: MIT License + +import numpy as np +import pytest + +import ot +from ot.lp import wasserstein_1d + +from ot.backend import get_backend_list +from scipy.stats import wasserstein_distance + +backend_list = get_backend_list() + + +def test_emd_1d_emd2_1d_with_weights(): + # test emd1d gives similar results as emd + n = 20 + m = 30 + rng = np.random.RandomState(0) + u = rng.randn(n, 1) + v = rng.randn(m, 1) + + w_u = rng.uniform(0., 1., n) + w_u = w_u / w_u.sum() + + w_v = rng.uniform(0., 1., m) + w_v = w_v / w_v.sum() + + M = ot.dist(u, v, metric='sqeuclidean') + + G, log = ot.emd(w_u, w_v, M, log=True) + wass = log["cost"] + G_1d, log = ot.emd_1d(u, v, w_u, w_v, metric='sqeuclidean', log=True) + wass1d = log["cost"] + wass1d_emd2 = ot.emd2_1d(u, v, w_u, w_v, metric='sqeuclidean', log=False) + wass1d_euc = ot.emd2_1d(u, v, w_u, w_v, metric='euclidean', log=False) + + # check loss is similar + np.testing.assert_allclose(wass, wass1d) + np.testing.assert_allclose(wass, wass1d_emd2) + + # check loss is similar to scipy's implementation for Euclidean metric + wass_sp = wasserstein_distance(u.reshape((-1,)), v.reshape((-1,)), w_u, w_v) + np.testing.assert_allclose(wass_sp, wass1d_euc) + + # check constraints + np.testing.assert_allclose(w_u, G.sum(1)) + np.testing.assert_allclose(w_v, G.sum(0)) + + +@pytest.mark.parametrize('nx', backend_list) +def test_wasserstein_1d(nx): + from scipy.stats import wasserstein_distance + + rng = np.random.RandomState(0) + + n = 100 + x = np.linspace(0, 5, n) + rho_u = np.abs(rng.randn(n)) + rho_u /= rho_u.sum() + rho_v = np.abs(rng.randn(n)) + rho_v /= rho_v.sum() + + xb = nx.from_numpy(x) + rho_ub = nx.from_numpy(rho_u) + rho_vb = nx.from_numpy(rho_v) + + # test 1 : wasserstein_1d should be close to scipy W_1 implementation + np.testing.assert_almost_equal(wasserstein_1d(xb, xb, rho_ub, rho_vb, p=1), + wasserstein_distance(x, x, rho_u, rho_v)) + + # test 2 : wasserstein_1d should be close to one when only translating the support + np.testing.assert_almost_equal(wasserstein_1d(xb, xb + 1, p=2), + 1.) + + # test 3 : arrays test + X = np.stack((np.linspace(0, 5, n), np.linspace(0, 5, n) * 10), -1) + Xb = nx.from_numpy(X) + res = wasserstein_1d(Xb, Xb, rho_ub, rho_vb, p=2) + np.testing.assert_almost_equal(100 * res[0], res[1], decimal=4) diff --git a/test/test_backend.py b/test/test_backend.py index 0f11ace..1832b91 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -208,6 +208,11 @@ def test_empty_backend(): with pytest.raises(NotImplementedError): nx.reshape(M, (5, 3, 2)) with pytest.raises(NotImplementedError): + nx.seed(42) + with pytest.raises(NotImplementedError): + nx.rand() + with pytest.raises(NotImplementedError): + nx.randn() nx.coo_matrix(M, M, M) with pytest.raises(NotImplementedError): nx.issparse(M) @@ -248,6 +253,7 @@ def test_func_backends(nx): Mb = nx.from_numpy(M) vb = nx.from_numpy(v) + val = nx.from_numpy(val) sp_rowb = nx.from_numpy(sp_row) @@ -255,6 +261,7 @@ def test_func_backends(nx): sp_datab = nx.from_numpy(sp_data) A = nx.set_gradients(val, v, v) + lst_b.append(nx.to_numpy(A)) lst_name.append('set_gradients') @@ -505,6 +512,35 @@ def test_func_backends(nx): assert np.allclose(a1, a2, atol=1e-7) +def test_random_backends(nx): + + tmp_u = nx.rand() + + assert tmp_u < 1 + + tmp_n = nx.randn() + + nx.seed(0) + M1 = nx.to_numpy(nx.rand(5, 2)) + nx.seed(0) + M2 = nx.to_numpy(nx.rand(5, 2, type_as=tmp_n)) + + assert np.all(M1 >= 0) + assert np.all(M1 < 1) + assert M1.shape == (5, 2) + assert np.allclose(M1, M2) + + nx.seed(0) + M1 = nx.to_numpy(nx.randn(5, 2)) + nx.seed(0) + M2 = nx.to_numpy(nx.randn(5, 2, type_as=tmp_u)) + + nx.seed(42) + v1 = nx.randn() + v2 = nx.randn() + assert v1 != v2 + + def test_gradients_backends(): rnd = np.random.RandomState(0) diff --git a/test/test_ot.py b/test/test_ot.py index 4dfc510..5bfde1d 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -8,11 +8,11 @@ import warnings import numpy as np import pytest -from scipy.stats import wasserstein_distance import ot from ot.datasets import make_1D_gauss as gauss from ot.backend import torch +from scipy.stats import wasserstein_distance def test_emd_dimension_and_mass_mismatch(): @@ -165,61 +165,6 @@ def test_emd_1d_emd2_1d(): ot.emd_1d(u, v, [], []) -def test_emd_1d_emd2_1d_with_weights(): - # test emd1d gives similar results as emd - n = 20 - m = 30 - rng = np.random.RandomState(0) - u = rng.randn(n, 1) - v = rng.randn(m, 1) - - w_u = rng.uniform(0., 1., n) - w_u = w_u / w_u.sum() - - w_v = rng.uniform(0., 1., m) - w_v = w_v / w_v.sum() - - M = ot.dist(u, v, metric='sqeuclidean') - - G, log = ot.emd(w_u, w_v, M, log=True) - wass = log["cost"] - G_1d, log = ot.emd_1d(u, v, w_u, w_v, metric='sqeuclidean', log=True) - wass1d = log["cost"] - wass1d_emd2 = ot.emd2_1d(u, v, w_u, w_v, metric='sqeuclidean', log=False) - wass1d_euc = ot.emd2_1d(u, v, w_u, w_v, metric='euclidean', log=False) - - # check loss is similar - np.testing.assert_allclose(wass, wass1d) - np.testing.assert_allclose(wass, wass1d_emd2) - - # check loss is similar to scipy's implementation for Euclidean metric - wass_sp = wasserstein_distance(u.reshape((-1,)), v.reshape((-1,)), w_u, w_v) - np.testing.assert_allclose(wass_sp, wass1d_euc) - - # check constraints - np.testing.assert_allclose(w_u, G.sum(1)) - np.testing.assert_allclose(w_v, G.sum(0)) - - -def test_wass_1d(): - # test emd1d gives similar results as emd - n = 20 - m = 30 - rng = np.random.RandomState(0) - u = rng.randn(n, 1) - v = rng.randn(m, 1) - - M = ot.dist(u, v, metric='sqeuclidean') - - G, log = ot.emd([], [], M, log=True) - wass = log["cost"] - - wass1d = ot.wasserstein_1d(u, v, [], [], p=2.) - - # check loss is similar - np.testing.assert_allclose(np.sqrt(wass), wass1d) - - def test_emd_empty(): # test emd and emd2 for simple identity n = 100 diff --git a/test/test_sliced.py b/test/test_sliced.py index a07d975..0bd74ec 100644 --- a/test/test_sliced.py +++ b/test/test_sliced.py @@ -1,6 +1,7 @@ """Tests for module sliced""" # Author: Adrien Corenflos +# Nicolas Courty # # License: MIT License @@ -14,7 +15,7 @@ from ot.sliced import get_random_projections def test_get_random_projections(): rng = np.random.RandomState(0) projections = get_random_projections(1000, 50, rng) - np.testing.assert_almost_equal(np.sum(projections ** 2, 1), 1.) + np.testing.assert_almost_equal(np.sum(projections ** 2, 0), 1.) def test_sliced_same_dist(): @@ -48,12 +49,12 @@ def test_sliced_log(): y = rng.randn(n, 4) u = ot.utils.unif(n) - res, log = ot.sliced_wasserstein_distance(x, y, u, u, 10, seed=rng, log=True) + res, log = ot.sliced_wasserstein_distance(x, y, u, u, 10, p=1, seed=rng, log=True) assert len(log) == 2 projections = log["projections"] projected_emds = log["projected_emds"] - assert len(projections) == len(projected_emds) == 10 + assert projections.shape[1] == len(projected_emds) == 10 for emd in projected_emds: assert emd > 0 @@ -83,3 +84,86 @@ def test_1d_sliced_equals_emd(): res = ot.sliced_wasserstein_distance(x, y, a, u, 10, seed=42) expected = ot.emd2_1d(x.squeeze(), y.squeeze(), a, u) np.testing.assert_almost_equal(res ** 2, expected) + + +def test_max_sliced_same_dist(): + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + u = ot.utils.unif(n) + + res = ot.max_sliced_wasserstein_distance(x, x, u, u, 10, seed=rng) + np.testing.assert_almost_equal(res, 0.) + + +def test_max_sliced_different_dists(): + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + u = ot.utils.unif(n) + y = rng.randn(n, 2) + + res, log = ot.max_sliced_wasserstein_distance(x, y, u, u, 10, seed=rng, log=True) + assert res > 0. + + +def test_sliced_backend(nx): + + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + y = rng.randn(2 * n, 2) + + P = rng.randn(2, 20) + P = P / np.sqrt((P**2).sum(0, keepdims=True)) + + n_projections = 20 + + xb = nx.from_numpy(x) + yb = nx.from_numpy(y) + Pb = nx.from_numpy(P) + + val0 = ot.sliced_wasserstein_distance(x, y, projections=P) + + val = ot.sliced_wasserstein_distance(xb, yb, n_projections=n_projections, seed=0) + val2 = ot.sliced_wasserstein_distance(xb, yb, n_projections=n_projections, seed=0) + + assert val > 0 + assert val == val2 + + valb = nx.to_numpy(ot.sliced_wasserstein_distance(xb, yb, projections=Pb)) + + assert np.allclose(val0, valb) + + +def test_max_sliced_backend(nx): + + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + y = rng.randn(2 * n, 2) + + P = rng.randn(2, 20) + P = P / np.sqrt((P**2).sum(0, keepdims=True)) + + n_projections = 20 + + xb = nx.from_numpy(x) + yb = nx.from_numpy(y) + Pb = nx.from_numpy(P) + + val0 = ot.max_sliced_wasserstein_distance(x, y, projections=P) + + val = ot.max_sliced_wasserstein_distance(xb, yb, n_projections=n_projections, seed=0) + val2 = ot.max_sliced_wasserstein_distance(xb, yb, n_projections=n_projections, seed=0) + + assert val > 0 + assert val == val2 + + valb = nx.to_numpy(ot.max_sliced_wasserstein_distance(xb, yb, projections=Pb)) + + assert np.allclose(val0, valb) diff --git a/test/test_utils.py b/test/test_utils.py index 0650ce2..40f4e49 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -109,7 +109,7 @@ def test_dist(): D2 = ot.dist(x, x) D3 = ot.dist(x) - D4 = ot.dist(x, x, metric='minkowski', p=0.5) + D4 = ot.dist(x, x, metric='minkowski', p=2) assert D4[0, 1] == D4[1, 0] -- cgit v1.2.3 From 9c6ac880d426b7577918b0c77bd74b3b01930ef6 Mon Sep 17 00:00:00 2001 From: ncassereau-idris <84033440+ncassereau-idris@users.noreply.github.com> Date: Wed, 3 Nov 2021 17:29:16 +0100 Subject: [MRG] Docs updates (#298) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * bregman docs * sliced docs * docs partial * unbalanced docs * stochastic docs * plot docs * datasets docs * utils docs * dr docs * dr docs corrected * smooth docs * docs da * pep8 * docs gromov * more space after min and argmin * docs lp * bregman docs * bregman docs mistake corrected * pep8 Co-authored-by: Rémi Flamary --- ot/bregman.py | 236 ++++++++++++++------------ ot/da.py | 499 ++++++++++++++++++++++++++++++------------------------ ot/datasets.py | 12 +- ot/dr.py | 40 +++-- ot/gromov.py | 35 ++-- ot/lp/__init__.py | 100 ++++++----- ot/optim.py | 8 +- ot/partial.py | 283 +++++++++++++++++-------------- ot/plot.py | 10 +- ot/sliced.py | 7 +- ot/smooth.py | 179 ++++++++++++-------- ot/stochastic.py | 192 ++++++++++----------- ot/unbalanced.py | 206 +++++++++++----------- ot/utils.py | 118 ++++++------- 14 files changed, 1048 insertions(+), 877 deletions(-) (limited to 'ot/lp') diff --git a/ot/bregman.py b/ot/bregman.py index 786f151..cce52e2 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -33,7 +33,8 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, The function solves the following optimization problem: .. math:: - \gamma = \mathop{\arg \min}_\gamma <\gamma, \mathbf{M}>_F + \mathrm{reg}\cdot\Omega(\gamma) + \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + + \mathrm{reg}\cdot\Omega(\gamma) s.t. \ \gamma \mathbf{1} &= \mathbf{a} @@ -45,9 +46,9 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix - :math:`\Omega` is the entropic regularization term - :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target - weights (histograms, both sum to 1) + weights (histograms, both sum to 1) .. note:: This function is backend-compatible and will work on arrays from all compatible backends. @@ -70,7 +71,7 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, version of the sinkhorn :py:func:`ot.bregman.screenkhorn` aim at providing a fast approximation of the Sinkhorn problem. For use of GPU and gradient computation with small number of iterations we strongly recommend the - :any:`ot.bregman.sinkhorn_log` solver that will no need to check for + :py:func:`ot.bregman.sinkhorn_log` solver that will no need to check for numerical problems. @@ -189,7 +190,8 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, The function solves the following optimization problem: .. math:: - W = \min_\gamma <\gamma, \mathbf{M}>_F + \mathrm{reg}\cdot\Omega(\gamma) + W = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + + \mathrm{reg}\cdot\Omega(\gamma) s.t. \ \gamma \mathbf{1} &= \mathbf{a} @@ -201,9 +203,9 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix - :math:`\Omega` is the entropic regularization term - :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target - weights (histograms, both sum to 1) + weights (histograms, both sum to 1) .. note:: This function is backend-compatible and will work on arrays from all compatible backends. @@ -217,17 +219,17 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, By default and when using a regularization parameter that is not too small the default sinkhorn solver should be enough. If you need to use a small regularization to get sharper OT matrices, you should use the - :any:`ot.bregman.sinkhorn_log` solver that will avoid numerical + :py:func:`ot.bregman.sinkhorn_log` solver that will avoid numerical errors. This last solver can be very slow in practice and might not even converge to a reasonable OT matrix in a finite time. This is why - :any:`ot.bregman.sinkhorn_epsilon_scaling` that relies on iterating the value + :py:func:`ot.bregman.sinkhorn_epsilon_scaling` that relies on iterating the value of the regularization (and using warm start) sometimes leads to better solutions. Note that the greedy version of the sinkhorn - :any:`ot.bregman.greenkhorn` can also lead to a speedup and the screening - version of the sinkhorn :any:`ot.bregman.screenkhorn` aim a providing a + :py:func:`ot.bregman.greenkhorn` can also lead to a speedup and the screening + version of the sinkhorn :py:func:`ot.bregman.screenkhorn` aim a providing a fast approximation of the Sinkhorn problem. For use of GPU and gradient computation with small number of iterations we strongly recommend the - :any:`ot.bregman.sinkhorn_log` solver that will no need to check for + :py:func:`ot.bregman.sinkhorn_log` solver that will no need to check for numerical problems. Parameters @@ -301,15 +303,15 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, divergences. In The 22nd International Conference on Artificial Intelligence and Statistics (pp. 2681-2690). PMLR. + See Also -------- ot.lp.emd : Unregularized OT ot.optim.cg : General regularized OT ot.bregman.sinkhorn_knopp : Classic Sinkhorn :ref:`[2] ` ot.bregman.greenkhorn : Greenkhorn :ref:`[21] ` - ot.bregman.sinkhorn_stabilized: Stabilized sinkhorn :ref:`[9] ` - :ref:`[10] ` - + ot.bregman.sinkhorn_stabilized: Stabilized sinkhorn + :ref:`[9] ` :ref:`[10] ` """ M, a, b = list_to_array(M, a, b) @@ -362,7 +364,8 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9, The function solves the following optimization problem: .. math:: - \gamma = \mathop{\arg \min}_\gamma <\gamma, \mathbf{M}>_F + \mathrm{reg}\cdot\Omega(\gamma) + \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + + \mathrm{reg}\cdot\Omega(\gamma) s.t. \ \gamma \mathbf{1} &= \mathbf{a} @@ -373,9 +376,9 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9, - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix - :math:`\Omega` is the entropic regularization term - :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target - weights (histograms, both sum to 1) + weights (histograms, both sum to 1) The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[2] ` @@ -543,7 +546,8 @@ def sinkhorn_log(a, b, M, reg, numItermax=1000, stopThr=1e-9, verbose=False, The function solves the following optimization problem: .. math:: - \gamma = \mathop{\arg \min}_\gamma <\gamma, \mathbf{M}>_F + \mathrm{reg}\cdot\Omega(\gamma) + \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + + \mathrm{reg}\cdot\Omega(\gamma) s.t. \ \gamma \mathbf{1} &= \mathbf{a} @@ -553,12 +557,13 @@ def sinkhorn_log(a, b, M, reg, numItermax=1000, stopThr=1e-9, verbose=False, where : - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix - - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\Omega` is the entropic regularization term + :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (histograms, both sum to 1) The algorithm used for solving the problem is the Sinkhorn-Knopp matrix - scaling algorithm :ref:`[2] ` with the - implementation from :ref:`[34] ` + scaling algorithm :ref:`[2] ` with the + implementation from :ref:`[34] ` Parameters @@ -744,7 +749,8 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, The function solves the following optimization problem: .. math:: - \gamma = \mathop{\arg \min}_\gamma <\gamma, \mathbf{M}>_F + \mathrm{reg}\cdot\Omega(\gamma) + \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + + \mathrm{reg}\cdot\Omega(\gamma) s.t. \ \gamma \mathbf{1} &= \mathbf{a} @@ -755,9 +761,9 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix - :math:`\Omega` is the entropic regularization term - :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target - weights (histograms, both sum to 1) + weights (histograms, both sum to 1) Parameters @@ -903,7 +909,8 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, The function solves the following optimization problem: .. math:: - \gamma = \mathop{\arg \min}_\gamma <\gamma, \mathbf{M}>_F + \mathrm{reg}\cdot\Omega(\gamma) + \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + + \mathrm{reg}\cdot\Omega(\gamma) s.t. \ \gamma \mathbf{1} &= \mathbf{a} @@ -914,9 +921,9 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix - :math:`\Omega` is the entropic regularization term - :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target - weights (histograms, both sum to 1) + weights (histograms, both sum to 1) The algorithm used for solving the problem is the Sinkhorn-Knopp matrix @@ -1145,20 +1152,24 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, Solve the entropic regularization optimal transport problem with log stabilization and epsilon scaling. The function solves the following optimization problem: + .. math:: - \gamma = \mathop{\arg \min}_\gamma <\gamma, \mathbf{M}>_F + \mathrm{reg}\cdot\Omega(\gamma) + \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + + \mathrm{reg}\cdot\Omega(\gamma) s.t. \ \gamma \mathbf{1} &= \mathbf{a} \gamma^T \mathbf{1} &= \mathbf{b} \gamma &\geq 0 + where : + - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix - :math:`\Omega` is the entropic regularization term - :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights - (histograms, both sum to 1) + :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (histograms, both sum to 1) + The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[2] ` but with the log stabilization @@ -1340,17 +1351,17 @@ def barycenter(A, M, reg, weights=None, method="sinkhorn", numItermax=10000, The function solves the following optimization problem: .. math:: - \mathbf{a} = \mathop{\arg \min}_\mathbf{a} \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i) + \mathbf{a} = \mathop{\arg \min}_\mathbf{a} \quad \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i) where : - - :math:`OT_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein - distance (see :py:func:`ot.bregman.sinkhorn`) - if `method` is `sinkhorn` or `sinkhorn_stabilized` or `sinkhorn_log`. + - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein + distance (see :py:func:`ot.bregman.sinkhorn`) + if `method` is `sinkhorn` or `sinkhorn_stabilized` or `sinkhorn_log`. - :math:`\mathbf{a}_i` are training distributions in the columns of matrix - :math:`\mathbf{A}` + :math:`\mathbf{A}` - `reg` and :math:`\mathbf{M}` are respectively the regularization term and - the cost matrix for OT + the cost matrix for OT The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[3] ` @@ -1424,16 +1435,16 @@ def barycenter_sinkhorn(A, M, reg, weights=None, numItermax=1000, The function solves the following optimization problem: .. math:: - \mathbf{a} = \mathop{\arg \min}_\mathbf{a} \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i) + \mathbf{a} = \mathop{\arg \min}_\mathbf{a} \quad \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i) where : - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance - (see :py:func:`ot.bregman.sinkhorn`) + (see :py:func:`ot.bregman.sinkhorn`) - :math:`\mathbf{a}_i` are training distributions in the columns of matrix - :math:`\mathbf{A}` + :math:`\mathbf{A}` - `reg` and :math:`\mathbf{M}` are respectively the regularization term and - the cost matrix for OT + the cost matrix for OT The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[3]`. @@ -1598,16 +1609,16 @@ def barycenter_stabilized(A, M, reg, tau=1e10, weights=None, numItermax=1000, The function solves the following optimization problem: .. math:: - \mathbf{a} = \mathop{\arg \min}_\mathbf{a} \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i) + \mathbf{a} = \mathop{\arg \min}_\mathbf{a} \quad \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i) where : - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein - distance (see :py:func:`ot.bregman.sinkhorn`) + distance (see :py:func:`ot.bregman.sinkhorn`) - :math:`\mathbf{a}_i` are training distributions in the columns of matrix - :math:`\mathbf{A}` + :math:`\mathbf{A}` - `reg` and :math:`\mathbf{M}` are respectively the regularization term and - the cost matrix for OT + the cost matrix for OT The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[3] ` @@ -1736,24 +1747,24 @@ def barycenter_debiased(A, M, reg, weights=None, method="sinkhorn", numItermax=1 The function solves the following optimization problem: .. math:: - \mathbf{a} = arg\min_\mathbf{a} \sum_i S_{reg}(\mathbf{a},\mathbf{a}_i) + \mathbf{a} = \mathop{\arg \min}_\mathbf{a} \quad \sum_i S_{reg}(\mathbf{a},\mathbf{a}_i) where : - :math:`S_{reg}(\cdot,\cdot)` is the debiased Sinkhorn divergence - (see :py:func:`ot.bregman.emirical_sinkhorn_divergence`) + (see :py:func:`ot.bregman.empirical_sinkhorn_divergence`) - :math:`\mathbf{a}_i` are training distributions in the columns of matrix - :math:`\mathbf{A}` + :math:`\mathbf{A}` - `reg` and :math:`\mathbf{M}` are respectively the regularization term and - the cost matrix for OT + the cost matrix for OT The algorithm used for solving the problem is the debiased Sinkhorn - algorithm as proposed in :ref:`[37] ` + algorithm as proposed in :ref:`[37] ` Parameters ---------- A : array-like, shape (dim, n_hists) - `n_hists` training distributions :math:`a_i` of size `dim` + `n_hists` training distributions :math:`\mathbf{a}_i` of size `dim` M : array-like, shape (dim, dim) loss matrix for OT reg : float @@ -1761,7 +1772,7 @@ def barycenter_debiased(A, M, reg, weights=None, method="sinkhorn", numItermax=1 method : str (optional) method used for the solver either 'sinkhorn' or 'sinkhorn_log' weights : array-like, shape (n_hists,) - Weights of each histogram :math:`a_i` on the simplex (barycentric coodinates) + Weights of each histogram :math:`\mathbf{a}_i` on the simplex (barycentric coodinates) numItermax : int, optional Max number of iterations stopThr : float, optional @@ -1774,7 +1785,6 @@ def barycenter_debiased(A, M, reg, weights=None, method="sinkhorn", numItermax=1 if True, raises a warning if the algorithm doesn't convergence. - Returns ------- a : (dim,) array-like @@ -1782,12 +1792,12 @@ def barycenter_debiased(A, M, reg, weights=None, method="sinkhorn", numItermax=1 log : dict log dictionary return only if log==True in parameters - .. _references-sinkhorn-debiased: - References - ---------- - .. [37] Janati, H., Cuturi, M., Gramfort, A. Proceedings of the 37th International - Conference on Machine Learning, PMLR 119:4692-4701, 2020 + .. _references-barycenter-debiased: + References + ---------- + .. [37] Janati, H., Cuturi, M., Gramfort, A. Proceedings of the 37th International + Conference on Machine Learning, PMLR 119:4692-4701, 2020 """ if method.lower() == 'sinkhorn': @@ -1934,20 +1944,20 @@ def _barycenter_debiased_log(A, M, reg, weights=None, numItermax=1000, def convolutional_barycenter2d(A, reg, weights=None, method="sinkhorn", numItermax=10000, stopThr=1e-4, verbose=False, log=False, warn=True, **kwargs): - r"""Compute the entropic regularized wasserstein barycenter of distributions A - where A is a collection of 2D images. + r"""Compute the entropic regularized wasserstein barycenter of distributions :math:`\mathbf{A}` + where :math:`\mathbf{A}` is a collection of 2D images. The function solves the following optimization problem: .. math:: - \mathbf{a} = \mathop{\arg \min}_\mathbf{a} \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i) + \mathbf{a} = \mathop{\arg \min}_\mathbf{a} \quad \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i) where : - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein - distance (see :py:func:`ot.bregman.sinkhorn`) + distance (see :py:func:`ot.bregman.sinkhorn`) - :math:`\mathbf{a}_i` are training distributions (2D images) in the mast two dimensions - of matrix :math:`\mathbf{A}` + of matrix :math:`\mathbf{A}` - `reg` is the regularization strength scalar value The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm @@ -2166,24 +2176,24 @@ def convolutional_barycenter2d_debiased(A, reg, weights=None, method="sinkhorn", numItermax=10000, stopThr=1e-3, verbose=False, log=False, warn=True, **kwargs): - r"""Compute the debiased sinkhorn barycenter of distributions A - where A is a collection of 2D images. + r"""Compute the debiased sinkhorn barycenter of distributions :math:`\mathbf{A}` + where :math:`\mathbf{A}` is a collection of 2D images. The function solves the following optimization problem: .. math:: - \mathbf{a} = arg\min_\mathbf{a} \sum_i S_{reg}(\mathbf{a},\mathbf{a}_i) + \mathbf{a} = \mathop{\arg \min}_\mathbf{a} \quad \sum_i S_{reg}(\mathbf{a},\mathbf{a}_i) where : - :math:`S_{reg}(\cdot,\cdot)` is the debiased entropic regularized Wasserstein - distance (see :py:func:`ot.bregman.sinkhorn_debiased`) + distance (see :py:func:`ot.bregman.barycenter_debiased`) - :math:`\mathbf{a}_i` are training distributions (2D images) in the mast two - dimensions of matrix :math:`\mathbf{A}` + dimensions of matrix :math:`\mathbf{A}` - `reg` is the regularization strength scalar value The algorithm used for solving the problem is the debiased Sinkhorn scaling - algorithm as proposed in :ref:`[37] ` + algorithm as proposed in :ref:`[37] ` Parameters ---------- @@ -2217,7 +2227,7 @@ def convolutional_barycenter2d_debiased(A, reg, weights=None, method="sinkhorn", log dictionary return only if log==True in parameters - .. _references-sinkhorn-debiased: + .. _references-convolutional-barycenter2d-debiased: References ---------- @@ -2406,23 +2416,25 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000, .. math:: - \mathbf{h} = \mathop{\arg \min}_\mathbf{h} (1- \alpha) W_{\mathbf{M}, \mathrm{reg}}(\mathbf{a},\mathbf{Dh})+\alpha W_{\mathbf{M_0},\mathrm{reg}_0}(\mathbf{h}_0,\mathbf{h}) + \mathbf{h} = \mathop{\arg \min}_\mathbf{h} \quad + (1 - \alpha) W_{\mathbf{M}, \mathrm{reg}}(\mathbf{a}, \mathbf{Dh}) + + \alpha W_{\mathbf{M_0}, \mathrm{reg}_0}(\mathbf{h}_0, \mathbf{h}) where : - :math:`W_{M,reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance - with M loss matrix (see :py:func:`ot.bregman.sinkhorn`) + with :math:`\mathbf{M}` loss matrix (see :py:func:`ot.bregman.sinkhorn`) - :math:`\mathbf{D}` is a dictionary of `n_atoms` atoms of dimension `dim_a`, - its expected shape is `(dim_a, n_atoms)` + its expected shape is `(dim_a, n_atoms)` - :math:`\mathbf{h}` is the estimated unmixing of dimension `n_atoms` - :math:`\mathbf{a}` is an observed distribution of dimension `dim_a` - :math:`\mathbf{h}_0` is a prior on :math:`\mathbf{h}` of dimension `dim_prior` - `reg` and :math:`\mathbf{M}` are respectively the regularization term and the - cost matrix (`dim_a`, `dim_a`) for OT data fitting + cost matrix (`dim_a`, `dim_a`) for OT data fitting - `reg`:math:`_0` and :math:`\mathbf{M_0}` are respectively the regularization - term and the cost matrix (`dim_prior`, `n_atoms`) regularization - - :math:`\\alpha` weight data fitting and regularization + term and the cost matrix (`dim_prior`, `n_atoms`) regularization + - :math:`\alpha` weight data fitting and regularization The optimization problem is solved following the algorithm described in :ref:`[4] ` @@ -2535,7 +2547,7 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, .. math:: - \mathbf{h} = \mathop{\arg \min}_{\mathbf{h}} \sum_{k=1}^{K} \lambda_k + \mathbf{h} = \mathop{\arg \min}_{\mathbf{h}} \quad \sum_{k=1}^{K} \lambda_k W_{reg}((\mathbf{D}_2^{(k)} \mathbf{h})^T, \mathbf{a}) s.t. \ \forall k, \mathbf{D}_1^{(k)} \gamma_k \mathbf{1}_n= \mathbf{h} @@ -2544,15 +2556,15 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, - :math:`\lambda_k` is the weight of `k`-th source domain - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance - (see :py:func:`ot.bregman.sinkhorn`) + (see :py:func:`ot.bregman.sinkhorn`) - :math:`\mathbf{D}_2^{(k)}` is a matrix of weights related to `k`-th source domain - defined as in [p. 5, :ref:`27 `], its expected shape - is :math:`(n_k, C)` where :math:`n_k` is the number of elements in the `k`-th source - domain and `C` is the number of classes + defined as in [p. 5, :ref:`27 `], its expected shape + is :math:`(n_k, C)` where :math:`n_k` is the number of elements in the `k`-th source + domain and `C` is the number of classes - :math:`\mathbf{h}` is a vector of estimated proportions in the target domain of size `C` - :math:`\mathbf{a}` is a uniform vector of weights in the target domain of size `n` - :math:`\mathbf{D}_1^{(k)}` is a matrix of class assignments defined as in - [p. 5, :ref:`27 `], its expected shape is :math:`(n_k, C)` + [p. 5, :ref:`27 `], its expected shape is :math:`(n_k, C)` The problem consist in solving a Wasserstein barycenter problem to estimate the proportions :math:`\mathbf{h}` in the target domain. @@ -2714,18 +2726,19 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', The function solves the following optimization problem: .. math:: - \gamma = \mathop{\arg \min}_\gamma <\gamma, \mathbf{M}>_F + \mathrm{reg} \cdot\Omega(\gamma) + \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + + \mathrm{reg} \cdot\Omega(\gamma) - s.t. \ \gamma \mathbf{1} &= a + s.t. \ \gamma \mathbf{1} &= \mathbf{a} - \gamma^T \mathbf{1} &= b + \gamma^T \mathbf{1} &= \mathbf{b} \gamma &\geq 0 where : - :math:`\mathbf{M}` is the (`n_samples_a`, `n_samples_b`) metric cost matrix - :math:`\Omega` is the entropic regularization term - :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1) @@ -2900,18 +2913,19 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', The function solves the following optimization problem: .. math:: - W = \min_\gamma <\gamma, \mathbf{M}>_F + \mathrm{reg} \cdot\Omega(\gamma) + W = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + + \mathrm{reg} \cdot\Omega(\gamma) - s.t. \ \gamma \mathbf{1} &= a + s.t. \ \gamma \mathbf{1} &= \mathbf{a} - \gamma^T \mathbf{1} &= b + \gamma^T \mathbf{1} &= \mathbf{b} \gamma &\geq 0 where : - :math:`\mathbf{M}` is the (`n_samples_a`, `n_samples_b`) metric cost matrix - :math:`\Omega` is the entropic regularization term - :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1) @@ -3055,18 +3069,21 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli .. math:: - W &= \min_\gamma <\gamma, \mathbf{M}>_F + \mathrm{reg} \cdot\Omega(\gamma) + W &= \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + + \mathrm{reg} \cdot\Omega(\gamma) - W_a &= \min_{\gamma_a} <\gamma_a, \mathbf{M_a}>_F + \mathrm{reg} \cdot\Omega(\gamma_a) + W_a &= \min_{\gamma_a} \quad \langle \gamma_a, \mathbf{M_a} \rangle_F + + \mathrm{reg} \cdot\Omega(\gamma_a) - W_b &= \min_{\gamma_b} <\gamma_b, \mathbf{M_b}>_F + \mathrm{reg} \cdot\Omega(\gamma_b) + W_b &= \min_{\gamma_b} \quad \langle \gamma_b, \mathbf{M_b} \rangle_F + + \mathrm{reg} \cdot\Omega(\gamma_b) S &= W - \frac{W_a + W_b}{2} .. math:: - s.t. \ \gamma \mathbf{1} &= a + s.t. \ \gamma \mathbf{1} &= \mathbf{a} - \gamma^T \mathbf{1} &= b + \gamma^T \mathbf{1} &= \mathbf{b} \gamma &\geq 0 @@ -3084,10 +3101,10 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli where : - :math:`\mathbf{M}` (resp. :math:`\mathbf{M_a}`, :math:`\mathbf{M_b}`) - is the (`n_samples_a`, `n_samples_b`) metric cost matrix - (resp (`n_samples_a, n_samples_a`) and (`n_samples_b`, `n_samples_b`)) + is the (`n_samples_a`, `n_samples_b`) metric cost matrix + (resp (`n_samples_a, n_samples_a`) and (`n_samples_b`, `n_samples_b`)) - :math:`\Omega` is the entropic regularization term - :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1) @@ -3198,7 +3215,10 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, .. math:: - (\mathbf{u}, \mathbf{v}) = \mathop{\arg \min}_{\mathbf{u}, \mathbf{v}} \ \mathbf{1}_{ns}^T \mathbf{B}(\mathbf{u}, \mathbf{v}) \mathbf{1}_{nt} - <\kappa \mathbf{u}, \mathbf{a}> - <\mathbf{v} / \kappa, \mathbf{b}> + (\mathbf{u}, \mathbf{v}) = \mathop{\arg \min}_{\mathbf{u}, \mathbf{v}} \quad + \mathbf{1}_{ns}^T \mathbf{B}(\mathbf{u}, \mathbf{v}) \mathbf{1}_{nt} - + \langle \kappa \mathbf{u}, \mathbf{a} \rangle - + \langle \frac{1}{\kappa} \mathbf{v}, \mathbf{b} \rangle where: @@ -3249,13 +3269,15 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, If `True`, display informations about the cardinals of the active sets and the parameters kappa and epsilon - Dependency - ---------- - To gain more efficiency, screenkhorn needs to call the "Bottleneck" - package (https://pypi.org/project/Bottleneck/) - in the screening pre-processing step. If Bottleneck isn't installed, - the following error message appears: - "Bottleneck module doesn't exist. Install it from https://pypi.org/project/Bottleneck/" + + .. admonition:: Dependency + + To gain more efficiency, :py:func:`ot.bregman.screenkhorn` needs to call the "Bottleneck" + package (https://pypi.org/project/Bottleneck/) in the screening pre-processing step. + + If Bottleneck isn't installed, the following error message appears: + + "Bottleneck module doesn't exist. Install it from https://pypi.org/project/Bottleneck/" Returns diff --git a/ot/da.py b/ot/da.py index cdc747c..4fd97df 100644 --- a/ot/da.py +++ b/ot/da.py @@ -33,27 +33,29 @@ def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10, The function solves the following optimization problem: .. math:: - \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega_e(\gamma) - + \eta \Omega_g(\gamma) + \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + + \mathrm{reg} \cdot \Omega_e(\gamma) + \eta \ \Omega_g(\gamma) - s.t. \gamma 1 = a + s.t. \ \gamma \mathbf{1} = \mathbf{a} + + \gamma^T \mathbf{1} = \mathbf{b} + + \gamma \geq 0 - \gamma^T 1= b - \gamma\geq 0 where : - - M is the (ns,nt) metric cost matrix + - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix - :math:`\Omega_e` is the entropic regularization term :math:`\Omega_e (\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - :math:`\Omega_g` is the group lasso regularization term :math:`\Omega_g(\gamma)=\sum_{i,c} \|\gamma_{i,\mathcal{I}_c}\|^{1/2}_1` - where :math:`\mathcal{I}_c` are the index of samples from class c + where :math:`\mathcal{I}_c` are the index of samples from class `c` in the source domain. - - a and b are source and target weights (sum to 1) + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1) The algorithm used for solving the problem is the generalized conditional - gradient as proposed in [5]_ [7]_ + gradient as proposed in :ref:`[5, 7] `. Parameters @@ -84,19 +86,20 @@ def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10, Returns ------- - gamma : (ns x nt) ndarray + gamma : (ns, nt) ndarray Optimal transportation matrix for the given parameters log : dict log dictionary return only if log==True in parameters + .. _references-sinkhorn-lpl1-mm: References ---------- - .. [5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, "Optimal Transport for Domain Adaptation," in IEEE Transactions on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1 + .. [7] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015). Generalized conditional gradient: analysis of convergence and applications. arXiv preprint arXiv:1510.06567. @@ -144,27 +147,29 @@ def sinkhorn_l1l2_gl(a, labels_a, b, M, reg, eta=0.1, numItermax=10, The function solves the following optimization problem: .. math:: - \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega_e(\gamma)+ - \eta \Omega_g(\gamma) + \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + + \mathrm{reg} \cdot \Omega_e(\gamma) + \eta \ \Omega_g(\gamma) + + s.t. \ \gamma \mathbf{1} = \mathbf{a} - s.t. \gamma 1 = a + \gamma^T \mathbf{1} = \mathbf{b} + + \gamma \geq 0 - \gamma^T 1= b - \gamma\geq 0 where : - - M is the (ns,nt) metric cost matrix + - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix - :math:`\Omega_e` is the entropic regularization term :math:`\Omega_e(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - :math:`\Omega_g` is the group lasso regulaization term :math:`\Omega_g(\gamma)=\sum_{i,c} \|\gamma_{i,\mathcal{I}_c}\|^2` where :math:`\mathcal{I}_c` are the index of samples from class - c in the source domain. - - a and b are source and target weights (sum to 1) + `c` in the source domain. + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1) The algorithm used for solving the problem is the generalised conditional - gradient as proposed in [5]_ [7]_ + gradient as proposed in :ref:`[5, 7] `. Parameters @@ -195,18 +200,19 @@ def sinkhorn_l1l2_gl(a, labels_a, b, M, reg, eta=0.1, numItermax=10, Returns ------- - gamma : (ns x nt) ndarray + gamma : (ns, nt) ndarray Optimal transportation matrix for the given parameters log : dict log dictionary return only if log==True in parameters + .. _references-sinkhorn-l1l2-gl: References ---------- - .. [5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, "Optimal Transport for Domain Adaptation," in IEEE Transactions on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1 + .. [7] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015). Generalized conditional gradient: analysis of convergence and applications. arXiv preprint arXiv:1510.06567. @@ -245,38 +251,40 @@ def joint_OT_mapping_linear(xs, xt, mu=1, eta=0.001, bias=False, verbose=False, verbose2=False, numItermax=100, numInnerItermax=10, stopInnerThr=1e-6, stopThr=1e-5, log=False, **kwargs): - r"""Joint OT and linear mapping estimation as proposed in [8] + r"""Joint OT and linear mapping estimation as proposed in + :ref:`[8] `. The function solves the following optimization problem: .. math:: - \min_{\gamma,L}\quad \|L(X_s) -n_s\gamma X_t\|^2_F + - \mu<\gamma,M>_F + \eta \|L -I\|^2_F + \min_{\gamma,L}\quad \|L(\mathbf{X_s}) - n_s\gamma \mathbf{X_t} \|^2_F + + \mu \langle \gamma, \mathbf{M} \rangle_F + \eta \|L - \mathbf{I}\|^2_F + + s.t. \ \gamma \mathbf{1} = \mathbf{a} - s.t. \gamma 1 = a + \gamma^T \mathbf{1} = \mathbf{b} - \gamma^T 1= b + \gamma \geq 0 - \gamma\geq 0 where : - - M is the (ns,nt) squared euclidean cost matrix between samples in - Xs and Xt (scaled by ns) - - :math:`L` is a dxd linear operator that approximates the barycentric + - :math:`\mathbf{M}` is the (`ns`, `nt`) squared euclidean cost matrix between samples in + :math:`\mathbf{X_s}` and :math:`\mathbf{X_t}` (scaled by :math:`n_s`) + - :math:`L` is a :math:`d\times d` linear operator that approximates the barycentric mapping - - :math:`I` is the identity matrix (neutral linear mapping) - - a and b are uniform source and target weights + - :math:`\mathbf{I}` is the identity matrix (neutral linear mapping) + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are uniform source and target weights The problem consist in solving jointly an optimal transport matrix :math:`\gamma` and a linear mapping that fits the barycentric mapping - :math:`n_s\gamma X_t`. + :math:`n_s\gamma \mathbf{X_t}`. One can also estimate a mapping with constant bias (see supplementary - material of [8]) using the bias optional argument. + material of :ref:`[8] `) using the bias optional argument. The algorithm used for solving the problem is the block coordinate - descent that alternates between updates of G (using conditionnal gradient) - and the update of L using a classical least square solver. + descent that alternates between updates of :math:`\mathbf{G}` (using conditionnal gradient) + and the update of :math:`\mathbf{L}` using a classical least square solver. Parameters @@ -307,17 +315,17 @@ def joint_OT_mapping_linear(xs, xt, mu=1, eta=0.001, bias=False, verbose=False, Returns ------- - gamma : (ns x nt) ndarray + gamma : (ns, nt) ndarray Optimal transportation matrix for the given parameters - L : (d x d) ndarray - Linear mapping matrix (d+1 x d if bias) + L : (d, d) ndarray + Linear mapping matrix ((:math:`d+1`, `d`) if bias) log : dict log dictionary return only if log==True in parameters + .. _references-joint-OT-mapping-linear: References ---------- - .. [8] M. Perrot, N. Courty, R. Flamary, A. Habrard, "Mapping estimation for discrete optimal transport", Neural Information Processing Systems (NIPS), 2016. @@ -434,37 +442,41 @@ def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian', numItermax=100, numInnerItermax=10, stopInnerThr=1e-6, stopThr=1e-5, log=False, **kwargs): - r"""Joint OT and nonlinear mapping estimation with kernels as proposed in [8] + r"""Joint OT and nonlinear mapping estimation with kernels as proposed in + :ref:`[8] `. The function solves the following optimization problem: .. math:: - \min_{\gamma,L\in\mathcal{H}}\quad \|L(X_s) - - n_s\gamma X_t\|^2_F + \mu<\gamma,M>_F + \eta \|L\|^2_\mathcal{H} + \min_{\gamma, L\in\mathcal{H}}\quad \|L(\mathbf{X_s}) - + n_s\gamma \mathbf{X_t}\|^2_F + \mu \langle \gamma, \mathbf{M} \rangle_F + + \eta \|L\|^2_\mathcal{H} + + s.t. \ \gamma \mathbf{1} = \mathbf{a} + + \gamma^T \mathbf{1} = \mathbf{b} - s.t. \gamma 1 = a + \gamma \geq 0 - \gamma^T 1= b - \gamma\geq 0 where : - - M is the (ns,nt) squared euclidean cost matrix between samples in - Xs and Xt (scaled by ns) - - :math:`L` is a ns x d linear operator on a kernel matrix that + - :math:`\mathbf{M}` is the (`ns`, `nt`) squared euclidean cost matrix between samples in + :math:`\mathbf{X_s}` and :math:`\mathbf{X_t}` (scaled by :math:`n_s`) + - :math:`L` is a :math:`n_s \times d` linear operator on a kernel matrix that approximates the barycentric mapping - - a and b are uniform source and target weights + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are uniform source and target weights The problem consist in solving jointly an optimal transport matrix :math:`\gamma` and the nonlinear mapping that fits the barycentric mapping - :math:`n_s\gamma X_t`. + :math:`n_s\gamma \mathbf{X_t}`. One can also estimate a mapping with constant bias (see supplementary - material of [8]) using the bias optional argument. + material of :ref:`[8] `) using the bias optional argument. The algorithm used for solving the problem is the block coordinate - descent that alternates between updates of G (using conditionnal gradient) - and the update of L using a classical kernel least square solver. + descent that alternates between updates of :math:`\mathbf{G}` (using conditionnal gradient) + and the update of :math:`\mathbf{L}` using a classical kernel least square solver. Parameters @@ -478,7 +490,7 @@ def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian', eta : float, optional Regularization term for the linear mapping L (>0) kerneltype : str,optional - kernel used by calling function ot.utils.kernel (gaussian by default) + kernel used by calling function :py:func:`ot.utils.kernel` (gaussian by default) sigma : float, optional Gaussian kernel bandwidth. bias : bool,optional @@ -501,17 +513,17 @@ def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian', Returns ------- - gamma : (ns x nt) ndarray + gamma : (ns, nt) ndarray Optimal transportation matrix for the given parameters - L : (ns x d) ndarray - Nonlinear mapping matrix (ns+1 x d if bias) + L : (ns, d) ndarray + Nonlinear mapping matrix ((:math:`n_s+1`, `d`) if bias) log : dict log dictionary return only if log==True in parameters + .. _references-joint-OT-mapping-kernel: References ---------- - .. [8] M. Perrot, N. Courty, R. Flamary, A. Habrard, "Mapping estimation for discrete optimal transport", Neural Information Processing Systems (NIPS), 2016. @@ -645,26 +657,27 @@ def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian', def OT_mapping_linear(xs, xt, reg=1e-6, ws=None, wt=None, bias=True, log=False): - r""" return OT linear operator between samples + r"""Return OT linear operator between samples. The function estimates the optimal linear operator that aligns the two empirical distributions. This is equivalent to estimating the closed - form mapping between two Gaussian distributions :math:`N(\mu_s,\Sigma_s)` - and :math:`N(\mu_t,\Sigma_t)` as proposed in [14] and discussed in remark - 2.29 in [15]. + form mapping between two Gaussian distributions :math:`\mathcal{N}(\mu_s,\Sigma_s)` + and :math:`\mathcal{N}(\mu_t,\Sigma_t)` as proposed in + :ref:`[14] ` and discussed in remark 2.29 in + :ref:`[15] `. The linear operator from source to target :math:`M` .. math:: - M(x)=Ax+b + M(\mathbf{x})= \mathbf{A} \mathbf{x} + \mathbf{b} where : .. math:: - A=\Sigma_s^{-1/2}(\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2})^{1/2} + \mathbf{A} &= \Sigma_s^{-1/2} \left(\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2} \right)^{1/2} \Sigma_s^{-1/2} - .. math:: - b=\mu_t-A\mu_s + + \mathbf{b} &= \mu_t - \mathbf{A} \mu_s Parameters ---------- @@ -673,35 +686,35 @@ def OT_mapping_linear(xs, xt, reg=1e-6, ws=None, xt : np.ndarray (nt,d) samples in the target domain reg : float,optional - regularization added to the diagonals of convariances (>0) + regularization added to the diagonals of covariances (>0) ws : np.ndarray (ns,1), optional weights for the source samples wt : np.ndarray (ns,1), optional weights for the target samples bias: boolean, optional - estimate bias b else b=0 (default:True) + estimate bias :math:`\mathbf{b}` else :math:`\mathbf{b} = 0` (default:True) log : bool, optional record log if True Returns ------- - A : (d x d) ndarray + A : (d, d) ndarray Linear operator - b : (1 x d) ndarray + b : (1, d) ndarray bias log : dict log dictionary return only if log==True in parameters + .. _references-OT-mapping-linear: References ---------- - .. [14] Knott, M. and Smith, C. S. "On the optimal mapping of distributions", Journal of Optimization Theory and Applications Vol 43, 1984 - .. [15] Peyré, G., & Cuturi, M. (2017). "Computational Optimal + .. [15] Peyré, G., & Cuturi, M. (2017). "Computational Optimal Transport", 2018. @@ -754,24 +767,34 @@ def emd_laplace(a, b, xs, xt, M, sim='knn', sim_param=None, reg='pos', eta=1, al r"""Solve the optimal transport problem (OT) with Laplacian regularization .. math:: - \gamma = arg\min_\gamma <\gamma,M>_F + eta\Omega_\alpha(\gamma) + \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + + \eta \cdot \Omega_\alpha(\gamma) - s.t.\ \gamma 1 = a + s.t. \ \gamma \mathbf{1} = \mathbf{a} - \gamma^T 1= b + \gamma^T \mathbf{1} = \mathbf{b} - \gamma\geq 0 + \gamma \geq 0 where: - - a and b are source and target weights (sum to 1) - - xs and xt are source and target samples - - M is the (ns,nt) metric cost matrix + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1) + - :math:`\mathbf{x_s}` and :math:`\mathbf{x_t}` are source and target samples + - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix - :math:`\Omega_\alpha` is the Laplacian regularization term - :math:`\Omega_\alpha = (1-\alpha)/n_s^2\sum_{i,j}S^s_{i,j}\|T(\mathbf{x}^s_i)-T(\mathbf{x}^s_j)\|^2+\alpha/n_t^2\sum_{i,j}S^t_{i,j}^'\|T(\mathbf{x}^t_i)-T(\mathbf{x}^t_j)\|^2` - with :math:`S^s_{i,j}, S^t_{i,j}` denoting source and target similarity matrices and :math:`T(\cdot)` being a barycentric mapping - The algorithm used for solving the problem is the conditional gradient algorithm as proposed in [5]. + .. math:: + \Omega_\alpha = \frac{1 - \alpha}{n_s^2} \sum_{i,j} + \mathbf{S^s}_{i,j} \|T(\mathbf{x}^s_i) - T(\mathbf{x}^s_j) \|^2 + + \frac{\alpha}{n_t^2} \sum_{i,j} + \mathbf{S^t}_{i,j} \|T(\mathbf{x}^t_i) - T(\mathbf{x}^t_j) \|^2 + + + with :math:`\mathbf{S^s}_{i,j}, \mathbf{S^t}_{i,j}` denoting source and target similarity + matrices and :math:`T(\cdot)` being a barycentric mapping. + + The algorithm used for solving the problem is the conditional gradient algorithm as proposed in + :ref:`[5] `. Parameters ---------- @@ -811,22 +834,23 @@ def emd_laplace(a, b, xs, xt, M, sim='knn', sim_param=None, reg='pos', eta=1, al Returns ------- - gamma : (ns x nt) ndarray + gamma : (ns, nt) ndarray Optimal transportation matrix for the given parameters log : dict log dictionary return only if log==True in parameters + .. _references-emd-laplace: References ---------- - .. [5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, "Optimal Transport for Domain Adaptation," in IEEE - Transactions on Pattern Analysis and Machine Intelligence , + Transactions on Pattern Analysis and Machine Intelligence, vol.PP, no.99, pp.1-1 + .. [30] R. Flamary, N. Courty, D. Tuia, A. Rakotomamonjy, "Optimal transport with Laplacian regularization: Applications to domain adaptation and shape matching," - in NIPS Workshop on Optimal Transport and Machine Learning OTML, 2014. + in NIPS Workshop on Optimal Transport and Machine Learning OTML, 2014. See Also -------- @@ -882,7 +906,7 @@ def emd_laplace(a, b, xs, xt, M, sim='knn', sim_param=None, reg='pos', eta=1, al def distribution_estimation_uniform(X): - """estimates a uniform distribution from an array of samples X + """estimates a uniform distribution from an array of samples :math:`\mathbf{X}` Parameters ---------- @@ -892,7 +916,7 @@ def distribution_estimation_uniform(X): Returns ------- mu : array-like, shape (n_samples,) - The uniform distribution estimated from X + The uniform distribution estimated from :math:`\mathbf{X}` """ return unif(X.shape[0]) @@ -902,32 +926,32 @@ class BaseTransport(BaseEstimator): """Base class for OTDA objects - Notes - ----- - All estimators should specify all the parameters that can be set - at the class level in their ``__init__`` as explicit keyword - arguments (no ``*args`` or ``**kwargs``). + .. note:: + All estimators should specify all the parameters that can be set + at the class level in their ``__init__`` as explicit keyword + arguments (no ``*args`` or ``**kwargs``). - the fit method should: + The fit method should: - estimate a cost matrix and store it in a `cost_` attribute - - estimate a coupling matrix and store it in a `coupling_` - attribute + - estimate a coupling matrix and store it in a `coupling_` attribute - estimate distributions from source and target data and store them in - mu_s and mu_t attributes - - store Xs and Xt in attributes to be used later on in transform and - inverse_transform methods + `mu_s` and `mu_t` attributes + - store `Xs` and `Xt` in attributes to be used later on in `transform` and + `inverse_transform` methods + + `transform` method should always get as input a `Xs` parameter - transform method should always get as input a Xs parameter - inverse_transform method should always get as input a Xt parameter + `inverse_transform` method should always get as input a `Xt` parameter - transform_labels method should always get as input a ys parameter - inverse_transform_labels method should always get as input a yt parameter + `transform_labels` method should always get as input a `ys` parameter + + `inverse_transform_labels` method should always get as input a `yt` parameter """ def fit(self, Xs=None, ys=None, Xt=None, yt=None): """Build a coupling matrix from source and target sets of samples - (Xs, ys) and (Xt, yt) + :math:`(\mathbf{X_s}, \mathbf{y_s})` and :math:`(\mathbf{X_t}, \mathbf{y_t})` Parameters ---------- @@ -938,8 +962,8 @@ class BaseTransport(BaseEstimator): Xt : array-like, shape (n_target_samples, n_features) The training input samples. yt : array-like, shape (n_target_samples,) - The class labels. If some target samples are unlabeled, fill the - yt's elements with -1. + The class labels. If some target samples are unlabelled, fill the + :math:`\mathbf{y_t}`'s elements with -1. Warning: Note that, due to this convention -1 cannot be used as a class label @@ -987,8 +1011,8 @@ class BaseTransport(BaseEstimator): def fit_transform(self, Xs=None, ys=None, Xt=None, yt=None): """Build a coupling matrix from source and target sets of samples - (Xs, ys) and (Xt, yt) and transports source samples Xs onto target - ones Xt + :math:`(\mathbf{X_s}, \mathbf{y_s})` and :math:`(\mathbf{X_t}, \mathbf{y_t})` + and transports source samples :math:`\mathbf{X_s}` onto target ones :math:`\mathbf{X_t}` Parameters ---------- @@ -999,8 +1023,8 @@ class BaseTransport(BaseEstimator): Xt : array-like, shape (n_target_samples, n_features) The training input samples. yt : array-like, shape (n_target_samples,) - The class labels. If some target samples are unlabeled, fill the - yt's elements with -1. + The class labels. If some target samples are unlabelled, fill the + :math:`\mathbf{y_t}`'s elements with -1. Warning: Note that, due to this convention -1 cannot be used as a class label @@ -1014,7 +1038,7 @@ class BaseTransport(BaseEstimator): return self.fit(Xs, ys, Xt, yt).transform(Xs, ys, Xt, yt) def transform(self, Xs=None, ys=None, Xt=None, yt=None, batch_size=128): - """Transports source samples Xs onto target ones Xt + """Transports source samples :math:`\mathbf{X_s}` onto target ones :math:`\mathbf{X_t}` Parameters ---------- @@ -1025,8 +1049,8 @@ class BaseTransport(BaseEstimator): Xt : array-like, shape (n_target_samples, n_features) The target input samples. yt : array-like, shape (n_target_samples,) - The class labels for target. If some target samples are unlabeled, fill the - yt's elements with -1. + The class labels for target. If some target samples are unlabelled, fill the + :math:`\mathbf{y_t}`'s elements with -1. Warning: Note that, due to this convention -1 cannot be used as a class label @@ -1081,7 +1105,8 @@ class BaseTransport(BaseEstimator): return transp_Xs def transform_labels(self, ys=None): - """Propagate source labels ys to obtain estimated target labels as in [27] + """Propagate source labels :math:`\mathbf{y_s}` to obtain estimated target labels as in + :ref:`[27] `. Parameters ---------- @@ -1093,9 +1118,10 @@ class BaseTransport(BaseEstimator): transp_ys : array-like, shape (n_target_samples, nb_classes) Estimated soft target labels. + + .. _references-basetransport-transform-labels: References ---------- - .. [27] Ievgen Redko, Nicolas Courty, Rémi Flamary, Devis Tuia "Optimal transport for multi-source domain adaptation under target shift", International Conference on Artificial Intelligence and Statistics (AISTATS), 2019. @@ -1126,7 +1152,7 @@ class BaseTransport(BaseEstimator): def inverse_transform(self, Xs=None, ys=None, Xt=None, yt=None, batch_size=128): - """Transports target samples Xt onto source samples Xs + """Transports target samples :math:`\mathbf{X_t}` onto source samples :math:`\mathbf{X_s}` Parameters ---------- @@ -1137,8 +1163,8 @@ class BaseTransport(BaseEstimator): Xt : array-like, shape (n_target_samples, n_features) The target input samples. yt : array-like, shape (n_target_samples,) - The target class labels. If some target samples are unlabeled, fill the - yt's elements with -1. + The target class labels. If some target samples are unlabelled, fill the + :math:`\mathbf{y_t}`'s elements with -1. Warning: Note that, due to this convention -1 cannot be used as a class label @@ -1192,7 +1218,8 @@ class BaseTransport(BaseEstimator): return transp_Xt def inverse_transform_labels(self, yt=None): - """Propagate target labels yt to obtain estimated source labels ys + """Propagate target labels :math:`\mathbf{y_t}` to obtain estimated source labels + :math:`\mathbf{y_s}` Parameters ---------- @@ -1232,35 +1259,37 @@ class LinearTransport(BaseTransport): The function estimates the optimal linear operator that aligns the two empirical distributions. This is equivalent to estimating the closed - form mapping between two Gaussian distributions :math:`N(\mu_s,\Sigma_s)` - and :math:`N(\mu_t,\Sigma_t)` as proposed in [14] and discussed in - remark 2.29 in [15]. + form mapping between two Gaussian distributions :math:`\mathcal{N}(\mu_s,\Sigma_s)` + and :math:`\mathcal{N}(\mu_t,\Sigma_t)` as proposed in + :ref:`[14] ` and discussed in remark 2.29 in + :ref:`[15] `. The linear operator from source to target :math:`M` .. math:: - M(x)=Ax+b + M(\mathbf{x})= \mathbf{A} \mathbf{x} + \mathbf{b} where : .. math:: - A=\Sigma_s^{-1/2}(\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2})^{1/2} + \mathbf{A} &= \Sigma_s^{-1/2} \left(\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2} \right)^{1/2} \Sigma_s^{-1/2} - .. math:: - b=\mu_t-A\mu_s + + \mathbf{b} &= \mu_t - \mathbf{A} \mu_s Parameters ---------- reg : float,optional - regularization added to the daigonals of convariances (>0) + regularization added to the daigonals of covariances (>0) bias: boolean, optional - estimate bias b else b=0 (default:True) + estimate bias :math:`\mathbf{b}` else :math:`\mathbf{b} = 0` (default:True) log : bool, optional record log if True + + .. _references-lineartransport: References ---------- - .. [14] Knott, M. and Smith, C. S. "On the optimal mapping of distributions", Journal of Optimization Theory and Applications Vol 43, 1984 @@ -1279,7 +1308,7 @@ class LinearTransport(BaseTransport): def fit(self, Xs=None, ys=None, Xt=None, yt=None): """Build a coupling matrix from source and target sets of samples - (Xs, ys) and (Xt, yt) + :math:`(\mathbf{X_s}, \mathbf{y_s})` and :math:`(\mathbf{X_t}, \mathbf{y_t})` Parameters ---------- @@ -1290,8 +1319,8 @@ class LinearTransport(BaseTransport): Xt : array-like, shape (n_target_samples, n_features) The training input samples. yt : array-like, shape (n_target_samples,) - The class labels. If some target samples are unlabeled, fill the - yt's elements with -1. + The class labels. If some target samples are unlabelled, fill the + :math:`\mathbf{y_t}`'s elements with -1. Warning: Note that, due to this convention -1 cannot be used as a class label @@ -1325,7 +1354,7 @@ class LinearTransport(BaseTransport): return self def transform(self, Xs=None, ys=None, Xt=None, yt=None, batch_size=128): - """Transports source samples Xs onto target ones Xt + """Transports source samples :math:`\mathbf{X_s}` onto target ones :math:`\mathbf{X_t}` Parameters ---------- @@ -1336,8 +1365,8 @@ class LinearTransport(BaseTransport): Xt : array-like, shape (n_target_samples, n_features) The training input samples. yt : array-like, shape (n_target_samples,) - The class labels. If some target samples are unlabeled, fill the - yt's elements with -1. + The class labels. If some target samples are unlabelled, fill the + :math:`\mathbf{y_t}`'s elements with -1. Warning: Note that, due to this convention -1 cannot be used as a class label @@ -1358,7 +1387,7 @@ class LinearTransport(BaseTransport): def inverse_transform(self, Xs=None, ys=None, Xt=None, yt=None, batch_size=128): - """Transports target samples Xt onto target samples Xs + """Transports target samples :math:`\mathbf{X_t}` onto source samples :math:`\mathbf{X_s}` Parameters ---------- @@ -1369,8 +1398,8 @@ class LinearTransport(BaseTransport): Xt : array-like, shape (n_target_samples, n_features) The training input samples. yt : array-like, shape (n_target_samples,) - The class labels. If some target samples are unlabeled, fill the - yt's elements with -1. + The class labels. If some target samples are unlabelled, fill the + :math:`\mathbf{y_t}`'s elements with -1. Warning: Note that, due to this convention -1 cannot be used as a class label @@ -1392,7 +1421,7 @@ class LinearTransport(BaseTransport): class SinkhornTransport(BaseTransport): - """Domain Adapatation OT method based on Sinkhorn Algorithm + """Domain Adaptation OT method based on Sinkhorn Algorithm Parameters ---------- @@ -1400,7 +1429,7 @@ class SinkhornTransport(BaseTransport): Entropic regularization parameter max_iter : int, float, optional (default=1000) The minimum number of iteration before stopping the optimization - algorithm if no it has not converged + algorithm if it has not converged tol : float, optional (default=10e-9) The precision required to stop the optimization algorithm. verbose : bool, optional (default=False) @@ -1417,8 +1446,8 @@ class SinkhornTransport(BaseTransport): out_of_sample_map : string, optional (default="ferradans") The kind of out of sample mapping to apply to transport samples from a domain into another one. Currently the only possible option is - "ferradans" which uses the method proposed in [6]. - limit_max: float, optional (defaul=np.infty) + "ferradans" which uses the method proposed in :ref:`[6] `. + limit_max: float, optional (default=np.infty) Controls the semi supervised mode. Transport between labeled source and target samples of different classes will exhibit an cost defined by this variable @@ -1428,16 +1457,20 @@ class SinkhornTransport(BaseTransport): coupling_ : array-like, shape (n_source_samples, n_target_samples) The optimal coupling log_ : dictionary - The dictionary of log, empty dic if parameter log is not True + The dictionary of log, empty dict if parameter log is not True + + .. _references-sinkhorntransport: References ---------- .. [1] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, "Optimal Transport for Domain Adaptation," in IEEE Transactions on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1 + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 + .. [6] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014). Regularized discrete optimal transport. SIAM Journal on Imaging Sciences, 7(3), 1853-1882. @@ -1461,7 +1494,7 @@ class SinkhornTransport(BaseTransport): def fit(self, Xs=None, ys=None, Xt=None, yt=None): """Build a coupling matrix from source and target sets of samples - (Xs, ys) and (Xt, yt) + :math:`(\mathbf{X_s}, \mathbf{y_s})` and :math:`(\mathbf{X_t}, \mathbf{y_t})` Parameters ---------- @@ -1472,8 +1505,8 @@ class SinkhornTransport(BaseTransport): Xt : array-like, shape (n_target_samples, n_features) The training input samples. yt : array-like, shape (n_target_samples,) - The class labels. If some target samples are unlabeled, fill the - yt's elements with -1. + The class labels. If some target samples are unlabelled, fill the + :math:`\mathbf{y_t}`'s elements with -1. Warning: Note that, due to this convention -1 cannot be used as a class label @@ -1504,7 +1537,7 @@ class SinkhornTransport(BaseTransport): class EMDTransport(BaseTransport): - """Domain Adapatation OT method based on Earth Mover's Distance + """Domain Adaptation OT method based on Earth Mover's Distance Parameters ---------- @@ -1520,7 +1553,7 @@ class EMDTransport(BaseTransport): out_of_sample_map : string, optional (default="ferradans") The kind of out of sample mapping to apply to transport samples from a domain into another one. Currently the only possible option is - "ferradans" which uses the method proposed in [6]. + "ferradans" which uses the method proposed in :ref:`[6] `. limit_max: float, optional (default=10) Controls the semi supervised mode. Transport between labeled source and target samples of different classes will exhibit an infinite cost @@ -1534,14 +1567,16 @@ class EMDTransport(BaseTransport): coupling_ : array-like, shape (n_source_samples, n_target_samples) The optimal coupling + + .. _references-emdtransport: References ---------- .. [1] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, - "Optimal Transport for Domain Adaptation," in IEEE Transactions - on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1 + "Optimal Transport for Domain Adaptation," in IEEE Transactions + on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1 .. [6] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014). - Regularized discrete optimal transport. SIAM Journal on Imaging - Sciences, 7(3), 1853-1882. + Regularized discrete optimal transport. SIAM Journal on Imaging + Sciences, 7(3), 1853-1882. """ def __init__(self, metric="sqeuclidean", norm=None, log=False, @@ -1558,7 +1593,7 @@ class EMDTransport(BaseTransport): def fit(self, Xs, ys=None, Xt=None, yt=None): """Build a coupling matrix from source and target sets of samples - (Xs, ys) and (Xt, yt) + :math:`(\mathbf{X_s}, \mathbf{y_s})` and :math:`(\mathbf{X_t}, \mathbf{y_t})` Parameters ---------- @@ -1569,8 +1604,8 @@ class EMDTransport(BaseTransport): Xt : array-like, shape (n_target_samples, n_features) The training input samples. yt : array-like, shape (n_target_samples,) - The class labels. If some target samples are unlabeled, fill the - yt's elements with -1. + The class labels. If some target samples are unlabelled, fill the + :math:`\mathbf{y_t}`'s elements with -1. Warning: Note that, due to this convention -1 cannot be used as a class label @@ -1597,8 +1632,7 @@ class EMDTransport(BaseTransport): class SinkhornLpl1Transport(BaseTransport): - - """Domain Adapatation OT method based on sinkhorn algorithm + + r"""Domain Adaptation OT method based on sinkhorn algorithm + LpL1 class regularization. Parameters @@ -1609,7 +1643,7 @@ class SinkhornLpl1Transport(BaseTransport): Class regularization parameter max_iter : int, float, optional (default=10) The minimum number of iteration before stopping the optimization - algorithm if no it has not converged + algorithm if it has not converged max_inner_iter : int, float, optional (default=200) The number of iteration in the inner loop log : bool, optional (default=False) @@ -1628,8 +1662,8 @@ class SinkhornLpl1Transport(BaseTransport): out_of_sample_map : string, optional (default="ferradans") The kind of out of sample mapping to apply to transport samples from a domain into another one. Currently the only possible option is - "ferradans" which uses the method proposed in [6]. - limit_max: float, optional (defaul=np.infty) + "ferradans" which uses the method proposed in :ref:`[6] `. + limit_max: float, optional (default=np.infty) Controls the semi supervised mode. Transport between labeled source and target samples of different classes will exhibit a cost defined by limit_max. @@ -1639,16 +1673,19 @@ class SinkhornLpl1Transport(BaseTransport): coupling_ : array-like, shape (n_source_samples, n_target_samples) The optimal coupling + + .. _references-sinkhornlpl1transport: References ---------- - .. [1] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, "Optimal Transport for Domain Adaptation," in IEEE Transactions on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1 + .. [2] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015). Generalized conditional gradient: analysis of convergence and applications. arXiv preprint arXiv:1510.06567. + .. [6] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014). Regularized discrete optimal transport. SIAM Journal on Imaging Sciences, 7(3), 1853-1882. @@ -1675,7 +1712,7 @@ class SinkhornLpl1Transport(BaseTransport): def fit(self, Xs, ys=None, Xt=None, yt=None): """Build a coupling matrix from source and target sets of samples - (Xs, ys) and (Xt, yt) + :math:`(\mathbf{X_s}, \mathbf{y_s})` and :math:`(\mathbf{X_t}, \mathbf{y_t})` Parameters ---------- @@ -1686,8 +1723,8 @@ class SinkhornLpl1Transport(BaseTransport): Xt : array-like, shape (n_target_samples, n_features) The training input samples. yt : array-like, shape (n_target_samples,) - The class labels. If some target samples are unlabeled, fill the - yt's elements with -1. + The class labels. If some target samples are unlabelled, fill the + :math:`\mathbf{y_t}`'s elements with -1. Warning: Note that, due to this convention -1 cannot be used as a class label @@ -1719,13 +1756,14 @@ class SinkhornLpl1Transport(BaseTransport): class EMDLaplaceTransport(BaseTransport): - """Domain Adapatation OT method based on Earth Mover's Distance with Laplacian regularization + """Domain Adaptation OT method based on Earth Mover's Distance with Laplacian regularization Parameters ---------- reg_type : string optional (default='pos') Type of the regularization term: 'pos' and 'disp' for - regularization term defined in [2] and [6], respectively. + regularization term defined in :ref:`[2] ` and + :ref:`[6] `, respectively. reg_lap : float, optional (default=1) Laplacian regularization parameter reg_src : float, optional (default=0.5) @@ -1756,24 +1794,27 @@ class EMDLaplaceTransport(BaseTransport): out_of_sample_map : string, optional (default="ferradans") The kind of out of sample mapping to apply to transport samples from a domain into another one. Currently the only possible option is - "ferradans" which uses the method proposed in [6]. + "ferradans" which uses the method proposed in :ref:`[6] `. Attributes ---------- coupling_ : array-like, shape (n_source_samples, n_target_samples) The optimal coupling + + .. _references-emdlaplacetransport: References ---------- .. [1] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, "Optimal Transport for Domain Adaptation," in IEEE Transactions on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1 + .. [2] R. Flamary, N. Courty, D. Tuia, A. Rakotomamonjy, "Optimal transport with Laplacian regularization: Applications to domain adaptation and shape matching," - in NIPS Workshop on Optimal Transport and Machine Learning OTML, 2014. + in NIPS Workshop on Optimal Transport and Machine Learning OTML, 2014. + .. [6] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014). - Regularized discrete optimal transport. SIAM Journal on Imaging - Sciences, 7(3), 1853-1882. + Regularized discrete optimal transport. SIAM Journal on Imaging Sciences, 7(3), 1853-1882. """ def __init__(self, reg_type='pos', reg_lap=1., reg_src=1., metric="sqeuclidean", @@ -1799,7 +1840,7 @@ class EMDLaplaceTransport(BaseTransport): def fit(self, Xs, ys=None, Xt=None, yt=None): """Build a coupling matrix from source and target sets of samples - (Xs, ys) and (Xt, yt) + :math:`(\mathbf{X_s}, \mathbf{y_s})` and :math:`(\mathbf{X_t}, \mathbf{y_t})` Parameters ---------- @@ -1810,8 +1851,8 @@ class EMDLaplaceTransport(BaseTransport): Xt : array-like, shape (n_target_samples, n_features) The training input samples. yt : array-like, shape (n_target_samples,) - The class labels. If some target samples are unlabeled, fill the - yt's elements with -1. + The class labels. If some target samples are unlabelled, fill the + :math:`\mathbf{y_t}`'s elements with -1. Warning: Note that, due to this convention -1 cannot be used as a class label @@ -1840,8 +1881,8 @@ class EMDLaplaceTransport(BaseTransport): class SinkhornL1l2Transport(BaseTransport): - """Domain Adapatation OT method based on sinkhorn algorithm + - l1l2 class regularization. + """Domain Adaptation OT method based on sinkhorn algorithm + + L1L2 class regularization. Parameters ---------- @@ -1851,7 +1892,7 @@ class SinkhornL1l2Transport(BaseTransport): Class regularization parameter max_iter : int, float, optional (default=10) The minimum number of iteration before stopping the optimization - algorithm if no it has not converged + algorithm if it has not converged max_inner_iter : int, float, optional (default=200) The number of iteration in the inner loop tol : float, optional (default=10e-9) @@ -1870,7 +1911,7 @@ class SinkhornL1l2Transport(BaseTransport): out_of_sample_map : string, optional (default="ferradans") The kind of out of sample mapping to apply to transport samples from a domain into another one. Currently the only possible option is - "ferradans" which uses the method proposed in [6]. + "ferradans" which uses the method proposed in :ref:`[6] `. limit_max: float, optional (default=10) Controls the semi supervised mode. Transport between labeled source and target samples of different classes will exhibit an infinite cost @@ -1881,18 +1922,21 @@ class SinkhornL1l2Transport(BaseTransport): coupling_ : array-like, shape (n_source_samples, n_target_samples) The optimal coupling log_ : dictionary - The dictionary of log, empty dic if parameter log is not True + The dictionary of log, empty dict if parameter log is not True + + .. _references-sinkhornl1l2transport: References ---------- - .. [1] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, "Optimal Transport for Domain Adaptation," in IEEE Transactions on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1 + .. [2] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015). Generalized conditional gradient: analysis of convergence and applications. arXiv preprint arXiv:1510.06567. + .. [6] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014). Regularized discrete optimal transport. SIAM Journal on Imaging Sciences, 7(3), 1853-1882. @@ -1919,7 +1963,7 @@ class SinkhornL1l2Transport(BaseTransport): def fit(self, Xs, ys=None, Xt=None, yt=None): """Build a coupling matrix from source and target sets of samples - (Xs, ys) and (Xt, yt) + :math:`(\mathbf{X_s}, \mathbf{y_s})` and :math:`(\mathbf{X_t}, \mathbf{y_t})` Parameters ---------- @@ -1930,8 +1974,8 @@ class SinkhornL1l2Transport(BaseTransport): Xt : array-like, shape (n_target_samples, n_features) The training input samples. yt : array-like, shape (n_target_samples,) - The class labels. If some target samples are unlabeled, fill the - yt's elements with -1. + The class labels. If some target samples are unlabelled, fill the + :math:`\mathbf{y_t}`'s elements with -1. Warning: Note that, due to this convention -1 cannot be used as a class label @@ -1973,7 +2017,7 @@ class MappingTransport(BaseEstimator): mu : float, optional (default=1) Weight for the linear OT loss (>0) eta : float, optional (default=0.001) - Regularization term for the linear mapping L (>0) + Regularization term for the linear mapping `L` (>0) bias : bool, optional (default=False) Estimate linear mapping with constant bias metric : string, optional (default="sqeuclidean") @@ -2004,17 +2048,20 @@ class MappingTransport(BaseEstimator): ---------- coupling_ : array-like, shape (n_source_samples, n_target_samples) The optimal coupling - mapping_ : array-like, shape (n_features (+ 1), n_features) - (if bias) for kernel == linear + mapping_ : The associated mapping - array-like, shape (n_source_samples (+ 1), n_features) - (if bias) for kernel == gaussian + + - array-like, shape (`n_features` (+ 1), `n_features`), + (if bias) for kernel == linear + + - array-like, shape (`n_source_samples` (+ 1), `n_features`), + (if bias) for kernel == gaussian log_ : dictionary - The dictionary of log, empty dic if parameter log is not True + The dictionary of log, empty dict if parameter log is not True + References ---------- - .. [8] M. Perrot, N. Courty, R. Flamary, A. Habrard, "Mapping estimation for discrete optimal transport", Neural Information Processing Systems (NIPS), 2016. @@ -2042,7 +2089,8 @@ class MappingTransport(BaseEstimator): def fit(self, Xs=None, ys=None, Xt=None, yt=None): """Builds an optimal coupling and estimates the associated mapping - from source and target sets of samples (Xs, ys) and (Xt, yt) + from source and target sets of samples + :math:`(\mathbf{X_s}, \mathbf{y_s})` and :math:`(\mathbf{X_t}, \mathbf{y_t})` Parameters ---------- @@ -2053,8 +2101,8 @@ class MappingTransport(BaseEstimator): Xt : array-like, shape (n_target_samples, n_features) The training input samples. yt : array-like, shape (n_target_samples,) - The class labels. If some target samples are unlabeled, fill the - yt's elements with -1. + The class labels. If some target samples are unlabelled, fill the + :math:`\mathbf{y_t}`'s elements with -1. Warning: Note that, due to this convention -1 cannot be used as a class label @@ -2098,7 +2146,7 @@ class MappingTransport(BaseEstimator): return self def transform(self, Xs): - """Transports source samples Xs onto target ones Xt + """Transports source samples :math:`\mathbf{X_s}` onto target ones :math:`\mathbf{X_t}` Parameters ---------- @@ -2138,7 +2186,7 @@ class MappingTransport(BaseEstimator): class UnbalancedSinkhornTransport(BaseTransport): - """Domain Adapatation unbalanced OT method based on sinkhorn algorithm + """Domain Adaptation unbalanced OT method based on sinkhorn algorithm Parameters ---------- @@ -2151,7 +2199,7 @@ class UnbalancedSinkhornTransport(BaseTransport): 'sinkhorn_epsilon_scaling', see those function for specific parameters max_iter : int, float, optional (default=10) The minimum number of iteration before stopping the optimization - algorithm if no it has not converged + algorithm if it has not converged tol : float, optional (default=10e-9) Stop threshold on error (inner sinkhorn solver) (>0) verbose : bool, optional (default=False) @@ -2168,7 +2216,7 @@ class UnbalancedSinkhornTransport(BaseTransport): out_of_sample_map : string, optional (default="ferradans") The kind of out of sample mapping to apply to transport samples from a domain into another one. Currently the only possible option is - "ferradans" which uses the method proposed in [6]. + "ferradans" which uses the method proposed in :ref:`[6] `. limit_max: float, optional (default=10) Controls the semi supervised mode. Transport between labeled source and target samples of different classes will exhibit an infinite cost @@ -2179,14 +2227,16 @@ class UnbalancedSinkhornTransport(BaseTransport): coupling_ : array-like, shape (n_source_samples, n_target_samples) The optimal coupling log_ : dictionary - The dictionary of log, empty dic if parameter log is not True + The dictionary of log, empty dict if parameter log is not True + + .. _references-unbalancedsinkhorntransport: References ---------- - .. [1] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). - Scaling algorithms for unbalanced transport problems. arXiv preprint - arXiv:1607.05816. + Scaling algorithms for unbalanced transport problems. arXiv preprint + arXiv:1607.05816. + .. [6] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014). Regularized discrete optimal transport. SIAM Journal on Imaging Sciences, 7(3), 1853-1882. @@ -2212,7 +2262,7 @@ class UnbalancedSinkhornTransport(BaseTransport): def fit(self, Xs, ys=None, Xt=None, yt=None): """Build a coupling matrix from source and target sets of samples - (Xs, ys) and (Xt, yt) + :math:`(\mathbf{X_s}, \mathbf{y_s})` and :math:`(\mathbf{X_t}, \mathbf{y_t})` Parameters ---------- @@ -2223,8 +2273,8 @@ class UnbalancedSinkhornTransport(BaseTransport): Xt : array-like, shape (n_target_samples, n_features) The training input samples. yt : array-like, shape (n_target_samples,) - The class labels. If some target samples are unlabeled, fill the - yt's elements with -1. + The class labels. If some target samples are unlabelled, fill the + :math:`\mathbf{y_t}`'s elements with -1. Warning: Note that, due to this convention -1 cannot be used as a class label @@ -2258,7 +2308,7 @@ class UnbalancedSinkhornTransport(BaseTransport): class JCPOTTransport(BaseTransport): - """Domain Adapatation OT method for multi-source target shift based on Wasserstein barycenter algorithm. + """Domain Adaptation OT method for multi-source target shift based on Wasserstein barycenter algorithm. Parameters ---------- @@ -2266,7 +2316,7 @@ class JCPOTTransport(BaseTransport): Entropic regularization parameter max_iter : int, float, optional (default=10) The minimum number of iteration before stopping the optimization - algorithm if no it has not converged + algorithm if it has not converged tol : float, optional (default=10e-9) Stop threshold on error (inner sinkhorn solver) (>0) verbose : bool, optional (default=False) @@ -2283,7 +2333,7 @@ class JCPOTTransport(BaseTransport): out_of_sample_map : string, optional (default="ferradans") The kind of out of sample mapping to apply to transport samples from a domain into another one. Currently the only possible option is - "ferradans" which uses the method proposed in [6]. + "ferradans" which uses the method proposed in :ref:`[6] `. Attributes ---------- @@ -2292,11 +2342,12 @@ class JCPOTTransport(BaseTransport): proportions_ : array-like, shape (n_classes,) Estimated class proportions in the target domain log_ : dictionary - The dictionary of log, empty dic if parameter log is not True + The dictionary of log, empty dict if parameter log is not True + + .. _references-jcpottransport: References ---------- - .. [1] Ievgen Redko, Nicolas Courty, Rémi Flamary, Devis Tuia "Optimal transport for multi-source domain adaptation under target shift", International Conference on Artificial Intelligence and Statistics (AISTATS), @@ -2323,7 +2374,7 @@ class JCPOTTransport(BaseTransport): def fit(self, Xs, ys=None, Xt=None, yt=None): """Building coupling matrices from a list of source and target sets of samples - (Xs, ys) and (Xt, yt) + :math:`(\mathbf{X_s}, \mathbf{y_s})` and :math:`(\mathbf{X_t}, \mathbf{y_t})` Parameters ---------- @@ -2334,8 +2385,8 @@ class JCPOTTransport(BaseTransport): Xt : array-like, shape (n_target_samples, n_features) The training input samples. yt : array-like, shape (n_target_samples,) - The class labels. If some target samples are unlabeled, fill the - yt's elements with -1. + The class labels. If some target samples are unlabelled, fill the + :math:`\mathbf{y_t}`'s elements with -1. Warning: Note that, due to this convention -1 cannot be used as a class label @@ -2368,7 +2419,7 @@ class JCPOTTransport(BaseTransport): return self def transform(self, Xs=None, ys=None, Xt=None, yt=None, batch_size=128): - """Transports source samples Xs onto target ones Xt + """Transports source samples :math:`\mathbf{X_s}` onto target ones :math:`\mathbf{X_t}` Parameters ---------- @@ -2379,8 +2430,8 @@ class JCPOTTransport(BaseTransport): Xt : array-like, shape (n_target_samples, n_features) The training input samples. yt : array-like, shape (n_target_samples,) - The class labels. If some target samples are unlabeled, fill the - yt's elements with -1. + The class labels. If some target samples are unlabelled, fill the + :math:`\mathbf{y_t}`'s elements with -1. Warning: Note that, due to this convention -1 cannot be used as a class label @@ -2440,7 +2491,8 @@ class JCPOTTransport(BaseTransport): return transp_Xs def transform_labels(self, ys=None): - """Propagate source labels ys to obtain target labels as in [27] + """Propagate source labels :math:`\mathbf{y_s}` to obtain target labels as in + :ref:`[27] ` Parameters ---------- @@ -2451,6 +2503,14 @@ class JCPOTTransport(BaseTransport): ------- yt : array-like, shape (n_target_samples, nb_classes) Estimated soft target labels. + + + .. _references-jcpottransport-transform-labels: + References + ---------- + .. [27] Ievgen Redko, Nicolas Courty, Rémi Flamary, Devis Tuia + "Optimal transport for multi-source domain adaptation under target shift", + International Conference on Artificial Intelligence and Statistics (AISTATS), 2019. """ # check the necessary inputs parameters are here @@ -2482,11 +2542,12 @@ class JCPOTTransport(BaseTransport): return yt.T def inverse_transform_labels(self, yt=None): - """Propagate source labels ys to obtain target labels + """Propagate target labels :math:`\mathbf{y_t}` to obtain estimated source labels + :math:`\mathbf{y_s}` Parameters ---------- - yt : array-like, shape (n_source_samples,) + yt : array-like, shape (n_target_samples,) The target class labels Returns diff --git a/ot/datasets.py b/ot/datasets.py index b86ef3b..ad6390c 100644 --- a/ot/datasets.py +++ b/ot/datasets.py @@ -13,7 +13,7 @@ from .utils import check_random_state, deprecated def make_1D_gauss(n, m, s): - """return a 1D histogram for a gaussian distribution (n bins, mean m and std s) + """return a 1D histogram for a gaussian distribution (`n` bins, mean `m` and std `s`) Parameters ---------- @@ -26,7 +26,7 @@ def make_1D_gauss(n, m, s): Returns ------- - h : ndarray (n,) + h : ndarray (`n`,) 1D histogram for a gaussian distribution """ x = np.arange(n, dtype=np.float64) @@ -41,7 +41,7 @@ def get_1D_gauss(n, m, sigma): def make_2D_samples_gauss(n, m, sigma, random_state=None): - """Return n samples drawn from 2D gaussian N(m,sigma) + """Return `n` samples drawn from 2D gaussian :math:`\mathcal{N}(m, \sigma)` Parameters ---------- @@ -59,8 +59,8 @@ def make_2D_samples_gauss(n, m, sigma, random_state=None): Returns ------- - X : ndarray, shape (n, 2) - n samples drawn from N(m, sigma). + X : ndarray, shape (`n`, 2) + n samples drawn from :math:`\mathcal{N}(m, \sigma)`. """ generator = check_random_state(random_state) @@ -102,7 +102,7 @@ def make_data_classif(dataset, n, nz=.5, theta=0, p=.5, random_state=None, **kwa Returns ------- X : ndarray, shape (n, d) - n observation of size d + `n` observation of size `d` y : ndarray, shape (n,) labels of the samples. """ diff --git a/ot/dr.py b/ot/dr.py index 7469270..c2f51f8 100644 --- a/ot/dr.py +++ b/ot/dr.py @@ -22,7 +22,7 @@ from pymanopt.solvers import SteepestDescent, TrustRegions def dist(x1, x2): - """ Compute squared euclidean distance between samples (autograd) + r""" Compute squared euclidean distance between samples (autograd) """ x1p2 = np.sum(np.square(x1), 1) x2p2 = np.sum(np.square(x2), 1) @@ -30,7 +30,7 @@ def dist(x1, x2): def sinkhorn(w1, w2, M, reg, k): - """Sinkhorn algorithm with fixed number of iteration (autograd) + r"""Sinkhorn algorithm with fixed number of iteration (autograd) """ K = np.exp(-M / reg) ui = np.ones((M.shape[0],)) @@ -43,14 +43,14 @@ def sinkhorn(w1, w2, M, reg, k): def split_classes(X, y): - """split samples in X by classes in y + r"""split samples in :math:`\mathbf{X}` by classes in :math:`\mathbf{y}` """ lstsclass = np.unique(y) return [X[y == i, :].astype(np.float32) for i in lstsclass] def fda(X, y, p=2, reg=1e-16): - """Fisher Discriminant Analysis + r"""Fisher Discriminant Analysis Parameters ---------- @@ -111,18 +111,19 @@ def fda(X, y, p=2, reg=1e-16): def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None, normalize=False): r""" - Wasserstein Discriminant Analysis [11]_ + Wasserstein Discriminant Analysis :ref:`[11] ` The function solves the following optimization problem: .. math:: - P = \\text{arg}\min_P \\frac{\\sum_i W(PX^i,PX^i)}{\\sum_{i,j\\neq i} W(PX^i,PX^j)} + \mathbf{P} = \mathop{\arg \min}_\mathbf{P} \quad + \frac{\sum\limits_i W(P \mathbf{X}^i, P \mathbf{X}^i)}{\sum\limits_{i, j \neq i} W(P \mathbf{X}^i, P \mathbf{X}^j)} where : - - :math:`P` is a linear projection operator in the Stiefel(p,d) manifold + - :math:`P` is a linear projection operator in the Stiefel(`p`, `d`) manifold - :math:`W` is entropic regularized Wasserstein distances - - :math:`X^i` are samples in the dataset corresponding to class i + - :math:`\mathbf{X}^i` are samples in the dataset corresponding to class i Parameters ---------- @@ -140,7 +141,7 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None, no P0 : ndarray, shape (d, p) Initial starting point for projection. normalize : bool, optional - Normalise the Wasserstaiun distane by the average distance on P0 (default : False) + Normalise the Wasserstaiun distance by the average distance on P0 (default : False) verbose : int, optional Print information along iterations. @@ -151,6 +152,8 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None, no proj : callable Projection function including mean centering. + + .. _references-wda: References ---------- .. [11] Flamary, R., Cuturi, M., Courty, N., & Rakotomamonjy, A. (2016). @@ -217,27 +220,28 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None, no def projection_robust_wasserstein(X, Y, a, b, tau, U0=None, reg=0.1, k=2, stopThr=1e-3, maxiter=100, verbose=0): r""" - Projection Robust Wasserstein Distance [32] + Projection Robust Wasserstein Distance :ref:`[32] ` The function solves the following optimization problem: .. math:: - \max_{U \in St(d, k)} \min_{\pi \in \Pi(\mu,\nu)} \sum_{i,j} \pi_{i,j} \|U^T(x_i - y_j)\|^2 - reg * H(\pi) + \max_{U \in St(d, k)} \ \min_{\pi \in \Pi(\mu,\nu)} \quad \sum_{i,j} \pi_{i,j} + \|U^T(\mathbf{x}_i - \mathbf{y}_j)\|^2 - \mathrm{reg} \cdot H(\pi) - - :math:`U` is a linear projection operator in the Stiefel(d, k) manifold + - :math:`U` is a linear projection operator in the Stiefel(`d`, `k`) manifold - :math:`H(\pi)` is entropy regularizer - - :math:`x_i`, :math:`y_j` are samples of measures \mu and \nu respectively + - :math:`\mathbf{x}_i`, :math:`\mathbf{y}_j` are samples of measures :math:`\mu` and :math:`\nu` respectively Parameters ---------- X : ndarray, shape (n, d) - Samples from measure \mu + Samples from measure :math:`\mu` Y : ndarray, shape (n, d) - Samples from measure \nu + Samples from measure :math:`\nu` a : ndarray, shape (n, ) - weights for measure \mu + weights for measure :math:`\mu` b : ndarray, shape (n, ) - weights for measure \nu + weights for measure :math:`\nu` tau : float stepsize for Riemannian Gradient Descent U0 : ndarray, shape (d, p) @@ -258,6 +262,8 @@ def projection_robust_wasserstein(X, Y, a, b, tau, U0=None, reg=0.1, k=2, stopTh U : ndarray, shape (d, k) Projection operator. + + .. _references-projection-robust-wasserstein: References ---------- .. [32] Huang, M. , Ma S. & Lai L. (2021). diff --git a/ot/gromov.py b/ot/gromov.py index a0fbf48..465693d 100644 --- a/ot/gromov.py +++ b/ot/gromov.py @@ -327,7 +327,8 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs The function solves the following optimization problem: .. math:: - \mathbf{GW} = \mathop{\arg \min}_\mathbf{T} \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} + \mathbf{GW} = \mathop{\arg \min}_\mathbf{T} \quad \sum_{i,j,k,l} + L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} Where : @@ -410,7 +411,8 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwarg The function solves the following optimization problem: .. math:: - GW = \min_\mathbf{T} \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} + GW = \min_\mathbf{T} \quad \sum_{i,j,k,l} + L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} Where : @@ -487,8 +489,8 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, Computes the FGW transport between two graphs (see :ref:`[24] `) .. math:: - \gamma = \mathop{\arg \min}_\gamma (1 - \alpha) <\gamma, \mathbf{M}>_F + \alpha \sum_{i,j,k,l} - L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} + \gamma = \mathop{\arg \min}_\gamma \quad (1 - \alpha) \langle \gamma, \mathbf{M} \rangle_F + + \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} s.t. \ \mathbf{\gamma} \mathbf{1} &= \mathbf{p} @@ -569,7 +571,7 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5 Computes the FGW distance between two graphs see (see :ref:`[24] `) .. math:: - \min_\gamma (1 - \alpha) <\gamma, \mathbf{M}>_F + \alpha \sum_{i,j,k,l} + \min_\gamma \quad (1 - \alpha) \langle \gamma, \mathbf{M} \rangle_F + \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} s.t. \ \mathbf{\gamma} \mathbf{1} &= \mathbf{p} @@ -591,9 +593,9 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5 M : array-like, shape (ns, nt) Metric cost matrix between features across domains C1 : array-like, shape (ns, ns) - Metric cost matrix respresentative of the structure in the source space. + Metric cost matrix representative of the structure in the source space. C2 : array-like, shape (nt, nt) - Metric cost matrix espresentative of the structure in the target space. + Metric cost matrix representative of the structure in the target space. p : array-like, shape (ns,) Distribution in the source space. q : array-like, shape (nt,) @@ -612,8 +614,8 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5 Returns ------- - gamma : array-like, shape (ns, nt) - Optimal transportation matrix for the given parameters. + fgw-distance : float + Fused gromov wasserstein distance for the given parameters. log : dict Log dictionary return only if log==True in parameters. @@ -780,7 +782,8 @@ def pointwise_gromov_wasserstein(C1, C2, p, q, loss_fun, The function solves the following optimization problem: .. math:: - \mathbf{GW} = \mathop{\arg \min}_\mathbf{T} \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} + \mathbf{GW} = \mathop{\arg \min}_\mathbf{T} \quad \sum_{i,j,k,l} + L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} s.t. \ \mathbf{T} \mathbf{1} &= \mathbf{p} @@ -901,7 +904,8 @@ def sampled_gromov_wasserstein(C1, C2, p, q, loss_fun, The function solves the following optimization problem: .. math:: - \mathbf{GW} = \mathop{\arg \min}_\mathbf{T} \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} + \mathbf{GW} = \mathop{\arg \min}_\mathbf{T} \quad \sum_{i,j,k,l} + L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} s.t. \ \mathbf{T} \mathbf{1} &= \mathbf{p} @@ -1052,7 +1056,7 @@ def entropic_gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, The function solves the following optimization problem: .. math:: - \mathbf{GW} = \mathop{\arg\min}_\mathbf{T} \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} - \epsilon(H(\mathbf{T})) + \mathbf{GW} = \mathop{\arg\min}_\mathbf{T} \quad \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} - \epsilon(H(\mathbf{T})) s.t. \ \mathbf{T} \mathbf{1} &= \mathbf{p} @@ -1157,7 +1161,8 @@ def entropic_gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon, The function solves the following optimization problem: .. math:: - GW = \min_\mathbf{T} \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} - \epsilon(H(\mathbf{T})) + GW = \min_\mathbf{T} \quad \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) + \mathbf{T}_{i,j} \mathbf{T}_{k,l} - \epsilon(H(\mathbf{T})) Where : @@ -1223,7 +1228,7 @@ def entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, .. math:: - \mathbf{C} = \mathop{\arg \min}_{\mathbf{C}\in \mathbb{R}^{N \times N}} \sum_s \lambda_s \mathrm{GW}(\mathbf{C}, \mathbf{C}_s, \mathbf{p}, \mathbf{p}_s) + \mathbf{C} = \mathop{\arg \min}_{\mathbf{C}\in \mathbb{R}^{N \times N}} \quad \sum_s \lambda_s \mathrm{GW}(\mathbf{C}, \mathbf{C}_s, \mathbf{p}, \mathbf{p}_s) Where : @@ -1336,7 +1341,7 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, .. math:: - \mathbf{C} = \mathop{\arg \min}_{\mathbf{C}\in \mathbb{R}^{N \times N}} \sum_s \lambda_s \mathrm{GW}(\mathbf{C}, \mathbf{C}_s, \mathbf{p}, \mathbf{p}_s) + \mathbf{C} = \mathop{\arg \min}_{\mathbf{C}\in \mathbb{R}^{N \times N}} \quad \sum_s \lambda_s \mathrm{GW}(\mathbf{C}, \mathbf{C}_s, \mathbf{p}, \mathbf{p}_s) Where : diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index 2c18a88..5da897d 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -62,7 +62,7 @@ def center_ot_dual(alpha0, beta0, a=None, b=None): is the following: .. math:: - \alpha^T a= \beta^T b + \alpha^T \mathbf{a} = \beta^T \mathbf{b} in addition to the OT problem constraints. @@ -70,11 +70,11 @@ def center_ot_dual(alpha0, beta0, a=None, b=None): a constant from both :math:`\alpha_0` and :math:`\beta_0`. .. math:: - c=\frac{\beta0^T b-\alpha_0^T a}{1^Tb+1^Ta} + c &= \frac{\beta_0^T \mathbf{b} - \alpha_0^T \mathbf{a}}{\mathbf{1}^T \mathbf{b} + \mathbf{1}^T \mathbf{a}} - \alpha=\alpha_0+c + \alpha &= \alpha_0 + c - \beta=\beta0+c + \beta &= \beta_0 + c Parameters ---------- @@ -117,7 +117,7 @@ def estimate_dual_null_weights(alpha0, beta0, a, b, M): The feasible values are computed efficiently but rather coarsely. .. warning:: - This function is necessary because the C++ solver in emd_c + This function is necessary because the C++ solver in `emd_c` discards all samples in the distributions with zeros weights. This means that while the primal variable (transport matrix) is exact, the solver only returns feasible dual potentials @@ -126,26 +126,26 @@ def estimate_dual_null_weights(alpha0, beta0, a, b, M): First we compute the constraints violations: .. math:: - V=\alpha+\beta^T-M + \mathbf{V} = \alpha + \beta^T - \mathbf{M} - Next we compute the max amount of violation per row (alpha) and - columns (beta) + Next we compute the max amount of violation per row (:math:`\alpha`) and + columns (:math:`beta`) .. math:: - v^a_i=\max_j V_{i,j} + \mathbf{v^a}_i = \max_j \mathbf{V}_{i,j} - v^b_j=\max_i V_{i,j} + \mathbf{v^b}_j = \max_i \mathbf{V}_{i,j} Finally we update the dual potential with 0 weights if a constraint is violated .. math:: - \alpha_i = \alpha_i -v^a_i \quad \text{ if } a_i=0 \text{ and } v^a_i>0 + \alpha_i = \alpha_i - \mathbf{v^a}_i \quad \text{ if } \mathbf{a}_i=0 \text{ and } \mathbf{v^a}_i>0 - \beta_j = \beta_j -v^b_j \quad \text{ if } b_j=0 \text{ and } v^b_j>0 + \beta_j = \beta_j - \mathbf{v^b}_j \quad \text{ if } \mathbf{b}_j=0 \text{ and } \mathbf{v^b}_j > 0 In the end the dual potentials are centered using function - :ref:`center_ot_dual`. + :py:func:`ot.lp.center_ot_dual`. Note that all those updates do not change the objective value of the solution but provide dual potentials that do not violate the constraints. @@ -201,26 +201,28 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1): r"""Solves the Earth Movers distance problem and returns the OT matrix - .. math:: \gamma = arg\min_\gamma <\gamma,M>_F + .. math:: + \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + + s.t. \ \gamma \mathbf{1} = \mathbf{a} - s.t. \gamma 1 = a + \gamma^T \mathbf{1} = \mathbf{b} - \gamma^T 1= b + \gamma \geq 0 - \gamma\geq 0 where : - - M is the metric cost matrix - - a and b are the sample weights + - :math:`\mathbf{M}` is the metric cost matrix + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are the sample weights - .. warning:: Note that the M matrix in numpy needs to be a C-order + .. warning:: Note that the :math:`\mathbf{M}` matrix in numpy needs to be a C-order numpy.array in float64 format. It will be converted if not in this format .. note:: This function is backend-compatible and will work on arrays from all compatible backends. - Uses the algorithm proposed in [1]_ + Uses the algorithm proposed in :ref:`[1] `. Parameters ---------- @@ -267,17 +269,19 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1): array([[0.5, 0. ], [0. , 0.5]]) + + .. _references-emd: References ---------- - .. [1] Bonneel, N., Van De Panne, M., Paris, S., & Heidrich, W. (2011, December). Displacement interpolation using Lagrangian mass transport. In ACM Transactions on Graphics (TOG) (Vol. 30, No. 6, p. 158). ACM. See Also -------- - ot.bregman.sinkhorn : Entropic regularized OT ot.optim.cg : General - regularized OT""" + ot.bregman.sinkhorn : Entropic regularized OT + ot.optim.cg : General regularized OT + """ # convert to numpy if list a, b, M = list_to_array(a, b, M) @@ -340,22 +344,23 @@ def emd2(a, b, M, processes=1, r"""Solves the Earth Movers distance problem and returns the loss .. math:: - \min_\gamma <\gamma,M>_F + \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F - s.t. \gamma 1 = a + s.t. \ \gamma \mathbf{1} = \mathbf{a} - \gamma^T 1= b + \gamma^T \mathbf{1} = \mathbf{b} + + \gamma \geq 0 - \gamma\geq 0 where : - - M is the metric cost matrix - - a and b are the sample weights + - :math:`\mathbf{M}` is the metric cost matrix + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are the sample weights .. note:: This function is backend-compatible and will work on arrays from all compatible backends. - Uses the algorithm proposed in [1]_ + Uses the algorithm proposed in :ref:`[1] `. Parameters ---------- @@ -405,9 +410,10 @@ def emd2(a, b, M, processes=1, >>> ot.emd2(a,b,M) 0.0 + + .. _references-emd2: References ---------- - .. [1] Bonneel, N., Van De Panne, M., Paris, S., & Heidrich, W. (2011, December). Displacement interpolation using Lagrangian mass transport. In ACM Transactions on Graphics (TOG) (Vol. 30, No. 6, p. @@ -416,7 +422,8 @@ def emd2(a, b, M, processes=1, See Also -------- ot.bregman.sinkhorn : Entropic regularized OT - ot.optim.cg : General regularized OT""" + ot.optim.cg : General regularized OT + """ a, b, M = list_to_array(a, b, M) @@ -508,29 +515,35 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None Solves the free support (locations of the barycenters are optimized, not the weights) Wasserstein barycenter problem (i.e. the weighted Frechet mean for the 2-Wasserstein distance), formally: .. math:: - \min_X \sum_{i=1}^N w_i W_2^2(b, X, a_i, X_i) + \min_\mathbf{X} \quad \sum_{i=1}^N w_i W_2^2(\mathbf{b}, \mathbf{X}, \mathbf{a}_i, \mathbf{X}_i) where : - :math:`w \in \mathbb{(0, 1)}^{N}`'s are the barycenter weights and sum to one - - the :math:`a_i \in \mathbb{R}^{k_i}` are the empirical measures weights and sum to one for each :math:`i` - - the :math:`X_i \in \mathbb{R}^{k_i, d}` are the empirical measures atoms locations - - :math:`b \in \mathbb{R}^{k}` is the desired weights vector of the barycenter + - the :math:`\mathbf{a}_i \in \mathbb{R}^{k_i}` are the empirical measures weights and sum to one for each :math:`i` + - the :math:`\mathbf{X}_i \in \mathbb{R}^{k_i, d}` are the empirical measures atoms locations + - :math:`\mathbf{b} \in \mathbb{R}^{k}` is the desired weights vector of the barycenter - This problem is considered in [1] (Algorithm 2). There are two differences with the following codes: + This problem is considered in :ref:`[1] ` (Algorithm 2). + There are two differences with the following codes: - we do not optimize over the weights - - we do not do line search for the locations updates, we use i.e. theta = 1 in [1] (Algorithm 2). This can be seen as a discrete implementation of the fixed-point algorithm of [2] proposed in the continuous setting. + - we do not do line search for the locations updates, we use i.e. :math:`\theta = 1` in + :ref:`[1] ` (Algorithm 2). This can be seen as a discrete + implementation of the fixed-point algorithm of + :ref:`[2] ` proposed in the continuous setting. Parameters ---------- measures_locations : list of N (k_i,d) numpy.ndarray - The discrete support of a measure supported on k_i locations of a d-dimensional space (k_i can be different for each element of the list) + The discrete support of a measure supported on :math:`k_i` locations of a `d`-dimensional space + (:math:`k_i` can be different for each element of the list) measures_weights : list of N (k_i,) numpy.ndarray - Numpy arrays where each numpy array has k_i non-negatives values summing to one representing the weights of each discrete input measure + Numpy arrays where each numpy array has :math:`k_i` non-negatives values summing to one + representing the weights of each discrete input measure X_init : (k,d) np.ndarray - Initialization of the support locations (on k atoms) of the barycenter + Initialization of the support locations (on `k` atoms) of the barycenter b : (k,) np.ndarray Initialization of the weights of the barycenter (non-negatives, sum to 1) weights : (N,) np.ndarray @@ -554,9 +567,10 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None X : (k,d) np.ndarray Support locations (on k atoms) of the barycenter + + .. _references-free-support-barycenter: References ---------- - .. [1] Cuturi, Marco, and Arnaud Doucet. "Fast computation of Wasserstein barycenters." International Conference on Machine Learning. 2014. .. [2] Álvarez-Esteban, Pedro C., et al. "A fixed-point approach to barycenters in Wasserstein space." Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762. diff --git a/ot/optim.py b/ot/optim.py index 6456c03..cc286b6 100644 --- a/ot/optim.py +++ b/ot/optim.py @@ -162,7 +162,8 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000, The function solves the following optimization problem: .. math:: - \gamma = \mathop{\arg \min}_\gamma <\gamma, \mathbf{M}>_F + \mathrm{reg} \cdot f(\gamma) + \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + + \mathrm{reg} \cdot f(\gamma) s.t. \ \gamma \mathbf{1} &= \mathbf{a} @@ -309,7 +310,8 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10, The function solves the following optimization problem: .. math:: - \gamma = \mathop{\arg \min}_\gamma <\gamma, \mathbf{M}>_F + \mathrm{reg_1}\cdot\Omega(\gamma) + \mathrm{reg_2}\cdot f(\gamma) + \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + + \mathrm{reg_1}\cdot\Omega(\gamma) + \mathrm{reg_2}\cdot f(\gamma) s.t. \ \gamma \mathbf{1} &= \mathbf{a} @@ -452,7 +454,7 @@ def solve_1d_linesearch_quad(a, b, c): .. math:: - \mathop{\arg \min}_{0 \leq x \leq 1} f(x) = ax^{2} + bx + c + \mathop{\arg \min}_{0 \leq x \leq 1} \quad f(x) = ax^{2} + bx + c Parameters ---------- diff --git a/ot/partial.py b/ot/partial.py index 814d779..b7093e4 100755 --- a/ot/partial.py +++ b/ot/partial.py @@ -20,13 +20,16 @@ def partial_wasserstein_lagrange(a, b, M, reg_m=None, nb_dummies=1, log=False, The function considers the following problem: .. math:: - \gamma = \arg\min_\gamma <\gamma,(M-\lambda)>_F + \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, (\mathbf{M} - \lambda) \rangle_F - s.t. - \gamma\geq 0 \\ - \gamma 1 \leq a\\ - \gamma^T 1 \leq b\\ - 1^T \gamma^T 1 = m \leq \min\{\|a\|_1, \|b\|_1\} + .. math:: + s.t. \ \gamma \mathbf{1} &\leq \mathbf{a} + + \gamma^T \mathbf{1} &\leq \mathbf{b} + + \gamma &\geq 0 + + \mathbf{1}^T \gamma^T \mathbf{1} = m &\leq \min\{\|\mathbf{a}\|_1, \|\mathbf{b}\|_1\} or equivalently (see Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. @@ -34,33 +37,32 @@ def partial_wasserstein_lagrange(a, b, M, reg_m=None, nb_dummies=1, log=False, metrics. Foundations of Computational Mathematics, 18(1), 1-44.) .. math:: - \gamma = \arg\min_\gamma <\gamma,M>_F + \sqrt(\lambda/2) - (\|\gamma 1 - a\|_1 + \|\gamma^T 1 - b\|_1) + \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + + \sqrt{\frac{\lambda}{2} (\|\gamma \mathbf{1} - \mathbf{a}\|_1 + \|\gamma^T \mathbf{1} - \mathbf{b}\|_1)} - s.t. - \gamma\geq 0 \\ + s.t. \ \gamma \geq 0 where : - - M is the metric cost matrix - - a and b are source and target unbalanced distributions - - :math:`\lambda` is the lagragian cost. Tuning its value allows attaining - a given mass to be transported m + - :math:`\mathbf{M}` is the metric cost matrix + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target unbalanced distributions + - :math:`\lambda` is the lagrangian cost. Tuning its value allows attaining + a given mass to be transported `m` - The formulation of the problem has been proposed in [28]_ + The formulation of the problem has been proposed in :ref:`[28] ` Parameters ---------- a : np.ndarray (dim_a,) - Unnormalized histogram of dimension dim_a + Unnormalized histogram of dimension `dim_a` b : np.ndarray (dim_b,) - Unnormalized histograms of dimension dim_b + Unnormalized histograms of dimension `dim_b` M : np.ndarray (dim_a, dim_b) cost matrix for the quadratic cost reg_m : float, optional - Lagragian cost + Lagrangian cost nb_dummies : int, optional, default:1 number of reservoir points to be added (to avoid numerical instabilities, increase its value if an error is raised) @@ -69,6 +71,7 @@ def partial_wasserstein_lagrange(a, b, M, reg_m=None, nb_dummies=1, log=False, **kwargs : dict parameters can be directly passed to the emd solver + .. warning:: When dealing with a large number of points, the EMD solver may face some instabilities, especially when the mass associated to the dummy @@ -77,7 +80,7 @@ def partial_wasserstein_lagrange(a, b, M, reg_m=None, nb_dummies=1, log=False, Returns ------- - gamma : (dim_a x dim_b) ndarray + gamma : (dim_a, dim_b) ndarray Optimal transportation matrix for the given parameters log : dict log dictionary returned only if `log` is `True` @@ -97,9 +100,10 @@ def partial_wasserstein_lagrange(a, b, M, reg_m=None, nb_dummies=1, log=False, array([[0.1, 0. ], [0. , 0. ]]) + + .. _references-partial-wasserstein-lagrange: References ---------- - .. [28] Caffarelli, L. A., & McCann, R. J. (2010) Free boundaries in optimal transport and Monge-Ampere obstacle problems. Annals of mathematics, 673-730. @@ -162,27 +166,30 @@ def partial_wasserstein(a, b, M, m=None, nb_dummies=1, log=False, **kwargs): The function considers the following problem: .. math:: - \gamma = \arg\min_\gamma <\gamma,M>_F + \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + + .. math:: + s.t. \ \gamma \mathbf{1} &\leq \mathbf{a} + + \gamma^T \mathbf{1} &\leq \mathbf{b} + + \gamma &\geq 0 - s.t. - \gamma\geq 0 \\ - \gamma 1 \leq a\\ - \gamma^T 1 \leq b\\ - 1^T \gamma^T 1 = m \leq \min\{\|a\|_1, \|b\|_1\} + \mathbf{1}^T \gamma^T \mathbf{1} = m &\leq \min\{\|\mathbf{a}\|_1, \|\mathbf{b}\|_1\} where : - - M is the metric cost matrix - - a and b are source and target unbalanced distributions - - m is the amount of mass to be transported + - :math:`\mathbf{M}` is the metric cost matrix + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target unbalanced distributions + - `m` is the amount of mass to be transported Parameters ---------- a : np.ndarray (dim_a,) - Unnormalized histogram of dimension dim_a + Unnormalized histogram of dimension `dim_a` b : np.ndarray (dim_b,) - Unnormalized histograms of dimension dim_b + Unnormalized histograms of dimension `dim_b` M : np.ndarray (dim_a, dim_b) cost matrix for the quadratic cost m : float, optional @@ -205,7 +212,7 @@ def partial_wasserstein(a, b, M, m=None, nb_dummies=1, log=False, **kwargs): Returns ------- - :math:`gamma` : (dim_a x dim_b) ndarray + gamma : (dim_a, dim_b) ndarray Optimal transportation matrix for the given parameters log : dict log dictionary returned only if `log` is `True` @@ -278,27 +285,30 @@ def partial_wasserstein2(a, b, M, m=None, nb_dummies=1, log=False, **kwargs): The function considers the following problem: .. math:: - \gamma = \arg\min_\gamma <\gamma,M>_F + \gamma = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F - s.t. - \gamma\geq 0 \\ - \gamma 1 \leq a\\ - \gamma^T 1 \leq b\\ - 1^T \gamma^T 1 = m \leq \min\{\|a\|_1, \|b\|_1\} + .. math:: + s.t. \ \gamma \mathbf{1} &\leq \mathbf{a} + + \gamma^T \mathbf{1} &\leq \mathbf{b} + + \gamma &\geq 0 + + \mathbf{1}^T \gamma^T \mathbf{1} = m &\leq \min\{\|\mathbf{a}\|_1, \|\mathbf{b}\|_1\} where : - - M is the metric cost matrix - - a and b are source and target unbalanced distributions - - m is the amount of mass to be transported + - :math:`\mathbf{M}` is the metric cost matrix + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target unbalanced distributions + - `m` is the amount of mass to be transported Parameters ---------- a : np.ndarray (dim_a,) - Unnormalized histogram of dimension dim_a + Unnormalized histogram of dimension `dim_a` b : np.ndarray (dim_b,) - Unnormalized histograms of dimension dim_b + Unnormalized histograms of dimension `dim_b` M : np.ndarray (dim_a, dim_b) cost matrix for the quadratic cost m : float, optional @@ -321,8 +331,8 @@ def partial_wasserstein2(a, b, M, m=None, nb_dummies=1, log=False, **kwargs): Returns ------- - :math:`gamma` : (dim_a x dim_b) ndarray - Optimal transportation matrix for the given parameters + GW: float + partial GW discrepancy log : dict log dictionary returned only if `log` is `True` @@ -360,8 +370,8 @@ def partial_wasserstein2(a, b, M, m=None, nb_dummies=1, log=False, **kwargs): def gwgrad_partial(C1, C2, T): - """Compute the GW gradient. Note: we can not use the trick in [12]_ as - the marginals may not sum to 1. + """Compute the GW gradient. Note: we can not use the trick in :ref:`[12] ` + as the marginals may not sum to 1. Parameters ---------- @@ -379,6 +389,8 @@ def gwgrad_partial(C1, C2, T): numpy.array of shape (n_p+nb_dummies, n_u) gradient + + .. _references-gwgrad-partial: References ---------- .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon, @@ -425,22 +437,25 @@ def partial_gromov_wasserstein(C1, C2, p, q, m=None, nb_dummies=1, G0=None, The function considers the following problem: .. math:: - \gamma = arg\min_\gamma <\gamma,M>_F + \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + + .. math:: + s.t. \ \gamma \mathbf{1} &\leq \mathbf{a} + + \gamma^T \mathbf{1} &\leq \mathbf{b} - s.t. \gamma 1 \leq a \\ - \gamma^T 1 \leq b \\ - \gamma\geq 0 \\ - 1^T \gamma^T 1 = m \leq \min\{\|a\|_1, \|b\|_1\} \\ + \gamma &\geq 0 + + \mathbf{1}^T \gamma^T \mathbf{1} = m &\leq \min\{\|\mathbf{a}\|_1, \|\mathbf{b}\|_1\} where : - - M is the metric cost matrix - - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma) - =\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - a and b are the sample weights - - m is the amount of mass to be transported + - :math:`\mathbf{M}` is the metric cost matrix + - :math:`\Omega` is the entropic regularization term, :math:`\Omega(\gamma) = \sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are the sample weights + - `m` is the amount of mass to be transported - The formulation of the problem has been proposed in [29]_ + The formulation of the problem has been proposed in :ref:`[29] ` Parameters @@ -454,7 +469,7 @@ def partial_gromov_wasserstein(C1, C2, p, q, m=None, nb_dummies=1, G0=None, q : ndarray, shape (nt,) Distribution in the target space m : float, optional - Amount of mass to be transported (default: min (|p|_1, |q|_1)) + Amount of mass to be transported (default: :math:`\min\{\|\mathbf{p}\|_1, \|\mathbf{q}\|_1\}`) nb_dummies : int, optional Number of dummy points to add (avoid instabilities in the EMD solver) G0 : ndarray, shape (ns, nt), optional @@ -476,7 +491,7 @@ def partial_gromov_wasserstein(C1, C2, p, q, m=None, nb_dummies=1, G0=None, Returns ------- - gamma : (dim_a x dim_b) ndarray + gamma : (dim_a, dim_b) ndarray Optimal transportation matrix for the given parameters log : dict log dictionary returned only if `log` is `True` @@ -503,6 +518,8 @@ def partial_gromov_wasserstein(C1, C2, p, q, m=None, nb_dummies=1, G0=None, [0. , 0. , 0.25, 0. ], [0. , 0. , 0. , 0. ]]) + + .. _references-partial-gromov-wasserstein: References ---------- .. [29] Chapel, L., Alaya, M., Gasso, G. (2020). "Partial Optimal @@ -597,22 +614,25 @@ def partial_gromov_wasserstein2(C1, C2, p, q, m=None, nb_dummies=1, G0=None, The function considers the following problem: .. math:: - \gamma = arg\min_\gamma <\gamma,M>_F + GW = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + + .. math:: + s.t. \ \gamma \mathbf{1} &\leq \mathbf{a} + + \gamma^T \mathbf{1} &\leq \mathbf{b} + + \gamma &\geq 0 - s.t. \gamma 1 \leq a \\ - \gamma^T 1 \leq b \\ - \gamma\geq 0 \\ - 1^T \gamma^T 1 = m \leq \min\{\|a\|_1, \|b\|_1\} \\ + \mathbf{1}^T \gamma^T \mathbf{1} = m &\leq \min\{\|\mathbf{a}\|_1, \|\mathbf{b}\|_1\} where : - - M is the metric cost matrix - - :math:`\Omega` is the entropic regularization term - :math:`\Omega=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - a and b are the sample weights - - m is the amount of mass to be transported + - :math:`\mathbf{M}` is the metric cost matrix + - :math:`\Omega` is the entropic regularization term, :math:`\Omega(\gamma) = \sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are the sample weights + - `m` is the amount of mass to be transported - The formulation of the problem has been proposed in [29]_ + The formulation of the problem has been proposed in :ref:`[29] ` Parameters @@ -626,7 +646,7 @@ def partial_gromov_wasserstein2(C1, C2, p, q, m=None, nb_dummies=1, G0=None, q : ndarray, shape (nt,) Distribution in the target space m : float, optional - Amount of mass to be transported (default: min (|p|_1, |q|_1)) + Amount of mass to be transported (default: :math:`\min\{\|\mathbf{p}\|_1, \|\mathbf{q}\|_1\}`) nb_dummies : int, optional Number of dummy points to add (avoid instabilities in the EMD solver) G0 : ndarray, shape (ns, nt), optional @@ -655,7 +675,7 @@ def partial_gromov_wasserstein2(C1, C2, p, q, m=None, nb_dummies=1, G0=None, Returns ------- - partial_gw_dist : (dim_a x dim_b) ndarray + partial_gw_dist : float partial GW discrepancy log : dict log dictionary returned only if `log` is `True` @@ -676,6 +696,8 @@ def partial_gromov_wasserstein2(C1, C2, p, q, m=None, nb_dummies=1, G0=None, >>> np.round(partial_gromov_wasserstein2(C1, C2, a, b, m=0.25),2) 0.0 + + .. _references-partial-gromov-wasserstein2: References ---------- .. [29] Chapel, L., Alaya, M., Gasso, G. (2020). "Partial Optimal @@ -706,30 +728,29 @@ def entropic_partial_wasserstein(a, b, M, reg, m=None, numItermax=1000, The function considers the following problem: .. math:: - \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + \mathrm{reg} \cdot\Omega(\gamma) - s.t. \gamma 1 \leq a \\ - \gamma^T 1 \leq b \\ - \gamma\geq 0 \\ - 1^T \gamma^T 1 = m \leq \min\{\|a\|_1, \|b\|_1\} \\ + s.t. \gamma \mathbf{1} &\leq \mathbf{a} \\ + \gamma^T \mathbf{1} &\leq \mathbf{b} \\ + \gamma &\geq 0 \\ + \mathbf{1}^T \gamma^T \mathbf{1} = m &\leq \min\{\|\mathbf{a}\|_1, \|\mathbf{b}\|_1\} \\ where : - - M is the metric cost matrix - - :math:`\Omega` is the entropic regularization term - :math:`\Omega=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - a and b are the sample weights - - m is the amount of mass to be transported + - :math:`\mathbf{M}` is the metric cost matrix + - :math:`\Omega` is the entropic regularization term, :math:`\Omega=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are the sample weights + - `m` is the amount of mass to be transported - The formulation of the problem has been proposed in [3]_ (prop. 5) + The formulation of the problem has been proposed in :ref:`[3] ` (prop. 5) Parameters ---------- a : np.ndarray (dim_a,) - Unnormalized histogram of dimension dim_a + Unnormalized histogram of dimension `dim_a` b : np.ndarray (dim_b,) - Unnormalized histograms of dimension dim_b + Unnormalized histograms of dimension `dim_b` M : np.ndarray (dim_a, dim_b) cost matrix reg : float @@ -748,7 +769,7 @@ def entropic_partial_wasserstein(a, b, M, reg, m=None, numItermax=1000, Returns ------- - gamma : (dim_a x dim_b) ndarray + gamma : (dim_a, dim_b) ndarray Optimal transportation matrix for the given parameters log : dict log dictionary returned only if `log` is `True` @@ -764,6 +785,8 @@ def entropic_partial_wasserstein(a, b, M, reg, m=None, numItermax=1000, array([[0.06, 0.02], [0.01, 0. ]]) + + .. _references-entropic-partial-wasserstein: References ---------- .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. @@ -838,32 +861,34 @@ def entropic_partial_gromov_wasserstein(C1, C2, p, q, reg, m=None, G0=None, numItermax=1000, tol=1e-7, log=False, verbose=False): r""" - Returns the partial Gromov-Wasserstein transport between (C1,p) and (C2,q) + Returns the partial Gromov-Wasserstein transport between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` The function solves the following optimization problem: .. math:: - GW = \arg\min_{\gamma} \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})\cdot - \gamma_{i,j}\cdot\gamma_{k,l} + reg\cdot\Omega(\gamma) + \gamma = \mathop{\arg \min}_{\gamma} \quad \sum_{i,j,k,l} + L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l})\cdot + \gamma_{i,j}\cdot\gamma_{k,l} + \mathrm{reg} \cdot\Omega(\gamma) + + .. math:: + s.t. \ \gamma &\geq 0 + + \gamma \mathbf{1} &\leq \mathbf{a} - s.t. - \gamma\geq 0 \\ - \gamma 1 \leq a\\ - \gamma^T 1 \leq b\\ - 1^T \gamma^T 1 = m \leq \min\{\|a\|_1, \|b\|_1\} + \gamma^T \mathbf{1} &\leq \mathbf{b} + + \mathbf{1}^T \gamma^T \mathbf{1} = m &\leq \min\{\|\mathbf{a}\|_1, \|\mathbf{b}\|_1\} where : - - C1 is the metric cost matrix in the source space - - C2 is the metric cost matrix in the target space - - p and q are the sample weights - - L : quadratic loss function - - :math:`\Omega` is the entropic regularization term - :math:`\Omega=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - m is the amount of mass to be transported + - :math:`\mathbf{C_1}` is the metric cost matrix in the source space + - :math:`\mathbf{C_2}` is the metric cost matrix in the target space + - :math:`\mathbf{p}` and :math:`\mathbf{q}` are the sample weights + - `L`: quadratic loss function + - :math:`\Omega` is the entropic regularization term, :math:`\Omega=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - `m` is the amount of mass to be transported - The formulation of the GW problem has been proposed in [12]_ and the - partial GW in [29]_. + The formulation of the GW problem has been proposed in :ref:`[12] ` and the partial GW in :ref:`[29] ` Parameters ---------- @@ -878,7 +903,7 @@ def entropic_partial_gromov_wasserstein(C1, C2, p, q, reg, m=None, G0=None, reg: float entropic regularization parameter m : float, optional - Amount of mass to be transported (default: min (|p|_1, |q|_1)) + Amount of mass to be transported (default: :math:`\min\{\|\mathbf{p}\|_1, \|\mathbf{q}\|_1\}`) G0 : ndarray, shape (ns, nt), optional Initialisation of the transportation matrix numItermax : int, optional @@ -913,17 +938,20 @@ def entropic_partial_gromov_wasserstein(C1, C2, p, q, reg, m=None, G0=None, Returns ------- - :math: `gamma` : (dim_a x dim_b) ndarray + :math: `gamma` : (dim_a, dim_b) ndarray Optimal transportation matrix for the given parameters log : dict log dictionary returned only if `log` is `True` + + .. _references-entropic-partial-gromov-wassertein: References ---------- .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon, "Gromov-Wasserstein averaging of kernel and distance matrices." International Conference on Machine Learning (ICML). 2016. - .. [29] Chapel, L., Alaya, M., Gasso, G. (2020). "Partial Optimal + + .. [29] Chapel, L., Alaya, M., Gasso, G. (2020). "Partial Optimal Transport with Applications on Positive-Unlabeled Learning". NeurIPS. @@ -977,33 +1005,33 @@ def entropic_partial_gromov_wasserstein2(C1, C2, p, q, reg, m=None, G0=None, numItermax=1000, tol=1e-7, log=False, verbose=False): r""" - Returns the partial Gromov-Wasserstein discrepancy between (C1,p) and - (C2,q) + Returns the partial Gromov-Wasserstein discrepancy between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` The function solves the following optimization problem: .. math:: - GW = \arg\min_{\gamma} \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})\cdot - \gamma_{i,j}\cdot\gamma_{k,l} + reg\cdot\Omega(\gamma) + GW = \min_{\gamma} \quad \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l})\cdot + \gamma_{i,j}\cdot\gamma_{k,l} + \mathrm{reg} \cdot\Omega(\gamma) + + .. math:: + s.t. \ \gamma &\geq 0 + + \gamma \mathbf{1} &\leq \mathbf{a} + + \gamma^T \mathbf{1} &\leq \mathbf{b} - s.t. - \gamma\geq 0 \\ - \gamma 1 \leq a\\ - \gamma^T 1 \leq b\\ - 1^T \gamma^T 1 = m \leq \min\{\|a\|_1, \|b\|_1\} + \mathbf{1}^T \gamma^T \mathbf{1} = m &\leq \min\{\|\mathbf{a}\|_1, \|\mathbf{b}\|_1\} where : - - C1 is the metric cost matrix in the source space - - C2 is the metric cost matrix in the target space - - p and q are the sample weights - - L : quadratic loss function - - :math:`\Omega` is the entropic regularization term - :math:`\Omega=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - m is the amount of mass to be transported + - :math:`\mathbf{C_1}` is the metric cost matrix in the source space + - :math:`\mathbf{C_2}` is the metric cost matrix in the target space + - :math:`\mathbf{p}` and :math:`\mathbf{q}` are the sample weights + - `L` : quadratic loss function + - :math:`\Omega` is the entropic regularization term, :math:`\Omega=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - `m` is the amount of mass to be transported - The formulation of the GW problem has been proposed in [12]_ and the - partial GW in [29]_. + The formulation of the GW problem has been proposed in :ref:`[12] ` and the partial GW in :ref:`[29] ` Parameters @@ -1019,7 +1047,7 @@ def entropic_partial_gromov_wasserstein2(C1, C2, p, q, reg, m=None, G0=None, reg: float entropic regularization parameter m : float, optional - Amount of mass to be transported (default: min (|p|_1, |q|_1)) + Amount of mass to be transported (default: :math:`\min\{\|\mathbf{p}\|_1, \|\mathbf{q}\|_1\}`) G0 : ndarray, shape (ns, nt), optional Initialisation of the transportation matrix numItermax : int, optional @@ -1052,11 +1080,14 @@ def entropic_partial_gromov_wasserstein2(C1, C2, p, q, reg, m=None, G0=None, >>> np.round(entropic_partial_gromov_wasserstein2(C1, C2, a, b,50), 2) 1.87 + + .. _references-entropic-partial-gromov-wassertein2: References ---------- .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon, "Gromov-Wasserstein averaging of kernel and distance matrices." International Conference on Machine Learning (ICML). 2016. + .. [29] Chapel, L., Alaya, M., Gasso, G. (2020). "Partial Optimal Transport with Applications on Positive-Unlabeled Learning". NeurIPS. diff --git a/ot/plot.py b/ot/plot.py index ad436b4..3e3bed7 100644 --- a/ot/plot.py +++ b/ot/plot.py @@ -18,10 +18,10 @@ from matplotlib import gridspec def plot1D_mat(a, b, M, title=''): - """ Plot matrix M with the source and target 1D distribution + """ Plot matrix :math:`\mathbf{M}` with the source and target 1D distribution - Creates a subplot with the source distribution a on the left and - target distribution b on the tot. The matrix M is shown in between. + Creates a subplot with the source distribution :math:`\mathbf{a}` on the left and + target distribution :math:`\mathbf{b}` on the top. The matrix :math:`\mathbf{M}` is shown in between. Parameters @@ -61,10 +61,10 @@ def plot1D_mat(a, b, M, title=''): def plot2D_samples_mat(xs, xt, G, thr=1e-8, **kwargs): - """ Plot matrix M in 2D with lines using alpha values + """ Plot matrix :math:`\mathbf{G}` in 2D with lines using alpha values Plot lines between source and target 2D samples with a color - proportional to the value of the matrix G between samples. + proportional to the value of the matrix :math:`\mathbf{G}` between samples. Parameters diff --git a/ot/sliced.py b/ot/sliced.py index d3dc3f2..7c09111 100644 --- a/ot/sliced.py +++ b/ot/sliced.py @@ -17,7 +17,7 @@ from .utils import list_to_array def get_random_projections(d, n_projections, seed=None, backend=None, type_as=None): r""" - Generates n_projections samples from the uniform on the unit sphere of dimension d-1: :math:`\mathcal{U}(\mathcal{S}^{d-1})` + Generates n_projections samples from the uniform on the unit sphere of dimension :math:`d-1`: :math:`\mathcal{U}(\mathcal{S}^{d-1})` Parameters ---------- @@ -67,11 +67,12 @@ def sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, p=2, Computes a Monte-Carlo approximation of the p-Sliced Wasserstein distance .. math:: - \mathcal{SWD}_p(\mu, \nu) = \underset{\theta \sim \mathcal{U}(\mathbb{S}^{d-1})}{\mathbb{E}}[\mathcal{W}_p^p(\theta_\# \mu, \theta_\# \nu)]^{\frac{1}{p}} + \mathcal{SWD}_p(\mu, \nu) = \underset{\theta \sim \mathcal{U}(\mathbb{S}^{d-1})}{\mathbb{E}}\left(\mathcal{W}_p^p(\theta_\# \mu, \theta_\# \nu)\right)^{\frac{1}{p}} + where : - - :math:`\theta_\# \mu` stands for the pushforwars of the projection :math:`\mathbb{R}^d \ni X \mapsto \langle \theta, X \rangle` + - :math:`\theta_\# \mu` stands for the pushforwards of the projection :math:`X \in \mathbb{R}^d \mapsto \langle \theta, X \rangle` Parameters diff --git a/ot/smooth.py b/ot/smooth.py index ea26bae..6855005 100644 --- a/ot/smooth.py +++ b/ot/smooth.py @@ -47,15 +47,24 @@ from scipy.optimize import minimize def projection_simplex(V, z=1, axis=None): - """ Projection of x onto the simplex, scaled by z + r""" Projection of :math:`\mathbf{V}` onto the simplex, scaled by `z` - P(x; z) = argmin_{y >= 0, sum(y) = z} ||y - x||^2 + .. math:: + P\left(\mathbf{V}, z\right) = \mathop{\arg \min}_{\substack{\mathbf{y} >= 0 \\ \sum_i \mathbf{y}_i = z}} \quad \|\mathbf{y} - \mathbf{V}\|^2 + + Parameters + ---------- + V: ndarray, rank 2 z: float or array - If array, len(z) must be compatible with V + If array, len(z) must be compatible with :math:`\mathbf{V}` axis: None or int - - axis=None: project V by P(V.ravel(); z) - - axis=1: project each V[i] by P(V[i]; z[i]) - - axis=0: project each V[:, j] by P(V[:, j]; z[j]) + - axis=None: project :math:`\mathbf{V}` by :math:`P(\mathbf{V}.\mathrm{ravel}(), z)` + - axis=1: project each :math:`\mathbf{V}_i` by :math:`P(\mathbf{V}_i, z_i)` + - axis=0: project each :math:`\mathbf{V}_{:, j}` by :math:`P(\mathbf{V}_{:, j}, z_j)` + + Returns + ------- + projection: ndarray, shape :math:`\mathbf{V}`.shape """ if axis == 1: n_features = V.shape[1] @@ -77,12 +86,12 @@ def projection_simplex(V, z=1, axis=None): class Regularization(object): - """Base class for Regularization objects + r"""Base class for Regularization objects Notes ----- - This class is not intended for direct use but as aparent for true - regularizatiojn implementation. + This class is not intended for direct use but as apparent for true + regularization implementation. """ def __init__(self, gamma=1.0): @@ -98,40 +107,48 @@ class Regularization(object): self.gamma = gamma def delta_Omega(X): - """ - Compute delta_Omega(X[:, j]) for each X[:, j]. - delta_Omega(x) = sup_{y >= 0} y^T x - Omega(y). + r""" + Compute :math:`\delta_\Omega(\mathbf{X}_{:, j})` for each :math:`\mathbf{X}_{:, j}`. + + .. math:: + \delta_\Omega(\mathbf{x}) = \sup_{\mathbf{y} >= 0} \ + \mathbf{y}^T \mathbf{x} - \Omega(\mathbf{y}) Parameters ---------- - X: array, shape = len(a) x len(b) + X: array, shape = (len(a), len(b)) Input array. Returns ------- - v: array, len(b) - Values: v[j] = delta_Omega(X[:, j]) - G: array, len(a) x len(b) - Gradients: G[:, j] = nabla delta_Omega(X[:, j]) + v: array, (len(b), ) + Values: :math:`\mathbf{v}_j = \delta_\Omega(\mathbf{X}_{:, j})` + G: array, (len(a), len(b)) + Gradients: :math:`\mathbf{G}_{:, j} = \nabla \delta_\Omega(\mathbf{X}_{:, j})` """ raise NotImplementedError def max_Omega(X, b): - """ - Compute max_Omega_j(X[:, j]) for each X[:, j]. - max_Omega_j(x) = sup_{y >= 0, sum(y) = 1} y^T x - Omega(b[j] y) / b[j]. + r""" + Compute :math:`\mathrm{max}_{\Omega, j}(\mathbf{X}_{:, j})` for each :math:`\mathbf{X}_{:, j}`. + + .. math:: + \mathrm{max}_{\Omega, j}(\mathbf{x}) = + \sup_{\substack{\mathbf{y} >= 0 \ \sum_i \mathbf{y}_i = 1}} + \mathbf{y}^T \mathbf{x} - \frac{1}{\mathbf{b}_j} \Omega(\mathbf{b}_j \mathbf{y}) Parameters ---------- - X: array, shape = len(a) x len(b) + X: array, shape = (len(a), len(b)) Input array. + b: array, shape = (len(b), ) Returns ------- - v: array, len(b) - Values: v[j] = max_Omega_j(X[:, j]) - G: array, len(a) x len(b) - Gradients: G[:, j] = nabla max_Omega_j(X[:, j]) + v: array, (len(b), ) + Values: :math:`\mathbf{v}_j = \mathrm{max}_{\Omega, j}(\mathbf{X}_{:, j})` + G: array, (len(a), len(b)) + Gradients: :math:`\mathbf{G}_{:, j} = \nabla \mathrm{max}_{\Omega, j}(\mathbf{X}_{:, j})` """ raise NotImplementedError @@ -192,7 +209,7 @@ class SquaredL2(Regularization): def dual_obj_grad(alpha, beta, a, b, C, regul): - """ + r""" Compute objective value and gradients of dual objective. Parameters @@ -203,19 +220,19 @@ def dual_obj_grad(alpha, beta, a, b, C, regul): a: array, shape = len(a) b: array, shape = len(b) Input histograms (should be non-negative and sum to 1). - C: array, shape = len(a) x len(b) + C: array, shape = (len(a), len(b)) Ground cost matrix. regul: Regularization object - Should implement a delta_Omega(X) method. + Should implement a `delta_Omega(X)` method. Returns ------- obj: float Objective value (higher is better). grad_alpha: array, shape = len(a) - Gradient w.r.t. alpha. + Gradient w.r.t. `alpha`. grad_beta: array, shape = len(b) - Gradient w.r.t. beta. + Gradient w.r.t. `beta`. """ obj = np.dot(alpha, a) + np.dot(beta, b) grad_alpha = a.copy() @@ -242,13 +259,13 @@ def solve_dual(a, b, C, regul, method="L-BFGS-B", tol=1e-3, max_iter=500, Parameters ---------- - a: array, shape = len(a) - b: array, shape = len(b) + a: array, shape = (len(a), ) + b: array, shape = (len(b), ) Input histograms (should be non-negative and sum to 1). - C: array, shape = len(a) x len(b) + C: array, shape = (len(a), len(b)) Ground cost matrix. regul: Regularization object - Should implement a delta_Omega(X) method. + Should implement a `delta_Omega(X)` method. method: str Solver to be used (passed to `scipy.optimize.minimize`). tol: float @@ -258,8 +275,8 @@ def solve_dual(a, b, C, regul, method="L-BFGS-B", tol=1e-3, max_iter=500, Returns ------- - alpha: array, shape = len(a) - beta: array, shape = len(b) + alpha: array, shape = (len(a), ) + beta: array, shape = (len(b), ) Dual potentials. """ @@ -302,10 +319,10 @@ def semi_dual_obj_grad(alpha, a, b, C, regul): a: array, shape = len(a) b: array, shape = len(b) Input histograms (should be non-negative and sum to 1). - C: array, shape = len(a) x len(b) + C: array, shape = (len(a), len(b)) Ground cost matrix. regul: Regularization object - Should implement a max_Omega(X) method. + Should implement a `max_Omega(X)` method. Returns ------- @@ -337,13 +354,13 @@ def solve_semi_dual(a, b, C, regul, method="L-BFGS-B", tol=1e-3, max_iter=500, Parameters ---------- - a: array, shape = len(a) - b: array, shape = len(b) + a: array, shape = (len(a), ) + b: array, shape = (len(b), ) Input histograms (should be non-negative and sum to 1). - C: array, shape = len(a) x len(b) + C: array, shape = (len(a), len(b)) Ground cost matrix. regul: Regularization object - Should implement a max_Omega(X) method. + Should implement a `max_Omega(X)` method. method: str Solver to be used (passed to `scipy.optimize.minimize`). tol: float @@ -353,7 +370,7 @@ def solve_semi_dual(a, b, C, regul, method="L-BFGS-B", tol=1e-3, max_iter=500, Returns ------- - alpha: array, shape = len(a) + alpha: array, shape = (len(a), ) Semi-dual potentials. """ @@ -371,7 +388,7 @@ def solve_semi_dual(a, b, C, regul, method="L-BFGS-B", tol=1e-3, max_iter=500, def get_plan_from_dual(alpha, beta, C, regul): - """ + r""" Retrieve optimal transportation plan from optimal dual potentials. Parameters @@ -379,14 +396,14 @@ def get_plan_from_dual(alpha, beta, C, regul): alpha: array, shape = len(a) beta: array, shape = len(b) Optimal dual potentials. - C: array, shape = len(a) x len(b) + C: array, shape = (len(a), len(b)) Ground cost matrix. regul: Regularization object - Should implement a delta_Omega(X) method. + Should implement a `delta_Omega(X)` method. Returns ------- - T: array, shape = len(a) x len(b) + T: array, shape = (len(a), len(b)) Optimal transportation plan. """ X = alpha[:, np.newaxis] + beta - C @@ -394,7 +411,7 @@ def get_plan_from_dual(alpha, beta, C, regul): def get_plan_from_semi_dual(alpha, b, C, regul): - """ + r""" Retrieve optimal transportation plan from optimal semi-dual potentials. Parameters @@ -403,14 +420,14 @@ def get_plan_from_semi_dual(alpha, b, C, regul): Optimal semi-dual potentials. b: array, shape = len(b) Second input histogram (should be non-negative and sum to 1). - C: array, shape = len(a) x len(b) + C: array, shape = (len(a), len(b)) Ground cost matrix. regul: Regularization object - Should implement a delta_Omega(X) method. + Should implement a `delta_Omega(X)` method. Returns ------- - T: array, shape = len(a) x len(b) + T: array, shape = (len(a), len(b)) Optimal transportation plan. """ X = alpha[:, np.newaxis] - C @@ -422,19 +439,21 @@ def smooth_ot_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=1e-9, r""" Solve the regularized OT problem in the dual and return the OT matrix - The function solves the smooth relaxed dual formulation (7) in [17]_ : + The function solves the smooth relaxed dual formulation (7) in + :ref:`[17] `: .. math:: - \max_{\alpha,\beta}\quad a^T\alpha+b^T\beta-\sum_j\delta_\Omega(\alpha+\beta_j-\mathbf{m}_j) + \max_{\alpha,\beta}\quad \mathbf{a}^T\alpha + \mathbf{b}^T\beta - + \sum_j \delta_\Omega \left(\alpha+\beta_j-\mathbf{m}_j \right) where : - - :math:`\mathbf{m}_j` is the jth column of the cost matrix + - :math:`\mathbf{m}_j` is the j-th column of the cost matrix - :math:`\delta_\Omega` is the convex conjugate of the regularization term :math:`\Omega` - - a and b are source and target weights (sum to 1) + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1) The OT matrix can is reconstructed from the gradient of :math:`\delta_\Omega` - (See [17]_ Proposition 1). + (See :ref:`[17] ` Proposition 1). The optimization algorithm is using gradient decent (L-BFGS by default). @@ -444,15 +463,19 @@ def smooth_ot_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=1e-9, samples weights in the source domain b : np.ndarray (nt,) or np.ndarray (nt,nbb) samples in the target domain, compute sinkhorn with multiple targets - and fixed M if b is a matrix (return OT loss + dual variables in log) + and fixed :math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix + (return OT loss + dual variables in log) M : np.ndarray (ns,nt) loss matrix reg : float Regularization term >0 reg_type : str - Regularization type, can be the following (default ='l2'): - - 'kl' : Kullback Leibler (~ Neg-entropy used in sinkhorn [2]_) - - 'l2' : Squared Euclidean regularization + Regularization type, can be the following (default ='l2'): + + - 'kl' : Kullback Leibler (~ Neg-entropy used in sinkhorn + :ref:`[2] `) + + - 'l2' : Squared Euclidean regularization method : str Solver to use for scipy.optimize.minimize numItermax : int, optional @@ -467,15 +490,15 @@ def smooth_ot_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=1e-9, Returns ------- - gamma : (ns x nt) ndarray + gamma : (ns, nt) ndarray Optimal transportation matrix for the given parameters log : dict log dictionary return only if log==True in parameters + .. _references-smooth-ot-dual: References ---------- - .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 .. [17] Blondel, M., Seguy, V., & Rolet, A. (2018). Smooth and Sparse Optimal Transport. Proceedings of the Twenty-First International Conference on Artificial Intelligence and Statistics (AISTATS). @@ -514,21 +537,23 @@ def smooth_ot_semi_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr= r""" Solve the regularized OT problem in the semi-dual and return the OT matrix - The function solves the smooth relaxed dual formulation (10) in [17]_ : + The function solves the smooth relaxed dual formulation (10) in + :ref:`[17] `: .. math:: - \max_{\alpha}\quad a^T\alpha-OT_\Omega^*(\alpha,b) + \max_{\alpha}\quad \mathbf{a}^T\alpha- \mathrm{OT}_\Omega^*(\alpha, \mathbf{b}) where : .. math:: - OT_\Omega^*(\alpha,b)=\sum_j b_j + \mathrm{OT}_\Omega^*(\alpha,b)=\sum_j \mathbf{b}_j - - :math:`\mathbf{m}_j` is the jth column of the cost matrix - - :math:`OT_\Omega^*(\alpha,b)` is defined in Eq. (9) in [17] - - a and b are source and target weights (sum to 1) + - :math:`\mathbf{m}_j` is the j-th column of the cost matrix + - :math:`\mathrm{OT}_\Omega^*(\alpha,b)` is defined in Eq. (9) in + :ref:`[17] ` + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1) - The OT matrix can is reconstructed using [17]_ Proposition 2. + The OT matrix can is reconstructed using :ref:`[17] ` Proposition 2. The optimization algorithm is using gradient decent (L-BFGS by default). @@ -538,15 +563,19 @@ def smooth_ot_semi_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr= samples weights in the source domain b : np.ndarray (nt,) or np.ndarray (nt,nbb) samples in the target domain, compute sinkhorn with multiple targets - and fixed M if b is a matrix (return OT loss + dual variables in log) + and fixed:math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix + (return OT loss + dual variables in log) M : np.ndarray (ns,nt) loss matrix reg : float Regularization term >0 reg_type : str - Regularization type, can be the following (default ='l2'): - - 'kl' : Kullback Leibler (~ Neg-entropy used in sinkhorn [2]_) - - 'l2' : Squared Euclidean regularization + Regularization type, can be the following (default ='l2'): + + - 'kl' : Kullback Leibler (~ Neg-entropy used in sinkhorn + :ref:`[2] `) + + - 'l2' : Squared Euclidean regularization method : str Solver to use for scipy.optimize.minimize numItermax : int, optional @@ -561,15 +590,15 @@ def smooth_ot_semi_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr= Returns ------- - gamma : (ns x nt) ndarray + gamma : (ns, nt) ndarray Optimal transportation matrix for the given parameters log : dict log dictionary return only if log==True in parameters + .. _references-smooth-ot-semi-dual: References ---------- - .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 .. [17] Blondel, M., Seguy, V., & Rolet, A. (2018). Smooth and Sparse Optimal Transport. Proceedings of the Twenty-First International Conference on Artificial Intelligence and Statistics (AISTATS). diff --git a/ot/stochastic.py b/ot/stochastic.py index 13ed9cc..693675f 100644 --- a/ot/stochastic.py +++ b/ot/stochastic.py @@ -18,22 +18,25 @@ import numpy as np def coordinate_grad_semi_dual(b, M, reg, beta, i): r''' - Compute the coordinate gradient update for regularized discrete distributions for (i, :) + Compute the coordinate gradient update for regularized discrete distributions for :math:`(i, :)` The function computes the gradient of the semi dual problem: .. math:: - \max_v \sum_i (\sum_j v_j * b_j - reg * log(\sum_j exp((v_j - M_{i,j})/reg) * b_j)) * a_i + \max_\mathbf{v} \ \sum_i \mathbf{a}_i \left[ \sum_j \mathbf{v}_j \mathbf{b}_j - \mathrm{reg} + \cdot \log \left( \sum_j \mathbf{b}_j + \exp \left( \frac{\mathbf{v}_j - \mathbf{M}_{i,j}}{\mathrm{reg}} + \right) \right) \right] Where : - - M is the (ns,nt) metric cost matrix - - v is a dual variable in R^J + - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix + - :math:`\mathbf{v}` is a dual variable in :math:`\mathbb{R}^{nt}` - reg is the regularization term - - a and b are source and target weights (sum to 1) + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1) The algorithm used for solving the problem is the ASGD & SAG algorithms - as proposed in [18]_ [alg.1 & alg.2] + as proposed in :ref:`[18] ` [alg.1 & alg.2] Parameters @@ -47,7 +50,7 @@ def coordinate_grad_semi_dual(b, M, reg, beta, i): v : ndarray, shape (nt,) Dual variable. i : int - Picked number i. + Picked number `i`. Returns ------- @@ -74,12 +77,10 @@ def coordinate_grad_semi_dual(b, M, reg, beta, i): [4.14237512e-02, 2.67487857e-02, 7.23016955e-02, 2.38291052e-03]]) + .. _references-coordinate-grad-semi-dual: References ---------- - [Genevay et al., 2016] : - Stochastic Optimization for Large-scale Optimal Transport, - Advances in Neural Information Processing Systems (2016), - arXiv preprint arxiv:1605.08527. + .. [18] Genevay, A., Cuturi, M., Peyré, G. & Bach, F. (2016) Stochastic Optimization for Large-scale Optimal Transport. Advances in Neural Information Processing Systems (2016). ''' r = M[i, :] - beta exp_beta = np.exp(-r / reg) * b @@ -88,29 +89,29 @@ def coordinate_grad_semi_dual(b, M, reg, beta, i): def sag_entropic_transport(a, b, M, reg, numItermax=10000, lr=None): - r''' - Compute the SAG algorithm to solve the regularized discrete measures - optimal transport max problem + r""" + Compute the SAG algorithm to solve the regularized discrete measures optimal transport max problem The function solves the following optimization problem: .. math:: - \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + + \mathrm{reg} \cdot\Omega(\gamma) - s.t. \gamma 1 = a + s.t. \ \gamma \mathbf{1} = \mathbf{a} - \gamma^T 1 = b + \gamma^T \mathbf{1} = \mathbf{b} \gamma \geq 0 Where : - - M is the (ns,nt) metric cost matrix + - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix - :math:`\Omega` is the entropic regularization term with :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - a and b are source and target weights (sum to 1) + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1) The algorithm used for solving the problem is the SAG algorithm - as proposed in [18]_ [alg.1] + as proposed in :ref:`[18] ` [alg.1] Parameters @@ -131,7 +132,7 @@ def sag_entropic_transport(a, b, M, reg, numItermax=10000, lr=None): Returns ------- - v : ndarray, shape (nt,) + v : ndarray, shape (`nt`,) Dual variable. Examples @@ -154,14 +155,12 @@ def sag_entropic_transport(a, b, M, reg, numItermax=10000, lr=None): [2.17218856e-02, 9.12931802e-04, 1.87962526e-03, 1.18342700e-01], [4.14237512e-02, 2.67487857e-02, 7.23016955e-02, 2.38291052e-03]]) + + .. _references-sag-entropic-transport: References ---------- - - [Genevay et al., 2016] : - Stochastic Optimization for Large-scale Optimal Transport, - Advances in Neural Information Processing Systems (2016), - arXiv preprint arxiv:1605.08527. - ''' + .. [18] Genevay, A., Cuturi, M., Peyré, G. & Bach, F. (2016) Stochastic Optimization for Large-scale Optimal Transport. Advances in Neural Information Processing Systems (2016). + """ if lr is None: lr = 1. / max(a / reg) @@ -187,22 +186,23 @@ def averaged_sgd_entropic_transport(a, b, M, reg, numItermax=300000, lr=None): The function solves the following optimization problem: .. math:: - \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + + \mathrm{reg}\cdot\Omega(\gamma) - s.t. \gamma 1 = a + s.t. \gamma \mathbf{1} = \mathbf{a} - \gamma^T 1= b + \gamma^T \mathbf{1} = \mathbf{b} \gamma \geq 0 Where : - - M is the (ns,nt) metric cost matrix + - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix - :math:`\Omega` is the entropic regularization term with :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - a and b are source and target weights (sum to 1) + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1) The algorithm used for solving the problem is the ASGD algorithm - as proposed in [18]_ [alg.2] + as proposed in :ref:`[18] ` [alg.2] Parameters @@ -220,7 +220,7 @@ def averaged_sgd_entropic_transport(a, b, M, reg, numItermax=300000, lr=None): Returns ------- - ave_v : ndarray, shape (nt,) + ave_v : ndarray, shape (`nt`,) dual variable Examples @@ -243,13 +243,11 @@ def averaged_sgd_entropic_transport(a, b, M, reg, numItermax=300000, lr=None): [2.17218856e-02, 9.12931802e-04, 1.87962526e-03, 1.18342700e-01], [4.14237512e-02, 2.67487857e-02, 7.23016955e-02, 2.38291052e-03]]) + + .. _references-averaged-sgd-entropic-transport: References ---------- - - [Genevay et al., 2016] : - Stochastic Optimization for Large-scale Optimal Transport, - Advances in Neural Information Processing Systems (2016), - arXiv preprint arxiv:1605.08527. + .. [18] Genevay, A., Cuturi, M., Peyré, G. & Bach, F. (2016) Stochastic Optimization for Large-scale Optimal Transport. Advances in Neural Information Processing Systems (2016). ''' if lr is None: @@ -271,20 +269,21 @@ def c_transform_entropic(b, M, reg, beta): r''' The goal is to recover u from the c-transform. - The function computes the c_transform of a dual variable from the other + The function computes the c-transform of a dual variable from the other dual variable: .. math:: - u = v^{c,reg} = -reg \sum_j exp((v - M)/reg) b_j + \mathbf{u} = \mathbf{v}^{c,reg} = - \mathrm{reg} \sum_j \mathbf{b}_j + \exp\left( \frac{\mathbf{v} - \mathbf{M}}{\mathrm{reg}} \right) Where : - - M is the (ns,nt) metric cost matrix - - u, v are dual variables in R^IxR^J + - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix + - :math:`\mathbf{u}`, :math:`\mathbf{v}` are dual variables in :math:`\mathbb{R}^{ns} \times \mathbb{R}^{nt}` - reg is the regularization term It is used to recover an optimal u from optimal v solving the semi dual - problem, see Proposition 2.1 of [18]_ + problem, see Proposition 2.1 of :ref:`[18] ` Parameters @@ -300,7 +299,7 @@ def c_transform_entropic(b, M, reg, beta): Returns ------- - u : ndarray, shape (ns,) + u : ndarray, shape (`ns`,) Dual variable. Examples @@ -323,13 +322,11 @@ def c_transform_entropic(b, M, reg, beta): [2.17218856e-02, 9.12931802e-04, 1.87962526e-03, 1.18342700e-01], [4.14237512e-02, 2.67487857e-02, 7.23016955e-02, 2.38291052e-03]]) + + .. _references-c-transform-entropic: References ---------- - - [Genevay et al., 2016] : - Stochastic Optimization for Large-scale Optimal Transport, - Advances in Neural Information Processing Systems (2016), - arXiv preprint arxiv:1605.08527. + .. [18] Genevay, A., Cuturi, M., Peyré, G. & Bach, F. (2016) Stochastic Optimization for Large-scale Optimal Transport. Advances in Neural Information Processing Systems (2016). ''' n_source = np.shape(M)[0] @@ -345,27 +342,28 @@ def c_transform_entropic(b, M, reg, beta): def solve_semi_dual_entropic(a, b, M, reg, method, numItermax=10000, lr=None, log=False): r''' - Compute the transportation matrix to solve the regularized discrete - measures optimal transport max problem + Compute the transportation matrix to solve the regularized discrete measures optimal transport max problem The function solves the following optimization problem: .. math:: - \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + + \mathrm{reg} \cdot\Omega(\gamma) - s.t. \gamma 1 = a + s.t. \ \gamma \mathbf{1} = \mathbf{a} - \gamma^T 1= b + \gamma^T \mathbf{1} = \mathbf{b} \gamma \geq 0 Where : - - M is the (ns,nt) metric cost matrix + - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix - :math:`\Omega` is the entropic regularization term with :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - a and b are source and target weights (sum to 1) + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1) + The algorithm used for solving the problem is the SAG or ASGD algorithms - as proposed in [18]_ + as proposed in :ref:`[18] ` Parameters @@ -419,13 +417,11 @@ def solve_semi_dual_entropic(a, b, M, reg, method, numItermax=10000, lr=None, [2.17218856e-02, 9.12931802e-04, 1.87962526e-03, 1.18342700e-01], [4.14237512e-02, 2.67487857e-02, 7.23016955e-02, 2.38291052e-03]]) + + .. _references-solve-semi-dual-entropic: References ---------- - - [Genevay et al., 2016] : - Stochastic Optimization for Large-scale Optimal Transport, - Advances in Neural Information Processing Systems (2016), - arXiv preprint arxiv:1605.08527. + .. [18] Genevay, A., Cuturi, M., Peyré, G. & Bach, F. (2016) Stochastic Optimization for Large-scale Optimal Transport. Advances in Neural Information Processing Systems (2016). ''' if method.lower() == "sag": @@ -459,26 +455,30 @@ def batch_grad_dual(a, b, M, reg, alpha, beta, batch_size, batch_alpha, r''' Computes the partial gradient of the dual optimal transport problem. - For each (i,j) in a batch of coordinates, the partial gradients are : + For each :math:`(i,j)` in a batch of coordinates, the partial gradients are : .. math:: - \partial_{u_i} F = u_i * b_s/l_{v} - \sum_{j \in B_v} exp((u_i + v_j - M_{i,j})/reg) * a_i * b_j + \partial_{\mathbf{u}_i} F = \frac{b_s}{l_v} \mathbf{u}_i - + \sum_{j \in B_v} \mathbf{a}_i \mathbf{b}_j + \exp\left( \frac{\mathbf{u}_i + \mathbf{v}_j - \mathbf{M}_{i,j}}{\mathrm{reg}} \right) - \partial_{v_j} F = v_j * b_s/l_{u} - \sum_{i \in B_u} exp((u_i + v_j - M_{i,j})/reg) * a_i * b_j + \partial_{\mathbf{v}_j} F = \frac{b_s}{l_u} \mathbf{v}_j - + \sum_{i \in B_u} \mathbf{a}_i \mathbf{b}_j + \exp\left( \frac{\mathbf{u}_i + \mathbf{v}_j - \mathbf{M}_{i,j}}{\mathrm{reg}} \right) Where : - - M is the (ns,nt) metric cost matrix - - u, v are dual variables in R^ixR^J + - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix + - :math:`\mathbf{u}`, :math:`\mathbf{v}` are dual variables in :math:`\mathbb{R}^{ns} \times \mathbb{R}^{nt}` - reg is the regularization term - :math:`B_u` and :math:`B_v` are lists of index - - :math:`b_s` is the size of the batchs :math:`B_u` and :math:`B_v` - - :math:`l_u` and :math:`l_v` are the lenghts of :math:`B_u` and :math:`B_v` - - a and b are source and target weights (sum to 1) + - :math:`b_s` is the size of the batches :math:`B_u` and :math:`B_v` + - :math:`l_u` and :math:`l_v` are the lengths of :math:`B_u` and :math:`B_v` + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1) The algorithm used for solving the dual problem is the SGD algorithm - as proposed in [19]_ [alg.1] + as proposed in :ref:`[19] ` [alg.1] Parameters @@ -504,7 +504,7 @@ def batch_grad_dual(a, b, M, reg, alpha, beta, batch_size, batch_alpha, Returns ------- - grad : ndarray, shape (ns,) + grad : ndarray, shape (`ns`,) partial grad F Examples @@ -533,12 +533,11 @@ def batch_grad_dual(a, b, M, reg, alpha, beta, batch_size, batch_alpha, [5.06266486e-02, 2.16230494e-03, 2.26215141e-03, 6.81514609e-04], [6.06713990e-02, 3.98139808e-02, 5.46829338e-02, 8.62371424e-06]]) + + .. _references-batch-grad-dual: References ---------- - - [Seguy et al., 2018] : - International Conference on Learning Representation (2018), - arXiv preprint arxiv:1711.02283. + .. [19] Seguy, V., Bhushan Damodaran, B., Flamary, R., Courty, N., Rolet, A.& Blondel, M. Large-scale Optimal Transport and Mapping Estimation. International Conference on Learning Representation (2018) ''' G = - (np.exp((alpha[batch_alpha, None] + beta[None, batch_beta] - M[batch_alpha, :][:, batch_beta]) / reg) * @@ -555,25 +554,25 @@ def batch_grad_dual(a, b, M, reg, alpha, beta, batch_size, batch_alpha, def sgd_entropic_regularization(a, b, M, reg, batch_size, numItermax, lr): r''' - Compute the sgd algorithm to solve the regularized discrete measures - optimal transport dual problem + Compute the sgd algorithm to solve the regularized discrete measures optimal transport dual problem The function solves the following optimization problem: .. math:: - \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + + \mathrm{reg} \cdot\Omega(\gamma) - s.t. \gamma 1 = a + s.t. \ \gamma \mathbf{1} = \mathbf{a} - \gamma^T 1= b + \gamma^T \mathbf{1} = \mathbf{b} \gamma \geq 0 Where : - - M is the (ns,nt) metric cost matrix + - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix - :math:`\Omega` is the entropic regularization term with :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - a and b are source and target weights (sum to 1) + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1) Parameters ---------- @@ -632,9 +631,7 @@ def sgd_entropic_regularization(a, b, M, reg, batch_size, numItermax, lr): References ---------- - [Seguy et al., 2018] : - International Conference on Learning Representation (2018), - arXiv preprint arxiv:1711.02283. + .. [19] Seguy, V., Bhushan Damodaran, B., Flamary, R., Courty, N., Rolet, A.& Blondel, M. Large-scale Optimal Transport and Mapping Estimation. International Conference on Learning Representation (2018) ''' n_source = np.shape(M)[0] @@ -657,25 +654,25 @@ def sgd_entropic_regularization(a, b, M, reg, batch_size, numItermax, lr): def solve_dual_entropic(a, b, M, reg, batch_size, numItermax=10000, lr=1, log=False): r''' - Compute the transportation matrix to solve the regularized discrete measures - optimal transport dual problem + Compute the transportation matrix to solve the regularized discrete measures optimal transport dual problem The function solves the following optimization problem: .. math:: - \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + + \mathrm{reg} \cdot\Omega(\gamma) - s.t. \gamma 1 = a + s.t. \ \gamma \mathbf{1} = \mathbf{a} - \gamma^T 1= b + \gamma^T \mathbf{1} = \mathbf{b} \gamma \geq 0 Where : - - M is the (ns,nt) metric cost matrix - - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - a and b are source and target weights (sum to 1) + - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix + - :math:`\Omega` is the entropic regularization term with :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1) Parameters ---------- @@ -736,10 +733,7 @@ def solve_dual_entropic(a, b, M, reg, batch_size, numItermax=10000, lr=1, References ---------- - - [Seguy et al., 2018] : - International Conference on Learning Representation (2018), - arXiv preprint arxiv:1711.02283. + .. [19] Seguy, V., Bhushan Damodaran, B., Flamary, R., Courty, N., Rolet, A.& Blondel, M. Large-scale Optimal Transport and Mapping Estimation. International Conference on Learning Representation (2018) ''' opt_alpha, opt_beta = sgd_entropic_regularization(a, b, M, reg, batch_size, diff --git a/ot/unbalanced.py b/ot/unbalanced.py index 6a61aa1..15e180b 100644 --- a/ot/unbalanced.py +++ b/ot/unbalanced.py @@ -23,29 +23,31 @@ def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', numItermax=1000, The function solves the following optimization problem: .. math:: - W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + reg_m KL(\gamma 1, a) + reg_m KL(\gamma^T 1, b) + W = \min_\gamma \ \langle \gamma, \mathbf{M} \rangle_F + \mathrm{reg}\cdot\Omega(\gamma) + + \mathrm{reg_m} \cdot \mathrm{KL}(\gamma \mathbf{1}, \mathbf{a}) + + \mathrm{reg_m} \cdot \mathrm{KL}(\gamma^T \mathbf{1}, \mathbf{b}) s.t. - \gamma\geq 0 + \gamma \geq 0 + where : - - M is the (dim_a, dim_b) metric cost matrix - - :math:`\Omega` is the entropic regularization - term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - a and b are source and target unbalanced distributions + - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix + - :math:`\Omega` is the entropic regularization term, :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target unbalanced distributions - KL is the Kullback-Leibler divergence The algorithm used for solving the problem is the generalized - Sinkhorn-Knopp matrix scaling algorithm as proposed in [10, 23]_ + Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[10, 25] ` Parameters ---------- a : np.ndarray (dim_a,) - Unnormalized histogram of dimension dim_a + Unnormalized histogram of dimension `dim_a` b : np.ndarray (dim_b,) or np.ndarray (dim_b, n_hists) - One or multiple unnormalized histograms of dimension dim_b - If many, compute all the OT distances (a, b_i) + One or multiple unnormalized histograms of dimension `dim_b`. + If many, compute all the OT distances :math:`(\mathbf{a}, \mathbf{b}_i)_i` M : np.ndarray (dim_a, dim_b) loss matrix reg : float @@ -68,14 +70,14 @@ def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', numItermax=1000, Returns ------- if n_hists == 1: - gamma : (dim_a x dim_b) ndarray + - gamma : (dim_a, dim_b) ndarray Optimal transportation matrix for the given parameters - log : dict + - log : dict log dictionary returned only if `log` is `True` else: - ot_distance : (n_hists,) ndarray - the OT distance between `a` and each of the histograms `b_i` - log : dict + - ot_distance : (n_hists,) ndarray + the OT distance between :math:`\mathbf{a}` and each of the histograms :math:`\mathbf{b}_i` + - log : dict log dictionary returned only if `log` is `True` Examples @@ -90,9 +92,9 @@ def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', numItermax=1000, [0.18807035, 0.51122823]]) + .. _references-sinkhorn-unbalanced: References ---------- - .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 @@ -111,11 +113,11 @@ def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', numItermax=1000, See Also -------- - ot.unbalanced.sinkhorn_knopp_unbalanced : Unbalanced Classic Sinkhorn [10] + ot.unbalanced.sinkhorn_knopp_unbalanced : Unbalanced Classic Sinkhorn :ref:`[10] ` ot.unbalanced.sinkhorn_stabilized_unbalanced: - Unbalanced Stabilized sinkhorn [9][10] + Unbalanced Stabilized sinkhorn :ref:`[9, 10] ` ot.unbalanced.sinkhorn_reg_scaling_unbalanced: - Unbalanced Sinkhorn with epslilon scaling [9][10] + Unbalanced Sinkhorn with epslilon scaling :ref:`[9, 10] ` """ @@ -151,29 +153,30 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn', The function solves the following optimization problem: .. math:: - W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + reg_m KL(\gamma 1, a) + reg_m KL(\gamma^T 1, b) + W = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + \mathrm{reg}\cdot\Omega(\gamma) + + \mathrm{reg_m} \cdot \mathrm{KL}(\gamma \mathbf{1}, \mathbf{a}) + + \mathrm{reg_m} \cdot \mathrm{KL}(\gamma^T \mathbf{1}, \mathbf{b}) s.t. \gamma\geq 0 where : - - M is the (dim_a, dim_b) metric cost matrix - - :math:`\Omega` is the entropic regularization term - :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - a and b are source and target unbalanced distributions + - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix + - :math:`\Omega` is the entropic regularization term, :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target unbalanced distributions - KL is the Kullback-Leibler divergence The algorithm used for solving the problem is the generalized - Sinkhorn-Knopp matrix scaling algorithm as proposed in [10, 23]_ + Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[10, 25] ` Parameters ---------- a : np.ndarray (dim_a,) - Unnormalized histogram of dimension dim_a + Unnormalized histogram of dimension `dim_a` b : np.ndarray (dim_b,) or np.ndarray (dim_b, n_hists) - One or multiple unnormalized histograms of dimension dim_b - If many, compute all the OT distances (a, b_i) + One or multiple unnormalized histograms of dimension `dim_b`. + If many, compute all the OT distances :math:`(\mathbf{a}, \mathbf{b}_i)_i` M : np.ndarray (dim_a, dim_b) loss matrix reg : float @@ -196,7 +199,7 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn', Returns ------- ot_distance : (n_hists,) ndarray - the OT distance between `a` and each of the histograms `b_i` + the OT distance between :math:`\mathbf{a}` and each of the histograms :math:`\mathbf{b}_i` log : dict log dictionary returned only if `log` is `True` @@ -211,10 +214,9 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn', array([0.31912866]) - + .. _references-sinkhorn-unbalanced2: References ---------- - .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 @@ -232,9 +234,9 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn', See Also -------- - ot.unbalanced.sinkhorn_knopp : Unbalanced Classic Sinkhorn [10] - ot.unbalanced.sinkhorn_stabilized: Unbalanced Stabilized sinkhorn [9][10] - ot.unbalanced.sinkhorn_reg_scaling: Unbalanced Sinkhorn with epslilon scaling [9][10] + ot.unbalanced.sinkhorn_knopp : Unbalanced Classic Sinkhorn :ref:`[10] ` + ot.unbalanced.sinkhorn_stabilized: Unbalanced Stabilized sinkhorn :ref:`[9, 10] ` + ot.unbalanced.sinkhorn_reg_scaling: Unbalanced Sinkhorn with epslilon scaling :ref:`[9, 10] ` """ b = np.asarray(b, dtype=np.float64) @@ -270,26 +272,29 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000, The function solves the following optimization problem: .. math:: - W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + \reg_m KL(\gamma 1, a) + \reg_m KL(\gamma^T 1, b) + W = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + \mathrm{reg}\cdot\Omega(\gamma) + + \mathrm{reg_m} \cdot \mathrm{KL}(\gamma \mathbf{1}, \mathbf{a}) + + \mathrm{reg_m} \cdot \mathrm{KL}(\gamma^T \mathbf{1}, \mathbf{b}) s.t. - \gamma\geq 0 + \gamma \geq 0 + where : - - M is the (dim_a, dim_b) metric cost matrix - - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - a and b are source and target unbalanced distributions + - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix + - :math:`\Omega` is the entropic regularization term, :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target unbalanced distributions - KL is the Kullback-Leibler divergence - The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in [10, 23]_ + The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[10, 25] ` Parameters ---------- a : np.ndarray (dim_a,) - Unnormalized histogram of dimension dim_a + Unnormalized histogram of dimension `dim_a` b : np.ndarray (dim_b,) or np.ndarray (dim_b, n_hists) - One or multiple unnormalized histograms of dimension dim_b + One or multiple unnormalized histograms of dimension `dim_b` If many, compute all the OT distances (a, b_i) M : np.ndarray (dim_a, dim_b) loss matrix @@ -310,15 +315,16 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000, Returns ------- if n_hists == 1: - gamma : (dim_a x dim_b) ndarray + - gamma : (dim_a, dim_b) ndarray Optimal transportation matrix for the given parameters - log : dict + - log : dict log dictionary returned only if `log` is `True` else: - ot_distance : (n_hists,) ndarray - the OT distance between `a` and each of the histograms `b_i` - log : dict + - ot_distance : (n_hists,) ndarray + the OT distance between :math:`\mathbf{a}` and each of the histograms :math:`\mathbf{b}_i` + - log : dict log dictionary returned only if `log` is `True` + Examples -------- @@ -330,9 +336,10 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000, array([[0.51122823, 0.18807035], [0.18807035, 0.51122823]]) + + .. _references-sinkhorn-knopp-unbalanced: References ---------- - .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. @@ -445,32 +452,34 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000 problem and return the loss The function solves the following optimization problem using log-domain - stabilization as proposed in [10]: + stabilization as proposed in :ref:`[10] `: .. math:: - W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + reg_m KL(\gamma 1, a) + reg_m KL(\gamma^T 1, b) + W = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + \mathrm{reg}\cdot\Omega(\gamma) + + \mathrm{reg_m} \cdot \mathrm{KL}(\gamma \mathbf{1}, \mathbf{a}) + + \mathrm{reg_m} \cdot \mathrm{KL}(\gamma^T \mathbf{1}, \mathbf{b}) s.t. - \gamma\geq 0 + \gamma \geq 0 + where : - - M is the (dim_a, dim_b) metric cost matrix - - :math:`\Omega` is the entropic regularization - term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - a and b are source and target unbalanced distributions + - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix + - :math:`\Omega` is the entropic regularization term, :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target unbalanced distributions - KL is the Kullback-Leibler divergence The algorithm used for solving the problem is the generalized - Sinkhorn-Knopp matrix scaling algorithm as proposed in [10, 23]_ + Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[10, 25] ` Parameters ---------- a : np.ndarray (dim_a,) - Unnormalized histogram of dimension dim_a + Unnormalized histogram of dimension `dim_a` b : np.ndarray (dim_b,) or np.ndarray (dim_b, n_hists) - One or multiple unnormalized histograms of dimension dim_b - If many, compute all the OT distances (a, b_i) + One or multiple unnormalized histograms of dimension `dim_b`. + If many, compute all the OT distances :math:`(\mathbf{a}, \mathbf{b}_i)_i` M : np.ndarray (dim_a, dim_b) loss matrix reg : float @@ -492,14 +501,14 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000 Returns ------- if n_hists == 1: - gamma : (dim_a x dim_b) ndarray + - gamma : (dim_a, dim_b) ndarray Optimal transportation matrix for the given parameters - log : dict + - log : dict log dictionary returned only if `log` is `True` else: - ot_distance : (n_hists,) ndarray - the OT distance between `a` and each of the histograms `b_i` - log : dict + - ot_distance : (n_hists,) ndarray + the OT distance between :math:`\mathbf{a}` and each of the histograms :math:`\mathbf{b}_i` + - log : dict log dictionary returned only if `log` is `True` Examples -------- @@ -512,9 +521,10 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000 array([[0.51122823, 0.18807035], [0.18807035, 0.51122823]]) + + .. _references-sinkhorn-stabilized-unbalanced: References ---------- - .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. @@ -654,29 +664,27 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000 def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3, numItermax=1000, stopThr=1e-6, verbose=False, log=False): - r"""Compute the entropic unbalanced wasserstein barycenter of A with stabilization. + r"""Compute the entropic unbalanced wasserstein barycenter of :math:`\mathbf{A}` with stabilization. The function solves the following optimization problem: .. math:: - \mathbf{a} = arg\min_\mathbf{a} \sum_i Wu_{reg}(\mathbf{a},\mathbf{a}_i) + \mathbf{a} = \mathop{\arg \min}_\mathbf{a} \quad \sum_i W_{u_{reg}}(\mathbf{a},\mathbf{a}_i) where : - - :math:`Wu_{reg}(\cdot,\cdot)` is the unbalanced entropic regularized - Wasserstein distance (see ot.unbalanced.sinkhorn_unbalanced) - - :math:`\mathbf{a}_i` are training distributions in the columns of - matrix :math:`\mathbf{A}` - - reg and :math:`\mathbf{M}` are respectively the regularization term and - the cost matrix for OT + - :math:`W_{u_{reg}}(\cdot,\cdot)` is the unbalanced entropic regularized Wasserstein distance (see :py:func:`ot.unbalanced.sinkhorn_unbalanced`) + - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}` + - reg and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT - reg_mis the marginal relaxation hyperparameter - The algorithm used for solving the problem is the generalized - Sinkhorn-Knopp matrix scaling algorithm as proposed in [10]_ + + The algorithm used for solving the problem is the generalized + Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[10] ` Parameters ---------- A : np.ndarray (dim, n_hists) - `n_hists` training distributions a_i of dimension dim + `n_hists` training distributions :math:`\mathbf{a}_i` of dimension `dim` M : np.ndarray (dim, dim) ground metric matrix for OT. reg : float @@ -706,9 +714,9 @@ def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3, log dictionary return only if log==True in parameters + .. _references-barycenter-unbalanced-stabilized: References ---------- - .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). Iterative Bregman projections for regularized transportation problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138. @@ -806,29 +814,27 @@ def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3, def barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, weights=None, numItermax=1000, stopThr=1e-6, verbose=False, log=False): - r"""Compute the entropic unbalanced wasserstein barycenter of A. + r"""Compute the entropic unbalanced wasserstein barycenter of :math:`\mathbf{A}`. - The function solves the following optimization problem with a + The function solves the following optimization problem with :math:`\mathbf{a}` .. math:: - \mathbf{a} = arg\min_\mathbf{a} \sum_i Wu_{reg}(\mathbf{a},\mathbf{a}_i) + \mathbf{a} = \mathop{\arg \min}_\mathbf{a} \quad \sum_i W_{u_{reg}}(\mathbf{a},\mathbf{a}_i) where : - - :math:`Wu_{reg}(\cdot,\cdot)` is the unbalanced entropic regularized - Wasserstein distance (see ot.unbalanced.sinkhorn_unbalanced) - - :math:`\mathbf{a}_i` are training distributions in the columns of matrix - :math:`\mathbf{A}` - - reg and :math:`\mathbf{M}` are respectively the regularization term and - the cost matrix for OT + - :math:`W_{u_{reg}}(\cdot,\cdot)` is the unbalanced entropic regularized Wasserstein distance (see :py:func:`ot.unbalanced.sinkhorn_unbalanced`) + - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}` + - reg and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT - reg_mis the marginal relaxation hyperparameter + The algorithm used for solving the problem is the generalized - Sinkhorn-Knopp matrix scaling algorithm as proposed in [10]_ + Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[10] ` Parameters ---------- A : np.ndarray (dim, n_hists) - `n_hists` training distributions a_i of dimension dim + `n_hists` training distributions :math:`\mathbf{a}_i` of dimension `dim` M : np.ndarray (dim, dim) ground metric matrix for OT. reg : float @@ -856,9 +862,9 @@ def barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, weights=None, log dictionary return only if log==True in parameters + .. _references-barycenter-unbalanced-sinkhorn: References ---------- - .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). Iterative Bregman projections for regularized transportation problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138. @@ -936,29 +942,27 @@ def barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, weights=None, def barycenter_unbalanced(A, M, reg, reg_m, method="sinkhorn", weights=None, numItermax=1000, stopThr=1e-6, verbose=False, log=False, **kwargs): - r"""Compute the entropic unbalanced wasserstein barycenter of A. + r"""Compute the entropic unbalanced wasserstein barycenter of :math:`\mathbf{A}`. - The function solves the following optimization problem with a + The function solves the following optimization problem with :math:`\mathbf{a}` .. math:: - \mathbf{a} = arg\min_\mathbf{a} \sum_i Wu_{reg}(\mathbf{a},\mathbf{a}_i) + \mathbf{a} = \mathop{\arg \min}_\mathbf{a} \quad \sum_i W_{u_{reg}}(\mathbf{a},\mathbf{a}_i) where : - - :math:`Wu_{reg}(\cdot,\cdot)` is the unbalanced entropic regularized - Wasserstein distance (see ot.unbalanced.sinkhorn_unbalanced) - - :math:`\mathbf{a}_i` are training distributions in the columns of matrix - :math:`\mathbf{A}` - - reg and :math:`\mathbf{M}` are respectively the regularization term and - the cost matrix for OT + - :math:`W_{u_{reg}}(\cdot,\cdot)` is the unbalanced entropic regularized Wasserstein distance (see :py:func:`ot.unbalanced.sinkhorn_unbalanced`) + - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}` + - reg and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT - reg_mis the marginal relaxation hyperparameter + The algorithm used for solving the problem is the generalized - Sinkhorn-Knopp matrix scaling algorithm as proposed in [10]_ + Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[10] ` Parameters ---------- A : np.ndarray (dim, n_hists) - `n_hists` training distributions a_i of dimension dim + `n_hists` training distributions :math:`\mathbf{a}_i` of dimension `dim` M : np.ndarray (dim, dim) ground metric matrix for OT. reg : float @@ -986,9 +990,9 @@ def barycenter_unbalanced(A, M, reg, reg_m, method="sinkhorn", weights=None, log dictionary return only if log==True in parameters + .. _references-barycenter-unbalanced: References ---------- - .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). Iterative Bregman projections for regularized transportation problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138. diff --git a/ot/utils.py b/ot/utils.py index 0608aee..c878563 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -21,26 +21,26 @@ __time_tic_toc = time.time() def tic(): - """ Python implementation of Matlab tic() function """ + r""" Python implementation of Matlab tic() function """ global __time_tic_toc __time_tic_toc = time.time() def toc(message='Elapsed time : {} s'): - """ Python implementation of Matlab toc() function """ + r""" Python implementation of Matlab toc() function """ t = time.time() print(message.format(t - __time_tic_toc)) return t - __time_tic_toc def toq(): - """ Python implementation of Julia toc() function """ + r""" Python implementation of Julia toc() function """ t = time.time() return t - __time_tic_toc def kernel(x1, x2, method='gaussian', sigma=1, **kwargs): - """Compute kernel matrix""" + r"""Compute kernel matrix""" nx = get_backend(x1, x2) @@ -50,13 +50,13 @@ def kernel(x1, x2, method='gaussian', sigma=1, **kwargs): def laplacian(x): - """Compute Laplacian matrix""" + r"""Compute Laplacian matrix""" L = np.diag(np.sum(x, axis=0)) - x return L def list_to_array(*lst): - """ Convert a list if in numpy format """ + r""" Convert a list if in numpy format """ if len(lst) > 1: return [np.array(a) if isinstance(a, list) else a for a in lst] else: @@ -64,17 +64,18 @@ def list_to_array(*lst): def proj_simplex(v, z=1): - r""" compute the closest point (orthogonal projection) on the - generalized (n-1)-simplex of a vector v wrt. to the Euclidean + r"""Compute the closest point (orthogonal projection) on the + generalized `(n-1)`-simplex of a vector :math:`\mathbf{v}` wrt. to the Euclidean distance, thus solving: + .. math:: - \mathcal{P}(w) \in arg\min_\gamma || \gamma - v ||_2 + \mathcal{P}(w) \in \mathop{\arg \min}_\gamma \| \gamma - \mathbf{v} \|_2 - s.t. \gamma^T 1= z + s.t. \ \gamma^T \mathbf{1} = z - \gamma\geq 0 + \gamma \geq 0 - If v is a 2d array, compute all the projections wrt. axis 0 + If :math:`\mathbf{v}` is a 2d array, compute all the projections wrt. axis 0 .. note:: This function is backend-compatible and will work on arrays from all compatible backends. @@ -87,7 +88,7 @@ def proj_simplex(v, z=1): Returns ------- - h : ndarray, shape (n,d) + h : ndarray, shape (`n`, `d`) Array of projections on the simplex """ nx = get_backend(v) @@ -116,26 +117,24 @@ def proj_simplex(v, z=1): def unif(n): - """ return a uniform histogram of length n (simplex) + r""" + Return a uniform histogram of length `n` (simplex). Parameters ---------- - n : int number of bins in the histogram Returns ------- - h : np.array (n,) - histogram of length n such that h_i=1/n for all i - - + h : np.array (`n`,) + histogram of length `n` such that :math:`\forall i, \mathbf{h}_i = \frac{1}{n}` """ return np.ones((n,)) / n def clean_zeros(a, b, M): - """ Remove all components with zeros weights in a and b + r""" Remove all components with zeros weights in :math:`\mathbf{a}` and :math:`\mathbf{b}` """ M2 = M[a > 0, :][:, b > 0].copy() # copy force c style matrix (froemd) a2 = a[a > 0] @@ -144,8 +143,8 @@ def clean_zeros(a, b, M): def euclidean_distances(X, Y, squared=False): - """ - Considering the rows of X (and Y=X) as vectors, compute the + r""" + Considering the rows of :math:`\mathbf{X}` (and :math:`\mathbf{Y} = \mathbf{X}`) as vectors, compute the distance matrix between each pair of vectors. .. note:: This function is backend-compatible and will work on arrays @@ -153,14 +152,14 @@ def euclidean_distances(X, Y, squared=False): Parameters ---------- - X : {array-like}, shape (n_samples_1, n_features) - Y : {array-like}, shape (n_samples_2, n_features) + X : array-like, shape (n_samples_1, n_features) + Y : array-like, shape (n_samples_2, n_features) squared : boolean, optional Return squared Euclidean distances. Returns ------- - distances : {array}, shape (n_samples_1, n_samples_2) + distances : array-like, shape (`n_samples_1`, `n_samples_2`) """ nx = get_backend(X, Y) @@ -184,7 +183,7 @@ def euclidean_distances(X, Y, squared=False): def dist(x1, x2=None, metric='sqeuclidean', p=2): - """Compute distance between samples in x1 and x2 + r"""Compute distance between samples in :math:`\mathbf{x_1}` and :math:`\mathbf{x_2}` .. note:: This function is backend-compatible and will work on arrays from all compatible backends. @@ -193,9 +192,9 @@ def dist(x1, x2=None, metric='sqeuclidean', p=2): ---------- x1 : array-like, shape (n1,d) - matrix with n1 samples of size d + matrix with `n1` samples of size `d` x2 : array-like, shape (n2,d), optional - matrix with n2 samples of size d (if None then x2=x1) + matrix with `n2` samples of size `d` (if None then :math:`\mathbf{x_2} = \mathbf{x_1}`) metric : str | callable, optional 'sqeuclidean' or 'euclidean' on all backends. On numpy the function also accepts from the scipy.spatial.distance.cdist function : 'braycurtis', @@ -208,7 +207,7 @@ def dist(x1, x2=None, metric='sqeuclidean', p=2): Returns ------- - M : array-like, shape (n1, n2) + M : array-like, shape (`n1`, `n2`) distance matrix computed with given metric """ @@ -226,7 +225,7 @@ def dist(x1, x2=None, metric='sqeuclidean', p=2): def dist0(n, method='lin_square'): - """Compute standard cost matrices of size (n, n) for OT problems + r"""Compute standard cost matrices of size (`n`, `n`) for OT problems Parameters ---------- @@ -235,11 +234,11 @@ def dist0(n, method='lin_square'): method : str, optional Type of loss matrix chosen from: - * 'lin_square' : linear sampling between 0 and n-1, quadratic loss + * 'lin_square' : linear sampling between 0 and `n-1`, quadratic loss Returns ------- - M : ndarray, shape (n1,n2) + M : ndarray, shape (`n1`, `n2`) Distance matrix computed with given metric. """ res = 0 @@ -250,7 +249,7 @@ def dist0(n, method='lin_square'): def cost_normalization(C, norm=None): - """ Apply normalization to the loss matrix + r""" Apply normalization to the loss matrix Parameters ---------- @@ -262,7 +261,7 @@ def cost_normalization(C, norm=None): Returns ------- - C : ndarray, shape (n1, n2) + C : ndarray, shape (`n1`, `n2`) The input cost matrix normalized according to given norm. """ @@ -284,23 +283,23 @@ def cost_normalization(C, norm=None): def dots(*args): - """ dots function for multiple matrix multiply """ + r""" dots function for multiple matrix multiply """ return reduce(np.dot, args) def label_normalization(y, start=0): - """ Transform labels to start at a given value + r""" Transform labels to start at a given value Parameters ---------- y : array-like, shape (n, ) The vector of labels to be normalized. start : int - Desired value for the smallest label in y (default=0) + Desired value for the smallest label in :math:`\mathbf{y}` (default=0) Returns ------- - y : array-like, shape (n1, ) + y : array-like, shape (`n1`, ) The input vector of labels normalized according to given start value. """ @@ -311,14 +310,14 @@ def label_normalization(y, start=0): def parmap(f, X, nprocs="default"): - """ paralell map for multiprocessing. + r""" parallel map for multiprocessing. The function has been deprecated and only performs a regular map. """ return list(map(f, X)) def check_params(**kwargs): - """check_params: check whether some parameters are missing + r"""check_params: check whether some parameters are missing """ missing_params = [] @@ -339,14 +338,14 @@ def check_params(**kwargs): def check_random_state(seed): - """Turn seed into a np.random.RandomState instance + r"""Turn `seed` into a np.random.RandomState instance Parameters ---------- seed : None | int | instance of RandomState - If seed is None, return the RandomState singleton used by np.random. - If seed is an int, return a new RandomState instance seeded with seed. - If seed is already a RandomState instance, return it. + If `seed` is None, return the RandomState singleton used by np.random. + If `seed` is an int, return a new RandomState instance seeded with `seed`. + If `seed` is already a RandomState instance, return it. Otherwise raise ValueError. """ if seed is None or seed is np.random: @@ -360,18 +359,21 @@ def check_random_state(seed): class deprecated(object): - """Decorator to mark a function or class as deprecated. + r"""Decorator to mark a function or class as deprecated. deprecated class from scikit-learn package https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/utils/deprecation.py Issue a warning when the function is called/the class is instantiated and adds a warning to the docstring. The optional extra argument will be appended to the deprecation message - and the docstring. Note: to use this with the default value for extra, put - in an empty of parentheses: - >>> from ot.deprecation import deprecated # doctest: +SKIP - >>> @deprecated() # doctest: +SKIP - ... def some_function(): pass # doctest: +SKIP + and the docstring. + + .. note:: + To use this with the default value for extra, use empty parentheses: + + >>> from ot.deprecation import deprecated # doctest: +SKIP + >>> @deprecated() # doctest: +SKIP + ... def some_function(): pass # doctest: +SKIP Parameters ---------- @@ -386,7 +388,7 @@ class deprecated(object): self.extra = extra def __call__(self, obj): - """Call method + r"""Call method Parameters ---------- obj : object @@ -417,7 +419,7 @@ class deprecated(object): return cls def _decorate_fun(self, fun): - """Decorate function fun""" + r"""Decorate function fun""" msg = "Function %s is deprecated" % fun.__name__ if self.extra: @@ -443,7 +445,7 @@ class deprecated(object): def _is_deprecated(func): - """Helper to check if func is wraped by our deprecated decorator""" + r"""Helper to check if func is wraped by our deprecated decorator""" if sys.version_info < (3, 5): raise NotImplementedError("This is only available for python3.5 " "or above") @@ -457,7 +459,7 @@ def _is_deprecated(func): class BaseEstimator(object): - """Base class for most objects in POT + r"""Base class for most objects in POT Code adapted from sklearn BaseEstimator class @@ -470,7 +472,7 @@ class BaseEstimator(object): @classmethod def _get_param_names(cls): - """Get parameter names for the estimator""" + r"""Get parameter names for the estimator""" # fetch the constructor or the original constructor before # deprecation wrapping if any @@ -497,7 +499,7 @@ class BaseEstimator(object): return sorted([p.name for p in parameters]) def get_params(self, deep=True): - """Get parameters for this estimator. + r"""Get parameters for this estimator. Parameters ---------- @@ -534,7 +536,7 @@ class BaseEstimator(object): return out def set_params(self, **params): - """Set the parameters of this estimator. + r"""Set the parameters of this estimator. The method works on simple estimators as well as on nested objects (such as pipelines). The latter have parameters of the form @@ -574,7 +576,7 @@ class BaseEstimator(object): class UndefinedParameter(Exception): - """ + r""" Aim at raising an Exception when a undefined parameter is called """ -- cgit v1.2.3 From 0e431c203a66c6d48e6bb1efeda149460472a0f0 Mon Sep 17 00:00:00 2001 From: Rémi Flamary Date: Thu, 4 Nov 2021 15:19:57 +0100 Subject: [MRG] Add tests about type and GPU for emd/emd2 + 1d variants + wasserstein1d (#304) * new test gpu * pep 8 of couse * debug torch * jax with gpu * device put * device put * it works * emd1d and emd2_1d working * emd_1d and emd2_1d done * cleanup * of course * should work on gpu now * tests done+ pep8 --- ot/backend.py | 20 ++++++++++- ot/lp/solver_1d.py | 14 ++++---- test/test_1d_solver.py | 93 ++++++++++++++++++++++++++++++++++++++++++++++++++ test/test_ot.py | 67 +++++++++++++++--------------------- 4 files changed, 146 insertions(+), 48 deletions(-) (limited to 'ot/lp') diff --git a/ot/backend.py b/ot/backend.py index d3df44c..55e10d3 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -102,6 +102,7 @@ class Backend(): __name__ = None __type__ = None + __type_list__ = None rng_ = None @@ -663,6 +664,8 @@ class NumpyBackend(Backend): __name__ = 'numpy' __type__ = np.ndarray + __type_list__ = [np.array(1, dtype=np.float32), + np.array(1, dtype=np.float64)] rng_ = np.random.RandomState() @@ -888,12 +891,17 @@ class JaxBackend(Backend): __name__ = 'jax' __type__ = jax_type + __type_list__ = None rng_ = None def __init__(self): self.rng_ = jax.random.PRNGKey(42) + for d in jax.devices(): + self.__type_list__ = [jax.device_put(jnp.array(1, dtype=np.float32), d), + jax.device_put(jnp.array(1, dtype=np.float64), d)] + def to_numpy(self, a): return np.array(a) @@ -901,7 +909,7 @@ class JaxBackend(Backend): if type_as is None: return jnp.array(a) else: - return jnp.array(a).astype(type_as.dtype) + return jax.device_put(jnp.array(a).astype(type_as.dtype), type_as.device_buffer.device()) def set_gradients(self, val, inputs, grads): from jax.flatten_util import ravel_pytree @@ -1130,6 +1138,7 @@ class TorchBackend(Backend): __name__ = 'torch' __type__ = torch_type + __type_list__ = None rng_ = None @@ -1138,6 +1147,13 @@ class TorchBackend(Backend): self.rng_ = torch.Generator() self.rng_.seed() + self.__type_list__ = [torch.tensor(1, dtype=torch.float32), + torch.tensor(1, dtype=torch.float64)] + + if torch.cuda.is_available(): + self.__type_list__.append(torch.tensor(1, dtype=torch.float32, device='cuda')) + self.__type_list__.append(torch.tensor(1, dtype=torch.float64, device='cuda')) + from torch.autograd import Function # define a function that takes inputs val and grads @@ -1160,6 +1176,8 @@ class TorchBackend(Backend): return a.cpu().detach().numpy() def from_numpy(self, a, type_as=None): + if isinstance(a, float): + a = np.array(a) if type_as is None: return torch.from_numpy(a) else: diff --git a/ot/lp/solver_1d.py b/ot/lp/solver_1d.py index 42554aa..8b4d0c3 100644 --- a/ot/lp/solver_1d.py +++ b/ot/lp/solver_1d.py @@ -235,8 +235,8 @@ def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True, # ensure that same mass np.testing.assert_almost_equal( - nx.sum(a, axis=0), - nx.sum(b, axis=0), + nx.to_numpy(nx.sum(a, axis=0)), + nx.to_numpy(nx.sum(b, axis=0)), err_msg='a and b vector must have the same sum' ) b = b * nx.sum(a) / nx.sum(b) @@ -247,10 +247,10 @@ def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True, perm_b = nx.argsort(x_b_1d) G_sorted, indices, cost = emd_1d_sorted( - nx.to_numpy(a[perm_a]), - nx.to_numpy(b[perm_b]), - nx.to_numpy(x_a_1d[perm_a]), - nx.to_numpy(x_b_1d[perm_b]), + nx.to_numpy(a[perm_a]).astype(np.float64), + nx.to_numpy(b[perm_b]).astype(np.float64), + nx.to_numpy(x_a_1d[perm_a]).astype(np.float64), + nx.to_numpy(x_b_1d[perm_b]).astype(np.float64), metric=metric, p=p ) @@ -266,7 +266,7 @@ def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True, elif str(nx) == "jax": warnings.warn("JAX does not support sparse matrices, converting to dense") if log: - log = {'cost': cost} + log = {'cost': nx.from_numpy(cost, type_as=x_a)} return G, log return G diff --git a/test/test_1d_solver.py b/test/test_1d_solver.py index 2c470c2..77b1234 100644 --- a/test/test_1d_solver.py +++ b/test/test_1d_solver.py @@ -83,3 +83,96 @@ def test_wasserstein_1d(nx): Xb = nx.from_numpy(X) res = wasserstein_1d(Xb, Xb, rho_ub, rho_vb, p=2) np.testing.assert_almost_equal(100 * res[0], res[1], decimal=4) + + +@pytest.mark.parametrize('nx', backend_list) +def test_wasserstein_1d_type_devices(nx): + + rng = np.random.RandomState(0) + + n = 10 + x = np.linspace(0, 5, n) + rho_u = np.abs(rng.randn(n)) + rho_u /= rho_u.sum() + rho_v = np.abs(rng.randn(n)) + rho_v /= rho_v.sum() + + for tp in nx.__type_list__: + + print(tp.dtype) + + xb = nx.from_numpy(x, type_as=tp) + rho_ub = nx.from_numpy(rho_u, type_as=tp) + rho_vb = nx.from_numpy(rho_v, type_as=tp) + + res = wasserstein_1d(xb, xb, rho_ub, rho_vb, p=1) + + if not str(nx) == 'numpy': + assert res.dtype == xb.dtype + + +def test_emd_1d_emd2_1d(): + # test emd1d gives similar results as emd + n = 20 + m = 30 + rng = np.random.RandomState(0) + u = rng.randn(n, 1) + v = rng.randn(m, 1) + + M = ot.dist(u, v, metric='sqeuclidean') + + G, log = ot.emd([], [], M, log=True) + wass = log["cost"] + G_1d, log = ot.emd_1d(u, v, [], [], metric='sqeuclidean', log=True) + wass1d = log["cost"] + wass1d_emd2 = ot.emd2_1d(u, v, [], [], metric='sqeuclidean', log=False) + wass1d_euc = ot.emd2_1d(u, v, [], [], metric='euclidean', log=False) + + # check loss is similar + np.testing.assert_allclose(wass, wass1d) + np.testing.assert_allclose(wass, wass1d_emd2) + + # check loss is similar to scipy's implementation for Euclidean metric + wass_sp = wasserstein_distance(u.reshape((-1,)), v.reshape((-1,))) + np.testing.assert_allclose(wass_sp, wass1d_euc) + + # check constraints + np.testing.assert_allclose(np.ones((n,)) / n, G.sum(1)) + np.testing.assert_allclose(np.ones((m,)) / m, G.sum(0)) + + # check G is similar + np.testing.assert_allclose(G, G_1d, atol=1e-15) + + # check AssertionError is raised if called on non 1d arrays + u = np.random.randn(n, 2) + v = np.random.randn(m, 2) + with pytest.raises(AssertionError): + ot.emd_1d(u, v, [], []) + + +def test_emd1d_type_devices(nx): + + rng = np.random.RandomState(0) + + n = 10 + x = np.linspace(0, 5, n) + rho_u = np.abs(rng.randn(n)) + rho_u /= rho_u.sum() + rho_v = np.abs(rng.randn(n)) + rho_v /= rho_v.sum() + + for tp in nx.__type_list__: + + print(tp.dtype) + + xb = nx.from_numpy(x, type_as=tp) + rho_ub = nx.from_numpy(rho_u, type_as=tp) + rho_vb = nx.from_numpy(rho_v, type_as=tp) + + emd = ot.emd_1d(xb, xb, rho_ub, rho_vb) + + emd2 = ot.emd2_1d(xb, xb, rho_ub, rho_vb) + + assert emd.dtype == xb.dtype + if not str(nx) == 'numpy': + assert emd2.dtype == xb.dtype diff --git a/test/test_ot.py b/test/test_ot.py index 5bfde1d..dc3930a 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -12,7 +12,6 @@ import pytest import ot from ot.datasets import make_1D_gauss as gauss from ot.backend import torch -from scipy.stats import wasserstein_distance def test_emd_dimension_and_mass_mismatch(): @@ -77,6 +76,33 @@ def test_emd2_backends(nx): np.allclose(val, nx.to_numpy(valb)) +def test_emd_emd2_types_devices(nx): + n_samples = 100 + n_features = 2 + rng = np.random.RandomState(0) + + x = rng.randn(n_samples, n_features) + y = rng.randn(n_samples, n_features) + a = ot.utils.unif(n_samples) + + M = ot.dist(x, y) + + for tp in nx.__type_list__: + + print(tp.dtype) + + ab = nx.from_numpy(a, type_as=tp) + Mb = nx.from_numpy(M, type_as=tp) + + Gb = ot.emd(ab, ab, Mb) + + w = ot.emd2(ab, ab, Mb) + + assert Gb.dtype == Mb.dtype + if not str(nx) == 'numpy': + assert w.dtype == Mb.dtype + + def test_emd2_gradients(): n_samples = 100 n_features = 2 @@ -126,45 +152,6 @@ def test_emd_emd2(): np.testing.assert_allclose(w, 0) -def test_emd_1d_emd2_1d(): - # test emd1d gives similar results as emd - n = 20 - m = 30 - rng = np.random.RandomState(0) - u = rng.randn(n, 1) - v = rng.randn(m, 1) - - M = ot.dist(u, v, metric='sqeuclidean') - - G, log = ot.emd([], [], M, log=True) - wass = log["cost"] - G_1d, log = ot.emd_1d(u, v, [], [], metric='sqeuclidean', log=True) - wass1d = log["cost"] - wass1d_emd2 = ot.emd2_1d(u, v, [], [], metric='sqeuclidean', log=False) - wass1d_euc = ot.emd2_1d(u, v, [], [], metric='euclidean', log=False) - - # check loss is similar - np.testing.assert_allclose(wass, wass1d) - np.testing.assert_allclose(wass, wass1d_emd2) - - # check loss is similar to scipy's implementation for Euclidean metric - wass_sp = wasserstein_distance(u.reshape((-1,)), v.reshape((-1,))) - np.testing.assert_allclose(wass_sp, wass1d_euc) - - # check constraints - np.testing.assert_allclose(np.ones((n,)) / n, G.sum(1)) - np.testing.assert_allclose(np.ones((m,)) / m, G.sum(0)) - - # check G is similar - np.testing.assert_allclose(G, G_1d, atol=1e-15) - - # check AssertionError is raised if called on non 1d arrays - u = np.random.randn(n, 2) - v = np.random.randn(m, 2) - with pytest.raises(AssertionError): - ot.emd_1d(u, v, [], []) - - def test_emd_empty(): # test emd and emd2 for simple identity n = 100 -- cgit v1.2.3