From 24a7a0439e631e90ff84ce84d0a78bc22846cf71 Mon Sep 17 00:00:00 2001 From: Panayiotis Panayiotou Date: Mon, 24 Aug 2020 15:40:05 +0300 Subject: Check if alpha is not None when restricting it to be at most 1 (#199) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Check if alpha is not None when restricting it to be at most 1 * Write check more clearly * Add no regression test for line search armijo returning None for alpha Co-authored-by: Rémi Flamary --- test/test_optim.py | 10 ++++++++++ 1 file changed, 10 insertions(+) (limited to 'test') diff --git a/test/test_optim.py b/test/test_optim.py index 87b0268..48de38a 100644 --- a/test/test_optim.py +++ b/test/test_optim.py @@ -104,3 +104,13 @@ def test_solve_1d_linesearch_quad_funct(): np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad(1, -1, 0), 0.5) np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad(-1, 5, 0), 0) np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad(-1, 0.5, 0), 1) + + +def test_line_search_armijo(): + xk = np.array([[0.25, 0.25], [0.25, 0.25]]) + pk = np.array([[-0.25, 0.25], [0.25, -0.25]]) + gfk = np.array([[23.04273441, 23.0449082], [23.04273441, 23.0449082]]) + old_fval = -123 + # Should not throw an exception and return None for alpha + alpha, _, _ = ot.optim.line_search_armijo(lambda x: 1, xk, pk, gfk, old_fval) + assert alpha is None -- cgit v1.2.3 From 7adc1b1aa73c55dc07983ff08dcb23fd71e9e8b6 Mon Sep 17 00:00:00 2001 From: Rémi Flamary Date: Thu, 22 Oct 2020 10:16:40 +0200 Subject: [MRG] Cleanup minimal build and add separate build for pep8 (#210) * cleanup requiorement minimal * add pep8 build * cleanup sklearn * skip test if no sklearn * debug build yaml * comment error out in test (test sklearn) * maybe small stuff for better robustness : copy the sub-array * bump verison minimal build * update version strict requireent * update version strict requirement last change --- .github/requirements_strict.txt | 9 +++------ .github/workflows/build_tests.yml | 37 +++++++++++++++++++++++++------------ .gitignore | 3 +++ Makefile | 4 ++-- ot/lp/__init__.py | 2 +- test/test_da.py | 8 ++++++++ 6 files changed, 42 insertions(+), 21 deletions(-) (limited to 'test') diff --git a/.github/requirements_strict.txt b/.github/requirements_strict.txt index d7539c5..9a1ada4 100644 --- a/.github/requirements_strict.txt +++ b/.github/requirements_strict.txt @@ -1,7 +1,4 @@ -numpy==1.16.* -scipy==1.0.* -cython==0.23.* -matplotlib -cvxopt -scikit-learn +numpy +scipy>=1.3 +cython pytest diff --git a/.github/workflows/build_tests.yml b/.github/workflows/build_tests.yml index 41b08b3..fa814ba 100644 --- a/.github/workflows/build_tests.yml +++ b/.github/workflows/build_tests.yml @@ -30,14 +30,7 @@ jobs: run: | python -m pip install --upgrade pip pip install -r requirements.txt - pip install flake8 pytest "pytest-cov<2.6" codecov - pip install -U "sklearn" - - name: Lint with flake8 - run: | - # stop the build if there are Python syntax errors or undefined names - flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics - # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide - flake8 examples/ ot/ test/ --count --max-line-length=127 --statistics + pip install pytest "pytest-cov<2.6" codecov - name: Install POT run: | pip install -e . @@ -48,6 +41,29 @@ jobs: run: | codecov + pep8: + runs-on: ubuntu-latest + strategy: + max-parallel: 4 + matrix: + python-version: [3.8] + + steps: + - uses: actions/checkout@v1 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v1 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install flake8 + - name: Lint with flake8 + run: | + # stop the build if there are Python syntax errors or undefined names + flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics + # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide + flake8 examples/ ot/ test/ --count --max-line-length=127 --statistics linux-minimal-deps: @@ -55,7 +71,7 @@ jobs: strategy: max-parallel: 4 matrix: - python-version: [3.6] + python-version: [3.8] steps: - uses: actions/checkout@v1 @@ -68,7 +84,6 @@ jobs: python -m pip install --upgrade pip pip install -r .github/requirements_strict.txt pip install pytest - pip install -U "sklearn" - name: Install POT run: | pip install -e . @@ -95,7 +110,6 @@ jobs: python -m pip install --upgrade pip pip install -r requirements.txt pip install pytest "pytest-cov<2.6" - pip install -U "sklearn" - name: Install POT run: | pip install -e . @@ -122,7 +136,6 @@ jobs: python -m pip install --upgrade pip pip install -r requirements.txt pip install pytest "pytest-cov<2.6" - pip install -U "sklearn" - name: Install POT run: | pip install -e . diff --git a/.gitignore b/.gitignore index a2ace7c..b44ea43 100644 --- a/.gitignore +++ b/.gitignore @@ -40,6 +40,9 @@ var/ *.manifest *.spec +# env +pythonenv3.8/ + # Installer logs pip-log.txt pip-delete-this-directory.txt diff --git a/Makefile b/Makefile index 70cdbdd..32332b4 100644 --- a/Makefile +++ b/Makefile @@ -45,10 +45,10 @@ pep8 : flake8 examples/ ot/ test/ test : FORCE pep8 - $(PYTHON) -m pytest -v test/ --doctest-modules --ignore ot/gpu/ --cov=ot --cov-report html:cov_html + $(PYTHON) -m pytest -v test/ --doctest-modules --ignore ot/gpu/ pytest : FORCE - $(PYTHON) -m pytest -v test/ --doctest-modules --ignore ot/gpu/ --cov=ot + $(PYTHON) -m pytest -v test/ --doctest-modules --ignore ot/gpu/ release : twine upload dist/* diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index 2a1b082..f08e020 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -426,7 +426,7 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), nb = b.shape[1] if processes > 1: - res = parmap(f, [b[:, i] for i in range(nb)], processes) + res = parmap(f, [b[:, i].copy() for i in range(nb)], processes) else: res = list(map(f, [b[:, i].copy() for i in range(nb)])) diff --git a/test/test_da.py b/test/test_da.py index 3b28119..52c6a48 100644 --- a/test/test_da.py +++ b/test/test_da.py @@ -6,11 +6,18 @@ import numpy as np from numpy.testing import assert_allclose, assert_equal +import pytest import ot from ot.datasets import make_data_classif from ot.utils import unif +try: # test if cudamat installed + import sklearn # noqa: F401 + nosklearn = False +except ImportError: + nosklearn = True + def test_sinkhorn_lpl1_transport_class(): """test_sinkhorn_transport @@ -691,6 +698,7 @@ def test_jcpot_barycenter(): np.testing.assert_allclose(prop, [1 - pt, pt], rtol=1e-3, atol=1e-3) +@pytest.mark.skipif(nosklearn, reason="No sklearn available") def test_emd_laplace_class(): """test_emd_laplace_transport """ -- cgit v1.2.3 From 78b44af2434f494c8f9e4c8c91003fbc0e1d4415 Mon Sep 17 00:00:00 2001 From: AdrienCorenflos Date: Thu, 22 Oct 2020 09:28:53 +0100 Subject: [MRG] Sliced wasserstein (#203) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * example for log treatment in bregman.py * Improve doc * Revert "example for log treatment in bregman.py" This reverts commit 9f51c14e * Add comments by Flamary * Delete repetitive description * Added raw string to avoid pbs with backslashes * Implements sliced wasserstein * Changed formatting of string for py3.5 support * Docstest, expected 0.0 and not 0. * Adressed comments by @rflamary * No 3d plot here * add sliced to the docs * Incorporate comments by @rflamary * add link to pdf Co-authored-by: Rémi Flamary --- README.md | 4 + docs/source/all.rst | 1 + examples/sliced-wasserstein/README.txt | 4 + examples/sliced-wasserstein/plot_variance.py | 84 ++++++++++++++++ ot/__init__.py | 3 +- ot/sliced.py | 144 +++++++++++++++++++++++++++ test/test_sliced.py | 85 ++++++++++++++++ 7 files changed, 324 insertions(+), 1 deletion(-) create mode 100644 examples/sliced-wasserstein/README.txt create mode 100644 examples/sliced-wasserstein/plot_variance.py create mode 100644 ot/sliced.py create mode 100644 test/test_sliced.py (limited to 'test') diff --git a/README.md b/README.md index e3598f1..6fe528a 100644 --- a/README.md +++ b/README.md @@ -33,6 +33,7 @@ POT provides the following generic OT solvers (links to examples): * [Unbalanced OT](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT_1D.html) with KL relaxation and [barycenter](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT_barycenter_1D.html) [10, 25]. * [Partial Wasserstein and Gromov-Wasserstein](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_partial_wass_and_gromov.html) (exact [29] and entropic [3] formulations). +* [Sliced Wasserstein](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_variance.html) [31, 32]. POT provides the following Machine Learning related solvers: @@ -180,6 +181,7 @@ The contributors to this library are * [Romain Tavenard](https://rtavenar.github.io/) (1d Wasserstein) * [Mokhtar Z. Alaya](http://mzalaya.github.io/) (Screenkhorn) * [Ievgen Redko](https://ievred.github.io/) (Laplacian DA, JCPOT) +* [Adrien Corenflos](https://adriencorenflos.github.io/) (Sliced Wasserstein Distance) This toolbox benefit a lot from open source research and we would like to thank the following persons for providing some code (in various languages): @@ -263,3 +265,5 @@ You can also post bug reports and feature requests in Github issues. Make sure t [29] Chapel, L., Alaya, M., Gasso, G. (2019). [Partial Gromov-Wasserstein with Applications on Positive-Unlabeled Learning](https://arxiv.org/abs/2002.08276), arXiv preprint arXiv:2002.08276. [30] Flamary R., Courty N., Tuia D., Rakotomamonjy A. (2014). [Optimal transport with Laplacian regularization: Applications to domain adaptation and shape matching](https://remi.flamary.com/biblio/flamary2014optlaplace.pdf), NIPS Workshop on Optimal Transport and Machine Learning OTML, 2014. + +[31] Bonneel, Nicolas, et al. [Sliced and radon wasserstein barycenters of measures](https://perso.liris.cnrs.fr/nicolas.bonneel/WassersteinSliced-JMIV.pdf), Journal of Mathematical Imaging and Vision 51.1 (2015): 22-45 diff --git a/docs/source/all.rst b/docs/source/all.rst index d7b878f..f1f7075 100644 --- a/docs/source/all.rst +++ b/docs/source/all.rst @@ -27,6 +27,7 @@ API and modules stochastic unbalanced partial + sliced .. autosummary:: :toctree: ../modules/generated/ diff --git a/examples/sliced-wasserstein/README.txt b/examples/sliced-wasserstein/README.txt new file mode 100644 index 0000000..a575345 --- /dev/null +++ b/examples/sliced-wasserstein/README.txt @@ -0,0 +1,4 @@ + + +Sliced Wasserstein Distance +--------------------------- \ No newline at end of file diff --git a/examples/sliced-wasserstein/plot_variance.py b/examples/sliced-wasserstein/plot_variance.py new file mode 100644 index 0000000..f3deeff --- /dev/null +++ b/examples/sliced-wasserstein/plot_variance.py @@ -0,0 +1,84 @@ +# -*- coding: utf-8 -*- +""" +============================== +2D Sliced Wasserstein Distance +============================== + +This example illustrates the computation of the sliced Wasserstein Distance as proposed in [31]. + +[31] Bonneel, Nicolas, et al. "Sliced and radon wasserstein barycenters of measures." Journal of Mathematical Imaging and Vision 51.1 (2015): 22-45 + +""" + +# Author: Adrien Corenflos +# +# License: MIT License + +import matplotlib.pylab as pl +import numpy as np + +import ot + +############################################################################## +# Generate data +# ------------- + +# %% parameters and data generation + +n = 500 # nb samples + +mu_s = np.array([0, 0]) +cov_s = np.array([[1, 0], [0, 1]]) + +mu_t = np.array([4, 4]) +cov_t = np.array([[1, -.8], [-.8, 1]]) + +xs = ot.datasets.make_2D_samples_gauss(n, mu_s, cov_s) +xt = ot.datasets.make_2D_samples_gauss(n, mu_t, cov_t) + +a, b = np.ones((n,)) / n, np.ones((n,)) / n # uniform distribution on samples + +############################################################################## +# Plot data +# --------- + +# %% plot samples + +pl.figure(1) +pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples') +pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples') +pl.legend(loc=0) +pl.title('Source and target distributions') + +################################################################################### +# Compute Sliced Wasserstein distance for different seeds and number of projections +# ----------- + +n_seed = 50 +n_projections_arr = np.logspace(0, 3, 25, dtype=int) +res = np.empty((n_seed, 25)) + +# %% Compute statistics +for seed in range(n_seed): + for i, n_projections in enumerate(n_projections_arr): + res[seed, i] = ot.sliced_wasserstein_distance(xs, xt, a, b, n_projections, seed) + +res_mean = np.mean(res, axis=0) +res_std = np.std(res, axis=0) + +################################################################################### +# Plot Sliced Wasserstein Distance +# ----------- + +pl.figure(2) +pl.plot(n_projections_arr, res_mean, label="SWD") +pl.fill_between(n_projections_arr, res_mean - 2 * res_std, res_mean + 2 * res_std, alpha=0.5) + +pl.legend() +pl.xscale('log') + +pl.xlabel("Number of projections") +pl.ylabel("Distance") +pl.title('Sliced Wasserstein Distance with 95% confidence inverval') + +pl.show() diff --git a/ot/__init__.py b/ot/__init__.py index 0e6e2e2..ec3ede2 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -39,6 +39,7 @@ from .lp import emd, emd2, emd_1d, emd2_1d, wasserstein_1d from .bregman import sinkhorn, sinkhorn2, barycenter from .unbalanced import sinkhorn_unbalanced, barycenter_unbalanced, sinkhorn_unbalanced2 from .da import sinkhorn_lpl1_mm +from .sliced import sliced_wasserstein_distance # utils functions from .utils import dist, unif, tic, toc, toq @@ -50,4 +51,4 @@ __all__ = ['emd', 'emd2', 'emd_1d', 'sinkhorn', 'sinkhorn2', 'utils', 'datasets' 'emd_1d', 'emd2_1d', 'wasserstein_1d', 'dist', 'unif', 'barycenter', 'sinkhorn_lpl1_mm', 'da', 'optim', 'sinkhorn_unbalanced', 'barycenter_unbalanced', - 'sinkhorn_unbalanced2'] + 'sinkhorn_unbalanced2', 'sliced_wasserstein_distance'] diff --git a/ot/sliced.py b/ot/sliced.py new file mode 100644 index 0000000..4792576 --- /dev/null +++ b/ot/sliced.py @@ -0,0 +1,144 @@ +""" +Sliced Wasserstein Distance. + +""" + +# Author: Adrien Corenflos +# +# License: MIT License + + +import numpy as np + + +def get_random_projections(n_projections, d, seed=None): + r""" + Generates n_projections samples from the uniform on the unit sphere of dimension d-1: :math:`\mathcal{U}(\mathcal{S}^{d-1})` + + Parameters + ---------- + n_projections : int + number of samples requested + d : int + dimension of the space + seed: int or RandomState, optional + Seed used for numpy random number generator + + Returns + ------- + out: ndarray, shape (n_projections, d) + The uniform unit vectors on the sphere + + Examples + -------- + >>> n_projections = 100 + >>> d = 5 + >>> projs = get_random_projections(n_projections, d) + >>> np.allclose(np.sum(np.square(projs), 1), 1.) # doctest: +NORMALIZE_WHITESPACE + True + + """ + + if not isinstance(seed, np.random.RandomState): + random_state = np.random.RandomState(seed) + else: + random_state = seed + + projections = random_state.normal(0., 1., [n_projections, d]) + norm = np.linalg.norm(projections, ord=2, axis=1, keepdims=True) + projections = projections / norm + return projections + + +def sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, seed=None, log=False): + r""" + Computes a Monte-Carlo approximation of the 2-Sliced Wasserstein distance + + .. math:: + \mathcal{SWD}_2(\mu, \nu) = \underset{\theta \sim \mathcal{U}(\mathbb{S}^{d-1})}{\mathbb{E}}[\mathcal{W}_2^2(\theta_\# \mu, \theta_\# \nu)]^{\frac{1}{2}} + + where : + + - :math:`\theta_\# \mu` stands for the pushforwars of the projection :math:`\mathbb{R}^d \ni X \mapsto \langle \theta, X \rangle` + + + Parameters + ---------- + X_s : ndarray, shape (n_samples_a, dim) + samples in the source domain + X_t : ndarray, shape (n_samples_b, dim) + samples in the target domain + a : ndarray, shape (n_samples_a,), optional + samples weights in the source domain + b : ndarray, shape (n_samples_b,), optional + samples weights in the target domain + n_projections : int, optional + Number of projections used for the Monte-Carlo approximation + seed: int or RandomState or None, optional + Seed used for numpy random number generator + log: bool, optional + if True, sliced_wasserstein_distance returns the projections used and their associated EMD. + + Returns + ------- + cost: float + Sliced Wasserstein Cost + log : dict, optional + log dictionary return only if log==True in parameters + + Examples + -------- + + >>> n_samples_a = 20 + >>> reg = 0.1 + >>> X = np.random.normal(0., 1., (n_samples_a, 5)) + >>> sliced_wasserstein_distance(X, X, seed=0) # doctest: +NORMALIZE_WHITESPACE + 0.0 + + References + ---------- + + .. [31] Bonneel, Nicolas, et al. "Sliced and radon wasserstein barycenters of measures." Journal of Mathematical Imaging and Vision 51.1 (2015): 22-45 + """ + from .lp import emd2_1d + + X_s = np.asanyarray(X_s) + X_t = np.asanyarray(X_t) + + n = X_s.shape[0] + m = X_t.shape[0] + + if X_s.shape[1] != X_t.shape[1]: + raise ValueError( + "X_s and X_t must have the same number of dimensions {} and {} respectively given".format(X_s.shape[1], + X_t.shape[1])) + + if a is None: + a = np.full(n, 1 / n) + if b is None: + b = np.full(m, 1 / m) + + d = X_s.shape[1] + + projections = get_random_projections(n_projections, d, seed) + + X_s_projections = np.dot(projections, X_s.T) + X_t_projections = np.dot(projections, X_t.T) + + if log: + projected_emd = np.empty(n_projections) + else: + projected_emd = None + + res = 0. + + for i, (X_s_proj, X_t_proj) in enumerate(zip(X_s_projections, X_t_projections)): + emd = emd2_1d(X_s_proj, X_t_proj, a, b, log=False, dense=False) + if projected_emd is not None: + projected_emd[i] = emd + res += emd + + res = (res / n_projections) ** 0.5 + if log: + return res, {"projections": projections, "projected_emds": projected_emd} + return res diff --git a/test/test_sliced.py b/test/test_sliced.py new file mode 100644 index 0000000..a07d975 --- /dev/null +++ b/test/test_sliced.py @@ -0,0 +1,85 @@ +"""Tests for module sliced""" + +# Author: Adrien Corenflos +# +# License: MIT License + +import numpy as np +import pytest + +import ot +from ot.sliced import get_random_projections + + +def test_get_random_projections(): + rng = np.random.RandomState(0) + projections = get_random_projections(1000, 50, rng) + np.testing.assert_almost_equal(np.sum(projections ** 2, 1), 1.) + + +def test_sliced_same_dist(): + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + u = ot.utils.unif(n) + + res = ot.sliced_wasserstein_distance(x, x, u, u, 10, seed=rng) + np.testing.assert_almost_equal(res, 0.) + + +def test_sliced_bad_shapes(): + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + y = rng.randn(n, 4) + u = ot.utils.unif(n) + + with pytest.raises(ValueError): + _ = ot.sliced_wasserstein_distance(x, y, u, u, 10, seed=rng) + + +def test_sliced_log(): + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 4) + y = rng.randn(n, 4) + u = ot.utils.unif(n) + + res, log = ot.sliced_wasserstein_distance(x, y, u, u, 10, seed=rng, log=True) + assert len(log) == 2 + projections = log["projections"] + projected_emds = log["projected_emds"] + + assert len(projections) == len(projected_emds) == 10 + for emd in projected_emds: + assert emd > 0 + + +def test_sliced_different_dists(): + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + u = ot.utils.unif(n) + y = rng.randn(n, 2) + + res = ot.sliced_wasserstein_distance(x, y, u, u, 10, seed=rng) + assert res > 0. + + +def test_1d_sliced_equals_emd(): + n = 100 + m = 120 + rng = np.random.RandomState(0) + + x = rng.randn(n, 1) + a = rng.uniform(0, 1, n) + a /= a.sum() + y = rng.randn(m, 1) + u = ot.utils.unif(m) + res = ot.sliced_wasserstein_distance(x, y, a, u, 10, seed=42) + expected = ot.emd2_1d(x.squeeze(), y.squeeze(), a, u) + np.testing.assert_almost_equal(res ** 2, expected) -- cgit v1.2.3 From 93785eba11b59d544f1edde6661e93ee587148ee Mon Sep 17 00:00:00 2001 From: Laetitia Chapel Date: Thu, 22 Oct 2020 10:58:31 +0200 Subject: [MRG] Fix bugs for partial OT (#215) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * bugfix * update refs partial OT * fixes small typos in plot_partial_wass_and_gromov * fix small bugs in partial.py * update README * pep8 bugfix * modif doctest * fix bugtests * update on test_partial and test on the numerical precision on ot/partial * resolve merge pb Co-authored-by: Rémi Flamary --- README.md | 2 +- .../plot_partial_wass_and_gromov.py | 23 ++++--- ot/partial.py | 71 +++++++++++++--------- test/test_partial.py | 6 +- 4 files changed, 60 insertions(+), 42 deletions(-) (limited to 'test') diff --git a/README.md b/README.md index 6fe528a..238faed 100644 --- a/README.md +++ b/README.md @@ -262,7 +262,7 @@ You can also post bug reports and feature requests in Github issues. Make sure t [28] Caffarelli, L. A., McCann, R. J. (2010). [Free boundaries in optimal transport and Monge-Ampere obstacle problems](http://www.math.toronto.edu/~mccann/papers/annals2010.pdf), Annals of mathematics, 673-730. -[29] Chapel, L., Alaya, M., Gasso, G. (2019). [Partial Gromov-Wasserstein with Applications on Positive-Unlabeled Learning](https://arxiv.org/abs/2002.08276), arXiv preprint arXiv:2002.08276. +[29] Chapel, L., Alaya, M., Gasso, G. (2020). [Partial Optimal Transport with Applications on Positive-Unlabeled Learning](https://arxiv.org/abs/2002.08276), Advances in Neural Information Processing Systems (NeurIPS), 2020. [30] Flamary R., Courty N., Tuia D., Rakotomamonjy A. (2014). [Optimal transport with Laplacian regularization: Applications to domain adaptation and shape matching](https://remi.flamary.com/biblio/flamary2014optlaplace.pdf), NIPS Workshop on Optimal Transport and Machine Learning OTML, 2014. diff --git a/examples/unbalanced-partial/plot_partial_wass_and_gromov.py b/examples/unbalanced-partial/plot_partial_wass_and_gromov.py index 0c5cbf9..ac4194c 100755 --- a/examples/unbalanced-partial/plot_partial_wass_and_gromov.py +++ b/examples/unbalanced-partial/plot_partial_wass_and_gromov.py @@ -4,7 +4,7 @@ Partial Wasserstein and Gromov-Wasserstein example ================================================== -This example is designed to show how to use the Partial (Gromov-)Wassertsein +This example is designed to show how to use the Partial (Gromov-)Wasserstein distance computation in POT. """ @@ -123,11 +123,12 @@ C1 = sp.spatial.distance.cdist(xs, xs) C2 = sp.spatial.distance.cdist(xt, xt) # transport 100% of the mass -print('-----m = 1') +print('------m = 1') m = 1 res0, log0 = ot.partial.partial_gromov_wasserstein(C1, C2, p, q, m=m, log=True) res, log = ot.partial.entropic_partial_gromov_wasserstein(C1, C2, p, q, 10, - m=m, log=True) + m=m, log=True, + verbose=True) print('Wasserstein distance (m = 1): ' + str(log0['partial_gw_dist'])) print('Entropic Wasserstein distance (m = 1): ' + str(log['partial_gw_dist'])) @@ -136,18 +137,20 @@ pl.figure(1, (10, 5)) pl.title("mass to be transported m = 1") pl.subplot(1, 2, 1) pl.imshow(res0, cmap='jet') -pl.title('Wasserstein') +pl.title('Gromov-Wasserstein') pl.subplot(1, 2, 2) pl.imshow(res, cmap='jet') -pl.title('Entropic Wasserstein') +pl.title('Entropic Gromov-Wasserstein') pl.show() # transport 2/3 of the mass -print('-----m = 2/3') +print('------m = 2/3') m = 2 / 3 -res0, log0 = ot.partial.partial_gromov_wasserstein(C1, C2, p, q, m=m, log=True) +res0, log0 = ot.partial.partial_gromov_wasserstein(C1, C2, p, q, m=m, log=True, + verbose=True) res, log = ot.partial.entropic_partial_gromov_wasserstein(C1, C2, p, q, 10, - m=m, log=True) + m=m, log=True, + verbose=True) print('Partial Wasserstein distance (m = 2/3): ' + str(log0['partial_gw_dist'])) @@ -158,8 +161,8 @@ pl.figure(1, (10, 5)) pl.title("mass to be transported m = 2/3") pl.subplot(1, 2, 1) pl.imshow(res0, cmap='jet') -pl.title('Partial Wasserstein') +pl.title('Partial Gromov-Wasserstein') pl.subplot(1, 2, 2) pl.imshow(res, cmap='jet') -pl.title('Entropic partial Wasserstein') +pl.title('Entropic partial Gromov-Wasserstein') pl.show() diff --git a/ot/partial.py b/ot/partial.py index eb707d8..814d779 100755 --- a/ot/partial.py +++ b/ot/partial.py @@ -230,9 +230,9 @@ def partial_wasserstein(a, b, M, m=None, nb_dummies=1, log=False, **kwargs): .. [28] Caffarelli, L. A., & McCann, R. J. (2010) Free boundaries in optimal transport and Monge-Ampere obstacle problems. Annals of mathematics, 673-730. - .. [29] Chapel, L., Alaya, M., Gasso, G. (2019). "Partial Gromov- - Wasserstein with Applications on Positive-Unlabeled Learning". - arXiv preprint arXiv:2002.08276. + .. [29] Chapel, L., Alaya, M., Gasso, G. (2020). "Partial Optimal + Transport with Applications on Positive-Unlabeled Learning". + NeurIPS. See Also -------- @@ -254,7 +254,7 @@ def partial_wasserstein(a, b, M, m=None, nb_dummies=1, log=False, **kwargs): b_extended = np.append(b, [(np.sum(a) - m) / nb_dummies] * nb_dummies) a_extended = np.append(a, [(np.sum(b) - m) / nb_dummies] * nb_dummies) M_extended = np.zeros((len(a_extended), len(b_extended))) - M_extended[-1, -1] = np.max(M) * 1e5 + M_extended[-nb_dummies:, -nb_dummies:] = np.max(M) * 1e5 M_extended[:len(a), :len(b)] = M gamma, log_emd = emd(a_extended, b_extended, M_extended, log=True, @@ -344,14 +344,13 @@ def partial_wasserstein2(a, b, M, m=None, nb_dummies=1, log=False, **kwargs): .. [28] Caffarelli, L. A., & McCann, R. J. (2010) Free boundaries in optimal transport and Monge-Ampere obstacle problems. Annals of mathematics, 673-730. - .. [29] Chapel, L., Alaya, M., Gasso, G. (2019). "Partial Gromov- - Wasserstein with Applications on Positive-Unlabeled Learning". - arXiv preprint arXiv:2002.08276. + .. [29] Chapel, L., Alaya, M., Gasso, G. (2020). "Partial Optimal + Transport with Applications on Positive-Unlabeled Learning". + NeurIPS. """ partial_gw, log_w = partial_wasserstein(a, b, M, m, nb_dummies, log=True, **kwargs) - log_w['T'] = partial_gw if log: @@ -501,14 +500,14 @@ def partial_gromov_wasserstein(C1, C2, p, q, m=None, nb_dummies=1, G0=None, >>> np.round(partial_gromov_wasserstein(C1, C2, a, b, m=0.25),2) array([[0. , 0. , 0. , 0. ], [0. , 0. , 0. , 0. ], - [0. , 0. , 0. , 0. ], - [0. , 0. , 0. , 0.25]]) + [0. , 0. , 0.25, 0. ], + [0. , 0. , 0. , 0. ]]) References ---------- - .. [29] Chapel, L., Alaya, M., Gasso, G. (2019). "Partial Gromov- - Wasserstein with Applications on Positive-Unlabeled Learning". - arXiv preprint arXiv:2002.08276. + .. [29] Chapel, L., Alaya, M., Gasso, G. (2020). "Partial Optimal + Transport with Applications on Positive-Unlabeled Learning". + NeurIPS. """ @@ -530,20 +529,18 @@ def partial_gromov_wasserstein(C1, C2, p, q, m=None, nb_dummies=1, G0=None, cpt = 0 err = 1 - eps = 1e-20 + if log: log = {'err': []} while (err > tol and cpt < numItermax): - Gprev = G0 + Gprev = np.copy(G0) M = gwgrad_partial(C1, C2, G0) - M[M < eps] = np.quantile(M, thres) - M_emd = np.zeros(dim_G_extended) M_emd[:len(p), :len(q)] = M - M_emd[-nb_dummies:, -nb_dummies:] = np.max(M) * 1e5 + M_emd[-nb_dummies:, -nb_dummies:] = np.max(M) * 1e2 M_emd = np.asarray(M_emd, dtype=np.float64) Gc, logemd = emd(p_extended, q_extended, M_emd, log=True, **kwargs) @@ -565,6 +562,22 @@ def partial_gromov_wasserstein(C1, C2, p, q, m=None, nb_dummies=1, G0=None, print('{:5d}|{:8e}|{:8e}'.format(cpt, err, gwloss_partial(C1, C2, G0))) + deltaG = G0 - Gprev + a = gwloss_partial(C1, C2, deltaG) + b = 2 * np.sum(M * deltaG) + if b > 0: # due to numerical precision + gamma = 0 + cpt = numItermax + elif a > 0: + gamma = min(1, np.divide(-b, 2.0 * a)) + else: + if (a + b) < 0: + gamma = 1 + else: + gamma = 0 + cpt = numItermax + + G0 = Gprev + gamma * deltaG cpt += 1 if log: @@ -665,9 +678,9 @@ def partial_gromov_wasserstein2(C1, C2, p, q, m=None, nb_dummies=1, G0=None, References ---------- - .. [29] Chapel, L., Alaya, M., Gasso, G. (2019). "Partial Gromov- - Wasserstein with Applications on Positive-Unlabeled Learning". - arXiv preprint arXiv:2002.08276. + .. [29] Chapel, L., Alaya, M., Gasso, G. (2020). "Partial Optimal + Transport with Applications on Positive-Unlabeled Learning". + NeurIPS. """ @@ -887,12 +900,12 @@ def entropic_partial_gromov_wasserstein(C1, C2, p, q, reg, m=None, G0=None, >>> y = np.array([3,2,98,199]).reshape((-1,1)) >>> C1 = sp.spatial.distance.cdist(x, x) >>> C2 = sp.spatial.distance.cdist(y, y) - >>> np.round(entropic_partial_gromov_wasserstein(C1, C2, a, b,50), 2) + >>> np.round(entropic_partial_gromov_wasserstein(C1, C2, a, b, 50), 2) array([[0.12, 0.13, 0. , 0. ], [0.13, 0.12, 0. , 0. ], [0. , 0. , 0.25, 0. ], [0. , 0. , 0. , 0.25]]) - >>> np.round(entropic_partial_gromov_wasserstein(C1, C2, a, b, 50, m=0.25), 2) + >>> np.round(entropic_partial_gromov_wasserstein(C1, C2, a, b, 50,0.25), 2) array([[0.02, 0.03, 0. , 0.03], [0.03, 0.03, 0. , 0.03], [0. , 0. , 0.03, 0. ], @@ -910,9 +923,9 @@ def entropic_partial_gromov_wasserstein(C1, C2, p, q, reg, m=None, G0=None, .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon, "Gromov-Wasserstein averaging of kernel and distance matrices." International Conference on Machine Learning (ICML). 2016. - .. [29] Chapel, L., Alaya, M., Gasso, G. (2019). "Partial Gromov- - Wasserstein with Applications on Positive-Unlabeled Learning". - arXiv preprint arXiv:2002.08276. + .. [29] Chapel, L., Alaya, M., Gasso, G. (2020). "Partial Optimal + Transport with Applications on Positive-Unlabeled Learning". + NeurIPS. See Also -------- @@ -1044,9 +1057,9 @@ def entropic_partial_gromov_wasserstein2(C1, C2, p, q, reg, m=None, G0=None, .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon, "Gromov-Wasserstein averaging of kernel and distance matrices." International Conference on Machine Learning (ICML). 2016. - .. [29] Chapel, L., Alaya, M., Gasso, G. (2019). "Partial Gromov- - Wasserstein with Applications on Positive-Unlabeled Learning". - arXiv preprint arXiv:2002.08276. + .. [29] Chapel, L., Alaya, M., Gasso, G. (2020). "Partial Optimal + Transport with Applications on Positive-Unlabeled Learning". + NeurIPS. """ partial_gw, log_gw = entropic_partial_gromov_wasserstein(C1, C2, p, q, reg, diff --git a/test/test_partial.py b/test/test_partial.py index 510e081..121f345 100755 --- a/test/test_partial.py +++ b/test/test_partial.py @@ -51,10 +51,12 @@ def test_raise_errors(): ot.partial.partial_gromov_wasserstein(M, M, p, q, m=-1, log=True) with pytest.raises(ValueError): - ot.partial.entropic_partial_gromov_wasserstein(M, M, p, q, reg=1, m=2, log=True) + ot.partial.entropic_partial_gromov_wasserstein(M, M, p, q, reg=1, m=2, + log=True) with pytest.raises(ValueError): - ot.partial.entropic_partial_gromov_wasserstein(M, M, p, q, reg=1, m=-1, log=True) + ot.partial.entropic_partial_gromov_wasserstein(M, M, p, q, reg=1, m=-1, + log=True) def test_partial_wasserstein_lagrange(): -- cgit v1.2.3 From 2e97be778d2d72d7a66b3721ee697399522538ba Mon Sep 17 00:00:00 2001 From: Rémi Flamary Date: Thu, 8 Apr 2021 11:09:50 +0200 Subject: [MRG] ADD JMLR paper to the readme and documentation (#231) * add JMLR reefrence to eradme and doc --- README.md | 20 ++++++++++++-------- docs/source/readme.rst | 50 +++++++++++++++++++++++++++++++++----------------- test/test_ot.py | 6 +++--- 3 files changed, 48 insertions(+), 28 deletions(-) (limited to 'test') diff --git a/README.md b/README.md index 238faed..7321aff 100644 --- a/README.md +++ b/README.md @@ -50,19 +50,23 @@ Some other examples are available in the [documentation](https://pythonot.githu #### Using and citing the toolbox If you use this toolbox in your research and find it useful, please cite POT -using the following reference: +using the following reference from our [JMLR paper](https://jmlr.org/papers/v22/20-451.html): ``` -Rémi Flamary and Nicolas Courty, POT Python Optimal Transport library, -Website: https://pythonot.github.io/, 2017 +Rémi Flamary, Nicolas Courty, Alexandre Gramfort, Mokhtar Z. Alaya, Aurélie Boisbunon, Stanislas Chambon, Laetitia Chapel, Adrien Corenflos, Kilian Fatras, Nemo Fournier, Léo Gautheron, Nathalie T.H. Gayraud, Hicham Janati, Alain Rakotomamonjy, Ievgen Redko, Antoine Rolet, Antony Schutz, Vivien Seguy, Danica J. Sutherland, Romain Tavenard, Alexander Tong, Titouan Vayer;, POT Python Optimal Transport library, Journal of Machine Learning Research, 22(78):1−8, 2021. +Website: https://pythonot.github.io/ ``` In Bibtex format: ``` -@misc{flamary2017pot, -title={POT Python Optimal Transport library}, -author={Flamary, R{'e}mi and Courty, Nicolas}, -url={https://pythonot.github.io/}, -year={2017} +@article{flamary2021pot, + author = {R{\'e}mi Flamary and Nicolas Courty and Alexandre Gramfort and Mokhtar Z. Alaya and Aur{\'e}lie Boisbunon and Stanislas Chambon and Laetitia Chapel and Adrien Corenflos and Kilian Fatras and Nemo Fournier and L{\'e}o Gautheron and Nathalie T.H. Gayraud and Hicham Janati and Alain Rakotomamonjy and Ievgen Redko and Antoine Rolet and Antony Schutz and Vivien Seguy and Danica J. Sutherland and Romain Tavenard and Alexander Tong and Titouan Vayer}, + title = {POT: Python Optimal Transport}, + journal = {Journal of Machine Learning Research}, + year = {2021}, + volume = {22}, + number = {78}, + pages = {1-8}, + url = {http://jmlr.org/papers/v22/20-451.html} } ``` diff --git a/docs/source/readme.rst b/docs/source/readme.rst index b8cb48c..f35f01b 100644 --- a/docs/source/readme.rst +++ b/docs/source/readme.rst @@ -66,6 +66,9 @@ POT provides the following generic OT solvers (links to examples): - `Partial Wasserstein and Gromov-Wasserstein `__ (exact [29] and entropic [3] formulations). +- `Sliced + Wasserstein `__ + [31, 32]. POT provides the following Machine Learning related solvers: @@ -96,22 +99,27 @@ Using and citing the toolbox ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ If you use this toolbox in your research and find it useful, please cite -POT using the following reference: +POT using the following reference from our `JMLR +paper `__: :: - Rémi Flamary and Nicolas Courty, POT Python Optimal Transport library, - Website: https://pythonot.github.io/, 2017 + Rémi Flamary, Nicolas Courty, Alexandre Gramfort, Mokhtar Z. Alaya, Aurélie Boisbunon, Stanislas Chambon, Laetitia Chapel, Adrien Corenflos, Kilian Fatras, Nemo Fournier, Léo Gautheron, Nathalie T.H. Gayraud, Hicham Janati, Alain Rakotomamonjy, Ievgen Redko, Antoine Rolet, Antony Schutz, Vivien Seguy, Danica J. Sutherland, Romain Tavenard, Alexander Tong, Titouan Vayer;, POT Python Optimal Transport library, Journal of Machine Learning Research, 22(78):1−8, 2021. + Website: https://pythonot.github.io/ In Bibtex format: :: - @misc{flamary2017pot, - title={POT Python Optimal Transport library}, - author={Flamary, R{'e}mi and Courty, Nicolas}, - url={https://pythonot.github.io/}, - year={2017} + @article{flamary2021pot, + author = {R{\'e}mi Flamary and Nicolas Courty and Alexandre Gramfort and Mokhtar Z. Alaya and Aur{\'e}lie Boisbunon and Stanislas Chambon and Laetitia Chapel and Adrien Corenflos and Kilian Fatras and Nemo Fournier and L{\'e}o Gautheron and Nathalie T.H. Gayraud and Hicham Janati and Alain Rakotomamonjy and Ievgen Redko and Antoine Rolet and Antony Schutz and Vivien Seguy and Danica J. Sutherland and Romain Tavenard and Alexander Tong and Titouan Vayer}, + title = {POT: Python Optimal Transport}, + journal = {Journal of Machine Learning Research}, + year = {2021}, + volume = {22}, + number = {78}, + pages = {1-8}, + url = {http://jmlr.org/papers/v22/20-451.html} } Installation @@ -269,6 +277,8 @@ The contributors to this library are - `Romain Tavenard `__ (1d Wasserstein) - `Mokhtar Z. Alaya `__ (Screenkhorn) - `Ievgen Redko `__ (Laplacian DA, JCPOT) +- `Adrien Corenflos `__ (Sliced + Wasserstein Distance) This toolbox benefit a lot from open source research and we would like to thank the following persons for providing some code (in various @@ -285,20 +295,21 @@ Contributions and code of conduct --------------------------------- Every contribution is welcome and should respect the `contribution -guidelines `__. Each member of the project is expected -to follow the `code of conduct `__. +guidelines <.github/CONTRIBUTING.md>`__. Each member of the project is +expected to follow the `code of conduct <.github/CODE_OF_CONDUCT.md>`__. Support ------- You can ask questions and join the development discussion: -- On the `POT Slack channel `__ +- On the POT `slack channel `__ +- On the POT `gitter channel `__ - On the POT `mailing list `__ You can also post bug reports and feature requests in Github issues. -Make sure to read our `guidelines `__ first. +Make sure to read our `guidelines <.github/CONTRIBUTING.md>`__ first. References ---------- @@ -439,10 +450,10 @@ optimal transport and Monge-Ampere obstacle problems `__, Annals of mathematics, 673-730. -[29] Chapel, L., Alaya, M., Gasso, G. (2019). `Partial -Gromov-Wasserstein with Applications on Positive-Unlabeled -Learning `__, arXiv preprint -arXiv:2002.08276. +[29] Chapel, L., Alaya, M., Gasso, G. (2020). `Partial Optimal Transport +with Applications on Positive-Unlabeled +Learning `__, Advances in Neural +Information Processing Systems (NeurIPS), 2020. [30] Flamary R., Courty N., Tuia D., Rakotomamonjy A. (2014). `Optimal transport with Laplacian regularization: Applications to domain @@ -450,11 +461,16 @@ adaptation and shape matching `__, NIPS Workshop on Optimal Transport and Machine Learning OTML, 2014. +[31] Bonneel, Nicolas, et al. `Sliced and radon wasserstein barycenters +of +measures `__, +Journal of Mathematical Imaging and Vision 51.1 (2015): 22-45 + .. |PyPI version| image:: https://badge.fury.io/py/POT.svg :target: https://badge.fury.io/py/POT .. |Anaconda Cloud| image:: https://anaconda.org/conda-forge/pot/badges/version.svg :target: https://anaconda.org/conda-forge/pot -.. |Build Status| image:: https://github.com/PythonOT/POT/workflows/build/badge.svg +.. |Build Status| image:: https://github.com/PythonOT/POT/workflows/build/badge.svg?branch=master&event=push :target: https://github.com/PythonOT/POT/actions .. |Codecov Status| image:: https://codecov.io/gh/PythonOT/POT/branch/master/graph/badge.svg :target: https://codecov.io/gh/PythonOT/POT diff --git a/test/test_ot.py b/test/test_ot.py index b7306f6..f45e4c9 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -291,17 +291,17 @@ def test_warnings(): print('Computing {} EMD '.format(1)) ot.emd(a, b, M, numItermax=1) assert "numItermax" in str(w[-1].message) - assert len(w) == 1 + #assert len(w) == 1 a[0] = 100 print('Computing {} EMD '.format(2)) ot.emd(a, b, M) assert "infeasible" in str(w[-1].message) - assert len(w) == 2 + #assert len(w) == 2 a[0] = -1 print('Computing {} EMD '.format(2)) ot.emd(a, b, M) assert "infeasible" in str(w[-1].message) - assert len(w) == 3 + #assert len(w) == 3 def test_dual_variables(): -- cgit v1.2.3 From 2a3f2241951ea9cc044b4fba8a382b6ae9630513 Mon Sep 17 00:00:00 2001 From: AdrienCorenflos Date: Mon, 19 Apr 2021 14:57:51 +0300 Subject: BUG/DOC FIX - Sinkhorn divergence used the wrong weights, and sinkhorn2 didn't support epsilon_scaling method. (#235) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * FIX: 1. Documentation of loss specific functions 2. Sinkhorn divergence weights handling 3. Sinkhorn2 does not support epsilon scaling, so I removed it (it *should* arguably support it, but this would require a refactoring of the sinkhorn iterates pretty much everywhere, maybe should be done in torch first?) * Had some PEP8 issues Co-authored-by: Rémi Flamary --- ot/bregman.py | 53 +++++++++++++++++++++++++--------------------------- test/test_bregman.py | 13 +++++++------ 2 files changed, 32 insertions(+), 34 deletions(-) (limited to 'test') diff --git a/ot/bregman.py b/ot/bregman.py index dcd35e1..559db14 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -14,11 +14,13 @@ Bregman projections solvers for entropic regularized OT # # License: MIT License -import numpy as np import warnings -from .utils import unif, dist + +import numpy as np from scipy.optimize import fmin_l_bfgs_b +from ot.utils import unif, dist + def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, stopThr=1e-9, verbose=False, log=False, **kwargs): @@ -179,8 +181,7 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, reg : float Regularization term >0 method : str - method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or - 'sinkhorn_epsilon_scaling', see those function for specific parameters + method used for the solver either 'sinkhorn', 'sinkhorn_stabilized', see those function for specific parameters numItermax : int, optional Max number of iterations stopThr : float, optional @@ -207,7 +208,7 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, Returns ------- - W : (n_hists) ndarray or float + W : (n_hists) ndarray Optimal transportation loss for the given parameters log : dict log dictionary return only if log==True in parameters @@ -244,12 +245,12 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, ot.bregman.sinkhorn_knopp : Classic Sinkhorn [2] ot.bregman.greenkhorn : Greenkhorn [21] ot.bregman.sinkhorn_stabilized: Stabilized sinkhorn [9][10] - ot.bregman.sinkhorn_epsilon_scaling: Sinkhorn with epslilon scaling [9][10] """ b = np.asarray(b, dtype=np.float64) if len(b.shape) < 2: b = b[:, None] + if method.lower() == 'sinkhorn': return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, @@ -258,10 +259,6 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) - elif method.lower() == 'sinkhorn_epsilon_scaling': - return sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=numItermax, - stopThr=stopThr, verbose=verbose, - log=log, **kwargs) else: raise ValueError("Unknown method '%s'." % method) @@ -745,8 +742,7 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, # remove numerical problems and store them in K if np.abs(u).max() > tau or np.abs(v).max() > tau: if n_hists: - alpha, beta = alpha + reg * \ - np.max(np.log(u), 1), beta + reg * np.max(np.log(v)) + alpha, beta = alpha + reg * np.max(np.log(u), 1), beta + reg * np.max(np.log(v)) else: alpha, beta = alpha + reg * np.log(u), beta + reg * np.log(v) if n_hists: @@ -1747,7 +1743,7 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', >>> reg = 0.1 >>> X_s = np.reshape(np.arange(n_samples_a), (n_samples_a, 1)) >>> X_t = np.reshape(np.arange(0, n_samples_b), (n_samples_b, 1)) - >>> empirical_sinkhorn(X_s, X_t, reg, verbose=False) # doctest: +NORMALIZE_WHITESPACE + >>> empirical_sinkhorn(X_s, X_t, reg=reg, verbose=False) # doctest: +NORMALIZE_WHITESPACE array([[4.99977301e-01, 2.26989344e-05], [2.26989344e-05, 4.99977301e-01]]) @@ -1825,8 +1821,8 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num Returns ------- - gamma : ndarray, shape (n_samples_a, n_samples_b) - Regularized optimal transportation matrix for the given parameters + W : (n_hists) ndarray or float + Optimal transportation loss for the given parameters log : dict log dictionary return only if log==True in parameters @@ -1838,8 +1834,9 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num >>> reg = 0.1 >>> X_s = np.reshape(np.arange(n_samples_a), (n_samples_a, 1)) >>> X_t = np.reshape(np.arange(0, n_samples_b), (n_samples_b, 1)) - >>> empirical_sinkhorn2(X_s, X_t, reg, verbose=False) - array([4.53978687e-05]) + >>> b = np.full((n_samples_b, 3), 1/n_samples_b) + >>> empirical_sinkhorn2(X_s, X_t, b=b, reg=reg, verbose=False) + array([4.53978687e-05, 4.53978687e-05, 4.53978687e-05]) References @@ -1935,8 +1932,8 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli Returns ------- - gamma : ndarray, shape (n_samples_a, n_samples_b) - Regularized optimal transportation matrix for the given parameters + W : (1,) ndarray + Optimal transportation symmetrized loss for the given parameters log : dict log dictionary return only if log==True in parameters @@ -1959,13 +1956,13 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli sinkhorn_loss_ab, log_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs) - sinkhorn_loss_a, log_a = empirical_sinkhorn2(X_s, X_s, reg, a, b, metric=metric, numIterMax=numIterMax, + sinkhorn_loss_a, log_a = empirical_sinkhorn2(X_s, X_s, reg, a, a, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs) - sinkhorn_loss_b, log_b = empirical_sinkhorn2(X_t, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, + sinkhorn_loss_b, log_b = empirical_sinkhorn2(X_t, X_t, reg, b, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs) - sinkhorn_div = sinkhorn_loss_ab - 1 / 2 * (sinkhorn_loss_a + sinkhorn_loss_b) + sinkhorn_div = sinkhorn_loss_ab - 0.5 * (sinkhorn_loss_a + sinkhorn_loss_b) log = {} log['sinkhorn_loss_ab'] = sinkhorn_loss_ab @@ -1981,13 +1978,13 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli sinkhorn_loss_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs) - sinkhorn_loss_a = empirical_sinkhorn2(X_s, X_s, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, + sinkhorn_loss_a = empirical_sinkhorn2(X_s, X_s, reg, a, a, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs) - sinkhorn_loss_b = empirical_sinkhorn2(X_t, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, + sinkhorn_loss_b = empirical_sinkhorn2(X_t, X_t, reg, b, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs) - sinkhorn_div = sinkhorn_loss_ab - 1 / 2 * (sinkhorn_loss_a + sinkhorn_loss_b) + sinkhorn_div = sinkhorn_loss_ab - 0.5 * (sinkhorn_loss_a + sinkhorn_loss_b) return max(0, sinkhorn_div) @@ -2212,11 +2209,11 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res # box constraints in L-BFGS-B (see Proposition 1 in [26]) bounds_u = [(max(a_I_min / ((nt - nt_budget) * epsilon + nt_budget * (b_J_max / ( - ns * epsilon * kappa * K_min))), epsilon / kappa), a_I_max / (nt * epsilon * K_min))] * ns_budget + ns * epsilon * kappa * K_min))), epsilon / kappa), a_I_max / (nt * epsilon * K_min))] * ns_budget bounds_v = [( - max(b_J_min / ((ns - ns_budget) * epsilon + ns_budget * (kappa * a_I_max / (nt * epsilon * K_min))), - epsilon * kappa), b_J_max / (ns * epsilon * K_min))] * nt_budget + max(b_J_min / ((ns - ns_budget) * epsilon + ns_budget * (kappa * a_I_max / (nt * epsilon * K_min))), + epsilon * kappa), b_J_max / (ns * epsilon * K_min))] * nt_budget # pre-calculated constants for the objective vec_eps_IJc = epsilon * kappa * (K_IJc * np.ones(nt - nt_budget).reshape((1, -1))).sum(axis=1) diff --git a/test/test_bregman.py b/test/test_bregman.py index 6aa4e08..331acd3 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -6,9 +6,10 @@ # License: MIT License import numpy as np -import ot import pytest +import ot + def test_sinkhorn(): # test sinkhorn @@ -257,7 +258,8 @@ def test_empirical_sinkhorn(): def test_empirical_sinkhorn_divergence(): # Test sinkhorn divergence n = 10 - a = ot.unif(n) + a = np.linspace(1, n, n) + a /= a.sum() b = ot.unif(n) X_s = np.reshape(np.arange(n), (n, 1)) X_t = np.reshape(np.arange(0, n * 2, 2), (n, 1)) @@ -265,16 +267,15 @@ def test_empirical_sinkhorn_divergence(): M_s = ot.dist(X_s, X_s) M_t = ot.dist(X_t, X_t) - emp_sinkhorn_div = ot.bregman.empirical_sinkhorn_divergence(X_s, X_t, 1) + emp_sinkhorn_div = ot.bregman.empirical_sinkhorn_divergence(X_s, X_t, 1, a=a, b=b) sinkhorn_div = (ot.sinkhorn2(a, b, M, 1) - 1 / 2 * ot.sinkhorn2(a, a, M_s, 1) - 1 / 2 * ot.sinkhorn2(b, b, M_t, 1)) - emp_sinkhorn_div_log, log_es = ot.bregman.empirical_sinkhorn_divergence(X_s, X_t, 1, log=True) + emp_sinkhorn_div_log, log_es = ot.bregman.empirical_sinkhorn_divergence(X_s, X_t, 1, a=a, b=b, log=True) sink_div_log_ab, log_s_ab = ot.sinkhorn2(a, b, M, 1, log=True) sink_div_log_a, log_s_a = ot.sinkhorn2(a, a, M_s, 1, log=True) sink_div_log_b, log_s_b = ot.sinkhorn2(b, b, M_t, 1, log=True) sink_div_log = sink_div_log_ab - 1 / 2 * (sink_div_log_a + sink_div_log_b) - - # check constratints + # check constraints np.testing.assert_allclose( emp_sinkhorn_div, sinkhorn_div, atol=1e-05) # cf conv emp sinkhorn np.testing.assert_allclose( -- cgit v1.2.3 From cd3ce6140d7a2dbe2bcf05927a8dd8289f4ce9e2 Mon Sep 17 00:00:00 2001 From: Rémi Flamary Date: Mon, 19 Apr 2021 15:03:57 +0200 Subject: [MRG] Cleanup test warnings (#242) * remove warnings in tests from docstrings * working tets for bregman implemneted methods * pep8 --- ot/da.py | 12 ++++++------ ot/dr.py | 2 +- ot/gpu/bregman.py | 2 +- ot/gromov.py | 20 ++++++++++---------- ot/lp/cvx.py | 3 +-- ot/optim.py | 4 ++-- test/test_bregman.py | 3 ++- 7 files changed, 23 insertions(+), 23 deletions(-) (limited to 'test') diff --git a/ot/da.py b/ot/da.py index f1e4769..cdc747c 100644 --- a/ot/da.py +++ b/ot/da.py @@ -26,7 +26,7 @@ from .optim import gcg def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10, numInnerItermax=200, stopInnerThr=1e-9, verbose=False, log=False): - """ + r""" Solve the entropic regularization optimal transport problem with nonconvex group lasso regularization @@ -137,7 +137,7 @@ def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10, def sinkhorn_l1l2_gl(a, labels_a, b, M, reg, eta=0.1, numItermax=10, numInnerItermax=200, stopInnerThr=1e-9, verbose=False, log=False): - """ + r""" Solve the entropic regularization optimal transport problem with group lasso regularization @@ -245,7 +245,7 @@ def joint_OT_mapping_linear(xs, xt, mu=1, eta=0.001, bias=False, verbose=False, verbose2=False, numItermax=100, numInnerItermax=10, stopInnerThr=1e-6, stopThr=1e-5, log=False, **kwargs): - """Joint OT and linear mapping estimation as proposed in [8] + r"""Joint OT and linear mapping estimation as proposed in [8] The function solves the following optimization problem: @@ -434,7 +434,7 @@ def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian', numItermax=100, numInnerItermax=10, stopInnerThr=1e-6, stopThr=1e-5, log=False, **kwargs): - """Joint OT and nonlinear mapping estimation with kernels as proposed in [8] + r"""Joint OT and nonlinear mapping estimation with kernels as proposed in [8] The function solves the following optimization problem: @@ -645,7 +645,7 @@ def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian', def OT_mapping_linear(xs, xt, reg=1e-6, ws=None, wt=None, bias=True, log=False): - """ return OT linear operator between samples + r""" return OT linear operator between samples The function estimates the optimal linear operator that aligns the two empirical distributions. This is equivalent to estimating the closed @@ -1228,7 +1228,7 @@ class BaseTransport(BaseEstimator): class LinearTransport(BaseTransport): - """ OT linear operator between empirical distributions + r""" OT linear operator between empirical distributions The function estimates the optimal linear operator that aligns the two empirical distributions. This is equivalent to estimating the closed diff --git a/ot/dr.py b/ot/dr.py index 11d2e10..b7a1af0 100644 --- a/ot/dr.py +++ b/ot/dr.py @@ -109,7 +109,7 @@ def fda(X, y, p=2, reg=1e-16): def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None): - """ + r""" Wasserstein Discriminant Analysis [11]_ The function solves the following optimization problem: diff --git a/ot/gpu/bregman.py b/ot/gpu/bregman.py index 2e2df83..82f34f3 100644 --- a/ot/gpu/bregman.py +++ b/ot/gpu/bregman.py @@ -15,7 +15,7 @@ from . import utils def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9, verbose=False, log=False, to_numpy=True, **kwargs): - """ + r""" Solve the entropic regularization optimal transport on GPU If the input matrix are in numpy format, they will be uploaded to the diff --git a/ot/gromov.py b/ot/gromov.py index 4427a96..8f457e9 100644 --- a/ot/gromov.py +++ b/ot/gromov.py @@ -19,7 +19,7 @@ from .optim import cg def init_matrix(C1, C2, p, q, loss_fun='square_loss'): - """Return loss matrices and tensors for Gromov-Wasserstein fast computation + r"""Return loss matrices and tensors for Gromov-Wasserstein fast computation Returns the value of \mathcal{L}(C1,C2) \otimes T with the selected loss function as the loss function of Gromow-Wasserstein discrepancy. @@ -109,7 +109,7 @@ def init_matrix(C1, C2, p, q, loss_fun='square_loss'): def tensor_product(constC, hC1, hC2, T): - """Return the tensor for Gromov-Wasserstein fast computation + r"""Return the tensor for Gromov-Wasserstein fast computation The tensor is computed as described in Proposition 1 Eq. (6) in [12]. @@ -262,7 +262,7 @@ def update_kl_loss(p, lambdas, T, Cs): def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs): - """ + r""" Returns the gromov-wasserstein transport between (C1,p) and (C2,q) The function solves the following optimization problem: @@ -343,7 +343,7 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs): - """ + r""" Returns the gromov-wasserstein discrepancy between (C1,p) and (C2,q) The function solves the following optimization problem: @@ -420,7 +420,7 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwarg def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, **kwargs): - """ + r""" Computes the FGW transport between two graphs see [24] .. math:: @@ -496,7 +496,7 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, **kwargs): - """ + r""" Computes the FGW distance between two graphs see [24] .. math:: @@ -574,7 +574,7 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5 def entropic_gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, tol=1e-9, verbose=False, log=False): - """ + r""" Returns the gromov-wasserstein transport between (C1,p) and (C2,q) (C1,p) and (C2,q) @@ -681,7 +681,7 @@ def entropic_gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, def entropic_gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, tol=1e-9, verbose=False, log=False): - """ + r""" Returns the entropic gromov-wasserstein discrepancy between the two measured similarity matrices (C1,p) and (C2,q) @@ -747,7 +747,7 @@ def entropic_gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon, def entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, max_iter=1000, tol=1e-9, verbose=False, log=False, init_C=None): - """ + r""" Returns the gromov-wasserstein barycenters of S measured similarity matrices (Cs)_{s=1}^{s=S} @@ -857,7 +857,7 @@ def entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, max_iter=1000, tol=1e-9, verbose=False, log=False, init_C=None): - """ + r""" Returns the gromov-wasserstein barycenters of S measured similarity matrices (Cs)_{s=1}^{s=S} diff --git a/ot/lp/cvx.py b/ot/lp/cvx.py index 8e763be..869d450 100644 --- a/ot/lp/cvx.py +++ b/ot/lp/cvx.py @@ -27,7 +27,7 @@ def scipy_sparse_to_spmatrix(A): def barycenter(A, M, weights=None, verbose=False, log=False, solver='interior-point'): - """Compute the Wasserstein barycenter of distributions A + r"""Compute the Wasserstein barycenter of distributions A The function solves the following optimization problem [16]: @@ -76,7 +76,6 @@ def barycenter(A, M, weights=None, verbose=False, log=False, solver='interior-po .. [16] Agueh, M., & Carlier, G. (2011). Barycenters in the Wasserstein space. SIAM Journal on Mathematical Analysis, 43(2), 904-924. - """ if weights is None: diff --git a/ot/optim.py b/ot/optim.py index 1902907..abe9e6a 100644 --- a/ot/optim.py +++ b/ot/optim.py @@ -139,7 +139,7 @@ def solve_linesearch(cost, G, deltaG, Mi, f_val, def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000, stopThr=1e-9, stopThr2=1e-9, verbose=False, log=False, **kwargs): - """ + r""" Solve the general regularized OT problem with conditional gradient The function solves the following optimization problem: @@ -278,7 +278,7 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000, def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10, numInnerItermax=200, stopThr=1e-9, stopThr2=1e-9, verbose=False, log=False): - """ + r""" Solve the general regularized OT problem with the generalized conditional gradient The function solves the following optimization problem: diff --git a/test/test_bregman.py b/test/test_bregman.py index 331acd3..1ebd21f 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -321,8 +321,9 @@ def test_implemented_methods(): # make dists unbalanced b = ot.utils.unif(n) A = rng.rand(n, 2) + A /= A.sum(0, keepdims=True) M = ot.dist(x, x) - epsilon = 1. + epsilon = 1.0 for method in IMPLEMENTED_METHODS: ot.bregman.sinkhorn(a, b, M, epsilon, method=method) -- cgit v1.2.3 From 184f8f4f7ac78f1dd7f653496d2753211a4e3426 Mon Sep 17 00:00:00 2001 From: Rémi Flamary Date: Tue, 1 Jun 2021 10:10:54 +0200 Subject: [MRG] POT numpy/torch/jax backends (#249) * add numpy and torch backends * stat sets on functions * proper import * install recent torch on windows * install recent torch on windows * now testing all functions in backedn * add jax backedn * clenaup windowds * proper convert for jax backedn * pep8 * try again windows tests * test jax conversion * try proper widows tests * emd fuction ses backedn * better test partial OT * proper tests to_numpy and teplate Backend * pep8 * pep8 x2 * feaking sinkhorn works with torch * sinkhorn2 compatible * working ot.emd2 * important detach * it should work * jax autodiff emd * pep8 * no tast same for jax * new independat tests per backedn * freaking pep8 * add tests for gradients * deprecate ot.gpu * worging dist function * working dist * dist done in backedn * not in * remove indexing * change accuacy for jax * first pull backend * projection simplex * projection simplex * projection simplex * projection simplex no ci * projection simplex no ci * projection simplex no ci * pep8 * add backedn discusion to quickstart guide * projection simplex no ci * projection simplex no ci * projection simplex no ci * pep8 + better doc * proper links * corect doctest * big debug documentation * doctest again * doctest again bis * doctest again ter (last one or i kill myself) * backend test + doc proj simplex * correction test_utils * correction test_utils * correction cumsum * correction flip * correction flip v2 * more debug * more debug * more debug + pep8 * pep8 * argh * proj_simplex * backedn works for sort * proj simplex * jax sucks * update doc * Update test/test_utils.py Co-authored-by: Alexandre Gramfort * Update docs/source/quickstart.rst Co-authored-by: Alexandre Gramfort * Update docs/source/quickstart.rst Co-authored-by: Alexandre Gramfort * Update docs/source/quickstart.rst Co-authored-by: Alexandre Gramfort * Update docs/source/readme.rst Co-authored-by: Alexandre Gramfort * Update test/test_utils.py Co-authored-by: Alexandre Gramfort * Update ot/utils.py Co-authored-by: Alexandre Gramfort * Update docs/source/readme.rst Co-authored-by: Alexandre Gramfort * Update ot/lp/__init__.py Co-authored-by: Alexandre Gramfort * begin comment alex * comment alex part 2 * optimize test gromov * proj_simplex on vectors * add awesome gradient decsnt example on the weights * pep98 of course * proof read example by alex * pep8 again * encoding oos in translation * correct legend Co-authored-by: Nicolas Courty Co-authored-by: Alexandre Gramfort --- .github/requirements_test_windows.txt | 10 + .github/workflows/build_tests.yml | 9 +- README.md | 8 +- docs/source/quickstart.rst | 68 +++- docs/source/readme.rst | 70 ++-- examples/README.txt | 2 +- examples/backends/README.txt | 4 + examples/backends/plot_unmix_optim_torch.py | 161 +++++++++ ot/__init__.py | 1 + ot/backend.py | 536 ++++++++++++++++++++++++++++ ot/bregman.py | 141 ++++---- ot/gpu/__init__.py | 4 +- ot/lp/__init__.py | 137 ++++--- ot/utils.py | 128 +++++-- requirements.txt | 3 + test/test_backend.py | 364 +++++++++++++++++++ test/test_bregman.py | 74 ++++ test/test_gromov.py | 10 +- test/test_ot.py | 91 ++++- test/test_partial.py | 4 +- test/test_utils.py | 76 +++- 21 files changed, 1692 insertions(+), 209 deletions(-) create mode 100644 .github/requirements_test_windows.txt create mode 100644 examples/backends/README.txt create mode 100644 examples/backends/plot_unmix_optim_torch.py create mode 100644 ot/backend.py create mode 100644 test/test_backend.py (limited to 'test') diff --git a/.github/requirements_test_windows.txt b/.github/requirements_test_windows.txt new file mode 100644 index 0000000..331dd57 --- /dev/null +++ b/.github/requirements_test_windows.txt @@ -0,0 +1,10 @@ +numpy +scipy>=1.3 +cython +matplotlib +autograd +pymanopt==0.2.4; python_version <'3' +pymanopt; python_version >= '3' +cvxopt +scikit-learn +pytest \ No newline at end of file diff --git a/.github/workflows/build_tests.yml b/.github/workflows/build_tests.yml index 2fc6770..92a07b5 100644 --- a/.github/workflows/build_tests.yml +++ b/.github/workflows/build_tests.yml @@ -40,7 +40,7 @@ jobs: pip install -e . - name: Run tests run: | - python -m pytest -v test/ ot/ --doctest-modules --ignore ot/gpu/ --cov=ot + python -m pytest -v test/ ot/ --doctest-modules --ignore ot/gpu/ --cov=ot --color=yes - name: Upload codecov run: | codecov @@ -142,11 +142,12 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install -r requirements.txt - pip install pytest "pytest-cov<2.6" + python -m pip install -r .github/requirements_test_windows.txt + python -m pip install torch==1.8.1+cpu -f https://download.pytorch.org/whl/torch_stable.html + python -m pip install pytest "pytest-cov<2.6" - name: Install POT run: | - pip install -e . + python -m pip install -e . - name: Run tests run: | python -m pytest -v test/ ot/ --doctest-modules --ignore ot/gpu/ --cov=ot diff --git a/README.md b/README.md index f5d18c1..e5e16e0 100644 --- a/README.md +++ b/README.md @@ -20,7 +20,7 @@ POT provides the following generic OT solvers (links to examples): * [OT Network Simplex solver](https://pythonot.github.io/auto_examples/plot_OT_1D.html) for the linear program/ Earth Movers Distance [1] . * [Conditional gradient](https://pythonot.github.io/auto_examples/plot_optim_OTreg.html) [6] and [Generalized conditional gradient](https://pythonot.github.io/auto_examples/plot_optim_OTreg.html) for regularized OT [7]. -* Entropic regularization OT solver with [Sinkhorn Knopp Algorithm](https://pythonot.github.io/auto_examples/plot_OT_1D.html) [2] , stabilized version [9] [10], greedy Sinkhorn [22] and [Screening Sinkhorn [26] ](https://pythonot.github.io/auto_examples/plot_screenkhorn_1D.html) with optional GPU implementation (requires cupy). +* Entropic regularization OT solver with [Sinkhorn Knopp Algorithm](https://pythonot.github.io/auto_examples/plot_OT_1D.html) [2] , stabilized version [9] [10], greedy Sinkhorn [22] and [Screening Sinkhorn [26] ](https://pythonot.github.io/auto_examples/plot_screenkhorn_1D.html). * Bregman projections for [Wasserstein barycenter](https://pythonot.github.io/auto_examples/barycenters/plot_barycenter_lp_vs_entropic.html) [3], [convolutional barycenter](https://pythonot.github.io/auto_examples/barycenters/plot_convolutional_barycenter.html) [21] and unmixing [4]. * Sinkhorn divergence [23] and entropic regularization OT from empirical data. * [Smooth optimal transport solvers](https://pythonot.github.io/auto_examples/plot_OT_1D_smooth.html) (dual and semi-dual) for KL and squared L2 regularizations [17]. @@ -33,6 +33,7 @@ POT provides the following generic OT solvers (links to examples): * [Partial Wasserstein and Gromov-Wasserstein](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_partial_wass_and_gromov.html) (exact [29] and entropic [3] formulations). * [Sliced Wasserstein](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_variance.html) [31, 32]. +* [Several backends](https://pythonot.github.io/quickstart.html#solving-ot-with-multiple-backends) for easy use of POT with [Pytorch](https://pytorch.org/)/[jax](https://github.com/google/jax)/[Numpy](https://numpy.org/) arrays. POT provides the following Machine Learning related solvers: @@ -77,8 +78,7 @@ The library has been tested on Linux, MacOSX and Windows. It requires a C++ comp - Numpy (>=1.16) - Scipy (>=1.0) -- Cython (>=0.23) -- Matplotlib (>=1.5) +- Cython (>=0.23) (build only, not necessary when installing wheels from pip or conda) #### Pip installation @@ -129,7 +129,7 @@ Some sub-modules require additional dependences which are discussed below pip install pymanopt autograd ``` -* **ot.gpu** (GPU accelerated OT) depends on cupy that have to be installed following instructions on [this page](https://docs-cupy.chainer.org/en/stable/install.html). Obviously you will need CUDA installed and a compatible GPU. +* **ot.gpu** (GPU accelerated OT) depends on cupy that have to be installed following instructions on [this page](https://docs-cupy.chainer.org/en/stable/install.html). Obviously you will need CUDA installed and a compatible GPU. Note that this module is deprecated since version 0.8 and will be deleted in the future. GPU is now handled automatically through the backends and several solver already can run on GPU using the Pytorch backend. ## Examples diff --git a/docs/source/quickstart.rst b/docs/source/quickstart.rst index cf5d6aa..fd046a1 100644 --- a/docs/source/quickstart.rst +++ b/docs/source/quickstart.rst @@ -15,6 +15,12 @@ are also available as notebooks on the POT Github. in ML applications we refer the reader to the following `OTML tutorial `_. +.. note:: + + Since version 0.8, POT provides a backend to automatically solve some OT + problems independently from the toolbox used by the user (numpy/torch/jax). + We provide a discussion about which functions are compatible in section + `Backend section <#solving-ot-with-multiple-backends>`_ . Why Optimal Transport ? @@ -158,7 +164,6 @@ Wasserstein but has better computational and `statistical properties `_. - Optimal transport and Wasserstein distance ------------------------------------------ @@ -922,6 +927,13 @@ The implementations of FGW and FGW barycenter is provided in functions GPU acceleration ^^^^^^^^^^^^^^^^ +.. warning:: + + The :any:`ot.gpu` has been deprecated since the release 0.8 of POT and + should not be used. The GPU implementation (in Pytorch for instance) can be + used with the novel backends using the compatible functions from POT. + + We provide several implementation of our OT solvers in :any:`ot.gpu`. Those implementations use the :code:`cupy` toolbox that obviously need to be installed. @@ -950,6 +962,60 @@ explicitly. use it you have to specifically import it with :code:`import ot.gpu` . +Solving OT with Multiple backends +--------------------------------- + +.. _backends_section: + +Since version 0.8, POT provides a backend that allows to code solvers +independently from the type of the input arrays. The idea is to provide the user +with a package that works seamlessly and returns a solution for instance as a +Pytorch tensors when the function has Pytorch tensors as input. + + +How it works +^^^^^^^^^^^^ + +The aim of the backend is to use the same function independently of the type of +the input arrays. + +For instance when executing the following code + +.. code:: python + + # a and b are 1D histograms (sum to 1 and positive) + # M is the ground cost matrix + T = ot.emd(a, b, M) # exact linear program + w = ot.emd2(a, b, M) # Wasserstein computation + +the functions :any:`ot.emd` and :any:`ot.emd2` can take inputs of the type +:any:`numpy.array`, :any:`torch.tensor` or :any:`jax.numpy.array`. The output of +the function will be the same type as the inputs and on the same device. When +possible all computations are done on the same device and also when possible the +output will be differentiable with respect to the input of the function. + + + +List of compatible Backends +^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +- `Numpy `_ (all functions and solvers) +- `Pytorch `_ (all outputs differentiable w.r.t. inputs) +- `Jax `_ (Some functions are differentiable some require a wrapper) + +List of compatible functions +^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +This list will get longer for new releases and will hopefully disappear when POT +become fully implemented with the backend. + +- :any:`ot.emd` +- :any:`ot.emd2` +- :any:`ot.sinkhorn` +- :any:`ot.sinkhorn2` +- :any:`ot.dist` + + FAQ --- diff --git a/docs/source/readme.rst b/docs/source/readme.rst index 3b594c2..82d3e6c 100644 --- a/docs/source/readme.rst +++ b/docs/source/readme.rst @@ -26,8 +26,7 @@ POT provides the following generic OT solvers (links to examples): Algorithm `__ [2] , stabilized version [9] [10], greedy Sinkhorn [22] and `Screening Sinkhorn - [26] `__ - with optional GPU implementation (requires cupy). + [26] `__. - Bregman projections for `Wasserstein barycenter `__ [3], `convolutional @@ -69,6 +68,11 @@ POT provides the following generic OT solvers (links to examples): - `Sliced Wasserstein `__ [31, 32]. +- `Several + backends `__ + for easy use of POT with + `Pytorch `__/`jax `__/`Numpy `__ + arrays. POT provides the following Machine Learning related solvers: @@ -104,12 +108,14 @@ paper `__: :: - Rémi Flamary, Nicolas Courty, Alexandre Gramfort, Mokhtar Z. Alaya, Aurélie Boisbunon, Stanislas Chambon, Laetitia Chapel, Adrien Corenflos, Kilian Fatras, Nemo Fournier, Léo Gautheron, Nathalie T.H. Gayraud, Hicham Janati, Alain Rakotomamonjy, Ievgen Redko, Antoine Rolet, Antony Schutz, Vivien Seguy, Danica J. Sutherland, Romain Tavenard, Alexander Tong, Titouan Vayer;, POT Python Optimal Transport library, Journal of Machine Learning Research, 22(78):1−8, 2021. + Rémi Flamary, Nicolas Courty, Alexandre Gramfort, Mokhtar Z. Alaya, Aurélie Boisbunon, Stanislas Chambon, Laetitia Chapel, Adrien Corenflos, Kilian Fatras, Nemo Fournier, Léo Gautheron, Nathalie T.H. Gayraud, Hicham Janati, Alain Rakotomamonjy, Ievgen Redko, Antoine Rolet, Antony Schutz, Vivien Seguy, Danica J. Sutherland, Romain Tavenard, Alexander Tong, Titouan Vayer, + POT Python Optimal Transport library, + Journal of Machine Learning Research, 22(78):1−8, 2021. Website: https://pythonot.github.io/ In Bibtex format: -:: +.. code:: bibtex @article{flamary2021pot, author = {R{\'e}mi Flamary and Nicolas Courty and Alexandre Gramfort and Mokhtar Z. Alaya and Aur{\'e}lie Boisbunon and Stanislas Chambon and Laetitia Chapel and Adrien Corenflos and Kilian Fatras and Nemo Fournier and L{\'e}o Gautheron and Nathalie T.H. Gayraud and Hicham Janati and Alain Rakotomamonjy and Ievgen Redko and Antoine Rolet and Antony Schutz and Vivien Seguy and Danica J. Sutherland and Romain Tavenard and Alexander Tong and Titouan Vayer}, @@ -131,8 +137,8 @@ following Python modules: - Numpy (>=1.16) - Scipy (>=1.0) -- Cython (>=0.23) -- Matplotlib (>=1.5) +- Cython (>=0.23) (build only, not necessary when installing wheels + from pip or conda) Pip installation ^^^^^^^^^^^^^^^^ @@ -140,19 +146,19 @@ Pip installation Note that due to a limitation of pip, ``cython`` and ``numpy`` need to be installed prior to installing POT. This can be done easily with -:: +.. code:: console pip install numpy cython You can install the toolbox through PyPI with: -:: +.. code:: console pip install POT or get the very latest version by running: -:: +.. code:: console pip install -U https://github.com/PythonOT/POT/archive/master.zip # with --user for user install (no root) @@ -163,7 +169,7 @@ If you use the Anaconda python distribution, POT is available in `conda-forge `__. To install it and the required dependencies: -:: +.. code:: console conda install -c conda-forge pot @@ -188,15 +194,17 @@ below - **ot.dr** (Wasserstein dimensionality reduction) depends on autograd and pymanopt that can be installed with: - :: +.. code:: shell - pip install pymanopt autograd + pip install pymanopt autograd - **ot.gpu** (GPU accelerated OT) depends on cupy that have to be installed following instructions on `this page `__. - -obviously you need CUDA installed and a compatible GPU. + Obviously you will need CUDA installed and a compatible GPU. Note + that this module is deprecated since version 0.8 and will be deleted + in the future. GPU is now handled automatically through the backends + and several solver already can run on GPU using the Pytorch backend. Examples -------- @@ -206,36 +214,36 @@ Short examples - Import the toolbox - .. code:: python +.. code:: python - import ot + import ot - Compute Wasserstein distances - .. code:: python +.. code:: python - # a,b are 1D histograms (sum to 1 and positive) - # M is the ground cost matrix - Wd=ot.emd2(a,b,M) # exact linear program - Wd_reg=ot.sinkhorn2(a,b,M,reg) # entropic regularized OT - # if b is a matrix compute all distances to a and return a vector + # a and b are 1D histograms (sum to 1 and positive) + # M is the ground cost matrix + Wd = ot.emd2(a, b, M) # exact linear program + Wd_reg = ot.sinkhorn2(a, b, M, reg) # entropic regularized OT + # if b is a matrix compute all distances to a and return a vector - Compute OT matrix - .. code:: python +.. code:: python - # a,b are 1D histograms (sum to 1 and positive) - # M is the ground cost matrix - T=ot.emd(a,b,M) # exact linear program - T_reg=ot.sinkhorn(a,b,M,reg) # entropic regularized OT + # a and b are 1D histograms (sum to 1 and positive) + # M is the ground cost matrix + T = ot.emd(a, b, M) # exact linear program + T_reg = ot.sinkhorn(a, b, M, reg) # entropic regularized OT - Compute Wasserstein barycenter - .. code:: python +.. code:: python - # A is a n*d matrix containing d 1D histograms - # M is the ground cost matrix - ba=ot.barycenter(A,M,reg) # reg is regularization parameter + # A is a n*d matrix containing d 1D histograms + # M is the ground cost matrix + ba = ot.barycenter(A, M, reg) # reg is regularization parameter Examples and Notebooks ~~~~~~~~~~~~~~~~~~~~~~ diff --git a/examples/README.txt b/examples/README.txt index 69a9f84..b48487f 100644 --- a/examples/README.txt +++ b/examples/README.txt @@ -1,7 +1,7 @@ Examples gallery ================ -This is a gallery of all the POT example files. +This is a gallery of all the POT example files. OT and regularized OT diff --git a/examples/backends/README.txt b/examples/backends/README.txt new file mode 100644 index 0000000..3ee0e27 --- /dev/null +++ b/examples/backends/README.txt @@ -0,0 +1,4 @@ + + +POT backend examples +-------------------- \ No newline at end of file diff --git a/examples/backends/plot_unmix_optim_torch.py b/examples/backends/plot_unmix_optim_torch.py new file mode 100644 index 0000000..9ae66e9 --- /dev/null +++ b/examples/backends/plot_unmix_optim_torch.py @@ -0,0 +1,161 @@ +# -*- coding: utf-8 -*- +r""" +================================= +Wasserstein unmixing with PyTorch +================================= + +In this example we estimate mixing parameters from distributions that minimize +the Wasserstein distance. In other words we suppose that a target +distribution :math:`\mu^t` can be expressed as a weighted sum of source +distributions :math:`\mu^s_k` with the following model: + +.. math:: + \mu^t = \sum_{k=1}^K w_k\mu^s_k + +where :math:`\mathbf{w}` is a vector of size :math:`K` and belongs in the +distribution simplex :math:`\Delta_K`. + +In order to estimate this weight vector we propose to optimize the Wasserstein +distance between the model and the observed :math:`\mu^t` with respect to +the vector. This leads to the following optimization problem: + +.. math:: + \min_{\mathbf{w}\in\Delta_K} \quad W \left(\mu^t,\sum_{k=1}^K w_k\mu^s_k\right) + +This minimization is done in this example with a simple projected gradient +descent in PyTorch. We use the automatic backend of POT that allows us to +compute the Wasserstein distance with :any:`ot.emd2` with +differentiable losses. + +""" + +# Author: Remi Flamary +# +# License: MIT License + +# sphinx_gallery_thumbnail_number = 2 + +import numpy as np +import matplotlib.pylab as pl +import ot +import torch + + +############################################################################## +# Generate data +# ------------- + +#%% Data + +nt = 100 +nt1 = 10 # + +ns1 = 50 +ns = 2 * ns1 + +rng = np.random.RandomState(2) + +xt = rng.randn(nt, 2) * 0.2 +xt[:nt1, 0] += 1 +xt[nt1:, 1] += 1 + + +xs1 = rng.randn(ns1, 2) * 0.2 +xs1[:, 0] += 1 +xs2 = rng.randn(ns1, 2) * 0.2 +xs2[:, 1] += 1 + +xs = np.concatenate((xs1, xs2)) + +# Sample reweighting matrix H +H = np.zeros((ns, 2)) +H[:ns1, 0] = 1 / ns1 +H[ns1:, 1] = 1 / ns1 +# each columns sums to 1 and has weights only for samples form the +# corresponding source distribution + +M = ot.dist(xs, xt) + +############################################################################## +# Plot data +# --------- + +#%% plot the distributions + +pl.figure(1) +pl.scatter(xt[:, 0], xt[:, 1], label='Target $\mu^t$', alpha=0.5) +pl.scatter(xs1[:, 0], xs1[:, 1], label='Source $\mu^s_1$', alpha=0.5) +pl.scatter(xs2[:, 0], xs2[:, 1], label='Source $\mu^s_2$', alpha=0.5) +pl.title('Sources and Target distributions') +pl.legend() + + +############################################################################## +# Optimization of the model wrt the Wasserstein distance +# ------------------------------------------------------ + + +#%% Weights optimization with gradient descent + +# convert numpy arrays to torch tensors +H2 = torch.tensor(H) +M2 = torch.tensor(M) + +# weights for the source distributions +w = torch.tensor(ot.unif(2), requires_grad=True) + +# uniform weights for target +b = torch.tensor(ot.unif(nt)) + +lr = 2e-3 # learning rate +niter = 500 # number of iterations +losses = [] # loss along the iterations + +# loss for the minimal Wasserstein estimator + + +def get_loss(w): + a = torch.mv(H2, w) # distribution reweighting + return ot.emd2(a, b, M2) # squared Wasserstein 2 + + +for i in range(niter): + + loss = get_loss(w) + losses.append(float(loss)) + + loss.backward() + + with torch.no_grad(): + w -= lr * w.grad # gradient step + w[:] = ot.utils.proj_simplex(w) # projection on the simplex + + w.grad.zero_() + + +############################################################################## +# Estimated weights and convergence of the objective +# --------------------------------------------------- + +we = w.detach().numpy() +print('Estimated mixture:', we) + +pl.figure(2) +pl.semilogy(losses) +pl.grid() +pl.title('Wasserstein distance') +pl.xlabel("Iterations") + +############################################################################## +# Ploting the reweighted source distribution +# ------------------------------------------ + +pl.figure(3) + +# compute source weights +ws = H.dot(we) + +pl.scatter(xt[:, 0], xt[:, 1], label='Target $\mu^t$', alpha=0.5) +pl.scatter(xs[:, 0], xs[:, 1], color='C3', s=ws * 20 * ns, label='Weighted sources $\sum_{k} w_k\mu^s_k$', alpha=0.5) +pl.title('Target and reweighted source distributions') +pl.legend() diff --git a/ot/__init__.py b/ot/__init__.py index 5a8a415..3b072c6 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -33,6 +33,7 @@ from . import smooth from . import stochastic from . import unbalanced from . import partial +from . import backend # OT functions from .lp import emd, emd2, emd_1d, emd2_1d, wasserstein_1d diff --git a/ot/backend.py b/ot/backend.py new file mode 100644 index 0000000..d68f5cf --- /dev/null +++ b/ot/backend.py @@ -0,0 +1,536 @@ +# -*- coding: utf-8 -*- +""" +Multi-lib backend for POT +""" + +# Author: Remi Flamary +# Nicolas Courty +# +# License: MIT License + +import numpy as np + +try: + import torch + torch_type = torch.Tensor +except ImportError: + torch = False + torch_type = float + +try: + import jax + import jax.numpy as jnp + jax_type = jax.numpy.ndarray +except ImportError: + jax = False + jax_type = float + +str_type_error = "All array should be from the same type/backend. Current types are : {}" + + +def get_backend_list(): + """ returns the list of available backends)""" + lst = [NumpyBackend(), ] + + if torch: + lst.append(TorchBackend()) + + if jax: + lst.append(JaxBackend()) + + return lst + + +def get_backend(*args): + """returns the proper backend for a list of input arrays + + Also raises TypeError if all arrays are not from the same backend + """ + # check that some arrays given + if not len(args) > 0: + raise ValueError(" The function takes at least one parameter") + # check all same type + + if isinstance(args[0], np.ndarray): + if not len(set(type(a) for a in args)) == 1: + raise ValueError(str_type_error.format([type(a) for a in args])) + return NumpyBackend() + elif torch and isinstance(args[0], torch_type): + if not len(set(type(a) for a in args)) == 1: + raise ValueError(str_type_error.format([type(a) for a in args])) + return TorchBackend() + elif isinstance(args[0], jax_type): + return JaxBackend() + else: + raise ValueError("Unknown type of non implemented backend.") + + +def to_numpy(*args): + """returns numpy arrays from any compatible backend""" + + if len(args) == 1: + return get_backend(args[0]).to_numpy(args[0]) + else: + return [get_backend(a).to_numpy(a) for a in args] + + +class Backend(): + + __name__ = None + __type__ = None + + def __str__(self): + return self.__name__ + + # convert to numpy + def to_numpy(self, a): + raise NotImplementedError() + + # convert from numpy + def from_numpy(self, a, type_as=None): + raise NotImplementedError() + + def set_gradients(self, val, inputs, grads): + """ define the gradients for the value val wrt the inputs """ + raise NotImplementedError() + + def zeros(self, shape, type_as=None): + raise NotImplementedError() + + def ones(self, shape, type_as=None): + raise NotImplementedError() + + def arange(self, stop, start=0, step=1, type_as=None): + raise NotImplementedError() + + def full(self, shape, fill_value, type_as=None): + raise NotImplementedError() + + def eye(self, N, M=None, type_as=None): + raise NotImplementedError() + + def sum(self, a, axis=None, keepdims=False): + raise NotImplementedError() + + def cumsum(self, a, axis=None): + raise NotImplementedError() + + def max(self, a, axis=None, keepdims=False): + raise NotImplementedError() + + def min(self, a, axis=None, keepdims=False): + raise NotImplementedError() + + def maximum(self, a, b): + raise NotImplementedError() + + def minimum(self, a, b): + raise NotImplementedError() + + def dot(self, a, b): + raise NotImplementedError() + + def abs(self, a): + raise NotImplementedError() + + def exp(self, a): + raise NotImplementedError() + + def log(self, a): + raise NotImplementedError() + + def sqrt(self, a): + raise NotImplementedError() + + def norm(self, a): + raise NotImplementedError() + + def any(self, a): + raise NotImplementedError() + + def isnan(self, a): + raise NotImplementedError() + + def isinf(self, a): + raise NotImplementedError() + + def einsum(self, subscripts, *operands): + raise NotImplementedError() + + def sort(self, a, axis=-1): + raise NotImplementedError() + + def argsort(self, a, axis=None): + raise NotImplementedError() + + def flip(self, a, axis=None): + raise NotImplementedError() + + +class NumpyBackend(Backend): + + __name__ = 'numpy' + __type__ = np.ndarray + + def to_numpy(self, a): + return a + + def from_numpy(self, a, type_as=None): + if type_as is None: + return a + elif isinstance(a, float): + return a + else: + return a.astype(type_as.dtype) + + def set_gradients(self, val, inputs, grads): + # no gradients for numpy + return val + + def zeros(self, shape, type_as=None): + if type_as is None: + return np.zeros(shape) + else: + return np.zeros(shape, dtype=type_as.dtype) + + def ones(self, shape, type_as=None): + if type_as is None: + return np.ones(shape) + else: + return np.ones(shape, dtype=type_as.dtype) + + def arange(self, stop, start=0, step=1, type_as=None): + return np.arange(start, stop, step) + + def full(self, shape, fill_value, type_as=None): + if type_as is None: + return np.full(shape, fill_value) + else: + return np.full(shape, fill_value, dtype=type_as.dtype) + + def eye(self, N, M=None, type_as=None): + if type_as is None: + return np.eye(N, M) + else: + return np.eye(N, M, dtype=type_as.dtype) + + def sum(self, a, axis=None, keepdims=False): + return np.sum(a, axis, keepdims=keepdims) + + def cumsum(self, a, axis=None): + return np.cumsum(a, axis) + + def max(self, a, axis=None, keepdims=False): + return np.max(a, axis, keepdims=keepdims) + + def min(self, a, axis=None, keepdims=False): + return np.min(a, axis, keepdims=keepdims) + + def maximum(self, a, b): + return np.maximum(a, b) + + def minimum(self, a, b): + return np.minimum(a, b) + + def dot(self, a, b): + return np.dot(a, b) + + def abs(self, a): + return np.abs(a) + + def exp(self, a): + return np.exp(a) + + def log(self, a): + return np.log(a) + + def sqrt(self, a): + return np.sqrt(a) + + def norm(self, a): + return np.sqrt(np.sum(np.square(a))) + + def any(self, a): + return np.any(a) + + def isnan(self, a): + return np.isnan(a) + + def isinf(self, a): + return np.isinf(a) + + def einsum(self, subscripts, *operands): + return np.einsum(subscripts, *operands) + + def sort(self, a, axis=-1): + return np.sort(a, axis) + + def argsort(self, a, axis=-1): + return np.argsort(a, axis) + + def flip(self, a, axis=None): + return np.flip(a, axis) + + +class JaxBackend(Backend): + + __name__ = 'jax' + __type__ = jax_type + + def to_numpy(self, a): + return np.array(a) + + def from_numpy(self, a, type_as=None): + if type_as is None: + return jnp.array(a) + else: + return jnp.array(a).astype(type_as.dtype) + + def set_gradients(self, val, inputs, grads): + # no gradients for jax because it is functional + + # does not work + # from jax import custom_jvp + # @custom_jvp + # def f(*inputs): + # return val + # f.defjvps(*grads) + # return f(*inputs) + + return val + + def zeros(self, shape, type_as=None): + if type_as is None: + return jnp.zeros(shape) + else: + return jnp.zeros(shape, dtype=type_as.dtype) + + def ones(self, shape, type_as=None): + if type_as is None: + return jnp.ones(shape) + else: + return jnp.ones(shape, dtype=type_as.dtype) + + def arange(self, stop, start=0, step=1, type_as=None): + return jnp.arange(start, stop, step) + + def full(self, shape, fill_value, type_as=None): + if type_as is None: + return jnp.full(shape, fill_value) + else: + return jnp.full(shape, fill_value, dtype=type_as.dtype) + + def eye(self, N, M=None, type_as=None): + if type_as is None: + return jnp.eye(N, M) + else: + return jnp.eye(N, M, dtype=type_as.dtype) + + def sum(self, a, axis=None, keepdims=False): + return jnp.sum(a, axis, keepdims=keepdims) + + def cumsum(self, a, axis=None): + return jnp.cumsum(a, axis) + + def max(self, a, axis=None, keepdims=False): + return jnp.max(a, axis, keepdims=keepdims) + + def min(self, a, axis=None, keepdims=False): + return jnp.min(a, axis, keepdims=keepdims) + + def maximum(self, a, b): + return jnp.maximum(a, b) + + def minimum(self, a, b): + return jnp.minimum(a, b) + + def dot(self, a, b): + return jnp.dot(a, b) + + def abs(self, a): + return jnp.abs(a) + + def exp(self, a): + return jnp.exp(a) + + def log(self, a): + return jnp.log(a) + + def sqrt(self, a): + return jnp.sqrt(a) + + def norm(self, a): + return jnp.sqrt(jnp.sum(jnp.square(a))) + + def any(self, a): + return jnp.any(a) + + def isnan(self, a): + return jnp.isnan(a) + + def isinf(self, a): + return jnp.isinf(a) + + def einsum(self, subscripts, *operands): + return jnp.einsum(subscripts, *operands) + + def sort(self, a, axis=-1): + return jnp.sort(a, axis) + + def argsort(self, a, axis=-1): + return jnp.argsort(a, axis) + + def flip(self, a, axis=None): + return jnp.flip(a, axis) + + +class TorchBackend(Backend): + + __name__ = 'torch' + __type__ = torch_type + + def to_numpy(self, a): + return a.cpu().detach().numpy() + + def from_numpy(self, a, type_as=None): + if type_as is None: + return torch.from_numpy(a) + else: + return torch.as_tensor(a, dtype=type_as.dtype, device=type_as.device) + + def set_gradients(self, val, inputs, grads): + from torch.autograd import Function + + # define a function that takes inputs and return val + class ValFunction(Function): + @staticmethod + def forward(ctx, *inputs): + return val + + @staticmethod + def backward(ctx, grad_output): + # the gradients are grad + return grads + + return ValFunction.apply(*inputs) + + def zeros(self, shape, type_as=None): + if type_as is None: + return torch.zeros(shape) + else: + return torch.zeros(shape, dtype=type_as.dtype, device=type_as.device) + + def ones(self, shape, type_as=None): + if type_as is None: + return torch.ones(shape) + else: + return torch.ones(shape, dtype=type_as.dtype, device=type_as.device) + + def arange(self, stop, start=0, step=1, type_as=None): + if type_as is None: + return torch.arange(start, stop, step) + else: + return torch.arange(start, stop, step, device=type_as.device) + + def full(self, shape, fill_value, type_as=None): + if type_as is None: + return torch.full(shape, fill_value) + else: + return torch.full(shape, fill_value, dtype=type_as.dtype, device=type_as.device) + + def eye(self, N, M=None, type_as=None): + if M is None: + M = N + if type_as is None: + return torch.eye(N, m=M) + else: + return torch.eye(N, m=M, dtype=type_as.dtype, device=type_as.device) + + def sum(self, a, axis=None, keepdims=False): + if axis is None: + return torch.sum(a) + else: + return torch.sum(a, axis, keepdim=keepdims) + + def cumsum(self, a, axis=None): + if axis is None: + return torch.cumsum(a.flatten(), 0) + else: + return torch.cumsum(a, axis) + + def max(self, a, axis=None, keepdims=False): + if axis is None: + return torch.max(a) + else: + return torch.max(a, axis, keepdim=keepdims)[0] + + def min(self, a, axis=None, keepdims=False): + if axis is None: + return torch.min(a) + else: + return torch.min(a, axis, keepdim=keepdims)[0] + + def maximum(self, a, b): + if isinstance(a, int) or isinstance(a, float): + a = torch.tensor([float(a)], dtype=b.dtype, device=b.device) + if isinstance(b, int) or isinstance(b, float): + b = torch.tensor([float(b)], dtype=a.dtype, device=a.device) + return torch.maximum(a, b) + + def minimum(self, a, b): + if isinstance(a, int) or isinstance(a, float): + a = torch.tensor([float(a)], dtype=b.dtype, device=b.device) + if isinstance(b, int) or isinstance(b, float): + b = torch.tensor([float(b)], dtype=a.dtype, device=a.device) + return torch.minimum(a, b) + + def dot(self, a, b): + if len(a.shape) == len(b.shape) == 1: + return torch.dot(a, b) + elif len(a.shape) == 2 and len(b.shape) == 1: + return torch.mv(a, b) + else: + return torch.mm(a, b) + + def abs(self, a): + return torch.abs(a) + + def exp(self, a): + return torch.exp(a) + + def log(self, a): + return torch.log(a) + + def sqrt(self, a): + return torch.sqrt(a) + + def norm(self, a): + return torch.sqrt(torch.sum(torch.square(a))) + + def any(self, a): + return torch.any(a) + + def isnan(self, a): + return torch.isnan(a) + + def isinf(self, a): + return torch.isinf(a) + + def einsum(self, subscripts, *operands): + return torch.einsum(subscripts, *operands) + + def sort(self, a, axis=-1): + sorted0, indices = torch.sort(a, dim=axis) + return sorted0 + + def argsort(self, a, axis=-1): + sorted, indices = torch.sort(a, dim=axis) + return indices + + def flip(self, a, axis=None): + if axis is None: + return torch.flip(a, tuple(i for i in range(len(a.shape)))) + if isinstance(axis, int): + return torch.flip(a, (axis,)) + else: + return torch.flip(a, dims=axis) diff --git a/ot/bregman.py b/ot/bregman.py index 559db14..b10effd 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -19,7 +19,8 @@ import warnings import numpy as np from scipy.optimize import fmin_l_bfgs_b -from ot.utils import unif, dist +from ot.utils import unif, dist, list_to_array +from .backend import get_backend def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, @@ -43,17 +44,36 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - a and b are source and target weights (histograms, both sum to 1) - The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [2]_ + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. + + The algorithm used for solving the problem is the Sinkhorn-Knopp matrix + scaling algorithm as proposed in [2]_ + + **Choosing a Sinkhorn solver** + + By default and when using a regularization parameter that is not too small + the default sinkhorn solver should be enough. If you need to use a small + regularization to get sharper OT matrices, you should use the + :any:`ot.bregman.sinkhorn_stabilized` solver that will avoid numerical + errors. This last solver can be very slow in practice and might not even + converge to a reasonable OT matrix in a finite time. This is why + :any:`ot.bregman.sinkhorn_epsilon_scaling` that relies on iterating the value + of the regularization (and using warm start) sometimes leads to better + solutions. Note that the greedy version of the sinkhorn + :any:`ot.bregman.greenkhorn` can also lead to a speedup and the screening + version of the sinkhorn :any:`ot.bregman.screenkhorn` aim a providing a + fast approximation of the Sinkhorn problem. Parameters ---------- - a : ndarray, shape (dim_a,) + a : array-like, shape (dim_a,) samples weights in the source domain - b : ndarray, shape (dim_b,) or ndarray, shape (dim_b, n_hists) + b : array-like, shape (dim_b,) or ndarray, shape (dim_b, n_hists) samples in the target domain, compute sinkhorn with multiple targets and fixed M if b is a matrix (return OT loss + dual variables in log) - M : ndarray, shape (dim_a, dim_b) + M : array-like, shape (dim_a, dim_b) loss matrix reg : float Regularization term >0 @@ -69,25 +89,9 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, log : bool, optional record log if True - **Choosing a Sinkhorn solver** - - By default and when using a regularization parameter that is not too small - the default sinkhorn solver should be enough. If you need to use a small - regularization to get sharper OT matrices, you should use the - :any:`ot.bregman.sinkhorn_stabilized` solver that will avoid numerical - errors. This last solver can be very slow in practice and might not even - converge to a reasonable OT matrix in a finite time. This is why - :any:`ot.bregman.sinkhorn_epsilon_scaling` that relie on iterating the value - of the regularization (and using warm start) sometimes leads to better - solutions. Note that the greedy version of the sinkhorn - :any:`ot.bregman.greenkhorn` can also lead to a speedup and the screening - version of the sinkhorn :any:`ot.bregman.screenkhorn` aim a providing a - fast approximation of the Sinkhorn problem. - - Returns ------- - gamma : ndarray, shape (dim_a, dim_b) + gamma : array-like, shape (dim_a, dim_b) Optimal transportation matrix for the given parameters log : dict log dictionary return only if log==True in parameters @@ -166,17 +170,35 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - a and b are source and target weights (histograms, both sum to 1) + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. + The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [2]_ + **Choosing a Sinkhorn solver** + + By default and when using a regularization parameter that is not too small + the default sinkhorn solver should be enough. If you need to use a small + regularization to get sharper OT matrices, you should use the + :any:`ot.bregman.sinkhorn_stabilized` solver that will avoid numerical + errors. This last solver can be very slow in practice and might not even + converge to a reasonable OT matrix in a finite time. This is why + :any:`ot.bregman.sinkhorn_epsilon_scaling` that relies on iterating the value + of the regularization (and using warm start) sometimes leads to better + solutions. Note that the greedy version of the sinkhorn + :any:`ot.bregman.greenkhorn` can also lead to a speedup and the screening + version of the sinkhorn :any:`ot.bregman.screenkhorn` aim a providing a + fast approximation of the Sinkhorn problem. + Parameters ---------- - a : ndarray, shape (dim_a,) + a : array-like, shape (dim_a,) samples weights in the source domain - b : ndarray, shape (dim_b,) or ndarray, shape (dim_b, n_hists) + b : array-like, shape (dim_b,) or ndarray, shape (dim_b, n_hists) samples in the target domain, compute sinkhorn with multiple targets and fixed M if b is a matrix (return OT loss + dual variables in log) - M : ndarray, shape (dim_a, dim_b) + M : array-like, shape (dim_a, dim_b) loss matrix reg : float Regularization term >0 @@ -191,28 +213,14 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, log : bool, optional record log if True - **Choosing a Sinkhorn solver** - - By default and when using a regularization parameter that is not too small - the default sinkhorn solver should be enough. If you need to use a small - regularization to get sharper OT matrices, you should use the - :any:`ot.bregman.sinkhorn_stabilized` solver that will avoid numerical - errors. This last solver can be very slow in practice and might not even - converge to a reasonable OT matrix in a finite time. This is why - :any:`ot.bregman.sinkhorn_epsilon_scaling` that relie on iterating the value - of the regularization (and using warm start) sometimes leads to better - solutions. Note that the greedy version of the sinkhorn - :any:`ot.bregman.greenkhorn` can also lead to a speedup and the screening - version of the sinkhorn :any:`ot.bregman.screenkhorn` aim a providing a - fast approximation of the Sinkhorn problem. - Returns ------- - W : (n_hists) ndarray + W : (n_hists) float/array-like Optimal transportation loss for the given parameters log : dict log dictionary return only if log==True in parameters + Examples -------- @@ -247,7 +255,8 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, ot.bregman.sinkhorn_stabilized: Stabilized sinkhorn [9][10] """ - b = np.asarray(b, dtype=np.float64) + + b = list_to_array(b) if len(b.shape) < 2: b = b[:, None] @@ -339,14 +348,14 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, """ - a = np.asarray(a, dtype=np.float64) - b = np.asarray(b, dtype=np.float64) - M = np.asarray(M, dtype=np.float64) + a, b, M = list_to_array(a, b, M) + + nx = get_backend(M, a, b) if len(a) == 0: - a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0] + a = nx.full((M.shape[0],), 1.0 / M.shape[0], type_as=M) if len(b) == 0: - b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1] + b = nx.full((M.shape[1],), 1.0 / M.shape[1], type_as=M) # init data dim_a = len(a) @@ -363,21 +372,13 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, # we assume that no distances are null except those of the diagonal of # distances if n_hists: - u = np.ones((dim_a, n_hists)) / dim_a - v = np.ones((dim_b, n_hists)) / dim_b + u = nx.ones((dim_a, n_hists), type_as=M) / dim_a + v = nx.ones((dim_b, n_hists), type_as=M) / dim_b else: - u = np.ones(dim_a) / dim_a - v = np.ones(dim_b) / dim_b + u = nx.ones(dim_a, type_as=M) / dim_a + v = nx.ones(dim_b, type_as=M) / dim_b - # print(reg) - - # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute - K = np.empty(M.shape, dtype=M.dtype) - np.divide(M, -reg, out=K) - np.exp(K, out=K) - - # print(np.min(K)) - tmp2 = np.empty(b.shape, dtype=M.dtype) + K = nx.exp(M / (-reg)) Kp = (1 / a).reshape(-1, 1) * K cpt = 0 @@ -386,13 +387,13 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, uprev = u vprev = v - KtransposeU = np.dot(K.T, u) - v = np.divide(b, KtransposeU) - u = 1. / np.dot(Kp, v) + KtransposeU = nx.dot(K.T, u) + v = b / KtransposeU + u = 1. / nx.dot(Kp, v) - if (np.any(KtransposeU == 0) - or np.any(np.isnan(u)) or np.any(np.isnan(v)) - or np.any(np.isinf(u)) or np.any(np.isinf(v))): + if (nx.any(KtransposeU == 0) + or nx.any(nx.isnan(u)) or nx.any(nx.isnan(v)) + or nx.any(nx.isinf(u)) or nx.any(nx.isinf(v))): # we have reached the machine precision # come back to previous solution and quit loop print('Warning: numerical errors at iteration', cpt) @@ -403,11 +404,11 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, # we can speed up the process by checking for the error only all # the 10th iterations if n_hists: - np.einsum('ik,ij,jk->jk', u, K, v, out=tmp2) + tmp2 = nx.einsum('ik,ij,jk->jk', u, K, v) else: # compute right marginal tmp2= (diag(u)Kdiag(v))^T1 - np.einsum('i,ij,j->j', u, K, v, out=tmp2) - err = np.linalg.norm(tmp2 - b) # violation of marginal + tmp2 = nx.einsum('i,ij,j->j', u, K, v) + err = nx.norm(tmp2 - b) # violation of marginal if log: log['err'].append(err) @@ -422,7 +423,7 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, log['v'] = v if n_hists: # return only loss - res = np.einsum('ik,ij,jk,ij->k', u, K, v, M) + res = nx.einsum('ik,ij,jk,ij->k', u, K, v, M) if log: return res, log else: diff --git a/ot/gpu/__init__.py b/ot/gpu/__init__.py index 7478fb9..e939610 100644 --- a/ot/gpu/__init__.py +++ b/ot/gpu/__init__.py @@ -25,6 +25,8 @@ result of the function with parameter ``to_numpy=False``. # # License: MIT License +import warnings + from . import bregman from . import da from .bregman import sinkhorn @@ -34,7 +36,7 @@ from . import utils from .utils import dist, to_gpu, to_np - +warnings.warn('This module will be deprecated in the next minor release of POT', category=DeprecationWarning) __all__ = ["utils", "dist", "sinkhorn", diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index d5c3a5e..c8c9da6 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -18,8 +18,9 @@ from . import cvx from .cvx import barycenter # import compiled emd from .emd_wrap import emd_c, check_result, emd_1d_sorted -from ..utils import dist +from ..utils import dist, list_to_array from ..utils import parmap +from ..backend import get_backend __all__ = ['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx', 'emd_1d', 'emd2_1d', 'wasserstein_1d'] @@ -176,8 +177,7 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True): r"""Solves the Earth Movers distance problem and returns the OT matrix - .. math:: - \gamma = arg\min_\gamma <\gamma,M>_F + .. math:: \gamma = arg\min_\gamma <\gamma,M>_F s.t. \gamma 1 = a @@ -189,37 +189,41 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True): - M is the metric cost matrix - a and b are the sample weights - .. warning:: - Note that the M matrix needs to be a C-order numpy.array in float64 - format. + .. warning:: Note that the M matrix in numpy needs to be a C-order + numpy.array in float64 format. It will be converted if not in this + format + + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. Uses the algorithm proposed in [1]_ Parameters ---------- - a : (ns,) numpy.ndarray, float64 + a : (ns,) array-like, float Source histogram (uniform weight if empty list) - b : (nt,) numpy.ndarray, float64 - Target histogram (uniform weight if empty list) - M : (ns,nt) numpy.ndarray, float64 - Loss matrix (c-order array with type float64) - numItermax : int, optional (default=100000) + b : (nt,) array-like, float + Target histogram (uniform weight if empty list) + M : (ns,nt) array-like, float + Loss matrix (c-order array in numpy with type float64) + numItermax : int, optional (default=100000) The maximum number of iterations before stopping the optimization - algorithm if it has not converged. - log: bool, optional (default=False) - If True, returns a dictionary containing the cost and dual - variables. Otherwise returns only the optimal transportation matrix. + algorithm if it has not converged. + log: bool, optional (default=False) + If True, returns a dictionary containing the cost and dual variables. + Otherwise returns only the optimal transportation matrix. center_dual: boolean, optional (default=True) - If True, centers the dual potential using function + If True, centers the dual potential using function :ref:`center_ot_dual`. Returns ------- - gamma: (ns x nt) numpy.ndarray - Optimal transportation matrix for the given parameters - log: dict - If input log is true, a dictionary containing the cost and dual - variables and exit status + gamma: array-like, shape (ns, nt) + Optimal transportation matrix for the given + parameters + log: dict, optional + If input log is true, a dictionary containing the + cost and dual variables and exit status Examples @@ -232,26 +236,37 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True): >>> a=[.5,.5] >>> b=[.5,.5] >>> M=[[0.,1.],[1.,0.]] - >>> ot.emd(a,b,M) + >>> ot.emd(a, b, M) array([[0.5, 0. ], [0. , 0.5]]) References ---------- - .. [1] Bonneel, N., Van De Panne, M., Paris, S., & Heidrich, W. - (2011, December). Displacement interpolation using Lagrangian mass - transport. In ACM Transactions on Graphics (TOG) (Vol. 30, No. 6, p. - 158). ACM. + .. [1] Bonneel, N., Van De Panne, M., Paris, S., & Heidrich, W. (2011, + December). Displacement interpolation using Lagrangian mass transport. + In ACM Transactions on Graphics (TOG) (Vol. 30, No. 6, p. 158). ACM. See Also -------- - ot.bregman.sinkhorn : Entropic regularized OT - ot.optim.cg : General regularized OT""" - + ot.bregman.sinkhorn : Entropic regularized OT ot.optim.cg : General + regularized OT""" + + # convert to numpy if list + a, b, M = list_to_array(a, b, M) + + a0, b0, M0 = a, b, M + nx = get_backend(M0, a0, b0) + + # convert to numpy + M = nx.to_numpy(M) + a = nx.to_numpy(a) + b = nx.to_numpy(b) + + # ensure float64 a = np.asarray(a, dtype=np.float64) b = np.asarray(b, dtype=np.float64) - M = np.asarray(M, dtype=np.float64) + M = np.asarray(M, dtype=np.float64, order='C') # if empty array given then use uniform distributions if len(a) == 0: @@ -262,6 +277,11 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True): assert (a.shape[0] == M.shape[0] and b.shape[0] == M.shape[1]), \ "Dimension mismatch, check dimensions of M with a and b" + # ensure that same mass + np.testing.assert_almost_equal(a.sum(0), + b.sum(0), err_msg='a and b vector must have the same sum') + b=b*a.sum()/b.sum() + asel = a != 0 bsel = b != 0 @@ -277,12 +297,12 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True): if log: log = {} log['cost'] = cost - log['u'] = u - log['v'] = v + log['u'] = nx.from_numpy(u, type_as=a0) + log['v'] = nx.from_numpy(v, type_as=b0) log['warning'] = result_code_string log['result_code'] = result_code - return G, log - return G + return nx.from_numpy(G, type_as=M0), log + return nx.from_numpy(G, type_as=M0) def emd2(a, b, M, processes=multiprocessing.cpu_count(), @@ -303,20 +323,19 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), - M is the metric cost matrix - a and b are the sample weights - .. warning:: - Note that the M matrix needs to be a C-order numpy.array in float64 - format. + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. Uses the algorithm proposed in [1]_ Parameters ---------- - a : (ns,) numpy.ndarray, float64 + a : (ns,) array-like, float64 Source histogram (uniform weight if empty list) - b : (nt,) numpy.ndarray, float64 + b : (nt,) array-like, float64 Target histogram (uniform weight if empty list) - M : (ns,nt) numpy.ndarray, float64 - Loss matrix (c-order array with type float64) + M : (ns,nt) array-like, float64 + Loss matrix (for numpy c-order array with type float64) processes : int, optional (default=nb cpu) Nb of processes used for multiple emd computation (not used on windows) numItermax : int, optional (default=100000) @@ -333,9 +352,9 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), Returns ------- - W: float + W: float, array-like Optimal transportation loss for the given parameters - log: dictnp + log: dict If input log is true, a dictionary containing dual variables and exit status @@ -367,12 +386,22 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), ot.bregman.sinkhorn : Entropic regularized OT ot.optim.cg : General regularized OT""" + a, b, M = list_to_array(a, b, M) + + a0, b0, M0 = a, b, M + nx = get_backend(M0, a0, b0) + + # convert to numpy + M = nx.to_numpy(M) + a = nx.to_numpy(a) + b = nx.to_numpy(b) + a = np.asarray(a, dtype=np.float64) b = np.asarray(b, dtype=np.float64) - M = np.asarray(M, dtype=np.float64) + M = np.asarray(M, dtype=np.float64, order= 'C') # problem with pikling Forks - if sys.platform.endswith('win32'): + if sys.platform.endswith('win32') or not nx.__name__ == 'numpy': processes = 1 # if empty array given then use uniform distributions @@ -400,12 +429,15 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), result_code_string = check_result(result_code) log = {} + G = nx.from_numpy(G, type_as=M0) if return_matrix: log['G'] = G - log['u'] = u - log['v'] = v + log['u'] = nx.from_numpy(u, type_as=a0) + log['v'] = nx.from_numpy(v, type_as=b0) log['warning'] = result_code_string log['result_code'] = result_code + cost = nx.set_gradients(nx.from_numpy(cost, type_as=M0), + (a0,b0, M0), (log['u'], log['v'], G)) return [cost, log] else: def f(b): @@ -418,6 +450,11 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), if np.any(~asel) or np.any(~bsel): u, v = estimate_dual_null_weights(u, v, a, b, M) + G = nx.from_numpy(G, type_as=M0) + cost = nx.set_gradients(nx.from_numpy(cost, type_as=M0), + (a0,b0, M0), (nx.from_numpy(u, type_as=a0), + nx.from_numpy(v, type_as=b0),G)) + check_result(result_code) return cost @@ -637,6 +674,10 @@ def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True, if b.ndim == 0 or len(b) == 0: b = np.ones((x_b.shape[0],), dtype=np.float64) / x_b.shape[0] + # ensure that same mass + np.testing.assert_almost_equal(a.sum(0),b.sum(0),err_msg='a and b vector must have the same sum') + b=b*a.sum()/b.sum() + x_a_1d = x_a.reshape((-1,)) x_b_1d = x_b.reshape((-1,)) perm_a = np.argsort(x_a_1d) diff --git a/ot/utils.py b/ot/utils.py index 544c569..4dac0c5 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -16,6 +16,7 @@ from scipy.spatial.distance import cdist import sys import warnings from inspect import signature +from .backend import get_backend __time_tic_toc = time.time() @@ -41,8 +42,11 @@ def toq(): def kernel(x1, x2, method='gaussian', sigma=1, **kwargs): """Compute kernel matrix""" + + nx = get_backend(x1, x2) + if method.lower() in ['gaussian', 'gauss', 'rbf']: - K = np.exp(-dist(x1, x2) / (2 * sigma**2)) + K = nx.exp(-dist(x1, x2) / (2 * sigma**2)) return K @@ -52,6 +56,66 @@ def laplacian(x): return L +def list_to_array(*lst): + """ Convert a list if in numpy format """ + if len(lst) > 1: + return [np.array(a) if isinstance(a, list) else a for a in lst] + else: + return np.array(lst[0]) if isinstance(lst[0], list) else lst[0] + + +def proj_simplex(v, z=1): + r""" compute the closest point (orthogonal projection) on the + generalized (n-1)-simplex of a vector v wrt. to the Euclidean + distance, thus solving: + .. math:: + \mathcal{P}(w) \in arg\min_\gamma || \gamma - v ||_2 + + s.t. \gamma^T 1= z + + \gamma\geq 0 + + If v is a 2d array, compute all the projections wrt. axis 0 + + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. + + Parameters + ---------- + v : {array-like}, shape (n, d) + z : int, optional + 'size' of the simplex (each vectors sum to z, 1 by default) + + Returns + ------- + h : ndarray, shape (n,d) + Array of projections on the simplex + """ + nx = get_backend(v) + n = v.shape[0] + if v.ndim == 1: + d1 = 1 + v = v[:, None] + else: + d1 = 0 + d = v.shape[1] + + # sort u in ascending order + u = nx.sort(v, axis=0) + # take the descending order + u = nx.flip(u, 0) + cssv = nx.cumsum(u, axis=0) - z + ind = nx.arange(n, type_as=v)[:, None] + 1 + cond = u - cssv / ind > 0 + rho = nx.sum(cond, 0) + theta = cssv[rho - 1, nx.arange(d)] / rho + w = nx.maximum(v - theta[None, :], nx.zeros(v.shape, type_as=v)) + if d1: + return w[:, 0] + else: + return w + + def unif(n): """ return a uniform histogram of length n (simplex) @@ -84,52 +148,68 @@ def euclidean_distances(X, Y, squared=False): """ Considering the rows of X (and Y=X) as vectors, compute the distance matrix between each pair of vectors. + + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. + Parameters ---------- X : {array-like}, shape (n_samples_1, n_features) Y : {array-like}, shape (n_samples_2, n_features) squared : boolean, optional Return squared Euclidean distances. + Returns ------- distances : {array}, shape (n_samples_1, n_samples_2) """ - XX = np.einsum('ij,ij->i', X, X)[:, np.newaxis] - YY = np.einsum('ij,ij->i', Y, Y)[np.newaxis, :] - distances = np.dot(X, Y.T) - distances *= -2 - distances += XX - distances += YY - np.maximum(distances, 0, out=distances) + + nx = get_backend(X, Y) + + a2 = nx.einsum('ij,ij->i', X, X) + b2 = nx.einsum('ij,ij->i', Y, Y) + + c = -2 * nx.dot(X, Y.T) + c += a2[:, None] + c += b2[None, :] + + c = nx.maximum(c, 0) + + if not squared: + c = nx.sqrt(c) + if X is Y: - # Ensure that distances between vectors and themselves are set to 0.0. - # This may not be the case due to floating point rounding errors. - distances.flat[::distances.shape[0] + 1] = 0.0 - return distances if squared else np.sqrt(distances, out=distances) + c = c * (1 - nx.eye(X.shape[0], type_as=c)) + + return c def dist(x1, x2=None, metric='sqeuclidean'): - """Compute distance between samples in x1 and x2 using function scipy.spatial.distance.cdist + """Compute distance between samples in x1 and x2 + + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. Parameters ---------- - x1 : ndarray, shape (n1,d) + x1 : array-like, shape (n1,d) matrix with n1 samples of size d - x2 : array, shape (n2,d), optional + x2 : array-like, shape (n2,d), optional matrix with n2 samples of size d (if None then x2=x1) metric : str | callable, optional - Name of the metric to be computed (full list in the doc of scipy), If a string, - the distance function can be 'braycurtis', 'canberra', 'chebyshev', 'cityblock', - 'correlation', 'cosine', 'dice', 'euclidean', 'hamming', 'jaccard', 'kulsinski', - 'mahalanobis', 'matching', 'minkowski', 'rogerstanimoto', 'russellrao', 'seuclidean', + 'sqeuclidean' or 'euclidean' on all backends. On numpy the function also + accepts from the scipy.spatial.distance.cdist function : 'braycurtis', + 'canberra', 'chebyshev', 'cityblock', 'correlation', 'cosine', 'dice', + 'euclidean', 'hamming', 'jaccard', 'kulsinski', 'mahalanobis', + 'matching', 'minkowski', 'rogerstanimoto', 'russellrao', 'seuclidean', 'sokalmichener', 'sokalsneath', 'sqeuclidean', 'wminkowski', 'yule'. Returns ------- - M : np.array (n1,n2) + M : array-like, shape (n1, n2) distance matrix computed with given metric """ @@ -137,7 +217,13 @@ def dist(x1, x2=None, metric='sqeuclidean'): x2 = x1 if metric == "sqeuclidean": return euclidean_distances(x1, x2, squared=True) - return cdist(x1, x2, metric=metric) + elif metric == "euclidean": + return euclidean_distances(x1, x2, squared=False) + else: + if not get_backend(x1, x2).__name__ == 'numpy': + raise NotImplementedError() + else: + return cdist(x1, x2, metric=metric) def dist0(n, method='lin_square'): diff --git a/requirements.txt b/requirements.txt index 331dd57..4353247 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,4 +7,7 @@ pymanopt==0.2.4; python_version <'3' pymanopt; python_version >= '3' cvxopt scikit-learn +torch +jax +jaxlib pytest \ No newline at end of file diff --git a/test/test_backend.py b/test/test_backend.py new file mode 100644 index 0000000..bc5b00c --- /dev/null +++ b/test/test_backend.py @@ -0,0 +1,364 @@ +"""Tests for backend module """ + +# Author: Remi Flamary +# +# License: MIT License + +import ot +import ot.backend +from ot.backend import torch, jax + +import pytest + +import numpy as np +from numpy.testing import assert_array_almost_equal_nulp + +from ot.backend import get_backend, get_backend_list, to_numpy + + +backend_list = get_backend_list() + + +def test_get_backend_list(): + + lst = get_backend_list() + + assert len(lst) > 0 + assert isinstance(lst[0], ot.backend.NumpyBackend) + + +@pytest.mark.parametrize('nx', backend_list) +def test_to_numpy(nx): + + v = nx.zeros(10) + M = nx.ones((10, 10)) + + v2 = to_numpy(v) + assert isinstance(v2, np.ndarray) + + v2, M2 = to_numpy(v, M) + assert isinstance(M2, np.ndarray) + + +def test_get_backend(): + + A = np.zeros((3, 2)) + B = np.zeros((3, 1)) + + nx = get_backend(A) + assert nx.__name__ == 'numpy' + + nx = get_backend(A, B) + assert nx.__name__ == 'numpy' + + # error if no parameters + with pytest.raises(ValueError): + get_backend() + + # error if unknown types + with pytest.raises(ValueError): + get_backend(1, 2.0) + + # test torch + if torch: + + A2 = torch.from_numpy(A) + B2 = torch.from_numpy(B) + + nx = get_backend(A2) + assert nx.__name__ == 'torch' + + nx = get_backend(A2, B2) + assert nx.__name__ == 'torch' + + # test not unique types in input + with pytest.raises(ValueError): + get_backend(A, B2) + + if jax: + + A2 = jax.numpy.array(A) + B2 = jax.numpy.array(B) + + nx = get_backend(A2) + assert nx.__name__ == 'jax' + + nx = get_backend(A2, B2) + assert nx.__name__ == 'jax' + + # test not unique types in input + with pytest.raises(ValueError): + get_backend(A, B2) + + +@pytest.mark.parametrize('nx', backend_list) +def test_convert_between_backends(nx): + + A = np.zeros((3, 2)) + B = np.zeros((3, 1)) + + A2 = nx.from_numpy(A) + B2 = nx.from_numpy(B) + + assert isinstance(A2, nx.__type__) + assert isinstance(B2, nx.__type__) + + nx2 = get_backend(A2, B2) + + assert nx2.__name__ == nx.__name__ + + assert_array_almost_equal_nulp(nx.to_numpy(A2), A) + assert_array_almost_equal_nulp(nx.to_numpy(B2), B) + + +def test_empty_backend(): + + rnd = np.random.RandomState(0) + M = rnd.randn(10, 3) + v = rnd.randn(3) + + nx = ot.backend.Backend() + + with pytest.raises(NotImplementedError): + nx.from_numpy(M) + with pytest.raises(NotImplementedError): + nx.to_numpy(M) + with pytest.raises(NotImplementedError): + nx.set_gradients(0, 0, 0) + with pytest.raises(NotImplementedError): + nx.zeros((10, 3)) + with pytest.raises(NotImplementedError): + nx.ones((10, 3)) + with pytest.raises(NotImplementedError): + nx.arange(10, 1, 2) + with pytest.raises(NotImplementedError): + nx.full((10, 3), 3.14) + with pytest.raises(NotImplementedError): + nx.eye((10, 3)) + with pytest.raises(NotImplementedError): + nx.sum(M) + with pytest.raises(NotImplementedError): + nx.cumsum(M) + with pytest.raises(NotImplementedError): + nx.max(M) + with pytest.raises(NotImplementedError): + nx.min(M) + with pytest.raises(NotImplementedError): + nx.maximum(v, v) + with pytest.raises(NotImplementedError): + nx.minimum(v, v) + with pytest.raises(NotImplementedError): + nx.abs(M) + with pytest.raises(NotImplementedError): + nx.log(M) + with pytest.raises(NotImplementedError): + nx.exp(M) + with pytest.raises(NotImplementedError): + nx.sqrt(M) + with pytest.raises(NotImplementedError): + nx.dot(v, v) + with pytest.raises(NotImplementedError): + nx.norm(M) + with pytest.raises(NotImplementedError): + nx.exp(M) + with pytest.raises(NotImplementedError): + nx.any(M) + with pytest.raises(NotImplementedError): + nx.isnan(M) + with pytest.raises(NotImplementedError): + nx.isinf(M) + with pytest.raises(NotImplementedError): + nx.einsum('ij->i', M) + with pytest.raises(NotImplementedError): + nx.sort(M) + with pytest.raises(NotImplementedError): + nx.argsort(M) + with pytest.raises(NotImplementedError): + nx.flip(M) + + +@pytest.mark.parametrize('backend', backend_list) +def test_func_backends(backend): + + rnd = np.random.RandomState(0) + M = rnd.randn(10, 3) + v = rnd.randn(3) + val = np.array([1.0]) + + lst_tot = [] + + for nx in [ot.backend.NumpyBackend(), backend]: + + print('Backend: ', nx.__name__) + + lst_b = [] + lst_name = [] + + Mb = nx.from_numpy(M) + vb = nx.from_numpy(v) + val = nx.from_numpy(val) + + A = nx.set_gradients(val, v, v) + lst_b.append(nx.to_numpy(A)) + lst_name.append('set_gradients') + + A = nx.zeros((10, 3)) + A = nx.zeros((10, 3), type_as=Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('zeros') + + A = nx.ones((10, 3)) + A = nx.ones((10, 3), type_as=Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('ones') + + A = nx.arange(10, 1, 2) + lst_b.append(nx.to_numpy(A)) + lst_name.append('arange') + + A = nx.full((10, 3), 3.14) + A = nx.full((10, 3), 3.14, type_as=Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('full') + + A = nx.eye(10, 3) + A = nx.eye(10, 3, type_as=Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('eye') + + A = nx.sum(Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('sum') + + A = nx.sum(Mb, axis=1, keepdims=True) + lst_b.append(nx.to_numpy(A)) + lst_name.append('sum(axis)') + + A = nx.cumsum(Mb, 0) + lst_b.append(nx.to_numpy(A)) + lst_name.append('cumsum(axis)') + + A = nx.max(Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('max') + + A = nx.max(Mb, axis=1, keepdims=True) + lst_b.append(nx.to_numpy(A)) + lst_name.append('max(axis)') + + A = nx.min(Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('min') + + A = nx.min(Mb, axis=1, keepdims=True) + lst_b.append(nx.to_numpy(A)) + lst_name.append('min(axis)') + + A = nx.maximum(vb, 0) + lst_b.append(nx.to_numpy(A)) + lst_name.append('maximum') + + A = nx.minimum(vb, 0) + lst_b.append(nx.to_numpy(A)) + lst_name.append('minimum') + + A = nx.abs(Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('abs') + + A = nx.log(A) + lst_b.append(nx.to_numpy(A)) + lst_name.append('log') + + A = nx.exp(Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('exp') + + A = nx.sqrt(nx.abs(Mb)) + lst_b.append(nx.to_numpy(A)) + lst_name.append('sqrt') + + A = nx.dot(vb, vb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('dot(v,v)') + + A = nx.dot(Mb, vb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('dot(M,v)') + + A = nx.dot(Mb, Mb.T) + lst_b.append(nx.to_numpy(A)) + lst_name.append('dot(M,M)') + + A = nx.norm(vb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('norm') + + A = nx.any(vb > 0) + lst_b.append(nx.to_numpy(A)) + lst_name.append('any') + + A = nx.isnan(vb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('isnan') + + A = nx.isinf(vb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('isinf') + + A = nx.einsum('ij->i', Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('einsum(ij->i)') + + A = nx.einsum('ij,j->i', Mb, vb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('nx.einsum(ij,j->i)') + + A = nx.einsum('ij->i', Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('nx.einsum(ij->i)') + + A = nx.sort(Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('sort') + + A = nx.argsort(Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('argsort') + + A = nx.flip(Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('flip') + + lst_tot.append(lst_b) + + lst_np = lst_tot[0] + lst_b = lst_tot[1] + + for a1, a2, name in zip(lst_np, lst_b, lst_name): + if not np.allclose(a1, a2): + print('Assert fail on: ', name) + assert np.allclose(a1, a2, atol=1e-7) + + +def test_gradients_backends(): + + rnd = np.random.RandomState(0) + v = rnd.randn(10) + c = rnd.randn(1) + + if torch: + + nx = ot.backend.TorchBackend() + + v2 = torch.tensor(v, requires_grad=True) + c2 = torch.tensor(c, requires_grad=True) + + val = c2 * torch.sum(v2 * v2) + + val2 = nx.set_gradients(val, (v2, c2), (v2, c2)) + + val2.backward() + + assert torch.equal(v2.grad, v2) + assert torch.equal(c2.grad, c2) diff --git a/test/test_bregman.py b/test/test_bregman.py index 1ebd21f..7c5162a 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -9,6 +9,10 @@ import numpy as np import pytest import ot +from ot.backend import get_backend_list +from ot.backend import torch + +backend_list = get_backend_list() def test_sinkhorn(): @@ -30,6 +34,76 @@ def test_sinkhorn(): u, G.sum(0), atol=1e-05) # cf convergence sinkhorn +@pytest.mark.parametrize('nx', backend_list) +def test_sinkhorn_backends(nx): + n_samples = 100 + n_features = 2 + rng = np.random.RandomState(0) + + x = rng.randn(n_samples, n_features) + y = rng.randn(n_samples, n_features) + a = ot.utils.unif(n_samples) + + M = ot.dist(x, y) + + G = ot.sinkhorn(a, a, M, 1) + + ab = nx.from_numpy(a) + Mb = nx.from_numpy(M) + + Gb = ot.sinkhorn(ab, ab, Mb, 1) + + np.allclose(G, nx.to_numpy(Gb)) + + +@pytest.mark.parametrize('nx', backend_list) +def test_sinkhorn2_backends(nx): + n_samples = 100 + n_features = 2 + rng = np.random.RandomState(0) + + x = rng.randn(n_samples, n_features) + y = rng.randn(n_samples, n_features) + a = ot.utils.unif(n_samples) + + M = ot.dist(x, y) + + G = ot.sinkhorn(a, a, M, 1) + + ab = nx.from_numpy(a) + Mb = nx.from_numpy(M) + + Gb = ot.sinkhorn2(ab, ab, Mb, 1) + + np.allclose(G, nx.to_numpy(Gb)) + + +def test_sinkhorn2_gradients(): + n_samples = 100 + n_features = 2 + rng = np.random.RandomState(0) + + x = rng.randn(n_samples, n_features) + y = rng.randn(n_samples, n_features) + a = ot.utils.unif(n_samples) + + M = ot.dist(x, y) + + if torch: + + a1 = torch.tensor(a, requires_grad=True) + b1 = torch.tensor(a, requires_grad=True) + M1 = torch.tensor(M, requires_grad=True) + + val = ot.sinkhorn2(a1, b1, M1, 1) + + val.backward() + + assert a1.shape == a1.grad.shape + assert b1.shape == b1.grad.shape + assert M1.shape == M1.grad.shape + + def test_sinkhorn_empty(): # test sinkhorn n = 100 diff --git a/test/test_gromov.py b/test/test_gromov.py index 43da9fc..81138ca 100644 --- a/test/test_gromov.py +++ b/test/test_gromov.py @@ -181,7 +181,7 @@ def test_fgw(): M = ot.dist(ys, yt) M /= M.max() - G = ot.gromov.fused_gromov_wasserstein(M, C1, C2, p, q, 'square_loss', alpha=0.5) + G, log = ot.gromov.fused_gromov_wasserstein(M, C1, C2, p, q, 'square_loss', alpha=0.5, log=True) # check constratints np.testing.assert_allclose( @@ -242,9 +242,9 @@ def test_fgw_barycenter(): init_X = np.random.randn(n_samples, ys.shape[1]) - X, C = ot.gromov.fgw_barycenters(n_samples, [ys, yt], [C1, C2], [ot.unif(ns), ot.unif(nt)], [.5, .5], 0.5, - fixed_structure=False, fixed_features=True, init_X=init_X, - p=ot.unif(n_samples), loss_fun='square_loss', - max_iter=100, tol=1e-3) + X, C, log = ot.gromov.fgw_barycenters(n_samples, [ys, yt], [C1, C2], [ot.unif(ns), ot.unif(nt)], [.5, .5], 0.5, + fixed_structure=False, fixed_features=True, init_X=init_X, + p=ot.unif(n_samples), loss_fun='square_loss', + max_iter=100, tol=1e-3, log=True) np.testing.assert_allclose(C.shape, (n_samples, n_samples)) np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1])) diff --git a/test/test_ot.py b/test/test_ot.py index f45e4c9..3e953dc 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -12,9 +12,12 @@ from scipy.stats import wasserstein_distance import ot from ot.datasets import make_1D_gauss as gauss +from ot.backend import get_backend_list, torch +backend_list = get_backend_list() -def test_emd_dimension_mismatch(): + +def test_emd_dimension_and_mass_mismatch(): # test emd and emd2 for dimension mismatch n_samples = 100 n_features = 2 @@ -29,6 +32,80 @@ def test_emd_dimension_mismatch(): np.testing.assert_raises(AssertionError, ot.emd2, a, a, M) + b = a.copy() + a[0] = 100 + np.testing.assert_raises(AssertionError, ot.emd, a, b, M) + + +@pytest.mark.parametrize('nx', backend_list) +def test_emd_backends(nx): + n_samples = 100 + n_features = 2 + rng = np.random.RandomState(0) + + x = rng.randn(n_samples, n_features) + y = rng.randn(n_samples, n_features) + a = ot.utils.unif(n_samples) + + M = ot.dist(x, y) + + G = ot.emd(a, a, M) + + ab = nx.from_numpy(a) + Mb = nx.from_numpy(M) + + Gb = ot.emd(ab, ab, Mb) + + np.allclose(G, nx.to_numpy(Gb)) + + +@pytest.mark.parametrize('nx', backend_list) +def test_emd2_backends(nx): + n_samples = 100 + n_features = 2 + rng = np.random.RandomState(0) + + x = rng.randn(n_samples, n_features) + y = rng.randn(n_samples, n_features) + a = ot.utils.unif(n_samples) + + M = ot.dist(x, y) + + val = ot.emd2(a, a, M) + + ab = nx.from_numpy(a) + Mb = nx.from_numpy(M) + + valb = ot.emd2(ab, ab, Mb) + + np.allclose(val, nx.to_numpy(valb)) + + +def test_emd2_gradients(): + n_samples = 100 + n_features = 2 + rng = np.random.RandomState(0) + + x = rng.randn(n_samples, n_features) + y = rng.randn(n_samples, n_features) + a = ot.utils.unif(n_samples) + + M = ot.dist(x, y) + + if torch: + + a1 = torch.tensor(a, requires_grad=True) + b1 = torch.tensor(a, requires_grad=True) + M1 = torch.tensor(M, requires_grad=True) + + val = ot.emd2(a1, b1, M1) + + val.backward() + + assert a1.shape == a1.grad.shape + assert b1.shape == b1.grad.shape + assert M1.shape == M1.grad.shape + def test_emd_emd2(): # test emd and emd2 for simple identity @@ -83,7 +160,7 @@ def test_emd_1d_emd2_1d(): np.testing.assert_allclose(np.ones((m,)) / m, G.sum(0)) # check G is similar - np.testing.assert_allclose(G, G_1d) + np.testing.assert_allclose(G, G_1d, atol=1e-15) # check AssertionError is raised if called on non 1d arrays u = np.random.randn(n, 2) @@ -292,16 +369,6 @@ def test_warnings(): ot.emd(a, b, M, numItermax=1) assert "numItermax" in str(w[-1].message) #assert len(w) == 1 - a[0] = 100 - print('Computing {} EMD '.format(2)) - ot.emd(a, b, M) - assert "infeasible" in str(w[-1].message) - #assert len(w) == 2 - a[0] = -1 - print('Computing {} EMD '.format(2)) - ot.emd(a, b, M) - assert "infeasible" in str(w[-1].message) - #assert len(w) == 3 def test_dual_variables(): diff --git a/test/test_partial.py b/test/test_partial.py index 121f345..3571e2a 100755 --- a/test/test_partial.py +++ b/test/test_partial.py @@ -129,9 +129,9 @@ def test_partial_wasserstein(): # check constratints np.testing.assert_equal( - G.sum(1) <= p, [True] * len(p)) # cf convergence wasserstein + G.sum(1) - p <= 1e-5, [True] * len(p)) # cf convergence wasserstein np.testing.assert_equal( - G.sum(0) <= q, [True] * len(q)) # cf convergence wasserstein + G.sum(0) - q <= 1e-5, [True] * len(q)) # cf convergence wasserstein np.testing.assert_allclose( np.sum(G), m, atol=1e-04) diff --git a/test/test_utils.py b/test/test_utils.py index db9cda6..76b1faa 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -4,11 +4,47 @@ # # License: MIT License - +import pytest import ot import numpy as np import sys +from ot.backend import get_backend_list + +backend_list = get_backend_list() + + +@pytest.mark.parametrize('nx', backend_list) +def test_proj_simplex(nx): + n = 10 + rng = np.random.RandomState(0) + + # test on matrix when projection is done on axis 0 + x = rng.randn(n, 2) + x1 = nx.from_numpy(x) + + # all projections should sum to 1 + proj = ot.utils.proj_simplex(x1) + l1 = np.sum(nx.to_numpy(proj), axis=0) + l2 = np.ones(2) + np.testing.assert_allclose(l1, l2, atol=1e-5) + + # all projections should sum to 3 + proj = ot.utils.proj_simplex(x1, 3) + l1 = np.sum(nx.to_numpy(proj), axis=0) + l2 = 3 * np.ones(2) + np.testing.assert_allclose(l1, l2, atol=1e-5) + + # tets on vector + x = rng.randn(n) + x1 = nx.from_numpy(x) + + # all projections should sum to 1 + proj = ot.utils.proj_simplex(x1) + l1 = np.sum(nx.to_numpy(proj), axis=0) + l2 = np.ones(2) + np.testing.assert_allclose(l1, l2, atol=1e-5) + def test_parmap(): @@ -45,8 +81,8 @@ def test_tic_toc(): def test_kernel(): n = 100 - - x = np.random.randn(n, 2) + rng = np.random.RandomState(0) + x = rng.randn(n, 2) K = ot.utils.kernel(x, x) @@ -67,7 +103,8 @@ def test_dist(): n = 100 - x = np.random.randn(n, 2) + rng = np.random.RandomState(0) + x = rng.randn(n, 2) D = np.zeros((n, n)) for i in range(n): @@ -78,8 +115,27 @@ def test_dist(): D3 = ot.dist(x) # dist shoul return squared euclidean - np.testing.assert_allclose(D, D2) - np.testing.assert_allclose(D, D3) + np.testing.assert_allclose(D, D2, atol=1e-14) + np.testing.assert_allclose(D, D3, atol=1e-14) + + +@ pytest.mark.parametrize('nx', backend_list) +def test_dist_backends(nx): + + n = 100 + rng = np.random.RandomState(0) + x = rng.randn(n, 2) + x1 = nx.from_numpy(x) + + lst_metric = ['euclidean', 'sqeuclidean'] + + for metric in lst_metric: + + D = ot.dist(x, x, metric=metric) + D1 = ot.dist(x1, x1, metric=metric) + + # low atol because jax forces float32 + np.testing.assert_allclose(D, nx.to_numpy(D1), atol=1e-5) def test_dist0(): @@ -95,9 +151,11 @@ def test_dots(): n1, n2, n3, n4 = 100, 50, 200, 100 - A = np.random.randn(n1, n2) - B = np.random.randn(n2, n3) - C = np.random.randn(n3, n4) + rng = np.random.RandomState(0) + + A = rng.randn(n1, n2) + B = rng.randn(n2, n3) + C = rng.randn(n3, n4) X1 = ot.utils.dots(A, B, C) -- cgit v1.2.3 From 2dbeeda9308029a8e8db56bed07d48f4d5718efb Mon Sep 17 00:00:00 2001 From: Huy Tran Date: Mon, 14 Jun 2021 13:06:40 +0200 Subject: [MRG] Batch/Lazy Log Sinkhorn Knopp on samples (#259) * Add batch implementation of Sinkhorn * Reformat to pep8 and modify parameter * Fix error in batch size * Code review and add test * Fix accidental typo in test_empirical_sinkhorn * Remove whitespace * Edit config.yml --- .circleci/config.yml | 1 + ot/bregman.py | 134 +++++++++++++++++++++++++++++++++++++++++++-------- test/test_bregman.py | 44 +++++++++++++++++ 3 files changed, 158 insertions(+), 21 deletions(-) (limited to 'test') diff --git a/.circleci/config.yml b/.circleci/config.yml index 29c9a07..e4c71dd 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -73,6 +73,7 @@ jobs: command: | cd docs; make html; + no_output_timeout: 30m # Save the outputs - store_artifacts: diff --git a/ot/bregman.py b/ot/bregman.py index b10effd..105b38b 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -11,6 +11,7 @@ Bregman projections solvers for entropic regularized OT # Mokhtar Z. Alaya # Alexander Tong # Ievgen Redko +# Quang Huy Tran # # License: MIT License @@ -18,6 +19,7 @@ import warnings import numpy as np from scipy.optimize import fmin_l_bfgs_b +from scipy.special import logsumexp from ot.utils import unif, dist, list_to_array from .backend import get_backend @@ -1684,7 +1686,7 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', - numIterMax=10000, stopThr=1e-9, verbose=False, + numIterMax=10000, stopThr=1e-9, isLazy=False, batchSize=100, verbose=False, log=False, **kwargs): r''' Solve the entropic regularization optimal transport problem and return the @@ -1723,6 +1725,12 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', Max number of iterations stopThr : float, optional Stop threshol on error (>0) + isLazy: boolean, optional + If True, then only calculate the cost matrix by block and return the dual potentials only (to save memory) + If False, calculate full cost matrix and return outputs of sinkhorn function. + batchSize: int or tuple of 2 int, optional + Size of the batcheses used to compute the sinkhorn update without memory overhead. + When a tuple is provided it sets the size of the left/right batches. verbose : bool, optional Print information along iterations log : bool, optional @@ -1758,24 +1766,78 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. ''' - + ns, nt = X_s.shape[0], X_t.shape[0] if a is None: - a = unif(np.shape(X_s)[0]) + a = unif(ns) if b is None: - b = unif(np.shape(X_t)[0]) + b = unif(nt) + + if isLazy: + if log: + dict_log = {"err": []} - M = dist(X_s, X_t, metric=metric) + log_a, log_b = np.log(a), np.log(b) + f, g = np.zeros(ns), np.zeros(nt) + + if isinstance(batchSize, int): + bs, bt = batchSize, batchSize + elif isinstance(batchSize, tuple) and len(batchSize) == 2: + bs, bt = batchSize[0], batchSize[1] + else: + raise ValueError("Batch size must be in integer or a tuple of two integers") + + range_s, range_t = range(0, ns, bs), range(0, nt, bt) + + lse_f = np.zeros(ns) + lse_g = np.zeros(nt) + + for i_ot in range(numIterMax): + + for i in range_s: + M = dist(X_s[i:i + bs, :], X_t, metric=metric) + lse_f[i:i + bs] = logsumexp(g[None, :] - M / reg, axis=1) + f = log_a - lse_f + + for j in range_t: + M = dist(X_s, X_t[j:j + bt, :], metric=metric) + lse_g[j:j + bt] = logsumexp(f[:, None] - M / reg, axis=0) + g = log_b - lse_g + + if (i_ot + 1) % 10 == 0: + m1 = np.zeros_like(a) + for i in range_s: + M = dist(X_s[i:i + bs, :], X_t, metric=metric) + m1[i:i + bs] = np.exp(f[i:i + bs, None] + g[None, :] - M / reg).sum(1) + err = np.abs(m1 - a).sum() + if log: + dict_log["err"].append(err) + + if verbose and (i_ot + 1) % 100 == 0: + print("Error in marginal at iteration {} = {}".format(i_ot + 1, err)) + + if err <= stopThr: + break + + if log: + dict_log["u"] = f + dict_log["v"] = g + return (f, g, dict_log) + else: + return (f, g) - if log: - pi, log = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=True, **kwargs) - return pi, log else: - pi = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=False, **kwargs) - return pi + M = dist(X_s, X_t, metric=metric) + + if log: + pi, log = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=True, **kwargs) + return pi, log + else: + pi = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=False, **kwargs) + return pi def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9, - verbose=False, log=False, **kwargs): + isLazy=False, batchSize=100, verbose=False, log=False, **kwargs): r''' Solve the entropic regularization optimal transport problem from empirical data and return the OT loss @@ -1814,6 +1876,12 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num Max number of iterations stopThr : float, optional Stop threshol on error (>0) + isLazy: boolean, optional + If True, then only calculate the cost matrix by block and return the dual potentials only (to save memory) + If False, calculate full cost matrix and return outputs of sinkhorn function. + batchSize: int or tuple of 2 int, optional + Size of the batcheses used to compute the sinkhorn update without memory overhead. + When a tuple is provided it sets the size of the left/right batches. verbose : bool, optional Print information along iterations log : bool, optional @@ -1850,21 +1918,45 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. ''' + ns, nt = X_s.shape[0], X_t.shape[0] if a is None: - a = unif(np.shape(X_s)[0]) + a = unif(ns) if b is None: - b = unif(np.shape(X_t)[0]) + b = unif(nt) - M = dist(X_s, X_t, metric=metric) + if isLazy: + if log: + f, g, dict_log = empirical_sinkhorn(X_s, X_t, reg, a, b, metric, numIterMax=numIterMax, stopThr=stopThr, + isLazy=isLazy, batchSize=batchSize, verbose=verbose, log=log) + else: + f, g = empirical_sinkhorn(X_s, X_t, reg, a, b, metric, numIterMax=numIterMax, stopThr=stopThr, + isLazy=isLazy, batchSize=batchSize, verbose=verbose, log=log) + + bs = batchSize if isinstance(batchSize, int) else batchSize[0] + range_s = range(0, ns, bs) + + loss = 0 + for i in range_s: + M_block = dist(X_s[i:i + bs, :], X_t, metric=metric) + pi_block = np.exp(f[i:i + bs, None] + g[None, :] - M_block / reg) + loss += np.sum(M_block * pi_block) + + if log: + return loss, dict_log + else: + return loss - if log: - sinkhorn_loss, log = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log, - **kwargs) - return sinkhorn_loss, log else: - sinkhorn_loss = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log, - **kwargs) - return sinkhorn_loss + M = dist(X_s, X_t, metric=metric) + + if log: + sinkhorn_loss, log = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log, + **kwargs) + return sinkhorn_loss, log + else: + sinkhorn_loss = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log, + **kwargs) + return sinkhorn_loss def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9, diff --git a/test/test_bregman.py b/test/test_bregman.py index 7c5162a..9665229 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -2,6 +2,7 @@ # Author: Remi Flamary # Kilian Fatras +# Quang Huy Tran # # License: MIT License @@ -329,6 +330,49 @@ def test_empirical_sinkhorn(): np.testing.assert_allclose(loss_emp_sinkhorn, loss_sinkhorn, atol=1e-05) +def test_lazy_empirical_sinkhorn(): + # test sinkhorn + n = 100 + a = ot.unif(n) + b = ot.unif(n) + numIterMax = 1000 + + X_s = np.reshape(np.arange(n), (n, 1)) + X_t = np.reshape(np.arange(0, n), (n, 1)) + M = ot.dist(X_s, X_t) + M_m = ot.dist(X_s, X_t, metric='minkowski') + + f, g = ot.bregman.empirical_sinkhorn(X_s, X_t, 1, numIterMax=numIterMax, isLazy=True, batchSize=(1, 1), verbose=True) + G_sqe = np.exp(f[:, None] + g[None, :] - M / 1) + sinkhorn_sqe = ot.sinkhorn(a, b, M, 1) + + f, g, log_es = ot.bregman.empirical_sinkhorn(X_s, X_t, 0.1, numIterMax=numIterMax, isLazy=True, batchSize=1, log=True) + G_log = np.exp(f[:, None] + g[None, :] - M / 0.1) + sinkhorn_log, log_s = ot.sinkhorn(a, b, M, 0.1, log=True) + + f, g = ot.bregman.empirical_sinkhorn(X_s, X_t, 1, metric='minkowski', numIterMax=numIterMax, isLazy=True, batchSize=1) + G_m = np.exp(f[:, None] + g[None, :] - M_m / 1) + sinkhorn_m = ot.sinkhorn(a, b, M_m, 1) + + loss_emp_sinkhorn, log = ot.bregman.empirical_sinkhorn2(X_s, X_t, 1, numIterMax=numIterMax, isLazy=True, batchSize=1, log=True) + loss_sinkhorn = ot.sinkhorn2(a, b, M, 1) + + # check constratints + np.testing.assert_allclose( + sinkhorn_sqe.sum(1), G_sqe.sum(1), atol=1e-05) # metric sqeuclidian + np.testing.assert_allclose( + sinkhorn_sqe.sum(0), G_sqe.sum(0), atol=1e-05) # metric sqeuclidian + np.testing.assert_allclose( + sinkhorn_log.sum(1), G_log.sum(1), atol=1e-05) # log + np.testing.assert_allclose( + sinkhorn_log.sum(0), G_log.sum(0), atol=1e-05) # log + np.testing.assert_allclose( + sinkhorn_m.sum(1), G_m.sum(1), atol=1e-05) # metric euclidian + np.testing.assert_allclose( + sinkhorn_m.sum(0), G_m.sum(0), atol=1e-05) # metric euclidian + np.testing.assert_allclose(loss_emp_sinkhorn, loss_sinkhorn, atol=1e-05) + + def test_empirical_sinkhorn_divergence(): # Test sinkhorn divergence n = 10 -- cgit v1.2.3 From 8ef3341a472909f223ec0f678f11f136f55c1406 Mon Sep 17 00:00:00 2001 From: Rémi Flamary Date: Thu, 17 Jun 2021 11:46:37 +0200 Subject: [MRG] Speedup tests (#262) * speedup tests * add color to tests and timings * add test unbalanced * stupid missing - --- .github/workflows/build_tests.yml | 8 ++++---- Makefile | 4 ++-- test/test_bregman.py | 7 ++++--- test/test_da.py | 8 ++++---- test/test_gromov.py | 15 +++++++++------ test/test_optim.py | 6 +++--- test/test_stochastic.py | 40 +++++++++++++++++++-------------------- test/test_unbalanced.py | 33 ++++++++++++++++++++++++++++++-- 8 files changed, 77 insertions(+), 44 deletions(-) (limited to 'test') diff --git a/.github/workflows/build_tests.yml b/.github/workflows/build_tests.yml index 92a07b5..fd0ade6 100644 --- a/.github/workflows/build_tests.yml +++ b/.github/workflows/build_tests.yml @@ -40,7 +40,7 @@ jobs: pip install -e . - name: Run tests run: | - python -m pytest -v test/ ot/ --doctest-modules --ignore ot/gpu/ --cov=ot --color=yes + python -m pytest --durations=20 -v test/ ot/ --doctest-modules --ignore ot/gpu/ --cov=ot --color=yes - name: Upload codecov run: | codecov @@ -95,7 +95,7 @@ jobs: pip install -e . - name: Run tests run: | - python -m pytest -v test/ ot/ --ignore ot/gpu/ + python -m pytest --durations=20 -v test/ ot/ --ignore ot/gpu/ --color=yes macos: @@ -122,7 +122,7 @@ jobs: pip install -e . - name: Run tests run: | - python -m pytest -v test/ ot/ --doctest-modules --ignore ot/gpu/ --cov=ot + python -m pytest --durations=20 -v test/ ot/ --doctest-modules --ignore ot/gpu/ --cov=ot --color=yes windows: @@ -150,4 +150,4 @@ jobs: python -m pip install -e . - name: Run tests run: | - python -m pytest -v test/ ot/ --doctest-modules --ignore ot/gpu/ --cov=ot + python -m pytest --durations=20 -v test/ ot/ --doctest-modules --ignore ot/gpu/ --cov=ot --color=yes diff --git a/Makefile b/Makefile index 32332b4..315218d 100644 --- a/Makefile +++ b/Makefile @@ -45,10 +45,10 @@ pep8 : flake8 examples/ ot/ test/ test : FORCE pep8 - $(PYTHON) -m pytest -v test/ --doctest-modules --ignore ot/gpu/ + $(PYTHON) -m pytest --durations=20 -v test/ --doctest-modules --ignore ot/gpu/ pytest : FORCE - $(PYTHON) -m pytest -v test/ --doctest-modules --ignore ot/gpu/ + $(PYTHON) -m pytest --durations=20 -v test/ --doctest-modules --ignore ot/gpu/ release : twine upload dist/* diff --git a/test/test_bregman.py b/test/test_bregman.py index 9665229..88166a5 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -293,7 +293,7 @@ def test_unmix(): def test_empirical_sinkhorn(): # test sinkhorn - n = 100 + n = 10 a = ot.unif(n) b = ot.unif(n) @@ -332,7 +332,7 @@ def test_empirical_sinkhorn(): def test_lazy_empirical_sinkhorn(): # test sinkhorn - n = 100 + n = 10 a = ot.unif(n) b = ot.unif(n) numIterMax = 1000 @@ -342,7 +342,7 @@ def test_lazy_empirical_sinkhorn(): M = ot.dist(X_s, X_t) M_m = ot.dist(X_s, X_t, metric='minkowski') - f, g = ot.bregman.empirical_sinkhorn(X_s, X_t, 1, numIterMax=numIterMax, isLazy=True, batchSize=(1, 1), verbose=True) + f, g = ot.bregman.empirical_sinkhorn(X_s, X_t, 1, numIterMax=numIterMax, isLazy=True, batchSize=(1, 3), verbose=True) G_sqe = np.exp(f[:, None] + g[None, :] - M / 1) sinkhorn_sqe = ot.sinkhorn(a, b, M, 1) @@ -458,6 +458,7 @@ def test_implemented_methods(): ot.bregman.sinkhorn2(a, b, M, epsilon, method=method) +@pytest.mark.filterwarnings("ignore:Bottleneck") def test_screenkhorn(): # test screenkhorn rng = np.random.RandomState(0) diff --git a/test/test_da.py b/test/test_da.py index 52c6a48..44bb2e9 100644 --- a/test/test_da.py +++ b/test/test_da.py @@ -106,8 +106,8 @@ def test_sinkhorn_l1l2_transport_class(): """test_sinkhorn_transport """ - ns = 150 - nt = 200 + ns = 50 + nt = 100 Xs, ys = make_data_classif('3gauss', ns) Xt, yt = make_data_classif('3gauss2', nt) @@ -448,8 +448,8 @@ def test_mapping_transport_class(): """test_mapping_transport """ - ns = 60 - nt = 120 + ns = 20 + nt = 30 Xs, ys = make_data_classif('3gauss', ns) Xt, yt = make_data_classif('3gauss2', nt) diff --git a/test/test_gromov.py b/test/test_gromov.py index 81138ca..56414a8 100644 --- a/test/test_gromov.py +++ b/test/test_gromov.py @@ -9,6 +9,8 @@ import numpy as np import ot +import pytest + def test_gromov(): n_samples = 50 # nb samples @@ -128,9 +130,10 @@ def test_gromov_barycenter(): np.testing.assert_allclose(Cb2.shape, (n_samples, n_samples)) +@pytest.mark.filterwarnings("ignore:divide") def test_gromov_entropic_barycenter(): - ns = 50 - nt = 60 + ns = 20 + nt = 30 Xs, ys = ot.datasets.make_data_classif('3gauss', ns, random_state=42) Xt, yt = ot.datasets.make_data_classif('3gauss2', nt, random_state=42) @@ -138,19 +141,19 @@ def test_gromov_entropic_barycenter(): C1 = ot.dist(Xs) C2 = ot.dist(Xt) - n_samples = 3 + n_samples = 2 Cb = ot.gromov.entropic_gromov_barycenters(n_samples, [C1, C2], [ot.unif(ns), ot.unif(nt) ], ot.unif(n_samples), [.5, .5], - 'square_loss', 2e-3, - max_iter=100, tol=1e-3, + 'square_loss', 1e-3, + max_iter=50, tol=1e-5, verbose=True) np.testing.assert_allclose(Cb.shape, (n_samples, n_samples)) Cb2 = ot.gromov.entropic_gromov_barycenters(n_samples, [C1, C2], [ot.unif(ns), ot.unif(nt) ], ot.unif(n_samples), [.5, .5], - 'kl_loss', 2e-3, + 'kl_loss', 1e-3, max_iter=100, tol=1e-3) np.testing.assert_allclose(Cb2.shape, (n_samples, n_samples)) diff --git a/test/test_optim.py b/test/test_optim.py index 48de38a..fd194c2 100644 --- a/test/test_optim.py +++ b/test/test_optim.py @@ -37,8 +37,8 @@ def test_conditional_gradient(): np.testing.assert_allclose(b, G.sum(0)) -def test_conditional_gradient2(): - n = 1000 # nb samples +def test_conditional_gradient_itermax(): + n = 100 # nb samples mu_s = np.array([0, 0]) cov_s = np.array([[1, 0], [0, 1]]) @@ -63,7 +63,7 @@ def test_conditional_gradient2(): reg = 1e-1 - G, log = ot.optim.cg(a, b, M, reg, f, df, numItermaxEmd=200000, + G, log = ot.optim.cg(a, b, M, reg, f, df, numItermaxEmd=10000, verbose=True, log=True) np.testing.assert_allclose(a, G.sum(1)) diff --git a/test/test_stochastic.py b/test/test_stochastic.py index 155622c..98e93ec 100644 --- a/test/test_stochastic.py +++ b/test/test_stochastic.py @@ -30,7 +30,7 @@ import ot def test_stochastic_sag(): # test sag - n = 15 + n = 10 reg = 1 numItermax = 30000 rng = np.random.RandomState(0) @@ -45,9 +45,9 @@ def test_stochastic_sag(): # check constratints np.testing.assert_allclose( - u, G.sum(1), atol=1e-04) # cf convergence sag + u, G.sum(1), atol=1e-03) # cf convergence sag np.testing.assert_allclose( - u, G.sum(0), atol=1e-04) # cf convergence sag + u, G.sum(0), atol=1e-03) # cf convergence sag ############################################################################# @@ -60,9 +60,9 @@ def test_stochastic_sag(): def test_stochastic_asgd(): # test asgd - n = 15 + n = 10 reg = 1 - numItermax = 100000 + numItermax = 10000 rng = np.random.RandomState(0) x = rng.randn(n, 2) @@ -75,9 +75,9 @@ def test_stochastic_asgd(): # check constratints np.testing.assert_allclose( - u, G.sum(1), atol=1e-03) # cf convergence asgd + u, G.sum(1), atol=1e-02) # cf convergence asgd np.testing.assert_allclose( - u, G.sum(0), atol=1e-03) # cf convergence asgd + u, G.sum(0), atol=1e-02) # cf convergence asgd ############################################################################# @@ -90,9 +90,9 @@ def test_stochastic_asgd(): def test_sag_asgd_sinkhorn(): # test all algorithms - n = 15 + n = 10 reg = 1 - nb_iter = 100000 + nb_iter = 10000 rng = np.random.RandomState(0) x = rng.randn(n, 2) @@ -107,17 +107,17 @@ def test_sag_asgd_sinkhorn(): # check constratints np.testing.assert_allclose( - G_sag.sum(1), G_sinkhorn.sum(1), atol=1e-03) + G_sag.sum(1), G_sinkhorn.sum(1), atol=1e-02) np.testing.assert_allclose( - G_sag.sum(0), G_sinkhorn.sum(0), atol=1e-03) + G_sag.sum(0), G_sinkhorn.sum(0), atol=1e-02) np.testing.assert_allclose( - G_asgd.sum(1), G_sinkhorn.sum(1), atol=1e-03) + G_asgd.sum(1), G_sinkhorn.sum(1), atol=1e-02) np.testing.assert_allclose( - G_asgd.sum(0), G_sinkhorn.sum(0), atol=1e-03) + G_asgd.sum(0), G_sinkhorn.sum(0), atol=1e-02) np.testing.assert_allclose( - G_sag, G_sinkhorn, atol=1e-03) # cf convergence sag + G_sag, G_sinkhorn, atol=1e-02) # cf convergence sag np.testing.assert_allclose( - G_asgd, G_sinkhorn, atol=1e-03) # cf convergence asgd + G_asgd, G_sinkhorn, atol=1e-02) # cf convergence asgd ############################################################################# @@ -136,7 +136,7 @@ def test_stochastic_dual_sgd(): # test sgd n = 10 reg = 1 - numItermax = 15000 + numItermax = 5000 batch_size = 10 rng = np.random.RandomState(0) @@ -167,7 +167,7 @@ def test_dual_sgd_sinkhorn(): # test all dual algorithms n = 10 reg = 1 - nb_iter = 15000 + nb_iter = 5000 batch_size = 10 rng = np.random.RandomState(0) @@ -183,11 +183,11 @@ def test_dual_sgd_sinkhorn(): # check constratints np.testing.assert_allclose( - G_sgd.sum(1), G_sinkhorn.sum(1), atol=1e-03) + G_sgd.sum(1), G_sinkhorn.sum(1), atol=1e-02) np.testing.assert_allclose( - G_sgd.sum(0), G_sinkhorn.sum(0), atol=1e-03) + G_sgd.sum(0), G_sinkhorn.sum(0), atol=1e-02) np.testing.assert_allclose( - G_sgd, G_sinkhorn, atol=1e-03) # cf convergence sgd + G_sgd, G_sinkhorn, atol=1e-02) # cf convergence sgd # Test gaussian n = 30 diff --git a/test/test_unbalanced.py b/test/test_unbalanced.py index dfeaad9..e8349d1 100644 --- a/test/test_unbalanced.py +++ b/test/test_unbalanced.py @@ -115,7 +115,8 @@ def test_stabilized_vs_sinkhorn(): G, log = ot.unbalanced.sinkhorn_unbalanced2(a, b, M, reg=epsilon, method="sinkhorn_stabilized", reg_m=reg_m, - log=True) + log=True, + verbose=True) G2, log2 = ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, reg_m, method="sinkhorn", log=True) @@ -138,7 +139,7 @@ def test_unbalanced_barycenter(method): reg_m = 1. q, log = barycenter_unbalanced(A, M, reg=epsilon, reg_m=reg_m, - method=method, log=True) + method=method, log=True, verbose=True) # check fixed point equations fi = reg_m / (reg_m + epsilon) logA = np.log(A + 1e-16) @@ -173,6 +174,7 @@ def test_barycenter_stabilized_vs_sinkhorn(): reg_m=reg_m, log=True, tau=100, method="sinkhorn_stabilized", + verbose=True ) q, log = barycenter_unbalanced(A, M, reg=epsilon, reg_m=reg_m, method="sinkhorn", @@ -182,6 +184,33 @@ def test_barycenter_stabilized_vs_sinkhorn(): q, qstable, atol=1e-05) +def test_wrong_method(): + + n = 10 + rng = np.random.RandomState(42) + + x = rng.randn(n, 2) + a = ot.utils.unif(n) + + # make dists unbalanced + b = ot.utils.unif(n) * 1.5 + + M = ot.dist(x, x) + epsilon = 1. + reg_m = 1. + + with pytest.raises(ValueError): + ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon, + reg_m=reg_m, + method='badmethod', + log=True, + verbose=True) + with pytest.raises(ValueError): + ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, reg_m, + method='badmethod', + verbose=True) + + def test_implemented_methods(): IMPLEMENTED_METHODS = ['sinkhorn', 'sinkhorn_stabilized'] TO_BE_IMPLEMENTED_METHODS = ['sinkhorn_reg_scaling'] -- cgit v1.2.3 From 96bf1a46e74d6985419e14222afb0b9241a7bb36 Mon Sep 17 00:00:00 2001 From: Minhui Huang <32522773+mhhuang95@users.noreply.github.com> Date: Mon, 6 Sep 2021 08:06:50 -0700 Subject: [MRG] Projection Robust Wasserstein (#267) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * ot.dr: PRW code; text.text_dr: PRW test code. * ot.dr: PRW code; test.test_dr: PRW test code. * fix errors: pep8(3.8) * fix errors: pep8(3.8) * modified readme; prw code review * fix pep error * edit comment * modified math comment Co-authored-by: Rémi Flamary --- README.md | 3 ++ ot/dr.py | 114 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ test/test_dr.py | 37 ++++++++++++++++++ 3 files changed, 154 insertions(+) (limited to 'test') diff --git a/README.md b/README.md index 20e0606..6a2cf15 100644 --- a/README.md +++ b/README.md @@ -198,6 +198,7 @@ The contributors to this library are * [Mokhtar Z. Alaya](http://mzalaya.github.io/) (Screenkhorn) * [Ievgen Redko](https://ievred.github.io/) (Laplacian DA, JCPOT) * [Adrien Corenflos](https://adriencorenflos.github.io/) (Sliced Wasserstein Distance) +* [Minhui Huang](https://mhhuang95.github.io) (Projection Robust Wasserstein Distance) This toolbox benefit a lot from open source research and we would like to thank the following persons for providing some code (in various languages): @@ -283,3 +284,5 @@ You can also post bug reports and feature requests in Github issues. Make sure t [30] Flamary R., Courty N., Tuia D., Rakotomamonjy A. (2014). [Optimal transport with Laplacian regularization: Applications to domain adaptation and shape matching](https://remi.flamary.com/biblio/flamary2014optlaplace.pdf), NIPS Workshop on Optimal Transport and Machine Learning OTML, 2014. [31] Bonneel, Nicolas, et al. [Sliced and radon wasserstein barycenters of measures](https://perso.liris.cnrs.fr/nicolas.bonneel/WassersteinSliced-JMIV.pdf), Journal of Mathematical Imaging and Vision 51.1 (2015): 22-45 + +[32] Huang, M., Ma S., Lai, L. (2021). [A Riemannian Block Coordinate Descent Method for Computing the Projection Robust Wasserstein Distance](http://proceedings.mlr.press/v139/huang21e.html), Proceedings of the 38th International Conference on Machine Learning (ICML). diff --git a/ot/dr.py b/ot/dr.py index b7a1af0..64588cf 100644 --- a/ot/dr.py +++ b/ot/dr.py @@ -10,6 +10,7 @@ Dimension reduction with OT """ # Author: Remi Flamary +# Minhui Huang # # License: MIT License @@ -198,3 +199,116 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None): return (X - mx.reshape((1, -1))).dot(Popt) return Popt, proj + + +def projection_robust_wasserstein(X, Y, a, b, tau, U0=None, reg=0.1, k=2, stopThr=1e-3, maxiter=100, verbose=0): + r""" + Projection Robust Wasserstein Distance [32] + + The function solves the following optimization problem: + + .. math:: + \max_{U \in St(d, k)} \min_{\pi \in \Pi(\mu,\nu)} \sum_{i,j} \pi_{i,j} \|U^T(x_i - y_j)\|^2 - reg * H(\pi) + + - :math:`U` is a linear projection operator in the Stiefel(d, k) manifold + - :math:`H(\pi)` is entropy regularizer + - :math:`x_i`, :math:`y_j` are samples of measures \mu and \nu respectively + + Parameters + ---------- + X : ndarray, shape (n, d) + Samples from measure \mu + Y : ndarray, shape (n, d) + Samples from measure \nu + a : ndarray, shape (n, ) + weights for measure \mu + b : ndarray, shape (n, ) + weights for measure \nu + tau : float + stepsize for Riemannian Gradient Descent + U0 : ndarray, shape (d, p) + Initial starting point for projection. + reg : float, optional + Regularization term >0 (entropic regularization) + k : int + Subspace dimension + stopThr : float, optional + Stop threshold on error (>0) + verbose : int, optional + Print information along iterations. + + Returns + ------- + pi : ndarray, shape (n, n) + Optimal transportation matrix for the given parameters + U : ndarray, shape (d, k) + Projection operator. + + References + ---------- + .. [32] Huang, M. , Ma S. & Lai L. (2021). + A Riemannian Block Coordinate Descent Method for Computing + the Projection Robust Wasserstein Distance, ICML. + """ # noqa + + # initialization + n, d = X.shape + m, d = Y.shape + a = np.asarray(a, dtype=np.float64) + b = np.asarray(b, dtype=np.float64) + u = np.ones(n) / n + v = np.ones(m) / m + ones = np.ones((n, m)) + + assert d > k + + if U0 is None: + U = np.random.randn(d, k) + U, _ = np.linalg.qr(U) + else: + U = U0 + + def Vpi(X, Y, a, b, pi): + # Return the second order matrix of the displacements: sum_ij { (pi)_ij (X_i-Y_j)(X_i-Y_j)^T }. + A = X.T.dot(pi).dot(Y) + return X.T.dot(np.diag(a)).dot(X) + Y.T.dot(np.diag(np.sum(pi, 0))).dot(Y) - A - A.T + + err = 1 + iter = 0 + + while err > stopThr and iter < maxiter: + + # Projected cost matrix + UUT = U.dot(U.T) + M = np.diag(np.diag(X.dot(UUT.dot(X.T)))).dot(ones) + ones.dot( + np.diag(np.diag(Y.dot(UUT.dot(Y.T))))) - 2 * X.dot(UUT.dot(Y.T)) + + A = np.empty(M.shape, dtype=M.dtype) + np.divide(M, -reg, out=A) + np.exp(A, out=A) + + # Sinkhorn update + Ap = (1 / a).reshape(-1, 1) * A + AtransposeU = np.dot(A.T, u) + v = np.divide(b, AtransposeU) + u = 1. / np.dot(Ap, v) + pi = u.reshape((-1, 1)) * A * v.reshape((1, -1)) + + V = Vpi(X, Y, a, b, pi) + + # Riemannian gradient descent + G = 2 / reg * V.dot(U) + GTU = G.T.dot(U) + xi = G - U.dot(GTU + GTU.T) / 2 # Riemannian gradient + U, _ = np.linalg.qr(U + tau * xi) # Retraction by QR decomposition + + grad_norm = np.linalg.norm(xi) + err = max(reg * grad_norm, np.linalg.norm(np.sum(pi, 0) - b, 1)) + + f_val = np.trace(U.T.dot(V.dot(U))) + if verbose: + print('RBCD Iteration: ', iter, ' error', err, '\t fval: ', f_val) + + iter = iter + 1 + + return pi, U diff --git a/test/test_dr.py b/test/test_dr.py index c5df287..fa75a18 100644 --- a/test/test_dr.py +++ b/test/test_dr.py @@ -1,6 +1,7 @@ """Tests for module dr on Dimensionality Reduction """ # Author: Remi Flamary +# Minhui Huang # # License: MIT License @@ -57,3 +58,39 @@ def test_wda(): projwda(xs) np.testing.assert_allclose(np.sum(Pwda**2, 0), np.ones(p)) + + +@pytest.mark.skipif(nogo, reason="Missing modules (autograd or pymanopt)") +def test_prw(): + d = 100 # Dimension + n = 100 # Number samples + k = 3 # Subspace dimension + dim = 3 + + def fragmented_hypercube(n, d, dim): + assert dim <= d + assert dim >= 1 + assert dim == int(dim) + + a = (1. / n) * np.ones(n) + b = (1. / n) * np.ones(n) + + # First measure : uniform on the hypercube + X = np.random.uniform(-1, 1, size=(n, d)) + + # Second measure : fragmentation + tmp_y = np.random.uniform(-1, 1, size=(n, d)) + Y = tmp_y + 2 * np.sign(tmp_y) * np.array(dim * [1] + (d - dim) * [0]) + return a, b, X, Y + + a, b, X, Y = fragmented_hypercube(n, d, dim) + + tau = 0.002 + reg = 0.2 + + pi, U = ot.dr.projection_robust_wasserstein(X, Y, a, b, tau, reg=reg, k=k, maxiter=1000, verbose=1) + + U0 = np.random.randn(d, k) + U0, _ = np.linalg.qr(U0) + + pi, U = ot.dr.projection_robust_wasserstein(X, Y, a, b, tau, U0=U0, reg=reg, k=k, maxiter=1000, verbose=1) -- cgit v1.2.3 From e0ba31ce39a7d9e65e50ea970a574b3db54e4207 Mon Sep 17 00:00:00 2001 From: Tanguy Date: Fri, 17 Sep 2021 18:36:33 +0200 Subject: [MRG] Implementation of two news algorithms: SaGroW and PoGroW. (#275) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add two new algorithms to solve Gromov Wasserstein: Sampled Gromov Wasserstein and Pointwise Gromov Wasserstein. * Correct some lines in SaGroW and PoGroW to follow pep8 guide. * Change nb_samples name. Use rdm state. Change symmetric check. * Change names of len(p) and len(q) in SaGroW and PoGroW. * Re-add some deleted lines in the comments of gromov.py Co-authored-by: Rémi Flamary --- README.md | 4 + examples/gromov/plot_gromov.py | 34 ++++ ot/gromov.py | 376 +++++++++++++++++++++++++++++++++++++++++ test/test_gromov.py | 88 +++++++++- 4 files changed, 496 insertions(+), 6 deletions(-) (limited to 'test') diff --git a/README.md b/README.md index 6a2cf15..266d847 100644 --- a/README.md +++ b/README.md @@ -28,6 +28,7 @@ POT provides the following generic OT solvers (links to examples): * [Gromov-Wasserstein distances](https://pythonot.github.io/auto_examples/gromov/plot_gromov.html) and [GW barycenters](https://pythonot.github.io/auto_examples/gromov/plot_gromov_barycenter.html) (exact [13] and regularized [12]) * [Fused-Gromov-Wasserstein distances solver](https://pythonot.github.io/auto_examples/gromov/plot_fgw.html#sphx-glr-auto-examples-plot-fgw-py) and [FGW barycenters](https://pythonot.github.io/auto_examples/gromov/plot_barycenter_fgw.html) [24] * [Stochastic solver](https://pythonot.github.io/auto_examples/plot_stochastic.html) for Large-scale Optimal Transport (semi-dual problem [18] and dual problem [19]) +* [Stochastic solver of Gromov Wasserstein](https://pythonot.github.io/auto_examples/gromov/plot_gromov.html) for large-scale problem with any loss functions [33] * Non regularized [free support Wasserstein barycenters](https://pythonot.github.io/auto_examples/barycenters/plot_free_support_barycenter.html) [20]. * [Unbalanced OT](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT_1D.html) with KL relaxation and [barycenter](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT_barycenter_1D.html) [10, 25]. * [Partial Wasserstein and Gromov-Wasserstein](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_partial_wass_and_gromov.html) (exact [29] and entropic [3] @@ -198,6 +199,7 @@ The contributors to this library are * [Mokhtar Z. Alaya](http://mzalaya.github.io/) (Screenkhorn) * [Ievgen Redko](https://ievred.github.io/) (Laplacian DA, JCPOT) * [Adrien Corenflos](https://adriencorenflos.github.io/) (Sliced Wasserstein Distance) +* [Tanguy Kerdoncuff](https://hv0nnus.github.io/) (Sampled Gromov Wasserstein) * [Minhui Huang](https://mhhuang95.github.io) (Projection Robust Wasserstein Distance) This toolbox benefit a lot from open source research and we would like to thank the following persons for providing some code (in various languages): @@ -286,3 +288,5 @@ You can also post bug reports and feature requests in Github issues. Make sure t [31] Bonneel, Nicolas, et al. [Sliced and radon wasserstein barycenters of measures](https://perso.liris.cnrs.fr/nicolas.bonneel/WassersteinSliced-JMIV.pdf), Journal of Mathematical Imaging and Vision 51.1 (2015): 22-45 [32] Huang, M., Ma S., Lai, L. (2021). [A Riemannian Block Coordinate Descent Method for Computing the Projection Robust Wasserstein Distance](http://proceedings.mlr.press/v139/huang21e.html), Proceedings of the 38th International Conference on Machine Learning (ICML). + +[33] Kerdoncuff T., Emonet R., Marc S. [Sampled Gromov Wasserstein](https://hal.archives-ouvertes.fr/hal-03232509/document), Machine Learning Journal (MJL), 2021 diff --git a/examples/gromov/plot_gromov.py b/examples/gromov/plot_gromov.py index deb2f86..5a362cf 100644 --- a/examples/gromov/plot_gromov.py +++ b/examples/gromov/plot_gromov.py @@ -104,3 +104,37 @@ pl.imshow(gw, cmap='jet') pl.title('Entropic Gromov Wasserstein') pl.show() + +############################################################################# +# +# Compute GW with a scalable stochastic method with any loss function +# ---------------------------------------------------------------------- + + +def loss(x, y): + return np.abs(x - y) + + +pgw, plog = ot.gromov.pointwise_gromov_wasserstein(C1, C2, p, q, loss, max_iter=100, + log=True) + +sgw, slog = ot.gromov.sampled_gromov_wasserstein(C1, C2, p, q, loss, epsilon=0.1, max_iter=100, + log=True) + +print('Pointwise Gromov-Wasserstein distance estimated: ' + str(plog['gw_dist_estimated'])) +print('Variance estimated: ' + str(plog['gw_dist_std'])) +print('Sampled Gromov-Wasserstein distance: ' + str(slog['gw_dist_estimated'])) +print('Variance estimated: ' + str(slog['gw_dist_std'])) + + +pl.figure(1, (10, 5)) + +pl.subplot(1, 2, 1) +pl.imshow(pgw.toarray(), cmap='jet') +pl.title('Pointwise Gromov Wasserstein') + +pl.subplot(1, 2, 2) +pl.imshow(sgw, cmap='jet') +pl.title('Sampled Gromov Wasserstein') + +pl.show() diff --git a/ot/gromov.py b/ot/gromov.py index 8f457e9..a27217a 100644 --- a/ot/gromov.py +++ b/ot/gromov.py @@ -16,6 +16,10 @@ import numpy as np from .bregman import sinkhorn from .utils import dist, UndefinedParameter from .optim import cg +from .lp import emd_1d, emd +from .utils import check_random_state + +from scipy.sparse import issparse def init_matrix(C1, C2, p, q, loss_fun='square_loss'): @@ -572,6 +576,378 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5 return log['fgw_dist'] +def GW_distance_estimation(C1, C2, p, q, loss_fun, T, + nb_samples_p=None, nb_samples_q=None, std=True, random_state=None): + r""" + Returns an approximation of the gromov-wasserstein cost between (C1,p) and (C2,q) + with a fixed transport plan T. + + The function gives an unbiased approximation of the following equation: + + .. math:: + GW = \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l} + + Where : + + - C1 : Metric cost matrix in the source space + - C2 : Metric cost matrix in the target space + - L : Loss function to account for the misfit between the similarity matrices + - T : Matrix with marginal p and q + + Parameters + ---------- + C1 : ndarray, shape (ns, ns) + Metric cost matrix in the source space + C2 : ndarray, shape (nt, nt) + Metric costfr matrix in the target space + p : ndarray, shape (ns,) + Distribution in the source space + q : ndarray, shape (nt,) + Distribution in the target space + loss_fun : function: \mathcal{R} \times \mathcal{R} \shortarrow \mathcal{R} + Loss function used for the distance, the transport plan does not depend on the loss function + T : csr or ndarray, shape (ns, nt) + Transport plan matrix, either a sparse csr matrix or + nb_samples_p : int, optional + nb_samples_p is the number of samples (without replacement) along the first dimension of T. + nb_samples_q : int, optional + nb_samples_q is the number of samples along the second dimension of T, for each sample along the first. + std : bool, optional + Standard deviation associated with the prediction of the gromov-wasserstein cost. + random_state : int or RandomState instance, optional + Fix the seed for to allow reproducibility + + Returns + ------- + : float + Gromov-wasserstein cost + + References + ---------- + .. [14] Kerdoncuff, Tanguy, Emonet, Rémi, Sebban, Marc + "Sampled Gromov Wasserstein." + Machine Learning Journal (MLJ). 2021. + + """ + generator = check_random_state(random_state) + + len_p = len(p) + len_q = len(q) + + # It is always better to sample from the biggest distribution first. + if len_p < len_q: + p, q = q, p + len_p, len_q = len_q, len_p + C1, C2 = C2, C1 + T = T.T + + if nb_samples_p is None: + if issparse(T): + # If T is sparse, it probably mean that PoGroW was used, thus the number of sample is reduced + nb_samples_p = min(int(5 * (len_p * np.log(len_p)) ** 0.5), len_p) + else: + nb_samples_p = len_p + else: + # The number of sample along the first dimension is without replacement. + nb_samples_p = min(nb_samples_p, len_p) + if nb_samples_q is None: + nb_samples_q = 1 + if std: + nb_samples_q = max(2, nb_samples_q) + + index_k = np.zeros((nb_samples_p, nb_samples_q), dtype=int) + index_l = np.zeros((nb_samples_p, nb_samples_q), dtype=int) + list_value_sample = np.zeros((nb_samples_p, nb_samples_p, nb_samples_q)) + + index_i = generator.choice(len_p, size=nb_samples_p, p=p, replace=False) + index_j = generator.choice(len_p, size=nb_samples_p, p=p, replace=False) + + for i in range(nb_samples_p): + if issparse(T): + T_indexi = T[index_i[i], :].toarray()[0] + T_indexj = T[index_j[i], :].toarray()[0] + else: + T_indexi = T[index_i[i], :] + T_indexj = T[index_j[i], :] + # For each of the row sampled, the column is sampled. + index_k[i] = generator.choice(len_q, size=nb_samples_q, p=T_indexi / T_indexi.sum(), replace=True) + index_l[i] = generator.choice(len_q, size=nb_samples_q, p=T_indexj / T_indexj.sum(), replace=True) + + for n in range(nb_samples_q): + list_value_sample[:, :, n] = loss_fun(C1[np.ix_(index_i, index_j)], C2[np.ix_(index_k[:, n], index_l[:, n])]) + + if std: + std_value = np.sum(np.std(list_value_sample, axis=2) ** 2) ** 0.5 + return np.mean(list_value_sample), std_value / (nb_samples_p * nb_samples_p) + else: + return np.mean(list_value_sample) + + +def pointwise_gromov_wasserstein(C1, C2, p, q, loss_fun, + alpha=1, max_iter=100, threshold_plan=0, log=False, verbose=False, random_state=None): + r""" + Returns the gromov-wasserstein transport between (C1,p) and (C2,q) using a stochastic Frank-Wolfe. + This method as a O(max_iter \times PN^2) time complexity with P the number of Sinkhorn iterations. + + The function solves the following optimization problem: + + .. math:: + GW = arg\min_T \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l} + + s.t. T 1 = p + + T^T 1= q + + T\geq 0 + + Where : + + - C1 : Metric cost matrix in the source space + - C2 : Metric cost matrix in the target space + - p : distribution in the source space + - q : distribution in the target space + - L : loss function to account for the misfit between the similarity matrices + + Parameters + ---------- + C1 : ndarray, shape (ns, ns) + Metric cost matrix in the source space + C2 : ndarray, shape (nt, nt) + Metric costfr matrix in the target space + p : ndarray, shape (ns,) + Distribution in the source space + q : ndarray, shape (nt,) + Distribution in the target space + loss_fun : function: \mathcal{R} \times \mathcal{R} \shortarrow \mathcal{R} + Loss function used for the distance, the transport plan does not depend on the loss function + alpha : float + Step of the Frank-Wolfe algorithm, should be between 0 and 1 + max_iter : int, optional + Max number of iterations + threshold_plan : float, optional + Deleting very small value in the transport plan. If above zero, it violate the marginal constraints. + verbose : bool, optional + Print information along iterations + log : bool, optional + Gives the distance estimated and the standard deviation + random_state : int or RandomState instance, optional + Fix the seed for to allow reproducibility + + Returns + ------- + T : ndarray, shape (ns, nt) + Optimal coupling between the two spaces + + References + ---------- + .. [14] Kerdoncuff, Tanguy, Emonet, Rémi, Sebban, Marc + "Sampled Gromov Wasserstein." + Machine Learning Journal (MLJ). 2021. + + """ + C1 = np.asarray(C1, dtype=np.float64) + C2 = np.asarray(C2, dtype=np.float64) + p = np.asarray(p, dtype=np.float64) + q = np.asarray(q, dtype=np.float64) + len_p = len(p) + len_q = len(q) + + generator = check_random_state(random_state) + + index = np.zeros(2, dtype=int) + + # Initialize with default marginal + index[0] = generator.choice(len_p, size=1, p=p) + index[1] = generator.choice(len_q, size=1, p=q) + T = emd_1d(C1[index[0]], C2[index[1]], a=p, b=q, dense=False).tocsr() + + best_gw_dist_estimated = np.inf + for cpt in range(max_iter): + index[0] = generator.choice(len_p, size=1, p=p) + T_index0 = T[index[0], :].toarray()[0] + index[1] = generator.choice(len_q, size=1, p=T_index0 / T_index0.sum()) + + if alpha == 1: + T = emd_1d(C1[index[0]], C2[index[1]], a=p, b=q, dense=False).tocsr() + else: + new_T = emd_1d(C1[index[0]], C2[index[1]], a=p, b=q, dense=False).tocsr() + T = (1 - alpha) * T + alpha * new_T + # To limit the number of non 0, the values bellow the threshold are set to 0. + T.data[T.data < threshold_plan] = 0 + T.eliminate_zeros() + + if cpt % 10 == 0 or cpt == (max_iter - 1): + gw_dist_estimated = GW_distance_estimation(C1=C1, C2=C2, loss_fun=loss_fun, + p=p, q=q, T=T, std=False, random_state=generator) + + if gw_dist_estimated < best_gw_dist_estimated: + best_gw_dist_estimated = gw_dist_estimated + best_T = T.copy() + + if verbose: + if cpt % 200 == 0: + print('{:5s}|{:12s}'.format('It.', 'Best gw estimated') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(cpt, best_gw_dist_estimated)) + + if log: + log = {} + log["gw_dist_estimated"], log["gw_dist_std"] = GW_distance_estimation(C1=C1, C2=C2, loss_fun=loss_fun, + p=p, q=q, T=best_T, + random_state=generator) + return best_T, log + return best_T + + +def sampled_gromov_wasserstein(C1, C2, p, q, loss_fun, + nb_samples_grad=100, epsilon=1, max_iter=500, log=False, verbose=False, + random_state=None): + r""" + Returns the gromov-wasserstein transport between (C1,p) and (C2,q) using a 1-stochastic Frank-Wolfe. + This method as a O(max_iter \times Nlog(N)) time complexity by relying on the 1D Optimal Transport solver. + + The function solves the following optimization problem: + + .. math:: + GW = arg\min_T \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l} + + s.t. T 1 = p + + T^T 1= q + + T\geq 0 + + Where : + + - C1 : Metric cost matrix in the source space + - C2 : Metric cost matrix in the target space + - p : distribution in the source space + - q : distribution in the target space + - L : loss function to account for the misfit between the similarity matrices + + Parameters + ---------- + C1 : ndarray, shape (ns, ns) + Metric cost matrix in the source space + C2 : ndarray, shape (nt, nt) + Metric costfr matrix in the target space + p : ndarray, shape (ns,) + Distribution in the source space + q : ndarray, shape (nt,) + Distribution in the target space + loss_fun : function: \mathcal{R} \times \mathcal{R} \shortarrow \mathcal{R} + Loss function used for the distance, the transport plan does not depend on the loss function + nb_samples_grad : int + Number of samples to approximate the gradient + epsilon : float + Weight of the Kullback-Leiber regularization + max_iter : int, optional + Max number of iterations + verbose : bool, optional + Print information along iterations + log : bool, optional + Gives the distance estimated and the standard deviation + random_state : int or RandomState instance, optional + Fix the seed for to allow reproducibility + + Returns + ------- + T : ndarray, shape (ns, nt) + Optimal coupling between the two spaces + + References + ---------- + .. [14] Kerdoncuff, Tanguy, Emonet, Rémi, Sebban, Marc + "Sampled Gromov Wasserstein." + Machine Learning Journal (MLJ). 2021. + + """ + C1 = np.asarray(C1, dtype=np.float64) + C2 = np.asarray(C2, dtype=np.float64) + p = np.asarray(p, dtype=np.float64) + q = np.asarray(q, dtype=np.float64) + len_p = len(p) + len_q = len(q) + + generator = check_random_state(random_state) + + # The most natural way to define nb_sample is with a simple integer. + if isinstance(nb_samples_grad, int): + if nb_samples_grad > len_p: + # As the sampling along the first dimension is done without replacement, the rest is reported to the second + # dimension. + nb_samples_grad_p, nb_samples_grad_q = len_p, nb_samples_grad // len_p + else: + nb_samples_grad_p, nb_samples_grad_q = nb_samples_grad, 1 + else: + nb_samples_grad_p, nb_samples_grad_q = nb_samples_grad + T = np.outer(p, q) + # continue_loop allows to stop the loop if there is several successive small modification of T. + continue_loop = 0 + + # The gradient of GW is more complex if the two matrices are not symmetric. + C_are_symmetric = np.allclose(C1, C1.T, rtol=1e-10, atol=1e-10) and np.allclose(C2, C2.T, rtol=1e-10, atol=1e-10) + + for cpt in range(max_iter): + index0 = generator.choice(len_p, size=nb_samples_grad_p, p=p, replace=False) + Lik = 0 + for i, index0_i in enumerate(index0): + index1 = generator.choice(len_q, + size=nb_samples_grad_q, + p=T[index0_i, :] / T[index0_i, :].sum(), + replace=False) + # If the matrices C are not symmetric, the gradient has 2 terms, thus the term is chosen randomly. + if (not C_are_symmetric) and generator.rand(1) > 0.5: + Lik += np.mean(loss_fun(np.expand_dims(C1[:, np.repeat(index0[i], nb_samples_grad_q)], 1), + np.expand_dims(C2[:, index1], 0)), + axis=2) + else: + Lik += np.mean(loss_fun(np.expand_dims(C1[np.repeat(index0[i], nb_samples_grad_q), :], 2), + np.expand_dims(C2[index1, :], 1)), + axis=0) + + max_Lik = np.max(Lik) + if max_Lik == 0: + continue + # This division by the max is here to facilitate the choice of epsilon. + Lik /= max_Lik + + if epsilon > 0: + # Set to infinity all the numbers bellow exp(-200) to avoid log of 0. + log_T = np.log(np.clip(T, np.exp(-200), 1)) + log_T[log_T == -200] = -np.inf + Lik = Lik - epsilon * log_T + + try: + new_T = sinkhorn(a=p, b=q, M=Lik, reg=epsilon) + except (RuntimeWarning, UserWarning): + print("Warning catched in Sinkhorn: Return last stable T") + break + else: + new_T = emd(a=p, b=q, M=Lik) + + change_T = ((T - new_T) ** 2).mean() + if change_T <= 10e-20: + continue_loop += 1 + if continue_loop > 100: # Number max of low modifications of T + T = new_T.copy() + break + else: + continue_loop = 0 + + if verbose and cpt % 10 == 0: + if cpt % 200 == 0: + print('{:5s}|{:12s}'.format('It.', '||T_n - T_{n+1}||') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(cpt, change_T)) + T = new_T.copy() + + if log: + log = {} + log["gw_dist_estimated"], log["gw_dist_std"] = GW_distance_estimation(C1=C1, C2=C2, loss_fun=loss_fun, + p=p, q=q, T=T, random_state=generator) + return T, log + return T + + def entropic_gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, tol=1e-9, verbose=False, log=False): r""" diff --git a/test/test_gromov.py b/test/test_gromov.py index 56414a8..19d61b1 100644 --- a/test/test_gromov.py +++ b/test/test_gromov.py @@ -33,7 +33,7 @@ def test_gromov(): G = ot.gromov.gromov_wasserstein(C1, C2, p, q, 'square_loss', verbose=True) - # check constratints + # check constraints np.testing.assert_allclose( p, G.sum(1), atol=1e-04) # cf convergence gromov np.testing.assert_allclose( @@ -54,7 +54,7 @@ def test_gromov(): np.testing.assert_allclose(gw, gw_val, atol=1e-1, rtol=1e-1) # cf log=False - # check constratints + # check constraints np.testing.assert_allclose( p, G.sum(1), atol=1e-04) # cf convergence gromov np.testing.assert_allclose( @@ -83,7 +83,7 @@ def test_entropic_gromov(): G = ot.gromov.entropic_gromov_wasserstein( C1, C2, p, q, 'square_loss', epsilon=5e-4, verbose=True) - # check constratints + # check constraints np.testing.assert_allclose( p, G.sum(1), atol=1e-04) # cf convergence gromov np.testing.assert_allclose( @@ -96,13 +96,89 @@ def test_entropic_gromov(): np.testing.assert_allclose(gw, 0, atol=1e-1, rtol=1e-1) - # check constratints + # check constraints np.testing.assert_allclose( p, G.sum(1), atol=1e-04) # cf convergence gromov np.testing.assert_allclose( q, G.sum(0), atol=1e-04) # cf convergence gromov +def test_pointwise_gromov(): + n_samples = 50 # nb samples + + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=42) + + xt = xs[::-1].copy() + + p = ot.unif(n_samples) + q = ot.unif(n_samples) + + C1 = ot.dist(xs, xs) + C2 = ot.dist(xt, xt) + + C1 /= C1.max() + C2 /= C2.max() + + def loss(x, y): + return np.abs(x - y) + + G, log = ot.gromov.pointwise_gromov_wasserstein( + C1, C2, p, q, loss, max_iter=100, log=True, verbose=True, random_state=42) + + # check constraints + np.testing.assert_allclose( + p[:, np.newaxis], G.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose( + q[np.newaxis, :], G.sum(0), atol=1e-04) # cf convergence gromov + + assert log['gw_dist_estimated'] == 0.0 + assert log['gw_dist_std'] == 0.0 + + G, log = ot.gromov.pointwise_gromov_wasserstein( + C1, C2, p, q, loss, max_iter=100, alpha=0.1, log=True, verbose=True, random_state=42) + + assert log['gw_dist_estimated'] == 0.10342276348494964 + assert log['gw_dist_std'] == 0.0015952535464736394 + + +def test_sampled_gromov(): + n_samples = 50 # nb samples + + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=42) + + xt = xs[::-1].copy() + + p = ot.unif(n_samples) + q = ot.unif(n_samples) + + C1 = ot.dist(xs, xs) + C2 = ot.dist(xt, xt) + + C1 /= C1.max() + C2 /= C2.max() + + def loss(x, y): + return np.abs(x - y) + + G, log = ot.gromov.sampled_gromov_wasserstein( + C1, C2, p, q, loss, max_iter=100, epsilon=1, log=True, verbose=True, random_state=42) + + # check constraints + np.testing.assert_allclose( + p, G.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose( + q, G.sum(0), atol=1e-04) # cf convergence gromov + + assert log['gw_dist_estimated'] == 0.05679474884977278 + assert log['gw_dist_std'] == 0.0005986592106971995 + + def test_gromov_barycenter(): ns = 50 nt = 60 @@ -186,7 +262,7 @@ def test_fgw(): G, log = ot.gromov.fused_gromov_wasserstein(M, C1, C2, p, q, 'square_loss', alpha=0.5, log=True) - # check constratints + # check constraints np.testing.assert_allclose( p, G.sum(1), atol=1e-04) # cf convergence fgw np.testing.assert_allclose( @@ -203,7 +279,7 @@ def test_fgw(): np.testing.assert_allclose(fgw, 0, atol=1e-1, rtol=1e-1) - # check constratints + # check constraints np.testing.assert_allclose( p, G.sum(1), atol=1e-04) # cf convergence gromov np.testing.assert_allclose( -- cgit v1.2.3 From 7dde9e8e4b6aae756e103d49198caaa4f24150e3 Mon Sep 17 00:00:00 2001 From: ncassereau-idris <84033440+ncassereau-idris@users.noreply.github.com> Date: Tue, 28 Sep 2021 16:34:28 +0200 Subject: [MRG] Regularized OT (optim.cg) bug solve (#286) * Line search stops when derphi is 0 instead of bugging out like in some instances * pep8 compliance * Tests --- ot/optim.py | 10 ++++++---- test/test_da.py | 8 ++++++++ test/test_optim.py | 25 +++++++++++++++++++++++++ 3 files changed, 39 insertions(+), 4 deletions(-) (limited to 'test') diff --git a/ot/optim.py b/ot/optim.py index abe9e6a..0359343 100644 --- a/ot/optim.py +++ b/ot/optim.py @@ -178,9 +178,9 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000, numItermaxEmd : int, optional Max number of iterations for emd stopThr : float, optional - Stop threshol on the relative variation (>0) + Stop threshold on the relative variation (>0) stopThr2 : float, optional - Stop threshol on the absolute variation (>0) + Stop threshold on the absolute variation (>0) verbose : bool, optional Print information along iterations log : bool, optional @@ -249,6 +249,8 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000, # line search alpha, fc, f_val = solve_linesearch(cost, G, deltaG, Mi, f_val, reg=reg, M=M, Gc=Gc, **kwargs) + if alpha is None: + alpha = 0.0 G = G + alpha * deltaG @@ -320,9 +322,9 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10, numInnerItermax : int, optional Max number of iterations of Sinkhorn stopThr : float, optional - Stop threshol on the relative variation (>0) + Stop threshold on the relative variation (>0) stopThr2 : float, optional - Stop threshol on the absolute variation (>0) + Stop threshold on the absolute variation (>0) verbose : bool, optional Print information along iterations log : bool, optional diff --git a/test/test_da.py b/test/test_da.py index 44bb2e9..9f2bb50 100644 --- a/test/test_da.py +++ b/test/test_da.py @@ -565,6 +565,14 @@ def test_mapping_transport_class(): otda.fit(Xs=Xs, Xt=Xt) assert len(otda.log_.keys()) != 0 + # check that it does not crash when derphi is very close to 0 + np.random.seed(39) + Xs, ys = make_data_classif('3gauss', ns) + Xt, yt = make_data_classif('3gauss2', nt) + otda = ot.da.MappingTransport(kernel="gaussian", bias=False) + otda.fit(Xs=Xs, Xt=Xt) + np.random.seed(None) + def test_linear_mapping(): ns = 150 diff --git a/test/test_optim.py b/test/test_optim.py index fd194c2..94995d5 100644 --- a/test/test_optim.py +++ b/test/test_optim.py @@ -114,3 +114,28 @@ def test_line_search_armijo(): # Should not throw an exception and return None for alpha alpha, _, _ = ot.optim.line_search_armijo(lambda x: 1, xk, pk, gfk, old_fval) assert alpha is None + + # check line search armijo + def f(x): + return np.sum((x - 5.0) ** 2) + + def grad(x): + return 2 * (x - 5.0) + + xk = np.array([[[-5.0, -5.0]]]) + pk = np.array([[[100.0, 100.0]]]) + gfk = grad(xk) + old_fval = f(xk) + + # chech the case where the optimum is on the direction + alpha, _, _ = ot.optim.line_search_armijo(f, xk, pk, gfk, old_fval) + np.testing.assert_allclose(alpha, 0.1) + + # check the case where the direction is not far enough + pk = np.array([[[3.0, 3.0]]]) + alpha, _, _ = ot.optim.line_search_armijo(f, xk, pk, gfk, old_fval, alpha0=1.0) + np.testing.assert_allclose(alpha, 1.0) + + # check the case where the checking the wrong direction + alpha, _, _ = ot.optim.line_search_armijo(f, xk, -pk, gfk, old_fval) + assert alpha <= 0 -- cgit v1.2.3 From 1c7e7ce2da8bb362c184fb6eae71fe7e36356494 Mon Sep 17 00:00:00 2001 From: kguerda-idris <84066930+kguerda-idris@users.noreply.github.com> Date: Wed, 29 Sep 2021 15:29:31 +0200 Subject: [MRG] OpenMP support (#260) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Added : OpenMP support Restored : Epsilon and Debug mode Replaced : parmap => multiprocessing is now replace by multithreading * Commit clean up * Number of CPUs correctly calculated on SLURM clusters * Corrected number of processes for cluster slurm * Mistake corrected * parmap is now deprecated * Now a different solver is used depending on the requested number of threads * Tiny mistake corrected * Folders are now in the ot library instead of at the root * Helpers is now correctly placed * Attempt to make compilation work smoothly * OS compatible path * NumThreads now defaults to 1 * Better flags * Mistake corrected in case of OpenMP unavailability * Revert OpenMP flags modification, which do not compile on Windows * Test helper functions * Helpers comments * Documentation update * File title corrected * Warning no longer using print * Last attempt for macos compilation * pls work * atempt * solving a type error * TypeError OpenMP * Compilation finally working on Windows * Bug solve, number of threads now correctly selected * 64 bits solver to avoid overflows for bigger problems * 64 bits EMD corrected Co-authored-by: kguerda-idris Co-authored-by: ncassereau-idris <84033440+ncassereau-idris@users.noreply.github.com> Co-authored-by: ncassereau Co-authored-by: Rémi Flamary --- ot/helpers/openmp_helpers.py | 85 ++ ot/helpers/pre_build_helpers.py | 87 ++ ot/lp/EMD.h | 5 +- ot/lp/EMD_wrapper.cpp | 124 ++- ot/lp/__init__.py | 71 +- ot/lp/emd_wrap.pyx | 9 +- ot/lp/full_bipartitegraph.h | 27 +- ot/lp/full_bipartitegraph_omp.h | 234 +++++ ot/lp/network_simplex_simple.h | 210 ++--- ot/lp/network_simplex_simple_omp.h | 1699 ++++++++++++++++++++++++++++++++++++ ot/utils.py | 38 +- setup.py | 12 +- test/test_helpers.py | 26 + 13 files changed, 2442 insertions(+), 185 deletions(-) create mode 100644 ot/helpers/openmp_helpers.py create mode 100644 ot/helpers/pre_build_helpers.py create mode 100644 ot/lp/full_bipartitegraph_omp.h create mode 100644 ot/lp/network_simplex_simple_omp.h create mode 100644 test/test_helpers.py (limited to 'test') diff --git a/ot/helpers/openmp_helpers.py b/ot/helpers/openmp_helpers.py new file mode 100644 index 0000000..a6ad38b --- /dev/null +++ b/ot/helpers/openmp_helpers.py @@ -0,0 +1,85 @@ +"""Helpers for OpenMP support during the build.""" + +# This code is adapted for a large part from the astropy openmp helpers, which +# can be found at: https://github.com/astropy/extension-helpers/blob/master/extension_helpers/_openmp_helpers.py # noqa + + +import os +import sys +import textwrap +import subprocess + +from distutils.errors import CompileError, LinkError + +from pre_build_helpers import compile_test_program + + +def get_openmp_flag(compiler): + """Get openmp flags for a given compiler""" + + if hasattr(compiler, 'compiler'): + compiler = compiler.compiler[0] + else: + compiler = compiler.__class__.__name__ + + if sys.platform == "win32" and ('icc' in compiler or 'icl' in compiler): + omp_flag = ['/Qopenmp'] + elif sys.platform == "win32": + omp_flag = ['/openmp'] + elif sys.platform in ("darwin", "linux") and "icc" in compiler: + omp_flag = ['-qopenmp'] + elif sys.platform == "darwin" and 'openmp' in os.getenv('CPPFLAGS', ''): + omp_flag = [] + else: + # Default flag for GCC and clang: + omp_flag = ['-fopenmp'] + if sys.platform.startswith("darwin"): + omp_flag += ["-Xpreprocessor", "-lomp"] + return omp_flag + + +def check_openmp_support(): + """Check whether OpenMP test code can be compiled and run""" + + code = textwrap.dedent( + """\ + #include + #include + int main(void) { + #pragma omp parallel + printf("nthreads=%d\\n", omp_get_num_threads()); + return 0; + } + """) + + extra_preargs = os.getenv('LDFLAGS', None) + if extra_preargs is not None: + extra_preargs = extra_preargs.strip().split(" ") + extra_preargs = [ + flag for flag in extra_preargs + if flag.startswith(('-L', '-Wl,-rpath', '-l'))] + + extra_postargs = get_openmp_flag + + try: + output, compile_flags = compile_test_program( + code, + extra_preargs=extra_preargs, + extra_postargs=extra_postargs + ) + + if output and 'nthreads=' in output[0]: + nthreads = int(output[0].strip().split('=')[1]) + openmp_supported = len(output) == nthreads + elif "PYTHON_CROSSENV" in os.environ: + # Since we can't run the test program when cross-compiling + # assume that openmp is supported if the program can be + # compiled. + openmp_supported = True + else: + openmp_supported = False + + except (CompileError, LinkError, subprocess.CalledProcessError): + openmp_supported = False + compile_flags = [] + return openmp_supported, compile_flags diff --git a/ot/helpers/pre_build_helpers.py b/ot/helpers/pre_build_helpers.py new file mode 100644 index 0000000..93ecd6a --- /dev/null +++ b/ot/helpers/pre_build_helpers.py @@ -0,0 +1,87 @@ +"""Helpers to check build environment before actual build of POT""" + +import os +import sys +import glob +import tempfile +import setuptools # noqa +import subprocess + +from distutils.dist import Distribution +from distutils.sysconfig import customize_compiler +from numpy.distutils.ccompiler import new_compiler +from numpy.distutils.command.config_compiler import config_cc + + +def _get_compiler(): + """Get a compiler equivalent to the one that will be used to build POT + Handles compiler specified as follows: + - python setup.py build_ext --compiler= + - CC= python setup.py build_ext + """ + dist = Distribution({'script_name': os.path.basename(sys.argv[0]), + 'script_args': sys.argv[1:], + 'cmdclass': {'config_cc': config_cc}}) + + cmd_opts = dist.command_options.get('build_ext') + if cmd_opts is not None and 'compiler' in cmd_opts: + compiler = cmd_opts['compiler'][1] + else: + compiler = None + + ccompiler = new_compiler(compiler=compiler) + customize_compiler(ccompiler) + + return ccompiler + + +def compile_test_program(code, extra_preargs=[], extra_postargs=[]): + """Check that some C code can be compiled and run""" + ccompiler = _get_compiler() + + # extra_(pre/post)args can be a callable to make it possible to get its + # value from the compiler + if callable(extra_preargs): + extra_preargs = extra_preargs(ccompiler) + if callable(extra_postargs): + extra_postargs = extra_postargs(ccompiler) + + start_dir = os.path.abspath('.') + + with tempfile.TemporaryDirectory() as tmp_dir: + try: + os.chdir(tmp_dir) + + # Write test program + with open('test_program.c', 'w') as f: + f.write(code) + + os.mkdir('objects') + + # Compile, test program + ccompiler.compile(['test_program.c'], output_dir='objects', + extra_postargs=extra_postargs) + + # Link test program + objects = glob.glob( + os.path.join('objects', '*' + ccompiler.obj_extension)) + ccompiler.link_executable(objects, 'test_program', + extra_preargs=extra_preargs, + extra_postargs=extra_postargs) + + if "PYTHON_CROSSENV" not in os.environ: + # Run test program if not cross compiling + # will raise a CalledProcessError if return code was non-zero + output = subprocess.check_output('./test_program') + output = output.decode( + sys.stdout.encoding or 'utf-8').splitlines() + else: + # Return an empty output if we are cross compiling + # as we cannot run the test_program + output = [] + except Exception: + raise + finally: + os.chdir(start_dir) + + return output, extra_postargs diff --git a/ot/lp/EMD.h b/ot/lp/EMD.h index c0fe7a3..8a1f9ac 100644 --- a/ot/lp/EMD.h +++ b/ot/lp/EMD.h @@ -18,19 +18,18 @@ #include #include -#include "network_simplex_simple.h" -using namespace lemon; typedef unsigned int node_id_type; enum ProblemType { INFEASIBLE, OPTIMAL, UNBOUNDED, - MAX_ITER_REACHED + MAX_ITER_REACHED }; int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int maxIter); +int EMD_wrap_omp(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int maxIter, int numThreads); diff --git a/ot/lp/EMD_wrapper.cpp b/ot/lp/EMD_wrapper.cpp index bc873ed..2bdc172 100644 --- a/ot/lp/EMD_wrapper.cpp +++ b/ot/lp/EMD_wrapper.cpp @@ -12,16 +12,22 @@ * */ + +#include "network_simplex_simple.h" +#include "network_simplex_simple_omp.h" #include "EMD.h" +#include int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G, double* alpha, double* beta, double *cost, int maxIter) { - // beware M and C anre strored in row major C style!!! - int n, m, i, cur; + // beware M and C are stored in row major C style!!! + + using namespace lemon; + int n, m, cur; typedef FullBipartiteDigraph Digraph; - DIGRAPH_TYPEDEFS(FullBipartiteDigraph); + DIGRAPH_TYPEDEFS(Digraph); // Get the number of non zero coordinates for r and c n=0; @@ -48,7 +54,7 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G, std::vector indI(n), indJ(m); std::vector weights1(n), weights2(m); Digraph di(n, m); - NetworkSimplexSimple net(di, true, n+m, n*m, maxIter); + NetworkSimplexSimple net(di, true, n+m, ((int64_t)n)*((int64_t)m), maxIter); // Set supply and demand, don't account for 0 values (faster) @@ -76,10 +82,12 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G, net.supplyMap(&weights1[0], n, &weights2[0], m); // Set the cost of each edge + int64_t idarc = 0; for (int i=0; i0) { + n++; + }else if(val<0){ + return INFEASIBLE; + } + } + m=0; + for (int i=0; i0) { + m++; + }else if(val<0){ + return INFEASIBLE; + } + } + + // Define the graph + + std::vector indI(n), indJ(m); + std::vector weights1(n), weights2(m); + Digraph di(n, m); + NetworkSimplexSimple net(di, true, n+m, ((int64_t)n)*((int64_t)m), maxIter, numThreads); + + // Set supply and demand, don't account for 0 values (faster) + + cur=0; + for (int i=0; i0) { + weights1[ cur ] = val; + indI[cur++]=i; + } + } + + // Demand is actually negative supply... + + cur=0; + for (int i=0; i0) { + weights2[ cur ] = -val; + indJ[cur++]=i; + } + } + + + net.supplyMap(&weights1[0], n, &weights2[0], m); + + // Set the cost of each edge + int64_t idarc = 0; + for (int i=0; i 1: - res = parmap(f, [b[:, i].copy() for i in range(nb)], processes) - else: - res = list(map(f, [b[:, i].copy() for i in range(nb)])) + warnings.warn( + "The 'processes' parameter has been deprecated. " + "Multiprocessing should be done outside of POT." + ) + res = list(map(f, [b[:, i].copy() for i in range(nb)])) return res def free_support_barycenter(measures_locations, measures_weights, X_init, b=None, weights=None, numItermax=100, - stopThr=1e-7, verbose=False, log=None): + stopThr=1e-7, verbose=False, log=None, numThreads=1): r""" Solves the free support (locations of the barycenters are optimized, not the weights) Wasserstein barycenter problem (i.e. the weighted Frechet mean for the 2-Wasserstein distance), formally: @@ -512,6 +541,10 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None Print information along iterations log : bool, optional record log if True + numThreads: int or "max", optional (default=1, i.e. OpenMP is not used) + If compiled with OpenMP, chooses the number of threads to parallelize. + "max" selects the highest number possible. + Returns ------- @@ -551,7 +584,7 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None for (measure_locations_i, measure_weights_i, weight_i) in zip(measures_locations, measures_weights, weights.tolist()): M_i = dist(X, measure_locations_i) - T_i = emd(b, measure_weights_i, M_i) + T_i = emd(b, measure_weights_i, M_i, numThreads=numThreads) T_sum = T_sum + weight_i * np.reshape(1. / b, (-1, 1)) * np.matmul(T_i, measure_locations_i) displacement_square_norm = np.sum(np.square(T_sum - X)) diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx index de9a700..42e08f4 100644 --- a/ot/lp/emd_wrap.pyx +++ b/ot/lp/emd_wrap.pyx @@ -20,6 +20,7 @@ import warnings cdef extern from "EMD.h": int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int maxIter) nogil + int EMD_wrap_omp(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int maxIter, int numThreads) nogil cdef enum ProblemType: INFEASIBLE, OPTIMAL, UNBOUNDED, MAX_ITER_REACHED @@ -38,7 +39,7 @@ def check_result(result_code): @cython.boundscheck(False) @cython.wraparound(False) -def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mode="c"] b, np.ndarray[double, ndim=2, mode="c"] M, int max_iter): +def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mode="c"] b, np.ndarray[double, ndim=2, mode="c"] M, int max_iter, int numThreads): """ Solves the Earth Movers distance problem and returns the optimal transport matrix @@ -109,8 +110,10 @@ def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mod # calling the function with nogil: - result_code = EMD_wrap(n1, n2, a.data, b.data, M.data, G.data, alpha.data, beta.data, &cost, max_iter) - + if numThreads == 1: + result_code = EMD_wrap(n1, n2, a.data, b.data, M.data, G.data, alpha.data, beta.data, &cost, max_iter) + else: + result_code = EMD_wrap_omp(n1, n2, a.data, b.data, M.data, G.data, alpha.data, beta.data, &cost, max_iter, numThreads) return G, cost, alpha, beta, result_code diff --git a/ot/lp/full_bipartitegraph.h b/ot/lp/full_bipartitegraph.h index 87a1bec..713ccb5 100644 --- a/ot/lp/full_bipartitegraph.h +++ b/ot/lp/full_bipartitegraph.h @@ -23,10 +23,10 @@ * */ -#ifndef LEMON_FULL_BIPARTITE_GRAPH_H -#define LEMON_FULL_BIPARTITE_GRAPH_H +#pragma once #include "core.h" +#include ///\ingroup graphs ///\file @@ -44,16 +44,16 @@ namespace lemon { //class Node; typedef int Node; //class Arc; - typedef long long Arc; + typedef int64_t Arc; protected: int _node_num; - long long _arc_num; + int64_t _arc_num; FullBipartiteDigraphBase() {} - void construct(int n1, int n2) { _node_num = n1+n2; _arc_num = n1 * n2; _n1=n1; _n2=n2;} + void construct(int n1, int n2) { _node_num = n1+n2; _arc_num = (int64_t)n1 * (int64_t)n2; _n1=n1; _n2=n2;} public: @@ -65,25 +65,25 @@ namespace lemon { Arc arc(const Node& s, const Node& t) const { if (s<_n1 && t>=_n1) - return Arc(s * _n2 + (t-_n1) ); + return Arc((int64_t)s * (int64_t)_n2 + (int64_t)(t-_n1) ); else return Arc(-1); } int nodeNum() const { return _node_num; } - long long arcNum() const { return _arc_num; } + int64_t arcNum() const { return _arc_num; } int maxNodeId() const { return _node_num - 1; } - long long maxArcId() const { return _arc_num - 1; } + int64_t maxArcId() const { return _arc_num - 1; } Node source(Arc arc) const { return arc / _n2; } Node target(Arc arc) const { return (arc % _n2) + _n1; } static int id(Node node) { return node; } - static long long id(Arc arc) { return arc; } + static int64_t id(Arc arc) { return arc; } static Node nodeFromId(int id) { return Node(id);} - static Arc arcFromId(int id) { return Arc(id);} + static Arc arcFromId(int64_t id) { return Arc(id);} Arc findArc(Node s, Node t, Arc prev = -1) const { @@ -136,7 +136,7 @@ namespace lemon { /// /// \brief A directed full graph class. /// - /// FullBipartiteDigraph is a simple and fast implmenetation of directed full + /// FullBipartiteDigraph is a simple and fast implementation of directed full /// (complete) graphs. It contains an arc from each node to each node /// (including a loop for each node), therefore the number of arcs /// is the square of the number of nodes. @@ -203,13 +203,10 @@ namespace lemon { /// \brief Number of nodes. int nodeNum() const { return Parent::nodeNum(); } /// \brief Number of arcs. - long long arcNum() const { return Parent::arcNum(); } + int64_t arcNum() const { return Parent::arcNum(); } }; } //namespace lemon - - -#endif //LEMON_FULL_GRAPH_H diff --git a/ot/lp/full_bipartitegraph_omp.h b/ot/lp/full_bipartitegraph_omp.h new file mode 100644 index 0000000..8cbed0b --- /dev/null +++ b/ot/lp/full_bipartitegraph_omp.h @@ -0,0 +1,234 @@ +/* -*- mode: C++; indent-tabs-mode: nil; -*- + * + * This file has been adapted by Nicolas Bonneel (2013), + * from full_graph.h from LEMON, a generic C++ optimization library, + * to implement a lightweight fully connected bipartite graph. A previous + * version of this file is used as part of the Displacement Interpolation + * project, + * Web: http://www.cs.ubc.ca/labs/imager/tr/2011/DisplacementInterpolation/ + * + * + **** Original file Copyright Notice : + * Copyright (C) 2003-2010 + * Egervary Jeno Kombinatorikus Optimalizalasi Kutatocsoport + * (Egervary Research Group on Combinatorial Optimization, EGRES). + * + * Permission to use, modify and distribute this software is granted + * provided that this copyright notice appears in all copies. For + * precise terms see the accompanying LICENSE file. + * + * This software is provided "AS IS" with no warranty of any kind, + * express or implied, and with no claim as to its suitability for any + * purpose. + * + */ + +#pragma once + +#include + +///\ingroup graphs +///\file +///\brief FullBipartiteDigraph and FullBipartiteGraph classes. + + +namespace lemon_omp { + + ///This \c \#define creates convenient type definitions for the following + ///types of \c Digraph: \c Node, \c NodeIt, \c Arc, \c ArcIt, \c InArcIt, + ///\c OutArcIt, \c BoolNodeMap, \c IntNodeMap, \c DoubleNodeMap, + ///\c BoolArcMap, \c IntArcMap, \c DoubleArcMap. + /// + ///\note If the graph type is a dependent type, ie. the graph type depend + ///on a template parameter, then use \c TEMPLATE_DIGRAPH_TYPEDEFS() + ///macro. +#define DIGRAPH_TYPEDEFS(Digraph) \ + typedef Digraph::Node Node; \ + typedef Digraph::Arc Arc; \ + + + ///Create convenience typedefs for the digraph types and iterators + + ///\see DIGRAPH_TYPEDEFS + /// + ///\note Use this macro, if the graph type is a dependent type, + ///ie. the graph type depend on a template parameter. +#define TEMPLATE_DIGRAPH_TYPEDEFS(Digraph) \ + typedef typename Digraph::Node Node; \ + typedef typename Digraph::Arc Arc; \ + + + class FullBipartiteDigraphBase { + public: + + typedef FullBipartiteDigraphBase Digraph; + + //class Node; + typedef int Node; + //class Arc; + typedef int64_t Arc; + + protected: + + int _node_num; + int64_t _arc_num; + + FullBipartiteDigraphBase() {} + + void construct(int n1, int n2) { _node_num = n1+n2; _arc_num = (int64_t)n1 * (int64_t)n2; _n1=n1; _n2=n2;} + + public: + + int _n1, _n2; + + + Node operator()(int ix) const { return Node(ix); } + static int index(const Node& node) { return node; } + + Arc arc(const Node& s, const Node& t) const { + if (s<_n1 && t>=_n1) + return Arc((int64_t)s * (int64_t)_n2 + (int64_t)(t-_n1) ); + else + return Arc(-1); + } + + int nodeNum() const { return _node_num; } + int64_t arcNum() const { return _arc_num; } + + int maxNodeId() const { return _node_num - 1; } + int64_t maxArcId() const { return _arc_num - 1; } + + Node source(Arc arc) const { return arc / _n2; } + Node target(Arc arc) const { return (arc % _n2) + _n1; } + + static int id(Node node) { return node; } + static int64_t id(Arc arc) { return arc; } + + static Node nodeFromId(int id) { return Node(id);} + static Arc arcFromId(int64_t id) { return Arc(id);} + + + Arc findArc(Node s, Node t, Arc prev = -1) const { + return prev == -1 ? arc(s, t) : -1; + } + + void first(Node& node) const { + node = _node_num - 1; + } + + static void next(Node& node) { + --node; + } + + void first(Arc& arc) const { + arc = _arc_num - 1; + } + + static void next(Arc& arc) { + --arc; + } + + void firstOut(Arc& arc, const Node& node) const { + if (node>=_n1) + arc = -1; + else + arc = (node + 1) * _n2 - 1; + } + + void nextOut(Arc& arc) const { + if (arc % _n2 == 0) arc = 0; + --arc; + } + + void firstIn(Arc& arc, const Node& node) const { + if (node<_n1) + arc = -1; + else + arc = _arc_num + node - _node_num; + } + + void nextIn(Arc& arc) const { + arc -= _n2; + if (arc < 0) arc = -1; + } + + }; + + /// \ingroup graphs + /// + /// \brief A directed full graph class. + /// + /// FullBipartiteDigraph is a simple and fast implmenetation of directed full + /// (complete) graphs. It contains an arc from each node to each node + /// (including a loop for each node), therefore the number of arcs + /// is the square of the number of nodes. + /// This class is completely static and it needs constant memory space. + /// Thus you can neither add nor delete nodes or arcs, however + /// the structure can be resized using resize(). + /// + /// This type fully conforms to the \ref concepts::Digraph "Digraph concept". + /// Most of its member functions and nested classes are documented + /// only in the concept class. + /// + /// This class provides constant time counting for nodes and arcs. + /// + /// \note FullBipartiteDigraph and FullBipartiteGraph classes are very similar, + /// but there are two differences. While this class conforms only + /// to the \ref concepts::Digraph "Digraph" concept, FullBipartiteGraph + /// conforms to the \ref concepts::Graph "Graph" concept, + /// moreover FullBipartiteGraph does not contain a loop for each + /// node as this class does. + /// + /// \sa FullBipartiteGraph + class FullBipartiteDigraph : public FullBipartiteDigraphBase { + typedef FullBipartiteDigraphBase Parent; + + public: + + /// \brief Default constructor. + /// + /// Default constructor. The number of nodes and arcs will be zero. + FullBipartiteDigraph() { construct(0,0); } + + /// \brief Constructor + /// + /// Constructor. + /// \param n The number of the nodes. + FullBipartiteDigraph(int n1, int n2) { construct(n1, n2); } + + + /// \brief Returns the node with the given index. + /// + /// Returns the node with the given index. Since this structure is + /// completely static, the nodes can be indexed with integers from + /// the range [0..nodeNum()-1]. + /// The index of a node is the same as its ID. + /// \sa index() + Node operator()(int ix) const { return Parent::operator()(ix); } + + /// \brief Returns the index of the given node. + /// + /// Returns the index of the given node. Since this structure is + /// completely static, the nodes can be indexed with integers from + /// the range [0..nodeNum()-1]. + /// The index of a node is the same as its ID. + /// \sa operator()() + static int index(const Node& node) { return Parent::index(node); } + + /// \brief Returns the arc connecting the given nodes. + /// + /// Returns the arc connecting the given nodes. + /*Arc arc(Node u, Node v) const { + return Parent::arc(u, v); + }*/ + + /// \brief Number of nodes. + int nodeNum() const { return Parent::nodeNum(); } + /// \brief Number of arcs. + int64_t arcNum() const { return Parent::arcNum(); } + }; + + + + +} //namespace lemon_omp diff --git a/ot/lp/network_simplex_simple.h b/ot/lp/network_simplex_simple.h index 630b595..3b46b9b 100644 --- a/ot/lp/network_simplex_simple.h +++ b/ot/lp/network_simplex_simple.h @@ -25,15 +25,17 @@ * */ -#ifndef LEMON_NETWORK_SIMPLEX_SIMPLE_H -#define LEMON_NETWORK_SIMPLEX_SIMPLE_H +#pragma once +#undef DEBUG_LVL #define DEBUG_LVL 0 #if DEBUG_LVL>0 #include #endif - +#undef EPSILON +#undef _EPSILON +#undef MAX_DEBUG_ITER #define EPSILON 2.2204460492503131e-15 #define _EPSILON 1e-8 #define MAX_DEBUG_ITER 100000 @@ -50,6 +52,7 @@ #include #include #include +#include #include #ifdef HASHMAP #include @@ -63,6 +66,8 @@ //#include "sparse_array_n.h" #include "full_bipartitegraph.h" +#undef INVALIDNODE +#undef INVALID #define INVALIDNODE -1 #define INVALID (-1) @@ -76,16 +81,16 @@ namespace lemon { class SparseValueVector { public: - SparseValueVector(int n=0) + SparseValueVector(size_t n=0) { } - void resize(int n=0){}; - T operator[](const int id) const + void resize(size_t n=0){}; + T operator[](const size_t id) const { #ifdef HASHMAP - typename stdext::hash_map::const_iterator it = data.find(id); + typename stdext::hash_map::const_iterator it = data.find(id); #else - typename std::map::const_iterator it = data.find(id); + typename std::map::const_iterator it = data.find(id); #endif if (it==data.end()) return 0; @@ -93,16 +98,16 @@ namespace lemon { return it->second; } - ProxyObject operator[](const int id) + ProxyObject operator[](const size_t id) { return ProxyObject( this, id ); } //private: #ifdef HASHMAP - stdext::hash_map data; + stdext::hash_map data; #else - std::map data; + std::map data; #endif }; @@ -110,7 +115,7 @@ namespace lemon { template class ProxyObject { public: - ProxyObject( SparseValueVector *v, int idx ){_v=v; _idx=idx;}; + ProxyObject( SparseValueVector *v, size_t idx ){_v=v; _idx=idx;}; ProxyObject & operator=( const T &v ) { // If we get here, we know that operator[] was called to perform a write access, // so we can insert an item in the vector if needed @@ -123,9 +128,9 @@ namespace lemon { // If we get here, we know that operator[] was called to perform a read access, // so we can simply return the existing object #ifdef HASHMAP - typename stdext::hash_map::iterator it = _v->data.find(_idx); + typename stdext::hash_map::iterator it = _v->data.find(_idx); #else - typename std::map::iterator it = _v->data.find(_idx); + typename std::map::iterator it = _v->data.find(_idx); #endif if (it==_v->data.end()) return 0; @@ -137,9 +142,9 @@ namespace lemon { { if (val==0) return; #ifdef HASHMAP - typename stdext::hash_map::iterator it = _v->data.find(_idx); + typename stdext::hash_map::iterator it = _v->data.find(_idx); #else - typename std::map::iterator it = _v->data.find(_idx); + typename std::map::iterator it = _v->data.find(_idx); #endif if (it==_v->data.end()) _v->data[_idx] = val; @@ -156,9 +161,9 @@ namespace lemon { { if (val==0) return; #ifdef HASHMAP - typename stdext::hash_map::iterator it = _v->data.find(_idx); + typename stdext::hash_map::iterator it = _v->data.find(_idx); #else - typename std::map::iterator it = _v->data.find(_idx); + typename std::map::iterator it = _v->data.find(_idx); #endif if (it==_v->data.end()) _v->data[_idx] = -val; @@ -173,7 +178,7 @@ namespace lemon { } SparseValueVector *_v; - int _idx; + size_t _idx; }; @@ -204,7 +209,7 @@ namespace lemon { /// /// \tparam GR The digraph type the algorithm runs on. /// \tparam V The number type used for flow amounts, capacity bounds - /// and supply values in the algorithm. By default, it is \c int. + /// and supply values in the algorithm. By default, it is \c int64_t. /// \tparam C The number type used for costs and potentials in the /// algorithm. By default, it is the same as \c V. /// @@ -214,7 +219,7 @@ namespace lemon { /// \note %NetworkSimplexSimple provides five different pivot rule /// implementations, from which the most efficient one is used /// by default. For more information, see \ref PivotRule. - template + template class NetworkSimplexSimple { public: @@ -228,7 +233,7 @@ namespace lemon { /// mixed order in the internal data structure. /// In special cases, it could lead to better overall performance, /// but it is usually slower. Therefore it is disabled by default. - NetworkSimplexSimple(const GR& graph, bool arc_mixing, int nbnodes, long long nb_arcs,int maxiters) : + NetworkSimplexSimple(const GR& graph, bool arc_mixing, int nbnodes, ArcsType nb_arcs, size_t maxiters) : _graph(graph), //_arc_id(graph), _arc_mixing(arc_mixing), _init_nb_nodes(nbnodes), _init_nb_arcs(nb_arcs), MAX(std::numeric_limits::max()), @@ -288,11 +293,11 @@ namespace lemon { private: - int max_iter; + size_t max_iter; TEMPLATE_DIGRAPH_TYPEDEFS(GR); typedef std::vector IntVector; - typedef std::vector UHalfIntVector; + typedef std::vector ArcVector; typedef std::vector ValueVector; typedef std::vector CostVector; // typedef SparseValueVector CostVector; @@ -315,9 +320,9 @@ namespace lemon { // Data related to the underlying digraph const GR &_graph; int _node_num; - int _arc_num; - int _all_arc_num; - int _search_arc_num; + ArcsType _arc_num; + ArcsType _all_arc_num; + ArcsType _search_arc_num; // Parameters of the problem SupplyType _stype; @@ -325,9 +330,9 @@ namespace lemon { inline int _node_id(int n) const {return _node_num-n-1;} ; - //IntArcMap _arc_id; - UHalfIntVector _source; - UHalfIntVector _target; +// IntArcMap _arc_id; + IntVector _source; // keep nodes as integers + IntVector _target; bool _arc_mixing; public: // Node and arc data @@ -341,7 +346,7 @@ namespace lemon { private: // Data for storing the spanning tree structure IntVector _parent; - IntVector _pred; + ArcVector _pred; IntVector _thread; IntVector _rev_thread; IntVector _succ_num; @@ -349,17 +354,17 @@ namespace lemon { IntVector _dirty_revs; BoolVector _forward; StateVector _state; - int _root; + ArcsType _root; // Temporary data used in the current pivot iteration - int in_arc, join, u_in, v_in, u_out, v_out; - int first, second, right, last; - int stem, par_stem, new_stem; + ArcsType in_arc, join, u_in, v_in, u_out, v_out; + ArcsType first, second, right, last; + ArcsType stem, par_stem, new_stem; Value delta; const Value MAX; - int mixingCoeff; + ArcsType mixingCoeff; public: @@ -373,27 +378,27 @@ namespace lemon { private: // thank you to DVK and MizardX from StackOverflow for this function! - inline int sequence(int k) const { - int smallv = (k > num_total_big_subsequence_numbers) & 1; + inline ArcsType sequence(ArcsType k) const { + ArcsType smallv = (k > num_total_big_subsequence_numbers) & 1; k -= num_total_big_subsequence_numbers * smallv; - int subsequence_length2 = subsequence_length- smallv; - int subsequence_num = (k / subsequence_length2) + num_big_subseqiences * smallv; - int subsequence_offset = (k % subsequence_length2) * mixingCoeff; + ArcsType subsequence_length2 = subsequence_length- smallv; + ArcsType subsequence_num = (k / subsequence_length2) + num_big_subseqiences * smallv; + ArcsType subsequence_offset = (k % subsequence_length2) * mixingCoeff; return subsequence_offset + subsequence_num; } - int subsequence_length; - int num_big_subseqiences; - int num_total_big_subsequence_numbers; + ArcsType subsequence_length; + ArcsType num_big_subseqiences; + ArcsType num_total_big_subsequence_numbers; - inline int getArcID(const Arc &arc) const + inline ArcsType getArcID(const Arc &arc) const { //int n = _arc_num-arc._id-1; - int n = _arc_num-GR::id(arc)-1; + ArcsType n = _arc_num-GR::id(arc)-1; - //int a = mixingCoeff*(n%mixingCoeff) + n/mixingCoeff; - //int b = _arc_id[arc]; + //ArcsType a = mixingCoeff*(n%mixingCoeff) + n/mixingCoeff; + //ArcsType b = _arc_id[arc]; if (_arc_mixing) return sequence(n); else @@ -401,16 +406,16 @@ namespace lemon { } // finally unused because too slow - inline int getSource(const int arc) const + inline ArcsType getSource(const ArcsType arc) const { - //int a = _source[arc]; + //ArcsType a = _source[arc]; //return a; - int n = _arc_num-arc-1; + ArcsType n = _arc_num-arc-1; if (_arc_mixing) n = mixingCoeff*(n%mixingCoeff) + n/mixingCoeff; - int b; + ArcsType b; if (n>=0) b = _node_id(_graph.source(GR::arcFromId( n ) )); else @@ -436,17 +441,17 @@ namespace lemon { private: // References to the NetworkSimplexSimple class - const UHalfIntVector &_source; - const UHalfIntVector &_target; + const IntVector &_source; + const IntVector &_target; const CostVector &_cost; const StateVector &_state; const CostVector &_pi; - int &_in_arc; - int _search_arc_num; + ArcsType &_in_arc; + ArcsType _search_arc_num; // Pivot rule data - int _block_size; - int _next_arc; + ArcsType _block_size; + ArcsType _next_arc; NetworkSimplexSimple &_ns; public: @@ -460,17 +465,16 @@ namespace lemon { { // The main parameters of the pivot rule const double BLOCK_SIZE_FACTOR = 1.0; - const int MIN_BLOCK_SIZE = 10; + const ArcsType MIN_BLOCK_SIZE = 10; - _block_size = std::max( int(BLOCK_SIZE_FACTOR * - std::sqrt(double(_search_arc_num))), - MIN_BLOCK_SIZE ); + _block_size = std::max(ArcsType(BLOCK_SIZE_FACTOR * std::sqrt(double(_search_arc_num))), MIN_BLOCK_SIZE); } + // Find next entering arc bool findEnteringArc() { Cost c, min = 0; - int e; - int cnt = _block_size; + ArcsType e; + ArcsType cnt = _block_size; double a; for (e = _next_arc; e != _search_arc_num; ++e) { c = _state[e] * (_cost[e] + _pi[_source[e]] - _pi[_target[e]]); @@ -516,7 +520,7 @@ namespace lemon { int _init_nb_nodes; - long long _init_nb_arcs; + ArcsType _init_nb_arcs; /// \name Parameters /// The parameters of the algorithm can be specified using these @@ -736,7 +740,7 @@ namespace lemon { for (int i = 0; i != _node_num; ++i) { _supply[i] = 0; } - for (int i = 0; i != _arc_num; ++i) { + for (ArcsType i = 0; i != _arc_num; ++i) { _cost[i] = 1; } _stype = GEQ; @@ -745,7 +749,7 @@ namespace lemon { - int divid (int x, int y) + int64_t divid (int64_t x, int64_t y) { return (x-x%y)/y; } @@ -775,7 +779,7 @@ namespace lemon { _node_num = _init_nb_nodes; _arc_num = _init_nb_arcs; int all_node_num = _node_num + 1; - int max_arc_num = _arc_num + 2 * _node_num; + ArcsType max_arc_num = _arc_num + 2 * _node_num; _source.resize(max_arc_num); _target.resize(max_arc_num); @@ -798,13 +802,13 @@ namespace lemon { //_arc_mixing=false; if (_arc_mixing) { // Store the arcs in a mixed order - int k = std::max(int(std::sqrt(double(_arc_num))), 10); + const ArcsType k = std::max(ArcsType(std::sqrt(double(_arc_num))), ArcsType(10)); mixingCoeff = k; subsequence_length = _arc_num / mixingCoeff + 1; num_big_subseqiences = _arc_num % mixingCoeff; num_total_big_subsequence_numbers = subsequence_length * num_big_subseqiences; - int i = 0, j = 0; + ArcsType i = 0, j = 0; Arc a; _graph.first(a); for (; a != INVALID; _graph.next(a)) { _source[i] = _node_id(_graph.source(a)); @@ -814,7 +818,7 @@ namespace lemon { } } else { // Store the arcs in the original order - int i = 0; + ArcsType i = 0; Arc a; _graph.first(a); for (; a != INVALID; _graph.next(a), ++i) { _source[i] = _node_id(_graph.source(a)); @@ -856,7 +860,7 @@ namespace lemon { Number totalCost() const { Number c = 0; for (ArcIt a(_graph); a != INVALID; ++a) { - int i = getArcID(a); + int64_t i = getArcID(a); c += Number(_flow[i]) * Number(_cost[i]); } return c; @@ -867,15 +871,15 @@ namespace lemon { Number c = 0; /*#ifdef HASHMAP - typename stdext::hash_map::const_iterator it; + typename stdext::hash_map::const_iterator it; #else - typename std::map::const_iterator it; + typename std::map::const_iterator it; #endif for (it = _flow.data.begin(); it!=_flow.data.end(); ++it) c += Number(it->second) * Number(_cost[it->first]); return c;*/ - for (unsigned long i=0; i<_flow.size(); i++) + for (ArcsType i=0; i<_flow.size(); i++) c += _flow[i] * Number(_cost[i]); return c; @@ -944,14 +948,14 @@ namespace lemon { // Initialize internal data structures bool init() { if (_node_num == 0) return false; - + // Check the sum of supply values _sum_supply = 0; for (int i = 0; i != _node_num; ++i) { _sum_supply += _supply[i]; } if ( fabs(_sum_supply) > _EPSILON ) return false; - + _sum_supply = 0; // Initialize artifical cost @@ -960,14 +964,14 @@ namespace lemon { ART_COST = std::numeric_limits::max() / 2 + 1; } else { ART_COST = 0; - for (int i = 0; i != _arc_num; ++i) { + for (ArcsType i = 0; i != _arc_num; ++i) { if (_cost[i] > ART_COST) ART_COST = _cost[i]; } ART_COST = (ART_COST + 1) * _node_num; } // Initialize arc maps - for (int i = 0; i != _arc_num; ++i) { + for (ArcsType i = 0; i != _arc_num; ++i) { //_flow[i] = 0; //by default, the sparse matrix is empty _state[i] = STATE_LOWER; } @@ -988,7 +992,7 @@ namespace lemon { // EQ supply constraints _search_arc_num = _arc_num; _all_arc_num = _arc_num + _node_num; - for (int u = 0, e = _arc_num; u != _node_num; ++u, ++e) { + for (ArcsType u = 0, e = _arc_num; u != _node_num; ++u, ++e) { _parent[u] = _root; _pred[u] = e; _thread[u] = u + 1; @@ -1016,8 +1020,8 @@ namespace lemon { else if (_sum_supply > 0) { // LEQ supply constraints _search_arc_num = _arc_num + _node_num; - int f = _arc_num + _node_num; - for (int u = 0, e = _arc_num; u != _node_num; ++u, ++e) { + ArcsType f = _arc_num + _node_num; + for (ArcsType u = 0, e = _arc_num; u != _node_num; ++u, ++e) { _parent[u] = _root; _thread[u] = u + 1; _rev_thread[u + 1] = u; @@ -1054,8 +1058,8 @@ namespace lemon { else { // GEQ supply constraints _search_arc_num = _arc_num + _node_num; - int f = _arc_num + _node_num; - for (int u = 0, e = _arc_num; u != _node_num; ++u, ++e) { + ArcsType f = _arc_num + _node_num; + for (ArcsType u = 0, e = _arc_num; u != _node_num; ++u, ++e) { _parent[u] = _root; _thread[u] = u + 1; _rev_thread[u + 1] = u; @@ -1120,9 +1124,9 @@ namespace lemon { second = _source[in_arc]; } delta = INF; - int result = 0; + char result = 0; Value d; - int e; + ArcsType e; // Search the cycle along the path form the first node to the root for (int u = first; u != join; u = _parent[u]) { @@ -1239,7 +1243,7 @@ namespace lemon { // Update _rev_thread using the new _thread values for (int i = 0; i != int(_dirty_revs.size()); ++i) { - u = _dirty_revs[i]; + int u = _dirty_revs[i]; _rev_thread[_thread[u]] = u; } @@ -1257,7 +1261,7 @@ namespace lemon { u = w; } _pred[u_in] = in_arc; - _forward[u_in] = ((unsigned int)u_in == _source[in_arc]); + _forward[u_in] = (u_in == _source[in_arc]); _succ_num[u_in] = old_succ_num; // Set limits for updating _last_succ form v_in and v_out @@ -1328,7 +1332,7 @@ namespace lemon { if (_sum_supply > 0) total -= _sum_supply; if (total <= 0) return true; - IntVector arc_vector; + ArcVector arc_vector; if (_sum_supply >= 0) { if (supply_nodes.size() == 1 && demand_nodes.size() == 1) { // Perform a reverse graph search from the sink to the source @@ -1345,7 +1349,7 @@ namespace lemon { Arc a; _graph.firstIn(a, v); for (; a != INVALID; _graph.nextIn(a)) { if (reached[u = _graph.source(a)]) continue; - int j = getArcID(a); + ArcsType j = getArcID(a); if (INF >= total) { arc_vector.push_back(j); reached[u] = true; @@ -1355,7 +1359,7 @@ namespace lemon { } } else { // Find the min. cost incomming arc for each demand node - for (int i = 0; i != int(demand_nodes.size()); ++i) { + for (int i = 0; i != demand_nodes.size(); ++i) { Node v = demand_nodes[i]; Cost c, min_cost = std::numeric_limits::max(); Arc min_arc = INVALID; @@ -1393,7 +1397,7 @@ namespace lemon { } // Perform heuristic initial pivots - for (int i = 0; i != int(arc_vector.size()); ++i) { + for (ArcsType i = 0; i != arc_vector.size(); ++i) { in_arc = arc_vector[i]; // l'erreur est probablement ici... if (_state[in_arc] * (_cost[in_arc] + _pi[_source[in_arc]] - @@ -1423,7 +1427,7 @@ namespace lemon { // Perform heuristic initial pivots if (!initialPivots()) return UNBOUNDED; - int iter_number=0; + size_t iter_number=0; //pivot.setDantzig(true); // Execute the Network Simplex algorithm while (pivot.findEnteringArc()) { @@ -1443,7 +1447,7 @@ namespace lemon { double a; a= (fabs(_pi[_source[in_arc]])>=fabs(_pi[_target[in_arc]])) ? fabs(_pi[_source[in_arc]]) : fabs(_pi[_target[in_arc]]); a=a>=fabs(_cost[in_arc])?a:fabs(_cost[in_arc]); - for (int i=0; i<_flow.size(); i++) { + for (int64_t i=0; i<_flow.size(); i++) { sumFlow+=_state[i]*_flow[i]; } std::cout << "Sum of the flow " << std::setprecision(20) << sumFlow << "\n" << iter_number << " iterations, current cost=" << curCost << "\nReduced cost=" << _state[in_arc] * (_cost[in_arc] + _pi[_source[in_arc]] -_pi[_target[in_arc]]) << "\nPrecision = "<< -EPSILON*(a) << "\n"; @@ -1482,12 +1486,12 @@ namespace lemon { double a; a= (fabs(_pi[_source[in_arc]])>=fabs(_pi[_target[in_arc]])) ? fabs(_pi[_source[in_arc]]) : fabs(_pi[_target[in_arc]]); a=a>=fabs(_cost[in_arc])?a:fabs(_cost[in_arc]); - for (int i=0; i<_flow.size(); i++) { + for (int64_t i=0; i<_flow.size(); i++) { sumFlow+=_state[i]*_flow[i]; } - + std::cout << "Sum of the flow " << std::setprecision(20) << sumFlow << "\n" << niter << " iterations, current cost=" << curCost << "\nReduced cost=" << _state[in_arc] * (_cost[in_arc] + _pi[_source[in_arc]] -_pi[_target[in_arc]]) << "\nPrecision = "<< -EPSILON*(a) << "\n"; - + std::cout << "Arc in = (" << _node_id(_source[in_arc]) << ", " << _node_id(_target[in_arc]) <<")\n"; std::cout << "Supplies = (" << _supply[_source[in_arc]] << ", " << _supply[_target[in_arc]] << ")\n"; @@ -1505,7 +1509,7 @@ namespace lemon { #endif // Check feasibility if( retVal == OPTIMAL){ - for (int e = _search_arc_num; e != _all_arc_num; ++e) { + for (ArcsType e = _search_arc_num; e != _all_arc_num; ++e) { if (_flow[e] != 0){ if (fabs(_flow[e]) > _EPSILON) // change of the original code following issue #126 return INFEASIBLE; @@ -1521,20 +1525,20 @@ namespace lemon { if (_sum_supply == 0) { if (_stype == GEQ) { Cost max_pot = -std::numeric_limits::max(); - for (int i = 0; i != _node_num; ++i) { + for (ArcsType i = 0; i != _node_num; ++i) { if (_pi[i] > max_pot) max_pot = _pi[i]; } if (max_pot > 0) { - for (int i = 0; i != _node_num; ++i) + for (ArcsType i = 0; i != _node_num; ++i) _pi[i] -= max_pot; } } else { Cost min_pot = std::numeric_limits::max(); - for (int i = 0; i != _node_num; ++i) { + for (ArcsType i = 0; i != _node_num; ++i) { if (_pi[i] < min_pot) min_pot = _pi[i]; } if (min_pot < 0) { - for (int i = 0; i != _node_num; ++i) + for (ArcsType i = 0; i != _node_num; ++i) _pi[i] -= min_pot; } } @@ -1548,5 +1552,3 @@ namespace lemon { ///@} } //namespace lemon - -#endif //LEMON_NETWORK_SIMPLEX_H diff --git a/ot/lp/network_simplex_simple_omp.h b/ot/lp/network_simplex_simple_omp.h new file mode 100644 index 0000000..87e4c05 --- /dev/null +++ b/ot/lp/network_simplex_simple_omp.h @@ -0,0 +1,1699 @@ +/* -*- mode: C++; indent-tabs-mode: nil; -*- +* +* +* This file has been adapted by Nicolas Bonneel (2013), +* from network_simplex.h from LEMON, a generic C++ optimization library, +* to implement a lightweight network simplex for mass transport, more +* memory efficient than the original file. A previous version of this file +* is used as part of the Displacement Interpolation project, +* Web: http://www.cs.ubc.ca/labs/imager/tr/2011/DisplacementInterpolation/ +* +* Revisions: +* March 2015: added OpenMP parallelization +* March 2017: included Antoine Rolet's trick to make it more robust +* April 2018: IMPORTANT bug fix + uses 64bit integers (slightly slower but less risks of overflows), updated to a newer version of the algo by LEMON, sparse flow by default + minor edits. +* +* +**** Original file Copyright Notice : +* +* Copyright (C) 2003-2010 +* Egervary Jeno Kombinatorikus Optimalizalasi Kutatocsoport +* (Egervary Research Group on Combinatorial Optimization, EGRES). +* +* Permission to use, modify and distribute this software is granted +* provided that this copyright notice appears in all copies. For +* precise terms see the accompanying LICENSE file. +* +* This software is provided "AS IS" with no warranty of any kind, +* express or implied, and with no claim as to its suitability for any +* purpose. +* +*/ + +#pragma once +#undef DEBUG_LVL +#define DEBUG_LVL 0 + +#if DEBUG_LVL>0 +#include +#endif + +#undef EPSILON +#undef _EPSILON +#undef MAX_DEBUG_ITER +#define EPSILON std::numeric_limits::epsilon()*10 +#define _EPSILON 1e-8 +#define MAX_DEBUG_ITER 100000 + +/// \ingroup min_cost_flow_algs +/// +/// \file +/// \brief Network Simplex algorithm for finding a minimum cost flow. + +// if your compiler has troubles with unorderedmaps, just comment the following line to use a slower std::map instead +#define HASHMAP // now handled with unorderedmaps instead of stdext::hash_map. Should be better supported. + +#define SPARSE_FLOW // a sparse flow vector will be 10-15% slower for small problems but uses less memory and becomes faster for large problems (40k total nodes) + +#include +#include +#include +#include +#ifdef HASHMAP +#include +#else +#include +#endif +//#include "core.h" +//#include "lmath.h" + +#ifdef OMP +#include +#endif +#include + + +//#include "sparse_array_n.h" +#include "full_bipartitegraph_omp.h" + +#undef INVALIDNODE +#undef INVALID +#define INVALIDNODE -1 +#define INVALID (-1) + +namespace lemon_omp { + + int64_t max_threads = -1; + + template + class ProxyObject; + + template + class SparseValueVector + { + public: + SparseValueVector(size_t n = 0) // parameter n for compatibility with standard vectors + { + } + void resize(size_t n = 0) {}; + T operator[](const size_t id) const + { +#ifdef HASHMAP + typename std::unordered_map::const_iterator it = data.find(id); +#else + typename std::map::const_iterator it = data.find(id); +#endif + if (it == data.end()) + return 0; + else + return it->second; + } + + ProxyObject operator[](const size_t id) + { + return ProxyObject(this, id); + } + + //private: +#ifdef HASHMAP + std::unordered_map data; +#else + std::map data; +#endif + + }; + + template + class ProxyObject { + public: + ProxyObject(SparseValueVector *v, size_t idx) { _v = v; _idx = idx; }; + ProxyObject & operator=(const T &v) { + // If we get here, we know that operator[] was called to perform a write access, + // so we can insert an item in the vector if needed + if (v != 0) + _v->data[_idx] = v; + return *this; + } + + operator T() { + // If we get here, we know that operator[] was called to perform a read access, + // so we can simply return the existing object +#ifdef HASHMAP + typename std::unordered_map::iterator it = _v->data.find(_idx); +#else + typename std::map::iterator it = _v->data.find(_idx); +#endif + if (it == _v->data.end()) + return 0; + else + return it->second; + } + + void operator+=(T val) + { + if (val == 0) return; +#ifdef HASHMAP + typename std::unordered_map::iterator it = _v->data.find(_idx); +#else + typename std::map::iterator it = _v->data.find(_idx); +#endif + if (it == _v->data.end()) + _v->data[_idx] = val; + else + { + T sum = it->second + val; + if (sum == 0) + _v->data.erase(it); + else + it->second = sum; + } + } + void operator-=(T val) + { + if (val == 0) return; +#ifdef HASHMAP + typename std::unordered_map::iterator it = _v->data.find(_idx); +#else + typename std::map::iterator it = _v->data.find(_idx); +#endif + if (it == _v->data.end()) + _v->data[_idx] = -val; + else + { + T sum = it->second - val; + if (sum == 0) + _v->data.erase(it); + else + it->second = sum; + } + } + + SparseValueVector *_v; + size_t _idx; + }; + + + + /// \addtogroup min_cost_flow_algs + /// @{ + + /// \brief Implementation of the primal Network Simplex algorithm + /// for finding a \ref min_cost_flow "minimum cost flow". + /// + /// \ref NetworkSimplexSimple implements the primal Network Simplex algorithm + /// for finding a \ref min_cost_flow "minimum cost flow" + /// \ref amo93networkflows, \ref dantzig63linearprog, + /// \ref kellyoneill91netsimplex. + /// This algorithm is a highly efficient specialized version of the + /// linear programming simplex method directly for the minimum cost + /// flow problem. + /// + /// In general, %NetworkSimplexSimple is the fastest implementation available + /// in LEMON for this problem. + /// Moreover, it supports both directions of the supply/demand inequality + /// constraints. For more information, see \ref SupplyType. + /// + /// Most of the parameters of the problem (except for the digraph) + /// can be given using separate functions, and the algorithm can be + /// executed using the \ref run() function. If some parameters are not + /// specified, then default values will be used. + /// + /// \tparam GR The digraph type the algorithm runs on. + /// \tparam V The number type used for flow amounts, capacity bounds + /// and supply values in the algorithm. By default, it is \c int. + /// \tparam C The number type used for costs and potentials in the + /// algorithm. By default, it is the same as \c V. + /// + /// \warning Both number types must be signed and all input data must + /// be integer. + /// + /// \note %NetworkSimplexSimple provides five different pivot rule + /// implementations, from which the most efficient one is used + /// by default. For more information, see \ref PivotRule. + template + class NetworkSimplexSimple + { + public: + + /// \brief Constructor. + /// + /// The constructor of the class. + /// + /// \param graph The digraph the algorithm runs on. + /// \param arc_mixing Indicate if the arcs have to be stored in a + /// mixed order in the internal data structure. + /// In special cases, it could lead to better overall performance, + /// but it is usually slower. Therefore it is disabled by default. + NetworkSimplexSimple(const GR& graph, bool arc_mixing, int nbnodes, ArcsType nb_arcs, size_t maxiters = 0, int numThreads=-1) : + _graph(graph), //_arc_id(graph), + _arc_mixing(arc_mixing), _init_nb_nodes(nbnodes), _init_nb_arcs(nb_arcs), + MAX(std::numeric_limits::max()), + INF(std::numeric_limits::has_infinity ? + std::numeric_limits::infinity() : MAX) + { + // Reset data structures + reset(); + max_iter = maxiters; +#ifdef OMP + if (max_threads < 0) { + max_threads = omp_get_max_threads(); + } + if (numThreads > 0 && numThreads<=max_threads){ + num_threads = numThreads; + } else if (numThreads == -1 || numThreads>max_threads) { + num_threads = max_threads; + } else { + num_threads = 1; + } + omp_set_num_threads(num_threads); +#else + num_threads = 1; +#endif + } + + /// The type of the flow amounts, capacity bounds and supply values + typedef V Value; + /// The type of the arc costs + typedef C Cost; + + public: + /// \brief Problem type constants for the \c run() function. + /// + /// Enum type containing the problem type constants that can be + /// returned by the \ref run() function of the algorithm. + enum ProblemType { + /// The problem has no feasible solution (flow). + INFEASIBLE, + /// The problem has optimal solution (i.e. it is feasible and + /// bounded), and the algorithm has found optimal flow and node + /// potentials (primal and dual solutions). + OPTIMAL, + /// The objective function of the problem is unbounded, i.e. + /// there is a directed cycle having negative total cost and + /// infinite upper bound. + UNBOUNDED, + // The maximum number of iteration has been reached + MAX_ITER_REACHED + }; + + /// \brief Constants for selecting the type of the supply constraints. + /// + /// Enum type containing constants for selecting the supply type, + /// i.e. the direction of the inequalities in the supply/demand + /// constraints of the \ref min_cost_flow "minimum cost flow problem". + /// + /// The default supply type is \c GEQ, the \c LEQ type can be + /// selected using \ref supplyType(). + /// The equality form is a special case of both supply types. + enum SupplyType { + /// This option means that there are "greater or equal" + /// supply/demand constraints in the definition of the problem. + GEQ, + /// This option means that there are "less or equal" + /// supply/demand constraints in the definition of the problem. + LEQ + }; + + + + private: + size_t max_iter; + int num_threads; + TEMPLATE_DIGRAPH_TYPEDEFS(GR); + + typedef std::vector IntVector; + typedef std::vector ArcVector; + typedef std::vector ValueVector; + typedef std::vector CostVector; + // typedef SparseValueVector CostVector; + typedef std::vector BoolVector; + // Note: vector is used instead of vector for efficiency reasons + + // State constants for arcs + enum ArcState { + STATE_UPPER = -1, + STATE_TREE = 0, + STATE_LOWER = 1 + }; + + typedef std::vector StateVector; + // Note: vector is used instead of vector for + // efficiency reasons + + private: + + // Data related to the underlying digraph + const GR &_graph; + int _node_num; + ArcsType _arc_num; + ArcsType _all_arc_num; + ArcsType _search_arc_num; + + // Parameters of the problem + SupplyType _stype; + Value _sum_supply; + + inline int _node_id(int n) const { return _node_num - n - 1; }; + + //IntArcMap _arc_id; + IntVector _source; // keep nodes as integers + IntVector _target; + bool _arc_mixing; + + // Node and arc data + CostVector _cost; + ValueVector _supply; +#ifdef SPARSE_FLOW + SparseValueVector _flow; +#else + ValueVector _flow; +#endif + + CostVector _pi; + + // Data for storing the spanning tree structure + IntVector _parent; + ArcVector _pred; + IntVector _thread; + IntVector _rev_thread; + IntVector _succ_num; + IntVector _last_succ; + IntVector _dirty_revs; + BoolVector _forward; + StateVector _state; + ArcsType _root; + + // Temporary data used in the current pivot iteration + ArcsType in_arc, join, u_in, v_in, u_out, v_out; + ArcsType first, second, right, last; + ArcsType stem, par_stem, new_stem; + Value delta; + + const Value MAX; + + ArcsType mixingCoeff; + + public: + + /// \brief Constant for infinite upper bounds (capacities). + /// + /// Constant for infinite upper bounds (capacities). + /// It is \c std::numeric_limits::infinity() if available, + /// \c std::numeric_limits::max() otherwise. + const Value INF; + + private: + + // thank you to DVK and MizardX from StackOverflow for this function! + inline ArcsType sequence(ArcsType k) const { + ArcsType smallv = (k > num_total_big_subsequence_numbers) & 1; + + k -= num_total_big_subsequence_numbers * smallv; + ArcsType subsequence_length2 = subsequence_length - smallv; + ArcsType subsequence_num = (k / subsequence_length2) + num_big_subsequences * smallv; + ArcsType subsequence_offset = (k % subsequence_length2) * mixingCoeff; + + return subsequence_offset + subsequence_num; + } + ArcsType subsequence_length; + ArcsType num_big_subsequences; + ArcsType num_total_big_subsequence_numbers; + + inline ArcsType getArcID(const Arc &arc) const + { + //int n = _arc_num-arc._id-1; + ArcsType n = _arc_num - GR::id(arc) - 1; + + //ArcsType a = mixingCoeff*(n%mixingCoeff) + n/mixingCoeff; + //ArcsType b = _arc_id[arc]; + if (_arc_mixing) + return sequence(n); + else + return n; + } + + // finally unused because too slow + inline ArcsType getSource(const ArcsType arc) const + { + //ArcsType a = _source[arc]; + //return a; + + ArcsType n = _arc_num - arc - 1; + if (_arc_mixing) + n = mixingCoeff*(n%mixingCoeff) + n / mixingCoeff; + + ArcsType b; + if (n >= 0) + b = _node_id(_graph.source(GR::arcFromId(n))); + else + { + n = arc + 1 - _arc_num; + if (n <= _node_num) + b = _node_num; + else + if (n >= _graph._n1) + b = _graph._n1; + else + b = _graph._n1 - n; + } + + return b; + } + + + + // Implementation of the Block Search pivot rule + class BlockSearchPivotRule + { + private: + + // References to the NetworkSimplexSimple class + const IntVector &_source; + const IntVector &_target; + const CostVector &_cost; + const StateVector &_state; + const CostVector &_pi; + ArcsType &_in_arc; + ArcsType _search_arc_num; + + // Pivot rule data + ArcsType _block_size; + ArcsType _next_arc; + NetworkSimplexSimple &_ns; + + public: + + // Constructor + BlockSearchPivotRule(NetworkSimplexSimple &ns) : + _source(ns._source), _target(ns._target), + _cost(ns._cost), _state(ns._state), _pi(ns._pi), + _in_arc(ns.in_arc), _search_arc_num(ns._search_arc_num), + _next_arc(0), _ns(ns) + { + // The main parameters of the pivot rule + const double BLOCK_SIZE_FACTOR = 1; + const ArcsType MIN_BLOCK_SIZE = 10; + + _block_size = std::max(ArcsType(BLOCK_SIZE_FACTOR * std::sqrt(double(_search_arc_num))), MIN_BLOCK_SIZE); + } + + // Find next entering arc + bool findEnteringArc() { + Cost min_val = 0; + + ArcsType N = _ns.num_threads; + + std::vector minArray(N, 0); + std::vector arcId(N); + ArcsType bs = (ArcsType)ceil(_block_size / (double)N); + + for (ArcsType i = 0; i < _search_arc_num; i += _block_size) { + + ArcsType e; + int j; +#pragma omp parallel + { +#ifdef OMP + int t = omp_get_thread_num(); +#else + int t = 0; +#endif + +#pragma omp for schedule(static, bs) lastprivate(e) + for (j = 0; j < std::min(i + _block_size, _search_arc_num) - i; j++) { + e = (_next_arc + i + j); if (e >= _search_arc_num) e -= _search_arc_num; + Cost c = _state[e] * (_cost[e] + _pi[_source[e]] - _pi[_target[e]]); + if (c < minArray[t]) { + minArray[t] = c; + arcId[t] = e; + } + } + } + for (int j = 0; j < N; j++) { + if (minArray[j] < min_val) { + min_val = minArray[j]; + _in_arc = arcId[j]; + } + } + Cost a = std::abs(_pi[_source[_in_arc]]) > std::abs(_pi[_target[_in_arc]]) ? std::abs(_pi[_source[_in_arc]]) : std::abs(_pi[_target[_in_arc]]); + a = a > std::abs(_cost[_in_arc]) ? a : std::abs(_cost[_in_arc]); + if (min_val < -EPSILON*a) { + _next_arc = e; + return true; + } + } + + Cost a = fabs(_pi[_source[_in_arc]]) > fabs(_pi[_target[_in_arc]]) ? fabs(_pi[_source[_in_arc]]) : fabs(_pi[_target[_in_arc]]); + a = a > fabs(_cost[_in_arc]) ? a : fabs(_cost[_in_arc]); + if (min_val >= -EPSILON*a) return false; + + return true; + } + + + // Find next entering arc + /*bool findEnteringArc() { + Cost min_val = 0; + int N = omp_get_max_threads(); + std::vector minArray(N); + std::vector arcId(N); + + ArcsType bs = (ArcsType)ceil(_block_size / (double)N); + for (ArcsType i = 0; i < _search_arc_num; i += _block_size) { + + ArcsType maxJ = std::min(i + _block_size, _search_arc_num) - i; + ArcsType j; +#pragma omp parallel + { + int t = omp_get_thread_num(); + Cost minV = 0; + ArcsType arcStart = _next_arc + i; + ArcsType arc = -1; +#pragma omp for schedule(static, bs) + for (j = 0; j < maxJ; j++) { + ArcsType e = arcStart + j; if (e >= _search_arc_num) e -= _search_arc_num; + Cost c = _state[e] * (_cost[e] + _pi[_source[e]] - _pi[_target[e]]); + if (c < minV) { + minV = c; + arc = e; + } + } + + minArray[t] = minV; + arcId[t] = arc; + } + for (int j = 0; j < N; j++) { + if (minArray[j] < min_val) { + min_val = minArray[j]; + _in_arc = arcId[j]; + } + } + + //FIX by Antoine Rolet to avoid precision issues + Cost a = std::max(std::abs(_cost[_in_arc]), std::max(std::abs(_pi[_source[_in_arc]]), std::abs(_pi[_target[_in_arc]]))); + if (min_val <-std::numeric_limits::epsilon()*a) { + _next_arc = _next_arc + i + maxJ - 1; + if (_next_arc >= _search_arc_num) _next_arc -= _search_arc_num; + return true; + } + } + + if (min_val >= 0) { + return false; + } + + return true; + }*/ + + + /*bool findEnteringArc() { + Cost c, min = 0; + int cnt = _block_size; + int e, min_arc = _next_arc; + for (e = _next_arc; e < _search_arc_num; ++e) { + c = _state[e] * (_cost[e] + _pi[_source[e]] - _pi[_target[e]]); + if (c < min) { + min = c; + min_arc = e; + + } + if (--cnt == 0) { + if (min < 0) break; + cnt = _block_size; + + } + + } + if (min == 0 || cnt > 0) { + for (e = 0; e < _next_arc; ++e) { + c = _state[e] * (_cost[e] + _pi[_source[e]] - _pi[_target[e]]); + if (c < min) { + min = c; + min_arc = e; + + } + if (--cnt == 0) { + if (min < 0) break; + cnt = _block_size; + + } + + } + + } + if (min >= 0) return false; + _in_arc = min_arc; + _next_arc = e; + return true; + }*/ + + + + }; //class BlockSearchPivotRule + + + + public: + + + + int _init_nb_nodes; + ArcsType _init_nb_arcs; + + /// \name Parameters + /// The parameters of the algorithm can be specified using these + /// functions. + + /// @{ + + + /// \brief Set the costs of the arcs. + /// + /// This function sets the costs of the arcs. + /// If it is not used before calling \ref run(), the costs + /// will be set to \c 1 on all arcs. + /// + /// \param map An arc map storing the costs. + /// Its \c Value type must be convertible to the \c Cost type + /// of the algorithm. + /// + /// \return (*this) + template + NetworkSimplexSimple& costMap(const CostMap& map) { + Arc a; _graph.first(a); + for (; a != INVALID; _graph.next(a)) { + _cost[getArcID(a)] = map[a]; + } + return *this; + } + + + /// \brief Set the costs of one arc. + /// + /// This function sets the costs of one arcs. + /// Done for memory reasons + /// + /// \param arc An arc. + /// \param arc A cost + /// + /// \return (*this) + template + NetworkSimplexSimple& setCost(const Arc& arc, const Value cost) { + _cost[getArcID(arc)] = cost; + return *this; + } + + + /// \brief Set the supply values of the nodes. + /// + /// This function sets the supply values of the nodes. + /// If neither this function nor \ref stSupply() is used before + /// calling \ref run(), the supply of each node will be set to zero. + /// + /// \param map A node map storing the supply values. + /// Its \c Value type must be convertible to the \c Value type + /// of the algorithm. + /// + /// \return (*this) + template + NetworkSimplexSimple& supplyMap(const SupplyMap& map) { + Node n; _graph.first(n); + for (; n != INVALIDNODE; _graph.next(n)) { + _supply[_node_id(n)] = map[n]; + } + return *this; + } + template + NetworkSimplexSimple& supplyMap(const SupplyMap* map1, int n1, const SupplyMap* map2, int n2) { + Node n; _graph.first(n); + for (; n != INVALIDNODE; _graph.next(n)) { + if (n + NetworkSimplexSimple& supplyMapAll(SupplyMap val1, int n1, SupplyMap val2, int n2) { + Node n; _graph.first(n); + for (; n != INVALIDNODE; _graph.next(n)) { + if (n(*this) + NetworkSimplexSimple& stSupply(const Node& s, const Node& t, Value k) { + for (int i = 0; i != _node_num; ++i) { + _supply[i] = 0; + } + _supply[_node_id(s)] = k; + _supply[_node_id(t)] = -k; + return *this; + } + + /// \brief Set the type of the supply constraints. + /// + /// This function sets the type of the supply/demand constraints. + /// If it is not used before calling \ref run(), the \ref GEQ supply + /// type will be used. + /// + /// For more information, see \ref SupplyType. + /// + /// \return (*this) + NetworkSimplexSimple& supplyType(SupplyType supply_type) { + _stype = supply_type; + return *this; + } + + /// @} + + /// \name Execution Control + /// The algorithm can be executed using \ref run(). + + /// @{ + + /// \brief Run the algorithm. + /// + /// This function runs the algorithm. + /// The paramters can be specified using functions \ref lowerMap(), + /// \ref upperMap(), \ref costMap(), \ref supplyMap(), \ref stSupply(), + /// \ref supplyType(). + /// For example, + /// \code + /// NetworkSimplexSimple ns(graph); + /// ns.lowerMap(lower).upperMap(upper).costMap(cost) + /// .supplyMap(sup).run(); + /// \endcode + /// + /// This function can be called more than once. All the given parameters + /// are kept for the next call, unless \ref resetParams() or \ref reset() + /// is used, thus only the modified parameters have to be set again. + /// If the underlying digraph was also modified after the construction + /// of the class (or the last \ref reset() call), then the \ref reset() + /// function must be called. + /// + /// \param pivot_rule The pivot rule that will be used during the + /// algorithm. For more information, see \ref PivotRule. + /// + /// \return \c INFEASIBLE if no feasible flow exists, + /// \n \c OPTIMAL if the problem has optimal solution + /// (i.e. it is feasible and bounded), and the algorithm has found + /// optimal flow and node potentials (primal and dual solutions), + /// \n \c UNBOUNDED if the objective function of the problem is + /// unbounded, i.e. there is a directed cycle having negative total + /// cost and infinite upper bound. + /// + /// \see ProblemType, PivotRule + /// \see resetParams(), reset() + ProblemType run() { +#if DEBUG_LVL>0 + std::cout << "OPTIMAL = " << OPTIMAL << "\nINFEASIBLE = " << INFEASIBLE << "\nUNBOUNDED = " << UNBOUNDED << "\nMAX_ITER_REACHED" << MAX_ITER_REACHED << "\n" ; +#endif + if (!init()) return INFEASIBLE; +#if DEBUG_LVL>0 + std::cout << "Init done, starting iterations\n"; +#endif + + return start(); + } + + /// \brief Reset all the parameters that have been given before. + /// + /// This function resets all the paramaters that have been given + /// before using functions \ref lowerMap(), \ref upperMap(), + /// \ref costMap(), \ref supplyMap(), \ref stSupply(), \ref supplyType(). + /// + /// It is useful for multiple \ref run() calls. Basically, all the given + /// parameters are kept for the next \ref run() call, unless + /// \ref resetParams() or \ref reset() is used. + /// If the underlying digraph was also modified after the construction + /// of the class or the last \ref reset() call, then the \ref reset() + /// function must be used, otherwise \ref resetParams() is sufficient. + /// + /// For example, + /// \code + /// NetworkSimplexSimple ns(graph); + /// + /// // First run + /// ns.lowerMap(lower).upperMap(upper).costMap(cost) + /// .supplyMap(sup).run(); + /// + /// // Run again with modified cost map (resetParams() is not called, + /// // so only the cost map have to be set again) + /// cost[e] += 100; + /// ns.costMap(cost).run(); + /// + /// // Run again from scratch using resetParams() + /// // (the lower bounds will be set to zero on all arcs) + /// ns.resetParams(); + /// ns.upperMap(capacity).costMap(cost) + /// .supplyMap(sup).run(); + /// \endcode + /// + /// \return (*this) + /// + /// \see reset(), run() + NetworkSimplexSimple& resetParams() { + for (int i = 0; i != _node_num; ++i) { + _supply[i] = 0; + } + for (ArcsType i = 0; i != _arc_num; ++i) { + _cost[i] = 1; + } + _stype = GEQ; + return *this; + } + + + /// \brief Reset the internal data structures and all the parameters + /// that have been given before. + /// + /// This function resets the internal data structures and all the + /// paramaters that have been given before using functions \ref lowerMap(), + /// \ref upperMap(), \ref costMap(), \ref supplyMap(), \ref stSupply(), + /// \ref supplyType(). + /// + /// It is useful for multiple \ref run() calls. Basically, all the given + /// parameters are kept for the next \ref run() call, unless + /// \ref resetParams() or \ref reset() is used. + /// If the underlying digraph was also modified after the construction + /// of the class or the last \ref reset() call, then the \ref reset() + /// function must be used, otherwise \ref resetParams() is sufficient. + /// + /// See \ref resetParams() for examples. + /// + /// \return (*this) + /// + /// \see resetParams(), run() + NetworkSimplexSimple& reset() { + // Resize vectors + _node_num = _init_nb_nodes; + _arc_num = _init_nb_arcs; + int all_node_num = _node_num + 1; + ArcsType max_arc_num = _arc_num + 2 * _node_num; + + _source.resize(max_arc_num); + _target.resize(max_arc_num); + + _cost.resize(max_arc_num); + _supply.resize(all_node_num); + _flow.resize(max_arc_num); + _pi.resize(all_node_num); + + _parent.resize(all_node_num); + _pred.resize(all_node_num); + _forward.resize(all_node_num); + _thread.resize(all_node_num); + _rev_thread.resize(all_node_num); + _succ_num.resize(all_node_num); + _last_succ.resize(all_node_num); + _state.resize(max_arc_num); + + + //_arc_mixing=false; + if (_arc_mixing && _node_num > 1) { + // Store the arcs in a mixed order + //ArcsType k = std::max(ArcsType(std::sqrt(double(_arc_num))), ArcsType(10)); + const ArcsType k = std::max(ArcsType(_arc_num / _node_num), ArcsType(3)); + mixingCoeff = k; + subsequence_length = _arc_num / mixingCoeff + 1; + num_big_subsequences = _arc_num % mixingCoeff; + num_total_big_subsequence_numbers = subsequence_length * num_big_subsequences; + +#pragma omp parallel for schedule(static) + for (Arc a = 0; a <= _graph.maxArcId(); a++) { // --a <=> _graph.next(a) , -1 == INVALID + ArcsType i = sequence(_graph.maxArcId()-a); + _source[i] = _node_id(_graph.source(a)); + _target[i] = _node_id(_graph.target(a)); + } + } else { + // Store the arcs in the original order + ArcsType i = 0; + Arc a; _graph.first(a); + for (; a != INVALID; _graph.next(a), ++i) { + _source[i] = _node_id(_graph.source(a)); + _target[i] = _node_id(_graph.target(a)); + //_arc_id[a] = i; + } + } + + // Reset parameters + resetParams(); + return *this; + } + + /// @} + + /// \name Query Functions + /// The results of the algorithm can be obtained using these + /// functions.\n + /// The \ref run() function must be called before using them. + + /// @{ + + /// \brief Return the total cost of the found flow. + /// + /// This function returns the total cost of the found flow. + /// Its complexity is O(e). + /// + /// \note The return type of the function can be specified as a + /// template parameter. For example, + /// \code + /// ns.totalCost(); + /// \endcode + /// It is useful if the total cost cannot be stored in the \c Cost + /// type of the algorithm, which is the default return type of the + /// function. + /// + /// \pre \ref run() must be called before using this function. + /*template + Number totalCost() const { + Number c = 0; + for (ArcIt a(_graph); a != INVALID; ++a) { + int i = getArcID(a); + c += Number(_flow[i]) * Number(_cost[i]); + } + return c; + }*/ + + template + Number totalCost() const { + Number c = 0; + +#ifdef SPARSE_FLOW + #ifdef HASHMAP + typename std::unordered_map::const_iterator it; + #else + typename std::map::const_iterator it; + #endif + for (it = _flow.data.begin(); it!=_flow.data.end(); ++it) + c += Number(it->second) * Number(_cost[it->first]); + return c; +#else + for (ArcsType i = 0; i<_flow.size(); i++) + c += _flow[i] * Number(_cost[i]); + return c; +#endif + } + +#ifndef DOXYGEN + Cost totalCost() const { + return totalCost(); + } +#endif + + /// \brief Return the flow on the given arc. + /// + /// This function returns the flow on the given arc. + /// + /// \pre \ref run() must be called before using this function. + Value flow(const Arc& a) const { + return _flow[getArcID(a)]; + } + + /// \brief Return the flow map (the primal solution). + /// + /// This function copies the flow value on each arc into the given + /// map. The \c Value type of the algorithm must be convertible to + /// the \c Value type of the map. + /// + /// \pre \ref run() must be called before using this function. + template + void flowMap(FlowMap &map) const { + Arc a; _graph.first(a); + for (; a != INVALID; _graph.next(a)) { + map.set(a, _flow[getArcID(a)]); + } + } + + /// \brief Return the potential (dual value) of the given node. + /// + /// This function returns the potential (dual value) of the + /// given node. + /// + /// \pre \ref run() must be called before using this function. + Cost potential(const Node& n) const { + return _pi[_node_id(n)]; + } + + /// \brief Return the potential map (the dual solution). + /// + /// This function copies the potential (dual value) of each node + /// into the given map. + /// The \c Cost type of the algorithm must be convertible to the + /// \c Value type of the map. + /// + /// \pre \ref run() must be called before using this function. + template + void potentialMap(PotentialMap &map) const { + Node n; _graph.first(n); + for (; n != INVALID; _graph.next(n)) { + map.set(n, _pi[_node_id(n)]); + } + } + + /// @} + + private: + + // Initialize internal data structures + bool init() { + if (_node_num == 0) return false; + + // Check the sum of supply values + _sum_supply = 0; + for (int i = 0; i != _node_num; ++i) { + _sum_supply += _supply[i]; + } + /*if (!((_stype == GEQ && _sum_supply <= 0) || + (_stype == LEQ && _sum_supply >= 0))) return false;*/ + + + // Initialize artifical cost + Cost ART_COST; + if (std::numeric_limits::is_exact) { + ART_COST = std::numeric_limits::max() / 2 + 1; + } else { + ART_COST = 0; + for (ArcsType i = 0; i != _arc_num; ++i) { + if (_cost[i] > ART_COST) ART_COST = _cost[i]; + } + ART_COST = (ART_COST + 1) * _node_num; + } + + // Initialize arc maps + for (ArcsType i = 0; i != _arc_num; ++i) { +#ifndef SPARSE_FLOW + _flow[i] = 0; //by default, the sparse matrix is empty +#endif + _state[i] = STATE_LOWER; + } +#ifdef SPARSE_FLOW + _flow = SparseValueVector(); +#endif + + // Set data for the artificial root node + _root = _node_num; + _parent[_root] = -1; + _pred[_root] = -1; + _thread[_root] = 0; + _rev_thread[0] = _root; + _succ_num[_root] = _node_num + 1; + _last_succ[_root] = _root - 1; + _supply[_root] = -_sum_supply; + _pi[_root] = 0; + + // Add artificial arcs and initialize the spanning tree data structure + if (_sum_supply == 0) { + // EQ supply constraints + _search_arc_num = _arc_num; + _all_arc_num = _arc_num + _node_num; + for (ArcsType u = 0, e = _arc_num; u != _node_num; ++u, ++e) { + _parent[u] = _root; + _pred[u] = e; + _thread[u] = u + 1; + _rev_thread[u + 1] = u; + _succ_num[u] = 1; + _last_succ[u] = u; + _state[e] = STATE_TREE; + if (_supply[u] >= 0) { + _forward[u] = true; + _pi[u] = 0; + _source[e] = u; + _target[e] = _root; + _flow[e] = _supply[u]; + _cost[e] = 0; + } else { + _forward[u] = false; + _pi[u] = ART_COST; + _source[e] = _root; + _target[e] = u; + _flow[e] = -_supply[u]; + _cost[e] = ART_COST; + } + } + } else if (_sum_supply > 0) { + // LEQ supply constraints + _search_arc_num = _arc_num + _node_num; + ArcsType f = _arc_num + _node_num; + for (ArcsType u = 0, e = _arc_num; u != _node_num; ++u, ++e) { + _parent[u] = _root; + _thread[u] = u + 1; + _rev_thread[u + 1] = u; + _succ_num[u] = 1; + _last_succ[u] = u; + if (_supply[u] >= 0) { + _forward[u] = true; + _pi[u] = 0; + _pred[u] = e; + _source[e] = u; + _target[e] = _root; + _flow[e] = _supply[u]; + _cost[e] = 0; + _state[e] = STATE_TREE; + } else { + _forward[u] = false; + _pi[u] = ART_COST; + _pred[u] = f; + _source[f] = _root; + _target[f] = u; + _flow[f] = -_supply[u]; + _cost[f] = ART_COST; + _state[f] = STATE_TREE; + _source[e] = u; + _target[e] = _root; + //_flow[e] = 0; //by default, the sparse matrix is empty + _cost[e] = 0; + _state[e] = STATE_LOWER; + ++f; + } + } + _all_arc_num = f; + } else { + // GEQ supply constraints + _search_arc_num = _arc_num + _node_num; + ArcsType f = _arc_num + _node_num; + for (ArcsType u = 0, e = _arc_num; u != _node_num; ++u, ++e) { + _parent[u] = _root; + _thread[u] = u + 1; + _rev_thread[u + 1] = u; + _succ_num[u] = 1; + _last_succ[u] = u; + if (_supply[u] <= 0) { + _forward[u] = false; + _pi[u] = 0; + _pred[u] = e; + _source[e] = _root; + _target[e] = u; + _flow[e] = -_supply[u]; + _cost[e] = 0; + _state[e] = STATE_TREE; + } else { + _forward[u] = true; + _pi[u] = -ART_COST; + _pred[u] = f; + _source[f] = u; + _target[f] = _root; + _flow[f] = _supply[u]; + _state[f] = STATE_TREE; + _cost[f] = ART_COST; + _source[e] = _root; + _target[e] = u; + //_flow[e] = 0; //by default, the sparse matrix is empty + _cost[e] = 0; + _state[e] = STATE_LOWER; + ++f; + } + } + _all_arc_num = f; + } + + return true; + } + + // Find the join node + void findJoinNode() { + int u = _source[in_arc]; + int v = _target[in_arc]; + while (u != v) { + if (_succ_num[u] < _succ_num[v]) { + u = _parent[u]; + } else { + v = _parent[v]; + } + } + join = u; + } + + // Find the leaving arc of the cycle and returns true if the + // leaving arc is not the same as the entering arc + bool findLeavingArc() { + // Initialize first and second nodes according to the direction + // of the cycle + if (_state[in_arc] == STATE_LOWER) { + first = _source[in_arc]; + second = _target[in_arc]; + } else { + first = _target[in_arc]; + second = _source[in_arc]; + } + delta = INF; + char result = 0; + Value d; + ArcsType e; + + // Search the cycle along the path form the first node to the root + for (int u = first; u != join; u = _parent[u]) { + e = _pred[u]; + d = _forward[u] ? _flow[e] : INF; + if (d < delta) { + delta = d; + u_out = u; + result = 1; + } + } + // Search the cycle along the path form the second node to the root + for (int u = second; u != join; u = _parent[u]) { + e = _pred[u]; + d = _forward[u] ? INF : _flow[e]; + if (d <= delta) { + delta = d; + u_out = u; + result = 2; + } + } + + if (result == 1) { + u_in = first; + v_in = second; + } else { + u_in = second; + v_in = first; + } + return result != 0; + } + + // Change _flow and _state vectors + void changeFlow(bool change) { + // Augment along the cycle + if (delta > 0) { + Value val = _state[in_arc] * delta; + _flow[in_arc] += val; + for (int u = _source[in_arc]; u != join; u = _parent[u]) { + _flow[_pred[u]] += _forward[u] ? -val : val; + } + for (int u = _target[in_arc]; u != join; u = _parent[u]) { + _flow[_pred[u]] += _forward[u] ? val : -val; + } + } + // Update the state of the entering and leaving arcs + if (change) { + _state[in_arc] = STATE_TREE; + _state[_pred[u_out]] = + (_flow[_pred[u_out]] == 0) ? STATE_LOWER : STATE_UPPER; + } else { + _state[in_arc] = -_state[in_arc]; + } + } + + // Update the tree structure + void updateTreeStructure() { + int old_rev_thread = _rev_thread[u_out]; + int old_succ_num = _succ_num[u_out]; + int old_last_succ = _last_succ[u_out]; + v_out = _parent[u_out]; + + // Check if u_in and u_out coincide + if (u_in == u_out) { + // Update _parent, _pred, _pred_dir + _parent[u_in] = v_in; + _pred[u_in] = in_arc; + _forward[u_in] = (u_in == _source[in_arc]); + + // Update _thread and _rev_thread + if (_thread[v_in] != u_out) { + ArcsType after = _thread[old_last_succ]; + _thread[old_rev_thread] = after; + _rev_thread[after] = old_rev_thread; + after = _thread[v_in]; + _thread[v_in] = u_out; + _rev_thread[u_out] = v_in; + _thread[old_last_succ] = after; + _rev_thread[after] = old_last_succ; + } + } else { + // Handle the case when old_rev_thread equals to v_in + // (it also means that join and v_out coincide) + int thread_continue = old_rev_thread == v_in ? + _thread[old_last_succ] : _thread[v_in]; + + // Update _thread and _parent along the stem nodes (i.e. the nodes + // between u_in and u_out, whose parent have to be changed) + int stem = u_in; // the current stem node + int par_stem = v_in; // the new parent of stem + int next_stem; // the next stem node + int last = _last_succ[u_in]; // the last successor of stem + int before, after = _thread[last]; + _thread[v_in] = u_in; + _dirty_revs.clear(); + _dirty_revs.push_back(v_in); + while (stem != u_out) { + // Insert the next stem node into the thread list + next_stem = _parent[stem]; + _thread[last] = next_stem; + _dirty_revs.push_back(last); + + // Remove the subtree of stem from the thread list + before = _rev_thread[stem]; + _thread[before] = after; + _rev_thread[after] = before; + + // Change the parent node and shift stem nodes + _parent[stem] = par_stem; + par_stem = stem; + stem = next_stem; + + // Update last and after + last = _last_succ[stem] == _last_succ[par_stem] ? + _rev_thread[par_stem] : _last_succ[stem]; + after = _thread[last]; + } + _parent[u_out] = par_stem; + _thread[last] = thread_continue; + _rev_thread[thread_continue] = last; + _last_succ[u_out] = last; + + // Remove the subtree of u_out from the thread list except for + // the case when old_rev_thread equals to v_in + if (old_rev_thread != v_in) { + _thread[old_rev_thread] = after; + _rev_thread[after] = old_rev_thread; + } + + // Update _rev_thread using the new _thread values + for (int i = 0; i != int(_dirty_revs.size()); ++i) { + int u = _dirty_revs[i]; + _rev_thread[_thread[u]] = u; + } + + // Update _pred, _pred_dir, _last_succ and _succ_num for the + // stem nodes from u_out to u_in + int tmp_sc = 0, tmp_ls = _last_succ[u_out]; + for (int u = u_out, p = _parent[u]; u != u_in; u = p, p = _parent[u]) { + _pred[u] = _pred[p]; + _forward[u] = !_forward[p]; + tmp_sc += _succ_num[u] - _succ_num[p]; + _succ_num[u] = tmp_sc; + _last_succ[p] = tmp_ls; + } + _pred[u_in] = in_arc; + _forward[u_in] = (u_in == _source[in_arc]); + _succ_num[u_in] = old_succ_num; + } + + // Update _last_succ from v_in towards the root + int up_limit_out = _last_succ[join] == v_in ? join : -1; + int last_succ_out = _last_succ[u_out]; + for (int u = v_in; u != -1 && _last_succ[u] == v_in; u = _parent[u]) { + _last_succ[u] = last_succ_out; + } + + // Update _last_succ from v_out towards the root + if (join != old_rev_thread && v_in != old_rev_thread) { + for (int u = v_out; u != up_limit_out && _last_succ[u] == old_last_succ; + u = _parent[u]) { + _last_succ[u] = old_rev_thread; + } + } else if (last_succ_out != old_last_succ) { + for (int u = v_out; u != up_limit_out && _last_succ[u] == old_last_succ; + u = _parent[u]) { + _last_succ[u] = last_succ_out; + } + } + + // Update _succ_num from v_in to join + for (int u = v_in; u != join; u = _parent[u]) { + _succ_num[u] += old_succ_num; + } + // Update _succ_num from v_out to join + for (int u = v_out; u != join; u = _parent[u]) { + _succ_num[u] -= old_succ_num; + } + } + + void updatePotential() { + Cost sigma = _pi[v_in] - _pi[u_in] - + ((_forward[u_in])?_cost[in_arc]:(-_cost[in_arc])); + int end = _thread[_last_succ[u_in]]; + for (int u = u_in; u != end; u = _thread[u]) { + _pi[u] += sigma; + } + } + + + // Heuristic initial pivots + bool initialPivots() { + Value curr, total = 0; + std::vector supply_nodes, demand_nodes; + Node u; _graph.first(u); + for (; u != INVALIDNODE; _graph.next(u)) { + curr = _supply[_node_id(u)]; + if (curr > 0) { + total += curr; + supply_nodes.push_back(u); + } else if (curr < 0) { + demand_nodes.push_back(u); + } + } + if (_sum_supply > 0) total -= _sum_supply; + if (total <= 0) return true; + + ArcVector arc_vector; + if (_sum_supply >= 0) { + if (supply_nodes.size() == 1 && demand_nodes.size() == 1) { + // Perform a reverse graph search from the sink to the source + //typename GR::template NodeMap reached(_graph, false); + BoolVector reached(_node_num, false); + Node s = supply_nodes[0], t = demand_nodes[0]; + std::vector stack; + reached[t] = true; + stack.push_back(t); + while (!stack.empty()) { + Node u, v = stack.back(); + stack.pop_back(); + if (v == s) break; + Arc a; _graph.firstIn(a, v); + for (; a != INVALID; _graph.nextIn(a)) { + if (reached[u = _graph.source(a)]) continue; + ArcsType j = getArcID(a); + arc_vector.push_back(j); + reached[u] = true; + stack.push_back(u); + } + } + } else { + arc_vector.resize(demand_nodes.size()); + // Find the min. cost incomming arc for each demand node +#pragma omp parallel for + for (int i = 0; i < demand_nodes.size(); ++i) { + Node v = demand_nodes[i]; + Cost min_cost = std::numeric_limits::max(); + Arc min_arc = INVALID; + Arc a; _graph.firstIn(a, v); + for (; a != INVALID; _graph.nextIn(a)) { + Cost c = _cost[getArcID(a)]; + if (c < min_cost) { + min_cost = c; + min_arc = a; + } + } + arc_vector[i] = getArcID(min_arc); + } + arc_vector.erase(std::remove(arc_vector.begin(), arc_vector.end(), INVALID), arc_vector.end()); + } + } else { + arc_vector.resize(supply_nodes.size()); + // Find the min. cost outgoing arc for each supply node +#pragma omp parallel for + for (int i = 0; i < int(supply_nodes.size()); ++i) { + Node u = supply_nodes[i]; + Cost min_cost = std::numeric_limits::max(); + Arc min_arc = INVALID; + Arc a; _graph.firstOut(a, u); + for (; a != INVALID; _graph.nextOut(a)) { + Cost c = _cost[getArcID(a)]; + if (c < min_cost) { + min_cost = c; + min_arc = a; + } + } + arc_vector[i] = getArcID(min_arc); + } + arc_vector.erase(std::remove(arc_vector.begin(), arc_vector.end(), INVALID), arc_vector.end()); + } + + // Perform heuristic initial pivots + for (ArcsType i = 0; i != ArcsType(arc_vector.size()); ++i) { + in_arc = arc_vector[i]; + if (_state[in_arc] * (_cost[in_arc] + _pi[_source[in_arc]] - + _pi[_target[in_arc]]) >= 0) continue; + findJoinNode(); + bool change = findLeavingArc(); + if (delta >= MAX) return false; + changeFlow(change); + if (change) { + updateTreeStructure(); + updatePotential(); + } + } + return true; + } + + // Execute the algorithm + ProblemType start() { + return start(); + } + + template + ProblemType start() { + PivotRuleImpl pivot(*this); + ProblemType retVal = OPTIMAL; + + // Perform heuristic initial pivots + if (!initialPivots()) return UNBOUNDED; + + size_t iter_number = 0; + // Execute the Network Simplex algorithm + while (pivot.findEnteringArc()) { + if ((++iter_number <= max_iter&&max_iter > 0) || max_iter<=0) { +#if DEBUG_LVL>0 + if(iter_number>MAX_DEBUG_ITER) + break; + if(iter_number%1000==0||iter_number%1000==1){ + Cost curCost=totalCost(); + Value sumFlow=0; + Cost a; + a= (fabs(_pi[_source[in_arc]])>=fabs(_pi[_target[in_arc]])) ? fabs(_pi[_source[in_arc]]) : fabs(_pi[_target[in_arc]]); + a=a>=fabs(_cost[in_arc])?a:fabs(_cost[in_arc]); + for (int i=0; i<_flow.size(); i++) { + sumFlow+=_state[i]*_flow[i]; + } + std::cout << "Sum of the flow " << std::setprecision(20) << sumFlow << "\n" << iter_number << " iterations, current cost=" << curCost << "\nReduced cost=" << _state[in_arc] * (_cost[in_arc] + _pi[_source[in_arc]] -_pi[_target[in_arc]]) << "\nPrecision = "<< -EPSILON*(a) << "\n"; + std::cout << "Arc in = (" << _node_id(_source[in_arc]) << ", " << _node_id(_target[in_arc]) <<")\n"; + std::cout << "Supplies = (" << _supply[_source[in_arc]] << ", " << _supply[_target[in_arc]] << ")\n"; + std::cout << _cost[in_arc] << "\n"; + std::cout << _pi[_source[in_arc]] << "\n"; + std::cout << _pi[_target[in_arc]] << "\n"; + std::cout << a << "\n"; + } +#endif + + findJoinNode(); + bool change = findLeavingArc(); + if (delta >= MAX) return UNBOUNDED; + changeFlow(change); + if (change) { + updateTreeStructure(); + updatePotential(); + } + +#if DEBUG_LVL>0 + else{ + std::cout << "No change\n"; + } +#endif + +#if DEBUG_LVL>1 + std::cout << "Arc in = (" << _source[in_arc] << ", " << _target[in_arc] << ")\n"; +#endif + + + } else { + char errMess[1000]; + sprintf( errMess, "RESULT MIGHT BE INACURATE\nMax number of iteration reached, currently \%d. Sometimes iterations go on in cycle even though the solution has been reached, to check if it's the case here have a look at the minimal reduced cost. If it is very close to machine precision, you might actually have the correct solution, if not try setting the maximum number of iterations a bit higher\n",iter_number ); + std::cerr << errMess; + retVal = MAX_ITER_REACHED; + break; + } + + } + + + +#if DEBUG_LVL>0 + Cost curCost=totalCost(); + Value sumFlow=0; + Cost a; + a= (fabs(_pi[_source[in_arc]])>=fabs(_pi[_target[in_arc]])) ? fabs(_pi[_source[in_arc]]) : fabs(_pi[_target[in_arc]]); + a=a>=fabs(_cost[in_arc])?a:fabs(_cost[in_arc]); + for (int i=0; i<_flow.size(); i++) { + sumFlow+=_state[i]*_flow[i]; + } + + std::cout << "Sum of the flow " << std::setprecision(20) << sumFlow << "\n" << niter << " iterations, current cost=" << curCost << "\nReduced cost=" << _state[in_arc] * (_cost[in_arc] + _pi[_source[in_arc]] -_pi[_target[in_arc]]) << "\nPrecision = "<< -EPSILON*(a) << "\n"; + + std::cout << "Arc in = (" << _node_id(_source[in_arc]) << ", " << _node_id(_target[in_arc]) <<")\n"; + std::cout << "Supplies = (" << _supply[_source[in_arc]] << ", " << _supply[_target[in_arc]] << ")\n"; + +#endif + + + +#if DEBUG_LVL>1 + sumFlow=0; + for (int i=0; i<_flow.size(); i++) { + sumFlow+=_state[i]*_flow[i]; + if (_state[i]==STATE_TREE) { + std::cout << "Non zero value at (" << _node_num+1-_source[i] << ", " << _node_num+1-_target[i] << ")\n"; + } + } + std::cout << "Sum of the flow " << sumFlow << "\n"<< niter <<" iterations, current cost=" << totalCost() << "\n"; +#endif + + + + //Check feasibility + if(retVal == OPTIMAL){ + for (ArcsType e = _search_arc_num; e != _all_arc_num; ++e) { + if (_flow[e] != 0){ + if (fabs(_flow[e]) > _EPSILON) // change of the original code following issue #126 + return INFEASIBLE; + else + _flow[e]=0; + } + } + } + + // Shift potentials to meet the requirements of the GEQ/LEQ type + // optimality conditions + if (_sum_supply == 0) { + if (_stype == GEQ) { + Cost max_pot = -std::numeric_limits::max(); + for (ArcsType i = 0; i != _node_num; ++i) { + if (_pi[i] > max_pot) max_pot = _pi[i]; + } + if (max_pot > 0) { + for (ArcsType i = 0; i != _node_num; ++i) + _pi[i] -= max_pot; + } + } else { + Cost min_pot = std::numeric_limits::max(); + for (ArcsType i = 0; i != _node_num; ++i) { + if (_pi[i] < min_pot) min_pot = _pi[i]; + } + if (min_pot < 0) { + for (ArcsType i = 0; i != _node_num; ++i) + _pi[i] -= min_pot; + } + } + } + + return retVal; + } + + }; //class NetworkSimplexSimple + + ///@} + +} //namespace lemon_omp diff --git a/ot/utils.py b/ot/utils.py index 4dac0c5..6a782e6 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -7,7 +7,6 @@ Various useful functions # # License: MIT License -import multiprocessing from functools import reduce import time @@ -311,38 +310,11 @@ def label_normalization(y, start=0): return y -def fun(f, q_in, q_out): - """ Utility function for parmap with no serializing problems """ - while True: - i, x = q_in.get() - if i is None: - break - q_out.put((i, f(x))) - - -def parmap(f, X, nprocs=multiprocessing.cpu_count()): - """ paralell map for multiprocessing (only map on windows)""" - - if not sys.platform.endswith('win32') and not sys.platform.endswith('darwin'): - - q_in = multiprocessing.Queue(1) - q_out = multiprocessing.Queue() - - proc = [multiprocessing.Process(target=fun, args=(f, q_in, q_out)) - for _ in range(nprocs)] - for p in proc: - p.daemon = True - p.start() - - sent = [q_in.put((i, x)) for i, x in enumerate(X)] - [q_in.put((None, None)) for _ in range(nprocs)] - res = [q_out.get() for _ in range(len(sent))] - - [p.join() for p in proc] - - return [x for i, x in sorted(res)] - else: - return list(map(f, X)) +def parmap(f, X, nprocs="default"): + """ paralell map for multiprocessing. + The function has been deprecated and only performs a regular map. + """ + return list(map(f, X)) def check_params(**kwargs): diff --git a/setup.py b/setup.py index 37a5824..ef95eaf 100644 --- a/setup.py +++ b/setup.py @@ -11,6 +11,8 @@ from setuptools.extension import Extension import numpy from Cython.Build import cythonize +sys.path.append(os.path.join("ot", "helpers")) +from openmp_helpers import check_openmp_support # dirty but working __version__ = re.search( @@ -30,7 +32,14 @@ if 'clean' in sys.argv[1:]: os.remove('ot/lp/emd_wrap.cpp') # add platform dependant optional compilation argument -compile_args = ["-O3"] +openmp_supported, flags = check_openmp_support() +compile_args = ["/O2" if sys.platform == "win32" else "-O3"] +link_args = [] + +if openmp_supported: + compile_args += flags + ["/DOMP" if sys.platform == 'win32' else "-DOMP"] + link_args += flags + if sys.platform.startswith('darwin'): compile_args.append("-stdlib=libc++") sdk_path = subprocess.check_output(['xcrun', '--show-sdk-path']) @@ -52,6 +61,7 @@ setup( language="c++", include_dirs=[numpy.get_include(), os.path.join(ROOT, 'ot/lp')], extra_compile_args=compile_args, + extra_link_args=link_args )), platforms=['linux', 'macosx', 'windows'], download_url='https://github.com/PythonOT/POT/archive/{}.tar.gz'.format(__version__), diff --git a/test/test_helpers.py b/test/test_helpers.py new file mode 100644 index 0000000..8bd0015 --- /dev/null +++ b/test/test_helpers.py @@ -0,0 +1,26 @@ +"""Tests for helpers functions """ + +# Author: Remi Flamary +# +# License: MIT License + +import os +import sys + +sys.path.append(os.path.join("ot", "helpers")) + +from openmp_helpers import get_openmp_flag, check_openmp_support # noqa +from pre_build_helpers import _get_compiler, compile_test_program # noqa + + +def test_helpers(): + + compiler = _get_compiler() + + get_openmp_flag(compiler) + + s = '#include \n#include \n\nint main(void) {\n\tprintf("Hello world!\\n");\n\treturn 0;\n}' + output, _ = compile_test_program(s) + assert len(output) == 1 and output[0] == "Hello world!" + + check_openmp_support() -- cgit v1.2.3 From d50d8145a5c0cf69d438b018cd5f1b914905e784 Mon Sep 17 00:00:00 2001 From: AdrienCorenflos Date: Fri, 22 Oct 2021 15:05:14 +0300 Subject: Add set_gradients method for JAX backend. (#278) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Rémi Flamary --- ot/backend.py | 16 ++++++++-------- test/test_backend.py | 15 ++++++++++++++- 2 files changed, 22 insertions(+), 9 deletions(-) (limited to 'test') diff --git a/ot/backend.py b/ot/backend.py index 8f46900..2ed40af 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -287,16 +287,16 @@ class JaxBackend(Backend): return jnp.array(a).astype(type_as.dtype) def set_gradients(self, val, inputs, grads): - # no gradients for jax because it is functional + from jax.flatten_util import ravel_pytree + val, = jax.lax.stop_gradient((val,)) - # does not work - # from jax import custom_jvp - # @custom_jvp - # def f(*inputs): - # return val - # f.defjvps(*grads) - # return f(*inputs) + ravelled_inputs, _ = ravel_pytree(inputs) + ravelled_grads, _ = ravel_pytree(grads) + aux = jnp.sum(ravelled_inputs * ravelled_grads) / 2 + aux = aux - jax.lax.stop_gradient(aux) + + val, = jax.tree_map(lambda z: z + aux, (val,)) return val def zeros(self, shape, type_as=None): diff --git a/test/test_backend.py b/test/test_backend.py index bc5b00c..cbfaf94 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -345,7 +345,8 @@ def test_gradients_backends(): rnd = np.random.RandomState(0) v = rnd.randn(10) - c = rnd.randn(1) + c = rnd.randn() + e = rnd.randn() if torch: @@ -362,3 +363,15 @@ def test_gradients_backends(): assert torch.equal(v2.grad, v2) assert torch.equal(c2.grad, c2) + + if jax: + nx = ot.backend.JaxBackend() + with jax.checking_leaks(): + def fun(a, b, d): + val = b * nx.sum(a ** 4) + d + return nx.set_gradients(val, (a, b, d), (a, b, 2 * d)) + grad_val = jax.grad(fun, argnums=(0, 1, 2))(v, c, e) + + np.testing.assert_almost_equal(fun(v, c, e), c * np.sum(v ** 4) + e, decimal=4) + np.testing.assert_allclose(grad_val[0], v, atol=1e-4) + np.testing.assert_allclose(grad_val[2], 2 * e, atol=1e-4) -- cgit v1.2.3 From 7af8c2147d61349f4d99ca33318a8a125e4569aa Mon Sep 17 00:00:00 2001 From: haoran010 <62598274+haoran010@users.noreply.github.com> Date: Mon, 25 Oct 2021 10:47:22 +0200 Subject: [MRG] Regularization path for l2 UOT (#274) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add reg path * debug examples and verify pep8 * pep8 and move the reg path examples in unbalanced folder Co-authored-by: haoran010 Co-authored-by: Rémi Flamary --- examples/unbalanced-partial/plot_regpath.py | 135 +++++ ot/__init__.py | 3 +- ot/regpath.py | 827 ++++++++++++++++++++++++++++ test/test_regpath.py | 64 +++ 4 files changed, 1028 insertions(+), 1 deletion(-) create mode 100644 examples/unbalanced-partial/plot_regpath.py create mode 100644 ot/regpath.py create mode 100644 test/test_regpath.py (limited to 'test') diff --git a/examples/unbalanced-partial/plot_regpath.py b/examples/unbalanced-partial/plot_regpath.py new file mode 100644 index 0000000..4a51c2d --- /dev/null +++ b/examples/unbalanced-partial/plot_regpath.py @@ -0,0 +1,135 @@ +# -*- coding: utf-8 -*- +""" +================================================================ +Regularization path of l2-penalized unbalanced optimal transport +================================================================ +This example illustrate the regularization path for 2D unbalanced +optimal transport. We present here both the fully relaxed case +and the semi-relaxed case. + +[Chapel et al., 2021] Chapel, L., Flamary, R., Wu, H., Févotte, C., +and Gasso, G. (2021). Unbalanced optimal transport through non-negative +penalized linear regression. +""" + +# Author: Haoran Wu +# License: MIT License + + +import numpy as np +import matplotlib.pylab as pl +import ot + +############################################################################## +# Generate data +# ------------- + +#%% parameters and data generation + +n = 50 # nb samples + +mu_s = np.array([-1, -1]) +cov_s = np.array([[1, 0], [0, 1]]) + +mu_t = np.array([4, 4]) +cov_t = np.array([[1, -.8], [-.8, 1]]) + +np.random.seed(0) +xs = ot.datasets.make_2D_samples_gauss(n, mu_s, cov_s) +xt = ot.datasets.make_2D_samples_gauss(n, mu_t, cov_t) + +a, b = np.ones((n,)) / n, np.ones((n,)) / n # uniform distribution on samples + +# loss matrix +M = ot.dist(xs, xt) +M /= M.max() + +############################################################################## +# Plot data +# --------- + +#%% plot 2 distribution samples + +pl.figure(1) +pl.scatter(xs[:, 0], xs[:, 1], c='C0', label='Source') +pl.scatter(xt[:, 0], xt[:, 1], c='C1', label='Target') +pl.legend(loc=2) +pl.title('Source and target distributions') +pl.show() + +############################################################################## +# Compute semi-relaxed and fully relaxed regularization paths +# ----------- + +#%% +final_gamma = 1e-8 +t, t_list, g_list = ot.regpath.regularization_path(a, b, M, reg=final_gamma, + semi_relaxed=False) +t2, t_list2, g_list2 = ot.regpath.regularization_path(a, b, M, reg=final_gamma, + semi_relaxed=True) + + +############################################################################## +# Plot the regularization path +# ---------------- + +#%% fully relaxed l2-penalized UOT + +pl.figure(2) +selected_gamma = [2e-1, 1e-1, 5e-2, 1e-3] +for p in range(4): + tp = ot.regpath.compute_transport_plan(selected_gamma[p], g_list, + t_list) + P = tp.reshape((n, n)) + pl.subplot(2, 2, p + 1) + if P.sum() > 0: + P = P / P.max() + for i in range(n): + for j in range(n): + if P[i, j] > 0: + pl.plot([xs[i, 0], xt[j, 0]], [xs[i, 1], xt[j, 1]], color='C2', + alpha=P[i, j] * 0.3) + pl.scatter(xs[:, 0], xs[:, 1], c='C0', alpha=0.2) + pl.scatter(xt[:, 0], xt[:, 1], c='C1', alpha=0.2) + pl.scatter(xs[:, 0], xs[:, 1], c='C0', s=P.sum(1).ravel() * (1 + p) * 2, + label='Re-weighted source', alpha=1) + pl.scatter(xt[:, 0], xt[:, 1], c='C1', s=P.sum(0).ravel() * (1 + p) * 2, + label='Re-weighted target', alpha=1) + pl.plot([], [], color='C2', alpha=0.8, label='OT plan') + pl.title(r'$\ell_2$ UOT $\gamma$={}'.format(selected_gamma[p]), + fontsize=11) + if p < 2: + pl.xticks(()) +pl.show() + + +############################################################################## +# Plot the semi-relaxed regularization path +# ------------------- + +#%% semi-relaxed l2-penalized UOT + +pl.figure(3) +selected_gamma = [10, 1, 1e-1, 1e-2] +for p in range(4): + tp = ot.regpath.compute_transport_plan(selected_gamma[p], g_list2, + t_list2) + P = tp.reshape((n, n)) + pl.subplot(2, 2, p + 1) + if P.sum() > 0: + P = P / P.max() + for i in range(n): + for j in range(n): + if P[i, j] > 0: + pl.plot([xs[i, 0], xt[j, 0]], [xs[i, 1], xt[j, 1]], color='C2', + alpha=P[i, j] * 0.3) + pl.scatter(xs[:, 0], xs[:, 1], c='C0', alpha=0.2) + pl.scatter(xt[:, 0], xt[:, 1], c='C1', alpha=1, label='Target marginal') + pl.scatter(xs[:, 0], xs[:, 1], c='C0', s=P.sum(1).ravel() * 2 * (1 + p), + label='Source marginal', alpha=1) + pl.plot([], [], color='C2', alpha=0.8, label='OT plan') + pl.title(r'Semi-relaxed $l_2$ UOT $\gamma$={}'.format(selected_gamma[p]), + fontsize=11) + if p < 2: + pl.xticks(()) +pl.show() diff --git a/ot/__init__.py b/ot/__init__.py index 3b072c6..5bd4bab 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -34,6 +34,7 @@ from . import stochastic from . import unbalanced from . import partial from . import backend +from . import regpath # OT functions from .lp import emd, emd2, emd_1d, emd2_1d, wasserstein_1d @@ -54,4 +55,4 @@ __all__ = ['emd', 'emd2', 'emd_1d', 'sinkhorn', 'sinkhorn2', 'utils', 'dist', 'unif', 'barycenter', 'sinkhorn_lpl1_mm', 'da', 'optim', 'sinkhorn_unbalanced', 'barycenter_unbalanced', 'sinkhorn_unbalanced2', 'sliced_wasserstein_distance', - 'smooth', 'stochastic', 'unbalanced', 'partial'] + 'smooth', 'stochastic', 'unbalanced', 'partial', 'regpath'] diff --git a/ot/regpath.py b/ot/regpath.py new file mode 100644 index 0000000..269937a --- /dev/null +++ b/ot/regpath.py @@ -0,0 +1,827 @@ +# -*- coding: utf-8 -*- +""" +Regularization path OT solvers +""" + +# Author: Haoran Wu +# License: MIT License + +import numpy as np +import scipy.sparse as sp + + +def recast_ot_as_lasso(a, b, C): + r"""This function recasts the l2-penalized UOT problem as a Lasso problem + + Recall the l2-penalized UOT problem defined in [Chapel et al., 2021] + .. math:: + UOT = \min_T + \lambda \|T 1_m - a\|_2^2 + + \lambda \|T^T 1_n - b\|_2^2 + s.t. + T \geq 0 + where : + - C is the (dim_a, dim_b) metric cost matrix + - :math:`\lambda` is the l2-regularization coefficient + - a and b are source and target distributions + - T is the transport plan to optimize + + The problem above can be reformulated to a non-negative penalized + linear regression problem, particularly Lasso + .. math:: + UOT2 = \min_t \gamma c^T t + 0.5 * \|H t - y\|_2^2 + s.t. + t \geq 0 + where : + - c is a (dim_a * dim_b, ) metric cost vector (flattened version of C) + - :math:`\gamma = 1/\lambda` is the l2-regularization coefficient + - y is the concatenation of vectors a and b, defined as y^T = [a^T b^T] + - H is a (dim_a + dim_b, dim_a * dim_b) metric matrix, + see [Chapel et al., 2021] for the design of H. The matrix product H t + computes both the source marginal and the target marginal. + - t is a (dim_a * dim_b, ) metric vector (flattened version of T) + Parameters + ---------- + a : np.ndarray (dim_a,) + Histogram of dimension dim_a + b : np.ndarray (dim_b,) + Histogram of dimension dim_b + C : np.ndarray, shape (dim_a, dim_b) + Cost matrix + Returns + ------- + H : np.ndarray (dim_a+dim_b, dim_a*dim_b) + Auxiliary matrix constituted by 0 and 1 + y : np.ndarray (ns + nt, ) + Concatenation of histogram a and histogram b + c : np.ndarray (ns * nt, ) + Flattened array of cost matrix + Examples + -------- + >>> import ot + >>> a = np.array([0.2, 0.3, 0.5]) + >>> b = np.array([0.1, 0.9]) + >>> C = np.array([[16., 25.], [28., 16.], [40., 36.]]) + >>> H, y, c = ot.regpath.recast_ot_as_lasso(a, b, C) + >>> H.toarray() + array([[1., 1., 0., 0., 0., 0.], + [0., 0., 1., 1., 0., 0.], + [0., 0., 0., 0., 1., 1.], + [1., 0., 1., 0., 1., 0.], + [0., 1., 0., 1., 0., 1.]]) + >>> y + array([0.2, 0.3, 0.5, 0.1, 0.9]) + >>> c + array([16., 25., 28., 16., 40., 36.]) + + References + ---------- + [Chapel et al., 2021]: + Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). + Unbalanced optimal transport through non-negative penalized + linear regression. + """ + + dim_a = np.shape(a)[0] + dim_b = np.shape(b)[0] + y = np.concatenate((a, b)) + c = C.flatten() + jHa = np.arange(dim_a * dim_b) + iHa = np.repeat(np.arange(dim_a), dim_b) + jHb = np.arange(dim_a * dim_b) + iHb = np.tile(np.arange(dim_b), dim_a) + dim_a + j = np.concatenate((jHa, jHb)) + i = np.concatenate((iHa, iHb)) + H = sp.csc_matrix((np.ones(dim_a * dim_b * 2), (i, j)), + shape=(dim_a + dim_b, dim_a * dim_b)) + return H, y, c + + +def recast_semi_relaxed_as_lasso(a, b, C): + r"""This function recasts the semi-relaxed l2-UOT problem as Lasso problem + + .. math:: + semi-relaxed UOT = \min_T + \lambda \|T 1_m - a\|_2^2 + s.t. + T^T 1_n = b + t \geq 0 + where : + - C is the (dim_a, dim_b) metric cost matrix + - :math:`\lambda` is the l2-regularization coefficient + - a and b are source and target distributions + - T is the transport plan to optimize + + The problem above can be reformulated as follows + .. math:: + semi-relaxed UOT2 = \min_t \gamma c^T t + 0.5 * \|H_r t - a\|_2^2 + s.t. + H_c t = b + t \geq 0 + where : + - c is a (dim_a * dim_b, ) metric cost vector (flattened version of C) + - :math:`\gamma = 1/\lambda` is the l2-regularization coefficient + - H_r is a (dim_a, dim_a * dim_b) metric matrix, + which computes the sum along the rows of transport plan T + - H_c is a (dim_b, dim_a * dim_b) metric matrix, + which computes the sum along the columns of transport plan T + - t is a (dim_a * dim_b, ) metric vector (flattened version of T) + Parameters + ---------- + a : np.ndarray (dim_a,) + Histogram of dimension dim_a + b : np.ndarray (dim_b,) + Histogram of dimension dim_b + C : np.ndarray, shape (dim_a, dim_b) + Cost matrix + Returns + ------- + Hr : np.ndarray (dim_a, dim_a * dim_b) + Auxiliary matrix constituted by 0 and 1, which computes + the sum along the rows of transport plan T + Hc : np.ndarray (dim_b, dim_a * dim_b) + Auxiliary matrix constituted by 0 and 1, which computes + the sum along the columns of transport plan T + c : np.ndarray (ns * nt, ) + Flattened array of cost matrix + Examples + -------- + >>> import ot + >>> a = np.array([0.2, 0.3, 0.5]) + >>> b = np.array([0.1, 0.9]) + >>> C = np.array([[16., 25.], [28., 16.], [40., 36.]]) + >>> Hr,Hc,c = ot.regpath.recast_semi_relaxed_as_lasso(a, b, C) + >>> Hr.toarray() + array([[1., 1., 0., 0., 0., 0.], + [0., 0., 1., 1., 0., 0.], + [0., 0., 0., 0., 1., 1.]]) + >>> Hc.toarray() + array([[1., 0., 1., 0., 1., 0.], + [0., 1., 0., 1., 0., 1.]]) + >>> c + array([16., 25., 28., 16., 40., 36.]) + """ + + dim_a = np.shape(a)[0] + dim_b = np.shape(b)[0] + + c = C.flatten() + jHr = np.arange(dim_a * dim_b) + iHr = np.repeat(np.arange(dim_a), dim_b) + jHc = np.arange(dim_a * dim_b) + iHc = np.tile(np.arange(dim_b), dim_a) + + Hr = sp.csc_matrix((np.ones(dim_a * dim_b), (iHr, jHr)), + shape=(dim_a, dim_a * dim_b)) + Hc = sp.csc_matrix((np.ones(dim_a * dim_b), (iHc, jHc)), + shape=(dim_b, dim_a * dim_b)) + + return Hr, Hc, c + + +def ot_next_gamma(phi, delta, HtH, Hty, c, active_index, current_gamma): + r""" This function computes the next value of gamma if a variable + will be added in next iteration of the regularization path + + We look for the largest value of gamma such that + the gradient of an inactive variable vanishes + .. math:: + \max_{i \in \bar{A}} \frac{h_i^T(H_A \phi - y)}{h_i^T H_A \delta - c_i} + where : + - A is the current active set + - h_i is the ith column of auxiliary matrix H + - H_A is the sub-matrix constructed by the columns of H + whose indices belong to the active set A + - c_i is the ith element of cost vector c + - y is the concatenation of source and target distribution + - :math:`\phi` is the intercept of the solutions in current iteration + - :math:`\delta` is the slope of the solutions in current iteration + Parameters + ---------- + phi : np.ndarray (|A|, ) + Intercept of the solutions in current iteration (t is piecewise linear) + delta : np.ndarray (|A|, ) + Slope of the solutions in current iteration (t is piecewise linear) + HtH : np.ndarray (dim_a * dim_b, dim_a * dim_b) + Matrix product of H^T H + Hty : np.ndarray (dim_a + dim_b, ) + Matrix product of H^T y + c: np.ndarray (dim_a * dim_b, ) + Flattened array of cost matrix C + active_index : list + Indices of active variables + current_gamma : float + Value of regularization coefficient at the start of current iteration + Returns + ------- + next_gamma : float + Value of gamma if a variable is added to active set in next iteration + next_active_index : int + Index of variable to be activated + References + ---------- + [Chapel et al., 2021]: + Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). + Unbalanced optimal transport through non-negative penalized + linear regression. + """ + M = (HtH[:, active_index].dot(phi) - Hty) / \ + (HtH[:, active_index].dot(delta) - c + 1e-16) + M[active_index] = 0 + M[M > (current_gamma - 1e-10 * current_gamma)] = 0 + return np.max(M), np.argmax(M) + + +def semi_relaxed_next_gamma(phi, delta, phi_u, delta_u, HrHr, Hc, Hra, + c, active_index, current_gamma): + r""" This function computes the next value of gamma when a variable is + active in the regularization path of semi-relaxed UOT. + + By taking the Lagrangian form of the problem, we obtain a similar update + as the two-sided relaxed UOT + .. math:: + \max_{i \in \bar{A}} \frac{h_{r i}^T(H_{r A} \phi - a) + h_{c i}^T + \phi_u}{h_{r i}^T H_{r A} \delta + h_{c i} \delta_u - c_i} + where : + - A is the current active set + - h_{r i} is the ith column of the matrix H_r + - h_{c i} is the ith column of the matrix H_c + - H_{r A} is the sub-matrix constructed by the columns of H_r + whose indices belong to the active set A + - c_i is the ith element of cost vector c + - y is the concatenation of source and target distribution + - :math:`\phi` is the intercept of the solutions in current iteration + - :math:`\delta` is the slope of the solutions in current iteration + - :math:`\phi_u` is the intercept of Lagrange parameter in current + iteration + - :math:`\delta_u` is the slope of Lagrange parameter in current iteration + Parameters + ---------- + phi : np.ndarray (|A|, ) + Intercept of the solutions in current iteration (t is piecewise linear) + delta : np.ndarray (|A|, ) + Slope of the solutions in current iteration (t is piecewise linear) + phi_u : np.ndarray (dim_b, ) + Intercept of the Lagrange parameter in current iteration (also linear) + delta_u : np.ndarray (dim_b, ) + Slope of the Lagrange parameter in current iteration (also linear) + HrHr : np.ndarray (dim_a * dim_b, dim_a * dim_b) + Matrix product of H_r^T H_r + Hc : np.ndarray (dim_b, dim_a * dim_b) + Matrix that computes the sum along the columns of transport plan T + Hra : np.ndarray (dim_a * dim_b, ) + Matrix product of H_r^T a + c: np.ndarray (dim_a * dim_b, ) + Flattened array of cost matrix C + active_index : list + Indices of active variables + current_gamma : float + Value of regularization coefficient at the start of current iteration + Returns + ------- + next_gamma : float + Value of gamma if a variable is added to active set in next iteration + next_active_index : int + Index of variable to be activated + References + ---------- + [Chapel et al., 2021]: + Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). + Unbalanced optimal transport through non-negative penalized + linear regression. + """ + + M = (HrHr[:, active_index].dot(phi) - Hra + Hc.T.dot(phi_u)) / \ + (HrHr[:, active_index].dot(delta) - c + Hc.T.dot(delta_u) + 1e-16) + M[active_index] = 0 + M[M > (current_gamma - 1e-10 * current_gamma)] = 0 + return np.max(M), np.argmax(M) + + +def compute_next_removal(phi, delta, current_gamma): + r""" This function computes the next value of gamma if a variable + is removed in next iteration of regularization path + + We look for the largest value of gamma such that + an element of current solution vanishes + .. math:: + \max_{j \in A} \frac{\phi_j}{\delta_j} + where : + - A is the current active set + - phi_j is the jth element of the intercept of current solution + - delta_j is the jth elemnt of the slope of current solution + Parameters + ---------- + phi : np.ndarray (|A|, ) + Intercept of the solutions in current iteration (t is piecewise linear) + delta : np.ndarray (|A|, ) + Slope of the solutions in current iteration (t is piecewise linear) + current_gamma : float + Value of regularization coefficient at the start of current iteration + Returns + ------- + next_removal_gamma : float + Value of gamma if a variable is removed in next iteration + next_removal_index : int + Index of the variable to remove in next iteration + References + ---------- + [Chapel et al., 2021]: + Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). + Unbalanced optimal transport through non-negative penalized + linear regression. + """ + r_candidate = phi / (delta - 1e-16) + r_candidate[r_candidate >= (1 - 1e-8) * current_gamma] = 0 + return np.max(r_candidate), np.argmax(r_candidate) + + +def complement_schur(M_current, b, d, id_pop): + r""" This function computes the inverse of matrix in regularization path + using Schur complement + + Two cases may arise: Firstly one variable is added to the active set + .. math:: + M_{k+1}^{-1} = + \begin{bmatrix} + M_{k}^{-1} + s^{-1} M_{k}^{-1} b b^T M_{k}^{-1} & -s^{-1} \\ + - s^{-1} b^T M_{k}^{-1} & s^{-1} + \end{bmatrix} + where : + - :math:`M_k^{-1}` is the inverse of matrix in previous iteration and + :math:`M_k` is the upper left block matrix in Schur formulation + - b is the upper right block matrix in Schur formulation. In our case, + b is reduced to a column vector and b^T is the lower left block matrix + - s is the Schur complement, given by + :math:`s = d - b^T M_{k}^{-1} b` in our case + + Secondly, one variable is removed from the active set + .. math:: + M_{k+1}^{-1} = M^{-1}_{A_k \backslash q} - + \frac{r_{-q,q} r^{T}_{-q,q}}{r_{q,q}} + where : + - q is the index of column and row to delete + - :math:`M^{-1}_{A_k \backslash q}` is the previous inverse matrix + without qth column and qth row + - r_{-q,q} is the qth column of :math:`M^{-1}_{k}` without the qth element + - r_{q, q} is the element of qth column and qth row in :math:`M^{-1}_{k}` + Parameters + ---------- + M_current : np.ndarray (|A|-1, |A|-1) + Inverse matrix in previous iteration + b : np.ndarray (|A|-1, ) + Upper right matrix in Schur complement, a column vector in our case + d : float + Lower right matrix in Schur complement, a scalar in our case + id_pop + Index of the variable to be removed, equal to -1 + if none of the variables is deleted in current iteration + Returns + ------- + M : np.ndarray (|A|, |A|) + Inverse matrix needed in current iteration + References + ---------- + [Chapel et al., 2021]: + Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). + Unbalanced optimal transport through non-negative penalized + linear regression. + """ + if b is None: + b = M_current[id_pop, :] + b = np.delete(b, id_pop) + M_del = np.delete(M_current, id_pop, 0) + a = M_del[:, id_pop] + M_del = np.delete(M_del, id_pop, 1) + M = M_del - np.outer(a, b) / M_current[id_pop, id_pop] + else: + n = b.shape[0] + 1 + if np.shape(b)[0] == 0: + M = np.array([[0.5]]) + else: + X = M_current.dot(b) + s = d - b.T.dot(X) + M = np.zeros((n, n)) + M[:-1, :-1] = M_current + X.dot(X.T) / s + X_ravel = X.ravel() + M[-1, :-1] = -X_ravel / s + M[:-1, -1] = -X_ravel / s + M[-1, -1] = 1 / s + return M + + +def construct_augmented_H(active_index, m, Hc, HrHr): + r""" This function construct an augmented matrix for the first iteration of + semi-relaxed regularization path + + .. math:: + Augmented_H = + \begin{bmatrix} + 0 & H_{c A} \\ + H_{c A}^T & H_{r A}^T H_{r A} + \end{bmatrix} + where : + - H_{r A} is the sub-matrix constructed by the columns of H_r + whose indices belong to the active set A + - H_{c A} is the sub-matrix constructed by the columns of H_c + whose indices belong to the active set A + Parameters + ---------- + active_index : list + Indices of active variables + m : int + Length of the target distribution + Hc : np.ndarray (dim_b, dim_a * dim_b) + Matrix that computes the sum along the columns of transport plan T + HrHr : np.ndarray (dim_a * dim_b, dim_a * dim_b) + Matrix product of H_r^T H_r + Returns + ------- + H_augmented : np.ndarray (dim_b + |A|, dim_b + |A|) + Augmented matrix for the first iteration of the semi-relaxed + regularization path + """ + Hc_sub = Hc[:, active_index].toarray() + HrHr_sub = HrHr[:, active_index] + HrHr_sub = HrHr_sub[active_index, :].toarray() + H_augmented = np.block([[np.zeros((m, m)), Hc_sub], [Hc_sub.T, HrHr_sub]]) + return H_augmented + + +def fully_relaxed_path(a: np.array, b: np.array, C: np.array, reg=1e-4, + itmax=50000): + r"""This function gives the regularization path of l2-penalized UOT problem + + The problem to optimize is the Lasso reformulation of the l2-penalized UOT: + .. math:: + \min_t \gamma c^T t + 0.5 * \|H t - y\|_2^2 + s.t. + t \geq 0 + where : + - c is a (dim_a * dim_b, ) metric cost vector (flattened version of C) + - :math:`\gamma = 1/\lambda` is the l2-regularization coefficient + - y is the concatenation of vectors a and b, defined as y^T = [a^T b^T] + - H is a (dim_a + dim_b, dim_a * dim_b) metric matrix, + see [Chapel et al., 2021] for the design of H. The matrix product Ht + computes both the source marginal and the target marginal. + - t is a (dim_a * dim_b, ) metric vector (flattened version of T) + Parameters + ---------- + a : np.ndarray (dim_a,) + Histogram of dimension dim_a + b : np.ndarray (dim_b,) + Histogram of dimension dim_b + C : np.ndarray, shape (dim_a, dim_b) + Cost matrix + reg: float + l2-regularization coefficient + itmax: int + Maximum number of iteration + Returns + ------- + t : np.ndarray (dim_a*dim_b, ) + Flattened vector of optimal transport matrix + t_list : list + List of solutions in regularization path + gamma_list : list + List of regularization coefficient in regularization path + Examples + -------- + >>> import ot + >>> import numpy as np + >>> n = 3 + >>> xs = np.array([1., 2., 3.]).reshape((n, 1)) + >>> xt = np.array([5., 6., 7.]).reshape((n, 1)) + >>> C = ot.dist(xs, xt) + >>> C /= C.max() + >>> a = np.array([0.2, 0.5, 0.3]) + >>> b = np.array([0.2, 0.5, 0.3]) + >>> t, _, _ = ot.regpath.fully_relaxed_path(a, b, C, 1e-4) + >>> t + array([1.99958333e-01, 0.00000000e+00, 0.00000000e+00, 3.88888889e-05, + 4.99938889e-01, 0.00000000e+00, 0.00000000e+00, 3.88888889e-05, + 2.99958333e-01]) + + References + ---------- + [Chapel et al., 2021]: + Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). + Unbalanced optimal transport through non-negative penalized + linear regression. + """ + + n = np.shape(a)[0] + m = np.shape(b)[0] + H, y, c = recast_ot_as_lasso(a, b, C) + HtH = H.T.dot(H) + Hty = H.T.dot(y) + n_iter = 1 + + # initialization + M0 = Hty / c + gamma_list = [np.max(M0)] + active_index = [np.argmax(M0)] + t_list = [np.zeros((n * m,))] + H_inv = np.array([[]]) + add_col = np.array([]) + id_pop = -1 + + while n_iter < itmax and gamma_list[-1] > reg: + H_inv = complement_schur(H_inv, add_col, 2., id_pop) + current_gamma = gamma_list[-1] + + # compute the intercept and slope of solutions in current iteration + # t = phi - gamma * delta + phi = H_inv.dot(Hty[active_index]) + delta = H_inv.dot(c[active_index]) + gamma, ik = ot_next_gamma(phi, delta, HtH, Hty, c, active_index, + current_gamma) + + # compute the next lambda when removing a point from the active set + alt_gamma, id_pop = compute_next_removal(phi, delta, current_gamma) + + # if the positivity constraint is violated, we remove id_pop + # from active set, otherwise we add ik to active set + if alt_gamma > gamma: + gamma = alt_gamma + else: + id_pop = -1 + + # compute the solution of current segment + tA = phi - gamma * delta + sol = np.zeros((n * m, )) + sol[active_index] = tA + + if id_pop != -1: + active_index.pop(id_pop) + add_col = None + else: + active_index.append(ik) + add_col = HtH[active_index[:-1], ik].toarray() + + gamma_list.append(gamma) + t_list.append(sol) + n_iter += 1 + + if itmax <= n_iter: + print('maximum iteration has been reached !') + + # correct the last solution and gamma + if len(t_list) > 1: + t_final = (t_list[-2] + (t_list[-1] - t_list[-2]) * + (reg - gamma_list[-2]) / (gamma_list[-1] - gamma_list[-2])) + t_list[-1] = t_final + gamma_list[-1] = reg + else: + gamma_list[-1] = reg + print('Regularization path does not exist !') + + return t_list[-1], t_list, gamma_list + + +def semi_relaxed_path(a: np.array, b: np.array, C: np.array, reg=1e-4, + itmax=50000): + r"""This function gives the regularization path of semi-relaxed + l2-UOT problem + + The problem to optimize is the Lasso reformulation of the l2-penalized UOT: + .. math:: + \min_t \gamma c^T t + 0.5 * \|H_r t - a\|_2^2 + s.t. + H_c t = b + t \geq 0 + where : + - c is a (dim_a * dim_b, ) metric cost vector (flattened version of C) + - :math:`\gamma = 1/\lambda` is the l2-regularization coefficient + - H_r is a (dim_a, dim_a * dim_b) metric matrix, + which computes the sum along the rows of transport plan T + - H_c is a (dim_b, dim_a * dim_b) metric matrix, + which computes the sum along the columns of transport plan T + - t is a (dim_a * dim_b, ) metric vector (flattened version of T) + Parameters + ---------- + a : np.ndarray (dim_a,) + Histogram of dimension dim_a + b : np.ndarray (dim_b,) + Histogram of dimension dim_b + C : np.ndarray, shape (dim_a, dim_b) + Cost matrix + reg: float (optional) + l2-regularization coefficient + itmax: int (optional) + Maximum number of iteration + Returns + ------- + t : np.ndarray (dim_a*dim_b, ) + Flattened vector of optimal transport matrix + t_list : list + List of solutions in regularization path + gamma_list : list + List of regularization coefficient in regularization path + Examples + -------- + >>> import ot + >>> import numpy as np + >>> n = 3 + >>> xs = np.array([1., 2., 3.]).reshape((n, 1)) + >>> xt = np.array([5., 6., 7.]).reshape((n, 1)) + >>> C = ot.dist(xs, xt) + >>> C /= C.max() + >>> a = np.array([0.2, 0.5, 0.3]) + >>> b = np.array([0.2, 0.5, 0.3]) + >>> t, _, _ = ot.regpath.semi_relaxed_path(a, b, C, 1e-4) + >>> t + array([1.99980556e-01, 0.00000000e+00, 0.00000000e+00, 1.94444444e-05, + 4.99980556e-01, 0.00000000e+00, 0.00000000e+00, 1.94444444e-05, + 3.00000000e-01]) + + References + ---------- + [Chapel et al., 2021]: + Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). + Unbalanced optimal transport through non-negative penalized + linear regression. + """ + + n = np.shape(a)[0] + m = np.shape(b)[0] + Hr, Hc, c = recast_semi_relaxed_as_lasso(a, b, C) + Hra = Hr.T.dot(a) + HrHr = Hr.T.dot(Hr) + n_iter = 1 + active_index = [] + + # initialization + for j in range(np.shape(C)[1]): + i = np.argmin(C[:, j]) + active_index.append(i * m + j) + gamma_list = [] + t_list = [] + current_gamma = np.Inf + augmented_H0 = construct_augmented_H(active_index, m, Hc, HrHr) + add_col = np.array([]) + id_pop = -1 + + while n_iter < itmax and current_gamma > reg: + if n_iter == 1: + H_inv = np.linalg.inv(augmented_H0) + else: + H_inv = complement_schur(H_inv, add_col, 1., id_pop + m) + # compute the intercept and slope of solutions in current iteration + augmented_phi = H_inv.dot(np.concatenate((b, Hra[active_index]))) + augmented_delta = H_inv[:, m:].dot(c[active_index]) + phi = augmented_phi[m:] + delta = augmented_delta[m:] + phi_u = augmented_phi[0:m] + delta_u = augmented_delta[0:m] + gamma, ik = semi_relaxed_next_gamma(phi, delta, phi_u, delta_u, + HrHr, Hc, Hra, c, active_index, + current_gamma) + + # compute the next lambda when removing a point from the active set + alt_gamma, id_pop = compute_next_removal(phi, delta, current_gamma) + + # if the positivity constraint is violated, we remove id_pop + # from active set, otherwise we add ik to active set + if alt_gamma > gamma: + gamma = alt_gamma + else: + id_pop = -1 + + # compute the solution of current segment + tA = phi - gamma * delta + sol = np.zeros((n * m, )) + sol[active_index] = tA + if id_pop != -1: + active_index.pop(id_pop) + add_col = None + else: + active_index.append(ik) + add_col = np.concatenate((Hc.toarray()[:, ik], + HrHr.toarray()[active_index[:-1], ik])) + add_col = add_col[:, np.newaxis] + + gamma_list.append(gamma) + t_list.append(sol) + current_gamma = gamma + n_iter += 1 + + if itmax <= n_iter: + print('maximum iteration has been reached !') + + # correct the last solution and gamma + if len(t_list) > 1: + t_final = (t_list[-2] + (t_list[-1] - t_list[-2]) * + (reg - gamma_list[-2]) / (gamma_list[-1] - gamma_list[-2])) + t_list[-1] = t_final + gamma_list[-1] = reg + else: + gamma_list[-1] = reg + print('Regularization path does not exist !') + + return t_list[-1], t_list, gamma_list + + +def regularization_path(a: np.array, b: np.array, C: np.array, reg=1e-4, + semi_relaxed=False, itmax=50000): + r"""This function combines both the semi-relaxed and the fully-relaxed + regularization paths of l2-UOT problem + + Parameters + ---------- + a : np.ndarray (dim_a,) + Histogram of dimension dim_a + b : np.ndarray (dim_b,) + Histogram of dimension dim_b + C : np.ndarray, shape (dim_a, dim_b) + Cost matrix + reg: float (optional) + l2-regularization coefficient + semi_relaxed : bool (optional) + Give the semi-relaxed path if true + itmax: int (optional) + Maximum number of iteration + Returns + ------- + t : np.ndarray (dim_a*dim_b, ) + Flattened vector of optimal transport matrix + t_list : list + List of solutions in regularization path + gamma_list : list + List of regularization coefficient in regularization path + References + ---------- + [Chapel et al., 2021]: + Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). + Unbalanced optimal transport through non-negative penalized + linear regression. + """ + if semi_relaxed: + t, t_list, gamma_list = semi_relaxed_path(a, b, C, reg=reg, + itmax=itmax) + else: + t, t_list, gamma_list = fully_relaxed_path(a, b, C, reg=reg, + itmax=itmax) + return t, t_list, gamma_list + + +def compute_transport_plan(gamma, gamma_list, Pi_list): + r""" Given the regularization path, this function computes the transport + plan for any value of gamma by the piecewise linearity of the path + + .. math:: + t(\gamma) = \phi(\gamma) - \gamma \delta(\gamma) + where : + - :math:`\gamma` is the regularization coefficient + - :math:`\phi(\gamma)` is the corresponding intercept + - :math:`\delta(\gamma)` is the corresponding slope + - t is a (dim_a * dim_b, ) vector (flattened version of transport matrix) + Parameters + ---------- + gamma : float + Regularization coefficient + gamma_list : list + List of regularization coefficients in regularization path + Pi_list : list + List of solutions in regularization path + Returns + ------- + t : np.ndarray (dim_a*dim_b, ) + Transport vector corresponding to the given value of gamma + Examples + -------- + >>> import ot + >>> import numpy as np + >>> n = 3 + >>> xs = np.array([1., 2., 3.]).reshape((n, 1)) + >>> xt = np.array([5., 6., 7.]).reshape((n, 1)) + >>> C = ot.dist(xs, xt) + >>> C /= C.max() + >>> a = np.array([0.2, 0.5, 0.3]) + >>> b = np.array([0.2, 0.5, 0.3]) + >>> t, pi_list, g_list = ot.regpath.regularization_path(a, b, C, reg=1e-4) + >>> gamma = 1 + >>> t2 = ot.regpath.compute_transport_plan(gamma, g_list, pi_list) + >>> t2 + array([0. , 0. , 0. , 0.19722222, 0.05555556, + 0. , 0. , 0.24722222, 0. ]) + + References + ---------- + [Chapel et al., 2021]: + Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). + Unbalanced optimal transport through non-negative penalized + linear regression. + """ + + if gamma >= gamma_list[0]: + Pi = Pi_list[0] + elif gamma <= gamma_list[-1]: + Pi = Pi_list[-1] + else: + idx = np.where(gamma <= np.array(gamma_list))[0][-1] + gamma_k0 = gamma_list[idx] + gamma_k1 = gamma_list[idx + 1] + pi_k0 = Pi_list[idx] + pi_k1 = Pi_list[idx + 1] + Pi = pi_k0 + (pi_k1 - pi_k0) * (gamma - gamma_k0) \ + / (gamma_k1 - gamma_k0) + return Pi diff --git a/test/test_regpath.py b/test/test_regpath.py new file mode 100644 index 0000000..967c27b --- /dev/null +++ b/test/test_regpath.py @@ -0,0 +1,64 @@ +"""Tests for module regularization path""" + +# Author: Haoran Wu +# +# License: MIT License + +import numpy as np +import ot + + +def test_fully_relaxed_path(): + + n_source = 50 # nb source samples (gaussian) + n_target = 40 # nb target samples (gaussian) + + mu = np.array([0, 0]) + cov = np.array([[1, 0], [0, 2]]) + + np.random.seed(0) + xs = ot.datasets.make_2D_samples_gauss(n_source, mu, cov) + xt = ot.datasets.make_2D_samples_gauss(n_target, mu, cov) + + # source and target distributions + a = ot.utils.unif(n_source) + b = ot.utils.unif(n_target) + + # loss matrix + M = ot.dist(xs, xt) + M /= M.max() + + t, _, _ = ot.regpath.regularization_path(a, b, M, reg=1e-8, + semi_relaxed=False) + + G = t.reshape((n_source, n_target)) + np.testing.assert_allclose(a, G.sum(1), atol=1e-05) + np.testing.assert_allclose(b, G.sum(0), atol=1e-05) + + +def test_semi_relaxed_path(): + + n_source = 50 # nb source samples (gaussian) + n_target = 40 # nb target samples (gaussian) + + mu = np.array([0, 0]) + cov = np.array([[1, 0], [0, 2]]) + + np.random.seed(0) + xs = ot.datasets.make_2D_samples_gauss(n_source, mu, cov) + xt = ot.datasets.make_2D_samples_gauss(n_target, mu, cov) + + # source and target distributions + a = ot.utils.unif(n_source) + b = ot.utils.unif(n_target) + + # loss matrix + M = ot.dist(xs, xt) + M /= M.max() + + t, _, _ = ot.regpath.regularization_path(a, b, M, reg=1e-8, + semi_relaxed=True) + + G = t.reshape((n_source, n_target)) + np.testing.assert_allclose(a, G.sum(1), atol=1e-05) + np.testing.assert_allclose(b, G.sum(0), atol=1e-10) -- cgit v1.2.3 From 7a65086dd340265d0223eb8ffb5c9a5152a82dff Mon Sep 17 00:00:00 2001 From: ncassereau-idris <84033440+ncassereau-idris@users.noreply.github.com> Date: Mon, 25 Oct 2021 11:36:21 +0200 Subject: [MRG] Bregman backend (#280) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Bregman * Resolve conflicts * Bug solve * Bregman updated for JAX compatibility * Tests coherence between backend improved * No longer enforcing 64 bits operations on Jax except for tests * Now using mixtures, to make backend dependent tests with less code * Better test skipping code * Pep8 + test optimizations * redundancy removed * Docs * Typo corrected * Typo * Typo * Docs * Docs * pep8 * Backend docs * Prettier docs * Mistake corrected * small changes * Better wording Co-authored-by: Rémi Flamary --- docs/source/all.rst | 1 + ot/backend.py | 581 ++++++++++++++++++++++++++++- ot/bregman.py | 970 ++++++++++++++++++++++++++---------------------- ot/gromov.py | 6 +- ot/smooth.py | 4 +- ot/unbalanced.py | 14 +- test/conftest.py | 49 +++ test/test_backend.py | 102 +++++ test/test_bregman.py | 217 +++++++---- test/test_partial.py | 6 +- test/test_smooth.py | 12 +- test/test_stochastic.py | 12 +- 12 files changed, 1423 insertions(+), 551 deletions(-) create mode 100644 test/conftest.py (limited to 'test') diff --git a/docs/source/all.rst b/docs/source/all.rst index f1f7075..6a07599 100644 --- a/docs/source/all.rst +++ b/docs/source/all.rst @@ -14,6 +14,7 @@ API and modules :template: module.rst lp + backend bregman smooth gromov diff --git a/ot/backend.py b/ot/backend.py index 2ed40af..a4a4757 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -1,6 +1,22 @@ # -*- coding: utf-8 -*- """ Multi-lib backend for POT + +The goal is to write backend-agnostic code. Whether you're using Numpy, PyTorch, +or Jax, POT code should work nonetheless. +To achieve that, POT provides backend classes which implements functions in their respective backend +imitating Numpy API. As a convention, we use nx instead of np to refer to the backend. + +Examples +-------- + +>>> from ot.utils import list_to_array +>>> from ot.backend import get_backend +>>> def f(a, b): # the function does not know which backend to use +... a, b = list_to_array(a, b) # if a list in given, make it an array +... nx = get_backend(a, b) # infer the backend from the arguments +... c = nx.dot(a, b) # now use the backend to do any calculation +... return c """ # Author: Remi Flamary @@ -9,6 +25,7 @@ Multi-lib backend for POT # License: MIT License import numpy as np +import scipy.special as scipy try: import torch @@ -20,6 +37,7 @@ except ImportError: try: import jax import jax.numpy as jnp + import jax.scipy.special as jscipy jax_type = jax.numpy.ndarray except ImportError: jax = False @@ -29,7 +47,7 @@ str_type_error = "All array should be from the same type/backend. Current types def get_backend_list(): - """ returns the list of available backends)""" + """Returns the list of available backends""" lst = [NumpyBackend(), ] if torch: @@ -42,7 +60,7 @@ def get_backend_list(): def get_backend(*args): - """returns the proper backend for a list of input arrays + """Returns the proper backend for a list of input arrays Also raises TypeError if all arrays are not from the same backend """ @@ -50,14 +68,12 @@ def get_backend(*args): if not len(args) > 0: raise ValueError(" The function takes at least one parameter") # check all same type + if not len(set(type(a) for a in args)) == 1: + raise ValueError(str_type_error.format([type(a) for a in args])) if isinstance(args[0], np.ndarray): - if not len(set(type(a) for a in args)) == 1: - raise ValueError(str_type_error.format([type(a) for a in args])) return NumpyBackend() - elif torch and isinstance(args[0], torch_type): - if not len(set(type(a) for a in args)) == 1: - raise ValueError(str_type_error.format([type(a) for a in args])) + elif isinstance(args[0], torch_type): return TorchBackend() elif isinstance(args[0], jax_type): return JaxBackend() @@ -66,7 +82,7 @@ def get_backend(*args): def to_numpy(*args): - """returns numpy arrays from any compatible backend""" + """Returns numpy arrays from any compatible backend""" if len(args) == 1: return get_backend(args[0]).to_numpy(args[0]) @@ -75,6 +91,13 @@ def to_numpy(*args): class Backend(): + """ + Backend abstract class. + Implementations: :py:class:`JaxBackend`, :py:class:`NumpyBackend`, :py:class:`TorchBackend` + + - The `__name__` class attribute refers to the name of the backend. + - The `__type__` class attribute refers to the data structure used by the backend. + """ __name__ = None __type__ = None @@ -84,90 +107,426 @@ class Backend(): # convert to numpy def to_numpy(self, a): + """Returns the numpy version of a tensor""" raise NotImplementedError() # convert from numpy def from_numpy(self, a, type_as=None): + """Creates a tensor cloning a numpy array, with the given precision (defaulting to input's precision) and the given device (in case of GPUs)""" raise NotImplementedError() def set_gradients(self, val, inputs, grads): - """ define the gradients for the value val wrt the inputs """ + """Define the gradients for the value val wrt the inputs """ raise NotImplementedError() def zeros(self, shape, type_as=None): + r""" + Creates a tensor full of zeros. + + This function follow the api from :any:`numpy.zeros` + + See: https://numpy.org/doc/stable/reference/generated/numpy.zeros.html + """ raise NotImplementedError() def ones(self, shape, type_as=None): + r""" + Creates a tensor full of ones. + + This function follow the api from :any:`numpy.ones` + + See: https://numpy.org/doc/stable/reference/generated/numpy.ones.html + """ raise NotImplementedError() def arange(self, stop, start=0, step=1, type_as=None): + r""" + Returns evenly spaced values within a given interval. + + This function follow the api from :any:`numpy.arange` + + See: https://numpy.org/doc/stable/reference/generated/numpy.arange.html + """ raise NotImplementedError() def full(self, shape, fill_value, type_as=None): + r""" + Creates a tensor with given shape, filled with given value. + + This function follow the api from :any:`numpy.full` + + See: https://numpy.org/doc/stable/reference/generated/numpy.full.html + """ raise NotImplementedError() def eye(self, N, M=None, type_as=None): + r""" + Creates the identity matrix of given size. + + This function follow the api from :any:`numpy.eye` + + See: https://numpy.org/doc/stable/reference/generated/numpy.eye.html + """ raise NotImplementedError() def sum(self, a, axis=None, keepdims=False): + r""" + Sums tensor elements over given dimensions. + + This function follow the api from :any:`numpy.sum` + + See: https://numpy.org/doc/stable/reference/generated/numpy.sum.html + """ raise NotImplementedError() def cumsum(self, a, axis=None): + r""" + Returns the cumulative sum of tensor elements over given dimensions. + + This function follow the api from :any:`numpy.cumsum` + + See: https://numpy.org/doc/stable/reference/generated/numpy.cumsum.html + """ raise NotImplementedError() def max(self, a, axis=None, keepdims=False): + r""" + Returns the maximum of an array or maximum along given dimensions. + + This function follow the api from :any:`numpy.amax` + + See: https://numpy.org/doc/stable/reference/generated/numpy.amax.html + """ raise NotImplementedError() def min(self, a, axis=None, keepdims=False): + r""" + Returns the maximum of an array or maximum along given dimensions. + + This function follow the api from :any:`numpy.amin` + + See: https://numpy.org/doc/stable/reference/generated/numpy.amin.html + """ raise NotImplementedError() def maximum(self, a, b): + r""" + Returns element-wise maximum of array elements. + + This function follow the api from :any:`numpy.maximum` + + See: https://numpy.org/doc/stable/reference/generated/numpy.maximum.html + """ raise NotImplementedError() def minimum(self, a, b): + r""" + Returns element-wise minimum of array elements. + + This function follow the api from :any:`numpy.minimum` + + See: https://numpy.org/doc/stable/reference/generated/numpy.minimum.html + """ raise NotImplementedError() def dot(self, a, b): + r""" + Returns the dot product of two tensors. + + This function follow the api from :any:`numpy.dot` + + See: https://numpy.org/doc/stable/reference/generated/numpy.dot.html + """ raise NotImplementedError() def abs(self, a): + r""" + Computes the absolute value element-wise. + + This function follow the api from :any:`numpy.absolute` + + See: https://numpy.org/doc/stable/reference/generated/numpy.absolute.html + """ raise NotImplementedError() def exp(self, a): + r""" + Computes the exponential value element-wise. + + This function follow the api from :any:`numpy.exp` + + See: https://numpy.org/doc/stable/reference/generated/numpy.exp.html + """ raise NotImplementedError() def log(self, a): + r""" + Computes the natural logarithm, element-wise. + + This function follow the api from :any:`numpy.log` + + See: https://numpy.org/doc/stable/reference/generated/numpy.log.html + """ raise NotImplementedError() def sqrt(self, a): + r""" + Returns the non-ngeative square root of a tensor, element-wise. + + This function follow the api from :any:`numpy.sqrt` + + See: https://numpy.org/doc/stable/reference/generated/numpy.sqrt.html + """ + raise NotImplementedError() + + def power(self, a, exponents): + r""" + First tensor elements raised to powers from second tensor, element-wise. + + This function follow the api from :any:`numpy.power` + + See: https://numpy.org/doc/stable/reference/generated/numpy.power.html + """ raise NotImplementedError() def norm(self, a): + r""" + Computes the matrix frobenius norm. + + This function follow the api from :any:`numpy.linalg.norm` + + See: https://numpy.org/doc/stable/reference/generated/numpy.linalg.norm.html + """ raise NotImplementedError() def any(self, a): + r""" + Tests whether any tensor element along given dimensions evaluates to True. + + This function follow the api from :any:`numpy.any` + + See: https://numpy.org/doc/stable/reference/generated/numpy.any.html + """ raise NotImplementedError() def isnan(self, a): + r""" + Tests element-wise for NaN and returns result as a boolean tensor. + + This function follow the api from :any:`numpy.isnan` + + See: https://numpy.org/doc/stable/reference/generated/numpy.isnan.html + """ raise NotImplementedError() def isinf(self, a): + r""" + Tests element-wise for positive or negative infinity and returns result as a boolean tensor. + + This function follow the api from :any:`numpy.isinf` + + See: https://numpy.org/doc/stable/reference/generated/numpy.isinf.html + """ raise NotImplementedError() def einsum(self, subscripts, *operands): + r""" + Evaluates the Einstein summation convention on the operands. + + This function follow the api from :any:`numpy.einsum` + + See: https://numpy.org/doc/stable/reference/generated/numpy.einsum.html + """ raise NotImplementedError() def sort(self, a, axis=-1): + r""" + Returns a sorted copy of a tensor. + + This function follow the api from :any:`numpy.sort` + + See: https://numpy.org/doc/stable/reference/generated/numpy.sort.html + """ raise NotImplementedError() def argsort(self, a, axis=None): + r""" + Returns the indices that would sort a tensor. + + This function follow the api from :any:`numpy.argsort` + + See: https://numpy.org/doc/stable/reference/generated/numpy.argsort.html + """ + raise NotImplementedError() + + def searchsorted(self, a, v, side='left'): + r""" + Finds indices where elements should be inserted to maintain order in given tensor. + + This function follow the api from :any:`numpy.searchsorted` + + See: https://numpy.org/doc/stable/reference/generated/numpy.searchsorted.html + """ raise NotImplementedError() def flip(self, a, axis=None): + r""" + Reverses the order of elements in a tensor along given dimensions. + + This function follow the api from :any:`numpy.flip` + + See: https://numpy.org/doc/stable/reference/generated/numpy.flip.html + """ + raise NotImplementedError() + + def clip(self, a, a_min, a_max): + """ + Limits the values in a tensor. + + This function follow the api from :any:`numpy.clip` + + See: https://numpy.org/doc/stable/reference/generated/numpy.clip.html + """ + raise NotImplementedError() + + def repeat(self, a, repeats, axis=None): + r""" + Repeats elements of a tensor. + + This function follow the api from :any:`numpy.repeat` + + See: https://numpy.org/doc/stable/reference/generated/numpy.repeat.html + """ + raise NotImplementedError() + + def take_along_axis(self, arr, indices, axis): + r""" + Gathers elements of a tensor along given dimensions. + + This function follow the api from :any:`numpy.take_along_axis` + + See: https://numpy.org/doc/stable/reference/generated/numpy.take_along_axis.html + """ + raise NotImplementedError() + + def concatenate(self, arrays, axis=0): + r""" + Joins a sequence of tensors along an existing dimension. + + This function follow the api from :any:`numpy.concatenate` + + See: https://numpy.org/doc/stable/reference/generated/numpy.concatenate.html + """ + raise NotImplementedError() + + def zero_pad(self, a, pad_width): + r""" + Pads a tensor. + + This function follow the api from :any:`numpy.pad` + + See: https://numpy.org/doc/stable/reference/generated/numpy.pad.html + """ + raise NotImplementedError() + + def argmax(self, a, axis=None): + r""" + Returns the indices of the maximum values of a tensor along given dimensions. + + This function follow the api from :any:`numpy.argmax` + + See: https://numpy.org/doc/stable/reference/generated/numpy.argmax.html + """ + raise NotImplementedError() + + def mean(self, a, axis=None): + r""" + Computes the arithmetic mean of a tensor along given dimensions. + + This function follow the api from :any:`numpy.mean` + + See: https://numpy.org/doc/stable/reference/generated/numpy.mean.html + """ + raise NotImplementedError() + + def std(self, a, axis=None): + r""" + Computes the standard deviation of a tensor along given dimensions. + + This function follow the api from :any:`numpy.std` + + See: https://numpy.org/doc/stable/reference/generated/numpy.std.html + """ + raise NotImplementedError() + + def linspace(self, start, stop, num): + r""" + Returns a specified number of evenly spaced values over a given interval. + + This function follow the api from :any:`numpy.linspace` + + See: https://numpy.org/doc/stable/reference/generated/numpy.linspace.html + """ + raise NotImplementedError() + + def meshgrid(self, a, b): + r""" + Returns coordinate matrices from coordinate vectors (Numpy convention). + + This function follow the api from :any:`numpy.meshgrid` + + See: https://numpy.org/doc/stable/reference/generated/numpy.meshgrid.html + """ + raise NotImplementedError() + + def diag(self, a, k=0): + r""" + Extracts or constructs a diagonal tensor. + + This function follow the api from :any:`numpy.diag` + + See: https://numpy.org/doc/stable/reference/generated/numpy.diag.html + """ + raise NotImplementedError() + + def unique(self, a): + r""" + Finds unique elements of given tensor. + + This function follow the api from :any:`numpy.unique` + + See: https://numpy.org/doc/stable/reference/generated/numpy.unique.html + """ + raise NotImplementedError() + + def logsumexp(self, a, axis=None): + r""" + Computes the log of the sum of exponentials of input elements. + + This function follow the api from :any:`scipy.special.logsumexp` + + See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.special.logsumexp.html + """ + raise NotImplementedError() + + def stack(self, arrays, axis=0): + r""" + Joins a sequence of tensors along a new dimension. + + This function follow the api from :any:`numpy.stack` + + See: https://numpy.org/doc/stable/reference/generated/numpy.stack.html + """ raise NotImplementedError() class NumpyBackend(Backend): + """ + NumPy implementation of the backend + + - `__name__` is "numpy" + - `__type__` is np.ndarray + """ __name__ = 'numpy' __type__ = np.ndarray @@ -184,7 +543,7 @@ class NumpyBackend(Backend): return a.astype(type_as.dtype) def set_gradients(self, val, inputs, grads): - # no gradients for numpy + # No gradients for numpy return val def zeros(self, shape, type_as=None): @@ -247,6 +606,9 @@ class NumpyBackend(Backend): def sqrt(self, a): return np.sqrt(a) + def power(self, a, exponents): + return np.power(a, exponents) + def norm(self, a): return np.sqrt(np.sum(np.square(a))) @@ -268,11 +630,70 @@ class NumpyBackend(Backend): def argsort(self, a, axis=-1): return np.argsort(a, axis) + def searchsorted(self, a, v, side='left'): + if a.ndim == 1: + return np.searchsorted(a, v, side) + else: + # this is a not very efficient way to make numpy + # searchsorted work on 2d arrays + ret = np.empty(v.shape, dtype=int) + for i in range(a.shape[0]): + ret[i, :] = np.searchsorted(a[i, :], v[i, :], side) + return ret + def flip(self, a, axis=None): return np.flip(a, axis) + def clip(self, a, a_min, a_max): + return np.clip(a, a_min, a_max) + + def repeat(self, a, repeats, axis=None): + return np.repeat(a, repeats, axis) + + def take_along_axis(self, arr, indices, axis): + return np.take_along_axis(arr, indices, axis) + + def concatenate(self, arrays, axis=0): + return np.concatenate(arrays, axis) + + def zero_pad(self, a, pad_width): + return np.pad(a, pad_width) + + def argmax(self, a, axis=None): + return np.argmax(a, axis=axis) + + def mean(self, a, axis=None): + return np.mean(a, axis=axis) + + def std(self, a, axis=None): + return np.std(a, axis=axis) + + def linspace(self, start, stop, num): + return np.linspace(start, stop, num) + + def meshgrid(self, a, b): + return np.meshgrid(a, b) + + def diag(self, a, k=0): + return np.diag(a, k) + + def unique(self, a): + return np.unique(a) + + def logsumexp(self, a, axis=None): + return scipy.logsumexp(a, axis=axis) + + def stack(self, arrays, axis=0): + return np.stack(arrays, axis) + class JaxBackend(Backend): + """ + JAX implementation of the backend + + - `__name__` is "jax" + - `__type__` is jax.numpy.ndarray + """ __name__ = 'jax' __type__ = jax_type @@ -359,6 +780,9 @@ class JaxBackend(Backend): def sqrt(self, a): return jnp.sqrt(a) + def power(self, a, exponents): + return jnp.power(a, exponents) + def norm(self, a): return jnp.sqrt(jnp.sum(jnp.square(a))) @@ -380,11 +804,67 @@ class JaxBackend(Backend): def argsort(self, a, axis=-1): return jnp.argsort(a, axis) + def searchsorted(self, a, v, side='left'): + if a.ndim == 1: + return jnp.searchsorted(a, v, side) + else: + # this is a not very efficient way to make jax numpy + # searchsorted work on 2d arrays + return jnp.array([jnp.searchsorted(a[i, :], v[i, :], side) for i in range(a.shape[0])]) + def flip(self, a, axis=None): return jnp.flip(a, axis) + def clip(self, a, a_min, a_max): + return jnp.clip(a, a_min, a_max) + + def repeat(self, a, repeats, axis=None): + return jnp.repeat(a, repeats, axis) + + def take_along_axis(self, arr, indices, axis): + return jnp.take_along_axis(arr, indices, axis) + + def concatenate(self, arrays, axis=0): + return jnp.concatenate(arrays, axis) + + def zero_pad(self, a, pad_width): + return jnp.pad(a, pad_width) + + def argmax(self, a, axis=None): + return jnp.argmax(a, axis=axis) + + def mean(self, a, axis=None): + return jnp.mean(a, axis=axis) + + def std(self, a, axis=None): + return jnp.std(a, axis=axis) + + def linspace(self, start, stop, num): + return jnp.linspace(start, stop, num) + + def meshgrid(self, a, b): + return jnp.meshgrid(a, b) + + def diag(self, a, k=0): + return jnp.diag(a, k) + + def unique(self, a): + return jnp.unique(a) + + def logsumexp(self, a, axis=None): + return jscipy.logsumexp(a, axis=axis) + + def stack(self, arrays, axis=0): + return jnp.stack(arrays, axis) + class TorchBackend(Backend): + """ + PyTorch implementation of the backend + + - `__name__` is "torch" + - `__type__` is torch.Tensor + """ __name__ = 'torch' __type__ = torch_type @@ -487,22 +967,23 @@ class TorchBackend(Backend): a = torch.tensor([float(a)], dtype=b.dtype, device=b.device) if isinstance(b, int) or isinstance(b, float): b = torch.tensor([float(b)], dtype=a.dtype, device=a.device) - return torch.maximum(a, b) + if torch.__version__ >= '1.7.0': + return torch.maximum(a, b) + else: + return torch.max(torch.stack(torch.broadcast_tensors(a, b)), axis=0)[0] def minimum(self, a, b): if isinstance(a, int) or isinstance(a, float): a = torch.tensor([float(a)], dtype=b.dtype, device=b.device) if isinstance(b, int) or isinstance(b, float): b = torch.tensor([float(b)], dtype=a.dtype, device=a.device) - return torch.minimum(a, b) + if torch.__version__ >= '1.7.0': + return torch.minimum(a, b) + else: + return torch.min(torch.stack(torch.broadcast_tensors(a, b)), axis=0)[0] def dot(self, a, b): - if len(a.shape) == len(b.shape) == 1: - return torch.dot(a, b) - elif len(a.shape) == 2 and len(b.shape) == 1: - return torch.mv(a, b) - else: - return torch.mm(a, b) + return torch.matmul(a, b) def abs(self, a): return torch.abs(a) @@ -516,6 +997,9 @@ class TorchBackend(Backend): def sqrt(self, a): return torch.sqrt(a) + def power(self, a, exponents): + return torch.pow(a, exponents) + def norm(self, a): return torch.sqrt(torch.sum(torch.square(a))) @@ -539,6 +1023,10 @@ class TorchBackend(Backend): sorted, indices = torch.sort(a, dim=axis) return indices + def searchsorted(self, a, v, side='left'): + right = (side != 'left') + return torch.searchsorted(a, v, right=right) + def flip(self, a, axis=None): if axis is None: return torch.flip(a, tuple(i for i in range(len(a.shape)))) @@ -546,3 +1034,60 @@ class TorchBackend(Backend): return torch.flip(a, (axis,)) else: return torch.flip(a, dims=axis) + + def clip(self, a, a_min, a_max): + return torch.clamp(a, a_min, a_max) + + def repeat(self, a, repeats, axis=None): + return torch.repeat_interleave(a, repeats, dim=axis) + + def take_along_axis(self, arr, indices, axis): + return torch.gather(arr, axis, indices) + + def concatenate(self, arrays, axis=0): + return torch.cat(arrays, dim=axis) + + def zero_pad(self, a, pad_width): + from torch.nn.functional import pad + # pad_width is an array of ndim tuples indicating how many 0 before and after + # we need to add. We first need to make it compliant with torch syntax, that + # starts with the last dim, then second last, etc. + how_pad = tuple(element for tupl in pad_width[::-1] for element in tupl) + return pad(a, how_pad) + + def argmax(self, a, axis=None): + return torch.argmax(a, dim=axis) + + def mean(self, a, axis=None): + if axis is not None: + return torch.mean(a, dim=axis) + else: + return torch.mean(a) + + def std(self, a, axis=None): + if axis is not None: + return torch.std(a, dim=axis, unbiased=False) + else: + return torch.std(a, unbiased=False) + + def linspace(self, start, stop, num): + return torch.linspace(start, stop, num, dtype=torch.float64) + + def meshgrid(self, a, b): + X, Y = torch.meshgrid(a, b) + return X.T, Y.T + + def diag(self, a, k=0): + return torch.diag(a, diagonal=k) + + def unique(self, a): + return torch.unique(a) + + def logsumexp(self, a, axis=None): + if axis is not None: + return torch.logsumexp(a, dim=axis) + else: + return torch.logsumexp(a, dim=tuple(range(len(a.shape)))) + + def stack(self, arrays, axis=0): + return torch.stack(arrays, dim=axis) diff --git a/ot/bregman.py b/ot/bregman.py index 317c902..b59ee1b 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -19,7 +19,6 @@ import warnings import numpy as np from scipy.optimize import fmin_l_bfgs_b -from scipy.special import logsumexp from ot.utils import unif, dist, list_to_array from .backend import get_backend @@ -35,36 +34,36 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, .. math:: \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) - s.t. \gamma 1 = a + s.t. \ \gamma 1 = a \gamma^T 1= b \gamma\geq 0 where : - - M is the (dim_a, dim_b) metric cost matrix + - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - a and b are source and target weights (histograms, both sum to 1) + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (histograms, both sum to 1) .. note:: This function is backend-compatible and will work on arrays from all compatible backends. The algorithm used for solving the problem is the Sinkhorn-Knopp matrix - scaling algorithm as proposed in [2]_ + scaling algorithm as proposed in :ref:`[2] ` **Choosing a Sinkhorn solver** By default and when using a regularization parameter that is not too small the default sinkhorn solver should be enough. If you need to use a small regularization to get sharper OT matrices, you should use the - :any:`ot.bregman.sinkhorn_stabilized` solver that will avoid numerical + :py:func:`ot.bregman.sinkhorn_stabilized` solver that will avoid numerical errors. This last solver can be very slow in practice and might not even converge to a reasonable OT matrix in a finite time. This is why - :any:`ot.bregman.sinkhorn_epsilon_scaling` that relies on iterating the value + :py:func:`ot.bregman.sinkhorn_epsilon_scaling` that relies on iterating the value of the regularization (and using warm start) sometimes leads to better solutions. Note that the greedy version of the sinkhorn - :any:`ot.bregman.greenkhorn` can also lead to a speedup and the screening - version of the sinkhorn :any:`ot.bregman.screenkhorn` aim a providing a + :py:func:`ot.bregman.greenkhorn` can also lead to a speedup and the screening + version of the sinkhorn :py:func:`ot.bregman.screenkhorn` aim at providing a fast approximation of the Sinkhorn problem. @@ -74,7 +73,7 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, samples weights in the source domain b : array-like, shape (dim_b,) or ndarray, shape (dim_b, n_hists) samples in the target domain, compute sinkhorn with multiple targets - and fixed M if b is a matrix (return OT loss + dual variables in log) + and fixed :math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix (return OT loss + dual variables in log) M : array-like, shape (dim_a, dim_b) loss matrix reg : float @@ -85,7 +84,7 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, numItermax : int, optional Max number of iterations stopThr : float, optional - Stop threshol on error (>0) + Stop threshold on error (>0) verbose : bool, optional Print information along iterations log : bool, optional @@ -109,7 +108,7 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, array([[0.36552929, 0.13447071], [0.13447071, 0.36552929]]) - + .. _references-sinkhorn: References ---------- @@ -125,9 +124,9 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, -------- ot.lp.emd : Unregularized OT ot.optim.cg : General regularized OT - ot.bregman.sinkhorn_knopp : Classic Sinkhorn [2] - ot.bregman.sinkhorn_stabilized: Stabilized sinkhorn [9][10] - ot.bregman.sinkhorn_epsilon_scaling: Sinkhorn with epslilon scaling [9][10] + ot.bregman.sinkhorn_knopp : Classic Sinkhorn :ref:`[2] ` + ot.bregman.sinkhorn_stabilized: Stabilized sinkhorn :ref:`[9] ` :ref:`[10] ` + ot.bregman.sinkhorn_epsilon_scaling: Sinkhorn with epslilon scaling :ref:`[9] ` :ref:`[10] ` """ @@ -161,21 +160,21 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, .. math:: W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) - s.t. \gamma 1 = a + s.t. \ \gamma 1 = a \gamma^T 1= b \gamma\geq 0 where : - - M is the (dim_a, dim_b) metric cost matrix + - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - a and b are source and target weights (histograms, both sum to 1) + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (histograms, both sum to 1) .. note:: This function is backend-compatible and will work on arrays from all compatible backends. - The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [2]_ + The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[2] ` **Choosing a Sinkhorn solver** @@ -199,17 +198,17 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, samples weights in the source domain b : array-like, shape (dim_b,) or ndarray, shape (dim_b, n_hists) samples in the target domain, compute sinkhorn with multiple targets - and fixed M if b is a matrix (return OT loss + dual variables in log) + and fixed :math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix (return OT loss + dual variables in log) M : array-like, shape (dim_a, dim_b) loss matrix reg : float Regularization term >0 method : str - method used for the solver either 'sinkhorn', 'sinkhorn_stabilized', see those function for specific parameters + method used for the solver either 'sinkhorn', 'sinkhorn_stabilized', see those function for specific parameters numItermax : int, optional Max number of iterations stopThr : float, optional - Stop threshol on error (>0) + Stop threshold on error (>0) verbose : bool, optional Print information along iterations log : bool, optional @@ -234,7 +233,7 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, array([0.26894142]) - + .. _references-sinkhorn2: References ---------- @@ -244,7 +243,7 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. - [21] Altschuler J., Weed J., Rigollet P. : Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration, Advances in Neural Information Processing Systems (NIPS) 31, 2017 + .. [21] Altschuler J., Weed J., Rigollet P. : Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration, Advances in Neural Information Processing Systems (NIPS) 31, 2017 @@ -252,9 +251,9 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, -------- ot.lp.emd : Unregularized OT ot.optim.cg : General regularized OT - ot.bregman.sinkhorn_knopp : Classic Sinkhorn [2] - ot.bregman.greenkhorn : Greenkhorn [21] - ot.bregman.sinkhorn_stabilized: Stabilized sinkhorn [9][10] + ot.bregman.sinkhorn_knopp : Classic Sinkhorn :ref:`[2] ` + ot.bregman.greenkhorn : Greenkhorn :ref:`[21] ` + ot.bregman.sinkhorn_stabilized: Stabilized sinkhorn :ref:`[9] ` :ref:`[10] ` """ @@ -291,21 +290,21 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, \gamma\geq 0 where : - - M is the (dim_a, dim_b) metric cost matrix + - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - a and b are source and target weights (histograms, both sum to 1) + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (histograms, both sum to 1) - The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [2]_ + The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[2] ` Parameters ---------- - a : ndarray, shape (dim_a,) + a : array-like, shape (dim_a,) samples weights in the source domain - b : ndarray, shape (dim_b,) or ndarray, shape (dim_b, n_hists) + b : array-like, shape (dim_b,) or array-like, shape (dim_b, n_hists) samples in the target domain, compute sinkhorn with multiple targets - and fixed M if b is a matrix (return OT loss + dual variables in log) - M : ndarray, shape (dim_a, dim_b) + and fixed :math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix (return OT loss + dual variables in log) + M : array-like, shape (dim_a, dim_b) loss matrix reg : float Regularization term >0 @@ -320,7 +319,7 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, Returns ------- - gamma : ndarray, shape (dim_a, dim_b) + gamma : array-like, shape (dim_a, dim_b) Optimal transportation matrix for the given parameters log : dict log dictionary return only if log==True in parameters @@ -337,6 +336,7 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, [0.13447071, 0.36552929]]) + .. _references-sinkhorn-knopp: References ---------- @@ -388,7 +388,6 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, while (err > stopThr and cpt < numItermax): uprev = u vprev = v - KtransposeU = nx.dot(K.T, u) v = b / KtransposeU u = 1. / nx.dot(Kp, v) @@ -444,53 +443,46 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, r""" Solve the entropic regularization optimal transport problem and return the OT matrix - The algorithm used is based on the paper - - Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration - by Jason Altschuler, Jonathan Weed, Philippe Rigollet - appeared at NIPS 2017 - - which is a stochastic version of the Sinkhorn-Knopp algorithm [2]. + The algorithm used is based on the paper :ref:`[22] ` which is a stochastic version of the Sinkhorn-Knopp algorithm :ref:`[2] ` The function solves the following optimization problem: .. math:: \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) - s.t. \gamma 1 = a + s.t. \ \gamma 1 = a \gamma^T 1= b \gamma\geq 0 where : - - M is the (dim_a, dim_b) metric cost matrix + - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - a and b are source and target weights (histograms, both sum to 1) - + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (histograms, both sum to 1) Parameters ---------- - a : ndarray, shape (dim_a,) + a : array-like, shape (dim_a,) samples weights in the source domain - b : ndarray, shape (dim_b,) or ndarray, shape (dim_b, n_hists) + b : array-like, shape (dim_b,) or array-like, shape (dim_b, n_hists) samples in the target domain, compute sinkhorn with multiple targets - and fixed M if b is a matrix (return OT loss + dual variables in log) - M : ndarray, shape (dim_a, dim_b) + and fixed :math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix (return OT loss + dual variables in log) + M : array-like, shape (dim_a, dim_b) loss matrix reg : float Regularization term >0 numItermax : int, optional Max number of iterations stopThr : float, optional - Stop threshol on error (>0) + Stop threshold on error (>0) log : bool, optional record log if True Returns ------- - gamma : ndarray, shape (dim_a, dim_b) + gamma : array-like, shape (dim_a, dim_b) Optimal transportation matrix for the given parameters log : dict log dictionary return only if log==True in parameters @@ -507,11 +499,13 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, [0.13447071, 0.36552929]]) + .. _references-greenkhorn: References ---------- .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 - [22] J. Altschuler, J.Weed, P. Rigollet : Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration, Advances in Neural Information Processing Systems (NIPS) 31, 2017 + + .. [22] J. Altschuler, J.Weed, P. Rigollet : Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration, Advances in Neural Information Processing Systems (NIPS) 31, 2017 See Also @@ -521,60 +515,58 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, """ - a = np.asarray(a, dtype=np.float64) - b = np.asarray(b, dtype=np.float64) - M = np.asarray(M, dtype=np.float64) + a, b, M = list_to_array(a, b, M) + + nx = get_backend(M, a, b) + if nx.__name__ == "jax": + raise TypeError("JAX arrays have been received. Greenkhorn is not compatible with JAX") if len(a) == 0: - a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0] + a = nx.ones((M.shape[0],), type_as=M) / M.shape[0] if len(b) == 0: - b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1] + b = nx.ones((M.shape[1],), type_as=M) / M.shape[1] dim_a = a.shape[0] dim_b = b.shape[0] - # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute - K = np.empty_like(M) - np.divide(M, -reg, out=K) - np.exp(K, out=K) + K = nx.exp(-M / reg) - u = np.full(dim_a, 1. / dim_a) - v = np.full(dim_b, 1. / dim_b) - G = u[:, np.newaxis] * K * v[np.newaxis, :] + u = nx.full((dim_a,), 1. / dim_a, type_as=K) + v = nx.full((dim_b,), 1. / dim_b, type_as=K) + G = u[:, None] * K * v[None, :] - viol = G.sum(1) - a - viol_2 = G.sum(0) - b + viol = nx.sum(G, axis=1) - a + viol_2 = nx.sum(G, axis=0) - b stopThr_val = 1 - if log: log = dict() log['u'] = u log['v'] = v for i in range(numItermax): - i_1 = np.argmax(np.abs(viol)) - i_2 = np.argmax(np.abs(viol_2)) - m_viol_1 = np.abs(viol[i_1]) - m_viol_2 = np.abs(viol_2[i_2]) - stopThr_val = np.maximum(m_viol_1, m_viol_2) + i_1 = nx.argmax(nx.abs(viol)) + i_2 = nx.argmax(nx.abs(viol_2)) + m_viol_1 = nx.abs(viol[i_1]) + m_viol_2 = nx.abs(viol_2[i_2]) + stopThr_val = nx.maximum(m_viol_1, m_viol_2) if m_viol_1 > m_viol_2: old_u = u[i_1] - u[i_1] = a[i_1] / (K[i_1, :].dot(v)) - G[i_1, :] = u[i_1] * K[i_1, :] * v - - viol[i_1] = u[i_1] * K[i_1, :].dot(v) - a[i_1] - viol_2 += (K[i_1, :].T * (u[i_1] - old_u) * v) + new_u = a[i_1] / (K[i_1, :].dot(v)) + G[i_1, :] = new_u * K[i_1, :] * v + viol[i_1] = new_u * K[i_1, :].dot(v) - a[i_1] + viol_2 += (K[i_1, :].T * (new_u - old_u) * v) + u[i_1] = new_u else: old_v = v[i_2] - v[i_2] = b[i_2] / (K[:, i_2].T.dot(u)) - G[:, i_2] = u * K[:, i_2] * v[i_2] + new_v = b[i_2] / (K[:, i_2].T.dot(u)) + G[:, i_2] = u * K[:, i_2] * new_v # aviol = (G@one_m - a) # aviol_2 = (G.T@one_n - b) - viol += (-old_v + v[i_2]) * K[:, i_2] * u - viol_2[i_2] = v[i_2] * K[:, i_2].dot(u) - b[i_2] - + viol += (-old_v + new_v) * K[:, i_2] * u + viol_2[i_2] = new_v * K[:, i_2].dot(u) - b[i_2] + v[i_2] = new_v # print('b',np.max(abs(aviol -viol)),np.max(abs(aviol_2 - viol_2))) if stopThr_val <= stopThr: @@ -603,41 +595,41 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, .. math:: \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) - s.t. \gamma 1 = a + s.t. \ \gamma 1 = a \gamma^T 1= b \gamma\geq 0 where : - - M is the (dim_a, dim_b) metric cost matrix + - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - a and b are source and target weights (histograms, both sum to 1) + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (histograms, both sum to 1) The algorithm used for solving the problem is the Sinkhorn-Knopp matrix - scaling algorithm as proposed in [2]_ but with the log stabilization - proposed in [10]_ an defined in [9]_ (Algo 3.1) . + scaling algorithm as proposed in :ref:`[2] ` but with the log stabilization + proposed in :ref:`[10] ` an defined in :ref:`[9] ` (Algo 3.1) . Parameters ---------- - a : ndarray, shape (dim_a,) + a : array-like, shape (dim_a,) samples weights in the source domain - b : ndarray, shape (dim_b,) + b : array-like, shape (dim_b,) samples in the target domain - M : ndarray, shape (dim_a, dim_b) + M : array-like, shape (dim_a, dim_b) loss matrix reg : float Regularization term >0 tau : float - thershold for max value in u or v for log scaling - warmstart : tible of vectors - if given then sarting values for alpha an beta log scalings + threshold for max value in :math:`\mathbf{u}` or :math:`\mathbf{v}` for log scaling + warmstart : table of vectors + if given then starting values for alpha and beta log scalings numItermax : int, optional Max number of iterations stopThr : float, optional - Stop threshol on error (>0) + Stop threshold on error (>0) verbose : bool, optional Print information along iterations log : bool, optional @@ -645,7 +637,7 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, Returns ------- - gamma : ndarray, shape (dim_a, dim_b) + gamma : array-like, shape (dim_a, dim_b) Optimal transportation matrix for the given parameters log : dict log dictionary return only if log==True in parameters @@ -662,6 +654,7 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, [0.13447071, 0.36552929]]) + .. _references-sinkhorn-stabilized: References ---------- @@ -679,19 +672,19 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, """ - a = np.asarray(a, dtype=np.float64) - b = np.asarray(b, dtype=np.float64) - M = np.asarray(M, dtype=np.float64) + a, b, M = list_to_array(a, b, M) + + nx = get_backend(M, a, b) if len(a) == 0: - a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0] + a = nx.ones((M.shape[0],), type_as=M) / M.shape[0] if len(b) == 0: - b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1] + b = nx.ones((M.shape[1],), type_as=M) / M.shape[1] # test if multiple target if len(b.shape) > 1: n_hists = b.shape[1] - a = a[:, np.newaxis] + a = a[:, None] else: n_hists = 0 @@ -706,25 +699,25 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, # we assume that no distances are null except those of the diagonal of # distances if warmstart is None: - alpha, beta = np.zeros(dim_a), np.zeros(dim_b) + alpha, beta = nx.zeros(dim_a, type_as=M), nx.zeros(dim_b, type_as=M) else: alpha, beta = warmstart if n_hists: - u = np.ones((dim_a, n_hists)) / dim_a - v = np.ones((dim_b, n_hists)) / dim_b + u = nx.ones((dim_a, n_hists), type_as=M) / dim_a + v = nx.ones((dim_b, n_hists), type_as=M) / dim_b else: - u, v = np.ones(dim_a) / dim_a, np.ones(dim_b) / dim_b + u, v = nx.ones(dim_a, type_as=M) / dim_a, nx.ones(dim_b, type_as=M) / dim_b def get_K(alpha, beta): """log space computation""" - return np.exp(-(M - alpha.reshape((dim_a, 1)) + return nx.exp(-(M - alpha.reshape((dim_a, 1)) - beta.reshape((1, dim_b))) / reg) def get_Gamma(alpha, beta, u, v): """log space gamma computation""" - return np.exp(-(M - alpha.reshape((dim_a, 1)) - beta.reshape((1, dim_b))) - / reg + np.log(u.reshape((dim_a, 1))) + np.log(v.reshape((1, dim_b)))) + return nx.exp(-(M - alpha.reshape((dim_a, 1)) - beta.reshape((1, dim_b))) + / reg + nx.log(u.reshape((dim_a, 1))) + nx.log(v.reshape((1, dim_b)))) # print(np.min(K)) @@ -739,33 +732,35 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, vprev = v # sinkhorn update - v = b / (np.dot(K.T, u) + 1e-16) - u = a / (np.dot(K, v) + 1e-16) + v = b / (nx.dot(K.T, u) + 1e-16) + u = a / (nx.dot(K, v) + 1e-16) # remove numerical problems and store them in K - if np.abs(u).max() > tau or np.abs(v).max() > tau: + if nx.max(nx.abs(u)) > tau or nx.max(nx.abs(v)) > tau: if n_hists: - alpha, beta = alpha + reg * np.max(np.log(u), 1), beta + reg * np.max(np.log(v)) + alpha, beta = alpha + reg * nx.max(nx.log(u), 1), beta + reg * nx.max(np.log(v)) else: - alpha, beta = alpha + reg * np.log(u), beta + reg * np.log(v) + alpha, beta = alpha + reg * nx.log(u), beta + reg * nx.log(v) if n_hists: - u, v = np.ones((dim_a, n_hists)) / dim_a, np.ones((dim_b, n_hists)) / dim_b + u = nx.ones((dim_a, n_hists), type_as=M) / dim_a + v = nx.ones((dim_b, n_hists), type_as=M) / dim_b else: - u, v = np.ones(dim_a) / dim_a, np.ones(dim_b) / dim_b + u = nx.ones(dim_a, type_as=M) / dim_a + v = nx.ones(dim_b, type_as=M) / dim_b K = get_K(alpha, beta) if cpt % print_period == 0: # we can speed up the process by checking for the error only all # the 10th iterations if n_hists: - err_u = abs(u - uprev).max() - err_u /= max(abs(u).max(), abs(uprev).max(), 1.) - err_v = abs(v - vprev).max() - err_v /= max(abs(v).max(), abs(vprev).max(), 1.) + err_u = nx.max(nx.abs(u - uprev)) + err_u /= max(nx.max(nx.abs(u)), nx.max(nx.abs(uprev)), 1.0) + err_v = nx.max(nx.abs(v - vprev)) + err_v /= max(nx.max(nx.abs(v)), nx.max(nx.abs(vprev)), 1.0) err = 0.5 * (err_u + err_v) else: transp = get_Gamma(alpha, beta, u, v) - err = np.linalg.norm((np.sum(transp, axis=0) - b)) + err = nx.norm(nx.sum(transp, axis=0) - b) if log: log['err'].append(err) @@ -781,7 +776,7 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, if cpt >= numItermax: loop = False - if np.any(np.isnan(u)) or np.any(np.isnan(v)): + if nx.any(nx.isnan(u)) or nx.any(nx.isnan(v)): # we have reached the machine precision # come back to previous solution and quit loop print('Warning: numerical errors at iteration', cpt) @@ -795,26 +790,28 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, if n_hists: alpha = alpha[:, None] beta = beta[:, None] - logu = alpha / reg + np.log(u) - logv = beta / reg + np.log(v) + logu = alpha / reg + nx.log(u) + logv = beta / reg + nx.log(v) log['logu'] = logu log['logv'] = logv - log['alpha'] = alpha + reg * np.log(u) - log['beta'] = beta + reg * np.log(v) + log['alpha'] = alpha + reg * nx.log(u) + log['beta'] = beta + reg * nx.log(v) log['warmstart'] = (log['alpha'], log['beta']) if n_hists: - res = np.zeros((n_hists)) - for i in range(n_hists): - res[i] = np.sum(get_Gamma(alpha, beta, u[:, i], v[:, i]) * M) + res = nx.stack([ + nx.sum(get_Gamma(alpha, beta, u[:, i], v[:, i]) * M) + for i in range(n_hists) + ]) return res, log else: return get_Gamma(alpha, beta, u, v), log else: if n_hists: - res = np.zeros((n_hists)) - for i in range(n_hists): - res[i] = np.sum(get_Gamma(alpha, beta, u[:, i], v[:, i]) * M) + res = nx.stack([ + nx.sum(get_Gamma(alpha, beta, u[:, i], v[:, i]) * M) + for i in range(n_hists) + ]) return res else: return get_Gamma(alpha, beta, u, v) @@ -833,45 +830,45 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, .. math:: \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) - s.t. \gamma 1 = a + s.t. \ \gamma 1 = a \gamma^T 1= b \gamma\geq 0 where : - - M is the (dim_a, dim_b) metric cost matrix + - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - a and b are source and target weights (histograms, both sum to 1) + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (histograms, both sum to 1) The algorithm used for solving the problem is the Sinkhorn-Knopp matrix - scaling algorithm as proposed in [2]_ but with the log stabilization - proposed in [10]_ and the log scaling proposed in [9]_ algorithm 3.2 + scaling algorithm as proposed in :ref:`[2] ` but with the log stabilization + proposed in :ref:`[10] ` and the log scaling proposed in :ref:`[9] ` algorithm 3.2 Parameters ---------- - a : ndarray, shape (dim_a,) + a : array-like, shape (dim_a,) samples weights in the source domain - b : ndarray, shape (dim_b,) + b : array-like, shape (dim_b,) samples in the target domain - M : ndarray, shape (dim_a, dim_b) + M : array-like, shape (dim_a, dim_b) loss matrix reg : float Regularization term >0 tau : float - thershold for max value in u or v for log scaling + threshold for max value in :math:`\mathbf{u}` or :math:`\mathbf{b}` for log scaling warmstart : tuple of vectors - if given then sarting values for alpha an beta log scalings + if given then starting values for alpha and beta log scalings numItermax : int, optional Max number of iterations numInnerItermax : int, optional - Max number of iterationsin the inner slog stabilized sinkhorn + Max number of iterations in the inner slog stabilized sinkhorn epsilon0 : int, optional first epsilon regularization value (then exponential decrease to reg) stopThr : float, optional - Stop threshol on error (>0) + Stop threshold on error (>0) verbose : bool, optional Print information along iterations log : bool, optional @@ -879,7 +876,7 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, Returns ------- - gamma : ndarray, shape (dim_a, dim_b) + gamma : array-like, shape (dim_a, dim_b) Optimal transportation matrix for the given parameters log : dict log dictionary return only if log==True in parameters @@ -895,7 +892,7 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, array([[0.36552929, 0.13447071], [0.13447071, 0.36552929]]) - + .. _references-sinkhorn-epsilon-scaling: References ---------- @@ -903,6 +900,9 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519. + .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. + + See Also -------- ot.lp.emd : Unregularized OT @@ -910,14 +910,14 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, """ - a = np.asarray(a, dtype=np.float64) - b = np.asarray(b, dtype=np.float64) - M = np.asarray(M, dtype=np.float64) + a, b, M = list_to_array(a, b, M) + + nx = get_backend(M, a, b) if len(a) == 0: - a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0] + a = nx.ones((M.shape[0],), type_as=M) / M.shape[0] if len(b) == 0: - b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1] + b = nx.ones((M.shape[1],), type_as=M) / M.shape[1] # init data dim_a = len(a) @@ -934,7 +934,7 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, # we assume that no distances are null except those of the diagonal of # distances if warmstart is None: - alpha, beta = np.zeros(dim_a), np.zeros(dim_b) + alpha, beta = nx.zeros(dim_a, type_as=M), nx.zeros(dim_b, type_as=M) else: alpha, beta = warmstart @@ -964,15 +964,13 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, # we can speed up the process by checking for the error only all # the 10th iterations transp = G - err = np.linalg.norm( - (np.sum(transp, axis=0) - b)) ** 2 + np.linalg.norm((np.sum(transp, axis=1) - a)) ** 2 + err = nx.norm(nx.sum(transp, axis=0) - b) ** 2 + nx.norm(nx.sum(transp, axis=1) - a) ** 2 if log: log['err'].append(err) if verbose: if cpt % (print_period * 10) == 0: - print( - '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) print('{:5d}|{:8e}|'.format(cpt, err)) if err <= stopThr and cpt > numItermin: @@ -991,23 +989,31 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, def geometricBar(weights, alldistribT): """return the weighted geometric mean of distributions""" + weights, alldistribT = list_to_array(weights, alldistribT) + nx = get_backend(weights, alldistribT) assert (len(weights) == alldistribT.shape[1]) - return np.exp(np.dot(np.log(alldistribT), weights.T)) + return nx.exp(nx.dot(nx.log(alldistribT), weights.T)) def geometricMean(alldistribT): """return the geometric mean of distributions""" - return np.exp(np.mean(np.log(alldistribT), axis=1)) + alldistribT = list_to_array(alldistribT) + nx = get_backend(alldistribT) + return nx.exp(nx.mean(nx.log(alldistribT), axis=1)) def projR(gamma, p): """return the KL projection on the row constrints """ - return np.multiply(gamma.T, p / np.maximum(np.sum(gamma, axis=1), 1e-10)).T + gamma, p = list_to_array(gamma, p) + nx = get_backend(gamma, p) + return (gamma.T * p / nx.maximum(nx.sum(gamma, axis=1), 1e-10)).T def projC(gamma, q): """return the KL projection on the column constrints """ - return np.multiply(gamma, q / np.maximum(np.sum(gamma, axis=0), 1e-10)) + gamma, q = list_to_array(gamma, q) + nx = get_backend(gamma, q) + return gamma * q / nx.maximum(nx.sum(gamma, axis=0), 1e-10) def barycenter(A, M, reg, weights=None, method="sinkhorn", numItermax=10000, @@ -1021,28 +1027,28 @@ def barycenter(A, M, reg, weights=None, method="sinkhorn", numItermax=10000, where : - - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see ot.bregman.sinkhorn) + - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see :py:func:`ot.bregman.sinkhorn`) - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}` - - reg and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT + - `reg` and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT - The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [3]_ + The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[3] ` Parameters ---------- - A : ndarray, shape (dim, n_hists) - n_hists training distributions a_i of size dim - M : ndarray, shape (dim, dim) + A : array-like, shape (dim, n_hists) + `n_hists` training distributions :math:`a_i` of size `dim` + M : array-like, shape (dim, dim) loss matrix for OT reg : float Regularization term > 0 method : str (optional) method used for the solver either 'sinkhorn' or 'sinkhorn_stabilized' - weights : ndarray, shape (n_hists,) - Weights of each histogram a_i on the simplex (barycentric coodinates) + weights : array-like, shape (n_hists,) + Weights of each histogram :math:`a_i` on the simplex (barycentric coodinates) numItermax : int, optional Max number of iterations stopThr : float, optional - Stop threshol on error (>0) + Stop threshold on error (>0) verbose : bool, optional Print information along iterations log : bool, optional @@ -1051,12 +1057,13 @@ def barycenter(A, M, reg, weights=None, method="sinkhorn", numItermax=10000, Returns ------- - a : (dim,) ndarray + a : (dim,) array-like Wasserstein barycenter log : dict log dictionary return only if log==True in parameters + .. _references-barycenter: References ---------- @@ -1089,26 +1096,26 @@ def barycenter_sinkhorn(A, M, reg, weights=None, numItermax=1000, where : - - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see ot.bregman.sinkhorn) + - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see :py:func:`ot.bregman.sinkhorn`) - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}` - - reg and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT + - `reg` and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT - The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [3]_ + The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[3] ` Parameters ---------- - A : ndarray, shape (dim, n_hists) - n_hists training distributions a_i of size dim - M : ndarray, shape (dim, dim) + A : array-like, shape (dim, n_hists) + `n_hists` training distributions :math:`a_i` of size `dim` + M : array-like, shape (dim, dim) loss matrix for OT reg : float Regularization term > 0 - weights : ndarray, shape (n_hists,) - Weights of each histogram a_i on the simplex (barycentric coodinates) + weights : array-like, shape (n_hists,) + Weights of each histogram :math:`a_i` on the simplex (barycentric coodinates) numItermax : int, optional Max number of iterations stopThr : float, optional - Stop threshol on error (>0) + Stop threshold on error (>0) verbose : bool, optional Print information along iterations log : bool, optional @@ -1117,12 +1124,13 @@ def barycenter_sinkhorn(A, M, reg, weights=None, numItermax=1000, Returns ------- - a : (dim,) ndarray + a : (dim,) array-like Wasserstein barycenter log : dict log dictionary return only if log==True in parameters + .. _references-barycenter-sinkhorn: References ---------- @@ -1130,8 +1138,12 @@ def barycenter_sinkhorn(A, M, reg, weights=None, numItermax=1000, """ + A, M = list_to_array(A, M) + + nx = get_backend(A, M) + if weights is None: - weights = np.ones(A.shape[1]) / A.shape[1] + weights = nx.ones((A.shape[1],), type_as=A) / A.shape[1] else: assert (len(weights) == A.shape[1]) @@ -1139,21 +1151,22 @@ def barycenter_sinkhorn(A, M, reg, weights=None, numItermax=1000, log = {'err': []} # M = M/np.median(M) # suggested by G. Peyre - K = np.exp(-M / reg) + K = nx.exp(-M / reg) cpt = 0 err = 1 - UKv = np.dot(K, np.divide(A.T, np.sum(K, axis=0)).T) + UKv = nx.dot(K, (A.T / nx.sum(K, axis=0)).T) + u = (geometricMean(UKv) / UKv.T).T while (err > stopThr and cpt < numItermax): cpt = cpt + 1 - UKv = u * np.dot(K, np.divide(A, np.dot(K, u))) + UKv = u * nx.dot(K, A / nx.dot(K, u)) u = (u.T * geometricBar(weights, UKv)).T / UKv if cpt % 10 == 1: - err = np.sum(np.std(UKv, axis=1)) + err = nx.sum(nx.std(UKv, axis=1)) # log and verbose print if log: @@ -1174,8 +1187,7 @@ def barycenter_sinkhorn(A, M, reg, weights=None, numItermax=1000, def barycenter_stabilized(A, M, reg, tau=1e10, weights=None, numItermax=1000, stopThr=1e-4, verbose=False, log=False): - r"""Compute the entropic regularized wasserstein barycenter of distributions A - with stabilization. + r"""Compute the entropic regularized wasserstein barycenter of distributions A with stabilization. The function solves the following optimization problem: @@ -1184,28 +1196,28 @@ def barycenter_stabilized(A, M, reg, tau=1e10, weights=None, numItermax=1000, where : - - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see ot.bregman.sinkhorn) + - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see :py:func:`ot.bregman.sinkhorn`) - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}` - - reg and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT + - `reg` and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT - The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [3]_ + The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[3] ` Parameters ---------- - A : ndarray, shape (dim, n_hists) - n_hists training distributions a_i of size dim - M : ndarray, shape (dim, dim) + A : array-like, shape (dim, n_hists) + `n_hists` training distributions :math:`a_i` of size `dim` + M : array-like, shape (dim, dim) loss matrix for OT reg : float Regularization term > 0 tau : float - thershold for max value in u or v for log scaling - weights : ndarray, shape (n_hists,) - Weights of each histogram a_i on the simplex (barycentric coodinates) + threshold for max value in :math:`\mathbf{u}` or :math:`\mathbf{v}` for log scaling + weights : array-like, shape (n_hists,) + Weights of each histogram :math:`a_i` on the simplex (barycentric coodinates) numItermax : int, optional Max number of iterations stopThr : float, optional - Stop threshol on error (>0) + Stop threshold on error (>0) verbose : bool, optional Print information along iterations log : bool, optional @@ -1214,12 +1226,13 @@ def barycenter_stabilized(A, M, reg, tau=1e10, weights=None, numItermax=1000, Returns ------- - a : (dim,) ndarray + a : (dim,) array-like Wasserstein barycenter log : dict log dictionary return only if log==True in parameters + .. _references-barycenter-stabilized: References ---------- @@ -1227,49 +1240,48 @@ def barycenter_stabilized(A, M, reg, tau=1e10, weights=None, numItermax=1000, """ + A, M = list_to_array(A, M) + + nx = get_backend(A, M) + dim, n_hists = A.shape if weights is None: - weights = np.ones(n_hists) / n_hists + weights = nx.ones((n_hists,), type_as=M) / n_hists else: assert (len(weights) == A.shape[1]) if log: log = {'err': []} - u = np.ones((dim, n_hists)) / dim - v = np.ones((dim, n_hists)) / dim + u = nx.ones((dim, n_hists), type_as=M) / dim + v = nx.ones((dim, n_hists), type_as=M) / dim - # print(reg) - # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute - K = np.empty(M.shape, dtype=M.dtype) - np.divide(M, -reg, out=K) - np.exp(K, out=K) + K = nx.exp(-M / reg) cpt = 0 err = 1. - alpha = np.zeros(dim) - beta = np.zeros(dim) - q = np.ones(dim) / dim + alpha = nx.zeros((dim,), type_as=M) + beta = nx.zeros((dim,), type_as=M) + q = nx.ones((dim,), type_as=M) / dim while (err > stopThr and cpt < numItermax): qprev = q - Kv = K.dot(v) + Kv = nx.dot(K, v) u = A / (Kv + 1e-16) - Ktu = K.T.dot(u) + Ktu = nx.dot(K.T, u) q = geometricBar(weights, Ktu) Q = q[:, None] v = Q / (Ktu + 1e-16) absorbing = False - if (u > tau).any() or (v > tau).any(): + if nx.any(u > tau) or nx.any(v > tau): absorbing = True - alpha = alpha + reg * np.log(np.max(u, 1)) - beta = beta + reg * np.log(np.max(v, 1)) - K = np.exp((alpha[:, None] + beta[None, :] - - M) / reg) - v = np.ones_like(v) - Kv = K.dot(v) - if (np.any(Ktu == 0.) - or np.any(np.isnan(u)) or np.any(np.isnan(v)) - or np.any(np.isinf(u)) or np.any(np.isinf(v))): + alpha += reg * nx.log(nx.max(u, 1)) + beta += reg * nx.log(nx.max(v, 1)) + K = nx.exp((alpha[:, None] + beta[None, :] - M) / reg) + v = nx.ones(tuple(v.shape), type_as=v) + Kv = nx.dot(K, v) + if (nx.any(Ktu == 0.) + or nx.any(nx.isnan(u)) or nx.any(nx.isnan(v)) + or nx.any(nx.isinf(u)) or nx.any(nx.isinf(v))): # we have reached the machine precision # come back to previous solution and quit loop warnings.warn('Numerical errors at iteration %s' % cpt) @@ -1278,7 +1290,7 @@ def barycenter_stabilized(A, M, reg, tau=1e10, weights=None, numItermax=1000, if (cpt % 10 == 0 and not absorbing) or cpt == 0: # we can speed up the process by checking for the error only all # the 10th iterations - err = abs(u * Kv - A).max() + err = nx.max(nx.abs(u * Kv - A)) if log: log['err'].append(err) if verbose: @@ -1314,24 +1326,24 @@ def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, where : - - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see ot.bregman.sinkhorn) + - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see :py:func:`ot.bregman.sinkhorn`) - :math:`\mathbf{a}_i` are training distributions (2D images) in the mast two dimensions of matrix :math:`\mathbf{A}` - - reg is the regularization strength scalar value + - `reg` is the regularization strength scalar value - The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [21]_ + The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[21] ` Parameters ---------- - A : ndarray, shape (n_hists, width, height) - n distributions (2D images) of size width x height + A : array-like, shape (n_hists, width, height) + `n` distributions (2D images) of size `width` x `height` reg : float Regularization term >0 - weights : ndarray, shape (n_hists,) + weights : array-like, shape (n_hists,) Weights of each image on the simplex (barycentric coodinates) numItermax : int, optional Max number of iterations stopThr : float, optional - Stop threshol on error (> 0) + Stop threshold on error (> 0) stabThr : float, optional Stabilization threshold to avoid numerical precision issue verbose : bool, optional @@ -1341,64 +1353,73 @@ def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, Returns ------- - a : ndarray, shape (width, height) + a : array-like, shape (width, height) 2D Wasserstein barycenter log : dict log dictionary return only if log==True in parameters + + .. _references-convolutional-barycenter-2d: References ---------- - .. [21] Solomon, J., De Goes, F., Peyré, G., Cuturi, M., Butscher, A., Nguyen, A. & Guibas, L. (2015). - Convolutional wasserstein distances: Efficient optimal transportation on geometric domains - ACM Transactions on Graphics (TOG), 34(4), 66 + .. [21] Solomon, J., De Goes, F., Peyré, G., Cuturi, M., Butscher, A., Nguyen, A. & Guibas, L. (2015). Convolutional wasserstein distances: Efficient optimal transportation on geometric domains. ACM Transactions on Graphics (TOG), 34(4), 66 """ + A = list_to_array(A) + + nx = get_backend(A) + if weights is None: - weights = np.ones(A.shape[0]) / A.shape[0] + weights = nx.ones((A.shape[0],), type_as=A) / A.shape[0] else: assert (len(weights) == A.shape[0]) if log: log = {'err': []} - b = np.zeros_like(A[0, :, :]) - U = np.ones_like(A) - KV = np.ones_like(A) + b = nx.zeros(A.shape[1:], type_as=A) + U = nx.ones(A.shape, type_as=A) + KV = nx.ones(A.shape, type_as=A) cpt = 0 err = 1 # build the convolution operator # this is equivalent to blurring on horizontal then vertical directions - t = np.linspace(0, 1, A.shape[1]) - [Y, X] = np.meshgrid(t, t) - xi1 = np.exp(-(X - Y) ** 2 / reg) + t = nx.linspace(0, 1, A.shape[1]) + [Y, X] = nx.meshgrid(t, t) + xi1 = nx.exp(-(X - Y) ** 2 / reg) - t = np.linspace(0, 1, A.shape[2]) - [Y, X] = np.meshgrid(t, t) - xi2 = np.exp(-(X - Y) ** 2 / reg) + t = nx.linspace(0, 1, A.shape[2]) + [Y, X] = nx.meshgrid(t, t) + xi2 = nx.exp(-(X - Y) ** 2 / reg) def K(x): - return np.dot(np.dot(xi1, x), xi2) + return nx.dot(nx.dot(xi1, x), xi2) while (err > stopThr and cpt < numItermax): bold = b cpt = cpt + 1 - b = np.zeros_like(A[0, :, :]) - for r in range(A.shape[0]): - KV[r, :, :] = K(A[r, :, :] / np.maximum(stabThr, K(U[r, :, :]))) - b += weights[r] * np.log(np.maximum(stabThr, U[r, :, :] * KV[r, :, :])) - b = np.exp(b) + b = nx.zeros(A.shape[1:], type_as=A) + KV_cols = [] for r in range(A.shape[0]): - U[r, :, :] = b / np.maximum(stabThr, KV[r, :, :]) - + KV_col_r = K(A[r, :, :] / nx.maximum(stabThr, K(U[r, :, :]))) + b += weights[r] * nx.log(nx.maximum(stabThr, U[r, :, :] * KV_col_r)) + KV_cols.append(KV_col_r) + KV = nx.stack(KV_cols) + b = nx.exp(b) + + U = nx.stack([ + b / nx.maximum(stabThr, KV[r, :, :]) + for r in range(A.shape[0]) + ]) if cpt % 10 == 1: - err = np.sum(np.abs(bold - b)) + err = nx.sum(nx.abs(bold - b)) # log and verbose print if log: log['err'].append(err) @@ -1424,34 +1445,35 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000, The function solve the following optimization problem: .. math:: - \mathbf{h} = arg\min_\mathbf{h} (1- \\alpha) W_{M,reg}(\mathbf{a},\mathbf{Dh})+\\alpha W_{M0,reg0}(\mathbf{h}_0,\mathbf{h}) + + \mathbf{h} = arg\min_\mathbf{h} (1- \alpha) W_{M,reg}(\mathbf{a},\mathbf{Dh})+\alpha W_{M_0,reg_0}(\mathbf{h}_0,\mathbf{h}) where : - - :math:`W_{M,reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance with M loss matrix (see ot.bregman.sinkhorn) - - :math: `\mathbf{D}` is a dictionary of `n_atoms` atoms of dimension `dim_a`, its expected shape is `(dim_a, n_atoms)` + - :math:`W_{M,reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance with M loss matrix (see :py:func:`ot.bregman.sinkhorn`) + - :math:`\mathbf{D}` is a dictionary of `n_atoms` atoms of dimension `dim_a`, its expected shape is `(dim_a, n_atoms)` - :math:`\mathbf{h}` is the estimated unmixing of dimension `n_atoms` - :math:`\mathbf{a}` is an observed distribution of dimension `dim_a` - - :math:`\mathbf{h}_0` is a prior on `h` of dimension `dim_prior` - - reg and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix (dim_a, dim_a) for OT data fitting - - reg0 and :math:`\mathbf{M0}` are respectively the regularization term and the cost matrix (dim_prior, n_atoms) regularization - - :math:`\\alpha`weight data fitting and regularization + - :math:`\mathbf{h}_0` is a prior on :math:`\mathbf{h}` of dimension `dim_prior` + - `reg` and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix (`dim_a`, `dim_a`) for OT data fitting + - `reg`:math:`_0` and :math:`\mathbf{M_0}` are respectively the regularization term and the cost matrix (`dim_prior`, `n_atoms`) regularization + - :math:`\\alpha` weight data fitting and regularization - The optimization problem is solved suing the algorithm described in [4] + The optimization problem is solved following the algorithm described in :ref:`[4] ` Parameters ---------- - a : ndarray, shape (dim_a) + a : array-like, shape (dim_a) observed distribution (histogram, sums to 1) - D : ndarray, shape (dim_a, n_atoms) + D : array-like, shape (dim_a, n_atoms) dictionary matrix - M : ndarray, shape (dim_a, dim_a) + M : array-like, shape (dim_a, dim_a) loss matrix - M0 : ndarray, shape (n_atoms, dim_prior) + M0 : array-like, shape (n_atoms, dim_prior) loss matrix - h0 : ndarray, shape (n_atoms,) + h0 : array-like, shape (n_atoms,) prior on the estimated unmixing h reg : float Regularization term >0 (Wasserstein data fitting) @@ -1462,7 +1484,7 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000, numItermax : int, optional Max number of iterations stopThr : float, optional - Stop threshol on error (>0) + Stop threshold on error (>0) verbose : bool, optional Print information along iterations log : bool, optional @@ -1471,11 +1493,13 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000, Returns ------- - h : ndarray, shape (n_atoms,) + h : array-like, shape (n_atoms,) Wasserstein barycenter log : dict log dictionary return only if log==True in parameters + + .. _references-unmix: References ---------- @@ -1483,11 +1507,15 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000, """ + a, D, M, M0, h0 = list_to_array(a, D, M, M0, h0) + + nx = get_backend(a, D, M, M0, h0) + # M = M/np.median(M) - K = np.exp(-M / reg) + K = nx.exp(-M / reg) # M0 = M0/np.median(M0) - K0 = np.exp(-M0 / reg0) + K0 = nx.exp(-M0 / reg0) old = h0 err = 1 @@ -1499,16 +1527,16 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000, while (err > stopThr and cpt < numItermax): K = projC(K, a) K0 = projC(K0, h0) - new = np.sum(K0, axis=1) + new = nx.sum(K0, axis=1) # we recombine the current selection from dictionnary - inv_new = np.dot(D, new) - other = np.sum(K, axis=1) + inv_new = nx.dot(D, new) + other = nx.sum(K, axis=1) # geometric interpolation - delta = np.exp(alpha * np.log(other) + (1 - alpha) * np.log(inv_new)) + delta = nx.exp(alpha * nx.log(other) + (1 - alpha) * nx.log(inv_new)) K = projR(K, delta) - K0 = np.dot(np.diag(np.dot(D.T, delta / inv_new)), K0) + K0 = nx.dot(nx.diag(nx.dot(D.T, delta / inv_new)), K0) - err = np.linalg.norm(np.sum(K0, axis=1) - old) + err = nx.norm(nx.sum(K0, axis=1) - old) old = new if log: log['err'].append(err) @@ -1522,14 +1550,14 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000, if log: log['niter'] = cpt - return np.sum(K0, axis=1), log + return nx.sum(K0, axis=1), log else: - return np.sum(K0, axis=1) + return nx.sum(K0, axis=1) def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, stopThr=1e-6, verbose=False, log=False, **kwargs): - r'''Joint OT and proportion estimation for multi-source target shift as proposed in [27] + r'''Joint OT and proportion estimation for multi-source target shift as proposed in :ref:`[27] ` The function solves the following optimization problem: @@ -1542,12 +1570,12 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, where : - - :math:`\lambda_k` is the weight of k-th source domain - - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see ot.bregman.sinkhorn) - - :math:`\mathbf{D}_2^{(k)}` is a matrix of weights related to k-th source domain defined as in [p. 5, 27], its expected shape is `(n_k, C)` where `n_k` is the number of elements in the k-th source domain and `C` is the number of classes - - :math:`\mathbf{h}` is a vector of estimated proportions in the target domain of size C + - :math:`\lambda_k` is the weight of `k`-th source domain + - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see :py:func:`ot.bregman.sinkhorn`) + - :math:`\mathbf{D}_2^{(k)}` is a matrix of weights related to `k`-th source domain defined as in [p. 5, :ref:`27 `], its expected shape is :math:`(n_k, C)` where :math:`n_k` is the number of elements in the `k`-th source domain and `C` is the number of classes + - :math:`\mathbf{h}` is a vector of estimated proportions in the target domain of size `C` - :math:`\mathbf{a}` is a uniform vector of weights in the target domain of size `n` - - :math:`\mathbf{D}_1^{(k)}` is a matrix of class assignments defined as in [p. 5, 27], its expected shape is `(n_k, C)` + - :math:`\mathbf{D}_1^{(k)}` is a matrix of class assignments defined as in [p. 5, :ref:`27 `], its expected shape is :math:`(n_k, C)` The problem consist in solving a Wasserstein barycenter problem to estimate the proportions :math:`\mathbf{h}` in the target domain. @@ -1556,11 +1584,11 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, Parameters ---------- - Xs : list of K np.ndarray(nsk,d) + Xs : list of K array-like(nsk,d) features of all source domains' samples - Ys : list of K np.ndarray(nsk,) + Ys : list of K array-like(nsk,) labels of all source domains' samples - Xt : np.ndarray (nt,d) + Xt : array-like (nt,d) samples in the target domain reg : float Regularization term > 0 @@ -1577,12 +1605,13 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, Returns ------- - h : (C,) ndarray + h : (C,) array-like proportion estimation in the target domain log : dict log dictionary return only if log==True in parameters + .. _references-jcpot-barycenter: References ---------- @@ -1591,7 +1620,14 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, International Conference on Artificial Intelligence and Statistics (AISTATS), 2019. ''' - nbclasses = len(np.unique(Ys[0])) + + Xs = list_to_array(*Xs) + Ys = list_to_array(*Ys) + Xt = list_to_array(Xt) + + nx = get_backend(*Xs, *Ys, Xt) + + nbclasses = len(nx.unique(Ys[0])) nbdomains = len(Xs) # log dictionary @@ -1608,19 +1644,19 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, dom = {} nsk = Xs[d].shape[0] # get number of elements for this domain dom['nbelem'] = nsk - classes = np.unique(Ys[d]) # get number of classes for this domain + classes = nx.unique(Ys[d]) # get number of classes for this domain # format classes to start from 0 for convenience - if np.min(classes) != 0: - Ys[d] = Ys[d] - np.min(classes) - classes = np.unique(Ys[d]) + if nx.min(classes) != 0: + Ys[d] -= nx.min(classes) + classes = nx.unique(Ys[d]) # build the corresponding D_1 and D_2 matrices - Dtmp1 = np.zeros((nbclasses, nsk)) - Dtmp2 = np.zeros((nbclasses, nsk)) + Dtmp1 = nx.zeros((nbclasses, nsk), type_as=Xs[0]) + Dtmp2 = nx.zeros((nbclasses, nsk), type_as=Xs[0]) for c in classes: - nbelemperclass = np.sum(Ys[d] == c) + nbelemperclass = nx.sum(Ys[d] == c) if nbelemperclass != 0: Dtmp1[int(c), Ys[d] == c] = 1. Dtmp2[int(c), Ys[d] == c] = 1. / (nbelemperclass) @@ -1631,36 +1667,34 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, Mtmp = dist(Xs[d], Xt, metric=metric) M.append(Mtmp) - Ktmp = np.empty(Mtmp.shape, dtype=Mtmp.dtype) - np.divide(Mtmp, -reg, out=Ktmp) - np.exp(Ktmp, out=Ktmp) + Ktmp = nx.exp(-Mtmp / reg) K.append(Ktmp) # uniform target distribution - a = unif(np.shape(Xt)[0]) + a = nx.from_numpy(unif(np.shape(Xt)[0])) cpt = 0 # iterations count err = 1 - old_bary = np.ones((nbclasses)) + old_bary = nx.ones((nbclasses,), type_as=Xs[0]) while (err > stopThr and cpt < numItermax): - bary = np.zeros((nbclasses)) + bary = nx.zeros((nbclasses,), type_as=Xs[0]) # update coupling matrices for marginal constraints w.r.t. uniform target distribution for d in range(nbdomains): K[d] = projC(K[d], a) - other = np.sum(K[d], axis=1) - bary = bary + np.log(np.dot(D1[d], other)) / nbdomains + other = nx.sum(K[d], axis=1) + bary += nx.log(nx.dot(D1[d], other)) / nbdomains - bary = np.exp(bary) + bary = nx.exp(bary) # update coupling matrices for marginal constraints w.r.t. unknown proportions based on [Prop 4., 27] for d in range(nbdomains): - new = np.dot(D2[d].T, bary) + new = nx.dot(D2[d].T, bary) K[d] = projR(K[d], new) - err = np.linalg.norm(bary - old_bary) + err = nx.norm(bary - old_bary) cpt = cpt + 1 old_bary = bary @@ -1672,7 +1706,7 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) print('{:5d}|{:8e}|'.format(cpt, err)) - bary = bary / np.sum(bary) + bary = bary / nx.sum(bary) if log: log['niter'] = cpt @@ -1697,39 +1731,38 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', .. math:: \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) - s.t. \gamma 1 = a + s.t. \ \gamma 1 = a \gamma^T 1= b \gamma\geq 0 where : - - :math:`M` is the (n_samples_a, n_samples_b) metric cost matrix + - :math:`\mathbf{M}` is the (`n_samples_a`, `n_samples_b`) metric cost matrix - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - :math:`a` and :math:`b` are source and target weights (sum to 1) + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1) Parameters ---------- - X_s : ndarray, shape (n_samples_a, dim) + X_s : array-like, shape (n_samples_a, dim) samples in the source domain - X_t : ndarray, shape (n_samples_b, dim) + X_t : array-like, shape (n_samples_b, dim) samples in the target domain reg : float Regularization term >0 - a : ndarray, shape (n_samples_a,) + a : array-like, shape (n_samples_a,) samples weights in the source domain - b : ndarray, shape (n_samples_b,) + b : array-like, shape (n_samples_b,) samples weights in the target domain numItermax : int, optional Max number of iterations stopThr : float, optional - Stop threshol on error (>0) + Stop threshold on error (>0) isLazy: boolean, optional - If True, then only calculate the cost matrix by block and return the dual potentials only (to save memory) - If False, calculate full cost matrix and return outputs of sinkhorn function. + If True, then only calculate the cost matrix by block and return the dual potentials only (to save memory). If False, calculate full cost matrix and return outputs of sinkhorn function. batchSize: int or tuple of 2 int, optional - Size of the batcheses used to compute the sinkhorn update without memory overhead. + Size of the batches used to compute the sinkhorn update without memory overhead. When a tuple is provided it sets the size of the left/right batches. verbose : bool, optional Print information along iterations @@ -1739,7 +1772,7 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', Returns ------- - gamma : ndarray, shape (n_samples_a, n_samples_b) + gamma : array-like, shape (n_samples_a, n_samples_b) Regularized optimal transportation matrix for the given parameters log : dict log dictionary return only if log==True in parameters @@ -1766,18 +1799,23 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. ''' + + X_s, X_t = list_to_array(X_s, X_t) + + nx = get_backend(X_s, X_t) + ns, nt = X_s.shape[0], X_t.shape[0] if a is None: - a = unif(ns) + a = nx.from_numpy(unif(ns)) if b is None: - b = unif(nt) + b = nx.from_numpy(unif(nt)) if isLazy: if log: dict_log = {"err": []} - log_a, log_b = np.log(a), np.log(b) - f, g = np.zeros(ns), np.zeros(nt) + log_a, log_b = nx.log(a), nx.log(b) + f, g = nx.zeros((ns,), type_as=a), nx.zeros((nt,), type_as=a) if isinstance(batchSize, int): bs, bt = batchSize, batchSize @@ -1788,27 +1826,44 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', range_s, range_t = range(0, ns, bs), range(0, nt, bt) - lse_f = np.zeros(ns) - lse_g = np.zeros(nt) + lse_f = nx.zeros((ns,), type_as=a) + lse_g = nx.zeros((nt,), type_as=a) + + X_s_np = nx.to_numpy(X_s) + X_t_np = nx.to_numpy(X_t) for i_ot in range(numIterMax): + lse_f_cols = [] for i in range_s: - M = dist(X_s[i:i + bs, :], X_t, metric=metric) - lse_f[i:i + bs] = logsumexp(g[None, :] - M / reg, axis=1) + M = dist(X_s_np[i:i + bs, :], X_t_np, metric=metric) + M = nx.from_numpy(M, type_as=a) + lse_f_cols.append( + nx.logsumexp(g[None, :] - M / reg, axis=1) + ) + lse_f = nx.concatenate(lse_f_cols, axis=0) f = log_a - lse_f + lse_g_cols = [] for j in range_t: - M = dist(X_s, X_t[j:j + bt, :], metric=metric) - lse_g[j:j + bt] = logsumexp(f[:, None] - M / reg, axis=0) + M = dist(X_s_np, X_t_np[j:j + bt, :], metric=metric) + M = nx.from_numpy(M, type_as=a) + lse_g_cols.append( + nx.logsumexp(f[:, None] - M / reg, axis=0) + ) + lse_g = nx.concatenate(lse_g_cols, axis=0) g = log_b - lse_g if (i_ot + 1) % 10 == 0: - m1 = np.zeros_like(a) + m1_cols = [] for i in range_s: - M = dist(X_s[i:i + bs, :], X_t, metric=metric) - m1[i:i + bs] = np.exp(f[i:i + bs, None] + g[None, :] - M / reg).sum(1) - err = np.abs(m1 - a).sum() + M = dist(X_s_np[i:i + bs, :], X_t_np, metric=metric) + M = nx.from_numpy(M, type_as=a) + m1_cols.append( + nx.sum(nx.exp(f[i:i + bs, None] + g[None, :] - M / reg), axis=1) + ) + m1 = nx.concatenate(m1_cols, axis=0) + err = nx.sum(nx.abs(m1 - a)) if log: dict_log["err"].append(err) @@ -1826,8 +1881,8 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', return (f, g) else: - M = dist(X_s, X_t, metric=metric) - + M = dist(nx.to_numpy(X_s), nx.to_numpy(X_t), metric=metric) + M = nx.from_numpy(M, type_as=a) if log: pi, log = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=True, **kwargs) return pi, log @@ -1848,39 +1903,38 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num .. math:: W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) - s.t. \gamma 1 = a + s.t. \ \gamma 1 = a \gamma^T 1= b \gamma\geq 0 where : - - :math:`M` is the (n_samples_a, n_samples_b) metric cost matrix + - :math:`\mathbf{M}` is the (`n_samples_a`, `n_samples_b`) metric cost matrix - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - :math:`a` and :math:`b` are source and target weights (sum to 1) + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1) Parameters ---------- - X_s : ndarray, shape (n_samples_a, dim) + X_s : array-like, shape (n_samples_a, dim) samples in the source domain - X_t : ndarray, shape (n_samples_b, dim) + X_t : array-like, shape (n_samples_b, dim) samples in the target domain reg : float Regularization term >0 - a : ndarray, shape (n_samples_a,) + a : array-like, shape (n_samples_a,) samples weights in the source domain - b : ndarray, shape (n_samples_b,) + b : array-like, shape (n_samples_b,) samples weights in the target domain numItermax : int, optional Max number of iterations stopThr : float, optional - Stop threshol on error (>0) + Stop threshold on error (>0) isLazy: boolean, optional - If True, then only calculate the cost matrix by block and return the dual potentials only (to save memory) - If False, calculate full cost matrix and return outputs of sinkhorn function. + If True, then only calculate the cost matrix by block and return the dual potentials only (to save memory). If False, calculate full cost matrix and return outputs of sinkhorn function. batchSize: int or tuple of 2 int, optional - Size of the batcheses used to compute the sinkhorn update without memory overhead. + Size of the batches used to compute the sinkhorn update without memory overhead. When a tuple is provided it sets the size of the left/right batches. verbose : bool, optional Print information along iterations @@ -1890,7 +1944,7 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num Returns ------- - W : (n_hists) ndarray or float + W : (n_hists) array-like or float Optimal transportation loss for the given parameters log : dict log dictionary return only if log==True in parameters @@ -1918,11 +1972,15 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. ''' + X_s, X_t = list_to_array(X_s, X_t) + + nx = get_backend(X_s, X_t) + ns, nt = X_s.shape[0], X_t.shape[0] if a is None: - a = unif(ns) + a = nx.from_numpy(unif(ns)) if b is None: - b = unif(nt) + b = nx.from_numpy(unif(nt)) if isLazy: if log: @@ -1936,10 +1994,15 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num range_s = range(0, ns, bs) loss = 0 + + X_s_np = nx.to_numpy(X_s) + X_t_np = nx.to_numpy(X_t) + for i in range_s: - M_block = dist(X_s[i:i + bs, :], X_t, metric=metric) - pi_block = np.exp(f[i:i + bs, None] + g[None, :] - M_block / reg) - loss += np.sum(M_block * pi_block) + M_block = dist(X_s_np[i:i + bs, :], X_t_np, metric=metric) + M_block = nx.from_numpy(M_block, type_as=a) + pi_block = nx.exp(f[i:i + bs, None] + g[None, :] - M_block / reg) + loss += nx.sum(M_block * pi_block) if log: return loss, dict_log @@ -1947,7 +2010,8 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num return loss else: - M = dist(X_s, X_t, metric=metric) + M = dist(nx.to_numpy(X_s), nx.to_numpy(X_t), metric=metric) + M = nx.from_numpy(M, type_as=a) if log: sinkhorn_loss, log = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log, @@ -1975,10 +2039,10 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli W_b &= \min_{\gamma_b} <\gamma_b,M_b>_F + reg\cdot\Omega(\gamma_b) - S &= W - 1/2 * (W_a + W_b) + S &= W - \frac{W_a + W_b}{2} .. math:: - s.t. \gamma 1 = a + s.t. \ \gamma 1 = a \gamma^T 1= b @@ -1997,27 +2061,27 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli \gamma_b\geq 0 where : - - :math:`M` (resp. :math:`M_a, M_b`) is the (n_samples_a, n_samples_b) metric cost matrix (resp (n_samples_a, n_samples_a) and (n_samples_b, n_samples_b)) + - :math:`\mathbf{M}` (resp. :math:`\mathbf{M_a}`, :math:`\mathbf{M_b}`) is the (`n_samples_a`, `n_samples_b`) metric cost matrix (resp (`n_samples_a, n_samples_a`) and (`n_samples_b`, `n_samples_b`)) - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - :math:`a` and :math:`b` are source and target weights (sum to 1) + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1) Parameters ---------- - X_s : ndarray, shape (n_samples_a, dim) + X_s : array-like, shape (n_samples_a, dim) samples in the source domain - X_t : ndarray, shape (n_samples_b, dim) + X_t : array-like, shape (n_samples_b, dim) samples in the target domain reg : float Regularization term >0 - a : ndarray, shape (n_samples_a,) + a : array-like, shape (n_samples_a,) samples weights in the source domain - b : ndarray, shape (n_samples_b,) + b : array-like, shape (n_samples_b,) samples weights in the target domain numItermax : int, optional Max number of iterations stopThr : float, optional - Stop threshol on error (>0) + Stop threshold on error (>0) verbose : bool, optional Print information along iterations log : bool, optional @@ -2025,7 +2089,7 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli Returns ------- - W : (1,) ndarray + W : (1,) array-like Optimal transportation symmetrized loss for the given parameters log : dict log dictionary return only if log==True in parameters @@ -2083,47 +2147,54 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, restricted=True, maxiter=10000, maxfun=10000, pgtol=1e-09, verbose=False, log=False): - r"""" + r""" Screening Sinkhorn Algorithm for Regularized Optimal Transport - The function solves an approximate dual of Sinkhorn divergence [2] which is written as the following optimization problem: + The function solves an approximate dual of Sinkhorn divergence :ref:`[2] ` which is written as the following optimization problem: - ..math:: - (u, v) = \argmin_{u, v} 1_{ns}^T B(u,v) 1_{nt} - <\kappa u, a> - + .. math:: - where B(u,v) = \diag(e^u) K \diag(e^v), with K = e^{-M/reg} and + (u, v) = arg\min_{u, v} 1_{ns}^T B(u,v) 1_{nt} - <\kappa u, a> - - s.t. e^{u_i} \geq \epsilon / \kappa, for all i \in {1, ..., ns} + where: - e^{v_j} \geq \epsilon \kappa, for all j \in {1, ..., nt} + .. math:: - The parameters \kappa and \epsilon are determined w.r.t the couple number budget of points (ns_budget, nt_budget), see Equation (5) in [26] + B(u,v) = \mathrm{diag}(e^u) K \mathrm{diag}(e^v) \text{, with } K = e^{-M/reg} \text{ and} + + .. math:: + + s.t. \ e^{u_i} \geq \epsilon / \kappa, \forall i \in \{1, \ldots, ns\} + + e^{v_j} \geq \epsilon \kappa, \forall j \in \{1, \ldots, nt\} + + The parameters `kappa` and `epsilon` are determined w.r.t the couple number budget of points (`ns_budget`, `nt_budget`), see Equation (5) in :ref:`[26] ` Parameters ---------- - a : `numpy.ndarray`, shape=(ns,) + a : array-like, shape=(ns,) samples weights in the source domain - b : `numpy.ndarray`, shape=(nt,) + b : array-like, shape=(nt,) samples weights in the target domain - M : `numpy.ndarray`, shape=(ns, nt) + M : array-like, shape=(ns, nt) Cost matrix reg : `float` Level of the entropy regularisation - ns_budget : `int`, deafult=None - Number budget of points to be keeped in the source domain - If it is None then 50% of the source sample points will be keeped + ns_budget : `int`, default=None + Number budget of points to be kept in the source domain. + If it is None then 50% of the source sample points will be kept - nt_budget : `int`, deafult=None - Number budget of points to be keeped in the target domain - If it is None then 50% of the target sample points will be keeped + nt_budget : `int`, default=None + Number budget of points to be kept in the target domain. + If it is None then 50% of the target sample points will be kept uniform : `bool`, default=False - If `True`, the source and target distribution are supposed to be uniform, i.e., a_i = 1 / ns and b_j = 1 / nt + If `True`, the source and target distribution are supposed to be uniform, i.e., :math:`a_i = 1 / ns` and :math:`b_j = 1 / nt` restricted : `bool`, default=True If `True`, a warm-start initialization for the L-BFGS-B solver @@ -2133,15 +2204,16 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res Maximum number of iterations in LBFGS solver maxfun : `int`, default=10000 - Maximum number of function evaluations in LBFGS solver + Maximum number of function evaluations in LBFGS solver pgtol : `float`, default=1e-09 Final objective function accuracy in LBFGS solver verbose : `bool`, default=False - If `True`, dispaly informations about the cardinals of the active sets and the paramerters kappa + If `True`, display informations about the cardinals of the active sets and the parameters kappa and epsilon + Dependency ---------- To gain more efficiency, screenkhorn needs to call the "Bottleneck" package (https://pypi.org/project/Bottleneck/) @@ -2151,15 +2223,19 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res Returns ------- - gamma : `numpy.ndarray`, shape=(ns, nt) + gamma : array-like, shape=(ns, nt) Screened optimal transportation matrix for the given parameters log : `dict`, default=False Log dictionary return only if log==True in parameters + .. _references-screenkhorn: References ----------- + + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 + .. [26] Alaya M. Z., Bérar M., Gasso G., Rakotomamonjy A. (2019). Screening Sinkhorn Algorithm for Regularized Optimal Transport (NIPS) 33, 2019 """ @@ -2171,9 +2247,12 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res "Bottleneck module is not installed. Install it from https://pypi.org/project/Bottleneck/ for better performance.") bottleneck = np - a = np.asarray(a, dtype=np.float64) - b = np.asarray(b, dtype=np.float64) - M = np.asarray(M, dtype=np.float64) + a, b, M = list_to_array(a, b, M) + + nx = get_backend(M, a, b) + if nx.__name__ == "jax": + raise TypeError("JAX arrays have been received but screenkhorn is not compatible with JAX.") + ns, nt = M.shape # by default, we keep only 50% of the sample data points @@ -2183,9 +2262,7 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res nt_budget = int(np.floor(0.5 * nt)) # calculate the Gibbs kernel - K = np.empty_like(M) - np.divide(M, -reg, out=K) - np.exp(K, out=K) + K = nx.exp(-M / reg) def projection(u, epsilon): u[u <= epsilon] = epsilon @@ -2197,8 +2274,8 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res if ns_budget == ns and nt_budget == nt: # full number of budget points (ns, nt) = (ns_budget, nt_budget) - Isel = np.ones(ns, dtype=bool) - Jsel = np.ones(nt, dtype=bool) + Isel = nx.from_numpy(np.ones(ns, dtype=bool)) + Jsel = nx.from_numpy(np.ones(nt, dtype=bool)) epsilon = 0.0 kappa = 1.0 @@ -2214,57 +2291,61 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res K_IJc = [] K_IcJ = [] - vec_eps_IJc = np.zeros(nt) - vec_eps_IcJ = np.zeros(ns) + vec_eps_IJc = nx.zeros((nt,), type_as=M) + vec_eps_IcJ = nx.zeros((ns,), type_as=M) else: # sum of rows and columns of K - K_sum_cols = K.sum(axis=1) - K_sum_rows = K.sum(axis=0) + K_sum_cols = nx.sum(K, axis=1) + K_sum_rows = nx.sum(K, axis=0) if uniform: if ns / ns_budget < 4: - aK_sort = np.sort(K_sum_cols) + aK_sort = nx.sort(K_sum_cols) epsilon_u_square = a[0] / aK_sort[ns_budget - 1] else: - aK_sort = bottleneck.partition(K_sum_cols, ns_budget - 1)[ns_budget - 1] + aK_sort = nx.from_numpy( + bottleneck.partition(nx.to_numpy(K_sum_cols), ns_budget - 1)[ns_budget - 1] + ) epsilon_u_square = a[0] / aK_sort if nt / nt_budget < 4: - bK_sort = np.sort(K_sum_rows) + bK_sort = nx.sort(K_sum_rows) epsilon_v_square = b[0] / bK_sort[nt_budget - 1] else: - bK_sort = bottleneck.partition(K_sum_rows, nt_budget - 1)[nt_budget - 1] + bK_sort = nx.from_numpy( + bottleneck.partition(nx.to_numpy(K_sum_rows), nt_budget - 1)[nt_budget - 1] + ) epsilon_v_square = b[0] / bK_sort else: aK = a / K_sum_cols bK = b / K_sum_rows - aK_sort = np.sort(aK)[::-1] + aK_sort = nx.flip(nx.sort(aK), axis=0) epsilon_u_square = aK_sort[ns_budget - 1] - bK_sort = np.sort(bK)[::-1] + bK_sort = nx.flip(nx.sort(bK), axis=0) epsilon_v_square = bK_sort[nt_budget - 1] # active sets I and J (see Lemma 1 in [26]) Isel = a >= epsilon_u_square * K_sum_cols Jsel = b >= epsilon_v_square * K_sum_rows - if sum(Isel) != ns_budget: + if nx.sum(Isel) != ns_budget: if uniform: aK = a / K_sum_cols - aK_sort = np.sort(aK)[::-1] - epsilon_u_square = aK_sort[ns_budget - 1:ns_budget + 1].mean() + aK_sort = nx.flip(nx.sort(aK), axis=0) + epsilon_u_square = nx.mean(aK_sort[ns_budget - 1:ns_budget + 1]) Isel = a >= epsilon_u_square * K_sum_cols - ns_budget = sum(Isel) + ns_budget = nx.sum(Isel) - if sum(Jsel) != nt_budget: + if nx.sum(Jsel) != nt_budget: if uniform: bK = b / K_sum_rows - bK_sort = np.sort(bK)[::-1] - epsilon_v_square = bK_sort[nt_budget - 1:nt_budget + 1].mean() + bK_sort = nx.flip(nx.sort(bK), axis=0) + epsilon_v_square = nx.mean(bK_sort[nt_budget - 1:nt_budget + 1]) Jsel = b >= epsilon_v_square * K_sum_rows - nt_budget = sum(Jsel) + nt_budget = nx.sum(Jsel) epsilon = (epsilon_u_square * epsilon_v_square) ** (1 / 4) kappa = (epsilon_v_square / epsilon_u_square) ** (1 / 2) @@ -2282,7 +2363,8 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res K_IcJ = K[np.ix_(Ic, Jsel)] K_IJc = K[np.ix_(Isel, Jc)] - K_min = K_IJ.min() + #K_min = K_IJ.min() + K_min = nx.min(K_IJ) if K_min == 0: K_min = np.finfo(float).tiny @@ -2290,10 +2372,10 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res a_I = a[Isel] b_J = b[Jsel] if not uniform: - a_I_min = a_I.min() - a_I_max = a_I.max() - b_J_max = b_J.max() - b_J_min = b_J.min() + a_I_min = nx.min(a_I) + a_I_max = nx.max(a_I) + b_J_max = nx.max(b_J) + b_J_min = nx.min(b_J) else: a_I_min = a_I[0] a_I_max = a_I[0] @@ -2309,24 +2391,30 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res epsilon * kappa), b_J_max / (ns * epsilon * K_min))] * nt_budget # pre-calculated constants for the objective - vec_eps_IJc = epsilon * kappa * (K_IJc * np.ones(nt - nt_budget).reshape((1, -1))).sum(axis=1) - vec_eps_IcJ = (epsilon / kappa) * (np.ones(ns - ns_budget).reshape((-1, 1)) * K_IcJ).sum(axis=0) + vec_eps_IJc = epsilon * kappa * nx.sum( + K_IJc * nx.ones((nt - nt_budget,), type_as=M)[None, :], + axis=1 + ) + vec_eps_IcJ = (epsilon / kappa) * nx.sum( + nx.ones((ns - ns_budget,), type_as=M)[:, None] * K_IcJ, + axis=0 + ) # initialisation - u0 = np.full(ns_budget, (1. / ns_budget) + epsilon / kappa) - v0 = np.full(nt_budget, (1. / nt_budget) + epsilon * kappa) + u0 = nx.full((ns_budget,), 1. / ns_budget + epsilon / kappa, type_as=M) + v0 = nx.full((nt_budget,), 1. / nt_budget + epsilon * kappa, type_as=M) # pre-calculed constants for Restricted Sinkhorn (see Algorithm 1 in supplementary of [26]) if restricted: if ns_budget != ns or nt_budget != nt: - cst_u = kappa * epsilon * K_IJc.sum(axis=1) - cst_v = epsilon * K_IcJ.sum(axis=0) / kappa + cst_u = kappa * epsilon * nx.sum(K_IJc, axis=1) + cst_v = epsilon * nx.sum(K_IcJ, axis=0) / kappa cpt = 1 while cpt < 5: # 5 iterations - K_IJ_v = np.dot(K_IJ.T, u0) + cst_v + K_IJ_v = nx.dot(K_IJ.T, u0) + cst_v v0 = b_J / (kappa * K_IJ_v) - KIJ_u = np.dot(K_IJ, v0) + cst_u + KIJ_u = nx.dot(K_IJ, v0) + cst_u u0 = (kappa * a_I) / KIJ_u cpt += 1 @@ -2343,9 +2431,9 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res """ cpt = 1 while cpt < max_iter: - K_IJ_v = np.dot(K_IJ.T, usc) + cst_v + K_IJ_v = nx.dot(K_IJ.T, usc) + cst_v vsc = b_J / (kappa * K_IJ_v) - KIJ_u = np.dot(K_IJ, vsc) + cst_u + KIJ_u = nx.dot(K_IJ, vsc) + cst_u usc = (kappa * a_I) / KIJ_u cpt += 1 @@ -2355,17 +2443,20 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res return usc, vsc def screened_obj(usc, vsc): - part_IJ = np.dot(np.dot(usc, K_IJ), vsc) - kappa * np.dot(a_I, np.log(usc)) - (1. / kappa) * np.dot(b_J, - np.log(vsc)) - part_IJc = np.dot(usc, vec_eps_IJc) - part_IcJ = np.dot(vec_eps_IcJ, vsc) + part_IJ = ( + nx.dot(nx.dot(usc, K_IJ), vsc) + - kappa * nx.dot(a_I, nx.log(usc)) + - (1. / kappa) * nx.dot(b_J, nx.log(vsc)) + ) + part_IJc = nx.dot(usc, vec_eps_IJc) + part_IcJ = nx.dot(vec_eps_IcJ, vsc) psi_epsilon = part_IJ + part_IJc + part_IcJ return psi_epsilon def screened_grad(usc, vsc): # gradients of Psi_(kappa,epsilon) w.r.t u and v - grad_u = np.dot(K_IJ, vsc) + vec_eps_IJc - kappa * a_I / usc - grad_v = np.dot(K_IJ.T, usc) + vec_eps_IcJ - (1. / kappa) * b_J / vsc + grad_u = nx.dot(K_IJ, vsc) + vec_eps_IJc - kappa * a_I / usc + grad_v = nx.dot(K_IJ.T, usc) + vec_eps_IcJ - (1. / kappa) * b_J / vsc return grad_u, grad_v def bfgspost(theta): @@ -2375,20 +2466,20 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res f = screened_obj(u, v) # gradient g_u, g_v = screened_grad(u, v) - g = np.hstack([g_u, g_v]) - return f, g + g = nx.concatenate([g_u, g_v], axis=0) + return nx.to_numpy(f), nx.to_numpy(g) # ----------------------------------------------------------------------------------------------------------------# # Step 2: L-BFGS-B solver # # ----------------------------------------------------------------------------------------------------------------# u0, v0 = restricted_sinkhorn(u0, v0) - theta0 = np.hstack([u0, v0]) + theta0 = nx.concatenate([u0, v0], axis=0) bounds = bounds_u + bounds_v # constraint bounds def obj(theta): - return bfgspost(theta) + return bfgspost(nx.from_numpy(theta, type_as=M)) theta, _, _ = fmin_l_bfgs_b(func=obj, x0=theta0, @@ -2396,12 +2487,13 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res maxfun=maxfun, pgtol=pgtol, maxiter=maxiter) + theta = nx.from_numpy(theta) usc = theta[:ns_budget] vsc = theta[ns_budget:] - usc_full = np.full(ns, epsilon / kappa) - vsc_full = np.full(nt, epsilon * kappa) + usc_full = nx.full((ns,), epsilon / kappa, type_as=M) + vsc_full = nx.full((nt,), epsilon * kappa, type_as=M) usc_full[Isel] = usc vsc_full[Jsel] = vsc @@ -2413,7 +2505,7 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res log['Jsel'] = Jsel gamma = usc_full[:, None] * K * vsc_full[None, :] - gamma = gamma / gamma.sum() + gamma = gamma / nx.sum(gamma) if log: return gamma, log diff --git a/ot/gromov.py b/ot/gromov.py index a27217a..85b1549 100644 --- a/ot/gromov.py +++ b/ot/gromov.py @@ -1161,7 +1161,7 @@ def entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, max_iter : int, optional Max number of iterations tol : float, optional - Stop threshol on error (>0) + Stop threshold on error (>0) verbose : bool, optional Print information along iterations. log : bool, optional @@ -1267,7 +1267,7 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, max_iter : int, optional Max number of iterations tol : float, optional - Stop threshol on error (>0). + Stop threshold on error (>0). verbose : bool, optional Print information along iterations. log : bool, optional @@ -1365,7 +1365,7 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_ max_iter : int, optional Max number of iterations tol : float, optional - Stop threshol on error (>0). + Stop threshold on error (>0). verbose : bool, optional Print information along iterations. log : bool, optional diff --git a/ot/smooth.py b/ot/smooth.py index 81f6a3e..ea26bae 100644 --- a/ot/smooth.py +++ b/ot/smooth.py @@ -458,7 +458,7 @@ def smooth_ot_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=1e-9, numItermax : int, optional Max number of iterations stopThr : float, optional - Stop threshol on error (>0) + Stop threshold on error (>0) verbose : bool, optional Print information along iterations log : bool, optional @@ -552,7 +552,7 @@ def smooth_ot_semi_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr= numItermax : int, optional Max number of iterations stopThr : float, optional - Stop threshol on error (>0) + Stop threshold on error (>0) verbose : bool, optional Print information along iterations log : bool, optional diff --git a/ot/unbalanced.py b/ot/unbalanced.py index e37f10c..6a61aa1 100644 --- a/ot/unbalanced.py +++ b/ot/unbalanced.py @@ -58,7 +58,7 @@ def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', numItermax=1000, numItermax : int, optional Max number of iterations stopThr : float, optional - Stop threshol on error (>0) + Stop threshold on error (>0) verbose : bool, optional Print information along iterations log : bool, optional @@ -186,7 +186,7 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn', numItermax : int, optional Max number of iterations stopThr : float, optional - Stop threshol on error (>0) + Stop threshold on error (>0) verbose : bool, optional Print information along iterations log : bool, optional @@ -300,7 +300,7 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000, numItermax : int, optional Max number of iterations stopThr : float, optional - Stop threshol on error (> 0) + Stop threshold on error (> 0) verbose : bool, optional Print information along iterations log : bool, optional @@ -482,7 +482,7 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000 numItermax : int, optional Max number of iterations stopThr : float, optional - Stop threshol on error (>0) + Stop threshold on error (>0) verbose : bool, optional Print information along iterations log : bool, optional @@ -691,7 +691,7 @@ def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3, numItermax : int, optional Max number of iterations stopThr : float, optional - Stop threshol on error (> 0) + Stop threshold on error (> 0) verbose : bool, optional Print information along iterations log : bool, optional @@ -841,7 +841,7 @@ def barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, weights=None, numItermax : int, optional Max number of iterations stopThr : float, optional - Stop threshol on error (> 0) + Stop threshold on error (> 0) verbose : bool, optional Print information along iterations log : bool, optional @@ -971,7 +971,7 @@ def barycenter_unbalanced(A, M, reg, reg_m, method="sinkhorn", weights=None, numItermax : int, optional Max number of iterations stopThr : float, optional - Stop threshol on error (> 0) + Stop threshold on error (> 0) verbose : bool, optional Print information along iterations log : bool, optional diff --git a/test/conftest.py b/test/conftest.py new file mode 100644 index 0000000..876b525 --- /dev/null +++ b/test/conftest.py @@ -0,0 +1,49 @@ +# -*- coding: utf-8 -*- + +# Configuration file for pytest + +# License: MIT License + +import pytest +from ot.backend import jax +from ot.backend import get_backend_list +import functools + +if jax: + from jax.config import config + +backend_list = get_backend_list() + + +@pytest.fixture(params=backend_list) +def nx(request): + backend = request.param + if backend.__name__ == "jax": + config.update("jax_enable_x64", True) + + yield backend + + if backend.__name__ == "jax": + config.update("jax_enable_x64", False) + + +def skip_arg(arg, value, reason=None, getter=lambda x: x): + if reason is None: + reason = f"Param {arg} should be skipped for value {value}" + + def wrapper(function): + + @functools.wraps(function) + def wrapped(*args, **kwargs): + if arg in kwargs.keys() and getter(kwargs[arg]) == value: + pytest.skip(reason) + return function(*args, **kwargs) + + return wrapped + + return wrapper + + +def pytest_configure(config): + pytest.skip_arg = skip_arg + pytest.skip_backend = functools.partial(skip_arg, "nx", getter=str) diff --git a/test/test_backend.py b/test/test_backend.py index cbfaf94..859da5a 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -1,6 +1,7 @@ """Tests for backend module """ # Author: Remi Flamary +# Nicolas Courty # # License: MIT License @@ -155,6 +156,8 @@ def test_empty_backend(): nx.exp(M) with pytest.raises(NotImplementedError): nx.sqrt(M) + with pytest.raises(NotImplementedError): + nx.power(v, 2) with pytest.raises(NotImplementedError): nx.dot(v, v) with pytest.raises(NotImplementedError): @@ -173,8 +176,38 @@ def test_empty_backend(): nx.sort(M) with pytest.raises(NotImplementedError): nx.argsort(M) + with pytest.raises(NotImplementedError): + nx.searchsorted(v, v) with pytest.raises(NotImplementedError): nx.flip(M) + with pytest.raises(NotImplementedError): + nx.clip(M, -1, 1) + with pytest.raises(NotImplementedError): + nx.repeat(M, 0, 1) + with pytest.raises(NotImplementedError): + nx.take_along_axis(M, v, 0) + with pytest.raises(NotImplementedError): + nx.concatenate([v, v]) + with pytest.raises(NotImplementedError): + nx.zero_pad(M, v) + with pytest.raises(NotImplementedError): + nx.argmax(M) + with pytest.raises(NotImplementedError): + nx.mean(M) + with pytest.raises(NotImplementedError): + nx.std(M) + with pytest.raises(NotImplementedError): + nx.linspace(0, 1, 50) + with pytest.raises(NotImplementedError): + nx.meshgrid(v, v) + with pytest.raises(NotImplementedError): + nx.diag(M) + with pytest.raises(NotImplementedError): + nx.unique([M, M]) + with pytest.raises(NotImplementedError): + nx.logsumexp(M) + with pytest.raises(NotImplementedError): + nx.stack([M, M]) @pytest.mark.parametrize('backend', backend_list) @@ -278,6 +311,10 @@ def test_func_backends(backend): lst_b.append(nx.to_numpy(A)) lst_name.append('sqrt') + A = nx.power(Mb, 2) + lst_b.append(nx.to_numpy(A)) + lst_name.append('power') + A = nx.dot(vb, vb) lst_b.append(nx.to_numpy(A)) lst_name.append('dot(v,v)') @@ -326,10 +363,75 @@ def test_func_backends(backend): lst_b.append(nx.to_numpy(A)) lst_name.append('argsort') + A = nx.searchsorted(Mb, Mb, 'right') + lst_b.append(nx.to_numpy(A)) + lst_name.append('searchsorted') + A = nx.flip(Mb) lst_b.append(nx.to_numpy(A)) lst_name.append('flip') + A = nx.clip(vb, 0, 1) + lst_b.append(nx.to_numpy(A)) + lst_name.append('clip') + + A = nx.repeat(Mb, 0) + A = nx.repeat(Mb, 2, -1) + lst_b.append(nx.to_numpy(A)) + lst_name.append('repeat') + + A = nx.take_along_axis(vb, nx.arange(3), -1) + lst_b.append(nx.to_numpy(A)) + lst_name.append('take_along_axis') + + A = nx.concatenate((Mb, Mb), -1) + lst_b.append(nx.to_numpy(A)) + lst_name.append('concatenate') + + A = nx.zero_pad(Mb, len(Mb.shape) * [(3, 3)]) + lst_b.append(nx.to_numpy(A)) + lst_name.append('zero_pad') + + A = nx.argmax(Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('argmax') + + A = nx.mean(Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('mean') + + A = nx.std(Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('std') + + A = nx.linspace(0, 1, 50) + lst_b.append(nx.to_numpy(A)) + lst_name.append('linspace') + + X, Y = nx.meshgrid(vb, vb) + lst_b.append(np.stack([nx.to_numpy(X), nx.to_numpy(Y)])) + lst_name.append('meshgrid') + + A = nx.diag(Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('diag2D') + + A = nx.diag(vb, 1) + lst_b.append(nx.to_numpy(A)) + lst_name.append('diag1D') + + A = nx.unique(nx.from_numpy(np.stack([M, M]))) + lst_b.append(nx.to_numpy(A)) + lst_name.append('unique') + + A = nx.logsumexp(Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('logsumexp') + + A = nx.stack([Mb, Mb]) + lst_b.append(nx.to_numpy(A)) + lst_name.append('stack') + lst_tot.append(lst_b) lst_np = lst_tot[0] diff --git a/test/test_bregman.py b/test/test_bregman.py index 88166a5..942cb6d 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -10,11 +10,8 @@ import numpy as np import pytest import ot -from ot.backend import get_backend_list from ot.backend import torch -backend_list = get_backend_list() - def test_sinkhorn(): # test sinkhorn @@ -28,14 +25,13 @@ def test_sinkhorn(): G = ot.sinkhorn(u, u, M, 1, stopThr=1e-10) - # check constratints + # check constraints np.testing.assert_allclose( u, G.sum(1), atol=1e-05) # cf convergence sinkhorn np.testing.assert_allclose( u, G.sum(0), atol=1e-05) # cf convergence sinkhorn -@pytest.mark.parametrize('nx', backend_list) def test_sinkhorn_backends(nx): n_samples = 100 n_features = 2 @@ -57,7 +53,6 @@ def test_sinkhorn_backends(nx): np.allclose(G, nx.to_numpy(Gb)) -@pytest.mark.parametrize('nx', backend_list) def test_sinkhorn2_backends(nx): n_samples = 100 n_features = 2 @@ -116,20 +111,20 @@ def test_sinkhorn_empty(): M = ot.dist(x, x) G, log = ot.sinkhorn([], [], M, 1, stopThr=1e-10, verbose=True, log=True) - # check constratints + # check constraints np.testing.assert_allclose(u, G.sum(1), atol=1e-05) np.testing.assert_allclose(u, G.sum(0), atol=1e-05) G, log = ot.sinkhorn([], [], M, 1, stopThr=1e-10, method='sinkhorn_stabilized', verbose=True, log=True) - # check constratints + # check constraints np.testing.assert_allclose(u, G.sum(1), atol=1e-05) np.testing.assert_allclose(u, G.sum(0), atol=1e-05) G, log = ot.sinkhorn( [], [], M, 1, stopThr=1e-10, method='sinkhorn_epsilon_scaling', verbose=True, log=True) - # check constratints + # check constraints np.testing.assert_allclose(u, G.sum(1), atol=1e-05) np.testing.assert_allclose(u, G.sum(0), atol=1e-05) @@ -137,7 +132,8 @@ def test_sinkhorn_empty(): ot.sinkhorn([], [], M, 1, method='greenkhorn', stopThr=1e-10, log=True) -def test_sinkhorn_variants(): +@pytest.skip_backend("jax") +def test_sinkhorn_variants(nx): # test sinkhorn n = 100 rng = np.random.RandomState(0) @@ -147,13 +143,18 @@ def test_sinkhorn_variants(): M = ot.dist(x, x) - G0 = ot.sinkhorn(u, u, M, 1, method='sinkhorn', stopThr=1e-10) - Gs = ot.sinkhorn(u, u, M, 1, method='sinkhorn_stabilized', stopThr=1e-10) - Ges = ot.sinkhorn( - u, u, M, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10) - G_green = ot.sinkhorn(u, u, M, 1, method='greenkhorn', stopThr=1e-10) + ub = nx.from_numpy(u) + Mb = nx.from_numpy(M) + + G = ot.sinkhorn(u, u, M, 1, method='sinkhorn', stopThr=1e-10) + G0 = nx.to_numpy(ot.sinkhorn(ub, ub, Mb, 1, method='sinkhorn', stopThr=1e-10)) + Gs = nx.to_numpy(ot.sinkhorn(ub, ub, Mb, 1, method='sinkhorn_stabilized', stopThr=1e-10)) + Ges = nx.to_numpy(ot.sinkhorn( + ub, ub, Mb, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10)) + G_green = nx.to_numpy(ot.sinkhorn(ub, ub, Mb, 1, method='greenkhorn', stopThr=1e-10)) # check values + np.testing.assert_allclose(G, G0, atol=1e-05) np.testing.assert_allclose(G0, Gs, atol=1e-05) np.testing.assert_allclose(G0, Ges, atol=1e-05) np.testing.assert_allclose(G0, G_green, atol=1e-5) @@ -184,7 +185,7 @@ def test_sinkhorn_variants_log(): @pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"]) -def test_barycenter(method): +def test_barycenter(nx, method): n_bins = 100 # nb bins # Gaussian distributions @@ -201,16 +202,23 @@ def test_barycenter(method): alpha = 0.5 # 0<=alpha<=1 weights = np.array([1 - alpha, alpha]) + Ab = nx.from_numpy(A) + Mb = nx.from_numpy(M) + weightsb = nx.from_numpy(weights) + # wasserstein reg = 1e-2 - bary_wass, log = ot.bregman.barycenter(A, M, reg, weights, method=method, log=True) + bary_wass_np, log = ot.bregman.barycenter(A, M, reg, weights, method=method, log=True) + bary_wass, _ = ot.bregman.barycenter(Ab, Mb, reg, weightsb, method=method, log=True) + bary_wass = nx.to_numpy(bary_wass) np.testing.assert_allclose(1, np.sum(bary_wass)) + np.testing.assert_allclose(bary_wass, bary_wass_np) - ot.bregman.barycenter(A, M, reg, log=True, verbose=True) + ot.bregman.barycenter(Ab, Mb, reg, log=True, verbose=True) -def test_barycenter_stabilization(): +def test_barycenter_stabilization(nx): n_bins = 100 # nb bins # Gaussian distributions @@ -227,17 +235,26 @@ def test_barycenter_stabilization(): alpha = 0.5 # 0<=alpha<=1 weights = np.array([1 - alpha, alpha]) + Ab = nx.from_numpy(A) + Mb = nx.from_numpy(M) + weights_b = nx.from_numpy(weights) + # wasserstein reg = 1e-2 - bar_stable = ot.bregman.barycenter(A, M, reg, weights, - method="sinkhorn_stabilized", - stopThr=1e-8, verbose=True) - bar = ot.bregman.barycenter(A, M, reg, weights, method="sinkhorn", - stopThr=1e-8, verbose=True) + bar_np = ot.bregman.barycenter(A, M, reg, weights, method="sinkhorn", stopThr=1e-8, verbose=True) + bar_stable = nx.to_numpy(ot.bregman.barycenter( + Ab, Mb, reg, weights_b, method="sinkhorn_stabilized", + stopThr=1e-8, verbose=True + )) + bar = nx.to_numpy(ot.bregman.barycenter( + Ab, Mb, reg, weights_b, method="sinkhorn", + stopThr=1e-8, verbose=True + )) np.testing.assert_allclose(bar, bar_stable) + np.testing.assert_allclose(bar, bar_np) -def test_wasserstein_bary_2d(): +def test_wasserstein_bary_2d(nx): size = 100 # size of a square image a1 = np.random.randn(size, size) a1 += a1.min() @@ -250,17 +267,21 @@ def test_wasserstein_bary_2d(): A[0, :, :] = a1 A[1, :, :] = a2 + Ab = nx.from_numpy(A) + # wasserstein reg = 1e-2 - bary_wass = ot.bregman.convolutional_barycenter2d(A, reg) + bary_wass_np = ot.bregman.convolutional_barycenter2d(A, reg) + bary_wass = nx.to_numpy(ot.bregman.convolutional_barycenter2d(Ab, reg)) np.testing.assert_allclose(1, np.sum(bary_wass)) + np.testing.assert_allclose(bary_wass, bary_wass_np) # help in checking if log and verbose do not bug the function ot.bregman.convolutional_barycenter2d(A, reg, log=True, verbose=True) -def test_unmix(): +def test_unmix(nx): n_bins = 50 # nb bins # Gaussian distributions @@ -280,18 +301,26 @@ def test_unmix(): M0 /= M0.max() h0 = ot.unif(2) + ab = nx.from_numpy(a) + Db = nx.from_numpy(D) + Mb = nx.from_numpy(M) + M0b = nx.from_numpy(M0) + h0b = nx.from_numpy(h0) + # wasserstein reg = 1e-3 - um = ot.bregman.unmix(a, D, M, M0, h0, reg, 1, alpha=0.01, ) + um_np = ot.bregman.unmix(a, D, M, M0, h0, reg, 1, alpha=0.01) + um = nx.to_numpy(ot.bregman.unmix(ab, Db, Mb, M0b, h0b, reg, 1, alpha=0.01)) np.testing.assert_allclose(1, np.sum(um), rtol=1e-03, atol=1e-03) np.testing.assert_allclose([0.5, 0.5], um, rtol=1e-03, atol=1e-03) + np.testing.assert_allclose(um, um_np) - ot.bregman.unmix(a, D, M, M0, h0, reg, + ot.bregman.unmix(ab, Db, Mb, M0b, h0b, reg, 1, alpha=0.01, log=True, verbose=True) -def test_empirical_sinkhorn(): +def test_empirical_sinkhorn(nx): # test sinkhorn n = 10 a = ot.unif(n) @@ -302,19 +331,28 @@ def test_empirical_sinkhorn(): M = ot.dist(X_s, X_t) M_m = ot.dist(X_s, X_t, metric='minkowski') - G_sqe = ot.bregman.empirical_sinkhorn(X_s, X_t, 1) - sinkhorn_sqe = ot.sinkhorn(a, b, M, 1) + ab = nx.from_numpy(a) + bb = nx.from_numpy(b) + X_sb = nx.from_numpy(X_s) + X_tb = nx.from_numpy(X_t) + Mb = nx.from_numpy(M, type_as=ab) + M_mb = nx.from_numpy(M_m, type_as=ab) + + G_sqe = nx.to_numpy(ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1)) + sinkhorn_sqe = nx.to_numpy(ot.sinkhorn(ab, bb, Mb, 1)) - G_log, log_es = ot.bregman.empirical_sinkhorn(X_s, X_t, 0.1, log=True) - sinkhorn_log, log_s = ot.sinkhorn(a, b, M, 0.1, log=True) + G_log, log_es = ot.bregman.empirical_sinkhorn(X_sb, X_tb, 0.1, log=True) + G_log = nx.to_numpy(G_log) + sinkhorn_log, log_s = ot.sinkhorn(ab, bb, Mb, 0.1, log=True) + sinkhorn_log = nx.to_numpy(sinkhorn_log) - G_m = ot.bregman.empirical_sinkhorn(X_s, X_t, 1, metric='minkowski') - sinkhorn_m = ot.sinkhorn(a, b, M_m, 1) + G_m = nx.to_numpy(ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1, metric='minkowski')) + sinkhorn_m = nx.to_numpy(ot.sinkhorn(ab, bb, M_mb, 1)) - loss_emp_sinkhorn = ot.bregman.empirical_sinkhorn2(X_s, X_t, 1) - loss_sinkhorn = ot.sinkhorn2(a, b, M, 1) + loss_emp_sinkhorn = nx.to_numpy(ot.bregman.empirical_sinkhorn2(X_sb, X_tb, 1)) + loss_sinkhorn = nx.to_numpy(ot.sinkhorn2(ab, bb, Mb, 1)) - # check constratints + # check constraints np.testing.assert_allclose( sinkhorn_sqe.sum(1), G_sqe.sum(1), atol=1e-05) # metric sqeuclidian np.testing.assert_allclose( @@ -330,7 +368,7 @@ def test_empirical_sinkhorn(): np.testing.assert_allclose(loss_emp_sinkhorn, loss_sinkhorn, atol=1e-05) -def test_lazy_empirical_sinkhorn(): +def test_lazy_empirical_sinkhorn(nx): # test sinkhorn n = 10 a = ot.unif(n) @@ -342,22 +380,34 @@ def test_lazy_empirical_sinkhorn(): M = ot.dist(X_s, X_t) M_m = ot.dist(X_s, X_t, metric='minkowski') - f, g = ot.bregman.empirical_sinkhorn(X_s, X_t, 1, numIterMax=numIterMax, isLazy=True, batchSize=(1, 3), verbose=True) + ab = nx.from_numpy(a) + bb = nx.from_numpy(b) + X_sb = nx.from_numpy(X_s) + X_tb = nx.from_numpy(X_t) + Mb = nx.from_numpy(M, type_as=ab) + M_mb = nx.from_numpy(M_m, type_as=ab) + + f, g = ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1, numIterMax=numIterMax, isLazy=True, batchSize=(1, 3), verbose=True) + f, g = nx.to_numpy(f), nx.to_numpy(g) G_sqe = np.exp(f[:, None] + g[None, :] - M / 1) - sinkhorn_sqe = ot.sinkhorn(a, b, M, 1) + sinkhorn_sqe = nx.to_numpy(ot.sinkhorn(ab, bb, Mb, 1)) - f, g, log_es = ot.bregman.empirical_sinkhorn(X_s, X_t, 0.1, numIterMax=numIterMax, isLazy=True, batchSize=1, log=True) + f, g, log_es = ot.bregman.empirical_sinkhorn(X_sb, X_tb, 0.1, numIterMax=numIterMax, isLazy=True, batchSize=1, log=True) + f, g = nx.to_numpy(f), nx.to_numpy(g) G_log = np.exp(f[:, None] + g[None, :] - M / 0.1) - sinkhorn_log, log_s = ot.sinkhorn(a, b, M, 0.1, log=True) + sinkhorn_log, log_s = ot.sinkhorn(ab, bb, Mb, 0.1, log=True) + sinkhorn_log = nx.to_numpy(sinkhorn_log) - f, g = ot.bregman.empirical_sinkhorn(X_s, X_t, 1, metric='minkowski', numIterMax=numIterMax, isLazy=True, batchSize=1) + f, g = ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1, metric='minkowski', numIterMax=numIterMax, isLazy=True, batchSize=1) + f, g = nx.to_numpy(f), nx.to_numpy(g) G_m = np.exp(f[:, None] + g[None, :] - M_m / 1) - sinkhorn_m = ot.sinkhorn(a, b, M_m, 1) + sinkhorn_m = nx.to_numpy(ot.sinkhorn(ab, bb, M_mb, 1)) - loss_emp_sinkhorn, log = ot.bregman.empirical_sinkhorn2(X_s, X_t, 1, numIterMax=numIterMax, isLazy=True, batchSize=1, log=True) - loss_sinkhorn = ot.sinkhorn2(a, b, M, 1) + loss_emp_sinkhorn, log = ot.bregman.empirical_sinkhorn2(X_sb, X_tb, 1, numIterMax=numIterMax, isLazy=True, batchSize=1, log=True) + loss_emp_sinkhorn = nx.to_numpy(loss_emp_sinkhorn) + loss_sinkhorn = nx.to_numpy(ot.sinkhorn2(ab, bb, Mb, 1)) - # check constratints + # check constraints np.testing.assert_allclose( sinkhorn_sqe.sum(1), G_sqe.sum(1), atol=1e-05) # metric sqeuclidian np.testing.assert_allclose( @@ -373,7 +423,7 @@ def test_lazy_empirical_sinkhorn(): np.testing.assert_allclose(loss_emp_sinkhorn, loss_sinkhorn, atol=1e-05) -def test_empirical_sinkhorn_divergence(): +def test_empirical_sinkhorn_divergence(nx): # Test sinkhorn divergence n = 10 a = np.linspace(1, n, n) @@ -385,22 +435,31 @@ def test_empirical_sinkhorn_divergence(): M_s = ot.dist(X_s, X_s) M_t = ot.dist(X_t, X_t) - emp_sinkhorn_div = ot.bregman.empirical_sinkhorn_divergence(X_s, X_t, 1, a=a, b=b) - sinkhorn_div = (ot.sinkhorn2(a, b, M, 1) - 1 / 2 * ot.sinkhorn2(a, a, M_s, 1) - 1 / 2 * ot.sinkhorn2(b, b, M_t, 1)) + ab = nx.from_numpy(a) + bb = nx.from_numpy(b) + X_sb = nx.from_numpy(X_s) + X_tb = nx.from_numpy(X_t) + Mb = nx.from_numpy(M, type_as=ab) + M_sb = nx.from_numpy(M_s, type_as=ab) + M_tb = nx.from_numpy(M_t, type_as=ab) + + emp_sinkhorn_div = nx.to_numpy(ot.bregman.empirical_sinkhorn_divergence(X_sb, X_tb, 1, a=ab, b=bb)) + sinkhorn_div = nx.to_numpy( + ot.sinkhorn2(ab, bb, Mb, 1) + - 1 / 2 * ot.sinkhorn2(ab, ab, M_sb, 1) + - 1 / 2 * ot.sinkhorn2(bb, bb, M_tb, 1) + ) + emp_sinkhorn_div_np = ot.bregman.empirical_sinkhorn_divergence(X_s, X_t, 1, a=a, b=b) - emp_sinkhorn_div_log, log_es = ot.bregman.empirical_sinkhorn_divergence(X_s, X_t, 1, a=a, b=b, log=True) - sink_div_log_ab, log_s_ab = ot.sinkhorn2(a, b, M, 1, log=True) - sink_div_log_a, log_s_a = ot.sinkhorn2(a, a, M_s, 1, log=True) - sink_div_log_b, log_s_b = ot.sinkhorn2(b, b, M_t, 1, log=True) - sink_div_log = sink_div_log_ab - 1 / 2 * (sink_div_log_a + sink_div_log_b) # check constraints + np.testing.assert_allclose(emp_sinkhorn_div, emp_sinkhorn_div_np, atol=1e-05) np.testing.assert_allclose( emp_sinkhorn_div, sinkhorn_div, atol=1e-05) # cf conv emp sinkhorn - np.testing.assert_allclose( - emp_sinkhorn_div_log, sink_div_log, atol=1e-05) # cf conv emp sinkhorn + ot.bregman.empirical_sinkhorn_divergence(X_sb, X_tb, 1, a=ab, b=bb, log=True) -def test_stabilized_vs_sinkhorn_multidim(): + +def test_stabilized_vs_sinkhorn_multidim(nx): # test if stable version matches sinkhorn # for multidimensional inputs n = 100 @@ -416,12 +475,21 @@ def test_stabilized_vs_sinkhorn_multidim(): M = ot.utils.dist0(n) M /= np.median(M) epsilon = 0.1 - G, log = ot.bregman.sinkhorn(a, b, M, reg=epsilon, + + ab = nx.from_numpy(a) + bb = nx.from_numpy(b) + Mb = nx.from_numpy(M, type_as=ab) + + G_np, _ = ot.bregman.sinkhorn(a, b, M, reg=epsilon, method="sinkhorn", log=True) + G, log = ot.bregman.sinkhorn(ab, bb, Mb, reg=epsilon, method="sinkhorn_stabilized", log=True) - G2, log2 = ot.bregman.sinkhorn(a, b, M, epsilon, + G = nx.to_numpy(G) + G2, log2 = ot.bregman.sinkhorn(ab, bb, Mb, epsilon, method="sinkhorn", log=True) + G2 = nx.to_numpy(G2) + np.testing.assert_allclose(G_np, G2) np.testing.assert_allclose(G, G2) @@ -458,8 +526,9 @@ def test_implemented_methods(): ot.bregman.sinkhorn2(a, b, M, epsilon, method=method) +@pytest.skip_backend("jax") @pytest.mark.filterwarnings("ignore:Bottleneck") -def test_screenkhorn(): +def test_screenkhorn(nx): # test screenkhorn rng = np.random.RandomState(0) n = 100 @@ -468,17 +537,31 @@ def test_screenkhorn(): x = rng.randn(n, 2) M = ot.dist(x, x) + + ab = nx.from_numpy(a) + bb = nx.from_numpy(b) + Mb = nx.from_numpy(M, type_as=ab) + + # np sinkhorn + G_sink_np = ot.sinkhorn(a, b, M, 1e-03) # sinkhorn - G_sink = ot.sinkhorn(a, b, M, 1e-03) + G_sink = nx.to_numpy(ot.sinkhorn(ab, bb, Mb, 1e-03)) # screenkhorn - G_screen = ot.bregman.screenkhorn(a, b, M, 1e-03, uniform=True, verbose=True) + G_screen = nx.to_numpy(ot.bregman.screenkhorn(ab, bb, Mb, 1e-03, uniform=True, verbose=True)) # check marginals + np.testing.assert_allclose(G_sink_np, G_sink) np.testing.assert_allclose(G_sink.sum(0), G_screen.sum(0), atol=1e-02) np.testing.assert_allclose(G_sink.sum(1), G_screen.sum(1), atol=1e-02) -def test_convolutional_barycenter_non_square(): +def test_convolutional_barycenter_non_square(nx): # test for image with height not equal width A = np.ones((2, 2, 3)) / (2 * 3) - b = ot.bregman.convolutional_barycenter2d(A, 1e-03) + Ab = nx.from_numpy(A) + + b_np = ot.bregman.convolutional_barycenter2d(A, 1e-03) + b = nx.to_numpy(ot.bregman.convolutional_barycenter2d(Ab, 1e-03)) + + np.testing.assert_allclose(np.ones((2, 3)) / (2 * 3), b, atol=1e-02) np.testing.assert_allclose(np.ones((2, 3)) / (2 * 3), b, atol=1e-02) + np.testing.assert_allclose(b, b_np) diff --git a/test/test_partial.py b/test/test_partial.py index 3571e2a..97c611b 100755 --- a/test/test_partial.py +++ b/test/test_partial.py @@ -104,7 +104,7 @@ def test_partial_wasserstein(): w, log = ot.partial.entropic_partial_wasserstein(p, q, M, reg=1, m=m, log=True, verbose=True) - # check constratints + # check constraints np.testing.assert_equal( w0.sum(1) - p <= 1e-5, [True] * len(p)) # cf convergence wasserstein np.testing.assert_equal( @@ -127,7 +127,7 @@ def test_partial_wasserstein(): np.testing.assert_allclose(w0, w0_val, atol=1e-1, rtol=1e-1) - # check constratints + # check constraints np.testing.assert_equal( G.sum(1) - p <= 1e-5, [True] * len(p)) # cf convergence wasserstein np.testing.assert_equal( @@ -194,7 +194,7 @@ def test_partial_gromov_wasserstein(): 100, m=m, log=True) - # check constratints + # check constraints np.testing.assert_equal( res0.sum(1) <= p, [True] * len(p)) # cf convergence wasserstein np.testing.assert_equal( diff --git a/test/test_smooth.py b/test/test_smooth.py index 2afa4f8..31e0b2e 100644 --- a/test/test_smooth.py +++ b/test/test_smooth.py @@ -25,16 +25,16 @@ def test_smooth_ot_dual(): Gl2, log = ot.smooth.smooth_ot_dual(u, u, M, 1, reg_type='l2', log=True, stopThr=1e-10) - # check constratints + # check constraints np.testing.assert_allclose( u, Gl2.sum(1), atol=1e-05) # cf convergence sinkhorn np.testing.assert_allclose( u, Gl2.sum(0), atol=1e-05) # cf convergence sinkhorn - # kl regyularisation + # kl regularisation G = ot.smooth.smooth_ot_dual(u, u, M, 1, reg_type='kl', stopThr=1e-10) - # check constratints + # check constraints np.testing.assert_allclose( u, G.sum(1), atol=1e-05) # cf convergence sinkhorn np.testing.assert_allclose( @@ -60,16 +60,16 @@ def test_smooth_ot_semi_dual(): Gl2, log = ot.smooth.smooth_ot_semi_dual(u, u, M, 1, reg_type='l2', log=True, stopThr=1e-10) - # check constratints + # check constraints np.testing.assert_allclose( u, Gl2.sum(1), atol=1e-05) # cf convergence sinkhorn np.testing.assert_allclose( u, Gl2.sum(0), atol=1e-05) # cf convergence sinkhorn - # kl regyularisation + # kl regularisation G = ot.smooth.smooth_ot_semi_dual(u, u, M, 1, reg_type='kl', stopThr=1e-10) - # check constratints + # check constraints np.testing.assert_allclose( u, G.sum(1), atol=1e-05) # cf convergence sinkhorn np.testing.assert_allclose( diff --git a/test/test_stochastic.py b/test/test_stochastic.py index 98e93ec..736df32 100644 --- a/test/test_stochastic.py +++ b/test/test_stochastic.py @@ -43,7 +43,7 @@ def test_stochastic_sag(): G = ot.stochastic.solve_semi_dual_entropic(u, u, M, reg, "sag", numItermax=numItermax) - # check constratints + # check constraints np.testing.assert_allclose( u, G.sum(1), atol=1e-03) # cf convergence sag np.testing.assert_allclose( @@ -73,7 +73,7 @@ def test_stochastic_asgd(): G, log = ot.stochastic.solve_semi_dual_entropic(u, u, M, reg, "asgd", numItermax=numItermax, log=True) - # check constratints + # check constraints np.testing.assert_allclose( u, G.sum(1), atol=1e-02) # cf convergence asgd np.testing.assert_allclose( @@ -105,7 +105,7 @@ def test_sag_asgd_sinkhorn(): numItermax=nb_iter) G_sinkhorn = ot.sinkhorn(u, u, M, reg) - # check constratints + # check constraints np.testing.assert_allclose( G_sag.sum(1), G_sinkhorn.sum(1), atol=1e-02) np.testing.assert_allclose( @@ -148,7 +148,7 @@ def test_stochastic_dual_sgd(): G, log = ot.stochastic.solve_dual_entropic(u, u, M, reg, batch_size, numItermax=numItermax, log=True) - # check constratints + # check constraints np.testing.assert_allclose( u, G.sum(1), atol=1e-03) # cf convergence sgd np.testing.assert_allclose( @@ -181,7 +181,7 @@ def test_dual_sgd_sinkhorn(): G_sinkhorn = ot.sinkhorn(u, u, M, reg) - # check constratints + # check constraints np.testing.assert_allclose( G_sgd.sum(1), G_sinkhorn.sum(1), atol=1e-02) np.testing.assert_allclose( @@ -206,7 +206,7 @@ def test_dual_sgd_sinkhorn(): G_sinkhorn = ot.sinkhorn(a, b, M, reg) - # check constratints + # check constraints np.testing.assert_allclose( G_sgd.sum(1), G_sinkhorn.sum(1), atol=1e-03) np.testing.assert_allclose( -- cgit v1.2.3 From 76450dddf8dd62b9714b72e99ae075516246d433 Mon Sep 17 00:00:00 2001 From: ncassereau-idris <84033440+ncassereau-idris@users.noreply.github.com> Date: Mon, 25 Oct 2021 17:35:36 +0200 Subject: [MRG] Backend for optim (#282) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Backend for optim * Bug solve * Doc update * backend tests now with fixture * Unused imports removed * Docs * Docs * Docs * Outer product backend docs * Prettier docs * Pep8 * Mistakes corrected Co-authored-by: Rémi Flamary --- ot/backend.py | 118 ++++++++++++++++++++++++++------------- ot/lp/__init__.py | 4 +- ot/optim.py | 155 ++++++++++++++++++++++++++++++--------------------- test/test_backend.py | 22 +++++--- test/test_optim.py | 78 ++++++++++++++++++++------ test/test_ot.py | 6 +- test/test_utils.py | 7 --- 7 files changed, 250 insertions(+), 140 deletions(-) (limited to 'test') diff --git a/ot/backend.py b/ot/backend.py index a4a4757..876b96a 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -123,7 +123,7 @@ class Backend(): r""" Creates a tensor full of zeros. - This function follow the api from :any:`numpy.zeros` + This function follows the api from :any:`numpy.zeros` See: https://numpy.org/doc/stable/reference/generated/numpy.zeros.html """ @@ -133,7 +133,7 @@ class Backend(): r""" Creates a tensor full of ones. - This function follow the api from :any:`numpy.ones` + This function follows the api from :any:`numpy.ones` See: https://numpy.org/doc/stable/reference/generated/numpy.ones.html """ @@ -143,7 +143,7 @@ class Backend(): r""" Returns evenly spaced values within a given interval. - This function follow the api from :any:`numpy.arange` + This function follows the api from :any:`numpy.arange` See: https://numpy.org/doc/stable/reference/generated/numpy.arange.html """ @@ -153,7 +153,7 @@ class Backend(): r""" Creates a tensor with given shape, filled with given value. - This function follow the api from :any:`numpy.full` + This function follows the api from :any:`numpy.full` See: https://numpy.org/doc/stable/reference/generated/numpy.full.html """ @@ -163,7 +163,7 @@ class Backend(): r""" Creates the identity matrix of given size. - This function follow the api from :any:`numpy.eye` + This function follows the api from :any:`numpy.eye` See: https://numpy.org/doc/stable/reference/generated/numpy.eye.html """ @@ -173,7 +173,7 @@ class Backend(): r""" Sums tensor elements over given dimensions. - This function follow the api from :any:`numpy.sum` + This function follows the api from :any:`numpy.sum` See: https://numpy.org/doc/stable/reference/generated/numpy.sum.html """ @@ -183,7 +183,7 @@ class Backend(): r""" Returns the cumulative sum of tensor elements over given dimensions. - This function follow the api from :any:`numpy.cumsum` + This function follows the api from :any:`numpy.cumsum` See: https://numpy.org/doc/stable/reference/generated/numpy.cumsum.html """ @@ -193,7 +193,7 @@ class Backend(): r""" Returns the maximum of an array or maximum along given dimensions. - This function follow the api from :any:`numpy.amax` + This function follows the api from :any:`numpy.amax` See: https://numpy.org/doc/stable/reference/generated/numpy.amax.html """ @@ -203,7 +203,7 @@ class Backend(): r""" Returns the maximum of an array or maximum along given dimensions. - This function follow the api from :any:`numpy.amin` + This function follows the api from :any:`numpy.amin` See: https://numpy.org/doc/stable/reference/generated/numpy.amin.html """ @@ -213,7 +213,7 @@ class Backend(): r""" Returns element-wise maximum of array elements. - This function follow the api from :any:`numpy.maximum` + This function follows the api from :any:`numpy.maximum` See: https://numpy.org/doc/stable/reference/generated/numpy.maximum.html """ @@ -223,7 +223,7 @@ class Backend(): r""" Returns element-wise minimum of array elements. - This function follow the api from :any:`numpy.minimum` + This function follows the api from :any:`numpy.minimum` See: https://numpy.org/doc/stable/reference/generated/numpy.minimum.html """ @@ -233,7 +233,7 @@ class Backend(): r""" Returns the dot product of two tensors. - This function follow the api from :any:`numpy.dot` + This function follows the api from :any:`numpy.dot` See: https://numpy.org/doc/stable/reference/generated/numpy.dot.html """ @@ -243,7 +243,7 @@ class Backend(): r""" Computes the absolute value element-wise. - This function follow the api from :any:`numpy.absolute` + This function follows the api from :any:`numpy.absolute` See: https://numpy.org/doc/stable/reference/generated/numpy.absolute.html """ @@ -253,7 +253,7 @@ class Backend(): r""" Computes the exponential value element-wise. - This function follow the api from :any:`numpy.exp` + This function follows the api from :any:`numpy.exp` See: https://numpy.org/doc/stable/reference/generated/numpy.exp.html """ @@ -263,7 +263,7 @@ class Backend(): r""" Computes the natural logarithm, element-wise. - This function follow the api from :any:`numpy.log` + This function follows the api from :any:`numpy.log` See: https://numpy.org/doc/stable/reference/generated/numpy.log.html """ @@ -273,7 +273,7 @@ class Backend(): r""" Returns the non-ngeative square root of a tensor, element-wise. - This function follow the api from :any:`numpy.sqrt` + This function follows the api from :any:`numpy.sqrt` See: https://numpy.org/doc/stable/reference/generated/numpy.sqrt.html """ @@ -283,7 +283,7 @@ class Backend(): r""" First tensor elements raised to powers from second tensor, element-wise. - This function follow the api from :any:`numpy.power` + This function follows the api from :any:`numpy.power` See: https://numpy.org/doc/stable/reference/generated/numpy.power.html """ @@ -293,7 +293,7 @@ class Backend(): r""" Computes the matrix frobenius norm. - This function follow the api from :any:`numpy.linalg.norm` + This function follows the api from :any:`numpy.linalg.norm` See: https://numpy.org/doc/stable/reference/generated/numpy.linalg.norm.html """ @@ -303,7 +303,7 @@ class Backend(): r""" Tests whether any tensor element along given dimensions evaluates to True. - This function follow the api from :any:`numpy.any` + This function follows the api from :any:`numpy.any` See: https://numpy.org/doc/stable/reference/generated/numpy.any.html """ @@ -313,7 +313,7 @@ class Backend(): r""" Tests element-wise for NaN and returns result as a boolean tensor. - This function follow the api from :any:`numpy.isnan` + This function follows the api from :any:`numpy.isnan` See: https://numpy.org/doc/stable/reference/generated/numpy.isnan.html """ @@ -323,7 +323,7 @@ class Backend(): r""" Tests element-wise for positive or negative infinity and returns result as a boolean tensor. - This function follow the api from :any:`numpy.isinf` + This function follows the api from :any:`numpy.isinf` See: https://numpy.org/doc/stable/reference/generated/numpy.isinf.html """ @@ -333,7 +333,7 @@ class Backend(): r""" Evaluates the Einstein summation convention on the operands. - This function follow the api from :any:`numpy.einsum` + This function follows the api from :any:`numpy.einsum` See: https://numpy.org/doc/stable/reference/generated/numpy.einsum.html """ @@ -343,7 +343,7 @@ class Backend(): r""" Returns a sorted copy of a tensor. - This function follow the api from :any:`numpy.sort` + This function follows the api from :any:`numpy.sort` See: https://numpy.org/doc/stable/reference/generated/numpy.sort.html """ @@ -353,7 +353,7 @@ class Backend(): r""" Returns the indices that would sort a tensor. - This function follow the api from :any:`numpy.argsort` + This function follows the api from :any:`numpy.argsort` See: https://numpy.org/doc/stable/reference/generated/numpy.argsort.html """ @@ -363,7 +363,7 @@ class Backend(): r""" Finds indices where elements should be inserted to maintain order in given tensor. - This function follow the api from :any:`numpy.searchsorted` + This function follows the api from :any:`numpy.searchsorted` See: https://numpy.org/doc/stable/reference/generated/numpy.searchsorted.html """ @@ -373,7 +373,7 @@ class Backend(): r""" Reverses the order of elements in a tensor along given dimensions. - This function follow the api from :any:`numpy.flip` + This function follows the api from :any:`numpy.flip` See: https://numpy.org/doc/stable/reference/generated/numpy.flip.html """ @@ -383,7 +383,7 @@ class Backend(): """ Limits the values in a tensor. - This function follow the api from :any:`numpy.clip` + This function follows the api from :any:`numpy.clip` See: https://numpy.org/doc/stable/reference/generated/numpy.clip.html """ @@ -393,7 +393,7 @@ class Backend(): r""" Repeats elements of a tensor. - This function follow the api from :any:`numpy.repeat` + This function follows the api from :any:`numpy.repeat` See: https://numpy.org/doc/stable/reference/generated/numpy.repeat.html """ @@ -403,7 +403,7 @@ class Backend(): r""" Gathers elements of a tensor along given dimensions. - This function follow the api from :any:`numpy.take_along_axis` + This function follows the api from :any:`numpy.take_along_axis` See: https://numpy.org/doc/stable/reference/generated/numpy.take_along_axis.html """ @@ -413,7 +413,7 @@ class Backend(): r""" Joins a sequence of tensors along an existing dimension. - This function follow the api from :any:`numpy.concatenate` + This function follows the api from :any:`numpy.concatenate` See: https://numpy.org/doc/stable/reference/generated/numpy.concatenate.html """ @@ -423,7 +423,7 @@ class Backend(): r""" Pads a tensor. - This function follow the api from :any:`numpy.pad` + This function follows the api from :any:`numpy.pad` See: https://numpy.org/doc/stable/reference/generated/numpy.pad.html """ @@ -433,7 +433,7 @@ class Backend(): r""" Returns the indices of the maximum values of a tensor along given dimensions. - This function follow the api from :any:`numpy.argmax` + This function follows the api from :any:`numpy.argmax` See: https://numpy.org/doc/stable/reference/generated/numpy.argmax.html """ @@ -443,7 +443,7 @@ class Backend(): r""" Computes the arithmetic mean of a tensor along given dimensions. - This function follow the api from :any:`numpy.mean` + This function follows the api from :any:`numpy.mean` See: https://numpy.org/doc/stable/reference/generated/numpy.mean.html """ @@ -453,7 +453,7 @@ class Backend(): r""" Computes the standard deviation of a tensor along given dimensions. - This function follow the api from :any:`numpy.std` + This function follows the api from :any:`numpy.std` See: https://numpy.org/doc/stable/reference/generated/numpy.std.html """ @@ -463,7 +463,7 @@ class Backend(): r""" Returns a specified number of evenly spaced values over a given interval. - This function follow the api from :any:`numpy.linspace` + This function follows the api from :any:`numpy.linspace` See: https://numpy.org/doc/stable/reference/generated/numpy.linspace.html """ @@ -473,7 +473,7 @@ class Backend(): r""" Returns coordinate matrices from coordinate vectors (Numpy convention). - This function follow the api from :any:`numpy.meshgrid` + This function follows the api from :any:`numpy.meshgrid` See: https://numpy.org/doc/stable/reference/generated/numpy.meshgrid.html """ @@ -483,7 +483,7 @@ class Backend(): r""" Extracts or constructs a diagonal tensor. - This function follow the api from :any:`numpy.diag` + This function follows the api from :any:`numpy.diag` See: https://numpy.org/doc/stable/reference/generated/numpy.diag.html """ @@ -493,7 +493,7 @@ class Backend(): r""" Finds unique elements of given tensor. - This function follow the api from :any:`numpy.unique` + This function follows the api from :any:`numpy.unique` See: https://numpy.org/doc/stable/reference/generated/numpy.unique.html """ @@ -503,7 +503,7 @@ class Backend(): r""" Computes the log of the sum of exponentials of input elements. - This function follow the api from :any:`scipy.special.logsumexp` + This function follows the api from :any:`scipy.special.logsumexp` See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.special.logsumexp.html """ @@ -513,12 +513,32 @@ class Backend(): r""" Joins a sequence of tensors along a new dimension. - This function follow the api from :any:`numpy.stack` + This function follows the api from :any:`numpy.stack` See: https://numpy.org/doc/stable/reference/generated/numpy.stack.html """ raise NotImplementedError() + def outer(self, a, b): + r""" + Computes the outer product between two vectors. + + This function follows the api from :any:`numpy.outer` + + See: https://numpy.org/doc/stable/reference/generated/numpy.outer.html + """ + raise NotImplementedError() + + def reshape(self, a, shape): + r""" + Gives a new shape to a tensor without changing its data. + + This function follows the api from :any:`numpy.reshape` + + See: https://numpy.org/doc/stable/reference/generated/numpy.reshape.html + """ + raise NotImplementedError() + class NumpyBackend(Backend): """ @@ -644,6 +664,9 @@ class NumpyBackend(Backend): def flip(self, a, axis=None): return np.flip(a, axis) + def outer(self, a, b): + return np.outer(a, b) + def clip(self, a, a_min, a_max): return np.clip(a, a_min, a_max) @@ -686,6 +709,9 @@ class NumpyBackend(Backend): def stack(self, arrays, axis=0): return np.stack(arrays, axis) + def reshape(self, a, shape): + return np.reshape(a, shape) + class JaxBackend(Backend): """ @@ -815,6 +841,9 @@ class JaxBackend(Backend): def flip(self, a, axis=None): return jnp.flip(a, axis) + def outer(self, a, b): + return jnp.outer(a, b) + def clip(self, a, a_min, a_max): return jnp.clip(a, a_min, a_max) @@ -857,6 +886,9 @@ class JaxBackend(Backend): def stack(self, arrays, axis=0): return jnp.stack(arrays, axis) + def reshape(self, a, shape): + return jnp.reshape(a, shape) + class TorchBackend(Backend): """ @@ -1035,6 +1067,9 @@ class TorchBackend(Backend): else: return torch.flip(a, dims=axis) + def outer(self, a, b): + return torch.outer(a, b) + def clip(self, a, a_min, a_max): return torch.clamp(a, a_min, a_max) @@ -1091,3 +1126,6 @@ class TorchBackend(Backend): def stack(self, arrays, axis=0): return torch.stack(arrays, dim=axis) + + def reshape(self, a, shape): + return torch.reshape(a, shape) diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index b907b10..c6757d1 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -281,12 +281,12 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1): a0, b0, M0 = a, b, M nx = get_backend(M0, a0, b0) - + # convert to numpy M = nx.to_numpy(M) a = nx.to_numpy(a) b = nx.to_numpy(b) - + # ensure float64 a = np.asarray(a, dtype=np.float64) b = np.asarray(b, dtype=np.float64) diff --git a/ot/optim.py b/ot/optim.py index 0359343..6822e4e 100644 --- a/ot/optim.py +++ b/ot/optim.py @@ -12,6 +12,8 @@ import numpy as np from scipy.optimize.linesearch import scalar_search_armijo from .lp import emd from .bregman import sinkhorn +from ot.utils import list_to_array +from .backend import get_backend # The corresponding scipy function does not work for matrices @@ -21,25 +23,25 @@ def line_search_armijo(f, xk, pk, gfk, old_fval, """ Armijo linesearch function that works with matrices - find an approximate minimum of f(xk+alpha*pk) that satifies the + Find an approximate minimum of :math:`f(x_k + \\alpha \cdot p_k)` that satisfies the armijo conditions. Parameters ---------- f : callable loss function - xk : ndarray + xk : array-like initial position - pk : ndarray + pk : array-like descent direction - gfk : ndarray - gradient of f at xk + gfk : array-like + gradient of `f` at :math:`x_k` old_fval : float - loss value at xk + loss value at :math:`x_k` args : tuple, optional - arguments given to f + arguments given to `f` c1 : float, optional - c1 const in armijo rule (>0) + :math:`c_1` const in armijo rule (>0) alpha0 : float, optional initial step (>0) @@ -53,7 +55,13 @@ def line_search_armijo(f, xk, pk, gfk, old_fval, loss value at step alpha """ - xk = np.atleast_1d(xk) + + xk, pk, gfk = list_to_array(xk, pk, gfk) + nx = get_backend(xk, pk) + + if len(xk.shape) == 0: + xk = nx.reshape(xk, (-1,)) + fc = [0] def phi(alpha1): @@ -65,7 +73,7 @@ def line_search_armijo(f, xk, pk, gfk, old_fval, else: phi0 = old_fval - derphi0 = np.sum(pk * gfk) # Quickfix for matrices + derphi0 = nx.sum(pk * gfk) # Quickfix for matrices alpha, phi1 = scalar_search_armijo( phi, phi0, derphi0, c1=c1, alpha0=alpha0) @@ -79,55 +87,64 @@ def solve_linesearch(cost, G, deltaG, Mi, f_val, armijo=True, C1=None, C2=None, reg=None, Gc=None, constC=None, M=None): """ Solve the linesearch in the FW iterations + Parameters ---------- cost : method Cost in the FW for the linesearch - G : ndarray, shape(ns,nt) + G : array-like, shape(ns,nt) The transport map at a given iteration of the FW - deltaG : ndarray (ns,nt) + deltaG : array-like (ns,nt) Difference between the optimal map found by linearization in the FW algorithm and the value at a given iteration - Mi : ndarray (ns,nt) + Mi : array-like (ns,nt) Cost matrix of the linearized transport problem. Corresponds to the gradient of the cost - f_val : float - Value of the cost at G + f_val : float + Value of the cost at `G` armijo : bool, optional - If True the steps of the line-search is found via an armijo research. Else closed form is used. - If there is convergence issues use False. - C1 : ndarray (ns,ns), optional + If True the steps of the line-search is found via an armijo research. Else closed form is used. + If there is convergence issues use False. + C1 : array-like (ns,ns), optional Structure matrix in the source domain. Only used and necessary when armijo=False - C2 : ndarray (nt,nt), optional + C2 : array-like (nt,nt), optional Structure matrix in the target domain. Only used and necessary when armijo=False reg : float, optional - Regularization parameter. Only used and necessary when armijo=False - Gc : ndarray (ns,nt) + Regularization parameter. Only used and necessary when armijo=False + Gc : array-like (ns,nt) Optimal map found by linearization in the FW algorithm. Only used and necessary when armijo=False - constC : ndarray (ns,nt) - Constant for the gromov cost. See [24]. Only used and necessary when armijo=False - M : ndarray (ns,nt), optional + constC : array-like (ns,nt) + Constant for the gromov cost. See :ref:`[24] `. Only used and necessary when armijo=False + M : array-like (ns,nt), optional Cost matrix between the features. Only used and necessary when armijo=False + Returns ------- alpha : float - The optimal step size of the FW + The optimal step size of the FW fc : int - nb of function call. Useless here - f_val : float - The value of the cost for the next iteration + nb of function call. Useless here + f_val : float + The value of the cost for the next iteration + + + .. _references-solve-linesearch: References ---------- - .. [24] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain - and Courty Nicolas + .. [24] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain and Courty Nicolas "Optimal Transport for structured data with application on graphs" International Conference on Machine Learning (ICML). 2019. """ if armijo: alpha, fc, f_val = line_search_armijo(cost, G, deltaG, Mi, f_val) else: # requires symetric matrices - dot1 = np.dot(C1, deltaG) - dot12 = dot1.dot(C2) - a = -2 * reg * np.sum(dot12 * deltaG) - b = np.sum((M + reg * constC) * deltaG) - 2 * reg * (np.sum(dot12 * G) + np.sum(np.dot(C1, G).dot(C2) * deltaG)) + G, deltaG, C1, C2, constC, M = list_to_array(G, deltaG, C1, C2, constC, M) + if isinstance(M, int) or isinstance(M, float): + nx = get_backend(G, deltaG, C1, C2, constC) + else: + nx = get_backend(G, deltaG, C1, C2, constC, M) + + dot = nx.dot(nx.dot(C1, deltaG), C2) + a = -2 * reg * nx.sum(dot * deltaG) + b = nx.sum((M + reg * constC) * deltaG) - 2 * reg * (nx.sum(dot * G) + nx.sum(nx.dot(nx.dot(C1, G), C2) * deltaG)) c = cost(G) alpha = solve_1d_linesearch_quad(a, b, c) @@ -145,33 +162,33 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000, The function solves the following optimization problem: .. math:: - \gamma = arg\min_\gamma <\gamma,M>_F + reg*f(\gamma) + \gamma = arg\min_\gamma <\gamma,M>_F + \mathrm{reg} \cdot f(\gamma) - s.t. \gamma 1 = a + s.t. \ \gamma 1 = a \gamma^T 1= b \gamma\geq 0 where : - - M is the (ns,nt) metric cost matrix - - :math:`f` is the regularization term ( and df is its gradient) - - a and b are source and target weights (sum to 1) + - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix + - :math:`f` is the regularization term (and `df` is its gradient) + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1) - The algorithm used for solving the problem is conditional gradient as discussed in [1]_ + The algorithm used for solving the problem is conditional gradient as discussed in :ref:`[1] ` Parameters ---------- - a : ndarray, shape (ns,) + a : array-like, shape (ns,) samples weights in the source domain - b : ndarray, shape (nt,) + b : array-like, shape (nt,) samples in the target domain - M : ndarray, shape (ns, nt) + M : array-like, shape (ns, nt) loss matrix reg : float Regularization term >0 - G0 : ndarray, shape (ns,nt), optional + G0 : array-like, shape (ns,nt), optional initial guess (default is indep joint density) numItermax : int, optional Max number of iterations @@ -196,6 +213,7 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000, log dictionary return only if log==True in parameters + .. _references-cg: References ---------- @@ -207,6 +225,11 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000, ot.bregman.sinkhorn : Entropic regularized optimal transport """ + a, b, M, G0 = list_to_array(a, b, M, G0) + if isinstance(M, int) or isinstance(M, float): + nx = get_backend(a, b) + else: + nx = get_backend(a, b, M) loop = 1 @@ -214,12 +237,12 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000, log = {'loss': []} if G0 is None: - G = np.outer(a, b) + G = nx.outer(a, b) else: G = G0 def cost(G): - return np.sum(M * G) + reg * f(G) + return nx.sum(M * G) + reg * f(G) f_val = cost(G) if log: @@ -240,7 +263,7 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000, # problem linearization Mi = M + reg * df(G) # set M positive - Mi += Mi.min() + Mi += nx.min(Mi) # solve linear program Gc = emd(a, b, Mi, numItermax=numItermaxEmd) @@ -286,36 +309,36 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10, The function solves the following optimization problem: .. math:: - \gamma = arg\min_\gamma <\gamma,M>_F + reg1\cdot\Omega(\gamma) + reg2\cdot f(\gamma) + \gamma = arg\min_\gamma <\gamma,M>_F + \mathrm{reg_1}\cdot\Omega(\gamma) + \mathrm{reg_2}\cdot f(\gamma) - s.t. \gamma 1 = a + s.t. \ \gamma 1 = a \gamma^T 1= b \gamma\geq 0 where : - - M is the (ns,nt) metric cost matrix + - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - :math:`f` is the regularization term ( and df is its gradient) - - a and b are source and target weights (sum to 1) + - :math:`f` is the regularization term (and `df` is its gradient) + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1) - The algorithm used for solving the problem is the generalized conditional gradient as discussed in [5,7]_ + The algorithm used for solving the problem is the generalized conditional gradient as discussed in :ref:`[5, 7] ` Parameters ---------- - a : ndarray, shape (ns,) + a : array-like, shape (ns,) samples weights in the source domain - b : ndarrayv (nt,) + b : array-like, (nt,) samples in the target domain - M : ndarray, shape (ns, nt) + M : array-like, shape (ns, nt) loss matrix reg1 : float Entropic Regularization term >0 reg2 : float Second Regularization term >0 - G0 : ndarray, shape (ns, nt), optional + G0 : array-like, shape (ns, nt), optional initial guess (default is indep joint density) numItermax : int, optional Max number of iterations @@ -337,9 +360,13 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10, log : dict log dictionary return only if log==True in parameters + + .. _references-gcg: References ---------- + .. [5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, "Optimal Transport for Domain Adaptation," in IEEE Transactions on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1 + .. [7] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015). Generalized conditional gradient: analysis of convergence and applications. arXiv preprint arXiv:1510.06567. See Also @@ -347,6 +374,8 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10, ot.optim.cg : conditional gradient """ + a, b, M, G0 = list_to_array(a, b, M, G0) + nx = get_backend(a, b, M) loop = 1 @@ -354,12 +383,12 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10, log = {'loss': []} if G0 is None: - G = np.outer(a, b) + G = nx.outer(a, b) else: G = G0 def cost(G): - return np.sum(M * G) + reg1 * np.sum(G * np.log(G)) + reg2 * f(G) + return nx.sum(M * G) + reg1 * nx.sum(G * nx.log(G)) + reg2 * f(G) f_val = cost(G) if log: @@ -387,7 +416,7 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10, deltaG = Gc - G # line search - dcost = Mi + reg1 * (1 + np.log(G)) # ?? + dcost = Mi + reg1 * (1 + nx.log(G)) # ?? alpha, fc, f_val = line_search_armijo(cost, G, deltaG, dcost, f_val) G = G + alpha * deltaG @@ -419,9 +448,11 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10, def solve_1d_linesearch_quad(a, b, c): """ - For any convex or non-convex 1d quadratic function f, solve on [0,1] the following problem: + For any convex or non-convex 1d quadratic function `f`, solve the following problem: + .. math:: - \argmin f(x)=a*x^{2}+b*x+c + + arg\min_{0 \leq x \leq 1} f(x) = ax^{2} + bx + c Parameters ---------- diff --git a/test/test_backend.py b/test/test_backend.py index 859da5a..5853282 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -17,9 +17,6 @@ from numpy.testing import assert_array_almost_equal_nulp from ot.backend import get_backend, get_backend_list, to_numpy -backend_list = get_backend_list() - - def test_get_backend_list(): lst = get_backend_list() @@ -28,7 +25,6 @@ def test_get_backend_list(): assert isinstance(lst[0], ot.backend.NumpyBackend) -@pytest.mark.parametrize('nx', backend_list) def test_to_numpy(nx): v = nx.zeros(10) @@ -92,7 +88,6 @@ def test_get_backend(): get_backend(A, B2) -@pytest.mark.parametrize('nx', backend_list) def test_convert_between_backends(nx): A = np.zeros((3, 2)) @@ -180,6 +175,8 @@ def test_empty_backend(): nx.searchsorted(v, v) with pytest.raises(NotImplementedError): nx.flip(M) + with pytest.raises(NotImplementedError): + nx.outer(v, v) with pytest.raises(NotImplementedError): nx.clip(M, -1, 1) with pytest.raises(NotImplementedError): @@ -208,10 +205,11 @@ def test_empty_backend(): nx.logsumexp(M) with pytest.raises(NotImplementedError): nx.stack([M, M]) + with pytest.raises(NotImplementedError): + nx.reshape(M, (5, 3, 2)) -@pytest.mark.parametrize('backend', backend_list) -def test_func_backends(backend): +def test_func_backends(nx): rnd = np.random.RandomState(0) M = rnd.randn(10, 3) @@ -220,7 +218,7 @@ def test_func_backends(backend): lst_tot = [] - for nx in [ot.backend.NumpyBackend(), backend]: + for nx in [ot.backend.NumpyBackend(), nx]: print('Backend: ', nx.__name__) @@ -371,6 +369,10 @@ def test_func_backends(backend): lst_b.append(nx.to_numpy(A)) lst_name.append('flip') + A = nx.outer(vb, vb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('outer') + A = nx.clip(vb, 0, 1) lst_b.append(nx.to_numpy(A)) lst_name.append('clip') @@ -432,6 +434,10 @@ def test_func_backends(backend): lst_b.append(nx.to_numpy(A)) lst_name.append('stack') + A = nx.reshape(Mb, (5, 3, 2)) + lst_b.append(nx.to_numpy(A)) + lst_name.append('reshape') + lst_tot.append(lst_b) lst_np = lst_tot[0] diff --git a/test/test_optim.py b/test/test_optim.py index 94995d5..4efd9b1 100644 --- a/test/test_optim.py +++ b/test/test_optim.py @@ -8,7 +8,7 @@ import numpy as np import ot -def test_conditional_gradient(): +def test_conditional_gradient(nx): n_bins = 100 # nb bins np.random.seed(0) @@ -29,15 +29,25 @@ def test_conditional_gradient(): def df(G): return G + def fb(G): + return 0.5 * nx.sum(G ** 2) + + ab = nx.from_numpy(a) + bb = nx.from_numpy(b) + Mb = nx.from_numpy(M, type_as=ab) + reg = 1e-1 G, log = ot.optim.cg(a, b, M, reg, f, df, verbose=True, log=True) + Gb, log = ot.optim.cg(ab, bb, Mb, reg, fb, df, verbose=True, log=True) + Gb = nx.to_numpy(Gb) - np.testing.assert_allclose(a, G.sum(1)) - np.testing.assert_allclose(b, G.sum(0)) + np.testing.assert_allclose(Gb, G) + np.testing.assert_allclose(a, Gb.sum(1)) + np.testing.assert_allclose(b, Gb.sum(0)) -def test_conditional_gradient_itermax(): +def test_conditional_gradient_itermax(nx): n = 100 # nb samples mu_s = np.array([0, 0]) @@ -61,16 +71,27 @@ def test_conditional_gradient_itermax(): def df(G): return G + def fb(G): + return 0.5 * nx.sum(G ** 2) + + ab = nx.from_numpy(a) + bb = nx.from_numpy(b) + Mb = nx.from_numpy(M, type_as=ab) + reg = 1e-1 G, log = ot.optim.cg(a, b, M, reg, f, df, numItermaxEmd=10000, verbose=True, log=True) + Gb, log = ot.optim.cg(ab, bb, Mb, reg, fb, df, numItermaxEmd=10000, + verbose=True, log=True) + Gb = nx.to_numpy(Gb) - np.testing.assert_allclose(a, G.sum(1)) - np.testing.assert_allclose(b, G.sum(0)) + np.testing.assert_allclose(Gb, G) + np.testing.assert_allclose(a, Gb.sum(1)) + np.testing.assert_allclose(b, Gb.sum(0)) -def test_generalized_conditional_gradient(): +def test_generalized_conditional_gradient(nx): n_bins = 100 # nb bins np.random.seed(0) @@ -91,13 +112,23 @@ def test_generalized_conditional_gradient(): def df(G): return G + def fb(G): + return 0.5 * nx.sum(G ** 2) + reg1 = 1e-3 reg2 = 1e-1 + ab = nx.from_numpy(a) + bb = nx.from_numpy(b) + Mb = nx.from_numpy(M, type_as=ab) + G, log = ot.optim.gcg(a, b, M, reg1, reg2, f, df, verbose=True, log=True) + Gb, log = ot.optim.gcg(ab, bb, Mb, reg1, reg2, fb, df, verbose=True, log=True) + Gb = nx.to_numpy(Gb) - np.testing.assert_allclose(a, G.sum(1), atol=1e-05) - np.testing.assert_allclose(b, G.sum(0), atol=1e-05) + np.testing.assert_allclose(Gb, G) + np.testing.assert_allclose(a, Gb.sum(1), atol=1e-05) + np.testing.assert_allclose(b, Gb.sum(0), atol=1e-05) def test_solve_1d_linesearch_quad_funct(): @@ -106,24 +137,31 @@ def test_solve_1d_linesearch_quad_funct(): np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad(-1, 0.5, 0), 1) -def test_line_search_armijo(): +def test_line_search_armijo(nx): xk = np.array([[0.25, 0.25], [0.25, 0.25]]) pk = np.array([[-0.25, 0.25], [0.25, -0.25]]) gfk = np.array([[23.04273441, 23.0449082], [23.04273441, 23.0449082]]) old_fval = -123 # Should not throw an exception and return None for alpha - alpha, _, _ = ot.optim.line_search_armijo(lambda x: 1, xk, pk, gfk, old_fval) + alpha, a, b = ot.optim.line_search_armijo( + lambda x: 1, nx.from_numpy(xk), nx.from_numpy(pk), nx.from_numpy(gfk), old_fval + ) + alpha_np, anp, bnp = ot.optim.line_search_armijo( + lambda x: 1, xk, pk, gfk, old_fval + ) + assert a == anp + assert b == bnp assert alpha is None # check line search armijo def f(x): - return np.sum((x - 5.0) ** 2) + return nx.sum((x - 5.0) ** 2) def grad(x): return 2 * (x - 5.0) - xk = np.array([[[-5.0, -5.0]]]) - pk = np.array([[[100.0, 100.0]]]) + xk = nx.from_numpy(np.array([[[-5.0, -5.0]]])) + pk = nx.from_numpy(np.array([[[100.0, 100.0]]])) gfk = grad(xk) old_fval = f(xk) @@ -132,10 +170,18 @@ def test_line_search_armijo(): np.testing.assert_allclose(alpha, 0.1) # check the case where the direction is not far enough - pk = np.array([[[3.0, 3.0]]]) + pk = nx.from_numpy(np.array([[[3.0, 3.0]]])) alpha, _, _ = ot.optim.line_search_armijo(f, xk, pk, gfk, old_fval, alpha0=1.0) np.testing.assert_allclose(alpha, 1.0) - # check the case where the checking the wrong direction + # check the case where checking the wrong direction alpha, _, _ = ot.optim.line_search_armijo(f, xk, -pk, gfk, old_fval) assert alpha <= 0 + + # check the case where the point is not a vector + xk = nx.from_numpy(np.array(-5.0)) + pk = nx.from_numpy(np.array(100.0)) + gfk = grad(xk) + old_fval = f(xk) + alpha, _, _ = ot.optim.line_search_armijo(f, xk, pk, gfk, old_fval) + np.testing.assert_allclose(alpha, 0.1) diff --git a/test/test_ot.py b/test/test_ot.py index 3e953dc..4dfc510 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -12,9 +12,7 @@ from scipy.stats import wasserstein_distance import ot from ot.datasets import make_1D_gauss as gauss -from ot.backend import get_backend_list, torch - -backend_list = get_backend_list() +from ot.backend import torch def test_emd_dimension_and_mass_mismatch(): @@ -37,7 +35,6 @@ def test_emd_dimension_and_mass_mismatch(): np.testing.assert_raises(AssertionError, ot.emd, a, b, M) -@pytest.mark.parametrize('nx', backend_list) def test_emd_backends(nx): n_samples = 100 n_features = 2 @@ -59,7 +56,6 @@ def test_emd_backends(nx): np.allclose(G, nx.to_numpy(Gb)) -@pytest.mark.parametrize('nx', backend_list) def test_emd2_backends(nx): n_samples = 100 n_features = 2 diff --git a/test/test_utils.py b/test/test_utils.py index 76b1faa..60ad5d3 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -4,17 +4,11 @@ # # License: MIT License -import pytest import ot import numpy as np import sys -from ot.backend import get_backend_list -backend_list = get_backend_list() - - -@pytest.mark.parametrize('nx', backend_list) def test_proj_simplex(nx): n = 10 rng = np.random.RandomState(0) @@ -119,7 +113,6 @@ def test_dist(): np.testing.assert_allclose(D, D3, atol=1e-14) -@ pytest.mark.parametrize('nx', backend_list) def test_dist_backends(nx): n = 100 -- cgit v1.2.3 From d7554331fc409fea48ee758fd630909dd9dc4827 Mon Sep 17 00:00:00 2001 From: Rémi Flamary Date: Wed, 27 Oct 2021 08:41:08 +0200 Subject: [WIP] Sinkhorn in log space (#290) * adda sinkhorn log and working sinkhorn2 function * more tests pass * more tests pass * it works but not by default yet * remove warningd * update circleci doc * update circleci doc * new sinkhorn implemeted but not by default * better * doctest pass * test doctest * new test utils * remove pep8 errors * remove pep8 errors * doc new implementtaion with log * test sinkhorn 2 * doc for log implementation --- .circleci/config.yml | 14 +-- README.md | 4 +- docs/source/quickstart.rst | 10 +- ot/bregman.py | 272 +++++++++++++++++++++++++++++++++++++++++---- ot/dr.py | 4 +- ot/gromov.py | 4 +- ot/optim.py | 4 +- ot/utils.py | 4 +- test/test_bregman.py | 120 ++++++++++++++++++-- test/test_gromov.py | 10 +- test/test_helpers.py | 4 +- test/test_utils.py | 15 +++ 12 files changed, 403 insertions(+), 62 deletions(-) (limited to 'test') diff --git a/.circleci/config.yml b/.circleci/config.yml index e4c71dd..379394a 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -4,7 +4,7 @@ version: 2 jobs: build_docs: docker: - - image: circleci/python:3.7-stretch + - image: cimg/python:3.9 steps: - checkout - run: @@ -34,18 +34,6 @@ jobs: - data-cache-0 - pip-cache - - run: - name: Spin up Xvfb - command: | - /sbin/start-stop-daemon --start --quiet --pidfile /tmp/custom_xvfb_99.pid --make-pidfile --background --exec /usr/bin/Xvfb -- :99 -screen 0 1400x900x24 -ac +extension GLX +render -noreset; - - # https://github.com/ContinuumIO/anaconda-issues/issues/9190#issuecomment-386508136 - # https://github.com/golemfactory/golem/issues/1019 - - run: - name: Fix libgcc_s.so.1 pthread_cancel bug - command: | - sudo apt-get install qt5-default - - run: name: Get Python running command: | diff --git a/README.md b/README.md index 266d847..ffad0bd 100644 --- a/README.md +++ b/README.md @@ -20,7 +20,7 @@ POT provides the following generic OT solvers (links to examples): * [OT Network Simplex solver](https://pythonot.github.io/auto_examples/plot_OT_1D.html) for the linear program/ Earth Movers Distance [1] . * [Conditional gradient](https://pythonot.github.io/auto_examples/plot_optim_OTreg.html) [6] and [Generalized conditional gradient](https://pythonot.github.io/auto_examples/plot_optim_OTreg.html) for regularized OT [7]. -* Entropic regularization OT solver with [Sinkhorn Knopp Algorithm](https://pythonot.github.io/auto_examples/plot_OT_1D.html) [2] , stabilized version [9] [10], greedy Sinkhorn [22] and [Screening Sinkhorn [26] ](https://pythonot.github.io/auto_examples/plot_screenkhorn_1D.html). +* Entropic regularization OT solver with [Sinkhorn Knopp Algorithm](https://pythonot.github.io/auto_examples/plot_OT_1D.html) [2] , stabilized version [9] [10] [34], greedy Sinkhorn [22] and [Screening Sinkhorn [26] ](https://pythonot.github.io/auto_examples/plot_screenkhorn_1D.html). * Bregman projections for [Wasserstein barycenter](https://pythonot.github.io/auto_examples/barycenters/plot_barycenter_lp_vs_entropic.html) [3], [convolutional barycenter](https://pythonot.github.io/auto_examples/barycenters/plot_convolutional_barycenter.html) [21] and unmixing [4]. * Sinkhorn divergence [23] and entropic regularization OT from empirical data. * [Smooth optimal transport solvers](https://pythonot.github.io/auto_examples/plot_OT_1D_smooth.html) (dual and semi-dual) for KL and squared L2 regularizations [17]. @@ -290,3 +290,5 @@ You can also post bug reports and feature requests in Github issues. Make sure t [32] Huang, M., Ma S., Lai, L. (2021). [A Riemannian Block Coordinate Descent Method for Computing the Projection Robust Wasserstein Distance](http://proceedings.mlr.press/v139/huang21e.html), Proceedings of the 38th International Conference on Machine Learning (ICML). [33] Kerdoncuff T., Emonet R., Marc S. [Sampled Gromov Wasserstein](https://hal.archives-ouvertes.fr/hal-03232509/document), Machine Learning Journal (MJL), 2021 + +[34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., Trouvé, A., & Peyré, G. (2019, April). Interpolating between optimal transport and MMD using Sinkhorn divergences. In The 22nd International Conference on Artificial Intelligence and Statistics (pp. 2681-2690). PMLR. diff --git a/docs/source/quickstart.rst b/docs/source/quickstart.rst index fd046a1..232df7b 100644 --- a/docs/source/quickstart.rst +++ b/docs/source/quickstart.rst @@ -358,6 +358,11 @@ More details about the algorithms used are given in the following note. + :code:`method='sinkhorn'` calls :any:`ot.bregman.sinkhorn_knopp` the classic algorithm [2]_. + + :code:`method='sinkhorn_log'` calls :any:`ot.bregman.sinkhorn_log` the + sinkhorn algorithm in log space [2]_ that is more stable but can be + slower in numpy since `logsumexp` is not implmemented in parallel. + It is the recommended solver for applications that requires + differentiability with a small number of iterations. + :code:`method='sinkhorn_stabilized'` calls :any:`ot.bregman.sinkhorn_stabilized` the log stabilized version of the algorithm [9]_. + :code:`method='sinkhorn_epsilon_scaling'` calls @@ -389,7 +394,10 @@ More details about the algorithms used are given in the following note. solutions. Note that the greedy version of the Sinkhorn :any:`ot.bregman.greenkhorn` can also lead to a speedup and the screening version of the Sinkhorn :any:`ot.bregman.screenkhorn` aim a providing a - fast approximation of the Sinkhorn problem. + fast approximation of the Sinkhorn problem. For use of GPU and gradient + computation with small number of iterations we strongly recommend the + :any:`ot.bregman.sinkhorn_log` solver that will no need to check for + numerical problems. diff --git a/ot/bregman.py b/ot/bregman.py index b59ee1b..2aa76ff 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -64,7 +64,10 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, solutions. Note that the greedy version of the sinkhorn :py:func:`ot.bregman.greenkhorn` can also lead to a speedup and the screening version of the sinkhorn :py:func:`ot.bregman.screenkhorn` aim at providing a - fast approximation of the Sinkhorn problem. + fast approximation of the Sinkhorn problem. For use of GPU and gradient + computation with small number of iterations we strongly recommend the + :any:`ot.bregman.sinkhorn_log` solver that will no need to check for + numerical problems. Parameters @@ -79,8 +82,9 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, reg : float Regularization term >0 method : str - method used for the solver either 'sinkhorn', 'greenkhorn', 'sinkhorn_stabilized' or - 'sinkhorn_epsilon_scaling', see those function for specific parameters + method used for the solver either 'sinkhorn','sinkhorn_log', + 'greenkhorn', 'sinkhorn_stabilized' or 'sinkhorn_epsilon_scaling', see + those function for specific parameters numItermax : int, optional Max number of iterations stopThr : float, optional @@ -118,6 +122,7 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. + .. [34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., Trouvé, A., & Peyré, G. (2019, April). Interpolating between optimal transport and MMD using Sinkhorn divergences. In The 22nd International Conference on Artificial Intelligence and Statistics (pp. 2681-2690). PMLR. See Also @@ -134,6 +139,10 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) + elif method.lower() == 'sinkhorn_log': + return sinkhorn_log(a, b, M, reg, numItermax=numItermax, + stopThr=stopThr, verbose=verbose, log=log, + **kwargs) elif method.lower() == 'greenkhorn': return greenkhorn(a, b, M, reg, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log) @@ -182,7 +191,7 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, By default and when using a regularization parameter that is not too small the default sinkhorn solver should be enough. If you need to use a small regularization to get sharper OT matrices, you should use the - :any:`ot.bregman.sinkhorn_stabilized` solver that will avoid numerical + :any:`ot.bregman.sinkhorn_log` solver that will avoid numerical errors. This last solver can be very slow in practice and might not even converge to a reasonable OT matrix in a finite time. This is why :any:`ot.bregman.sinkhorn_epsilon_scaling` that relies on iterating the value @@ -190,7 +199,10 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, solutions. Note that the greedy version of the sinkhorn :any:`ot.bregman.greenkhorn` can also lead to a speedup and the screening version of the sinkhorn :any:`ot.bregman.screenkhorn` aim a providing a - fast approximation of the Sinkhorn problem. + fast approximation of the Sinkhorn problem. For use of GPU and gradient + computation with small number of iterations we strongly recommend the + :any:`ot.bregman.sinkhorn_log` solver that will no need to check for + numerical problems. Parameters ---------- @@ -204,7 +216,8 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, reg : float Regularization term >0 method : str - method used for the solver either 'sinkhorn', 'sinkhorn_stabilized', see those function for specific parameters + method used for the solver either 'sinkhorn','sinkhorn_log', + 'sinkhorn_stabilized', see those function for specific parameters numItermax : int, optional Max number of iterations stopThr : float, optional @@ -230,7 +243,7 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, >>> b=[.5, .5] >>> M=[[0., 1.], [1., 0.]] >>> ot.sinkhorn2(a, b, M, 1) - array([0.26894142]) + 0.26894142136999516 .. _references-sinkhorn2: @@ -243,7 +256,11 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. - .. [21] Altschuler J., Weed J., Rigollet P. : Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration, Advances in Neural Information Processing Systems (NIPS) 31, 2017 + .. [21] Altschuler J., Weed J., Rigollet P. : Near-linear time approximation + algorithms for optimal transport via Sinkhorn iteration, Advances in Neural + Information Processing Systems (NIPS) 31, 2017 + + .. [34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., Trouvé, A., & Peyré, G. (2019, April). Interpolating between optimal transport and MMD using Sinkhorn divergences. In The 22nd International Conference on Artificial Intelligence and Statistics (pp. 2681-2690). PMLR. @@ -257,20 +274,45 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, """ - b = list_to_array(b) + M, a, b = list_to_array(M, a, b) + nx = get_backend(M, a, b) + if len(b.shape) < 2: - b = b[:, None] + if method.lower() == 'sinkhorn': + res = sinkhorn_knopp(a, b, M, reg, numItermax=numItermax, + stopThr=stopThr, verbose=verbose, log=log, + **kwargs) + elif method.lower() == 'sinkhorn_log': + res = sinkhorn_log(a, b, M, reg, numItermax=numItermax, + stopThr=stopThr, verbose=verbose, log=log, + **kwargs) + elif method.lower() == 'sinkhorn_stabilized': + res = sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax, + stopThr=stopThr, verbose=verbose, log=log, + **kwargs) + else: + raise ValueError("Unknown method '%s'." % method) + if log: + return nx.sum(M * res[0]), res[1] + else: + return nx.sum(M * res) - if method.lower() == 'sinkhorn': - return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax, - stopThr=stopThr, verbose=verbose, log=log, - **kwargs) - elif method.lower() == 'sinkhorn_stabilized': - return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax, - stopThr=stopThr, verbose=verbose, log=log, - **kwargs) else: - raise ValueError("Unknown method '%s'." % method) + + if method.lower() == 'sinkhorn': + return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax, + stopThr=stopThr, verbose=verbose, log=log, + **kwargs) + elif method.lower() == 'sinkhorn_log': + return sinkhorn_log(a, b, M, reg, numItermax=numItermax, + stopThr=stopThr, verbose=verbose, log=log, + **kwargs) + elif method.lower() == 'sinkhorn_stabilized': + return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax, + stopThr=stopThr, verbose=verbose, log=log, + **kwargs) + else: + raise ValueError("Unknown method '%s'." % method) def sinkhorn_knopp(a, b, M, reg, numItermax=1000, @@ -361,7 +403,7 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, # init data dim_a = len(a) - dim_b = len(b) + dim_b = b.shape[0] if len(b.shape) > 1: n_hists = b.shape[1] @@ -438,6 +480,191 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, return u.reshape((-1, 1)) * K * v.reshape((1, -1)) +def sinkhorn_log(a, b, M, reg, numItermax=1000, + stopThr=1e-9, verbose=False, log=False, **kwargs): + r""" + Solve the entropic regularization optimal transport problem in log space + and return the OT matrix + + The function solves the following optimization problem: + + .. math:: + \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + + s.t. \gamma 1 = a + + \gamma^T 1= b + + \gamma\geq 0 + where : + + - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix + - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (histograms, both sum to 1) + + The algorithm used for solving the problem is the Sinkhorn-Knopp matrix + scaling algorithm :ref:`[2] ` with the + implementation from :ref:`[34] ` + + + Parameters + ---------- + a : array-like, shape (dim_a,) + samples weights in the source domain + b : array-like, shape (dim_b,) or array-like, shape (dim_b, n_hists) + samples in the target domain, compute sinkhorn with multiple targets + and fixed :math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix (return OT loss + dual variables in log) + M : array-like, shape (dim_a, dim_b) + loss matrix + reg : float + Regularization term >0 + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshold on error (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + Returns + ------- + gamma : array-like, shape (dim_a, dim_b) + Optimal transportation matrix for the given parameters + log : dict + log dictionary return only if log==True in parameters + + Examples + -------- + + >>> import ot + >>> a=[.5, .5] + >>> b=[.5, .5] + >>> M=[[0., 1.], [1., 0.]] + >>> ot.sinkhorn(a, b, M, 1) + array([[0.36552929, 0.13447071], + [0.13447071, 0.36552929]]) + + + .. _references-sinkhorn-log: + References + ---------- + + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal + Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 + + .. [34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., Trouvé, A., & Peyré, G. (2019, April). Interpolating between optimal transport and MMD using Sinkhorn divergences. In The 22nd International Conference on Artificial Intelligence and Statistics (pp. 2681-2690). PMLR. + + + See Also + -------- + ot.lp.emd : Unregularized OT + ot.optim.cg : General regularized OT + + """ + + a, b, M = list_to_array(a, b, M) + + nx = get_backend(M, a, b) + + if len(a) == 0: + a = nx.full((M.shape[0],), 1.0 / M.shape[0], type_as=M) + if len(b) == 0: + b = nx.full((M.shape[1],), 1.0 / M.shape[1], type_as=M) + + # init data + dim_a = len(a) + dim_b = b.shape[0] + + if len(b.shape) > 1: + n_hists = b.shape[1] + else: + n_hists = 0 + + if n_hists: # we do not want to use tensors sor we do a loop + + lst_loss = [] + lst_u = [] + lst_v = [] + + for k in range(n_hists): + res = sinkhorn_log(a, b[:, k], M, reg, numItermax=numItermax, + stopThr=stopThr, verbose=verbose, log=log, **kwargs) + + if log: + lst_loss.append(nx.sum(M * res[0])) + lst_u.append(res[1]['log_u']) + lst_v.append(res[1]['log_v']) + else: + lst_loss.append(nx.sum(M * res)) + res = nx.stack(lst_loss) + if log: + log = {'log_u': nx.stack(lst_u, 1), + 'log_v': nx.stack(lst_v, 1), } + log['u'] = nx.exp(log['log_u']) + log['v'] = nx.exp(log['log_v']) + return res, log + else: + return res + + else: + + if log: + log = {'err': []} + + Mr = M / (-reg) + + # we assume that no distances are null except those of the diagonal of + # distances + + u = nx.zeros(dim_a, type_as=M) + v = nx.zeros(dim_b, type_as=M) + + def get_logT(u, v): + if n_hists: + return Mr[:, :, None] + u + v + else: + return Mr + u[:, None] + v[None, :] + + loga = nx.log(a) + logb = nx.log(b) + + cpt = 0 + err = 1 + while (err > stopThr and cpt < numItermax): + + v = logb - nx.logsumexp(Mr + u[:, None], 0) + u = loga - nx.logsumexp(Mr + v[None, :], 1) + + if cpt % 10 == 0: + # we can speed up the process by checking for the error only all + # the 10th iterations + + # compute right marginal tmp2= (diag(u)Kdiag(v))^T1 + tmp2 = nx.sum(nx.exp(get_logT(u, v)), 0) + err = nx.norm(tmp2 - b) # violation of marginal + if log: + log['err'].append(err) + + if verbose: + if cpt % 200 == 0: + print( + '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(cpt, err)) + cpt = cpt + 1 + + if log: + log['log_u'] = u + log['log_v'] = v + log['u'] = nx.exp(u) + log['v'] = nx.exp(v) + + return nx.exp(get_logT(u, v)), log + + else: + return nx.exp(get_logT(u, v)) + + def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log=False): r""" @@ -1881,8 +2108,7 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', return (f, g) else: - M = dist(nx.to_numpy(X_s), nx.to_numpy(X_t), metric=metric) - M = nx.from_numpy(M, type_as=a) + M = dist(X_s, X_t, metric=metric) if log: pi, log = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=True, **kwargs) return pi, log @@ -2102,7 +2328,7 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli >>> X_s = np.reshape(np.arange(n_samples_a), (n_samples_a, 1)) >>> X_t = np.reshape(np.arange(0, n_samples_b), (n_samples_b, 1)) >>> empirical_sinkhorn_divergence(X_s, X_t, reg) # doctest: +ELLIPSIS - array([1.499...]) + 1.499887176049052 References diff --git a/ot/dr.py b/ot/dr.py index 64588cf..de39662 100644 --- a/ot/dr.py +++ b/ot/dr.py @@ -209,11 +209,11 @@ def projection_robust_wasserstein(X, Y, a, b, tau, U0=None, reg=0.1, k=2, stopTh .. math:: \max_{U \in St(d, k)} \min_{\pi \in \Pi(\mu,\nu)} \sum_{i,j} \pi_{i,j} \|U^T(x_i - y_j)\|^2 - reg * H(\pi) - + - :math:`U` is a linear projection operator in the Stiefel(d, k) manifold - :math:`H(\pi)` is entropy regularizer - :math:`x_i`, :math:`y_j` are samples of measures \mu and \nu respectively - + Parameters ---------- X : ndarray, shape (n, d) diff --git a/ot/gromov.py b/ot/gromov.py index 85b1549..33b4453 100644 --- a/ot/gromov.py +++ b/ot/gromov.py @@ -1030,7 +1030,7 @@ def entropic_gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, # compute the gradient tens = gwggrad(constC, hC1, hC2, T) - T = sinkhorn(p, q, tens, epsilon) + T = sinkhorn(p, q, tens, epsilon, method='sinkhorn') if cpt % 10 == 0: # we can speed up the process by checking for the error only all @@ -1204,7 +1204,7 @@ def entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, Cprev = C T = [entropic_gromov_wasserstein(Cs[s], C, ps[s], p, loss_fun, epsilon, - max_iter, 1e-5, verbose, log) for s in range(S)] + max_iter, 1e-4, verbose, log) for s in range(S)] if loss_fun == 'square_loss': C = update_square_loss(p, lambdas, T, Cs) diff --git a/ot/optim.py b/ot/optim.py index 6822e4e..34cbb17 100644 --- a/ot/optim.py +++ b/ot/optim.py @@ -20,7 +20,7 @@ from .backend import get_backend def line_search_armijo(f, xk, pk, gfk, old_fval, args=(), c1=1e-4, alpha0=0.99): - """ + r""" Armijo linesearch function that works with matrices Find an approximate minimum of :math:`f(x_k + \\alpha \cdot p_k)` that satisfies the @@ -447,7 +447,7 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10, def solve_1d_linesearch_quad(a, b, c): - """ + r""" For any convex or non-convex 1d quadratic function `f`, solve the following problem: .. math:: diff --git a/ot/utils.py b/ot/utils.py index 6a782e6..0608aee 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -183,7 +183,7 @@ def euclidean_distances(X, Y, squared=False): return c -def dist(x1, x2=None, metric='sqeuclidean'): +def dist(x1, x2=None, metric='sqeuclidean', p=2): """Compute distance between samples in x1 and x2 .. note:: This function is backend-compatible and will work on arrays @@ -222,7 +222,7 @@ def dist(x1, x2=None, metric='sqeuclidean'): if not get_backend(x1, x2).__name__ == 'numpy': raise NotImplementedError() else: - return cdist(x1, x2, metric=metric) + return cdist(x1, x2, metric=metric, p=p) def dist0(n, method='lin_square'): diff --git a/test/test_bregman.py b/test/test_bregman.py index 942cb6d..c1120ba 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -32,6 +32,27 @@ def test_sinkhorn(): u, G.sum(0), atol=1e-05) # cf convergence sinkhorn +def test_sinkhorn_multi_b(): + # test sinkhorn + n = 10 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + u = ot.utils.unif(n) + + b = rng.rand(n, 3) + b = b / np.sum(b, 0, keepdims=True) + + M = ot.dist(x, x) + + loss0, log = ot.sinkhorn(u, b, M, .1, stopThr=1e-10, log=True) + + loss = [ot.sinkhorn2(u, b[:, k], M, .1, stopThr=1e-10) for k in range(3)] + # check constraints + np.testing.assert_allclose( + loss0, loss, atol=1e-06) # cf convergence sinkhorn + + def test_sinkhorn_backends(nx): n_samples = 100 n_features = 2 @@ -147,6 +168,7 @@ def test_sinkhorn_variants(nx): Mb = nx.from_numpy(M) G = ot.sinkhorn(u, u, M, 1, method='sinkhorn', stopThr=1e-10) + Gl = nx.to_numpy(ot.sinkhorn(ub, ub, Mb, 1, method='sinkhorn_log', stopThr=1e-10)) G0 = nx.to_numpy(ot.sinkhorn(ub, ub, Mb, 1, method='sinkhorn', stopThr=1e-10)) Gs = nx.to_numpy(ot.sinkhorn(ub, ub, Mb, 1, method='sinkhorn_stabilized', stopThr=1e-10)) Ges = nx.to_numpy(ot.sinkhorn( @@ -155,15 +177,73 @@ def test_sinkhorn_variants(nx): # check values np.testing.assert_allclose(G, G0, atol=1e-05) + np.testing.assert_allclose(G, Gl, atol=1e-05) np.testing.assert_allclose(G0, Gs, atol=1e-05) np.testing.assert_allclose(G0, Ges, atol=1e-05) np.testing.assert_allclose(G0, G_green, atol=1e-5) - print(G0, G_green) + + +@pytest.skip_backend("jax") +def test_sinkhorn_variants_multi_b(nx): + # test sinkhorn + n = 50 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + u = ot.utils.unif(n) + + b = rng.rand(n, 3) + b = b / np.sum(b, 0, keepdims=True) + + M = ot.dist(x, x) + + ub = nx.from_numpy(u) + bb = nx.from_numpy(b) + Mb = nx.from_numpy(M) + + G = ot.sinkhorn(u, b, M, 1, method='sinkhorn', stopThr=1e-10) + Gl = nx.to_numpy(ot.sinkhorn(ub, bb, Mb, 1, method='sinkhorn_log', stopThr=1e-10)) + G0 = nx.to_numpy(ot.sinkhorn(ub, bb, Mb, 1, method='sinkhorn', stopThr=1e-10)) + Gs = nx.to_numpy(ot.sinkhorn(ub, bb, Mb, 1, method='sinkhorn_stabilized', stopThr=1e-10)) + + # check values + np.testing.assert_allclose(G, G0, atol=1e-05) + np.testing.assert_allclose(G, Gl, atol=1e-05) + np.testing.assert_allclose(G0, Gs, atol=1e-05) + + +@pytest.skip_backend("jax") +def test_sinkhorn2_variants_multi_b(nx): + # test sinkhorn + n = 50 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + u = ot.utils.unif(n) + + b = rng.rand(n, 3) + b = b / np.sum(b, 0, keepdims=True) + + M = ot.dist(x, x) + + ub = nx.from_numpy(u) + bb = nx.from_numpy(b) + Mb = nx.from_numpy(M) + + G = ot.sinkhorn2(u, b, M, 1, method='sinkhorn', stopThr=1e-10) + Gl = nx.to_numpy(ot.sinkhorn2(ub, bb, Mb, 1, method='sinkhorn_log', stopThr=1e-10)) + G0 = nx.to_numpy(ot.sinkhorn2(ub, bb, Mb, 1, method='sinkhorn', stopThr=1e-10)) + Gs = nx.to_numpy(ot.sinkhorn2(ub, bb, Mb, 1, method='sinkhorn_stabilized', stopThr=1e-10)) + + # check values + np.testing.assert_allclose(G, G0, atol=1e-05) + np.testing.assert_allclose(G, Gl, atol=1e-05) + np.testing.assert_allclose(G0, Gs, atol=1e-05) def test_sinkhorn_variants_log(): # test sinkhorn - n = 100 + n = 50 rng = np.random.RandomState(0) x = rng.randn(n, 2) @@ -172,6 +252,7 @@ def test_sinkhorn_variants_log(): M = ot.dist(x, x) G0, log0 = ot.sinkhorn(u, u, M, 1, method='sinkhorn', stopThr=1e-10, log=True) + Gl, logl = ot.sinkhorn(u, u, M, 1, method='sinkhorn_log', stopThr=1e-10, log=True) Gs, logs = ot.sinkhorn(u, u, M, 1, method='sinkhorn_stabilized', stopThr=1e-10, log=True) Ges, loges = ot.sinkhorn( u, u, M, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10, log=True) @@ -179,9 +260,30 @@ def test_sinkhorn_variants_log(): # check values np.testing.assert_allclose(G0, Gs, atol=1e-05) + np.testing.assert_allclose(G0, Gl, atol=1e-05) np.testing.assert_allclose(G0, Ges, atol=1e-05) np.testing.assert_allclose(G0, G_green, atol=1e-5) - print(G0, G_green) + + +def test_sinkhorn_variants_log_multib(): + # test sinkhorn + n = 50 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + u = ot.utils.unif(n) + b = rng.rand(n, 3) + b = b / np.sum(b, 0, keepdims=True) + + M = ot.dist(x, x) + + G0, log0 = ot.sinkhorn(u, b, M, 1, method='sinkhorn', stopThr=1e-10, log=True) + Gl, logl = ot.sinkhorn(u, b, M, 1, method='sinkhorn_log', stopThr=1e-10, log=True) + Gs, logs = ot.sinkhorn(u, b, M, 1, method='sinkhorn_stabilized', stopThr=1e-10, log=True) + + # check values + np.testing.assert_allclose(G0, Gs, atol=1e-05) + np.testing.assert_allclose(G0, Gl, atol=1e-05) @pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"]) @@ -326,10 +428,10 @@ def test_empirical_sinkhorn(nx): a = ot.unif(n) b = ot.unif(n) - X_s = np.reshape(np.arange(n), (n, 1)) - X_t = np.reshape(np.arange(0, n), (n, 1)) + X_s = np.reshape(1.0 * np.arange(n), (n, 1)) + X_t = np.reshape(1.0 * np.arange(0, n), (n, 1)) M = ot.dist(X_s, X_t) - M_m = ot.dist(X_s, X_t, metric='minkowski') + M_m = ot.dist(X_s, X_t, metric='euclidean') ab = nx.from_numpy(a) bb = nx.from_numpy(b) @@ -346,7 +448,7 @@ def test_empirical_sinkhorn(nx): sinkhorn_log, log_s = ot.sinkhorn(ab, bb, Mb, 0.1, log=True) sinkhorn_log = nx.to_numpy(sinkhorn_log) - G_m = nx.to_numpy(ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1, metric='minkowski')) + G_m = nx.to_numpy(ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1, metric='euclidean')) sinkhorn_m = nx.to_numpy(ot.sinkhorn(ab, bb, M_mb, 1)) loss_emp_sinkhorn = nx.to_numpy(ot.bregman.empirical_sinkhorn2(X_sb, X_tb, 1)) @@ -378,7 +480,7 @@ def test_lazy_empirical_sinkhorn(nx): X_s = np.reshape(np.arange(n), (n, 1)) X_t = np.reshape(np.arange(0, n), (n, 1)) M = ot.dist(X_s, X_t) - M_m = ot.dist(X_s, X_t, metric='minkowski') + M_m = ot.dist(X_s, X_t, metric='euclidean') ab = nx.from_numpy(a) bb = nx.from_numpy(b) @@ -398,7 +500,7 @@ def test_lazy_empirical_sinkhorn(nx): sinkhorn_log, log_s = ot.sinkhorn(ab, bb, Mb, 0.1, log=True) sinkhorn_log = nx.to_numpy(sinkhorn_log) - f, g = ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1, metric='minkowski', numIterMax=numIterMax, isLazy=True, batchSize=1) + f, g = ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1, metric='euclidean', numIterMax=numIterMax, isLazy=True, batchSize=1) f, g = nx.to_numpy(f), nx.to_numpy(g) G_m = np.exp(f[:, None] + g[None, :] - M_m / 1) sinkhorn_m = nx.to_numpy(ot.sinkhorn(ab, bb, M_mb, 1)) diff --git a/test/test_gromov.py b/test/test_gromov.py index 19d61b1..0242d72 100644 --- a/test/test_gromov.py +++ b/test/test_gromov.py @@ -180,8 +180,8 @@ def test_sampled_gromov(): def test_gromov_barycenter(): - ns = 50 - nt = 60 + ns = 10 + nt = 20 Xs, ys = ot.datasets.make_data_classif('3gauss', ns, random_state=42) Xt, yt = ot.datasets.make_data_classif('3gauss2', nt, random_state=42) @@ -208,8 +208,8 @@ def test_gromov_barycenter(): @pytest.mark.filterwarnings("ignore:divide") def test_gromov_entropic_barycenter(): - ns = 20 - nt = 30 + ns = 10 + nt = 20 Xs, ys = ot.datasets.make_data_classif('3gauss', ns, random_state=42) Xt, yt = ot.datasets.make_data_classif('3gauss2', nt, random_state=42) @@ -222,7 +222,7 @@ def test_gromov_entropic_barycenter(): [ot.unif(ns), ot.unif(nt) ], ot.unif(n_samples), [.5, .5], 'square_loss', 1e-3, - max_iter=50, tol=1e-5, + max_iter=50, tol=1e-3, verbose=True) np.testing.assert_allclose(Cb.shape, (n_samples, n_samples)) diff --git a/test/test_helpers.py b/test/test_helpers.py index 8bd0015..cc4c90e 100644 --- a/test/test_helpers.py +++ b/test/test_helpers.py @@ -9,8 +9,8 @@ import sys sys.path.append(os.path.join("ot", "helpers")) -from openmp_helpers import get_openmp_flag, check_openmp_support # noqa -from pre_build_helpers import _get_compiler, compile_test_program # noqa +from openmp_helpers import get_openmp_flag, check_openmp_support # noqa +from pre_build_helpers import _get_compiler, compile_test_program # noqa def test_helpers(): diff --git a/test/test_utils.py b/test/test_utils.py index 60ad5d3..0650ce2 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -7,6 +7,7 @@ import ot import numpy as np import sys +import pytest def test_proj_simplex(nx): @@ -108,6 +109,10 @@ def test_dist(): D2 = ot.dist(x, x) D3 = ot.dist(x) + D4 = ot.dist(x, x, metric='minkowski', p=0.5) + + assert D4[0, 1] == D4[1, 0] + # dist shoul return squared euclidean np.testing.assert_allclose(D, D2, atol=1e-14) np.testing.assert_allclose(D, D3, atol=1e-14) @@ -220,6 +225,13 @@ def test_deprecated_func(): class Class(): pass + with pytest.warns(DeprecationWarning): + fun() + + with pytest.warns(DeprecationWarning): + cl = Class() + print(cl) + if sys.version_info < (3, 5): print('Not tested') else: @@ -250,4 +262,7 @@ def test_BaseEstimator(): params['first'] = 'spam again' cl.set_params(**params) + with pytest.raises(ValueError): + cl.set_params(bibi=10) + assert cl.first == 'spam again' -- cgit v1.2.3 From 0cb2b2efe901ed74c614046d250518769f870313 Mon Sep 17 00:00:00 2001 From: Rémi Flamary Date: Fri, 29 Oct 2021 18:39:13 +0200 Subject: [MRG] Add tesing on wda (#296) --- test/test_dr.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) (limited to 'test') diff --git a/test/test_dr.py b/test/test_dr.py index fa75a18..741f2ad 100644 --- a/test/test_dr.py +++ b/test/test_dr.py @@ -60,6 +60,31 @@ def test_wda(): np.testing.assert_allclose(np.sum(Pwda**2, 0), np.ones(p)) +@pytest.mark.skipif(nogo, reason="Missing modules (autograd or pymanopt)") +def test_wda_normalized(): + + n_samples = 100 # nb samples in source and target datasets + np.random.seed(0) + + # generate gaussian dataset + xs, ys = ot.datasets.make_data_classif('gaussrot', n_samples) + + n_features_noise = 8 + + xs = np.hstack((xs, np.random.randn(n_samples, n_features_noise))) + + p = 2 + + P0 = np.random.randn(10, p) + P0 /= P0.sum(0, keepdims=True) + + Pwda, projwda = ot.dr.wda(xs, ys, p, maxiter=10, P0=P0, normalize=True) + + projwda(xs) + + np.testing.assert_allclose(np.sum(Pwda**2, 0), np.ones(p)) + + @pytest.mark.skipif(nogo, reason="Missing modules (autograd or pymanopt)") def test_prw(): d = 100 # Dimension -- cgit v1.2.3 From a335324d008e8982be61d7ace937815a2bfa98f9 Mon Sep 17 00:00:00 2001 From: ncassereau-idris <84033440+ncassereau-idris@users.noreply.github.com> Date: Tue, 2 Nov 2021 13:42:02 +0100 Subject: [MRG] Backend for gromov (#294) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * bregman: small correction * gromov backend first draft * Removing decorators * Reworked casting method * Bug solve * Removing casting * Bug solve * toarray renamed todense ; expand_dims removed * Warning (jax not supporting sparse matrix) moved * Mistake corrected * test backend * Sparsity test for older versions of pytorch * Trying pytorch/1.10 * Attempt to correct torch sparse bug * Backend version of gromov tests * Random state introduced for remaining gromov functions * review changes * code coverage * Docs (first draft, to be continued) * Gromov docs * Prettified docs * mistake corrected in the docs * little change Co-authored-by: Rémi Flamary --- ot/backend.py | 214 ++++++++- ot/bregman.py | 184 ++++---- ot/gromov.py | 1220 +++++++++++++++++++++++++++----------------------- ot/lp/__init__.py | 58 ++- ot/optim.py | 22 +- test/test_backend.py | 56 +++ test/test_bregman.py | 4 +- test/test_gromov.py | 297 ++++++++---- 8 files changed, 1289 insertions(+), 766 deletions(-) (limited to 'test') diff --git a/ot/backend.py b/ot/backend.py index 876b96a..358297c 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -26,6 +26,7 @@ Examples import numpy as np import scipy.special as scipy +from scipy.sparse import issparse, coo_matrix, csr_matrix try: import torch @@ -539,6 +540,86 @@ class Backend(): """ raise NotImplementedError() + def coo_matrix(self, data, rows, cols, shape=None, type_as=None): + r""" + Creates a sparse tensor in COOrdinate format. + + This function follows the api from :any:`scipy.sparse.coo_matrix` + + See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.coo_matrix.html + """ + raise NotImplementedError() + + def issparse(self, a): + r""" + Checks whether or not the input tensor is a sparse tensor. + + This function follows the api from :any:`scipy.sparse.issparse` + + See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.issparse.html + """ + raise NotImplementedError() + + def tocsr(self, a): + r""" + Converts this matrix to Compressed Sparse Row format. + + This function follows the api from :any:`scipy.sparse.coo_matrix.tocsr` + + See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.coo_matrix.tocsr.html + """ + raise NotImplementedError() + + def eliminate_zeros(self, a, threshold=0.): + r""" + Removes entries smaller than the given threshold from the sparse tensor. + + This function follows the api from :any:`scipy.sparse.csr_matrix.eliminate_zeros` + + See: https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.sparse.csr_matrix.eliminate_zeros.html + """ + raise NotImplementedError() + + def todense(self, a): + r""" + Converts a sparse tensor to a dense tensor. + + This function follows the api from :any:`scipy.sparse.csr_matrix.toarray` + + See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.csr_matrix.toarray.html + """ + raise NotImplementedError() + + def where(self, condition, x, y): + r""" + Returns elements chosen from x or y depending on condition. + + This function follows the api from :any:`numpy.where` + + See: https://numpy.org/doc/stable/reference/generated/numpy.where.html + """ + raise NotImplementedError() + + def copy(self, a): + r""" + Returns a copy of the given tensor. + + This function follows the api from :any:`numpy.copy` + + See: https://numpy.org/doc/stable/reference/generated/numpy.copy.html + """ + raise NotImplementedError() + + def allclose(self, a, b, rtol=1e-05, atol=1e-08, equal_nan=False): + r""" + Returns True if two arrays are element-wise equal within a tolerance. + + This function follows the api from :any:`numpy.allclose` + + See: https://numpy.org/doc/stable/reference/generated/numpy.allclose.html + """ + raise NotImplementedError() + class NumpyBackend(Backend): """ @@ -712,6 +793,46 @@ class NumpyBackend(Backend): def reshape(self, a, shape): return np.reshape(a, shape) + def coo_matrix(self, data, rows, cols, shape=None, type_as=None): + if type_as is None: + return coo_matrix((data, (rows, cols)), shape=shape) + else: + return coo_matrix((data, (rows, cols)), shape=shape, dtype=type_as.dtype) + + def issparse(self, a): + return issparse(a) + + def tocsr(self, a): + if self.issparse(a): + return a.tocsr() + else: + return csr_matrix(a) + + def eliminate_zeros(self, a, threshold=0.): + if threshold > 0: + if self.issparse(a): + a.data[self.abs(a.data) <= threshold] = 0 + else: + a[self.abs(a) <= threshold] = 0 + if self.issparse(a): + a.eliminate_zeros() + return a + + def todense(self, a): + if self.issparse(a): + return a.toarray() + else: + return a + + def where(self, condition, x, y): + return np.where(condition, x, y) + + def copy(self, a): + return a.copy() + + def allclose(self, a, b, rtol=1e-05, atol=1e-08, equal_nan=False): + return np.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan) + class JaxBackend(Backend): """ @@ -889,6 +1010,48 @@ class JaxBackend(Backend): def reshape(self, a, shape): return jnp.reshape(a, shape) + def coo_matrix(self, data, rows, cols, shape=None, type_as=None): + # Currently, JAX does not support sparse matrices + data = self.to_numpy(data) + rows = self.to_numpy(rows) + cols = self.to_numpy(cols) + nx = NumpyBackend() + coo_matrix = nx.coo_matrix(data, rows, cols, shape=shape, type_as=type_as) + matrix = nx.todense(coo_matrix) + return self.from_numpy(matrix) + + def issparse(self, a): + # Currently, JAX does not support sparse matrices + return False + + def tocsr(self, a): + # Currently, JAX does not support sparse matrices + return a + + def eliminate_zeros(self, a, threshold=0.): + # Currently, JAX does not support sparse matrices + if threshold > 0: + return self.where( + self.abs(a) <= threshold, + self.zeros((1,), type_as=a), + a + ) + return a + + def todense(self, a): + # Currently, JAX does not support sparse matrices + return a + + def where(self, condition, x, y): + return jnp.where(condition, x, y) + + def copy(self, a): + # No need to copy, JAX arrays are immutable + return a + + def allclose(self, a, b, rtol=1e-05, atol=1e-08, equal_nan=False): + return jnp.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan) + class TorchBackend(Backend): """ @@ -999,7 +1162,7 @@ class TorchBackend(Backend): a = torch.tensor([float(a)], dtype=b.dtype, device=b.device) if isinstance(b, int) or isinstance(b, float): b = torch.tensor([float(b)], dtype=a.dtype, device=a.device) - if torch.__version__ >= '1.7.0': + if hasattr(torch, "maximum"): return torch.maximum(a, b) else: return torch.max(torch.stack(torch.broadcast_tensors(a, b)), axis=0)[0] @@ -1009,7 +1172,7 @@ class TorchBackend(Backend): a = torch.tensor([float(a)], dtype=b.dtype, device=b.device) if isinstance(b, int) or isinstance(b, float): b = torch.tensor([float(b)], dtype=a.dtype, device=a.device) - if torch.__version__ >= '1.7.0': + if hasattr(torch, "minimum"): return torch.minimum(a, b) else: return torch.min(torch.stack(torch.broadcast_tensors(a, b)), axis=0)[0] @@ -1129,3 +1292,50 @@ class TorchBackend(Backend): def reshape(self, a, shape): return torch.reshape(a, shape) + + def coo_matrix(self, data, rows, cols, shape=None, type_as=None): + if type_as is None: + return torch.sparse_coo_tensor(torch.stack([rows, cols]), data, size=shape) + else: + return torch.sparse_coo_tensor( + torch.stack([rows, cols]), data, size=shape, + dtype=type_as.dtype, device=type_as.device + ) + + def issparse(self, a): + return getattr(a, "is_sparse", False) or getattr(a, "is_sparse_csr", False) + + def tocsr(self, a): + # Versions older than 1.9 do not support CSR tensors. PyTorch 1.9 and 1.10 offer a very limited support + return self.todense(a) + + def eliminate_zeros(self, a, threshold=0.): + if self.issparse(a): + if threshold > 0: + mask = self.abs(a) <= threshold + mask = ~mask + mask = mask.nonzero() + else: + mask = a._values().nonzero() + nv = a._values().index_select(0, mask.view(-1)) + ni = a._indices().index_select(1, mask.view(-1)) + return self.coo_matrix(nv, ni[0], ni[1], shape=a.shape, type_as=a) + else: + if threshold > 0: + a[self.abs(a) <= threshold] = 0 + return a + + def todense(self, a): + if self.issparse(a): + return a.to_dense() + else: + return a + + def where(self, condition, x, y): + return torch.where(condition, x, y) + + def copy(self, a): + return torch.clone(a) + + def allclose(self, a, b, rtol=1e-05, atol=1e-08, equal_nan=False): + return torch.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan) diff --git a/ot/bregman.py b/ot/bregman.py index 2aa76ff..0499b8e 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -32,13 +32,14 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, The function solves the following optimization problem: .. math:: - \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + \gamma = \mathop{\arg \min}_\gamma <\gamma, \mathbf{M}>_F + \mathrm{reg}\cdot\Omega(\gamma) - s.t. \ \gamma 1 = a + s.t. \ \gamma \mathbf{1} &= \mathbf{a} - \gamma^T 1= b + \gamma^T \mathbf{1} &= \mathbf{b} + + \gamma &\geq 0 - \gamma\geq 0 where : - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix @@ -167,13 +168,14 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, The function solves the following optimization problem: .. math:: - W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + W = \min_\gamma <\gamma, \mathbf{M}>_F + \mathrm{reg}\cdot\Omega(\gamma) + + s.t. \ \gamma \mathbf{1} &= \mathbf{a} - s.t. \ \gamma 1 = a + \gamma^T \mathbf{1} &= \mathbf{b} - \gamma^T 1= b + \gamma &\geq 0 - \gamma\geq 0 where : - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix @@ -323,13 +325,13 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, The function solves the following optimization problem: .. math:: - \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + \gamma = \mathop{\arg \min}_\gamma <\gamma, \mathbf{M}>_F + \mathrm{reg}\cdot\Omega(\gamma) - s.t. \gamma 1 = a + s.t. \ \gamma \mathbf{1} &= \mathbf{a} - \gamma^T 1= b + \gamma^T \mathbf{1} &= \mathbf{b} - \gamma\geq 0 + \gamma &\geq 0 where : - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix @@ -489,13 +491,13 @@ def sinkhorn_log(a, b, M, reg, numItermax=1000, The function solves the following optimization problem: .. math:: - \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + \gamma = \mathop{\arg \min}_\gamma <\gamma, \mathbf{M}>_F + \mathrm{reg}\cdot\Omega(\gamma) - s.t. \gamma 1 = a + s.t. \ \gamma \mathbf{1} &= \mathbf{a} - \gamma^T 1= b + \gamma^T \mathbf{1} &= \mathbf{b} - \gamma\geq 0 + \gamma &\geq 0 where : - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix @@ -550,8 +552,7 @@ def sinkhorn_log(a, b, M, reg, numItermax=1000, References ---------- - .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal - Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 .. [34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., Trouvé, A., & Peyré, G. (2019, April). Interpolating between optimal transport and MMD using Sinkhorn divergences. In The 22nd International Conference on Artificial Intelligence and Statistics (pp. 2681-2690). PMLR. @@ -675,13 +676,13 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, The function solves the following optimization problem: .. math:: - \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + \gamma = \mathop{\arg \min}_\gamma <\gamma, \mathbf{M}>_F + \mathrm{reg}\cdot\Omega(\gamma) - s.t. \ \gamma 1 = a + s.t. \ \gamma \mathbf{1} &= \mathbf{a} - \gamma^T 1= b + \gamma^T \mathbf{1} &= \mathbf{b} - \gamma\geq 0 + \gamma &\geq 0 where : - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix @@ -820,13 +821,13 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, The function solves the following optimization problem: .. math:: - \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + \gamma = \mathop{\arg \min}_\gamma <\gamma, \mathbf{M}>_F + \mathrm{reg}\cdot\Omega(\gamma) - s.t. \ \gamma 1 = a + s.t. \ \gamma \mathbf{1} &= \mathbf{a} - \gamma^T 1= b + \gamma^T \mathbf{1} &= \mathbf{b} - \gamma\geq 0 + \gamma &\geq 0 where : - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix @@ -965,7 +966,7 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, # remove numerical problems and store them in K if nx.max(nx.abs(u)) > tau or nx.max(nx.abs(v)) > tau: if n_hists: - alpha, beta = alpha + reg * nx.max(nx.log(u), 1), beta + reg * nx.max(np.log(v)) + alpha, beta = alpha + reg * nx.max(nx.log(u), 1), beta + reg * nx.max(nx.log(v)) else: alpha, beta = alpha + reg * nx.log(u), beta + reg * nx.log(v) if n_hists: @@ -1055,13 +1056,13 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, The function solves the following optimization problem: .. math:: - \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + \gamma = \mathop{\arg \min}_\gamma <\gamma, \mathbf{M}>_F + \mathrm{reg}\cdot\Omega(\gamma) - s.t. \ \gamma 1 = a + s.t. \ \gamma \mathbf{1} &= \mathbf{a} - \gamma^T 1= b + \gamma^T \mathbf{1} &= \mathbf{b} - \gamma\geq 0 + \gamma &\geq 0 where : - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix @@ -1245,12 +1246,12 @@ def projC(gamma, q): def barycenter(A, M, reg, weights=None, method="sinkhorn", numItermax=10000, stopThr=1e-4, verbose=False, log=False, **kwargs): - r"""Compute the entropic regularized wasserstein barycenter of distributions A + r"""Compute the entropic regularized wasserstein barycenter of distributions :math:`\mathbf{A}` The function solves the following optimization problem: .. math:: - \mathbf{a} = arg\min_\mathbf{a} \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i) + \mathbf{a} = \mathop{\arg \min}_\mathbf{a} \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i) where : @@ -1263,7 +1264,7 @@ def barycenter(A, M, reg, weights=None, method="sinkhorn", numItermax=10000, Parameters ---------- A : array-like, shape (dim, n_hists) - `n_hists` training distributions :math:`a_i` of size `dim` + `n_hists` training distributions :math:`\mathbf{a}_i` of size `dim` M : array-like, shape (dim, dim) loss matrix for OT reg : float @@ -1271,7 +1272,7 @@ def barycenter(A, M, reg, weights=None, method="sinkhorn", numItermax=10000, method : str (optional) method used for the solver either 'sinkhorn' or 'sinkhorn_stabilized' weights : array-like, shape (n_hists,) - Weights of each histogram :math:`a_i` on the simplex (barycentric coodinates) + Weights of each histogram :math:`\mathbf{a}_i` on the simplex (barycentric coodinates) numItermax : int, optional Max number of iterations stopThr : float, optional @@ -1314,12 +1315,12 @@ def barycenter(A, M, reg, weights=None, method="sinkhorn", numItermax=10000, def barycenter_sinkhorn(A, M, reg, weights=None, numItermax=1000, stopThr=1e-4, verbose=False, log=False): - r"""Compute the entropic regularized wasserstein barycenter of distributions A + r"""Compute the entropic regularized wasserstein barycenter of distributions :math:`\mathbf{A}` The function solves the following optimization problem: .. math:: - \mathbf{a} = arg\min_\mathbf{a} \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i) + \mathbf{a} = \mathop{\arg \min}_\mathbf{a} \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i) where : @@ -1332,13 +1333,13 @@ def barycenter_sinkhorn(A, M, reg, weights=None, numItermax=1000, Parameters ---------- A : array-like, shape (dim, n_hists) - `n_hists` training distributions :math:`a_i` of size `dim` + `n_hists` training distributions :math:`\mathbf{a}_i` of size `dim` M : array-like, shape (dim, dim) loss matrix for OT reg : float Regularization term > 0 weights : array-like, shape (n_hists,) - Weights of each histogram :math:`a_i` on the simplex (barycentric coodinates) + Weights of each histogram :math:`\mathbf{a}_i` on the simplex (barycentric coodinates) numItermax : int, optional Max number of iterations stopThr : float, optional @@ -1414,12 +1415,12 @@ def barycenter_sinkhorn(A, M, reg, weights=None, numItermax=1000, def barycenter_stabilized(A, M, reg, tau=1e10, weights=None, numItermax=1000, stopThr=1e-4, verbose=False, log=False): - r"""Compute the entropic regularized wasserstein barycenter of distributions A with stabilization. + r"""Compute the entropic regularized wasserstein barycenter of distributions :math:`\mathbf{A}` with stabilization. The function solves the following optimization problem: .. math:: - \mathbf{a} = arg\min_\mathbf{a} \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i) + \mathbf{a} = \mathop{\arg \min}_\mathbf{a} \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i) where : @@ -1432,7 +1433,7 @@ def barycenter_stabilized(A, M, reg, tau=1e10, weights=None, numItermax=1000, Parameters ---------- A : array-like, shape (dim, n_hists) - `n_hists` training distributions :math:`a_i` of size `dim` + `n_hists` training distributions :math:`\mathbf{a}_i` of size `dim` M : array-like, shape (dim, dim) loss matrix for OT reg : float @@ -1440,7 +1441,7 @@ def barycenter_stabilized(A, M, reg, tau=1e10, weights=None, numItermax=1000, tau : float threshold for max value in :math:`\mathbf{u}` or :math:`\mathbf{v}` for log scaling weights : array-like, shape (n_hists,) - Weights of each histogram :math:`a_i` on the simplex (barycentric coodinates) + Weights of each histogram :math:`\mathbf{a}_i` on the simplex (barycentric coodinates) numItermax : int, optional Max number of iterations stopThr : float, optional @@ -1533,8 +1534,8 @@ def barycenter_stabilized(A, M, reg, tau=1e10, weights=None, numItermax=1000, "Or a larger absorption threshold `tau`.") if log: log['niter'] = cpt - log['logu'] = np.log(u + 1e-16) - log['logv'] = np.log(v + 1e-16) + log['logu'] = nx.log(u + 1e-16) + log['logv'] = nx.log(v + 1e-16) return q, log else: return q @@ -1543,13 +1544,13 @@ def barycenter_stabilized(A, M, reg, tau=1e10, weights=None, numItermax=1000, def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, stopThr=1e-9, stabThr=1e-30, verbose=False, log=False): - r"""Compute the entropic regularized wasserstein barycenter of distributions A - where A is a collection of 2D images. + r"""Compute the entropic regularized wasserstein barycenter of distributions :math:`\mathbf{A}` + where :math:`\mathbf{A}` is a collection of 2D images. The function solves the following optimization problem: .. math:: - \mathbf{a} = arg\min_\mathbf{a} \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i) + \mathbf{a} = \mathop{\arg \min}_\mathbf{a} \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i) where : @@ -1673,12 +1674,12 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000, .. math:: - \mathbf{h} = arg\min_\mathbf{h} (1- \alpha) W_{M,reg}(\mathbf{a},\mathbf{Dh})+\alpha W_{M_0,reg_0}(\mathbf{h}_0,\mathbf{h}) + \mathbf{h} = \mathop{\arg \min}_\mathbf{h} (1- \alpha) W_{\mathbf{M}, \mathrm{reg}}(\mathbf{a},\mathbf{Dh})+\alpha W_{\mathbf{M_0},\mathrm{reg}_0}(\mathbf{h}_0,\mathbf{h}) where : - - :math:`W_{M,reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance with M loss matrix (see :py:func:`ot.bregman.sinkhorn`) + - :math:`W_{M,reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance with :math:`\mathbf{M}` loss matrix (see :py:func:`ot.bregman.sinkhorn`) - :math:`\mathbf{D}` is a dictionary of `n_atoms` atoms of dimension `dim_a`, its expected shape is `(dim_a, n_atoms)` - :math:`\mathbf{h}` is the estimated unmixing of dimension `n_atoms` - :math:`\mathbf{a}` is an observed distribution of dimension `dim_a` @@ -1790,7 +1791,7 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, .. math:: - \mathbf{h} = arg\min_{\mathbf{h}}\quad \sum_{k=1}^{K} \lambda_k + \mathbf{h} = \mathop{\arg \min}_{\mathbf{h}} \sum_{k=1}^{K} \lambda_k W_{reg}((\mathbf{D}_2^{(k)} \mathbf{h})^T, \mathbf{a}) s.t. \ \forall k, \mathbf{D}_1^{(k)} \gamma_k \mathbf{1}_n= \mathbf{h} @@ -1898,7 +1899,7 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, K.append(Ktmp) # uniform target distribution - a = nx.from_numpy(unif(np.shape(Xt)[0])) + a = nx.from_numpy(unif(Xt.shape[0]), type_as=Xs[0]) cpt = 0 # iterations count err = 1 @@ -1956,13 +1957,13 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', The function solves the following optimization problem: .. math:: - \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + \gamma = \mathop{\arg \min}_\gamma <\gamma, \mathbf{M}>_F + \mathrm{reg} \cdot\Omega(\gamma) - s.t. \ \gamma 1 = a + s.t. \ \gamma \mathbf{1} &= a - \gamma^T 1= b + \gamma^T \mathbf{1} &= b - \gamma\geq 0 + \gamma &\geq 0 where : - :math:`\mathbf{M}` is the (`n_samples_a`, `n_samples_b`) metric cost matrix @@ -2010,8 +2011,8 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', >>> n_samples_a = 2 >>> n_samples_b = 2 >>> reg = 0.1 - >>> X_s = np.reshape(np.arange(n_samples_a), (n_samples_a, 1)) - >>> X_t = np.reshape(np.arange(0, n_samples_b), (n_samples_b, 1)) + >>> X_s = np.reshape(np.arange(n_samples_a, dtype=np.float64), (n_samples_a, 1)) + >>> X_t = np.reshape(np.arange(0, n_samples_b, dtype=np.float64), (n_samples_b, 1)) >>> empirical_sinkhorn(X_s, X_t, reg=reg, verbose=False) # doctest: +NORMALIZE_WHITESPACE array([[4.99977301e-01, 2.26989344e-05], [2.26989344e-05, 4.99977301e-01]]) @@ -2033,9 +2034,9 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', ns, nt = X_s.shape[0], X_t.shape[0] if a is None: - a = nx.from_numpy(unif(ns)) + a = nx.from_numpy(unif(ns), type_as=X_s) if b is None: - b = nx.from_numpy(unif(nt)) + b = nx.from_numpy(unif(nt), type_as=X_s) if isLazy: if log: @@ -2127,13 +2128,13 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num The function solves the following optimization problem: .. math:: - W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + W = \min_\gamma <\gamma, \mathbf{M}>_F + \mathrm{reg} \cdot\Omega(\gamma) - s.t. \ \gamma 1 = a + s.t. \ \gamma \mathbf{1} &= a - \gamma^T 1= b + \gamma^T \mathbf{1} &= b - \gamma\geq 0 + \gamma &\geq 0 where : - :math:`\mathbf{M}` is the (`n_samples_a`, `n_samples_b`) metric cost matrix @@ -2181,8 +2182,8 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num >>> n_samples_a = 2 >>> n_samples_b = 2 >>> reg = 0.1 - >>> X_s = np.reshape(np.arange(n_samples_a), (n_samples_a, 1)) - >>> X_t = np.reshape(np.arange(0, n_samples_b), (n_samples_b, 1)) + >>> X_s = np.reshape(np.arange(n_samples_a, dtype=np.float64), (n_samples_a, 1)) + >>> X_t = np.reshape(np.arange(0, n_samples_b, dtype=np.float64), (n_samples_b, 1)) >>> b = np.full((n_samples_b, 3), 1/n_samples_b) >>> empirical_sinkhorn2(X_s, X_t, b=b, reg=reg, verbose=False) array([4.53978687e-05, 4.53978687e-05, 4.53978687e-05]) @@ -2204,9 +2205,9 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num ns, nt = X_s.shape[0], X_t.shape[0] if a is None: - a = nx.from_numpy(unif(ns)) + a = nx.from_numpy(unif(ns), type_as=X_s) if b is None: - b = nx.from_numpy(unif(nt)) + b = nx.from_numpy(unif(nt), type_as=X_s) if isLazy: if log: @@ -2259,32 +2260,32 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli .. math:: - W &= \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + W &= \min_\gamma <\gamma, \mathbf{M}>_F + \mathrm{reg} \cdot\Omega(\gamma) - W_a &= \min_{\gamma_a} <\gamma_a,M_a>_F + reg\cdot\Omega(\gamma_a) + W_a &= \min_{\gamma_a} <\gamma_a, \mathbf{M_a}>_F + \mathrm{reg} \cdot\Omega(\gamma_a) - W_b &= \min_{\gamma_b} <\gamma_b,M_b>_F + reg\cdot\Omega(\gamma_b) + W_b &= \min_{\gamma_b} <\gamma_b, \mathbf{M_b}>_F + \mathrm{reg} \cdot\Omega(\gamma_b) S &= W - \frac{W_a + W_b}{2} .. math:: - s.t. \ \gamma 1 = a + s.t. \ \gamma \mathbf{1} &= a - \gamma^T 1= b + \gamma^T \mathbf{1} &= b - \gamma\geq 0 + \gamma &\geq 0 - \gamma_a 1 = a + \gamma_a \mathbf{1} &= \mathbf{a} - \gamma_a^T 1= a + \gamma_a^T \mathbf{1} &= \mathbf{a} - \gamma_a\geq 0 + \gamma_a &\geq 0 - \gamma_b 1 = b + \gamma_b \mathbf{1} &= \mathbf{b} - \gamma_b^T 1= b + \gamma_b^T \mathbf{1} &= \mathbf{b} - \gamma_b\geq 0 + \gamma_b &\geq 0 where : - :math:`\mathbf{M}` (resp. :math:`\mathbf{M_a}`, :math:`\mathbf{M_b}`) is the (`n_samples_a`, `n_samples_b`) metric cost matrix (resp (`n_samples_a, n_samples_a`) and (`n_samples_b`, `n_samples_b`)) @@ -2325,8 +2326,8 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli >>> n_samples_a = 2 >>> n_samples_b = 4 >>> reg = 0.1 - >>> X_s = np.reshape(np.arange(n_samples_a), (n_samples_a, 1)) - >>> X_t = np.reshape(np.arange(0, n_samples_b), (n_samples_b, 1)) + >>> X_s = np.reshape(np.arange(n_samples_a, dtype=np.float64), (n_samples_a, 1)) + >>> X_t = np.reshape(np.arange(0, n_samples_b, dtype=np.float64), (n_samples_b, 1)) >>> empirical_sinkhorn_divergence(X_s, X_t, reg) # doctest: +ELLIPSIS 1.499887176049052 @@ -2380,19 +2381,19 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res .. math:: - (u, v) = arg\min_{u, v} 1_{ns}^T B(u,v) 1_{nt} - <\kappa u, a> - + (\mathbf{u}, \mathbf{v}) = \mathop{\arg \min}_{\mathbf{u}, \mathbf{v}} \ \mathbf{1}_{ns}^T \mathbf{B}(\mathbf{u}, \mathbf{v}) \mathbf{1}_{nt} - <\kappa \mathbf{u}, \mathbf{a}> - <\mathbf{v} / \kappa, \mathbf{b}> where: .. math:: - B(u,v) = \mathrm{diag}(e^u) K \mathrm{diag}(e^v) \text{, with } K = e^{-M/reg} \text{ and} + \mathbf{B}(\mathbf{u}, \mathbf{v}) = \mathrm{diag}(e^\mathbf{u}) \mathbf{K} \mathrm{diag}(e^\mathbf{v}) \text{, with } \mathbf{K} = e^{-\mathbf{M} / \mathrm{reg}} \text{ and} .. math:: - s.t. \ e^{u_i} \geq \epsilon / \kappa, \forall i \in \{1, \ldots, ns\} + s.t. \ e^{u_i} &\geq \epsilon / \kappa, \forall i \in \{1, \ldots, ns\} - e^{v_j} \geq \epsilon \kappa, \forall j \in \{1, \ldots, nt\} + e^{v_j} &\geq \epsilon \kappa, \forall j \in \{1, \ldots, nt\} The parameters `kappa` and `epsilon` are determined w.r.t the couple number budget of points (`ns_budget`, `nt_budget`), see Equation (5) in :ref:`[26] ` @@ -2531,7 +2532,8 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res epsilon_u_square = a[0] / aK_sort[ns_budget - 1] else: aK_sort = nx.from_numpy( - bottleneck.partition(nx.to_numpy(K_sum_cols), ns_budget - 1)[ns_budget - 1] + bottleneck.partition(nx.to_numpy(K_sum_cols), ns_budget - 1)[ns_budget - 1], + type_as=M ) epsilon_u_square = a[0] / aK_sort @@ -2540,7 +2542,8 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res epsilon_v_square = b[0] / bK_sort[nt_budget - 1] else: bK_sort = nx.from_numpy( - bottleneck.partition(nx.to_numpy(K_sum_rows), nt_budget - 1)[nt_budget - 1] + bottleneck.partition(nx.to_numpy(K_sum_rows), nt_budget - 1)[nt_budget - 1], + type_as=M ) epsilon_v_square = b[0] / bK_sort else: @@ -2589,10 +2592,9 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res K_IcJ = K[np.ix_(Ic, Jsel)] K_IJc = K[np.ix_(Isel, Jc)] - #K_min = K_IJ.min() K_min = nx.min(K_IJ) if K_min == 0: - K_min = np.finfo(float).tiny + K_min = float(np.finfo(float).tiny) # a_I, b_J, a_Ic, b_Jc a_I = a[Isel] @@ -2713,7 +2715,7 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res maxfun=maxfun, pgtol=pgtol, maxiter=maxiter) - theta = nx.from_numpy(theta) + theta = nx.from_numpy(theta, type_as=M) usc = theta[:ns_budget] vsc = theta[ns_budget:] diff --git a/ot/gromov.py b/ot/gromov.py index 33b4453..a0fbf48 100644 --- a/ot/gromov.py +++ b/ot/gromov.py @@ -14,67 +14,85 @@ import numpy as np from .bregman import sinkhorn -from .utils import dist, UndefinedParameter +from .utils import dist, UndefinedParameter, list_to_array from .optim import cg from .lp import emd_1d, emd from .utils import check_random_state - -from scipy.sparse import issparse +from .backend import get_backend def init_matrix(C1, C2, p, q, loss_fun='square_loss'): r"""Return loss matrices and tensors for Gromov-Wasserstein fast computation - Returns the value of \mathcal{L}(C1,C2) \otimes T with the selected loss - function as the loss function of Gromow-Wasserstein discrepancy. + Returns the value of :math:`\mathcal{L}(\mathbf{C_1}, \mathbf{C_2}) \otimes \mathbf{T}` with the + selected loss function as the loss function of Gromow-Wasserstein discrepancy. - The matrices are computed as described in Proposition 1 in [12] + The matrices are computed as described in Proposition 1 in :ref:`[12] ` Where : - * C1 : Metric cost matrix in the source space - * C2 : Metric cost matrix in the target space - * T : A coupling between those two spaces - - The square-loss function L(a,b)=|a-b|^2 is read as : - L(a,b) = f1(a)+f2(b)-h1(a)*h2(b) with : - * f1(a)=(a^2) - * f2(b)=(b^2) - * h1(a)=a - * h2(b)=2*b - - The kl-loss function L(a,b)=a*log(a/b)-a+b is read as : - L(a,b) = f1(a)+f2(b)-h1(a)*h2(b) with : - * f1(a)=a*log(a)-a - * f2(b)=b - * h1(a)=a - * h2(b)=log(b) + + - :math:`\mathbf{C_1}`: Metric cost matrix in the source space + - :math:`\mathbf{C_2}`: Metric cost matrix in the target space + - :math:`\mathbf{T}`: A coupling between those two spaces + + The square-loss function :math:`L(a, b) = |a - b|^2` is read as : + + .. math:: + + L(a, b) = f_1(a) + f_2(b) - h_1(a) h_2(b) + + \mathrm{with} \ f_1(a) &= a^2 + + f_2(b) &= b^2 + + h_1(a) &= a + + h_2(b) &= 2b + + The kl-loss function :math:`L(a, b) = a \log\left(\frac{a}{b}\right) - a + b` is read as : + + .. math:: + + L(a, b) = f_1(a) + f_2(b) - h_1(a) h_2(b) + + \mathrm{with} \ f_1(a) &= a \log(a) - a + + f_2(b) &= b + + h_1(a) &= a + + h_2(b) &= \log(b) Parameters ---------- - C1 : ndarray, shape (ns, ns) + C1 : array-like, shape (ns, ns) Metric cost matrix in the source space - C2 : ndarray, shape (nt, nt) - Metric costfr matrix in the target space - T : ndarray, shape (ns, nt) + C2 : array-like, shape (nt, nt) + Metric cost matrix in the target space + T : array-like, shape (ns, nt) Coupling between source and target spaces - p : ndarray, shape (ns,) + p : array-like, shape (ns,) Returns ------- - constC : ndarray, shape (ns, nt) - Constant C matrix in Eq. (6) - hC1 : ndarray, shape (ns, ns) - h1(C1) matrix in Eq. (6) - hC2 : ndarray, shape (nt, nt) - h2(C) matrix in Eq. (6) + constC : array-like, shape (ns, nt) + Constant :math:`\mathbf{C}` matrix in Eq. (6) + hC1 : array-like, shape (ns, ns) + :math:`\mathbf{h1}(\mathbf{C1})` matrix in Eq. (6) + hC2 : array-like, shape (nt, nt) + :math:`\mathbf{h2}(\mathbf{C2})` matrix in Eq. (6) + + .. _references-init-matrix: References ---------- - .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon, + .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon, "Gromov-Wasserstein averaging of kernel and distance matrices." International Conference on Machine Learning (ICML). 2016. """ + C1, C2, p, q = list_to_array(C1, C2, p, q) + nx = get_backend(C1, C2, p, q) if loss_fun == 'square_loss': def f1(a): @@ -90,7 +108,7 @@ def init_matrix(C1, C2, p, q, loss_fun='square_loss'): return 2 * b elif loss_fun == 'kl_loss': def f1(a): - return a * np.log(a + 1e-15) - a + return a * nx.log(a + 1e-15) - a def f2(b): return b @@ -99,12 +117,16 @@ def init_matrix(C1, C2, p, q, loss_fun='square_loss'): return a def h2(b): - return np.log(b + 1e-15) - - constC1 = np.dot(np.dot(f1(C1), p.reshape(-1, 1)), - np.ones(len(q)).reshape(1, -1)) - constC2 = np.dot(np.ones(len(p)).reshape(-1, 1), - np.dot(q.reshape(1, -1), f2(C2).T)) + return nx.log(b + 1e-15) + + constC1 = nx.dot( + nx.dot(f1(C1), nx.reshape(p, (-1, 1))), + nx.ones((1, len(q)), type_as=q) + ) + constC2 = nx.dot( + nx.ones((len(p), 1), type_as=p), + nx.dot(nx.reshape(q, (1, -1)), f2(C2).T) + ) constC = constC1 + constC2 hC1 = h1(C1) hC2 = h2(C2) @@ -115,30 +137,37 @@ def init_matrix(C1, C2, p, q, loss_fun='square_loss'): def tensor_product(constC, hC1, hC2, T): r"""Return the tensor for Gromov-Wasserstein fast computation - The tensor is computed as described in Proposition 1 Eq. (6) in [12]. + The tensor is computed as described in Proposition 1 Eq. (6) in :ref:`[12] ` Parameters ---------- - constC : ndarray, shape (ns, nt) - Constant C matrix in Eq. (6) - hC1 : ndarray, shape (ns, ns) - h1(C1) matrix in Eq. (6) - hC2 : ndarray, shape (nt, nt) - h2(C) matrix in Eq. (6) + constC : array-like, shape (ns, nt) + Constant :math:`\mathbf{C}` matrix in Eq. (6) + hC1 : array-like, shape (ns, ns) + :math:`\mathbf{h1}(\mathbf{C1})` matrix in Eq. (6) + hC2 : array-like, shape (nt, nt) + :math:`\mathbf{h2}(\mathbf{C2})` matrix in Eq. (6) Returns ------- - tens : ndarray, shape (ns, nt) - \mathcal{L}(C1,C2) \otimes T tensor-matrix multiplication result + tens : array-like, shape (`ns`, `nt`) + :math:`\mathcal{L}(\mathbf{C_1}, \mathbf{C_2}) \otimes \mathbf{T}` tensor-matrix multiplication result + + .. _references-tensor-product: References ---------- - .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon, + .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon, "Gromov-Wasserstein averaging of kernel and distance matrices." International Conference on Machine Learning (ICML). 2016. """ - A = -np.dot(hC1, T).dot(hC2.T) + constC, hC1, hC2, T = list_to_array(constC, hC1, hC2, T) + nx = get_backend(constC, hC1, hC2, T) + + A = - nx.dot( + nx.dot(hC1, T), hC2.T + ) tens = constC + A # tens -= tens.min() return tens @@ -147,27 +176,29 @@ def tensor_product(constC, hC1, hC2, T): def gwloss(constC, hC1, hC2, T): """Return the Loss for Gromov-Wasserstein - The loss is computed as described in Proposition 1 Eq. (6) in [12]. + The loss is computed as described in Proposition 1 Eq. (6) in :ref:`[12] ` Parameters ---------- - constC : ndarray, shape (ns, nt) - Constant C matrix in Eq. (6) - hC1 : ndarray, shape (ns, ns) - h1(C1) matrix in Eq. (6) - hC2 : ndarray, shape (nt, nt) - h2(C) matrix in Eq. (6) - T : ndarray, shape (ns, nt) - Current value of transport matrix T + constC : array-like, shape (ns, nt) + Constant :math:`\mathbf{C}` matrix in Eq. (6) + hC1 : array-like, shape (ns, ns) + :math:`\mathbf{h1}(\mathbf{C1})` matrix in Eq. (6) + hC2 : array-like, shape (nt, nt) + :math:`\mathbf{h2}(\mathbf{C2})` matrix in Eq. (6) + T : array-like, shape (ns, nt) + Current value of transport matrix :math:`\mathbf{T}` Returns ------- loss : float Gromov Wasserstein loss + + .. _references-gwloss: References ---------- - .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon, + .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon, "Gromov-Wasserstein averaging of kernel and distance matrices." International Conference on Machine Learning (ICML). 2016. @@ -175,33 +206,38 @@ def gwloss(constC, hC1, hC2, T): tens = tensor_product(constC, hC1, hC2, T) - return np.sum(tens * T) + tens, T = list_to_array(tens, T) + nx = get_backend(tens, T) + + return nx.sum(tens * T) def gwggrad(constC, hC1, hC2, T): """Return the gradient for Gromov-Wasserstein - The gradient is computed as described in Proposition 2 in [12]. + The gradient is computed as described in Proposition 2 in :ref:`[12] ` Parameters ---------- - constC : ndarray, shape (ns, nt) - Constant C matrix in Eq. (6) - hC1 : ndarray, shape (ns, ns) - h1(C1) matrix in Eq. (6) - hC2 : ndarray, shape (nt, nt) - h2(C) matrix in Eq. (6) - T : ndarray, shape (ns, nt) - Current value of transport matrix T + constC : array-like, shape (ns, nt) + Constant :math:`\mathbf{C}` matrix in Eq. (6) + hC1 : array-like, shape (ns, ns) + :math:`\mathbf{h1}(\mathbf{C1})` matrix in Eq. (6) + hC2 : array-like, shape (nt, nt) + :math:`\mathbf{h2}(\mathbf{C2})` matrix in Eq. (6) + T : array-like, shape (ns, nt) + Current value of transport matrix :math:`\mathbf{T}` Returns ------- - grad : ndarray, shape (ns, nt) + grad : array-like, shape (`ns`, `nt`) Gromov Wasserstein gradient + + .. _references-gwggrad: References ---------- - .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon, + .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon, "Gromov-Wasserstein averaging of kernel and distance matrices." International Conference on Machine Learning (ICML). 2016. @@ -212,88 +248,107 @@ def gwggrad(constC, hC1, hC2, T): def update_square_loss(p, lambdas, T, Cs): """ - Updates C according to the L2 Loss kernel with the S Ts couplings - calculated at each iteration + Updates :math:`\mathbf{C}` according to the L2 Loss kernel with the `S` :math:`\mathbf{T}_s` + couplings calculated at each iteration Parameters ---------- - p : ndarray, shape (N,) + p : array-like, shape (N,) Masses in the targeted barycenter. lambdas : list of float - List of the S spaces' weights. - T : list of S np.ndarray of shape (ns,N) - The S Ts couplings calculated at each iteration. - Cs : list of S ndarray, shape(ns,ns) + List of the `S` spaces' weights. + T : list of S array-like of shape (ns,N) + The `S` :math:`\mathbf{T}_s` couplings calculated at each iteration. + Cs : list of S array-like, shape(ns,ns) Metric cost matrices. Returns ---------- - C : ndarray, shape (nt, nt) - Updated C matrix. + C : array-like, shape (`nt`, `nt`) + Updated :math:`\mathbf{C}` matrix. """ - tmpsum = sum([lambdas[s] * np.dot(T[s].T, Cs[s]).dot(T[s]) - for s in range(len(T))]) - ppt = np.outer(p, p) + T = list_to_array(*T) + Cs = list_to_array(*Cs) + p = list_to_array(p) + nx = get_backend(p, *T, *Cs) + + tmpsum = sum([ + lambdas[s] * nx.dot( + nx.dot(T[s].T, Cs[s]), + T[s] + ) for s in range(len(T)) + ]) + ppt = nx.outer(p, p) - return np.divide(tmpsum, ppt) + return tmpsum / ppt def update_kl_loss(p, lambdas, T, Cs): """ - Updates C according to the KL Loss kernel with the S Ts couplings calculated at each iteration + Updates :math:`\mathbf{C}` according to the KL Loss kernel with the `S` :math:`\mathbf{T}_s` couplings calculated at each iteration Parameters ---------- - p : ndarray, shape (N,) + p : array-like, shape (N,) Weights in the targeted barycenter. - lambdas : list of the S spaces' weights - T : list of S np.ndarray of shape (ns,N) - The S Ts couplings calculated at each iteration. - Cs : list of S ndarray, shape(ns,ns) + lambdas : list of float + List of the `S` spaces' weights + T : list of S array-like of shape (ns,N) + The `S` :math:`\mathbf{T}_s` couplings calculated at each iteration. + Cs : list of S array-like, shape(ns,ns) Metric cost matrices. Returns ---------- - C : ndarray, shape (ns,ns) - updated C matrix + C : array-like, shape (`ns`, `ns`) + updated :math:`\mathbf{C}` matrix """ - tmpsum = sum([lambdas[s] * np.dot(T[s].T, Cs[s]).dot(T[s]) - for s in range(len(T))]) - ppt = np.outer(p, p) + Cs = list_to_array(*Cs) + T = list_to_array(*T) + p = list_to_array(p) + nx = get_backend(p, *T, *Cs) - return np.exp(np.divide(tmpsum, ppt)) + tmpsum = sum([ + lambdas[s] * nx.dot( + nx.dot(T[s].T, Cs[s]), + T[s] + ) for s in range(len(T)) + ]) + ppt = nx.outer(p, p) + + return nx.exp(tmpsum / ppt) def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs): r""" - Returns the gromov-wasserstein transport between (C1,p) and (C2,q) + Returns the gromov-wasserstein transport between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` The function solves the following optimization problem: .. math:: - GW = \min_T \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l} + \mathbf{GW} = \mathop{\arg \min}_\mathbf{T} \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} Where : - - C1 : Metric cost matrix in the source space - - C2 : Metric cost matrix in the target space - - p : distribution in the source space - - q : distribution in the target space - - L : loss function to account for the misfit between the similarity matrices + + - :math:`\mathbf{C_1}`: Metric cost matrix in the source space + - :math:`\mathbf{C_2}`: Metric cost matrix in the target space + - :math:`\mathbf{p}`: distribution in the source space + - :math:`\mathbf{q}`: distribution in the target space + - `L`: loss function to account for the misfit between the similarity matrices Parameters ---------- - C1 : ndarray, shape (ns, ns) + C1 : array-like, shape (ns, ns) Metric cost matrix in the source space - C2 : ndarray, shape (nt, nt) - Metric costfr matrix in the target space - p : ndarray, shape (ns,) + C2 : array-like, shape (nt, nt) + Metric cost matrix in the target space + p : array-like, shape (ns,) Distribution in the source space - q : ndarray, shape (nt,) + q : array-like, shape (nt,) Distribution in the target space loss_fun : str loss function used for the solver either 'square_loss' or 'kl_loss' - max_iter : int, optional Max number of iterations tol : float, optional @@ -303,22 +358,23 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs log : bool, optional record log if True armijo : bool, optional - If True the steps of the line-search is found via an armijo research. Else closed form is used. - If there is convergence issues use False. + If True the step of the line-search is found via an armijo research. Else closed form is used. + If there are convergence issues use False. **kwargs : dict parameters can be directly passed to the ot.optim.cg solver Returns ------- - T : ndarray, shape (ns, nt) - Doupling between the two spaces that minimizes: - \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l} + T : array-like, shape (`ns`, `nt`) + Coupling between the two spaces that minimizes: + + :math:`\sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l}` log : dict Convergence information and loss. References ---------- - .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon, + .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon, "Gromov-Wasserstein averaging of kernel and distance matrices." International Conference on Machine Learning (ICML). 2016. @@ -327,6 +383,7 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs mathematics 11.4 (2011): 417-487. """ + p, q = list_to_array(p, q) constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun) @@ -348,29 +405,30 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs): r""" - Returns the gromov-wasserstein discrepancy between (C1,p) and (C2,q) + Returns the gromov-wasserstein discrepancy between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` The function solves the following optimization problem: .. math:: - GW = \min_T \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l} + GW = \min_\mathbf{T} \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} Where : - - C1 : Metric cost matrix in the source space - - C2 : Metric cost matrix in the target space - - p : distribution in the source space - - q : distribution in the target space - - L : loss function to account for the misfit between the similarity matrices + + - :math:`\mathbf{C_1}`: Metric cost matrix in the source space + - :math:`\mathbf{C_2}`: Metric cost matrix in the target space + - :math:`\mathbf{p}`: distribution in the source space + - :math:`\mathbf{q}`: distribution in the target space + - `L`: loss function to account for the misfit between the similarity matrices Parameters ---------- - C1 : ndarray, shape (ns, ns) + C1 : array-like, shape (ns, ns) Metric cost matrix in the source space - C2 : ndarray, shape (nt, nt) + C2 : array-like, shape (nt, nt) Metric cost matrix in the target space - p : ndarray, shape (ns,) + p : array-like, shape (ns,) Distribution in the source space. - q : ndarray, shape (nt,) + q : array-like, shape (nt,) Distribution in the target space. loss_fun : str loss function used for the solver either 'square_loss' or 'kl_loss' @@ -383,8 +441,8 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwarg log : bool, optional record log if True armijo : bool, optional - If True the steps of the line-search is found via an armijo research. Else closed form is used. - If there is convergence issues use False. + If True the step of the line-search is found via an armijo research. Else closed form is used. + If there are convergence issues use False. Returns ------- @@ -395,7 +453,7 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwarg References ---------- - .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon, + .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon, "Gromov-Wasserstein averaging of kernel and distance matrices." International Conference on Machine Learning (ICML). 2016. @@ -404,6 +462,7 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwarg mathematics 11.4 (2011): 417-487. """ + p, q = list_to_array(p, q) constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun) @@ -425,42 +484,45 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwarg def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, **kwargs): r""" - Computes the FGW transport between two graphs see [24] + Computes the FGW transport between two graphs (see :ref:`[24] `) .. math:: - \gamma = arg\min_\gamma (1-\\alpha)*<\gamma,M>_F + \\alpha* \sum_{i,j,k,l} - L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l} + \gamma = \mathop{\arg \min}_\gamma (1 - \alpha) <\gamma, \mathbf{M}>_F + \alpha \sum_{i,j,k,l} + L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} + + s.t. \ \mathbf{\gamma} \mathbf{1} &= \mathbf{p} - s.t. \gamma 1 = p - \gamma^T 1= q - \gamma\geq 0 + \mathbf{\gamma}^T \mathbf{1} &= \mathbf{q} + + \mathbf{\gamma} &\geq 0 where : - - M is the (ns,nt) metric cost matrix - - p and q are source and target weights (sum to 1) - - L is a loss function to account for the misfit between the similarity matrices - The algorithm used for solving the problem is conditional gradient as discussed in [24]_ + - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix + - :math:`\mathbf{p}` and :math:`\mathbf{q}` are source and target weights (sum to 1) + - `L` is a loss function to account for the misfit between the similarity matrices + + The algorithm used for solving the problem is conditional gradient as discussed in :ref:`[24] ` Parameters ---------- - M : ndarray, shape (ns, nt) + M : array-like, shape (ns, nt) Metric cost matrix between features across domains - C1 : ndarray, shape (ns, ns) + C1 : array-like, shape (ns, ns) Metric cost matrix representative of the structure in the source space - C2 : ndarray, shape (nt, nt) + C2 : array-like, shape (nt, nt) Metric cost matrix representative of the structure in the target space - p : ndarray, shape (ns,) + p : array-like, shape (ns,) Distribution in the source space - q : ndarray, shape (nt,) + q : array-like, shape (nt,) Distribution in the target space loss_fun : str, optional Loss function used for the solver alpha : float, optional Trade-off parameter (0 < alpha < 1) armijo : bool, optional - If True the steps of the line-search is found via an armijo research. Else closed form is used. - If there is convergence issues use False. + If True the step of the line-search is found via an armijo research. Else closed form is used. + If there are convergence issues use False. log : bool, optional record log if True **kwargs : dict @@ -468,18 +530,21 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, Returns ------- - gamma : ndarray, shape (ns, nt) + gamma : array-like, shape (`ns`, `nt`) Optimal transportation matrix for the given parameters. log : dict Log dictionary return only if log==True in parameters. + + .. _references-fused-gromov-wasserstein: References ---------- - .. [24] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain + .. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain and Courty Nicolas "Optimal Transport for structured data with application on graphs", International Conference on Machine Learning (ICML). 2019. """ + p, q = list_to_array(p, q) constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun) @@ -501,61 +566,67 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, **kwargs): r""" - Computes the FGW distance between two graphs see [24] + Computes the FGW distance between two graphs see (see :ref:`[24] `) .. math:: - \min_\gamma (1-\\alpha)*<\gamma,M>_F + \\alpha* \sum_{i,j,k,l} - L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l} + \min_\gamma (1 - \alpha) <\gamma, \mathbf{M}>_F + \alpha \sum_{i,j,k,l} + L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} + s.t. \ \mathbf{\gamma} \mathbf{1} &= \mathbf{p} - s.t. \gamma 1 = p - \gamma^T 1= q - \gamma\geq 0 + \mathbf{\gamma}^T \mathbf{1} &= \mathbf{q} + + \mathbf{\gamma} &\geq 0 where : - - M is the (ns,nt) metric cost matrix - - p and q are source and target weights (sum to 1) - - L is a loss function to account for the misfit between the similarity matrices - The algorithm used for solving the problem is conditional gradient as discussed in [1]_ + + - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix + - :math:`\mathbf{p}` and :math:`\mathbf{q}` are source and target weights (sum to 1) + - `L` is a loss function to account for the misfit between the similarity matrices + + The algorithm used for solving the problem is conditional gradient as discussed in :ref:`[24] ` Parameters ---------- - M : ndarray, shape (ns, nt) + M : array-like, shape (ns, nt) Metric cost matrix between features across domains - C1 : ndarray, shape (ns, ns) + C1 : array-like, shape (ns, ns) Metric cost matrix respresentative of the structure in the source space. - C2 : ndarray, shape (nt, nt) + C2 : array-like, shape (nt, nt) Metric cost matrix espresentative of the structure in the target space. - p : ndarray, shape (ns,) + p : array-like, shape (ns,) Distribution in the source space. - q : ndarray, shape (nt,) + q : array-like, shape (nt,) Distribution in the target space. loss_fun : str, optional Loss function used for the solver. alpha : float, optional Trade-off parameter (0 < alpha < 1) armijo : bool, optional - If True the steps of the line-search is found via an armijo research. - Else closed form is used. If there is convergence issues use False. + If True the step of the line-search is found via an armijo research. + Else closed form is used. If there are convergence issues use False. log : bool, optional Record log if True. **kwargs : dict - Parameters can be directly pased to the ot.optim.cg solver. + Parameters can be directly passed to the ot.optim.cg solver. Returns ------- - gamma : ndarray, shape (ns, nt) + gamma : array-like, shape (ns, nt) Optimal transportation matrix for the given parameters. log : dict Log dictionary return only if log==True in parameters. + + .. _references-fused-gromov-wasserstein2: References ---------- - .. [24] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain + .. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain and Courty Nicolas "Optimal Transport for structured data with application on graphs" International Conference on Machine Learning (ICML). 2019. """ + p, q = list_to_array(p, q) constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun) @@ -579,60 +650,64 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5 def GW_distance_estimation(C1, C2, p, q, loss_fun, T, nb_samples_p=None, nb_samples_q=None, std=True, random_state=None): r""" - Returns an approximation of the gromov-wasserstein cost between (C1,p) and (C2,q) - with a fixed transport plan T. - - The function gives an unbiased approximation of the following equation: - - .. math:: - GW = \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l} - - Where : - - - C1 : Metric cost matrix in the source space - - C2 : Metric cost matrix in the target space - - L : Loss function to account for the misfit between the similarity matrices - - T : Matrix with marginal p and q - - Parameters - ---------- - C1 : ndarray, shape (ns, ns) - Metric cost matrix in the source space - C2 : ndarray, shape (nt, nt) - Metric costfr matrix in the target space - p : ndarray, shape (ns,) - Distribution in the source space - q : ndarray, shape (nt,) - Distribution in the target space - loss_fun : function: \mathcal{R} \times \mathcal{R} \shortarrow \mathcal{R} - Loss function used for the distance, the transport plan does not depend on the loss function - T : csr or ndarray, shape (ns, nt) - Transport plan matrix, either a sparse csr matrix or - nb_samples_p : int, optional - nb_samples_p is the number of samples (without replacement) along the first dimension of T. - nb_samples_q : int, optional - nb_samples_q is the number of samples along the second dimension of T, for each sample along the first. - std : bool, optional - Standard deviation associated with the prediction of the gromov-wasserstein cost. - random_state : int or RandomState instance, optional - Fix the seed for to allow reproducibility - - Returns - ------- - : float - Gromov-wasserstein cost - - References - ---------- - .. [14] Kerdoncuff, Tanguy, Emonet, Rémi, Sebban, Marc - "Sampled Gromov Wasserstein." - Machine Learning Journal (MLJ). 2021. - - """ + Returns an approximation of the gromov-wasserstein cost between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` + with a fixed transport plan :math:`\mathbf{T}`. + + The function gives an unbiased approximation of the following equation: + + .. math:: + + GW = \sum_{i,j,k,l} L(\mathbf{C_{1}}_{i,k}, \mathbf{C_{2}}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} + + Where : + + - :math:`\mathbf{C_1}`: Metric cost matrix in the source space + - :math:`\mathbf{C_2}`: Metric cost matrix in the target space + - `L` : Loss function to account for the misfit between the similarity matrices + - :math:`\mathbf{T}`: Matrix with marginal :math:`\mathbf{p}` and :math:`\mathbf{q}` + + Parameters + ---------- + C1 : array-like, shape (ns, ns) + Metric cost matrix in the source space + C2 : array-like, shape (nt, nt) + Metric cost matrix in the target space + p : array-like, shape (ns,) + Distribution in the source space + q : array-like, shape (nt,) + Distribution in the target space + loss_fun : function: :math:`\mathbb{R} \times \mathbb{R} \mapsto \mathbb{R}` + Loss function used for the distance, the transport plan does not depend on the loss function + T : csr or array-like, shape (ns, nt) + Transport plan matrix, either a sparse csr or a dense matrix + nb_samples_p : int, optional + `nb_samples_p` is the number of samples (without replacement) along the first dimension of :math:`\mathbf{T}` + nb_samples_q : int, optional + `nb_samples_q` is the number of samples along the second dimension of :math:`\mathbf{T}`, for each sample along the first + std : bool, optional + Standard deviation associated with the prediction of the gromov-wasserstein cost + random_state : int or RandomState instance, optional + Fix the seed for reproducibility + + Returns + ------- + : float + Gromov-wasserstein cost + + References + ---------- + .. [14] Kerdoncuff, Tanguy, Emonet, Rémi, Sebban, Marc + "Sampled Gromov Wasserstein." + Machine Learning Journal (MLJ). 2021. + + """ + C1, C2, p, q = list_to_array(C1, C2, p, q) + nx = get_backend(C1, C2, p, q) + generator = check_random_state(random_state) - len_p = len(p) - len_q = len(q) + len_p = p.shape[0] + len_q = q.shape[0] # It is always better to sample from the biggest distribution first. if len_p < len_q: @@ -642,7 +717,7 @@ def GW_distance_estimation(C1, C2, p, q, loss_fun, T, T = T.T if nb_samples_p is None: - if issparse(T): + if nx.issparse(T): # If T is sparse, it probably mean that PoGroW was used, thus the number of sample is reduced nb_samples_p = min(int(5 * (len_p * np.log(len_p)) ** 0.5), len_p) else: @@ -657,100 +732,112 @@ def GW_distance_estimation(C1, C2, p, q, loss_fun, T, index_k = np.zeros((nb_samples_p, nb_samples_q), dtype=int) index_l = np.zeros((nb_samples_p, nb_samples_q), dtype=int) - list_value_sample = np.zeros((nb_samples_p, nb_samples_p, nb_samples_q)) index_i = generator.choice(len_p, size=nb_samples_p, p=p, replace=False) index_j = generator.choice(len_p, size=nb_samples_p, p=p, replace=False) for i in range(nb_samples_p): - if issparse(T): - T_indexi = T[index_i[i], :].toarray()[0] - T_indexj = T[index_j[i], :].toarray()[0] + if nx.issparse(T): + T_indexi = nx.reshape(nx.todense(T[index_i[i], :]), (-1,)) + T_indexj = nx.reshape(nx.todense(T[index_j[i], :]), (-1,)) else: T_indexi = T[index_i[i], :] T_indexj = T[index_j[i], :] # For each of the row sampled, the column is sampled. - index_k[i] = generator.choice(len_q, size=nb_samples_q, p=T_indexi / T_indexi.sum(), replace=True) - index_l[i] = generator.choice(len_q, size=nb_samples_q, p=T_indexj / T_indexj.sum(), replace=True) - - for n in range(nb_samples_q): - list_value_sample[:, :, n] = loss_fun(C1[np.ix_(index_i, index_j)], C2[np.ix_(index_k[:, n], index_l[:, n])]) + index_k[i] = generator.choice( + len_q, + size=nb_samples_q, + p=T_indexi / nx.sum(T_indexi), + replace=True + ) + index_l[i] = generator.choice( + len_q, + size=nb_samples_q, + p=T_indexj / nx.sum(T_indexj), + replace=True + ) + + list_value_sample = nx.stack([ + loss_fun( + C1[np.ix_(index_i, index_j)], + C2[np.ix_(index_k[:, n], index_l[:, n])] + ) for n in range(nb_samples_q) + ], axis=2) if std: - std_value = np.sum(np.std(list_value_sample, axis=2) ** 2) ** 0.5 - return np.mean(list_value_sample), std_value / (nb_samples_p * nb_samples_p) + std_value = nx.sum(nx.std(list_value_sample, axis=2) ** 2) ** 0.5 + return nx.mean(list_value_sample), std_value / (nb_samples_p * nb_samples_p) else: - return np.mean(list_value_sample) + return nx.mean(list_value_sample) def pointwise_gromov_wasserstein(C1, C2, p, q, loss_fun, alpha=1, max_iter=100, threshold_plan=0, log=False, verbose=False, random_state=None): r""" - Returns the gromov-wasserstein transport between (C1,p) and (C2,q) using a stochastic Frank-Wolfe. - This method as a O(max_iter \times PN^2) time complexity with P the number of Sinkhorn iterations. - - The function solves the following optimization problem: - - .. math:: - GW = arg\min_T \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l} - - s.t. T 1 = p - - T^T 1= q - - T\geq 0 - - Where : - - - C1 : Metric cost matrix in the source space - - C2 : Metric cost matrix in the target space - - p : distribution in the source space - - q : distribution in the target space - - L : loss function to account for the misfit between the similarity matrices - - Parameters - ---------- - C1 : ndarray, shape (ns, ns) - Metric cost matrix in the source space - C2 : ndarray, shape (nt, nt) - Metric costfr matrix in the target space - p : ndarray, shape (ns,) - Distribution in the source space - q : ndarray, shape (nt,) - Distribution in the target space - loss_fun : function: \mathcal{R} \times \mathcal{R} \shortarrow \mathcal{R} - Loss function used for the distance, the transport plan does not depend on the loss function - alpha : float - Step of the Frank-Wolfe algorithm, should be between 0 and 1 - max_iter : int, optional - Max number of iterations - threshold_plan : float, optional - Deleting very small value in the transport plan. If above zero, it violate the marginal constraints. - verbose : bool, optional - Print information along iterations - log : bool, optional - Gives the distance estimated and the standard deviation - random_state : int or RandomState instance, optional - Fix the seed for to allow reproducibility - - Returns - ------- - T : ndarray, shape (ns, nt) - Optimal coupling between the two spaces - - References - ---------- - .. [14] Kerdoncuff, Tanguy, Emonet, Rémi, Sebban, Marc - "Sampled Gromov Wasserstein." - Machine Learning Journal (MLJ). 2021. - - """ - C1 = np.asarray(C1, dtype=np.float64) - C2 = np.asarray(C2, dtype=np.float64) - p = np.asarray(p, dtype=np.float64) - q = np.asarray(q, dtype=np.float64) - len_p = len(p) - len_q = len(q) + Returns the gromov-wasserstein transport between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` using a stochastic Frank-Wolfe. + This method has a :math:`\mathcal{O}(\mathrm{max\_iter} \times PN^2)` time complexity with `P` the number of Sinkhorn iterations. + + The function solves the following optimization problem: + + .. math:: + \mathbf{GW} = \mathop{\arg \min}_\mathbf{T} \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} + + s.t. \ \mathbf{T} \mathbf{1} &= \mathbf{p} + + \mathbf{T}^T \mathbf{1} &= \mathbf{q} + + \mathbf{T} &\geq 0 + + Where : + + - :math:`\mathbf{C_1}`: Metric cost matrix in the source space + - :math:`\mathbf{C_2}`: Metric cost matrix in the target space + - :math:`\mathbf{p}`: distribution in the source space + - :math:`\mathbf{q}`: distribution in the target space + - `L`: loss function to account for the misfit between the similarity matrices + + Parameters + ---------- + C1 : array-like, shape (ns, ns) + Metric cost matrix in the source space + C2 : array-like, shape (nt, nt) + Metric cost matrix in the target space + p : array-like, shape (ns,) + Distribution in the source space + q : array-like, shape (nt,) + Distribution in the target space + loss_fun : function: :math:`\mathbb{R} \times \mathbb{R} \mapsto \mathbb{R}` + Loss function used for the distance, the transport plan does not depend on the loss function + alpha : float + Step of the Frank-Wolfe algorithm, should be between 0 and 1 + max_iter : int, optional + Max number of iterations + threshold_plan : float, optional + Deleting very small values in the transport plan. If above zero, it violates the marginal constraints. + verbose : bool, optional + Print information along iterations + log : bool, optional + Gives the distance estimated and the standard deviation + random_state : int or RandomState instance, optional + Fix the seed for reproducibility + + Returns + ------- + T : array-like, shape (`ns`, `nt`) + Optimal coupling between the two spaces + + References + ---------- + .. [14] Kerdoncuff, Tanguy, Emonet, Rémi, Sebban, Marc + "Sampled Gromov Wasserstein." + Machine Learning Journal (MLJ). 2021. + + """ + C1, C2, p, q = list_to_array(C1, C2, p, q) + nx = get_backend(C1, C2, p, q) + + len_p = p.shape[0] + len_q = q.shape[0] generator = check_random_state(random_state) @@ -759,30 +846,35 @@ def pointwise_gromov_wasserstein(C1, C2, p, q, loss_fun, # Initialize with default marginal index[0] = generator.choice(len_p, size=1, p=p) index[1] = generator.choice(len_q, size=1, p=q) - T = emd_1d(C1[index[0]], C2[index[1]], a=p, b=q, dense=False).tocsr() + T = nx.tocsr(emd_1d(C1[index[0]], C2[index[1]], a=p, b=q, dense=False)) best_gw_dist_estimated = np.inf for cpt in range(max_iter): index[0] = generator.choice(len_p, size=1, p=p) - T_index0 = T[index[0], :].toarray()[0] + T_index0 = nx.reshape(nx.todense(T[index[0], :]), (-1,)) index[1] = generator.choice(len_q, size=1, p=T_index0 / T_index0.sum()) if alpha == 1: - T = emd_1d(C1[index[0]], C2[index[1]], a=p, b=q, dense=False).tocsr() + T = nx.tocsr( + emd_1d(C1[index[0]], C2[index[1]], a=p, b=q, dense=False) + ) else: - new_T = emd_1d(C1[index[0]], C2[index[1]], a=p, b=q, dense=False).tocsr() + new_T = nx.tocsr( + emd_1d(C1[index[0]], C2[index[1]], a=p, b=q, dense=False) + ) T = (1 - alpha) * T + alpha * new_T - # To limit the number of non 0, the values bellow the threshold are set to 0. - T.data[T.data < threshold_plan] = 0 - T.eliminate_zeros() + # To limit the number of non 0, the values below the threshold are set to 0. + T = nx.eliminate_zeros(T, threshold=threshold_plan) if cpt % 10 == 0 or cpt == (max_iter - 1): - gw_dist_estimated = GW_distance_estimation(C1=C1, C2=C2, loss_fun=loss_fun, - p=p, q=q, T=T, std=False, random_state=generator) + gw_dist_estimated = GW_distance_estimation( + C1=C1, C2=C2, loss_fun=loss_fun, + p=p, q=q, T=T, std=False, random_state=generator + ) if gw_dist_estimated < best_gw_dist_estimated: best_gw_dist_estimated = gw_dist_estimated - best_T = T.copy() + best_T = nx.copy(T) if verbose: if cpt % 200 == 0: @@ -791,9 +883,10 @@ def pointwise_gromov_wasserstein(C1, C2, p, q, loss_fun, if log: log = {} - log["gw_dist_estimated"], log["gw_dist_std"] = GW_distance_estimation(C1=C1, C2=C2, loss_fun=loss_fun, - p=p, q=q, T=best_T, - random_state=generator) + log["gw_dist_estimated"], log["gw_dist_std"] = GW_distance_estimation( + C1=C1, C2=C2, loss_fun=loss_fun, + p=p, q=q, T=best_T, random_state=generator + ) return best_T, log return best_T @@ -802,71 +895,70 @@ def sampled_gromov_wasserstein(C1, C2, p, q, loss_fun, nb_samples_grad=100, epsilon=1, max_iter=500, log=False, verbose=False, random_state=None): r""" - Returns the gromov-wasserstein transport between (C1,p) and (C2,q) using a 1-stochastic Frank-Wolfe. - This method as a O(max_iter \times Nlog(N)) time complexity by relying on the 1D Optimal Transport solver. - - The function solves the following optimization problem: - - .. math:: - GW = arg\min_T \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l} - - s.t. T 1 = p - - T^T 1= q - - T\geq 0 - - Where : - - - C1 : Metric cost matrix in the source space - - C2 : Metric cost matrix in the target space - - p : distribution in the source space - - q : distribution in the target space - - L : loss function to account for the misfit between the similarity matrices - - Parameters - ---------- - C1 : ndarray, shape (ns, ns) - Metric cost matrix in the source space - C2 : ndarray, shape (nt, nt) - Metric costfr matrix in the target space - p : ndarray, shape (ns,) - Distribution in the source space - q : ndarray, shape (nt,) - Distribution in the target space - loss_fun : function: \mathcal{R} \times \mathcal{R} \shortarrow \mathcal{R} - Loss function used for the distance, the transport plan does not depend on the loss function - nb_samples_grad : int - Number of samples to approximate the gradient - epsilon : float - Weight of the Kullback-Leiber regularization - max_iter : int, optional - Max number of iterations - verbose : bool, optional - Print information along iterations - log : bool, optional - Gives the distance estimated and the standard deviation - random_state : int or RandomState instance, optional - Fix the seed for to allow reproducibility - - Returns - ------- - T : ndarray, shape (ns, nt) - Optimal coupling between the two spaces - - References - ---------- - .. [14] Kerdoncuff, Tanguy, Emonet, Rémi, Sebban, Marc - "Sampled Gromov Wasserstein." - Machine Learning Journal (MLJ). 2021. - - """ - C1 = np.asarray(C1, dtype=np.float64) - C2 = np.asarray(C2, dtype=np.float64) - p = np.asarray(p, dtype=np.float64) - q = np.asarray(q, dtype=np.float64) - len_p = len(p) - len_q = len(q) + Returns the gromov-wasserstein transport between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` using a 1-stochastic Frank-Wolfe. + This method has a :math:`\mathcal{O}(\mathrm{max\_iter} \times N \log(N))` time complexity by relying on the 1D Optimal Transport solver. + + The function solves the following optimization problem: + + .. math:: + \mathbf{GW} = \mathop{\arg \min}_\mathbf{T} \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} + + s.t. \ \mathbf{T} \mathbf{1} &= \mathbf{p} + + \mathbf{T}^T \mathbf{1} &= \mathbf{q} + + \mathbf{T} &\geq 0 + + Where : + + - :math:`\mathbf{C_1}`: Metric cost matrix in the source space + - :math:`\mathbf{C_2}`: Metric cost matrix in the target space + - :math:`\mathbf{p}`: distribution in the source space + - :math:`\mathbf{q}`: distribution in the target space + - `L`: loss function to account for the misfit between the similarity matrices + + Parameters + ---------- + C1 : array-like, shape (ns, ns) + Metric cost matrix in the source space + C2 : array-like, shape (nt, nt) + Metric cost matrix in the target space + p : array-like, shape (ns,) + Distribution in the source space + q : array-like, shape (nt,) + Distribution in the target space + loss_fun : function: :math:`\mathbb{R} \times \mathbb{R} \mapsto \mathbb{R}` + Loss function used for the distance, the transport plan does not depend on the loss function + nb_samples_grad : int + Number of samples to approximate the gradient + epsilon : float + Weight of the Kullback-Leibler regularization + max_iter : int, optional + Max number of iterations + verbose : bool, optional + Print information along iterations + log : bool, optional + Gives the distance estimated and the standard deviation + random_state : int or RandomState instance, optional + Fix the seed for reproducibility + + Returns + ------- + T : array-like, shape (`ns`, `nt`) + Optimal coupling between the two spaces + + References + ---------- + .. [14] Kerdoncuff, Tanguy, Emonet, Rémi, Sebban, Marc + "Sampled Gromov Wasserstein." + Machine Learning Journal (MLJ). 2021. + + """ + C1, C2, p, q = list_to_array(C1, C2, p, q) + nx = get_backend(C1, C2, p, q) + + len_p = p.shape[0] + len_q = q.shape[0] generator = check_random_state(random_state) @@ -880,12 +972,12 @@ def sampled_gromov_wasserstein(C1, C2, p, q, loss_fun, nb_samples_grad_p, nb_samples_grad_q = nb_samples_grad, 1 else: nb_samples_grad_p, nb_samples_grad_q = nb_samples_grad - T = np.outer(p, q) + T = nx.outer(p, q) # continue_loop allows to stop the loop if there is several successive small modification of T. continue_loop = 0 # The gradient of GW is more complex if the two matrices are not symmetric. - C_are_symmetric = np.allclose(C1, C1.T, rtol=1e-10, atol=1e-10) and np.allclose(C2, C2.T, rtol=1e-10, atol=1e-10) + C_are_symmetric = nx.allclose(C1, C1.T, rtol=1e-10, atol=1e-10) and nx.allclose(C2, C2.T, rtol=1e-10, atol=1e-10) for cpt in range(max_iter): index0 = generator.choice(len_p, size=nb_samples_grad_p, p=p, replace=False) @@ -893,28 +985,30 @@ def sampled_gromov_wasserstein(C1, C2, p, q, loss_fun, for i, index0_i in enumerate(index0): index1 = generator.choice(len_q, size=nb_samples_grad_q, - p=T[index0_i, :] / T[index0_i, :].sum(), + p=T[index0_i, :] / nx.sum(T[index0_i, :]), replace=False) # If the matrices C are not symmetric, the gradient has 2 terms, thus the term is chosen randomly. if (not C_are_symmetric) and generator.rand(1) > 0.5: - Lik += np.mean(loss_fun(np.expand_dims(C1[:, np.repeat(index0[i], nb_samples_grad_q)], 1), - np.expand_dims(C2[:, index1], 0)), - axis=2) + Lik += nx.mean(loss_fun( + C1[:, [index0[i]] * nb_samples_grad_q][:, None, :], + C2[:, index1][None, :, :] + ), axis=2) else: - Lik += np.mean(loss_fun(np.expand_dims(C1[np.repeat(index0[i], nb_samples_grad_q), :], 2), - np.expand_dims(C2[index1, :], 1)), - axis=0) + Lik += nx.mean(loss_fun( + C1[[index0[i]] * nb_samples_grad_q, :][:, :, None], + C2[index1, :][:, None, :] + ), axis=0) - max_Lik = np.max(Lik) + max_Lik = nx.max(Lik) if max_Lik == 0: continue # This division by the max is here to facilitate the choice of epsilon. Lik /= max_Lik if epsilon > 0: - # Set to infinity all the numbers bellow exp(-200) to avoid log of 0. - log_T = np.log(np.clip(T, np.exp(-200), 1)) - log_T[log_T == -200] = -np.inf + # Set to infinity all the numbers below exp(-200) to avoid log of 0. + log_T = nx.log(nx.clip(T, np.exp(-200), 1)) + log_T = nx.where(log_T == -200, -np.inf, log_T) Lik = Lik - epsilon * log_T try: @@ -925,11 +1019,11 @@ def sampled_gromov_wasserstein(C1, C2, p, q, loss_fun, else: new_T = emd(a=p, b=q, M=Lik) - change_T = ((T - new_T) ** 2).mean() + change_T = nx.mean((T - new_T) ** 2) if change_T <= 10e-20: continue_loop += 1 if continue_loop > 100: # Number max of low modifications of T - T = new_T.copy() + T = nx.copy(new_T) break else: continue_loop = 0 @@ -938,12 +1032,14 @@ def sampled_gromov_wasserstein(C1, C2, p, q, loss_fun, if cpt % 200 == 0: print('{:5s}|{:12s}'.format('It.', '||T_n - T_{n+1}||') + '\n' + '-' * 19) print('{:5d}|{:8e}|'.format(cpt, change_T)) - T = new_T.copy() + T = nx.copy(new_T) if log: log = {} - log["gw_dist_estimated"], log["gw_dist_std"] = GW_distance_estimation(C1=C1, C2=C2, loss_fun=loss_fun, - p=p, q=q, T=T, random_state=generator) + log["gw_dist_estimated"], log["gw_dist_std"] = GW_distance_estimation( + C1=C1, C2=C2, loss_fun=loss_fun, + p=p, q=q, T=T, random_state=generator + ) return T, log return T @@ -951,38 +1047,37 @@ def sampled_gromov_wasserstein(C1, C2, p, q, loss_fun, def entropic_gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, tol=1e-9, verbose=False, log=False): r""" - Returns the gromov-wasserstein transport between (C1,p) and (C2,q) - - (C1,p) and (C2,q) + Returns the gromov-wasserstein transport between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` The function solves the following optimization problem: .. math:: - GW = arg\min_T \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l}-\epsilon(H(T)) + \mathbf{GW} = \mathop{\arg\min}_\mathbf{T} \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} - \epsilon(H(\mathbf{T})) - s.t. T 1 = p + s.t. \ \mathbf{T} \mathbf{1} &= \mathbf{p} - T^T 1= q + \mathbf{T}^T \mathbf{1} &= \mathbf{q} - T\geq 0 + \mathbf{T} &\geq 0 Where : - - C1 : Metric cost matrix in the source space - - C2 : Metric cost matrix in the target space - - p : distribution in the source space - - q : distribution in the target space - - L : loss function to account for the misfit between the similarity matrices - - H : entropy + + - :math:`\mathbf{C_1}`: Metric cost matrix in the source space + - :math:`\mathbf{C_2}`: Metric cost matrix in the target space + - :math:`\mathbf{p}`: distribution in the source space + - :math:`\mathbf{q}`: distribution in the target space + - `L`: loss function to account for the misfit between the similarity matrices + - `H`: entropy Parameters ---------- - C1 : ndarray, shape (ns, ns) + C1 : array-like, shape (ns, ns) Metric cost matrix in the source space - C2 : ndarray, shape (nt, nt) - Metric costfr matrix in the target space - p : ndarray, shape (ns,) + C2 : array-like, shape (nt, nt) + Metric cost matrix in the target space + p : array-like, shape (ns,) Distribution in the source space - q : ndarray, shape (nt,) + q : array-like, shape (nt,) Distribution in the target space loss_fun : string Loss function used for the solver either 'square_loss' or 'kl_loss' @@ -999,21 +1094,20 @@ def entropic_gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, Returns ------- - T : ndarray, shape (ns, nt) + T : array-like, shape (`ns`, `nt`) Optimal coupling between the two spaces References ---------- - .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon, + .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon, "Gromov-Wasserstein averaging of kernel and distance matrices." International Conference on Machine Learning (ICML). 2016. """ + C1, C2, p, q = list_to_array(C1, C2, p, q) + nx = get_backend(C1, C2, p, q) - C1 = np.asarray(C1, dtype=np.float64) - C2 = np.asarray(C2, dtype=np.float64) - - T = np.outer(p, q) # Initialization + T = nx.outer(p, q) constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun) @@ -1035,7 +1129,7 @@ def entropic_gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, if cpt % 10 == 0: # we can speed up the process by checking for the error only all # the 10th iterations - err = np.linalg.norm(T - Tprev) + err = nx.norm(T - Tprev) if log: log['err'].append(err) @@ -1058,32 +1152,31 @@ def entropic_gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, def entropic_gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, tol=1e-9, verbose=False, log=False): r""" - Returns the entropic gromov-wasserstein discrepancy between the two measured similarity matrices - - (C1,p) and (C2,q) + Returns the entropic gromov-wasserstein discrepancy between the two measured similarity matrices :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` The function solves the following optimization problem: .. math:: - GW = \min_T \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l}-\epsilon(H(T)) + GW = \min_\mathbf{T} \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} - \epsilon(H(\mathbf{T})) Where : - - C1 : Metric cost matrix in the source space - - C2 : Metric cost matrix in the target space - - p : distribution in the source space - - q : distribution in the target space - - L : loss function to account for the misfit between the similarity matrices - - H : entropy + + - :math:`\mathbf{C_1}`: Metric cost matrix in the source space + - :math:`\mathbf{C_2}`: Metric cost matrix in the target space + - :math:`\mathbf{p}`: distribution in the source space + - :math:`\mathbf{q}`: distribution in the target space + - `L`: loss function to account for the misfit between the similarity matrices + - `H`: entropy Parameters ---------- - C1 : ndarray, shape (ns, ns) + C1 : array-like, shape (ns, ns) Metric cost matrix in the source space - C2 : ndarray, shape (nt, nt) - Metric costfr matrix in the target space - p : ndarray, shape (ns,) + C2 : array-like, shape (nt, nt) + Metric cost matrix in the target space + p : array-like, shape (ns,) Distribution in the source space - q : ndarray, shape (nt,) + q : array-like, shape (nt,) Distribution in the target space loss_fun : str Loss function used for the solver either 'square_loss' or 'kl_loss' @@ -1105,7 +1198,7 @@ def entropic_gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon, References ---------- - .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon, + .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon, "Gromov-Wasserstein averaging of kernel and distance matrices." International Conference on Machine Learning (ICML). 2016. @@ -1122,40 +1215,39 @@ def entropic_gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon, def entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, - max_iter=1000, tol=1e-9, verbose=False, log=False, init_C=None): + max_iter=1000, tol=1e-9, verbose=False, log=False, init_C=None, random_state=None): r""" - Returns the gromov-wasserstein barycenters of S measured similarity matrices - - (Cs)_{s=1}^{s=S} + Returns the gromov-wasserstein barycenters of `S` measured similarity matrices :math:`(\mathbf{C}_s)_{1 \leq s \leq S}` The function solves the following optimization problem: .. math:: - C = argmin_{C\in R^{NxN}} \sum_s \lambda_s GW(C,C_s,p,p_s) + \mathbf{C} = \mathop{\arg \min}_{\mathbf{C}\in \mathbb{R}^{N \times N}} \sum_s \lambda_s \mathrm{GW}(\mathbf{C}, \mathbf{C}_s, \mathbf{p}, \mathbf{p}_s) Where : - - :math:`C_s` : metric cost matrix - - :math:`p_s` : distribution + - :math:`\mathbf{C}_s`: metric cost matrix + - :math:`\mathbf{p}_s`: distribution Parameters ---------- N : int Size of the targeted barycenter - Cs : list of S np.ndarray of shape (ns,ns) + Cs : list of S array-like of shape (ns,ns) Metric cost matrices - ps : list of S np.ndarray of shape (ns,) - Sample weights in the S spaces - p : ndarray, shape(N,) + ps : list of S array-like of shape (ns,) + Sample weights in the `S` spaces + p : array-like, shape(N,) Weights in the targeted barycenter lambdas : list of float - List of the S spaces' weights. + List of the `S` spaces' weights. loss_fun : callable Tensor-matrix multiplication function based on specific loss function. update : callable - function(p,lambdas,T,Cs) that updates C according to a specific Kernel - with the S Ts couplings calculated at each iteration + function(:math:`\mathbf{p}`, lambdas, :math:`\mathbf{T}`, :math:`\mathbf{Cs}`) that updates + :math:`\mathbf{C}` according to a specific Kernel with the `S` :math:`\mathbf{T}_s` couplings + calculated at each iteration epsilon : float Regularization term >0 max_iter : int, optional @@ -1166,32 +1258,36 @@ def entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, Print information along iterations. log : bool, optional Record log if True. - init_C : bool | ndarray, shape (N, N) - Random initial value for the C matrix provided by user. + init_C : bool | array-like, shape (N, N) + Random initial value for the :math:`\mathbf{C}` matrix provided by user. + random_state : int or RandomState instance, optional + Fix the seed for reproducibility Returns ------- - C : ndarray, shape (N, N) + C : array-like, shape (`N`, `N`) Similarity matrix in the barycenter space (permutated arbitrarily) References ---------- - .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon, + .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon, "Gromov-Wasserstein averaging of kernel and distance matrices." International Conference on Machine Learning (ICML). 2016. """ + Cs = list_to_array(*Cs) + ps = list_to_array(*ps) + p = list_to_array(p) + nx = get_backend(*Cs, *ps, p) S = len(Cs) - Cs = [np.asarray(Cs[s], dtype=np.float64) for s in range(S)] - lambdas = np.asarray(lambdas, dtype=np.float64) - # Initialization of C : random SPD matrix (if not provided by user) if init_C is None: - # XXX use random state - xalea = np.random.randn(N, 2) + generator = check_random_state(random_state) + xalea = generator.randn(N, 2) C = dist(xalea, xalea) C /= C.max() + C = nx.from_numpy(C, type_as=p) else: C = init_C @@ -1214,7 +1310,7 @@ def entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, if cpt % 10 == 0: # we can speed up the process by checking for the error only all # the 10th iterations - err = np.linalg.norm(C - Cprev) + err = nx.norm(C - Cprev) error.append(err) if log: @@ -1232,38 +1328,39 @@ def entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, - max_iter=1000, tol=1e-9, verbose=False, log=False, init_C=None): + max_iter=1000, tol=1e-9, verbose=False, log=False, init_C=None, random_state=None): r""" - Returns the gromov-wasserstein barycenters of S measured similarity matrices - - (Cs)_{s=1}^{s=S} + Returns the gromov-wasserstein barycenters of `S` measured similarity matrices :math:`(\mathbf{C}_s)_{1 \leq s \leq S}` - The function solves the following optimization problem with block - coordinate descent: + The function solves the following optimization problem with block coordinate descent: .. math:: - C = argmin_C\in R^NxN \sum_s \lambda_s GW(C,Cs,p,ps) + + \mathbf{C} = \mathop{\arg \min}_{\mathbf{C}\in \mathbb{R}^{N \times N}} \sum_s \lambda_s \mathrm{GW}(\mathbf{C}, \mathbf{C}_s, \mathbf{p}, \mathbf{p}_s) Where : - - Cs : metric cost matrix - - ps : distribution + - :math:`\mathbf{C}_s`: metric cost matrix + - :math:`\mathbf{p}_s`: distribution Parameters ---------- N : int Size of the targeted barycenter - Cs : list of S np.ndarray of shape (ns, ns) + Cs : list of S array-like of shape (ns, ns) Metric cost matrices - ps : list of S np.ndarray of shape (ns,) - Sample weights in the S spaces - p : ndarray, shape (N,) + ps : list of S array-like of shape (ns,) + Sample weights in the `S` spaces + p : array-like, shape (N,) Weights in the targeted barycenter lambdas : list of float - List of the S spaces' weights - loss_fun : tensor-matrix multiplication function based on specific loss function - update : function(p,lambdas,T,Cs) that updates C according to a specific Kernel - with the S Ts couplings calculated at each iteration + List of the `S` spaces' weights + loss_fun : callable + tensor-matrix multiplication function based on specific loss function + update : callable + function(:math:`\mathbf{p}`, lambdas, :math:`\mathbf{T}`, :math:`\mathbf{Cs}`) that updates + :math:`\mathbf{C}` according to a specific Kernel with the `S` :math:`\mathbf{T}_s` couplings + calculated at each iteration max_iter : int, optional Max number of iterations tol : float, optional @@ -1272,32 +1369,37 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, Print information along iterations. log : bool, optional Record log if True. - init_C : bool | ndarray, shape(N,N) - Random initial value for the C matrix provided by user. + init_C : bool | array-like, shape(N,N) + Random initial value for the :math:`\mathbf{C}` matrix provided by user. + random_state : int or RandomState instance, optional + Fix the seed for reproducibility Returns ------- - C : ndarray, shape (N, N) + C : array-like, shape (`N`, `N`) Similarity matrix in the barycenter space (permutated arbitrarily) References ---------- - .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon, + .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon, "Gromov-Wasserstein averaging of kernel and distance matrices." International Conference on Machine Learning (ICML). 2016. """ - S = len(Cs) + Cs = list_to_array(*Cs) + ps = list_to_array(*ps) + p = list_to_array(p) + nx = get_backend(*Cs, *ps, p) - Cs = [np.asarray(Cs[s], dtype=np.float64) for s in range(S)] - lambdas = np.asarray(lambdas, dtype=np.float64) + S = len(Cs) # Initialization of C : random SPD matrix (if not provided by user) if init_C is None: - # XXX : should use a random state and not use the global seed - xalea = np.random.randn(N, 2) + generator = check_random_state(random_state) + xalea = generator.randn(N, 2) C = dist(xalea, xalea) C /= C.max() + C = nx.from_numpy(C, type_as=p) else: C = init_C @@ -1320,7 +1422,7 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, if cpt % 10 == 0: # we can speed up the process by checking for the error only all # the 10th iterations - err = np.linalg.norm(C - Cprev) + err = nx.norm(C - Cprev) error.append(err) if log: @@ -1339,21 +1441,21 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_features=False, p=None, loss_fun='square_loss', max_iter=100, tol=1e-9, - verbose=False, log=False, init_C=None, init_X=None): - """Compute the fgw barycenter as presented eq (5) in [24]. + verbose=False, log=False, init_C=None, init_X=None, random_state=None): + """Compute the fgw barycenter as presented eq (5) in :ref:`[24] ` Parameters ---------- - N : integer + N : int Desired number of samples of the target barycenter - Ys: list of ndarray, each element has shape (ns,d) + Ys: list of array-like, each element has shape (ns,d) Features of all samples - Cs : list of ndarray, each element has shape (ns,ns) + Cs : list of array-like, each element has shape (ns,ns) Structure matrices of all samples - ps : list of ndarray, each element has shape (ns,) + ps : list of array-like, each element has shape (ns,) Masses of all samples. lambdas : list of float - List of the S spaces' weights + List of the `S` spaces' weights alpha : float Alpha parameter for the fgw distance fixed_structure : bool @@ -1370,41 +1472,46 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_ Print information along iterations. log : bool, optional Record log if True. - init_C : ndarray, shape (N,N), optional + init_C : array-like, shape (N,N), optional Initialization for the barycenters' structure matrix. If not set a random init is used. - init_X : ndarray, shape (N,d), optional + init_X : array-like, shape (N,d), optional Initialization for the barycenters' features. If not set a random init is used. + random_state : int or RandomState instance, optional + Fix the seed for reproducibility Returns ------- - X : ndarray, shape (N, d) + X : array-like, shape (`N`, `d`) Barycenters' features - C : ndarray, shape (N, N) + C : array-like, shape (`N`, `N`) Barycenters' structure matrix - log_: dict + log : dict Only returned when log=True. It contains the keys: - T : list of (N,ns) transport matrices - Ms : all distance matrices between the feature of the barycenter and the - other features dist(X,Ys) shape (N,ns) + - :math:`\mathbf{T}`: list of (`N`, `ns`) transport matrices + - :math:`(\mathbf{M}_s)_s`: all distance matrices between the feature of the barycenter and the other features :math:`(dist(\mathbf{X}, \mathbf{Y}_s))_s` shape (`N`, `ns`) + + + .. _references-fgw-barycenters: References ---------- - .. [24] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain + .. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain and Courty Nicolas "Optimal Transport for structured data with application on graphs" International Conference on Machine Learning (ICML). 2019. """ + Cs = list_to_array(*Cs) + ps = list_to_array(*ps) + Ys = list_to_array(*Ys) + p = list_to_array(p) + nx = get_backend(*Cs, *Ys, *ps) + S = len(Cs) d = Ys[0].shape[1] # dimension on the node features if p is None: - p = np.ones(N) / N - - Cs = [np.asarray(Cs[s], dtype=np.float64) for s in range(S)] - Ys = [np.asarray(Ys[s], dtype=np.float64) for s in range(S)] - - lambdas = np.asarray(lambdas, dtype=np.float64) + p = nx.ones(N, type_as=Cs[0]) / N if fixed_structure: if init_C is None: @@ -1413,8 +1520,10 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_ C = init_C else: if init_C is None: - xalea = np.random.randn(N, 2) + generator = check_random_state(random_state) + xalea = generator.randn(N, 2) C = dist(xalea, xalea) + C = nx.from_numpy(C, type_as=ps[0]) else: C = init_C @@ -1425,13 +1534,13 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_ X = init_X else: if init_X is None: - X = np.zeros((N, d)) + X = nx.zeros((N, d), type_as=ps[0]) else: X = init_X - T = [np.outer(p, q) for q in ps] + T = [nx.outer(p, q) for q in ps] - Ms = [np.asarray(dist(X, Ys[s]), dtype=np.float64) for s in range(len(Ys))] # Ms is N,ns + Ms = [dist(X, Ys[s]) for s in range(len(Ys))] cpt = 0 err_feature = 1 @@ -1451,20 +1560,19 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_ Ys_temp = [y.T for y in Ys] X = update_feature_matrix(lambdas, Ys_temp, T, p).T - Ms = [np.asarray(dist(X, Ys[s]), dtype=np.float64) for s in range(len(Ys))] + Ms = [dist(X, Ys[s]) for s in range(len(Ys))] if not fixed_structure: if loss_fun == 'square_loss': T_temp = [t.T for t in T] - C = update_sructure_matrix(p, lambdas, T_temp, Cs) + C = update_structure_matrix(p, lambdas, T_temp, Cs) T = [fused_gromov_wasserstein(Ms[s], C, Cs[s], p, ps[s], loss_fun, alpha, numItermax=max_iter, stopThr=1e-5, verbose=verbose) for s in range(S)] # T is N,ns - err_feature = np.linalg.norm(X - Xprev.reshape(N, d)) - err_structure = np.linalg.norm(C - Cprev) - + err_feature = nx.norm(X - nx.reshape(Xprev, (N, d))) + err_structure = nx.norm(C - Cprev) if log: log_['err_feature'].append(err_feature) log_['err_structure'].append(err_structure) @@ -1490,64 +1598,80 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_ return X, C -def update_sructure_matrix(p, lambdas, T, Cs): - """Updates C according to the L2 Loss kernel with the S Ts couplings. +def update_structure_matrix(p, lambdas, T, Cs): + """Updates :math:`\mathbf{C}` according to the L2 Loss kernel with the `S` :math:`\mathbf{T}_s` couplings. It is calculated at each iteration Parameters ---------- - p : ndarray, shape (N,) + p : array-like, shape (N,) Masses in the targeted barycenter. lambdas : list of float - List of the S spaces' weights. - T : list of S ndarray of shape (ns, N) - The S Ts couplings calculated at each iteration. - Cs : list of S ndarray, shape (ns, ns) - Metric cost matrices. + List of the `S` spaces' weights. + T : list of S array-like of shape (ns, N) + The `S` :math:`\mathbf{T}_s` couplings calculated at each iteration. + Cs : list of S array-like, shape (ns, ns) + Metric cost matrices. Returns ------- - C : ndarray, shape (nt, nt) - Updated C matrix. + C : array-like, shape (`nt`, `nt`) + Updated :math:`\mathbf{C}` matrix. """ - tmpsum = sum([lambdas[s] * np.dot(T[s].T, Cs[s]).dot(T[s]) for s in range(len(T))]) - ppt = np.outer(p, p) + p = list_to_array(p) + T = list_to_array(*T) + Cs = list_to_array(*Cs) + nx = get_backend(*Cs, *T, p) - return np.divide(tmpsum, ppt) + tmpsum = sum([ + lambdas[s] * nx.dot( + nx.dot(T[s].T, Cs[s]), + T[s] + ) for s in range(len(T)) + ]) + ppt = nx.outer(p, p) + return tmpsum / ppt def update_feature_matrix(lambdas, Ys, Ts, p): - """Updates the feature with respect to the S Ts couplings. + """Updates the feature with respect to the `S` :math:`\mathbf{T}_s` couplings. See "Solving the barycenter problem with Block Coordinate Descent (BCD)" - in [24] calculated at each iteration + in :ref:`[24] ` calculated at each iteration Parameters ---------- - p : ndarray, shape (N,) + p : array-like, shape (N,) masses in the targeted barycenter lambdas : list of float - List of the S spaces' weights - Ts : list of S np.ndarray(ns,N) - the S Ts couplings calculated at each iteration - Ys : list of S ndarray, shape(d,ns) + List of the `S` spaces' weights + Ts : list of S array-like, shape (ns,N) + The `S` :math:`\mathbf{T}_s` couplings calculated at each iteration + Ys : list of S array-like, shape (d,ns) The features. Returns ------- - X : ndarray, shape (d, N) + X : array-like, shape (`d`, `N`) + + .. _references-update-feature-matrix: References ---------- - .. [24] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain - and Courty Nicolas + .. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain and Courty Nicolas "Optimal Transport for structured data with application on graphs" International Conference on Machine Learning (ICML). 2019. """ - p = np.array(1. / p).reshape(-1,) - - tmpsum = sum([lambdas[s] * np.dot(Ys[s], Ts[s].T) * p[None, :] for s in range(len(Ts))]) - + p = list_to_array(p) + Ts = list_to_array(*Ts) + Ys = list_to_array(*Ys) + nx = get_backend(*Ys, *Ts, p) + + p = 1. / p + tmpsum = sum([ + lambdas[s] * nx.dot(Ys[s], Ts[s].T) * p[None, :] + for s in range(len(Ts)) + ]) return tmpsum diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index c6757d1..4e95ccf 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -691,10 +691,8 @@ def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True, ot.lp.emd2_1d : EMD for 1d distributions (returns cost instead of the transportation matrix) """ - a = np.asarray(a, dtype=np.float64) - b = np.asarray(b, dtype=np.float64) - x_a = np.asarray(x_a, dtype=np.float64) - x_b = np.asarray(x_b, dtype=np.float64) + a, b, x_a, x_b = list_to_array(a, b, x_a, x_b) + nx = get_backend(x_a, x_b) assert (x_a.ndim == 1 or x_a.ndim == 2 and x_a.shape[1] == 1), \ "emd_1d should only be used with monodimensional data" @@ -702,27 +700,43 @@ def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True, "emd_1d should only be used with monodimensional data" # if empty array given then use uniform distributions - if a.ndim == 0 or len(a) == 0: - a = np.ones((x_a.shape[0],), dtype=np.float64) / x_a.shape[0] - if b.ndim == 0 or len(b) == 0: - b = np.ones((x_b.shape[0],), dtype=np.float64) / x_b.shape[0] + if a is None or a.ndim == 0 or len(a) == 0: + a = nx.ones((x_a.shape[0],), type_as=x_a) / x_a.shape[0] + if b is None or b.ndim == 0 or len(b) == 0: + b = nx.ones((x_b.shape[0],), type_as=x_b) / x_b.shape[0] # ensure that same mass - np.testing.assert_almost_equal(a.sum(0),b.sum(0),err_msg='a and b vector must have the same sum') - b=b*a.sum()/b.sum() - - x_a_1d = x_a.reshape((-1,)) - x_b_1d = x_b.reshape((-1,)) - perm_a = np.argsort(x_a_1d) - perm_b = np.argsort(x_b_1d) - - G_sorted, indices, cost = emd_1d_sorted(a[perm_a], b[perm_b], - x_a_1d[perm_a], x_b_1d[perm_b], - metric=metric, p=p) - G = coo_matrix((G_sorted, (perm_a[indices[:, 0]], perm_b[indices[:, 1]])), - shape=(a.shape[0], b.shape[0])) + np.testing.assert_almost_equal( + nx.sum(a, axis=0), + nx.sum(b, axis=0), + err_msg='a and b vector must have the same sum' + ) + b = b * nx.sum(a) / nx.sum(b) + + x_a_1d = nx.reshape(x_a, (-1,)) + x_b_1d = nx.reshape(x_b, (-1,)) + perm_a = nx.argsort(x_a_1d) + perm_b = nx.argsort(x_b_1d) + + G_sorted, indices, cost = emd_1d_sorted( + nx.to_numpy(a[perm_a]), + nx.to_numpy(b[perm_b]), + nx.to_numpy(x_a_1d[perm_a]), + nx.to_numpy(x_b_1d[perm_b]), + metric=metric, p=p + ) + + G = nx.coo_matrix( + G_sorted, + perm_a[indices[:, 0]], + perm_b[indices[:, 1]], + shape=(a.shape[0], b.shape[0]), + type_as=x_a + ) if dense: - G = G.toarray() + G = nx.todense(G) + elif str(nx) == "jax": + warnings.warn("JAX does not support sparse matrices, converting to dense") if log: log = {'cost': cost} return G, log diff --git a/ot/optim.py b/ot/optim.py index 34cbb17..6456c03 100644 --- a/ot/optim.py +++ b/ot/optim.py @@ -23,7 +23,7 @@ def line_search_armijo(f, xk, pk, gfk, old_fval, r""" Armijo linesearch function that works with matrices - Find an approximate minimum of :math:`f(x_k + \\alpha \cdot p_k)` that satisfies the + Find an approximate minimum of :math:`f(x_k + \alpha \cdot p_k)` that satisfies the armijo conditions. Parameters @@ -129,7 +129,7 @@ def solve_linesearch(cost, G, deltaG, Mi, f_val, .. _references-solve-linesearch: References ---------- - .. [24] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain and Courty Nicolas + .. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain and Courty Nicolas "Optimal Transport for structured data with application on graphs" International Conference on Machine Learning (ICML). 2019. """ @@ -162,13 +162,13 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000, The function solves the following optimization problem: .. math:: - \gamma = arg\min_\gamma <\gamma,M>_F + \mathrm{reg} \cdot f(\gamma) + \gamma = \mathop{\arg \min}_\gamma <\gamma, \mathbf{M}>_F + \mathrm{reg} \cdot f(\gamma) - s.t. \ \gamma 1 = a + s.t. \ \gamma \mathbf{1} &= \mathbf{a} - \gamma^T 1= b + \gamma^T \mathbf{1} &= \mathbf{b} - \gamma\geq 0 + \gamma &\geq 0 where : - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix @@ -309,13 +309,13 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10, The function solves the following optimization problem: .. math:: - \gamma = arg\min_\gamma <\gamma,M>_F + \mathrm{reg_1}\cdot\Omega(\gamma) + \mathrm{reg_2}\cdot f(\gamma) + \gamma = \mathop{\arg \min}_\gamma <\gamma, \mathbf{M}>_F + \mathrm{reg_1}\cdot\Omega(\gamma) + \mathrm{reg_2}\cdot f(\gamma) - s.t. \ \gamma 1 = a + s.t. \ \gamma \mathbf{1} &= \mathbf{a} - \gamma^T 1= b + \gamma^T \mathbf{1} &= \mathbf{b} - \gamma\geq 0 + \gamma &\geq 0 where : - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix @@ -452,7 +452,7 @@ def solve_1d_linesearch_quad(a, b, c): .. math:: - arg\min_{0 \leq x \leq 1} f(x) = ax^{2} + bx + c + \mathop{\arg \min}_{0 \leq x \leq 1} f(x) = ax^{2} + bx + c Parameters ---------- diff --git a/test/test_backend.py b/test/test_backend.py index 5853282..0f11ace 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -207,6 +207,22 @@ def test_empty_backend(): nx.stack([M, M]) with pytest.raises(NotImplementedError): nx.reshape(M, (5, 3, 2)) + with pytest.raises(NotImplementedError): + nx.coo_matrix(M, M, M) + with pytest.raises(NotImplementedError): + nx.issparse(M) + with pytest.raises(NotImplementedError): + nx.tocsr(M) + with pytest.raises(NotImplementedError): + nx.eliminate_zeros(M) + with pytest.raises(NotImplementedError): + nx.todense(M) + with pytest.raises(NotImplementedError): + nx.where(M, M, M) + with pytest.raises(NotImplementedError): + nx.copy(M) + with pytest.raises(NotImplementedError): + nx.allclose(M, M) def test_func_backends(nx): @@ -216,6 +232,11 @@ def test_func_backends(nx): v = rnd.randn(3) val = np.array([1.0]) + # Sparse tensors test + sp_row = np.array([0, 3, 1, 0, 3]) + sp_col = np.array([0, 3, 1, 2, 2]) + sp_data = np.array([4, 5, 7, 9, 0]) + lst_tot = [] for nx in [ot.backend.NumpyBackend(), nx]: @@ -229,6 +250,10 @@ def test_func_backends(nx): vb = nx.from_numpy(v) val = nx.from_numpy(val) + sp_rowb = nx.from_numpy(sp_row) + sp_colb = nx.from_numpy(sp_col) + sp_datab = nx.from_numpy(sp_data) + A = nx.set_gradients(val, v, v) lst_b.append(nx.to_numpy(A)) lst_name.append('set_gradients') @@ -438,6 +463,37 @@ def test_func_backends(nx): lst_b.append(nx.to_numpy(A)) lst_name.append('reshape') + sp_Mb = nx.coo_matrix(sp_datab, sp_rowb, sp_colb, shape=(4, 4)) + nx.todense(Mb) + lst_b.append(nx.to_numpy(nx.todense(sp_Mb))) + lst_name.append('coo_matrix') + + assert not nx.issparse(Mb), 'Assert fail on: issparse (expected False)' + assert nx.issparse(sp_Mb) or nx.__name__ == "jax", 'Assert fail on: issparse (expected True)' + + A = nx.tocsr(sp_Mb) + lst_b.append(nx.to_numpy(nx.todense(A))) + lst_name.append('tocsr') + + A = nx.eliminate_zeros(nx.copy(sp_datab), threshold=5.) + lst_b.append(nx.to_numpy(A)) + lst_name.append('eliminate_zeros (dense)') + + A = nx.eliminate_zeros(sp_Mb) + lst_b.append(nx.to_numpy(nx.todense(A))) + lst_name.append('eliminate_zeros (sparse)') + + A = nx.where(Mb >= nx.stack([nx.linspace(0, 1, 10)] * 3, axis=1), Mb, 0.0) + lst_b.append(nx.to_numpy(A)) + lst_name.append('where') + + A = nx.copy(Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('copy') + + assert nx.allclose(Mb, Mb), 'Assert fail on: allclose (expected True)' + assert not nx.allclose(2 * Mb, Mb), 'Assert fail on: allclose (expected False)' + lst_tot.append(lst_b) lst_np = lst_tot[0] diff --git a/test/test_bregman.py b/test/test_bregman.py index c1120ba..6923d31 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -477,8 +477,8 @@ def test_lazy_empirical_sinkhorn(nx): b = ot.unif(n) numIterMax = 1000 - X_s = np.reshape(np.arange(n), (n, 1)) - X_t = np.reshape(np.arange(0, n), (n, 1)) + X_s = np.reshape(np.arange(n, dtype=np.float64), (n, 1)) + X_t = np.reshape(np.arange(0, n, dtype=np.float64), (n, 1)) M = ot.dist(X_s, X_t) M_m = ot.dist(X_s, X_t, metric='euclidean') diff --git a/test/test_gromov.py b/test/test_gromov.py index 0242d72..509c54d 100644 --- a/test/test_gromov.py +++ b/test/test_gromov.py @@ -8,11 +8,12 @@ import numpy as np import ot +from ot.backend import NumpyBackend import pytest -def test_gromov(): +def test_gromov(nx): n_samples = 50 # nb samples mu_s = np.array([0, 0]) @@ -31,37 +32,50 @@ def test_gromov(): C1 /= C1.max() C2 /= C2.max() + C1b = nx.from_numpy(C1) + C2b = nx.from_numpy(C2) + pb = nx.from_numpy(p) + qb = nx.from_numpy(q) + G = ot.gromov.gromov_wasserstein(C1, C2, p, q, 'square_loss', verbose=True) + Gb = nx.to_numpy(ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', verbose=True)) # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) np.testing.assert_allclose( - p, G.sum(1), atol=1e-04) # cf convergence gromov + p, Gb.sum(1), atol=1e-04) # cf convergence gromov np.testing.assert_allclose( - q, G.sum(0), atol=1e-04) # cf convergence gromov + q, Gb.sum(0), atol=1e-04) # cf convergence gromov Id = (1 / (1.0 * n_samples)) * np.eye(n_samples, n_samples) - np.testing.assert_allclose( - G, np.flipud(Id), atol=1e-04) + np.testing.assert_allclose(Gb, np.flipud(Id), atol=1e-04) gw, log = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', log=True) + gwb, logb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', log=True) gw_val = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', log=False) + gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', log=False) G = log['T'] + Gb = nx.to_numpy(logb['T']) - np.testing.assert_allclose(gw, 0, atol=1e-1, rtol=1e-1) + np.testing.assert_allclose(gw, gwb, atol=1e-06) + np.testing.assert_allclose(gwb, 0, atol=1e-1, rtol=1e-1) - np.testing.assert_allclose(gw, gw_val, atol=1e-1, rtol=1e-1) # cf log=False + np.testing.assert_allclose(gw_val, gw_valb, atol=1e-06) + np.testing.assert_allclose(gwb, gw_valb, atol=1e-1, rtol=1e-1) # cf log=False # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) np.testing.assert_allclose( - p, G.sum(1), atol=1e-04) # cf convergence gromov + p, Gb.sum(1), atol=1e-04) # cf convergence gromov np.testing.assert_allclose( - q, G.sum(0), atol=1e-04) # cf convergence gromov + q, Gb.sum(0), atol=1e-04) # cf convergence gromov -def test_entropic_gromov(): +@pytest.skip_backend("jax", reason="test very slow with jax backend") +def test_entropic_gromov(nx): n_samples = 50 # nb samples mu_s = np.array([0, 0]) @@ -80,30 +94,44 @@ def test_entropic_gromov(): C1 /= C1.max() C2 /= C2.max() + C1b = nx.from_numpy(C1) + C2b = nx.from_numpy(C2) + pb = nx.from_numpy(p) + qb = nx.from_numpy(q) + G = ot.gromov.entropic_gromov_wasserstein( C1, C2, p, q, 'square_loss', epsilon=5e-4, verbose=True) + Gb = nx.to_numpy(ot.gromov.entropic_gromov_wasserstein( + C1b, C2b, pb, qb, 'square_loss', epsilon=5e-4, verbose=True + )) # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) np.testing.assert_allclose( - p, G.sum(1), atol=1e-04) # cf convergence gromov + p, Gb.sum(1), atol=1e-04) # cf convergence gromov np.testing.assert_allclose( - q, G.sum(0), atol=1e-04) # cf convergence gromov + q, Gb.sum(0), atol=1e-04) # cf convergence gromov gw, log = ot.gromov.entropic_gromov_wasserstein2( C1, C2, p, q, 'kl_loss', epsilon=1e-2, log=True) + gwb, logb = ot.gromov.entropic_gromov_wasserstein2( + C1b, C2b, pb, qb, 'kl_loss', epsilon=1e-2, log=True) G = log['T'] + Gb = nx.to_numpy(logb['T']) + np.testing.assert_allclose(gw, gwb, atol=1e-06) np.testing.assert_allclose(gw, 0, atol=1e-1, rtol=1e-1) # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) np.testing.assert_allclose( - p, G.sum(1), atol=1e-04) # cf convergence gromov + p, Gb.sum(1), atol=1e-04) # cf convergence gromov np.testing.assert_allclose( - q, G.sum(0), atol=1e-04) # cf convergence gromov + q, Gb.sum(0), atol=1e-04) # cf convergence gromov -def test_pointwise_gromov(): +def test_pointwise_gromov(nx): n_samples = 50 # nb samples mu_s = np.array([0, 0]) @@ -122,33 +150,52 @@ def test_pointwise_gromov(): C1 /= C1.max() C2 /= C2.max() + C1b = nx.from_numpy(C1) + C2b = nx.from_numpy(C2) + pb = nx.from_numpy(p) + qb = nx.from_numpy(q) + def loss(x, y): return np.abs(x - y) + def lossb(x, y): + return nx.abs(x - y) + G, log = ot.gromov.pointwise_gromov_wasserstein( C1, C2, p, q, loss, max_iter=100, log=True, verbose=True, random_state=42) + G = NumpyBackend().todense(G) + Gb, logb = ot.gromov.pointwise_gromov_wasserstein( + C1b, C2b, pb, qb, lossb, max_iter=100, log=True, verbose=True, random_state=42) + Gb = nx.to_numpy(nx.todense(Gb)) # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) np.testing.assert_allclose( - p[:, np.newaxis], G.sum(1), atol=1e-04) # cf convergence gromov + p, Gb.sum(1), atol=1e-04) # cf convergence gromov np.testing.assert_allclose( - q[np.newaxis, :], G.sum(0), atol=1e-04) # cf convergence gromov + q, Gb.sum(0), atol=1e-04) # cf convergence gromov - assert log['gw_dist_estimated'] == 0.0 - assert log['gw_dist_std'] == 0.0 + np.testing.assert_allclose(logb['gw_dist_estimated'], 0.0, atol=1e-08) + np.testing.assert_allclose(logb['gw_dist_std'], 0.0, atol=1e-08) G, log = ot.gromov.pointwise_gromov_wasserstein( C1, C2, p, q, loss, max_iter=100, alpha=0.1, log=True, verbose=True, random_state=42) + G = NumpyBackend().todense(G) + Gb, logb = ot.gromov.pointwise_gromov_wasserstein( + C1b, C2b, pb, qb, lossb, max_iter=100, alpha=0.1, log=True, verbose=True, random_state=42) + Gb = nx.to_numpy(nx.todense(Gb)) - assert log['gw_dist_estimated'] == 0.10342276348494964 - assert log['gw_dist_std'] == 0.0015952535464736394 + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose(logb['gw_dist_estimated'], 0.10342276348494964, atol=1e-8) + np.testing.assert_allclose(logb['gw_dist_std'], 0.0015952535464736394, atol=1e-8) -def test_sampled_gromov(): +@pytest.skip_backend("jax", reason="test very slow with jax backend") +def test_sampled_gromov(nx): n_samples = 50 # nb samples - mu_s = np.array([0, 0]) - cov_s = np.array([[1, 0], [0, 1]]) + mu_s = np.array([0, 0], dtype=np.float64) + cov_s = np.array([[1, 0], [0, 1]], dtype=np.float64) xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=42) @@ -163,23 +210,35 @@ def test_sampled_gromov(): C1 /= C1.max() C2 /= C2.max() + C1b = nx.from_numpy(C1) + C2b = nx.from_numpy(C2) + pb = nx.from_numpy(p) + qb = nx.from_numpy(q) + def loss(x, y): return np.abs(x - y) + def lossb(x, y): + return nx.abs(x - y) + G, log = ot.gromov.sampled_gromov_wasserstein( C1, C2, p, q, loss, max_iter=100, epsilon=1, log=True, verbose=True, random_state=42) + Gb, logb = ot.gromov.sampled_gromov_wasserstein( + C1b, C2b, pb, qb, lossb, max_iter=100, epsilon=1, log=True, verbose=True, random_state=42) + Gb = nx.to_numpy(Gb) # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) np.testing.assert_allclose( - p, G.sum(1), atol=1e-04) # cf convergence gromov + p, Gb.sum(1), atol=1e-04) # cf convergence gromov np.testing.assert_allclose( - q, G.sum(0), atol=1e-04) # cf convergence gromov + q, Gb.sum(0), atol=1e-04) # cf convergence gromov - assert log['gw_dist_estimated'] == 0.05679474884977278 - assert log['gw_dist_std'] == 0.0005986592106971995 + np.testing.assert_allclose(logb['gw_dist_estimated'], 0.05679474884977278, atol=1e-08) + np.testing.assert_allclose(logb['gw_dist_std'], 0.0005986592106971995, atol=1e-08) -def test_gromov_barycenter(): +def test_gromov_barycenter(nx): ns = 10 nt = 20 @@ -188,26 +247,42 @@ def test_gromov_barycenter(): C1 = ot.dist(Xs) C2 = ot.dist(Xt) - + p1 = ot.unif(ns) + p2 = ot.unif(nt) n_samples = 3 - Cb = ot.gromov.gromov_barycenters(n_samples, [C1, C2], - [ot.unif(ns), ot.unif(nt) - ], ot.unif(n_samples), [.5, .5], - 'square_loss', # 5e-4, - max_iter=100, tol=1e-3, - verbose=True) - np.testing.assert_allclose(Cb.shape, (n_samples, n_samples)) + p = ot.unif(n_samples) - Cb2 = ot.gromov.gromov_barycenters(n_samples, [C1, C2], - [ot.unif(ns), ot.unif(nt) - ], ot.unif(n_samples), [.5, .5], - 'kl_loss', # 5e-4, - max_iter=100, tol=1e-3) - np.testing.assert_allclose(Cb2.shape, (n_samples, n_samples)) + C1b = nx.from_numpy(C1) + C2b = nx.from_numpy(C2) + p1b = nx.from_numpy(p1) + p2b = nx.from_numpy(p2) + pb = nx.from_numpy(p) + + Cb = ot.gromov.gromov_barycenters( + n_samples, [C1, C2], [p1, p2], p, [.5, .5], + 'square_loss', max_iter=100, tol=1e-3, verbose=True, random_state=42 + ) + Cbb = nx.to_numpy(ot.gromov.gromov_barycenters( + n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5], + 'square_loss', max_iter=100, tol=1e-3, verbose=True, random_state=42 + )) + np.testing.assert_allclose(Cb, Cbb, atol=1e-06) + np.testing.assert_allclose(Cbb.shape, (n_samples, n_samples)) + + Cb2 = ot.gromov.gromov_barycenters( + n_samples, [C1, C2], [p1, p2], p, [.5, .5], + 'kl_loss', max_iter=100, tol=1e-3, random_state=42 + ) + Cb2b = nx.to_numpy(ot.gromov.gromov_barycenters( + n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5], + 'kl_loss', max_iter=100, tol=1e-3, random_state=42 + )) + np.testing.assert_allclose(Cb2, Cb2b, atol=1e-06) + np.testing.assert_allclose(Cb2b.shape, (n_samples, n_samples)) @pytest.mark.filterwarnings("ignore:divide") -def test_gromov_entropic_barycenter(): +def test_gromov_entropic_barycenter(nx): ns = 10 nt = 20 @@ -216,26 +291,41 @@ def test_gromov_entropic_barycenter(): C1 = ot.dist(Xs) C2 = ot.dist(Xt) - + p1 = ot.unif(ns) + p2 = ot.unif(nt) n_samples = 2 - Cb = ot.gromov.entropic_gromov_barycenters(n_samples, [C1, C2], - [ot.unif(ns), ot.unif(nt) - ], ot.unif(n_samples), [.5, .5], - 'square_loss', 1e-3, - max_iter=50, tol=1e-3, - verbose=True) - np.testing.assert_allclose(Cb.shape, (n_samples, n_samples)) - - Cb2 = ot.gromov.entropic_gromov_barycenters(n_samples, [C1, C2], - [ot.unif(ns), ot.unif(nt) - ], ot.unif(n_samples), [.5, .5], - 'kl_loss', 1e-3, - max_iter=100, tol=1e-3) - np.testing.assert_allclose(Cb2.shape, (n_samples, n_samples)) - - -def test_fgw(): + p = ot.unif(n_samples) + C1b = nx.from_numpy(C1) + C2b = nx.from_numpy(C2) + p1b = nx.from_numpy(p1) + p2b = nx.from_numpy(p2) + pb = nx.from_numpy(p) + + Cb = ot.gromov.entropic_gromov_barycenters( + n_samples, [C1, C2], [p1, p2], p, [.5, .5], + 'square_loss', 1e-3, max_iter=50, tol=1e-3, verbose=True, random_state=42 + ) + Cbb = nx.to_numpy(ot.gromov.entropic_gromov_barycenters( + n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5], + 'square_loss', 1e-3, max_iter=50, tol=1e-3, verbose=True, random_state=42 + )) + np.testing.assert_allclose(Cb, Cbb, atol=1e-06) + np.testing.assert_allclose(Cbb.shape, (n_samples, n_samples)) + + Cb2 = ot.gromov.entropic_gromov_barycenters( + n_samples, [C1, C2], [p1, p2], p, [.5, .5], + 'kl_loss', 1e-3, max_iter=100, tol=1e-3, random_state=42 + ) + Cb2b = nx.to_numpy(ot.gromov.entropic_gromov_barycenters( + n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5], + 'kl_loss', 1e-3, max_iter=100, tol=1e-3, random_state=42 + )) + np.testing.assert_allclose(Cb2, Cb2b, atol=1e-06) + np.testing.assert_allclose(Cb2b.shape, (n_samples, n_samples)) + + +def test_fgw(nx): n_samples = 50 # nb samples mu_s = np.array([0, 0]) @@ -260,33 +350,46 @@ def test_fgw(): M = ot.dist(ys, yt) M /= M.max() + Mb = nx.from_numpy(M) + C1b = nx.from_numpy(C1) + C2b = nx.from_numpy(C2) + pb = nx.from_numpy(p) + qb = nx.from_numpy(q) + G, log = ot.gromov.fused_gromov_wasserstein(M, C1, C2, p, q, 'square_loss', alpha=0.5, log=True) + Gb, logb = ot.gromov.fused_gromov_wasserstein(Mb, C1b, C2b, pb, qb, 'square_loss', alpha=0.5, log=True) + Gb = nx.to_numpy(Gb) # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) np.testing.assert_allclose( - p, G.sum(1), atol=1e-04) # cf convergence fgw + p, Gb.sum(1), atol=1e-04) # cf convergence fgw np.testing.assert_allclose( - q, G.sum(0), atol=1e-04) # cf convergence fgw + q, Gb.sum(0), atol=1e-04) # cf convergence fgw Id = (1 / (1.0 * n_samples)) * np.eye(n_samples, n_samples) np.testing.assert_allclose( - G, np.flipud(Id), atol=1e-04) # cf convergence gromov + Gb, np.flipud(Id), atol=1e-04) # cf convergence gromov fgw, log = ot.gromov.fused_gromov_wasserstein2(M, C1, C2, p, q, 'square_loss', alpha=0.5, log=True) + fgwb, logb = ot.gromov.fused_gromov_wasserstein2(Mb, C1b, C2b, pb, qb, 'square_loss', alpha=0.5, log=True) G = log['T'] + Gb = nx.to_numpy(logb['T']) - np.testing.assert_allclose(fgw, 0, atol=1e-1, rtol=1e-1) + np.testing.assert_allclose(fgw, fgwb, atol=1e-08) + np.testing.assert_allclose(fgwb, 0, atol=1e-1, rtol=1e-1) # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) np.testing.assert_allclose( - p, G.sum(1), atol=1e-04) # cf convergence gromov + p, Gb.sum(1), atol=1e-04) # cf convergence gromov np.testing.assert_allclose( - q, G.sum(0), atol=1e-04) # cf convergence gromov + q, Gb.sum(0), atol=1e-04) # cf convergence gromov -def test_fgw_barycenter(): +def test_fgw_barycenter(nx): np.random.seed(42) ns = 50 @@ -300,30 +403,44 @@ def test_fgw_barycenter(): C1 = ot.dist(Xs) C2 = ot.dist(Xt) - + p1, p2 = ot.unif(ns), ot.unif(nt) n_samples = 3 - X, C = ot.gromov.fgw_barycenters(n_samples, [ys, yt], [C1, C2], [ot.unif(ns), ot.unif(nt)], [.5, .5], 0.5, - fixed_structure=False, fixed_features=False, - p=ot.unif(n_samples), loss_fun='square_loss', - max_iter=100, tol=1e-3) - np.testing.assert_allclose(C.shape, (n_samples, n_samples)) - np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1])) + p = ot.unif(n_samples) + + ysb = nx.from_numpy(ys) + ytb = nx.from_numpy(yt) + C1b = nx.from_numpy(C1) + C2b = nx.from_numpy(C2) + p1b = nx.from_numpy(p1) + p2b = nx.from_numpy(p2) + pb = nx.from_numpy(p) + + Xb, Cb = ot.gromov.fgw_barycenters( + n_samples, [ysb, ytb], [C1b, C2b], [p1b, p2b], [.5, .5], 0.5, fixed_structure=False, + fixed_features=False, p=pb, loss_fun='square_loss', max_iter=100, tol=1e-3, random_state=12345 + ) xalea = np.random.randn(n_samples, 2) init_C = ot.dist(xalea, xalea) - - X, C = ot.gromov.fgw_barycenters(n_samples, [ys, yt], [C1, C2], ps=[ot.unif(ns), ot.unif(nt)], lambdas=[.5, .5], alpha=0.5, - fixed_structure=True, init_C=init_C, fixed_features=False, - p=ot.unif(n_samples), loss_fun='square_loss', - max_iter=100, tol=1e-3) - np.testing.assert_allclose(C.shape, (n_samples, n_samples)) - np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1])) + init_Cb = nx.from_numpy(init_C) + + Xb, Cb = ot.gromov.fgw_barycenters( + n_samples, [ysb, ytb], [C1b, C2b], ps=[p1b, p2b], lambdas=[.5, .5], + alpha=0.5, fixed_structure=True, init_C=init_Cb, fixed_features=False, + p=pb, loss_fun='square_loss', max_iter=100, tol=1e-3 + ) + Xb, Cb = nx.to_numpy(Xb), nx.to_numpy(Cb) + np.testing.assert_allclose(Cb.shape, (n_samples, n_samples)) + np.testing.assert_allclose(Xb.shape, (n_samples, ys.shape[1])) init_X = np.random.randn(n_samples, ys.shape[1]) - - X, C, log = ot.gromov.fgw_barycenters(n_samples, [ys, yt], [C1, C2], [ot.unif(ns), ot.unif(nt)], [.5, .5], 0.5, - fixed_structure=False, fixed_features=True, init_X=init_X, - p=ot.unif(n_samples), loss_fun='square_loss', - max_iter=100, tol=1e-3, log=True) - np.testing.assert_allclose(C.shape, (n_samples, n_samples)) - np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1])) + init_Xb = nx.from_numpy(init_X) + + Xb, Cb, logb = ot.gromov.fgw_barycenters( + n_samples, [ysb, ytb], [C1b, C2b], [p1b, p2b], [.5, .5], 0.5, + fixed_structure=False, fixed_features=True, init_X=init_Xb, + p=pb, loss_fun='square_loss', max_iter=100, tol=1e-3, log=True, random_state=98765 + ) + Xb, Cb = nx.to_numpy(Xb), nx.to_numpy(Cb) + np.testing.assert_allclose(Cb.shape, (n_samples, n_samples)) + np.testing.assert_allclose(Xb.shape, (n_samples, ys.shape[1])) -- cgit v1.2.3 From 6775a527f9d3c801f8cdd805d8f205b6a75551b9 Mon Sep 17 00:00:00 2001 From: Nicolas Courty Date: Tue, 2 Nov 2021 14:19:57 +0100 Subject: [MRG] Sliced and 1D Wasserstein distances : backend versions (#256) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add numpy and torch backends * stat sets on functions * proper import * install recent torch on windows * install recent torch on windows * now testing all functions in backedn * add jax backedn * clenaup windowds * proper convert for jax backedn * pep8 * try again windows tests * test jax conversion * try proper widows tests * emd fuction ses backedn * better test partial OT * proper tests to_numpy and teplate Backend * pep8 * pep8 x2 * feaking sinkhorn works with torch * sinkhorn2 compatible * working ot.emd2 * important detach * it should work * jax autodiff emd * pep8 * no tast same for jax * new independat tests per backedn * freaking pep8 * add tests for gradients * deprecate ot.gpu * worging dist function * working dist * dist done in backedn * not in * remove indexing * change accuacy for jax * first pull backend * projection simplex * projection simplex * projection simplex * projection simplex no ci * projection simplex no ci * projection simplex no ci * pep8 * add backedn discusion to quickstart guide * projection simplex no ci * projection simplex no ci * projection simplex no ci * pep8 + better doc * proper links * corect doctest * big debug documentation * doctest again * doctest again bis * doctest again ter (last one or i kill myself) * backend test + doc proj simplex * correction test_utils * correction test_utils * correction cumsum * correction flip * correction flip v2 * more debug * more debug * more debug + pep8 * pep8 * argh * proj_simplex * backedn works for sort * proj simplex * jax sucks * update doc * Update test/test_utils.py Co-authored-by: Alexandre Gramfort * Update docs/source/quickstart.rst Co-authored-by: Alexandre Gramfort * Update docs/source/quickstart.rst Co-authored-by: Alexandre Gramfort * Update docs/source/quickstart.rst Co-authored-by: Alexandre Gramfort * Update docs/source/readme.rst Co-authored-by: Alexandre Gramfort * Update test/test_utils.py Co-authored-by: Alexandre Gramfort * Update ot/utils.py Co-authored-by: Alexandre Gramfort * Update docs/source/readme.rst Co-authored-by: Alexandre Gramfort * Update ot/lp/__init__.py Co-authored-by: Alexandre Gramfort * begin comment alex * comment alex part 2 * optimize test gromov * proj_simplex on vectors * add awesome gradient decsnt example on the weights * pep98 of course * proof read example by alex * pep8 again * encoding oos in translation * correct legend * new backend functions for sliced * small indent pb * Optimized backendversion of sliced W * error in sliced W * after master merge * error sliced * error sliced * pep8 * test_sliced pep8 * doctest + precision for sliced * doctest * type win test_backend gather * type win test_backend gather * Update sliced.py change argument of padding pad_width * Update backend.py update redefinition * Update backend.py pep8 * Update backend.py pep 8 again.... * pep8 * build docs * emd2_1D example * refectoring emd_1d and variants * remove unused previous wasserstein_1d * pep8 * upate example * move stuff * tesys should work + implemù random backend * test random generayor functions * correction * better random generation * update sliced * update sliced * proper tests sliced * max sliced * chae file nam * add stuff * example sliced flow and barycenter * correct typo + update readme * exemple sliced flow done * pep8 * solver1d works * pep8 Co-authored-by: Rémi Flamary Co-authored-by: Alexandre Gramfort --- README.md | 11 +- docs/source/readme.rst | 51 ++- .../backends/plot_sliced_wass_grad_flow_pytorch.py | 185 +++++++++++ examples/backends/plot_wass1d_torch.py | 152 +++++++++ examples/sliced-wasserstein/plot_variance.py | 2 +- ot/__init__.py | 5 +- ot/backend.py | 98 ++++++ ot/lp/__init__.py | 367 ++------------------- ot/lp/solver_1d.py | 367 +++++++++++++++++++++ ot/sliced.py | 181 ++++++++-- test/test_1d_solver.py | 85 +++++ test/test_backend.py | 36 ++ test/test_ot.py | 57 +--- test/test_sliced.py | 90 ++++- test/test_utils.py | 2 +- 15 files changed, 1244 insertions(+), 445 deletions(-) create mode 100644 examples/backends/plot_sliced_wass_grad_flow_pytorch.py create mode 100644 examples/backends/plot_wass1d_torch.py create mode 100644 ot/lp/solver_1d.py create mode 100644 test/test_1d_solver.py (limited to 'test') diff --git a/README.md b/README.md index f0e5227..cfb9744 100644 --- a/README.md +++ b/README.md @@ -33,7 +33,7 @@ POT provides the following generic OT solvers (links to examples): * [Unbalanced OT](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT_1D.html) with KL relaxation and [barycenter](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT_barycenter_1D.html) [10, 25]. * [Partial Wasserstein and Gromov-Wasserstein](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_partial_wass_and_gromov.html) (exact [29] and entropic [3] formulations). -* [Sliced Wasserstein](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_variance.html) [31, 32]. +* [Sliced Wasserstein](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_variance.html) [31, 32] and Max-sliced Wasserstein [35] that can be used for gradient flows [36]. * [Several backends](https://pythonot.github.io/quickstart.html#solving-ot-with-multiple-backends) for easy use of POT with [Pytorch](https://pytorch.org/)/[jax](https://github.com/google/jax)/[Numpy](https://numpy.org/) arrays. POT provides the following Machine Learning related solvers: @@ -285,4 +285,11 @@ You can also post bug reports and feature requests in Github issues. Make sure t [33] Kerdoncuff T., Emonet R., Marc S. [Sampled Gromov Wasserstein](https://hal.archives-ouvertes.fr/hal-03232509/document), Machine Learning Journal (MJL), 2021 -[34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., Trouvé, A., & Peyré, G. (2019, April). Interpolating between optimal transport and MMD using Sinkhorn divergences. In The 22nd International Conference on Artificial Intelligence and Statistics (pp. 2681-2690). PMLR. +[34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., Trouvé, A., & Peyré, G. (2019, April). [Interpolating between optimal transport and MMD using Sinkhorn divergences](http://proceedings.mlr.press/v89/feydy19a/feydy19a.pdf). In The 22nd International Conference on Artificial Intelligence and Statistics (pp. 2681-2690). PMLR. + +[35] Deshpande, I., Hu, Y. T., Sun, R., Pyrros, A., Siddiqui, N., Koyejo, S., ... & Schwing, A. G. (2019). [Max-sliced wasserstein distance and its use for gans](https://openaccess.thecvf.com/content_CVPR_2019/papers/Deshpande_Max-Sliced_Wasserstein_Distance_and_Its_Use_for_GANs_CVPR_2019_paper.pdf). In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (pp. 10648-10656). + +[36] Liutkus, A., Simsekli, U., Majewski, S., Durmus, A., & Stöter, F. R. +(2019, May). [Sliced-Wasserstein flows: Nonparametric generative modeling +via optimal transport and diffusions](http://proceedings.mlr.press/v97/liutkus19a/liutkus19a.pdf). In International Conference on +Machine Learning (pp. 4104-4113). PMLR. diff --git a/docs/source/readme.rst b/docs/source/readme.rst index 82d3e6c..ee32e2b 100644 --- a/docs/source/readme.rst +++ b/docs/source/readme.rst @@ -24,7 +24,7 @@ POT provides the following generic OT solvers (links to examples): for regularized OT [7]. - Entropic regularization OT solver with `Sinkhorn Knopp Algorithm `__ - [2] , stabilized version [9] [10], greedy Sinkhorn [22] and + [2] , stabilized version [9] [10] [34], greedy Sinkhorn [22] and `Screening Sinkhorn [26] `__. - Bregman projections for `Wasserstein @@ -54,6 +54,9 @@ POT provides the following generic OT solvers (links to examples): solver `__ for Large-scale Optimal Transport (semi-dual problem [18] and dual problem [19]) +- `Stochastic solver of Gromov + Wasserstein `__ + for large-scale problem with any loss functions [33] - Non regularized `free support Wasserstein barycenters `__ [20]. @@ -137,19 +140,12 @@ following Python modules: - Numpy (>=1.16) - Scipy (>=1.0) -- Cython (>=0.23) (build only, not necessary when installing wheels - from pip or conda) +- Cython (>=0.23) (build only, not necessary when installing from pip + or conda) Pip installation ^^^^^^^^^^^^^^^^ -Note that due to a limitation of pip, ``cython`` and ``numpy`` need to -be installed prior to installing POT. This can be done easily with - -.. code:: console - - pip install numpy cython - You can install the toolbox through PyPI with: .. code:: console @@ -183,7 +179,8 @@ without errors: import ot -Note that for easier access the module is name ot instead of pot. +Note that for easier access the module is named ``ot`` instead of +``pot``. Dependencies ~~~~~~~~~~~~ @@ -222,7 +219,7 @@ Short examples .. code:: python - # a and b are 1D histograms (sum to 1 and positive) + # a,b are 1D histograms (sum to 1 and positive) # M is the ground cost matrix Wd = ot.emd2(a, b, M) # exact linear program Wd_reg = ot.sinkhorn2(a, b, M, reg) # entropic regularized OT @@ -232,7 +229,7 @@ Short examples .. code:: python - # a and b are 1D histograms (sum to 1 and positive) + # a,b are 1D histograms (sum to 1 and positive) # M is the ground cost matrix T = ot.emd(a, b, M) # exact linear program T_reg = ot.sinkhorn(a, b, M, reg) # entropic regularized OT @@ -287,6 +284,10 @@ The contributors to this library are - `Ievgen Redko `__ (Laplacian DA, JCPOT) - `Adrien Corenflos `__ (Sliced Wasserstein Distance) +- `Tanguy Kerdoncuff `__ (Sampled Gromov + Wasserstein) +- `Minhui Huang `__ (Projection Robust + Wasserstein Distance) This toolbox benefit a lot from open source research and we would like to thank the following persons for providing some code (in various @@ -476,6 +477,30 @@ of measures `__, Journal of Mathematical Imaging and Vision 51.1 (2015): 22-45 +[32] Huang, M., Ma S., Lai, L. (2021). `A Riemannian Block Coordinate +Descent Method for Computing the Projection Robust Wasserstein +Distance `__, +Proceedings of the 38th International Conference on Machine Learning +(ICML). + +[33] Kerdoncuff T., Emonet R., Marc S. `Sampled Gromov +Wasserstein `__, +Machine Learning Journal (MJL), 2021 + +[34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., Trouvé, A., +& Peyré, G. (2019, April). `Interpolating between optimal transport and +MMD using Sinkhorn +divergences `__. +In The 22nd International Conference on Artificial Intelligence and +Statistics (pp. 2681-2690). PMLR. + +[35] Deshpande, I., Hu, Y. T., Sun, R., Pyrros, A., Siddiqui, N., +Koyejo, S., ... & Schwing, A. G. (2019). `Max-sliced wasserstein +distance and its use for +gans `__. +In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern +Recognition (pp. 10648-10656). + .. |PyPI version| image:: https://badge.fury.io/py/POT.svg :target: https://badge.fury.io/py/POT .. |Anaconda Cloud| image:: https://anaconda.org/conda-forge/pot/badges/version.svg diff --git a/examples/backends/plot_sliced_wass_grad_flow_pytorch.py b/examples/backends/plot_sliced_wass_grad_flow_pytorch.py new file mode 100644 index 0000000..05b9952 --- /dev/null +++ b/examples/backends/plot_sliced_wass_grad_flow_pytorch.py @@ -0,0 +1,185 @@ +r""" +================================= +Sliced Wasserstein barycenter and gradient flow with PyTorch +================================= + +In this exemple we use the pytorch backend to optimize the sliced Wasserstein +loss between two empirical distributions [31]. + +In the first example one we perform a +gradient flow on the support of a distribution that minimize the sliced +Wassersein distance as poposed in [36]. + +In the second exemple we optimize with a gradient descent the sliced +Wasserstein barycenter between two distributions as in [31]. + +[31] Bonneel, Nicolas, et al. "Sliced and radon wasserstein barycenters of +measures." Journal of Mathematical Imaging and Vision 51.1 (2015): 22-45 + +[36] Liutkus, A., Simsekli, U., Majewski, S., Durmus, A., & Stöter, F. R. +(2019, May). Sliced-Wasserstein flows: Nonparametric generative modeling +via optimal transport and diffusions. In International Conference on +Machine Learning (pp. 4104-4113). PMLR. + + +""" +# Author: Rémi Flamary +# +# License: MIT License + + +# %% +# Loading the data + + +import numpy as np +import matplotlib.pylab as pl +import torch +import ot +import matplotlib.animation as animation + +I1 = pl.imread('../../data/redcross.png').astype(np.float64)[::4, ::4, 2] +I2 = pl.imread('../../data/tooth.png').astype(np.float64)[::4, ::4, 2] + +sz = I2.shape[0] +XX, YY = np.meshgrid(np.arange(sz), np.arange(sz)) + +x1 = np.stack((XX[I1 == 0], YY[I1 == 0]), 1) * 1.0 +x2 = np.stack((XX[I2 == 0] + 60, -YY[I2 == 0] + 32), 1) * 1.0 +x3 = np.stack((XX[I2 == 0], -YY[I2 == 0] + 32), 1) * 1.0 + +pl.figure(1, (8, 4)) +pl.scatter(x1[:, 0], x1[:, 1], alpha=0.5) +pl.scatter(x2[:, 0], x2[:, 1], alpha=0.5) + +# %% +# Sliced Wasserstein gradient flow with Pytorch +# --------------------------------------------- + + +device = "cuda" if torch.cuda.is_available() else "cpu" + +# use pyTorch for our data +x1_torch = torch.tensor(x1).to(device=device).requires_grad_(True) +x2_torch = torch.tensor(x2).to(device=device) + + +lr = 1e3 +nb_iter_max = 100 + +x_all = np.zeros((nb_iter_max, x1.shape[0], 2)) + +loss_iter = [] + +# generator for random permutations +gen = torch.Generator() +gen.manual_seed(42) + +for i in range(nb_iter_max): + + loss = ot.sliced_wasserstein_distance(x1_torch, x2_torch, n_projections=20, seed=gen) + + loss_iter.append(loss.clone().detach().cpu().numpy()) + loss.backward() + + # performs a step of projected gradient descent + with torch.no_grad(): + grad = x1_torch.grad + x1_torch -= grad * lr / (1 + i / 5e1) # step + x1_torch.grad.zero_() + x_all[i, :, :] = x1_torch.clone().detach().cpu().numpy() + +xb = x1_torch.clone().detach().cpu().numpy() + +pl.figure(2, (8, 4)) +pl.scatter(x1[:, 0], x1[:, 1], alpha=0.5, label='$\mu^{(0)}$') +pl.scatter(x2[:, 0], x2[:, 1], alpha=0.5, label=r'$\nu$') +pl.scatter(xb[:, 0], xb[:, 1], alpha=0.5, label='$\mu^{(100)}$') +pl.title('Sliced Wasserstein gradient flow') +pl.legend() +ax = pl.axis() + +# %% +# Animate trajectories of the gradient flow along iteration +# ------------------------------------------------------- + +pl.figure(3, (8, 4)) + + +def _update_plot(i): + pl.clf() + pl.scatter(x1[:, 0], x1[:, 1], alpha=0.5, label='$\mu^{(0)}$') + pl.scatter(x2[:, 0], x2[:, 1], alpha=0.5, label=r'$\nu$') + pl.scatter(x_all[i, :, 0], x_all[i, :, 1], alpha=0.5, label='$\mu^{(100)}$') + pl.title('Sliced Wasserstein gradient flow Iter. {}'.format(i)) + pl.axis(ax) + return 1 + + +ani = animation.FuncAnimation(pl.gcf(), _update_plot, nb_iter_max, interval=100, repeat_delay=2000) + +# %% +# Compute the Sliced Wasserstein Barycenter +# +x1_torch = torch.tensor(x1).to(device=device) +x3_torch = torch.tensor(x3).to(device=device) +xbinit = np.random.randn(500, 2) * 10 + 16 +xbary_torch = torch.tensor(xbinit).to(device=device).requires_grad_(True) + +lr = 1e3 +nb_iter_max = 100 + +x_all = np.zeros((nb_iter_max, xbary_torch.shape[0], 2)) + +loss_iter = [] + +# generator for random permutations +gen = torch.Generator() +gen.manual_seed(42) + +alpha = 0.5 + +for i in range(nb_iter_max): + + loss = alpha * ot.sliced_wasserstein_distance(xbary_torch, x3_torch, n_projections=50, seed=gen) \ + + (1 - alpha) * ot.sliced_wasserstein_distance(xbary_torch, x1_torch, n_projections=50, seed=gen) + + loss_iter.append(loss.clone().detach().cpu().numpy()) + loss.backward() + + # performs a step of projected gradient descent + with torch.no_grad(): + grad = xbary_torch.grad + xbary_torch -= grad * lr # / (1 + i / 5e1) # step + xbary_torch.grad.zero_() + x_all[i, :, :] = xbary_torch.clone().detach().cpu().numpy() + +xb = xbary_torch.clone().detach().cpu().numpy() + +pl.figure(4, (8, 4)) +pl.scatter(x1[:, 0], x1[:, 1], alpha=0.5, label='$\mu$') +pl.scatter(x2[:, 0], x2[:, 1], alpha=0.5, label=r'$\nu$') +pl.scatter(xb[:, 0] + 30, xb[:, 1], alpha=0.5, label='Barycenter') +pl.title('Sliced Wasserstein barycenter') +pl.legend() +ax = pl.axis() + + +# %% +# Animate trajectories of the barycenter along gradient descent +# ------------------------------------------------------- + +pl.figure(5, (8, 4)) + + +def _update_plot(i): + pl.clf() + pl.scatter(x1[:, 0], x1[:, 1], alpha=0.5, label='$\mu^{(0)}$') + pl.scatter(x2[:, 0], x2[:, 1], alpha=0.5, label=r'$\nu$') + pl.scatter(x_all[i, :, 0] + 30, x_all[i, :, 1], alpha=0.5, label='$\mu^{(100)}$') + pl.title('Sliced Wasserstein barycenter Iter. {}'.format(i)) + pl.axis(ax) + return 1 + + +ani = animation.FuncAnimation(pl.gcf(), _update_plot, nb_iter_max, interval=100, repeat_delay=2000) diff --git a/examples/backends/plot_wass1d_torch.py b/examples/backends/plot_wass1d_torch.py new file mode 100644 index 0000000..0abdd6d --- /dev/null +++ b/examples/backends/plot_wass1d_torch.py @@ -0,0 +1,152 @@ +r""" +================================= +Wasserstein 1D with PyTorch +================================= + +In this small example, we consider the following minization problem: + +.. math:: + \mu^* = \min_\mu W(\mu,\nu) + +where :math:`\nu` is a reference 1D measure. The problem is handled +by a projected gradient descent method, where the gradient is computed +by pyTorch automatic differentiation. The projection on the simplex +ensures that the iterate will remain on the probability simplex. + +This example illustrates both `wasserstein_1d` function and backend use within +the POT framework. +""" +# Author: Nicolas Courty +# Rémi Flamary +# +# License: MIT License + +import numpy as np +import matplotlib.pylab as pl +import matplotlib as mpl +import torch + +from ot.lp import wasserstein_1d +from ot.datasets import make_1D_gauss as gauss +from ot.utils import proj_simplex + +red = np.array(mpl.colors.to_rgb('red')) +blue = np.array(mpl.colors.to_rgb('blue')) + + +n = 100 # nb bins + +# bin positions +x = np.arange(n, dtype=np.float64) + +# Gaussian distributions +a = gauss(n, m=20, s=5) # m= mean, s= std +b = gauss(n, m=60, s=10) + +# enforce sum to one on the support +a = a / a.sum() +b = b / b.sum() + + +device = "cuda" if torch.cuda.is_available() else "cpu" + +# use pyTorch for our data +x_torch = torch.tensor(x).to(device=device) +a_torch = torch.tensor(a).to(device=device).requires_grad_(True) +b_torch = torch.tensor(b).to(device=device) + +lr = 1e-6 +nb_iter_max = 800 + +loss_iter = [] + +pl.figure(1, figsize=(8, 4)) +pl.plot(x, a, 'b', label='Source distribution') +pl.plot(x, b, 'r', label='Target distribution') + +for i in range(nb_iter_max): + # Compute the Wasserstein 1D with torch backend + loss = wasserstein_1d(x_torch, x_torch, a_torch, b_torch, p=2) + # record the corresponding loss value + loss_iter.append(loss.clone().detach().cpu().numpy()) + loss.backward() + + # performs a step of projected gradient descent + with torch.no_grad(): + grad = a_torch.grad + a_torch -= a_torch.grad * lr # step + a_torch.grad.zero_() + a_torch.data = proj_simplex(a_torch) # projection onto the simplex + + # plot one curve every 10 iterations + if i % 10 == 0: + mix = float(i) / nb_iter_max + pl.plot(x, a_torch.clone().detach().cpu().numpy(), c=(1 - mix) * blue + mix * red) + +pl.legend() +pl.title('Distribution along the iterations of the projected gradient descent') +pl.show() + +pl.figure(2) +pl.plot(range(nb_iter_max), loss_iter, lw=3) +pl.title('Evolution of the loss along iterations', fontsize=16) +pl.show() + +# %% +# Wasserstein barycenter +# --------- +# In this example, we consider the following Wasserstein barycenter problem +# $$ \\eta^* = \\min_\\eta\;\;\; (1-t)W(\\mu,\\eta) + tW(\\eta,\\nu)$$ +# where :math:`\\mu` and :math:`\\nu` are reference 1D measures, and :math:`t` +# is a parameter :math:`\in [0,1]`. The problem is handled by a project gradient +# descent method, where the gradient is computed by pyTorch automatic differentiation. +# The projection on the simplex ensures that the iterate will remain on the +# probability simplex. +# +# This example illustrates both `wasserstein_1d` function and backend use within the +# POT framework. + + +device = "cuda" if torch.cuda.is_available() else "cpu" + +# use pyTorch for our data +x_torch = torch.tensor(x).to(device=device) +a_torch = torch.tensor(a).to(device=device) +b_torch = torch.tensor(b).to(device=device) +bary_torch = torch.tensor((a + b).copy() / 2).to(device=device).requires_grad_(True) + + +lr = 1e-6 +nb_iter_max = 2000 + +loss_iter = [] + +# instant of the interpolation +t = 0.5 + +for i in range(nb_iter_max): + # Compute the Wasserstein 1D with torch backend + loss = (1 - t) * wasserstein_1d(x_torch, x_torch, a_torch.detach(), bary_torch, p=2) + t * wasserstein_1d(x_torch, x_torch, b_torch, bary_torch, p=2) + # record the corresponding loss value + loss_iter.append(loss.clone().detach().cpu().numpy()) + loss.backward() + + # performs a step of projected gradient descent + with torch.no_grad(): + grad = bary_torch.grad + bary_torch -= bary_torch.grad * lr # step + bary_torch.grad.zero_() + bary_torch.data = proj_simplex(bary_torch) # projection onto the simplex + +pl.figure(3, figsize=(8, 4)) +pl.plot(x, a, 'b', label='Source distribution') +pl.plot(x, b, 'r', label='Target distribution') +pl.plot(x, bary_torch.clone().detach().cpu().numpy(), c='green', label='W barycenter') +pl.legend() +pl.title('Wasserstein barycenter computed by gradient descent') +pl.show() + +pl.figure(4) +pl.plot(range(nb_iter_max), loss_iter, lw=3) +pl.title('Evolution of the loss along iterations', fontsize=16) +pl.show() diff --git a/examples/sliced-wasserstein/plot_variance.py b/examples/sliced-wasserstein/plot_variance.py index 27df656..7d73907 100644 --- a/examples/sliced-wasserstein/plot_variance.py +++ b/examples/sliced-wasserstein/plot_variance.py @@ -63,7 +63,7 @@ res = np.empty((n_seed, 25)) # %% Compute statistics for seed in range(n_seed): for i, n_projections in enumerate(n_projections_arr): - res[seed, i] = ot.sliced_wasserstein_distance(xs, xt, a, b, n_projections, seed) + res[seed, i] = ot.sliced_wasserstein_distance(xs, xt, a, b, n_projections, seed=seed) res_mean = np.mean(res, axis=0) res_std = np.std(res, axis=0) diff --git a/ot/__init__.py b/ot/__init__.py index 5bd4bab..f20332c 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -42,7 +42,7 @@ from .bregman import sinkhorn, sinkhorn2, barycenter from .unbalanced import (sinkhorn_unbalanced, barycenter_unbalanced, sinkhorn_unbalanced2) from .da import sinkhorn_lpl1_mm -from .sliced import sliced_wasserstein_distance +from .sliced import sliced_wasserstein_distance, max_sliced_wasserstein_distance # utils functions from .utils import dist, unif, tic, toc, toq @@ -51,8 +51,9 @@ __version__ = "0.8.0dev" __all__ = ['emd', 'emd2', 'emd_1d', 'sinkhorn', 'sinkhorn2', 'utils', 'datasets', 'bregman', 'lp', 'tic', 'toc', 'toq', 'gromov', - 'emd_1d', 'emd2_1d', 'wasserstein_1d', + 'emd2_1d', 'wasserstein_1d', 'backend', 'dist', 'unif', 'barycenter', 'sinkhorn_lpl1_mm', 'da', 'optim', 'sinkhorn_unbalanced', 'barycenter_unbalanced', 'sinkhorn_unbalanced2', 'sliced_wasserstein_distance', + 'max_sliced_wasserstein_distance', 'smooth', 'stochastic', 'unbalanced', 'partial', 'regpath'] diff --git a/ot/backend.py b/ot/backend.py index 358297c..d3df44c 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -103,6 +103,8 @@ class Backend(): __name__ = None __type__ = None + rng_ = None + def __str__(self): return self.__name__ @@ -540,6 +542,36 @@ class Backend(): """ raise NotImplementedError() + def seed(self, seed=None): + r""" + Sets the seed for the random generator. + + This function follows the api from :any:`numpy.random.seed` + + See: https://numpy.org/doc/stable/reference/generated/numpy.random.seed.html + """ + raise NotImplementedError() + + def rand(self, *size, type_as=None): + r""" + Generate uniform random numbers. + + This function follows the api from :any:`numpy.random.rand` + + See: https://numpy.org/doc/stable/reference/generated/numpy.random.rand.html + """ + raise NotImplementedError() + + def randn(self, *size, type_as=None): + r""" + Generate normal Gaussian random numbers. + + This function follows the api from :any:`numpy.random.rand` + + See: https://numpy.org/doc/stable/reference/generated/numpy.random.rand.html + """ + raise NotImplementedError() + def coo_matrix(self, data, rows, cols, shape=None, type_as=None): r""" Creates a sparse tensor in COOrdinate format. @@ -632,6 +664,8 @@ class NumpyBackend(Backend): __name__ = 'numpy' __type__ = np.ndarray + rng_ = np.random.RandomState() + def to_numpy(self, a): return a @@ -793,6 +827,16 @@ class NumpyBackend(Backend): def reshape(self, a, shape): return np.reshape(a, shape) + def seed(self, seed=None): + if seed is not None: + self.rng_.seed(seed) + + def rand(self, *size, type_as=None): + return self.rng_.rand(*size) + + def randn(self, *size, type_as=None): + return self.rng_.randn(*size) + def coo_matrix(self, data, rows, cols, shape=None, type_as=None): if type_as is None: return coo_matrix((data, (rows, cols)), shape=shape) @@ -845,6 +889,11 @@ class JaxBackend(Backend): __name__ = 'jax' __type__ = jax_type + rng_ = None + + def __init__(self): + self.rng_ = jax.random.PRNGKey(42) + def to_numpy(self, a): return np.array(a) @@ -1010,6 +1059,24 @@ class JaxBackend(Backend): def reshape(self, a, shape): return jnp.reshape(a, shape) + def seed(self, seed=None): + if seed is not None: + self.rng_ = jax.random.PRNGKey(seed) + + def rand(self, *size, type_as=None): + self.rng_, subkey = jax.random.split(self.rng_) + if type_as is not None: + return jax.random.uniform(subkey, shape=size, dtype=type_as.dtype) + else: + return jax.random.uniform(subkey, shape=size) + + def randn(self, *size, type_as=None): + self.rng_, subkey = jax.random.split(self.rng_) + if type_as is not None: + return jax.random.normal(subkey, shape=size, dtype=type_as.dtype) + else: + return jax.random.normal(subkey, shape=size) + def coo_matrix(self, data, rows, cols, shape=None, type_as=None): # Currently, JAX does not support sparse matrices data = self.to_numpy(data) @@ -1064,8 +1131,13 @@ class TorchBackend(Backend): __name__ = 'torch' __type__ = torch_type + rng_ = None + def __init__(self): + self.rng_ = torch.Generator() + self.rng_.seed() + from torch.autograd import Function # define a function that takes inputs val and grads @@ -1102,12 +1174,16 @@ class TorchBackend(Backend): return res def zeros(self, shape, type_as=None): + if isinstance(shape, int): + shape = (shape,) if type_as is None: return torch.zeros(shape) else: return torch.zeros(shape, dtype=type_as.dtype, device=type_as.device) def ones(self, shape, type_as=None): + if isinstance(shape, int): + shape = (shape,) if type_as is None: return torch.ones(shape) else: @@ -1120,6 +1196,8 @@ class TorchBackend(Backend): return torch.arange(start, stop, step, device=type_as.device) def full(self, shape, fill_value, type_as=None): + if isinstance(shape, int): + shape = (shape,) if type_as is None: return torch.full(shape, fill_value) else: @@ -1293,6 +1371,26 @@ class TorchBackend(Backend): def reshape(self, a, shape): return torch.reshape(a, shape) + def seed(self, seed=None): + if isinstance(seed, int): + self.rng_.manual_seed(seed) + elif isinstance(seed, torch.Generator): + self.rng_ = seed + else: + raise ValueError("Non compatible seed : {}".format(seed)) + + def rand(self, *size, type_as=None): + if type_as is not None: + return torch.rand(size=size, generator=self.rng_, dtype=type_as.dtype, device=type_as.device) + else: + return torch.rand(size=size, generator=self.rng_) + + def randn(self, *size, type_as=None): + if type_as is not None: + return torch.randn(size=size, dtype=type_as.dtype, generator=self.rng_, device=type_as.device) + else: + return torch.randn(size=size, generator=self.rng_) + def coo_matrix(self, data, rows, cols, shape=None, type_as=None): if type_as is None: return torch.sparse_coo_tensor(torch.stack([rows, cols]), data, size=shape) diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index 4e95ccf..2c18a88 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -13,20 +13,23 @@ import multiprocessing import sys import numpy as np -from scipy.sparse import coo_matrix import warnings from . import cvx from .cvx import barycenter + # import compiled emd from .emd_wrap import emd_c, check_result, emd_1d_sorted +from .solver_1d import emd_1d, emd2_1d, wasserstein_1d + from ..utils import dist, list_to_array from ..utils import parmap from ..backend import get_backend -__all__ = ['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx', +__all__ = ['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx', ' emd_1d_sorted', 'emd_1d', 'emd2_1d', 'wasserstein_1d'] + def check_number_threads(numThreads): """Checks whether or not the requested number of threads has a valid value. @@ -115,10 +118,10 @@ def estimate_dual_null_weights(alpha0, beta0, a, b, M): .. warning:: This function is necessary because the C++ solver in emd_c - discards all samples in the distributions with - zeros weights. This means that while the primal variable (transport + discards all samples in the distributions with + zeros weights. This means that while the primal variable (transport matrix) is exact, the solver only returns feasible dual potentials - on the samples with weights different from zero. + on the samples with weights different from zero. First we compute the constraints violations: @@ -215,26 +218,26 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1): format .. note:: This function is backend-compatible and will work on arrays - from all compatible backends. + from all compatible backends. Uses the algorithm proposed in [1]_ Parameters ---------- - a : (ns,) array-like, float + a : (ns,) array-like, float Source histogram (uniform weight if empty list) - b : (nt,) array-like, float - Target histogram (uniform weight if empty list) - M : (ns,nt) array-like, float - Loss matrix (c-order array in numpy with type float64) - numItermax : int, optional (default=100000) + b : (nt,) array-like, float + Target histogram (uniform weight if empty list) + M : (ns,nt) array-like, float + Loss matrix (c-order array in numpy with type float64) + numItermax : int, optional (default=100000) The maximum number of iterations before stopping the optimization - algorithm if it has not converged. - log: bool, optional (default=False) - If True, returns a dictionary containing the cost and dual variables. - Otherwise returns only the optimal transportation matrix. + algorithm if it has not converged. + log: bool, optional (default=False) + If True, returns a dictionary containing the cost and dual variables. + Otherwise returns only the optimal transportation matrix. center_dual: boolean, optional (default=True) - If True, centers the dual potential using function + If True, centers the dual potential using function :ref:`center_ot_dual`. numThreads: int or "max", optional (default=1, i.e. OpenMP is not used) If compiled with OpenMP, chooses the number of threads to parallelize. @@ -242,9 +245,9 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1): Returns ------- - gamma: array-like, shape (ns, nt) + gamma: array-like, shape (ns, nt) Optimal transportation matrix for the given - parameters + parameters log: dict, optional If input log is true, a dictionary containing the cost and dual variables and exit status @@ -277,10 +280,10 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1): regularized OT""" # convert to numpy if list - a, b, M = list_to_array(a, b, M) + a, b, M = list_to_array(a, b, M) a0, b0, M0 = a, b, M - nx = get_backend(M0, a0, b0) + nx = get_backend(M0, a0, b0) # convert to numpy M = nx.to_numpy(M) @@ -302,9 +305,9 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1): "Dimension mismatch, check dimensions of M with a and b" # ensure that same mass - np.testing.assert_almost_equal(a.sum(0), - b.sum(0), err_msg='a and b vector must have the same sum') - b=b*a.sum()/b.sum() + np.testing.assert_almost_equal(a.sum(0), + b.sum(0), err_msg='a and b vector must have the same sum') + b = b * a.sum() / b.sum() asel = a != 0 bsel = b != 0 @@ -415,10 +418,10 @@ def emd2(a, b, M, processes=1, ot.bregman.sinkhorn : Entropic regularized OT ot.optim.cg : General regularized OT""" - a, b, M = list_to_array(a, b, M) + a, b, M = list_to_array(a, b, M) a0, b0, M0 = a, b, M - nx = get_backend(M0, a0, b0) + nx = get_backend(M0, a0, b0) # convert to numpy M = nx.to_numpy(M) @@ -427,7 +430,7 @@ def emd2(a, b, M, processes=1, a = np.asarray(a, dtype=np.float64) b = np.asarray(b, dtype=np.float64) - M = np.asarray(M, dtype=np.float64, order= 'C') + M = np.asarray(M, dtype=np.float64, order='C') # if empty array given then use uniform distributions if len(a) == 0: @@ -463,8 +466,8 @@ def emd2(a, b, M, processes=1, log['v'] = nx.from_numpy(v, type_as=b0) log['warning'] = result_code_string log['result_code'] = result_code - cost = nx.set_gradients(nx.from_numpy(cost, type_as=M0), - (a0,b0, M0), (log['u'], log['v'], G)) + cost = nx.set_gradients(nx.from_numpy(cost, type_as=M0), + (a0, b0, M0), (log['u'], log['v'], G)) return [cost, log] else: def f(b): @@ -479,8 +482,8 @@ def emd2(a, b, M, processes=1, G = nx.from_numpy(G, type_as=M0) cost = nx.set_gradients(nx.from_numpy(cost, type_as=M0), - (a0,b0, M0), (nx.from_numpy(u, type_as=a0), - nx.from_numpy(v, type_as=b0),G)) + (a0, b0, M0), (nx.from_numpy(u, type_as=a0), + nx.from_numpy(v, type_as=b0), G)) check_result(result_code) return cost @@ -603,305 +606,3 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None return X, log_dict else: return X - - -def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True, - log=False): - r"""Solves the Earth Movers distance problem between 1d measures and returns - the OT matrix - - - .. math:: - \gamma = arg\min_\gamma \sum_i \sum_j \gamma_{ij} d(x_a[i], x_b[j]) - - s.t. \gamma 1 = a, - \gamma^T 1= b, - \gamma\geq 0 - where : - - - d is the metric - - x_a and x_b are the samples - - a and b are the sample weights - - When 'minkowski' is used as a metric, :math:`d(x, y) = |x - y|^p`. - - Uses the algorithm detailed in [1]_ - - Parameters - ---------- - x_a : (ns,) or (ns, 1) ndarray, float64 - Source dirac locations (on the real line) - x_b : (nt,) or (ns, 1) ndarray, float64 - Target dirac locations (on the real line) - a : (ns,) ndarray, float64, optional - Source histogram (default is uniform weight) - b : (nt,) ndarray, float64, optional - Target histogram (default is uniform weight) - metric: str, optional (default='sqeuclidean') - Metric to be used. Only strings listed in :func:`ot.dist` are accepted. - Due to implementation details, this function runs faster when - `'sqeuclidean'`, `'cityblock'`, or `'euclidean'` metrics are used. - p: float, optional (default=1.0) - The p-norm to apply for if metric='minkowski' - dense: boolean, optional (default=True) - If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt). - Otherwise returns a sparse representation using scipy's `coo_matrix` - format. Due to implementation details, this function runs faster when - `'sqeuclidean'`, `'minkowski'`, `'cityblock'`, or `'euclidean'` metrics - are used. - log: boolean, optional (default=False) - If True, returns a dictionary containing the cost. - Otherwise returns only the optimal transportation matrix. - - Returns - ------- - gamma: (ns, nt) ndarray - Optimal transportation matrix for the given parameters - log: dict - If input log is True, a dictionary containing the cost - - - Examples - -------- - - Simple example with obvious solution. The function emd_1d accepts lists and - performs automatic conversion to numpy arrays - - >>> import ot - >>> a=[.5, .5] - >>> b=[.5, .5] - >>> x_a = [2., 0.] - >>> x_b = [0., 3.] - >>> ot.emd_1d(x_a, x_b, a, b) - array([[0. , 0.5], - [0.5, 0. ]]) - >>> ot.emd_1d(x_a, x_b) - array([[0. , 0.5], - [0.5, 0. ]]) - - References - ---------- - - .. [1] Peyré, G., & Cuturi, M. (2017). "Computational Optimal - Transport", 2018. - - See Also - -------- - ot.lp.emd : EMD for multidimensional distributions - ot.lp.emd2_1d : EMD for 1d distributions (returns cost instead of the - transportation matrix) - """ - a, b, x_a, x_b = list_to_array(a, b, x_a, x_b) - nx = get_backend(x_a, x_b) - - assert (x_a.ndim == 1 or x_a.ndim == 2 and x_a.shape[1] == 1), \ - "emd_1d should only be used with monodimensional data" - assert (x_b.ndim == 1 or x_b.ndim == 2 and x_b.shape[1] == 1), \ - "emd_1d should only be used with monodimensional data" - - # if empty array given then use uniform distributions - if a is None or a.ndim == 0 or len(a) == 0: - a = nx.ones((x_a.shape[0],), type_as=x_a) / x_a.shape[0] - if b is None or b.ndim == 0 or len(b) == 0: - b = nx.ones((x_b.shape[0],), type_as=x_b) / x_b.shape[0] - - # ensure that same mass - np.testing.assert_almost_equal( - nx.sum(a, axis=0), - nx.sum(b, axis=0), - err_msg='a and b vector must have the same sum' - ) - b = b * nx.sum(a) / nx.sum(b) - - x_a_1d = nx.reshape(x_a, (-1,)) - x_b_1d = nx.reshape(x_b, (-1,)) - perm_a = nx.argsort(x_a_1d) - perm_b = nx.argsort(x_b_1d) - - G_sorted, indices, cost = emd_1d_sorted( - nx.to_numpy(a[perm_a]), - nx.to_numpy(b[perm_b]), - nx.to_numpy(x_a_1d[perm_a]), - nx.to_numpy(x_b_1d[perm_b]), - metric=metric, p=p - ) - - G = nx.coo_matrix( - G_sorted, - perm_a[indices[:, 0]], - perm_b[indices[:, 1]], - shape=(a.shape[0], b.shape[0]), - type_as=x_a - ) - if dense: - G = nx.todense(G) - elif str(nx) == "jax": - warnings.warn("JAX does not support sparse matrices, converting to dense") - if log: - log = {'cost': cost} - return G, log - return G - - -def emd2_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True, - log=False): - r"""Solves the Earth Movers distance problem between 1d measures and returns - the loss - - - .. math:: - \gamma = arg\min_\gamma \sum_i \sum_j \gamma_{ij} d(x_a[i], x_b[j]) - - s.t. \gamma 1 = a, - \gamma^T 1= b, - \gamma\geq 0 - where : - - - d is the metric - - x_a and x_b are the samples - - a and b are the sample weights - - When 'minkowski' is used as a metric, :math:`d(x, y) = |x - y|^p`. - - Uses the algorithm detailed in [1]_ - - Parameters - ---------- - x_a : (ns,) or (ns, 1) ndarray, float64 - Source dirac locations (on the real line) - x_b : (nt,) or (ns, 1) ndarray, float64 - Target dirac locations (on the real line) - a : (ns,) ndarray, float64, optional - Source histogram (default is uniform weight) - b : (nt,) ndarray, float64, optional - Target histogram (default is uniform weight) - metric: str, optional (default='sqeuclidean') - Metric to be used. Only strings listed in :func:`ot.dist` are accepted. - Due to implementation details, this function runs faster when - `'sqeuclidean'`, `'minkowski'`, `'cityblock'`, or `'euclidean'` metrics - are used. - p: float, optional (default=1.0) - The p-norm to apply for if metric='minkowski' - dense: boolean, optional (default=True) - If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt). - Otherwise returns a sparse representation using scipy's `coo_matrix` - format. Only used if log is set to True. Due to implementation details, - this function runs faster when dense is set to False. - log: boolean, optional (default=False) - If True, returns a dictionary containing the transportation matrix. - Otherwise returns only the loss. - - Returns - ------- - loss: float - Cost associated to the optimal transportation - log: dict - If input log is True, a dictionary containing the Optimal transportation - matrix for the given parameters - - - Examples - -------- - - Simple example with obvious solution. The function emd2_1d accepts lists and - performs automatic conversion to numpy arrays - - >>> import ot - >>> a=[.5, .5] - >>> b=[.5, .5] - >>> x_a = [2., 0.] - >>> x_b = [0., 3.] - >>> ot.emd2_1d(x_a, x_b, a, b) - 0.5 - >>> ot.emd2_1d(x_a, x_b) - 0.5 - - References - ---------- - - .. [1] Peyré, G., & Cuturi, M. (2017). "Computational Optimal - Transport", 2018. - - See Also - -------- - ot.lp.emd2 : EMD for multidimensional distributions - ot.lp.emd_1d : EMD for 1d distributions (returns the transportation matrix - instead of the cost) - """ - # If we do not return G (log==False), then we should not to cast it to dense - # (useless overhead) - G, log_emd = emd_1d(x_a=x_a, x_b=x_b, a=a, b=b, metric=metric, p=p, - dense=dense and log, log=True) - cost = log_emd['cost'] - if log: - log_emd = {'G': G} - return cost, log_emd - return cost - - -def wasserstein_1d(x_a, x_b, a=None, b=None, p=1.): - r"""Solves the p-Wasserstein distance problem between 1d measures and returns - the distance - - .. math:: - \min_\gamma \left( \sum_i \sum_j \gamma_{ij} \|x_a[i] - x_b[j]\|^p \right)^{1/p} - - s.t. \gamma 1 = a, - \gamma^T 1= b, - \gamma\geq 0 - - where : - - - x_a and x_b are the samples - - a and b are the sample weights - - Uses the algorithm detailed in [1]_ - - Parameters - ---------- - x_a : (ns,) or (ns, 1) ndarray, float64 - Source dirac locations (on the real line) - x_b : (nt,) or (ns, 1) ndarray, float64 - Target dirac locations (on the real line) - a : (ns,) ndarray, float64, optional - Source histogram (default is uniform weight) - b : (nt,) ndarray, float64, optional - Target histogram (default is uniform weight) - p: float, optional (default=1.0) - The order of the p-Wasserstein distance to be computed - - Returns - ------- - dist: float - p-Wasserstein distance - - - Examples - -------- - - Simple example with obvious solution. The function wasserstein_1d accepts - lists and performs automatic conversion to numpy arrays - - >>> import ot - >>> a=[.5, .5] - >>> b=[.5, .5] - >>> x_a = [2., 0.] - >>> x_b = [0., 3.] - >>> ot.wasserstein_1d(x_a, x_b, a, b) - 0.5 - >>> ot.wasserstein_1d(x_a, x_b) - 0.5 - - References - ---------- - - .. [1] Peyré, G., & Cuturi, M. (2017). "Computational Optimal - Transport", 2018. - - See Also - -------- - ot.lp.emd_1d : EMD for 1d distributions - """ - cost_emd = emd2_1d(x_a=x_a, x_b=x_b, a=a, b=b, metric='minkowski', p=p, - dense=False, log=False) - return np.power(cost_emd, 1. / p) diff --git a/ot/lp/solver_1d.py b/ot/lp/solver_1d.py new file mode 100644 index 0000000..42554aa --- /dev/null +++ b/ot/lp/solver_1d.py @@ -0,0 +1,367 @@ +# -*- coding: utf-8 -*- +""" +Exact solvers for the 1D Wasserstein distance using cvxopt +""" + +# Author: Remi Flamary +# Author: Nicolas Courty +# +# License: MIT License + +import numpy as np +import warnings + +from .emd_wrap import emd_1d_sorted +from ..backend import get_backend +from ..utils import list_to_array + + +def quantile_function(qs, cws, xs): + r""" Computes the quantile function of an empirical distribution + + Parameters + ---------- + qs: array-like, shape (n,) + Quantiles at which the quantile function is evaluated + cws: array-like, shape (m, ...) + cumulative weights of the 1D empirical distribution, if batched, must be similar to xs + xs: array-like, shape (n, ...) + locations of the 1D empirical distribution, batched against the `xs.ndim - 1` first dimensions + + Returns + ------- + q: array-like, shape (..., n) + The quantiles of the distribution + """ + nx = get_backend(qs, cws) + n = xs.shape[0] + if nx.__name__ == 'torch': + # this is to ensure the best performance for torch searchsorted + # and avoid a warninng related to non-contiguous arrays + cws = cws.T.contiguous() + qs = qs.T.contiguous() + else: + cws = cws.T + qs = qs.T + idx = nx.searchsorted(cws, qs).T + return nx.take_along_axis(xs, nx.clip(idx, 0, n - 1), axis=0) + + +def wasserstein_1d(u_values, v_values, u_weights=None, v_weights=None, p=1, require_sort=True): + r""" + Computes the 1 dimensional OT loss [15] between two (batched) empirical + distributions + + .. math: + OT_{loss} = \int_0^1 |cdf_u^{-1}(q) cdf_v^{-1}(q)|^p dq + + It is formally the p-Wasserstein distance raised to the power p. + We do so in a vectorized way by first building the individual quantile functions then integrating them. + + This function should be preferred to `emd_1d` whenever the backend is + different to numpy, and when gradients over + either sample positions or weights are required. + + Parameters + ---------- + u_values: array-like, shape (n, ...) + locations of the first empirical distribution + v_values: array-like, shape (m, ...) + locations of the second empirical distribution + u_weights: array-like, shape (n, ...), optional + weights of the first empirical distribution, if None then uniform weights are used + v_weights: array-like, shape (m, ...), optional + weights of the second empirical distribution, if None then uniform weights are used + p: int, optional + order of the ground metric used, should be at least 1 (see [2, Chap. 2], default is 1 + require_sort: bool, optional + sort the distributions atoms locations, if False we will consider they have been sorted prior to being passed to + the function, default is True + + Returns + ------- + cost: float/array-like, shape (...) + the batched EMD + + References + ---------- + .. [15] Peyré, G., & Cuturi, M. (2018). Computational Optimal Transport. + + """ + + assert p >= 1, "The OT loss is only valid for p>=1, {p} was given".format(p=p) + + if u_weights is not None and v_weights is not None: + nx = get_backend(u_values, v_values, u_weights, v_weights) + else: + nx = get_backend(u_values, v_values) + + n = u_values.shape[0] + m = v_values.shape[0] + + if u_weights is None: + u_weights = nx.full(u_values.shape, 1. / n) + elif u_weights.ndim != u_values.ndim: + u_weights = nx.repeat(u_weights[..., None], u_values.shape[-1], -1) + if v_weights is None: + v_weights = nx.full(v_values.shape, 1. / m) + elif v_weights.ndim != v_values.ndim: + v_weights = nx.repeat(v_weights[..., None], v_values.shape[-1], -1) + + if require_sort: + u_sorter = nx.argsort(u_values, 0) + u_values = nx.take_along_axis(u_values, u_sorter, 0) + + v_sorter = nx.argsort(v_values, 0) + v_values = nx.take_along_axis(v_values, v_sorter, 0) + + u_weights = nx.take_along_axis(u_weights, u_sorter, 0) + v_weights = nx.take_along_axis(v_weights, v_sorter, 0) + + u_cumweights = nx.cumsum(u_weights, 0) + v_cumweights = nx.cumsum(v_weights, 0) + + qs = nx.sort(nx.concatenate((u_cumweights, v_cumweights), 0), 0) + u_quantiles = quantile_function(qs, u_cumweights, u_values) + v_quantiles = quantile_function(qs, v_cumweights, v_values) + qs = nx.zero_pad(qs, pad_width=[(1, 0)] + (qs.ndim - 1) * [(0, 0)]) + delta = qs[1:, ...] - qs[:-1, ...] + diff_quantiles = nx.abs(u_quantiles - v_quantiles) + + if p == 1: + return nx.sum(delta * nx.abs(diff_quantiles), axis=0) + return nx.sum(delta * nx.power(diff_quantiles, p), axis=0) + + +def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True, + log=False): + r"""Solves the Earth Movers distance problem between 1d measures and returns + the OT matrix + + + .. math:: + \gamma = arg\min_\gamma \sum_i \sum_j \gamma_{ij} d(x_a[i], x_b[j]) + + s.t. \gamma 1 = a, + \gamma^T 1= b, + \gamma\geq 0 + where : + + - d is the metric + - x_a and x_b are the samples + - a and b are the sample weights + + When 'minkowski' is used as a metric, :math:`d(x, y) = |x - y|^p`. + + Uses the algorithm detailed in [1]_ + + Parameters + ---------- + x_a : (ns,) or (ns, 1) ndarray, float64 + Source dirac locations (on the real line) + x_b : (nt,) or (ns, 1) ndarray, float64 + Target dirac locations (on the real line) + a : (ns,) ndarray, float64, optional + Source histogram (default is uniform weight) + b : (nt,) ndarray, float64, optional + Target histogram (default is uniform weight) + metric: str, optional (default='sqeuclidean') + Metric to be used. Only strings listed in :func:`ot.dist` are accepted. + Due to implementation details, this function runs faster when + `'sqeuclidean'`, `'cityblock'`, or `'euclidean'` metrics are used. + p: float, optional (default=1.0) + The p-norm to apply for if metric='minkowski' + dense: boolean, optional (default=True) + If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt). + Otherwise returns a sparse representation using scipy's `coo_matrix` + format. Due to implementation details, this function runs faster when + `'sqeuclidean'`, `'minkowski'`, `'cityblock'`, or `'euclidean'` metrics + are used. + log: boolean, optional (default=False) + If True, returns a dictionary containing the cost. + Otherwise returns only the optimal transportation matrix. + + Returns + ------- + gamma: (ns, nt) ndarray + Optimal transportation matrix for the given parameters + log: dict + If input log is True, a dictionary containing the cost + + + Examples + -------- + + Simple example with obvious solution. The function emd_1d accepts lists and + performs automatic conversion to numpy arrays + + >>> import ot + >>> a=[.5, .5] + >>> b=[.5, .5] + >>> x_a = [2., 0.] + >>> x_b = [0., 3.] + >>> ot.emd_1d(x_a, x_b, a, b) + array([[0. , 0.5], + [0.5, 0. ]]) + >>> ot.emd_1d(x_a, x_b) + array([[0. , 0.5], + [0.5, 0. ]]) + + References + ---------- + + .. [1] Peyré, G., & Cuturi, M. (2017). "Computational Optimal + Transport", 2018. + + See Also + -------- + ot.lp.emd : EMD for multidimensional distributions + ot.lp.emd2_1d : EMD for 1d distributions (returns cost instead of the + transportation matrix) + """ + a, b, x_a, x_b = list_to_array(a, b, x_a, x_b) + nx = get_backend(x_a, x_b) + + assert (x_a.ndim == 1 or x_a.ndim == 2 and x_a.shape[1] == 1), \ + "emd_1d should only be used with monodimensional data" + assert (x_b.ndim == 1 or x_b.ndim == 2 and x_b.shape[1] == 1), \ + "emd_1d should only be used with monodimensional data" + + # if empty array given then use uniform distributions + if a is None or a.ndim == 0 or len(a) == 0: + a = nx.ones((x_a.shape[0],), type_as=x_a) / x_a.shape[0] + if b is None or b.ndim == 0 or len(b) == 0: + b = nx.ones((x_b.shape[0],), type_as=x_b) / x_b.shape[0] + + # ensure that same mass + np.testing.assert_almost_equal( + nx.sum(a, axis=0), + nx.sum(b, axis=0), + err_msg='a and b vector must have the same sum' + ) + b = b * nx.sum(a) / nx.sum(b) + + x_a_1d = nx.reshape(x_a, (-1,)) + x_b_1d = nx.reshape(x_b, (-1,)) + perm_a = nx.argsort(x_a_1d) + perm_b = nx.argsort(x_b_1d) + + G_sorted, indices, cost = emd_1d_sorted( + nx.to_numpy(a[perm_a]), + nx.to_numpy(b[perm_b]), + nx.to_numpy(x_a_1d[perm_a]), + nx.to_numpy(x_b_1d[perm_b]), + metric=metric, p=p + ) + + G = nx.coo_matrix( + G_sorted, + perm_a[indices[:, 0]], + perm_b[indices[:, 1]], + shape=(a.shape[0], b.shape[0]), + type_as=x_a + ) + if dense: + G = nx.todense(G) + elif str(nx) == "jax": + warnings.warn("JAX does not support sparse matrices, converting to dense") + if log: + log = {'cost': cost} + return G, log + return G + + +def emd2_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True, + log=False): + r"""Solves the Earth Movers distance problem between 1d measures and returns + the loss + + + .. math:: + \gamma = arg\min_\gamma \sum_i \sum_j \gamma_{ij} d(x_a[i], x_b[j]) + + s.t. \gamma 1 = a, + \gamma^T 1= b, + \gamma\geq 0 + where : + + - d is the metric + - x_a and x_b are the samples + - a and b are the sample weights + + When 'minkowski' is used as a metric, :math:`d(x, y) = |x - y|^p`. + + Uses the algorithm detailed in [1]_ + + Parameters + ---------- + x_a : (ns,) or (ns, 1) ndarray, float64 + Source dirac locations (on the real line) + x_b : (nt,) or (ns, 1) ndarray, float64 + Target dirac locations (on the real line) + a : (ns,) ndarray, float64, optional + Source histogram (default is uniform weight) + b : (nt,) ndarray, float64, optional + Target histogram (default is uniform weight) + metric: str, optional (default='sqeuclidean') + Metric to be used. Only strings listed in :func:`ot.dist` are accepted. + Due to implementation details, this function runs faster when + `'sqeuclidean'`, `'minkowski'`, `'cityblock'`, or `'euclidean'` metrics + are used. + p: float, optional (default=1.0) + The p-norm to apply for if metric='minkowski' + dense: boolean, optional (default=True) + If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt). + Otherwise returns a sparse representation using scipy's `coo_matrix` + format. Only used if log is set to True. Due to implementation details, + this function runs faster when dense is set to False. + log: boolean, optional (default=False) + If True, returns a dictionary containing the transportation matrix. + Otherwise returns only the loss. + + Returns + ------- + loss: float + Cost associated to the optimal transportation + log: dict + If input log is True, a dictionary containing the Optimal transportation + matrix for the given parameters + + + Examples + -------- + + Simple example with obvious solution. The function emd2_1d accepts lists and + performs automatic conversion to numpy arrays + + >>> import ot + >>> a=[.5, .5] + >>> b=[.5, .5] + >>> x_a = [2., 0.] + >>> x_b = [0., 3.] + >>> ot.emd2_1d(x_a, x_b, a, b) + 0.5 + >>> ot.emd2_1d(x_a, x_b) + 0.5 + + References + ---------- + + .. [1] Peyré, G., & Cuturi, M. (2017). "Computational Optimal + Transport", 2018. + + See Also + -------- + ot.lp.emd2 : EMD for multidimensional distributions + ot.lp.emd_1d : EMD for 1d distributions (returns the transportation matrix + instead of the cost) + """ + # If we do not return G (log==False), then we should not to cast it to dense + # (useless overhead) + G, log_emd = emd_1d(x_a=x_a, x_b=x_b, a=a, b=b, metric=metric, p=p, + dense=dense and log, log=True) + cost = log_emd['cost'] + if log: + log_emd = {'G': G} + return cost, log_emd + return cost diff --git a/ot/sliced.py b/ot/sliced.py index 4792576..d3dc3f2 100644 --- a/ot/sliced.py +++ b/ot/sliced.py @@ -1,61 +1,73 @@ """ -Sliced Wasserstein Distance. +Sliced OT Distances """ # Author: Adrien Corenflos +# Nicolas Courty +# Rémi Flamary # # License: MIT License import numpy as np +from .backend import get_backend, NumpyBackend +from .utils import list_to_array -def get_random_projections(n_projections, d, seed=None): +def get_random_projections(d, n_projections, seed=None, backend=None, type_as=None): r""" Generates n_projections samples from the uniform on the unit sphere of dimension d-1: :math:`\mathcal{U}(\mathcal{S}^{d-1})` Parameters ---------- - n_projections : int - number of samples requested d : int dimension of the space + n_projections : int + number of samples requested seed: int or RandomState, optional Seed used for numpy random number generator + backend: + Backend to ue for random generation Returns ------- - out: ndarray, shape (n_projections, d) + out: ndarray, shape (d, n_projections) The uniform unit vectors on the sphere Examples -------- >>> n_projections = 100 >>> d = 5 - >>> projs = get_random_projections(n_projections, d) - >>> np.allclose(np.sum(np.square(projs), 1), 1.) # doctest: +NORMALIZE_WHITESPACE + >>> projs = get_random_projections(d, n_projections) + >>> np.allclose(np.sum(np.square(projs), 0), 1.) # doctest: +NORMALIZE_WHITESPACE True """ - if not isinstance(seed, np.random.RandomState): - random_state = np.random.RandomState(seed) + if backend is None: + nx = NumpyBackend() + else: + nx = backend + + if isinstance(seed, np.random.RandomState) and str(nx) == 'numpy': + projections = seed.randn(d, n_projections) else: - random_state = seed + if seed is not None: + nx.seed(seed) + projections = nx.randn(d, n_projections, type_as=type_as) - projections = random_state.normal(0., 1., [n_projections, d]) - norm = np.linalg.norm(projections, ord=2, axis=1, keepdims=True) - projections = projections / norm + projections = projections / nx.sqrt(nx.sum(projections**2, 0, keepdims=True)) return projections -def sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, seed=None, log=False): +def sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, p=2, + projections=None, seed=None, log=False): r""" - Computes a Monte-Carlo approximation of the 2-Sliced Wasserstein distance + Computes a Monte-Carlo approximation of the p-Sliced Wasserstein distance .. math:: - \mathcal{SWD}_2(\mu, \nu) = \underset{\theta \sim \mathcal{U}(\mathbb{S}^{d-1})}{\mathbb{E}}[\mathcal{W}_2^2(\theta_\# \mu, \theta_\# \nu)]^{\frac{1}{2}} + \mathcal{SWD}_p(\mu, \nu) = \underset{\theta \sim \mathcal{U}(\mathbb{S}^{d-1})}{\mathbb{E}}[\mathcal{W}_p^p(\theta_\# \mu, \theta_\# \nu)]^{\frac{1}{p}} where : @@ -74,8 +86,12 @@ def sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, seed samples weights in the target domain n_projections : int, optional Number of projections used for the Monte-Carlo approximation + p: float, optional = + Power p used for computing the sliced Wasserstein + projections: shape (dim, n_projections), optional + Projection matrix (n_projections and seed are not used in this case) seed: int or RandomState or None, optional - Seed used for numpy random number generator + Seed used for random number generator log: bool, optional if True, sliced_wasserstein_distance returns the projections used and their associated EMD. @@ -100,10 +116,18 @@ def sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, seed .. [31] Bonneel, Nicolas, et al. "Sliced and radon wasserstein barycenters of measures." Journal of Mathematical Imaging and Vision 51.1 (2015): 22-45 """ - from .lp import emd2_1d + from .lp import wasserstein_1d - X_s = np.asanyarray(X_s) - X_t = np.asanyarray(X_t) + X_s, X_t = list_to_array(X_s, X_t) + + if a is not None and b is not None and projections is None: + nx = get_backend(X_s, X_t, a, b) + elif a is not None and b is not None and projections is not None: + nx = get_backend(X_s, X_t, a, b, projections) + elif a is None and b is None and projections is not None: + nx = get_backend(X_s, X_t, projections) + else: + nx = get_backend(X_s, X_t) n = X_s.shape[0] m = X_t.shape[0] @@ -114,31 +138,120 @@ def sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, seed X_t.shape[1])) if a is None: - a = np.full(n, 1 / n) + a = nx.full(n, 1 / n) if b is None: - b = np.full(m, 1 / m) + b = nx.full(m, 1 / m) d = X_s.shape[1] - projections = get_random_projections(n_projections, d, seed) + if projections is None: + projections = get_random_projections(d, n_projections, seed, backend=nx, type_as=X_s) + + X_s_projections = nx.dot(X_s, projections) + X_t_projections = nx.dot(X_t, projections) - X_s_projections = np.dot(projections, X_s.T) - X_t_projections = np.dot(projections, X_t.T) + projected_emd = wasserstein_1d(X_s_projections, X_t_projections, a, b, p=p) + res = (nx.sum(projected_emd) / n_projections) ** (1.0 / p) if log: - projected_emd = np.empty(n_projections) + return res, {"projections": projections, "projected_emds": projected_emd} + return res + + +def max_sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, p=2, + projections=None, seed=None, log=False): + r""" + Computes a Monte-Carlo approximation of the max p-Sliced Wasserstein distance + + .. math:: + \mathcal{Max-SWD}_p(\mu, \nu) = \underset{\theta _in + \mathcal{U}(\mathbb{S}^{d-1})}{\max} [\mathcal{W}_p^p(\theta_\# + \mu, \theta_\# \nu)]^{\frac{1}{p}} + + where : + + - :math:`\theta_\# \mu` stands for the pushforwars of the projection :math:`\mathbb{R}^d \ni X \mapsto \langle \theta, X \rangle` + + + Parameters + ---------- + X_s : ndarray, shape (n_samples_a, dim) + samples in the source domain + X_t : ndarray, shape (n_samples_b, dim) + samples in the target domain + a : ndarray, shape (n_samples_a,), optional + samples weights in the source domain + b : ndarray, shape (n_samples_b,), optional + samples weights in the target domain + n_projections : int, optional + Number of projections used for the Monte-Carlo approximation + p: float, optional = + Power p used for computing the sliced Wasserstein + projections: shape (dim, n_projections), optional + Projection matrix (n_projections and seed are not used in this case) + seed: int or RandomState or None, optional + Seed used for random number generator + log: bool, optional + if True, sliced_wasserstein_distance returns the projections used and their associated EMD. + + Returns + ------- + cost: float + Sliced Wasserstein Cost + log : dict, optional + log dictionary return only if log==True in parameters + + Examples + -------- + + >>> n_samples_a = 20 + >>> reg = 0.1 + >>> X = np.random.normal(0., 1., (n_samples_a, 5)) + >>> sliced_wasserstein_distance(X, X, seed=0) # doctest: +NORMALIZE_WHITESPACE + 0.0 + + References + ---------- + + .. [35] Deshpande, I., Hu, Y. T., Sun, R., Pyrros, A., Siddiqui, N., Koyejo, S., ... & Schwing, A. G. (2019). Max-sliced wasserstein distance and its use for gans. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (pp. 10648-10656). + """ + from .lp import wasserstein_1d + + X_s, X_t = list_to_array(X_s, X_t) + + if a is not None and b is not None and projections is None: + nx = get_backend(X_s, X_t, a, b) + elif a is not None and b is not None and projections is not None: + nx = get_backend(X_s, X_t, a, b, projections) + elif a is None and b is None and projections is not None: + nx = get_backend(X_s, X_t, projections) else: - projected_emd = None + nx = get_backend(X_s, X_t) + + n = X_s.shape[0] + m = X_t.shape[0] + + if X_s.shape[1] != X_t.shape[1]: + raise ValueError( + "X_s and X_t must have the same number of dimensions {} and {} respectively given".format(X_s.shape[1], + X_t.shape[1])) + + if a is None: + a = nx.full(n, 1 / n) + if b is None: + b = nx.full(m, 1 / m) + + d = X_s.shape[1] + + if projections is None: + projections = get_random_projections(d, n_projections, seed, backend=nx, type_as=X_s) - res = 0. + X_s_projections = nx.dot(X_s, projections) + X_t_projections = nx.dot(X_t, projections) - for i, (X_s_proj, X_t_proj) in enumerate(zip(X_s_projections, X_t_projections)): - emd = emd2_1d(X_s_proj, X_t_proj, a, b, log=False, dense=False) - if projected_emd is not None: - projected_emd[i] = emd - res += emd + projected_emd = wasserstein_1d(X_s_projections, X_t_projections, a, b, p=p) - res = (res / n_projections) ** 0.5 + res = nx.max(projected_emd) ** (1.0 / p) if log: return res, {"projections": projections, "projected_emds": projected_emd} return res diff --git a/test/test_1d_solver.py b/test/test_1d_solver.py new file mode 100644 index 0000000..2c470c2 --- /dev/null +++ b/test/test_1d_solver.py @@ -0,0 +1,85 @@ +"""Tests for module 1d Wasserstein solver""" + +# Author: Adrien Corenflos +# Nicolas Courty +# +# License: MIT License + +import numpy as np +import pytest + +import ot +from ot.lp import wasserstein_1d + +from ot.backend import get_backend_list +from scipy.stats import wasserstein_distance + +backend_list = get_backend_list() + + +def test_emd_1d_emd2_1d_with_weights(): + # test emd1d gives similar results as emd + n = 20 + m = 30 + rng = np.random.RandomState(0) + u = rng.randn(n, 1) + v = rng.randn(m, 1) + + w_u = rng.uniform(0., 1., n) + w_u = w_u / w_u.sum() + + w_v = rng.uniform(0., 1., m) + w_v = w_v / w_v.sum() + + M = ot.dist(u, v, metric='sqeuclidean') + + G, log = ot.emd(w_u, w_v, M, log=True) + wass = log["cost"] + G_1d, log = ot.emd_1d(u, v, w_u, w_v, metric='sqeuclidean', log=True) + wass1d = log["cost"] + wass1d_emd2 = ot.emd2_1d(u, v, w_u, w_v, metric='sqeuclidean', log=False) + wass1d_euc = ot.emd2_1d(u, v, w_u, w_v, metric='euclidean', log=False) + + # check loss is similar + np.testing.assert_allclose(wass, wass1d) + np.testing.assert_allclose(wass, wass1d_emd2) + + # check loss is similar to scipy's implementation for Euclidean metric + wass_sp = wasserstein_distance(u.reshape((-1,)), v.reshape((-1,)), w_u, w_v) + np.testing.assert_allclose(wass_sp, wass1d_euc) + + # check constraints + np.testing.assert_allclose(w_u, G.sum(1)) + np.testing.assert_allclose(w_v, G.sum(0)) + + +@pytest.mark.parametrize('nx', backend_list) +def test_wasserstein_1d(nx): + from scipy.stats import wasserstein_distance + + rng = np.random.RandomState(0) + + n = 100 + x = np.linspace(0, 5, n) + rho_u = np.abs(rng.randn(n)) + rho_u /= rho_u.sum() + rho_v = np.abs(rng.randn(n)) + rho_v /= rho_v.sum() + + xb = nx.from_numpy(x) + rho_ub = nx.from_numpy(rho_u) + rho_vb = nx.from_numpy(rho_v) + + # test 1 : wasserstein_1d should be close to scipy W_1 implementation + np.testing.assert_almost_equal(wasserstein_1d(xb, xb, rho_ub, rho_vb, p=1), + wasserstein_distance(x, x, rho_u, rho_v)) + + # test 2 : wasserstein_1d should be close to one when only translating the support + np.testing.assert_almost_equal(wasserstein_1d(xb, xb + 1, p=2), + 1.) + + # test 3 : arrays test + X = np.stack((np.linspace(0, 5, n), np.linspace(0, 5, n) * 10), -1) + Xb = nx.from_numpy(X) + res = wasserstein_1d(Xb, Xb, rho_ub, rho_vb, p=2) + np.testing.assert_almost_equal(100 * res[0], res[1], decimal=4) diff --git a/test/test_backend.py b/test/test_backend.py index 0f11ace..1832b91 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -208,6 +208,11 @@ def test_empty_backend(): with pytest.raises(NotImplementedError): nx.reshape(M, (5, 3, 2)) with pytest.raises(NotImplementedError): + nx.seed(42) + with pytest.raises(NotImplementedError): + nx.rand() + with pytest.raises(NotImplementedError): + nx.randn() nx.coo_matrix(M, M, M) with pytest.raises(NotImplementedError): nx.issparse(M) @@ -248,6 +253,7 @@ def test_func_backends(nx): Mb = nx.from_numpy(M) vb = nx.from_numpy(v) + val = nx.from_numpy(val) sp_rowb = nx.from_numpy(sp_row) @@ -255,6 +261,7 @@ def test_func_backends(nx): sp_datab = nx.from_numpy(sp_data) A = nx.set_gradients(val, v, v) + lst_b.append(nx.to_numpy(A)) lst_name.append('set_gradients') @@ -505,6 +512,35 @@ def test_func_backends(nx): assert np.allclose(a1, a2, atol=1e-7) +def test_random_backends(nx): + + tmp_u = nx.rand() + + assert tmp_u < 1 + + tmp_n = nx.randn() + + nx.seed(0) + M1 = nx.to_numpy(nx.rand(5, 2)) + nx.seed(0) + M2 = nx.to_numpy(nx.rand(5, 2, type_as=tmp_n)) + + assert np.all(M1 >= 0) + assert np.all(M1 < 1) + assert M1.shape == (5, 2) + assert np.allclose(M1, M2) + + nx.seed(0) + M1 = nx.to_numpy(nx.randn(5, 2)) + nx.seed(0) + M2 = nx.to_numpy(nx.randn(5, 2, type_as=tmp_u)) + + nx.seed(42) + v1 = nx.randn() + v2 = nx.randn() + assert v1 != v2 + + def test_gradients_backends(): rnd = np.random.RandomState(0) diff --git a/test/test_ot.py b/test/test_ot.py index 4dfc510..5bfde1d 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -8,11 +8,11 @@ import warnings import numpy as np import pytest -from scipy.stats import wasserstein_distance import ot from ot.datasets import make_1D_gauss as gauss from ot.backend import torch +from scipy.stats import wasserstein_distance def test_emd_dimension_and_mass_mismatch(): @@ -165,61 +165,6 @@ def test_emd_1d_emd2_1d(): ot.emd_1d(u, v, [], []) -def test_emd_1d_emd2_1d_with_weights(): - # test emd1d gives similar results as emd - n = 20 - m = 30 - rng = np.random.RandomState(0) - u = rng.randn(n, 1) - v = rng.randn(m, 1) - - w_u = rng.uniform(0., 1., n) - w_u = w_u / w_u.sum() - - w_v = rng.uniform(0., 1., m) - w_v = w_v / w_v.sum() - - M = ot.dist(u, v, metric='sqeuclidean') - - G, log = ot.emd(w_u, w_v, M, log=True) - wass = log["cost"] - G_1d, log = ot.emd_1d(u, v, w_u, w_v, metric='sqeuclidean', log=True) - wass1d = log["cost"] - wass1d_emd2 = ot.emd2_1d(u, v, w_u, w_v, metric='sqeuclidean', log=False) - wass1d_euc = ot.emd2_1d(u, v, w_u, w_v, metric='euclidean', log=False) - - # check loss is similar - np.testing.assert_allclose(wass, wass1d) - np.testing.assert_allclose(wass, wass1d_emd2) - - # check loss is similar to scipy's implementation for Euclidean metric - wass_sp = wasserstein_distance(u.reshape((-1,)), v.reshape((-1,)), w_u, w_v) - np.testing.assert_allclose(wass_sp, wass1d_euc) - - # check constraints - np.testing.assert_allclose(w_u, G.sum(1)) - np.testing.assert_allclose(w_v, G.sum(0)) - - -def test_wass_1d(): - # test emd1d gives similar results as emd - n = 20 - m = 30 - rng = np.random.RandomState(0) - u = rng.randn(n, 1) - v = rng.randn(m, 1) - - M = ot.dist(u, v, metric='sqeuclidean') - - G, log = ot.emd([], [], M, log=True) - wass = log["cost"] - - wass1d = ot.wasserstein_1d(u, v, [], [], p=2.) - - # check loss is similar - np.testing.assert_allclose(np.sqrt(wass), wass1d) - - def test_emd_empty(): # test emd and emd2 for simple identity n = 100 diff --git a/test/test_sliced.py b/test/test_sliced.py index a07d975..0bd74ec 100644 --- a/test/test_sliced.py +++ b/test/test_sliced.py @@ -1,6 +1,7 @@ """Tests for module sliced""" # Author: Adrien Corenflos +# Nicolas Courty # # License: MIT License @@ -14,7 +15,7 @@ from ot.sliced import get_random_projections def test_get_random_projections(): rng = np.random.RandomState(0) projections = get_random_projections(1000, 50, rng) - np.testing.assert_almost_equal(np.sum(projections ** 2, 1), 1.) + np.testing.assert_almost_equal(np.sum(projections ** 2, 0), 1.) def test_sliced_same_dist(): @@ -48,12 +49,12 @@ def test_sliced_log(): y = rng.randn(n, 4) u = ot.utils.unif(n) - res, log = ot.sliced_wasserstein_distance(x, y, u, u, 10, seed=rng, log=True) + res, log = ot.sliced_wasserstein_distance(x, y, u, u, 10, p=1, seed=rng, log=True) assert len(log) == 2 projections = log["projections"] projected_emds = log["projected_emds"] - assert len(projections) == len(projected_emds) == 10 + assert projections.shape[1] == len(projected_emds) == 10 for emd in projected_emds: assert emd > 0 @@ -83,3 +84,86 @@ def test_1d_sliced_equals_emd(): res = ot.sliced_wasserstein_distance(x, y, a, u, 10, seed=42) expected = ot.emd2_1d(x.squeeze(), y.squeeze(), a, u) np.testing.assert_almost_equal(res ** 2, expected) + + +def test_max_sliced_same_dist(): + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + u = ot.utils.unif(n) + + res = ot.max_sliced_wasserstein_distance(x, x, u, u, 10, seed=rng) + np.testing.assert_almost_equal(res, 0.) + + +def test_max_sliced_different_dists(): + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + u = ot.utils.unif(n) + y = rng.randn(n, 2) + + res, log = ot.max_sliced_wasserstein_distance(x, y, u, u, 10, seed=rng, log=True) + assert res > 0. + + +def test_sliced_backend(nx): + + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + y = rng.randn(2 * n, 2) + + P = rng.randn(2, 20) + P = P / np.sqrt((P**2).sum(0, keepdims=True)) + + n_projections = 20 + + xb = nx.from_numpy(x) + yb = nx.from_numpy(y) + Pb = nx.from_numpy(P) + + val0 = ot.sliced_wasserstein_distance(x, y, projections=P) + + val = ot.sliced_wasserstein_distance(xb, yb, n_projections=n_projections, seed=0) + val2 = ot.sliced_wasserstein_distance(xb, yb, n_projections=n_projections, seed=0) + + assert val > 0 + assert val == val2 + + valb = nx.to_numpy(ot.sliced_wasserstein_distance(xb, yb, projections=Pb)) + + assert np.allclose(val0, valb) + + +def test_max_sliced_backend(nx): + + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + y = rng.randn(2 * n, 2) + + P = rng.randn(2, 20) + P = P / np.sqrt((P**2).sum(0, keepdims=True)) + + n_projections = 20 + + xb = nx.from_numpy(x) + yb = nx.from_numpy(y) + Pb = nx.from_numpy(P) + + val0 = ot.max_sliced_wasserstein_distance(x, y, projections=P) + + val = ot.max_sliced_wasserstein_distance(xb, yb, n_projections=n_projections, seed=0) + val2 = ot.max_sliced_wasserstein_distance(xb, yb, n_projections=n_projections, seed=0) + + assert val > 0 + assert val == val2 + + valb = nx.to_numpy(ot.max_sliced_wasserstein_distance(xb, yb, projections=Pb)) + + assert np.allclose(val0, valb) diff --git a/test/test_utils.py b/test/test_utils.py index 0650ce2..40f4e49 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -109,7 +109,7 @@ def test_dist(): D2 = ot.dist(x, x) D3 = ot.dist(x) - D4 = ot.dist(x, x, metric='minkowski', p=0.5) + D4 = ot.dist(x, x, metric='minkowski', p=2) assert D4[0, 1] == D4[1, 0] -- cgit v1.2.3 From e1b67c641da3b3e497db6811af2c200022b10302 Mon Sep 17 00:00:00 2001 From: Hicham Janati Date: Wed, 3 Nov 2021 08:41:35 +0100 Subject: [WIP] Add debiased barycenter (Sinkhorn + convolutional sinkhorn) (#291) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add debiased sinkhorn barycenter + make loops pythonic * add debiased arg in tests * add 1d and 2d examples of debiased barycenters * fix doctest * fix flake8 * pep8 + make func private + add convergence warnings * remove rel paths + add rng + pylab to pyplot * fix stopping criterion debiased * pass alex * change params with new API * add logdomain barycenters + separate debiased API * test new API * fix jax read-only ? * raise error for jax * test catch jax error * fix pytest catch error * fix relative path * fix flake8 * add warn arg everywhere * fix ref number * catch warnings in tests * add contrib to readme + change ref number * fix convolution example + gallery thumbnails * increase coverage * fix flake Co-authored-by: Hicham Janati Co-authored-by: Rémi Flamary Co-authored-by: Alexandre Gramfort --- README.md | 8 +- examples/barycenters/plot_barycenter_1D.py | 63 +- .../barycenters/plot_barycenter_lp_vs_entropic.py | 2 +- .../barycenters/plot_convolutional_barycenter.py | 53 +- examples/barycenters/plot_debiased_barycenter.py | 131 ++ .../domain-adaptation/plot_otda_color_images.py | 118 +- .../domain-adaptation/plot_otda_linear_mapping.py | 73 +- .../plot_otda_mapping_colors_images.py | 118 +- examples/gromov/plot_gromov_barycenter.py | 90 +- ot/bregman.py | 1491 +++++++++++++++----- test/test_bregman.py | 365 ++++- 11 files changed, 1837 insertions(+), 675 deletions(-) create mode 100644 examples/barycenters/plot_debiased_barycenter.py (limited to 'test') diff --git a/README.md b/README.md index cfb9744..ff32c53 100644 --- a/README.md +++ b/README.md @@ -22,7 +22,8 @@ POT provides the following generic OT solvers (links to examples): * [Conditional gradient](https://pythonot.github.io/auto_examples/plot_optim_OTreg.html) [6] and [Generalized conditional gradient](https://pythonot.github.io/auto_examples/plot_optim_OTreg.html) for regularized OT [7]. * Entropic regularization OT solver with [Sinkhorn Knopp Algorithm](https://pythonot.github.io/auto_examples/plot_OT_1D.html) [2] , stabilized version [9] [10] [34], greedy Sinkhorn [22] and [Screening Sinkhorn [26] ](https://pythonot.github.io/auto_examples/plot_screenkhorn_1D.html). * Bregman projections for [Wasserstein barycenter](https://pythonot.github.io/auto_examples/barycenters/plot_barycenter_lp_vs_entropic.html) [3], [convolutional barycenter](https://pythonot.github.io/auto_examples/barycenters/plot_convolutional_barycenter.html) [21] and unmixing [4]. -* Sinkhorn divergence [23] and entropic regularization OT from empirical data. +* Sinkhorn divergence [23] and entropic regularization OT from empirical data. +* Debiased Sinkhorn barycenters [Sinkhorn divergence barycenter](https://pythonot.github.io/auto_examples/barycenters/plot_debiased_barycenter.html) [37] * [Smooth optimal transport solvers](https://pythonot.github.io/auto_examples/plot_OT_1D_smooth.html) (dual and semi-dual) for KL and squared L2 regularizations [17]. * Non regularized [Wasserstein barycenters [16] ](https://pythonot.github.io/auto_examples/barycenters/plot_barycenter_lp_vs_entropic.html)) with LP solver (only small scale). * [Gromov-Wasserstein distances](https://pythonot.github.io/auto_examples/gromov/plot_gromov.html) and [GW barycenters](https://pythonot.github.io/auto_examples/gromov/plot_gromov_barycenter.html) (exact [13] and regularized [12]) @@ -188,7 +189,7 @@ The contributors to this library are * [Kilian Fatras](https://kilianfatras.github.io/) (Stochastic solvers) * [Alain Rakotomamonjy](https://sites.google.com/site/alainrakotomamonjy/home) * [Vayer Titouan](https://tvayer.github.io/) (Gromov-Wasserstein -, Fused-Gromov-Wasserstein) -* [Hicham Janati](https://hichamjanati.github.io/) (Unbalanced OT) +* [Hicham Janati](https://hichamjanati.github.io/) (Unbalanced OT, Debiased barycenters) * [Romain Tavenard](https://rtavenar.github.io/) (1d Wasserstein) * [Mokhtar Z. Alaya](http://mzalaya.github.io/) (Screenkhorn) * [Ievgen Redko](https://ievred.github.io/) (Laplacian DA, JCPOT) @@ -293,3 +294,6 @@ You can also post bug reports and feature requests in Github issues. Make sure t (2019, May). [Sliced-Wasserstein flows: Nonparametric generative modeling via optimal transport and diffusions](http://proceedings.mlr.press/v97/liutkus19a/liutkus19a.pdf). In International Conference on Machine Learning (pp. 4104-4113). PMLR. + +[37] Janati, H., Cuturi, M., Gramfort, A. Proceedings of the 37th International +Conference on Machine Learning, PMLR 119:4692-4701, 2020 \ No newline at end of file diff --git a/examples/barycenters/plot_barycenter_1D.py b/examples/barycenters/plot_barycenter_1D.py index 63dc460..2373e99 100644 --- a/examples/barycenters/plot_barycenter_1D.py +++ b/examples/barycenters/plot_barycenter_1D.py @@ -18,10 +18,10 @@ SIAM Journal on Scientific Computing, 37(2), A1111-A1138. # # License: MIT License -# sphinx_gallery_thumbnail_number = 4 +# sphinx_gallery_thumbnail_number = 1 import numpy as np -import matplotlib.pylab as pl +import matplotlib.pyplot as plt import ot # necessary for 3d plot even if not used from mpl_toolkits.mplot3d import Axes3D # noqa @@ -50,18 +50,6 @@ n_distributions = A.shape[1] M = ot.utils.dist0(n) M /= M.max() -############################################################################## -# Plot data -# --------- - -#%% plot the distributions - -pl.figure(1, figsize=(6.4, 3)) -for i in range(n_distributions): - pl.plot(x, A[:, i]) -pl.title('Distributions') -pl.tight_layout() - ############################################################################## # Barycenter computation # ---------------------- @@ -78,24 +66,20 @@ bary_l2 = A.dot(weights) reg = 1e-3 bary_wass = ot.bregman.barycenter(A, M, reg, weights) -pl.figure(2) -pl.clf() -pl.subplot(2, 1, 1) -for i in range(n_distributions): - pl.plot(x, A[:, i]) -pl.title('Distributions') +f, (ax1, ax2) = plt.subplots(2, 1, tight_layout=True, num=1) +ax1.plot(x, A, color="black") +ax1.set_title('Distributions') -pl.subplot(2, 1, 2) -pl.plot(x, bary_l2, 'r', label='l2') -pl.plot(x, bary_wass, 'g', label='Wasserstein') -pl.legend() -pl.title('Barycenters') -pl.tight_layout() +ax2.plot(x, bary_l2, 'r', label='l2') +ax2.plot(x, bary_wass, 'g', label='Wasserstein') +ax2.set_title('Barycenters') + +plt.legend() +plt.show() ############################################################################## # Barycentric interpolation # ------------------------- - #%% barycenter interpolation n_alpha = 11 @@ -106,24 +90,23 @@ B_l2 = np.zeros((n, n_alpha)) B_wass = np.copy(B_l2) -for i in range(0, n_alpha): +for i in range(n_alpha): alpha = alpha_list[i] weights = np.array([1 - alpha, alpha]) B_l2[:, i] = A.dot(weights) B_wass[:, i] = ot.bregman.barycenter(A, M, reg, weights) #%% plot interpolation +plt.figure(2) -pl.figure(3) - -cmap = pl.cm.get_cmap('viridis') +cmap = plt.cm.get_cmap('viridis') verts = [] zs = alpha_list for i, z in enumerate(zs): ys = B_l2[:, i] verts.append(list(zip(x, ys))) -ax = pl.gcf().gca(projection='3d') +ax = plt.gcf().gca(projection='3d') poly = PolyCollection(verts, facecolors=[cmap(a) for a in alpha_list]) poly.set_alpha(0.7) @@ -134,18 +117,18 @@ ax.set_ylabel('$\\alpha$') ax.set_ylim3d(0, 1) ax.set_zlabel('') ax.set_zlim3d(0, B_l2.max() * 1.01) -pl.title('Barycenter interpolation with l2') -pl.tight_layout() +plt.title('Barycenter interpolation with l2') +plt.tight_layout() -pl.figure(4) -cmap = pl.cm.get_cmap('viridis') +plt.figure(3) +cmap = plt.cm.get_cmap('viridis') verts = [] zs = alpha_list for i, z in enumerate(zs): ys = B_wass[:, i] verts.append(list(zip(x, ys))) -ax = pl.gcf().gca(projection='3d') +ax = plt.gcf().gca(projection='3d') poly = PolyCollection(verts, facecolors=[cmap(a) for a in alpha_list]) poly.set_alpha(0.7) @@ -156,7 +139,7 @@ ax.set_ylabel('$\\alpha$') ax.set_ylim3d(0, 1) ax.set_zlabel('') ax.set_zlim3d(0, B_l2.max() * 1.01) -pl.title('Barycenter interpolation with Wasserstein') -pl.tight_layout() +plt.title('Barycenter interpolation with Wasserstein') +plt.tight_layout() -pl.show() +plt.show() diff --git a/examples/barycenters/plot_barycenter_lp_vs_entropic.py b/examples/barycenters/plot_barycenter_lp_vs_entropic.py index 57a6bac..6502f16 100644 --- a/examples/barycenters/plot_barycenter_lp_vs_entropic.py +++ b/examples/barycenters/plot_barycenter_lp_vs_entropic.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- """ ================================================================================= -1D Wasserstein barycenter comparison between exact LP and entropic regularization +1D Wasserstein barycenter: exact LP vs entropic regularization ================================================================================= This example illustrates the computation of regularized Wasserstein Barycenter diff --git a/examples/barycenters/plot_convolutional_barycenter.py b/examples/barycenters/plot_convolutional_barycenter.py index cbcd4a1..3721f31 100644 --- a/examples/barycenters/plot_convolutional_barycenter.py +++ b/examples/barycenters/plot_convolutional_barycenter.py @@ -6,17 +6,18 @@ Convolutional Wasserstein Barycenter example ============================================ -This example is designed to illustrate how the Convolutional Wasserstein Barycenter -function of POT works. +This example is designed to illustrate how the Convolutional Wasserstein +Barycenter function of POT works. """ # Author: Nicolas Courty # # License: MIT License - +import os +from pathlib import Path import numpy as np -import pylab as pl +import matplotlib.pyplot as plt import ot ############################################################################## @@ -25,22 +26,19 @@ import ot # # The four distributions are constructed from 4 simple images +this_file = os.path.realpath('__file__') +data_path = os.path.join(Path(this_file).parent.parent.parent, 'data') -f1 = 1 - pl.imread('../../data/redcross.png')[:, :, 2] -f2 = 1 - pl.imread('../../data/duck.png')[:, :, 2] -f3 = 1 - pl.imread('../../data/heart.png')[:, :, 2] -f4 = 1 - pl.imread('../../data/tooth.png')[:, :, 2] +f1 = 1 - plt.imread(os.path.join(data_path, 'redcross.png'))[:, :, 2] +f2 = 1 - plt.imread(os.path.join(data_path, 'tooth.png'))[:, :, 2] +f3 = 1 - plt.imread(os.path.join(data_path, 'heart.png'))[:, :, 2] +f4 = 1 - plt.imread(os.path.join(data_path, 'duck.png'))[:, :, 2] -A = [] f1 = f1 / np.sum(f1) f2 = f2 / np.sum(f2) f3 = f3 / np.sum(f3) f4 = f4 / np.sum(f4) -A.append(f1) -A.append(f2) -A.append(f3) -A.append(f4) -A = np.array(A) +A = np.array([f1, f2, f3, f4]) nb_images = 5 @@ -57,14 +55,13 @@ v4 = np.array((0, 0, 0, 1)) # ---------------------------------------- # -pl.figure(figsize=(10, 10)) -pl.title('Convolutional Wasserstein Barycenters in POT') +fig, axes = plt.subplots(nb_images, nb_images, figsize=(7, 7)) +plt.suptitle('Convolutional Wasserstein Barycenters in POT') cm = 'Blues' # regularization parameter reg = 0.004 for i in range(nb_images): for j in range(nb_images): - pl.subplot(nb_images, nb_images, i * nb_images + j + 1) tx = float(i) / (nb_images - 1) ty = float(j) / (nb_images - 1) @@ -74,19 +71,19 @@ for i in range(nb_images): weights = (1 - ty) * tmp1 + ty * tmp2 if i == 0 and j == 0: - pl.imshow(f1, cmap=cm) - pl.axis('off') + axes[i, j].imshow(f1, cmap=cm) elif i == 0 and j == (nb_images - 1): - pl.imshow(f3, cmap=cm) - pl.axis('off') + axes[i, j].imshow(f3, cmap=cm) elif i == (nb_images - 1) and j == 0: - pl.imshow(f2, cmap=cm) - pl.axis('off') + axes[i, j].imshow(f2, cmap=cm) elif i == (nb_images - 1) and j == (nb_images - 1): - pl.imshow(f4, cmap=cm) - pl.axis('off') + axes[i, j].imshow(f4, cmap=cm) else: # call to barycenter computation - pl.imshow(ot.bregman.convolutional_barycenter2d(A, reg, weights), cmap=cm) - pl.axis('off') -pl.show() + axes[i, j].imshow( + ot.bregman.convolutional_barycenter2d(A, reg, weights), + cmap=cm + ) + axes[i, j].axis('off') +plt.tight_layout() +plt.show() diff --git a/examples/barycenters/plot_debiased_barycenter.py b/examples/barycenters/plot_debiased_barycenter.py new file mode 100644 index 0000000..2a603dd --- /dev/null +++ b/examples/barycenters/plot_debiased_barycenter.py @@ -0,0 +1,131 @@ +# -*- coding: utf-8 -*- +""" +================================= +Debiased Sinkhorn barycenter demo +================================= + +This example illustrates the computation of the debiased Sinkhorn barycenter +as proposed in [37]_. + + +.. [37] Janati, H., Cuturi, M., Gramfort, A. Proceedings of the 37th + International Conference on Machine Learning, PMLR 119:4692-4701, 2020 +""" + +# Author: Hicham Janati +# +# License: MIT License +# sphinx_gallery_thumbnail_number = 3 + +import os +from pathlib import Path + +import numpy as np +import matplotlib.pyplot as plt + +import ot +from ot.bregman import (barycenter, barycenter_debiased, + convolutional_barycenter2d, + convolutional_barycenter2d_debiased) + +############################################################################## +# Debiased barycenter of 1D Gaussians +# ------------------------------------ + +#%% parameters + +n = 100 # nb bins + +# bin positions +x = np.arange(n, dtype=np.float64) + +# Gaussian distributions +a1 = ot.datasets.make_1D_gauss(n, m=20, s=5) # m= mean, s= std +a2 = ot.datasets.make_1D_gauss(n, m=60, s=8) + +# creating matrix A containing all distributions +A = np.vstack((a1, a2)).T +n_distributions = A.shape[1] + +# loss matrix + normalization +M = ot.utils.dist0(n) +M /= M.max() + +#%% barycenter computation + +alpha = 0.2 # 0<=alpha<=1 +weights = np.array([1 - alpha, alpha]) + +epsilons = [5e-3, 1e-2, 5e-2] + + +bars = [barycenter(A, M, reg, weights) for reg in epsilons] +bars_debiased = [barycenter_debiased(A, M, reg, weights) for reg in epsilons] +labels = ["Sinkhorn barycenter", "Debiased barycenter"] +colors = ["indianred", "gold"] + +f, axes = plt.subplots(1, len(epsilons), tight_layout=True, sharey=True, + figsize=(12, 4), num=1) +for ax, eps, bar, bar_debiased in zip(axes, epsilons, bars, bars_debiased): + ax.plot(A[:, 0], color="k", ls="--", label="Input data", alpha=0.3) + ax.plot(A[:, 1], color="k", ls="--", alpha=0.3) + for data, label, color in zip([bar, bar_debiased], labels, colors): + ax.plot(data, color=color, label=label, lw=2) + ax.set_title(r"$\varepsilon = %.3f$" % eps) +plt.legend() +plt.show() + + +############################################################################## +# Debiased barycenter of 2D images +# --------------------------------- +this_file = os.path.realpath('__file__') +data_path = os.path.join(Path(this_file).parent.parent.parent, 'data') +f1 = 1 - plt.imread(os.path.join(data_path, 'heart.png'))[:, :, 2] +f2 = 1 - plt.imread(os.path.join(data_path, 'duck.png'))[:, :, 2] + +A = np.asarray([f1, f2]) + 1e-2 +A /= A.sum(axis=(1, 2))[:, None, None] + +############################################################################## +# Display the input images + +fig, axes = plt.subplots(1, 2, figsize=(7, 4), num=2) +for ax, img in zip(axes, A): + ax.imshow(img, cmap="Greys") + ax.axis("off") +fig.tight_layout() +plt.show() + + +############################################################################## +# Barycenter computation and visualization +# ---------------------------------------- +# + +bars_sinkhorn, bars_debiased = [], [] +epsilons = [5e-3, 7e-3, 1e-2] +for eps in epsilons: + bar = convolutional_barycenter2d(A, eps) + bar_debiased, log = convolutional_barycenter2d_debiased(A, eps, log=True) + bars_sinkhorn.append(bar) + bars_debiased.append(bar_debiased) + +titles = ["Sinkhorn", "Debiased"] +all_bars = [bars_sinkhorn, bars_debiased] +fig, axes = plt.subplots(2, 3, figsize=(8, 6), num=3) +for jj, (method, ax_row, bars) in enumerate(zip(titles, axes, all_bars)): + for ii, (ax, img, eps) in enumerate(zip(ax_row, bars, epsilons)): + ax.imshow(img, cmap="Greys") + if jj == 0: + ax.set_title(r"$\varepsilon = %.3f$" % eps, fontsize=13) + ax.set_xticks([]) + ax.set_yticks([]) + ax.spines['top'].set_visible(False) + ax.spines['right'].set_visible(False) + ax.spines['bottom'].set_visible(False) + ax.spines['left'].set_visible(False) + if ii == 0: + ax.set_ylabel(method, fontsize=15) +fig.tight_layout() +plt.show() diff --git a/examples/domain-adaptation/plot_otda_color_images.py b/examples/domain-adaptation/plot_otda_color_images.py index 6218b13..06dc8ab 100644 --- a/examples/domain-adaptation/plot_otda_color_images.py +++ b/examples/domain-adaptation/plot_otda_color_images.py @@ -19,12 +19,15 @@ SIAM Journal on Imaging Sciences, 7(3), 1853-1882. # sphinx_gallery_thumbnail_number = 2 +import os +from pathlib import Path + import numpy as np -import matplotlib.pylab as pl +from matplotlib import pyplot as plt import ot -r = np.random.RandomState(42) +rng = np.random.RandomState(42) def im2mat(img): @@ -46,16 +49,19 @@ def minmax(img): # ------------- # Loading images -I1 = pl.imread('../../data/ocean_day.jpg').astype(np.float64) / 256 -I2 = pl.imread('../../data/ocean_sunset.jpg').astype(np.float64) / 256 +this_file = os.path.realpath('__file__') +data_path = os.path.join(Path(this_file).parent.parent.parent, 'data') + +I1 = plt.imread(os.path.join(data_path, 'ocean_day.jpg')).astype(np.float64) / 256 +I2 = plt.imread(os.path.join(data_path, 'ocean_sunset.jpg')).astype(np.float64) / 256 X1 = im2mat(I1) X2 = im2mat(I2) # training samples nb = 500 -idx1 = r.randint(X1.shape[0], size=(nb,)) -idx2 = r.randint(X2.shape[0], size=(nb,)) +idx1 = rng.randint(X1.shape[0], size=(nb,)) +idx2 = rng.randint(X2.shape[0], size=(nb,)) Xs = X1[idx1, :] Xt = X2[idx2, :] @@ -65,39 +71,39 @@ Xt = X2[idx2, :] # Plot original image # ------------------- -pl.figure(1, figsize=(6.4, 3)) +plt.figure(1, figsize=(6.4, 3)) -pl.subplot(1, 2, 1) -pl.imshow(I1) -pl.axis('off') -pl.title('Image 1') +plt.subplot(1, 2, 1) +plt.imshow(I1) +plt.axis('off') +plt.title('Image 1') -pl.subplot(1, 2, 2) -pl.imshow(I2) -pl.axis('off') -pl.title('Image 2') +plt.subplot(1, 2, 2) +plt.imshow(I2) +plt.axis('off') +plt.title('Image 2') ############################################################################## # Scatter plot of colors # ---------------------- -pl.figure(2, figsize=(6.4, 3)) +plt.figure(2, figsize=(6.4, 3)) -pl.subplot(1, 2, 1) -pl.scatter(Xs[:, 0], Xs[:, 2], c=Xs) -pl.axis([0, 1, 0, 1]) -pl.xlabel('Red') -pl.ylabel('Blue') -pl.title('Image 1') +plt.subplot(1, 2, 1) +plt.scatter(Xs[:, 0], Xs[:, 2], c=Xs) +plt.axis([0, 1, 0, 1]) +plt.xlabel('Red') +plt.ylabel('Blue') +plt.title('Image 1') -pl.subplot(1, 2, 2) -pl.scatter(Xt[:, 0], Xt[:, 2], c=Xt) -pl.axis([0, 1, 0, 1]) -pl.xlabel('Red') -pl.ylabel('Blue') -pl.title('Image 2') -pl.tight_layout() +plt.subplot(1, 2, 2) +plt.scatter(Xt[:, 0], Xt[:, 2], c=Xt) +plt.axis([0, 1, 0, 1]) +plt.xlabel('Red') +plt.ylabel('Blue') +plt.title('Image 2') +plt.tight_layout() ############################################################################## @@ -130,37 +136,37 @@ I2te = minmax(mat2im(transp_Xt_sinkhorn, I2.shape)) # Plot new images # --------------- -pl.figure(3, figsize=(8, 4)) +plt.figure(3, figsize=(8, 4)) -pl.subplot(2, 3, 1) -pl.imshow(I1) -pl.axis('off') -pl.title('Image 1') +plt.subplot(2, 3, 1) +plt.imshow(I1) +plt.axis('off') +plt.title('Image 1') -pl.subplot(2, 3, 2) -pl.imshow(I1t) -pl.axis('off') -pl.title('Image 1 Adapt') +plt.subplot(2, 3, 2) +plt.imshow(I1t) +plt.axis('off') +plt.title('Image 1 Adapt') -pl.subplot(2, 3, 3) -pl.imshow(I1te) -pl.axis('off') -pl.title('Image 1 Adapt (reg)') +plt.subplot(2, 3, 3) +plt.imshow(I1te) +plt.axis('off') +plt.title('Image 1 Adapt (reg)') -pl.subplot(2, 3, 4) -pl.imshow(I2) -pl.axis('off') -pl.title('Image 2') +plt.subplot(2, 3, 4) +plt.imshow(I2) +plt.axis('off') +plt.title('Image 2') -pl.subplot(2, 3, 5) -pl.imshow(I2t) -pl.axis('off') -pl.title('Image 2 Adapt') +plt.subplot(2, 3, 5) +plt.imshow(I2t) +plt.axis('off') +plt.title('Image 2 Adapt') -pl.subplot(2, 3, 6) -pl.imshow(I2te) -pl.axis('off') -pl.title('Image 2 Adapt (reg)') -pl.tight_layout() +plt.subplot(2, 3, 6) +plt.imshow(I2te) +plt.axis('off') +plt.title('Image 2 Adapt (reg)') +plt.tight_layout() -pl.show() +plt.show() diff --git a/examples/domain-adaptation/plot_otda_linear_mapping.py b/examples/domain-adaptation/plot_otda_linear_mapping.py index be47510..a44096a 100644 --- a/examples/domain-adaptation/plot_otda_linear_mapping.py +++ b/examples/domain-adaptation/plot_otda_linear_mapping.py @@ -13,9 +13,11 @@ Linear OT mapping estimation # License: MIT License # sphinx_gallery_thumbnail_number = 2 +import os +from pathlib import Path import numpy as np -import pylab as pl +from matplotlib import pyplot as plt import ot ############################################################################## @@ -26,17 +28,19 @@ n = 1000 d = 2 sigma = .1 +rng = np.random.RandomState(42) + # source samples -angles = np.random.rand(n, 1) * 2 * np.pi +angles = rng.rand(n, 1) * 2 * np.pi xs = np.concatenate((np.sin(angles), np.cos(angles)), - axis=1) + sigma * np.random.randn(n, 2) + axis=1) + sigma * rng.randn(n, 2) xs[:n // 2, 1] += 2 # target samples -anglet = np.random.rand(n, 1) * 2 * np.pi +anglet = rng.rand(n, 1) * 2 * np.pi xt = np.concatenate((np.sin(anglet), np.cos(anglet)), - axis=1) + sigma * np.random.randn(n, 2) + axis=1) + sigma * rng.randn(n, 2) xt[:n // 2, 1] += 2 @@ -48,9 +52,9 @@ xt = xt.dot(A) + b # Plot data # --------- -pl.figure(1, (5, 5)) -pl.plot(xs[:, 0], xs[:, 1], '+') -pl.plot(xt[:, 0], xt[:, 1], 'o') +plt.figure(1, (5, 5)) +plt.plot(xs[:, 0], xs[:, 1], '+') +plt.plot(xt[:, 0], xt[:, 1], 'o') ############################################################################## @@ -66,13 +70,13 @@ xst = xs.dot(Ae) + be # Plot transported samples # ------------------------ -pl.figure(1, (5, 5)) -pl.clf() -pl.plot(xs[:, 0], xs[:, 1], '+') -pl.plot(xt[:, 0], xt[:, 1], 'o') -pl.plot(xst[:, 0], xst[:, 1], '+') +plt.figure(1, (5, 5)) +plt.clf() +plt.plot(xs[:, 0], xs[:, 1], '+') +plt.plot(xt[:, 0], xt[:, 1], 'o') +plt.plot(xst[:, 0], xst[:, 1], '+') -pl.show() +plt.show() ############################################################################## # Load image data @@ -94,8 +98,11 @@ def minmax(img): # Loading images -I1 = pl.imread('../../data/ocean_day.jpg').astype(np.float64) / 256 -I2 = pl.imread('../../data/ocean_sunset.jpg').astype(np.float64) / 256 +this_file = os.path.realpath('__file__') +data_path = os.path.join(Path(this_file).parent.parent.parent, 'data') + +I1 = plt.imread(os.path.join(data_path, 'ocean_day.jpg')).astype(np.float64) / 256 +I2 = plt.imread(os.path.join(data_path, 'ocean_sunset.jpg')).astype(np.float64) / 256 X1 = im2mat(I1) @@ -123,24 +130,24 @@ I2t = minmax(mat2im(xts, I2.shape)) # Plot transformed images # ----------------------- -pl.figure(2, figsize=(10, 7)) +plt.figure(2, figsize=(10, 7)) -pl.subplot(2, 2, 1) -pl.imshow(I1) -pl.axis('off') -pl.title('Im. 1') +plt.subplot(2, 2, 1) +plt.imshow(I1) +plt.axis('off') +plt.title('Im. 1') -pl.subplot(2, 2, 2) -pl.imshow(I2) -pl.axis('off') -pl.title('Im. 2') +plt.subplot(2, 2, 2) +plt.imshow(I2) +plt.axis('off') +plt.title('Im. 2') -pl.subplot(2, 2, 3) -pl.imshow(I1t) -pl.axis('off') -pl.title('Mapping Im. 1') +plt.subplot(2, 2, 3) +plt.imshow(I1t) +plt.axis('off') +plt.title('Mapping Im. 1') -pl.subplot(2, 2, 4) -pl.imshow(I2t) -pl.axis('off') -pl.title('Inverse mapping Im. 2') +plt.subplot(2, 2, 4) +plt.imshow(I2t) +plt.axis('off') +plt.title('Inverse mapping Im. 2') diff --git a/examples/domain-adaptation/plot_otda_mapping_colors_images.py b/examples/domain-adaptation/plot_otda_mapping_colors_images.py index 72010a6..dbece70 100644 --- a/examples/domain-adaptation/plot_otda_mapping_colors_images.py +++ b/examples/domain-adaptation/plot_otda_mapping_colors_images.py @@ -21,12 +21,14 @@ discrete optimal transport", Neural Information Processing Systems (NIPS), 2016. # License: MIT License # sphinx_gallery_thumbnail_number = 3 +import os +from pathlib import Path import numpy as np -import matplotlib.pylab as pl +from matplotlib import pyplot as plt import ot -r = np.random.RandomState(42) +rng = np.random.RandomState(42) def im2mat(img): @@ -48,17 +50,19 @@ def minmax(img): # ------------- # Loading images -I1 = pl.imread('../../data/ocean_day.jpg').astype(np.float64) / 256 -I2 = pl.imread('../../data/ocean_sunset.jpg').astype(np.float64) / 256 +this_file = os.path.realpath('__file__') +data_path = os.path.join(Path(this_file).parent.parent.parent, 'data') +I1 = plt.imread(os.path.join(data_path, 'ocean_day.jpg')).astype(np.float64) / 256 +I2 = plt.imread(os.path.join(data_path, 'ocean_sunset.jpg')).astype(np.float64) / 256 X1 = im2mat(I1) X2 = im2mat(I2) # training samples nb = 500 -idx1 = r.randint(X1.shape[0], size=(nb,)) -idx2 = r.randint(X2.shape[0], size=(nb,)) +idx1 = rng.randint(X1.shape[0], size=(nb,)) +idx2 = rng.randint(X2.shape[0], size=(nb,)) Xs = X1[idx1, :] Xt = X2[idx2, :] @@ -99,76 +103,76 @@ Image_mapping_gaussian = minmax(mat2im(X1tn, I1.shape)) # Plot original images # -------------------- -pl.figure(1, figsize=(6.4, 3)) -pl.subplot(1, 2, 1) -pl.imshow(I1) -pl.axis('off') -pl.title('Image 1') +plt.figure(1, figsize=(6.4, 3)) +plt.subplot(1, 2, 1) +plt.imshow(I1) +plt.axis('off') +plt.title('Image 1') -pl.subplot(1, 2, 2) -pl.imshow(I2) -pl.axis('off') -pl.title('Image 2') -pl.tight_layout() +plt.subplot(1, 2, 2) +plt.imshow(I2) +plt.axis('off') +plt.title('Image 2') +plt.tight_layout() ############################################################################## # Plot pixel values distribution # ------------------------------ -pl.figure(2, figsize=(6.4, 5)) +plt.figure(2, figsize=(6.4, 5)) -pl.subplot(1, 2, 1) -pl.scatter(Xs[:, 0], Xs[:, 2], c=Xs) -pl.axis([0, 1, 0, 1]) -pl.xlabel('Red') -pl.ylabel('Blue') -pl.title('Image 1') +plt.subplot(1, 2, 1) +plt.scatter(Xs[:, 0], Xs[:, 2], c=Xs) +plt.axis([0, 1, 0, 1]) +plt.xlabel('Red') +plt.ylabel('Blue') +plt.title('Image 1') -pl.subplot(1, 2, 2) -pl.scatter(Xt[:, 0], Xt[:, 2], c=Xt) -pl.axis([0, 1, 0, 1]) -pl.xlabel('Red') -pl.ylabel('Blue') -pl.title('Image 2') -pl.tight_layout() +plt.subplot(1, 2, 2) +plt.scatter(Xt[:, 0], Xt[:, 2], c=Xt) +plt.axis([0, 1, 0, 1]) +plt.xlabel('Red') +plt.ylabel('Blue') +plt.title('Image 2') +plt.tight_layout() ############################################################################## # Plot transformed images # ----------------------- -pl.figure(2, figsize=(10, 5)) +plt.figure(2, figsize=(10, 5)) -pl.subplot(2, 3, 1) -pl.imshow(I1) -pl.axis('off') -pl.title('Im. 1') +plt.subplot(2, 3, 1) +plt.imshow(I1) +plt.axis('off') +plt.title('Im. 1') -pl.subplot(2, 3, 4) -pl.imshow(I2) -pl.axis('off') -pl.title('Im. 2') +plt.subplot(2, 3, 4) +plt.imshow(I2) +plt.axis('off') +plt.title('Im. 2') -pl.subplot(2, 3, 2) -pl.imshow(Image_emd) -pl.axis('off') -pl.title('EmdTransport') +plt.subplot(2, 3, 2) +plt.imshow(Image_emd) +plt.axis('off') +plt.title('EmdTransport') -pl.subplot(2, 3, 5) -pl.imshow(Image_sinkhorn) -pl.axis('off') -pl.title('SinkhornTransport') +plt.subplot(2, 3, 5) +plt.imshow(Image_sinkhorn) +plt.axis('off') +plt.title('SinkhornTransport') -pl.subplot(2, 3, 3) -pl.imshow(Image_mapping_linear) -pl.axis('off') -pl.title('MappingTransport (linear)') +plt.subplot(2, 3, 3) +plt.imshow(Image_mapping_linear) +plt.axis('off') +plt.title('MappingTransport (linear)') -pl.subplot(2, 3, 6) -pl.imshow(Image_mapping_gaussian) -pl.axis('off') -pl.title('MappingTransport (gaussian)') -pl.tight_layout() +plt.subplot(2, 3, 6) +plt.imshow(Image_mapping_gaussian) +plt.axis('off') +plt.title('MappingTransport (gaussian)') +plt.tight_layout() -pl.show() +plt.show() diff --git a/examples/gromov/plot_gromov_barycenter.py b/examples/gromov/plot_gromov_barycenter.py index e2d88ba..7fe081f 100755 --- a/examples/gromov/plot_gromov_barycenter.py +++ b/examples/gromov/plot_gromov_barycenter.py @@ -13,11 +13,13 @@ computation in POT. # # License: MIT License +import os +from pathlib import Path import numpy as np import scipy as sp -import matplotlib.pylab as pl +from matplotlib import pyplot as plt from sklearn import manifold from sklearn.decomposition import PCA @@ -89,17 +91,19 @@ def im2mat(img): return img.reshape((img.shape[0] * img.shape[1], img.shape[2])) -square = pl.imread('../../data/square.png').astype(np.float64)[:, :, 2] -cross = pl.imread('../../data/cross.png').astype(np.float64)[:, :, 2] -triangle = pl.imread('../../data/triangle.png').astype(np.float64)[:, :, 2] -star = pl.imread('../../data/star.png').astype(np.float64)[:, :, 2] +this_file = os.path.realpath('__file__') +data_path = os.path.join(Path(this_file).parent.parent.parent, 'data') + +square = plt.imread(os.path.join(data_path, 'square.png')).astype(np.float64)[:, :, 2] +cross = plt.imread(os.path.join(data_path, 'cross.png')).astype(np.float64)[:, :, 2] +triangle = plt.imread(os.path.join(data_path, 'triangle.png')).astype(np.float64)[:, :, 2] +star = plt.imread(os.path.join(data_path, 'star.png')).astype(np.float64)[:, :, 2] shapes = [square, cross, triangle, star] S = 4 xs = [[] for i in range(S)] - for nb in range(4): for i in range(8): for j in range(8): @@ -184,64 +188,64 @@ npost23 = [smacof_mds(Ct23[s], 2) for s in range(2)] npost23 = [clf.fit_transform(npost23[s]) for s in range(2)] -fig = pl.figure(figsize=(10, 10)) +fig = plt.figure(figsize=(10, 10)) -ax1 = pl.subplot2grid((4, 4), (0, 0)) -pl.xlim((-1, 1)) -pl.ylim((-1, 1)) +ax1 = plt.subplot2grid((4, 4), (0, 0)) +plt.xlim((-1, 1)) +plt.ylim((-1, 1)) ax1.scatter(npos[0][:, 0], npos[0][:, 1], color='r') -ax2 = pl.subplot2grid((4, 4), (0, 1)) -pl.xlim((-1, 1)) -pl.ylim((-1, 1)) +ax2 = plt.subplot2grid((4, 4), (0, 1)) +plt.xlim((-1, 1)) +plt.ylim((-1, 1)) ax2.scatter(npost01[1][:, 0], npost01[1][:, 1], color='b') -ax3 = pl.subplot2grid((4, 4), (0, 2)) -pl.xlim((-1, 1)) -pl.ylim((-1, 1)) +ax3 = plt.subplot2grid((4, 4), (0, 2)) +plt.xlim((-1, 1)) +plt.ylim((-1, 1)) ax3.scatter(npost01[0][:, 0], npost01[0][:, 1], color='b') -ax4 = pl.subplot2grid((4, 4), (0, 3)) -pl.xlim((-1, 1)) -pl.ylim((-1, 1)) +ax4 = plt.subplot2grid((4, 4), (0, 3)) +plt.xlim((-1, 1)) +plt.ylim((-1, 1)) ax4.scatter(npos[1][:, 0], npos[1][:, 1], color='r') -ax5 = pl.subplot2grid((4, 4), (1, 0)) -pl.xlim((-1, 1)) -pl.ylim((-1, 1)) +ax5 = plt.subplot2grid((4, 4), (1, 0)) +plt.xlim((-1, 1)) +plt.ylim((-1, 1)) ax5.scatter(npost02[1][:, 0], npost02[1][:, 1], color='b') -ax6 = pl.subplot2grid((4, 4), (1, 3)) -pl.xlim((-1, 1)) -pl.ylim((-1, 1)) +ax6 = plt.subplot2grid((4, 4), (1, 3)) +plt.xlim((-1, 1)) +plt.ylim((-1, 1)) ax6.scatter(npost13[1][:, 0], npost13[1][:, 1], color='b') -ax7 = pl.subplot2grid((4, 4), (2, 0)) -pl.xlim((-1, 1)) -pl.ylim((-1, 1)) +ax7 = plt.subplot2grid((4, 4), (2, 0)) +plt.xlim((-1, 1)) +plt.ylim((-1, 1)) ax7.scatter(npost02[0][:, 0], npost02[0][:, 1], color='b') -ax8 = pl.subplot2grid((4, 4), (2, 3)) -pl.xlim((-1, 1)) -pl.ylim((-1, 1)) +ax8 = plt.subplot2grid((4, 4), (2, 3)) +plt.xlim((-1, 1)) +plt.ylim((-1, 1)) ax8.scatter(npost13[0][:, 0], npost13[0][:, 1], color='b') -ax9 = pl.subplot2grid((4, 4), (3, 0)) -pl.xlim((-1, 1)) -pl.ylim((-1, 1)) +ax9 = plt.subplot2grid((4, 4), (3, 0)) +plt.xlim((-1, 1)) +plt.ylim((-1, 1)) ax9.scatter(npos[2][:, 0], npos[2][:, 1], color='r') -ax10 = pl.subplot2grid((4, 4), (3, 1)) -pl.xlim((-1, 1)) -pl.ylim((-1, 1)) +ax10 = plt.subplot2grid((4, 4), (3, 1)) +plt.xlim((-1, 1)) +plt.ylim((-1, 1)) ax10.scatter(npost23[1][:, 0], npost23[1][:, 1], color='b') -ax11 = pl.subplot2grid((4, 4), (3, 2)) -pl.xlim((-1, 1)) -pl.ylim((-1, 1)) +ax11 = plt.subplot2grid((4, 4), (3, 2)) +plt.xlim((-1, 1)) +plt.ylim((-1, 1)) ax11.scatter(npost23[0][:, 0], npost23[0][:, 1], color='b') -ax12 = pl.subplot2grid((4, 4), (3, 3)) -pl.xlim((-1, 1)) -pl.ylim((-1, 1)) +ax12 = plt.subplot2grid((4, 4), (3, 3)) +plt.xlim((-1, 1)) +plt.ylim((-1, 1)) ax12.scatter(npos[3][:, 0], npos[3][:, 1], color='r') diff --git a/ot/bregman.py b/ot/bregman.py index 0499b8e..786f151 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -7,7 +7,7 @@ Bregman projections solvers for entropic regularized OT # Nicolas Courty # Kilian Fatras # Titouan Vayer -# Hicham Janati +# Hicham Janati # Mokhtar Z. Alaya # Alexander Tong # Ievgen Redko @@ -25,7 +25,8 @@ from .backend import get_backend def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, - stopThr=1e-9, verbose=False, log=False, **kwargs): + stopThr=1e-9, verbose=False, log=False, warn=True, + **kwargs): r""" Solve the entropic regularization optimal transport problem and return the OT matrix @@ -43,8 +44,10 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, where : - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix - - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (histograms, both sum to 1) + - :math:`\Omega` is the entropic regularization term + :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target + weights (histograms, both sum to 1) .. note:: This function is backend-compatible and will work on arrays from all compatible backends. @@ -77,7 +80,8 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, samples weights in the source domain b : array-like, shape (dim_b,) or ndarray, shape (dim_b, n_hists) samples in the target domain, compute sinkhorn with multiple targets - and fixed :math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix (return OT loss + dual variables in log) + and fixed :math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix + (return OT loss + dual variables in log) M : array-like, shape (dim_a, dim_b) loss matrix reg : float @@ -94,6 +98,8 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, Print information along iterations log : bool, optional record log if True + warn : bool, optional + if True, raises a warning if the algorithm doesn't convergence. Returns ------- @@ -117,13 +123,21 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, References ---------- - .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation + of Optimal Transport, Advances in Neural Information Processing + Systems (NIPS) 26, 2013 - .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519. + .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms + for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519. - .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. + .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). + Scaling algorithms for unbalanced transport problems. + arXiv preprint arXiv:1607.05816. - .. [34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., Trouvé, A., & Peyré, G. (2019, April). Interpolating between optimal transport and MMD using Sinkhorn divergences. In The 22nd International Conference on Artificial Intelligence and Statistics (pp. 2681-2690). PMLR. + .. [34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., Trouvé, + A., & Peyré, G. (2019, April). Interpolating between optimal transport + and MMD using Sinkhorn divergences. In The 22nd International Conference + on Artificial Intelligence and Statistics (pp. 2681-2690). PMLR. See Also @@ -131,37 +145,44 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, ot.lp.emd : Unregularized OT ot.optim.cg : General regularized OT ot.bregman.sinkhorn_knopp : Classic Sinkhorn :ref:`[2] ` - ot.bregman.sinkhorn_stabilized: Stabilized sinkhorn :ref:`[9] ` :ref:`[10] ` - ot.bregman.sinkhorn_epsilon_scaling: Sinkhorn with epslilon scaling :ref:`[9] ` :ref:`[10] ` + ot.bregman.sinkhorn_stabilized: Stabilized sinkhorn + :ref:`[9] ` :ref:`[10] ` + ot.bregman.sinkhorn_epsilon_scaling: Sinkhorn with epslilon scaling + :ref:`[9] ` :ref:`[10] ` """ if method.lower() == 'sinkhorn': return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, + warn=warn, **kwargs) elif method.lower() == 'sinkhorn_log': return sinkhorn_log(a, b, M, reg, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, + warn=warn, **kwargs) elif method.lower() == 'greenkhorn': return greenkhorn(a, b, M, reg, numItermax=numItermax, - stopThr=stopThr, verbose=verbose, log=log) + stopThr=stopThr, verbose=verbose, log=log, + warn=warn) elif method.lower() == 'sinkhorn_stabilized': return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax, stopThr=stopThr, verbose=verbose, - log=log, **kwargs) + log=log, warn=warn, + **kwargs) elif method.lower() == 'sinkhorn_epsilon_scaling': return sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=numItermax, stopThr=stopThr, verbose=verbose, - log=log, **kwargs) + log=log, warn=warn, + **kwargs) else: raise ValueError("Unknown method '%s'." % method) def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, - stopThr=1e-9, verbose=False, log=False, **kwargs): + stopThr=1e-9, verbose=False, log=False, warn=False, **kwargs): r""" Solve the entropic regularization optimal transport problem and return the loss @@ -179,13 +200,16 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, where : - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix - - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (histograms, both sum to 1) + - :math:`\Omega` is the entropic regularization term + :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target + weights (histograms, both sum to 1) .. note:: This function is backend-compatible and will work on arrays from all compatible backends. - The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[2] ` + The algorithm used for solving the problem is the Sinkhorn-Knopp matrix + scaling algorithm as proposed in :ref:`[2] ` **Choosing a Sinkhorn solver** @@ -212,7 +236,8 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, samples weights in the source domain b : array-like, shape (dim_b,) or ndarray, shape (dim_b, n_hists) samples in the target domain, compute sinkhorn with multiple targets - and fixed :math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix (return OT loss + dual variables in log) + and fixed :math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix + (return OT loss + dual variables in log) M : array-like, shape (dim_a, dim_b) loss matrix reg : float @@ -228,6 +253,8 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, Print information along iterations log : bool, optional record log if True + warn : bool, optional + if True, raises a warning if the algorithm doesn't convergence. Returns ------- @@ -252,19 +279,27 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, References ---------- - .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of + Optimal Transport, Advances in Neural Information + Processing Systems (NIPS) 26, 2013 - .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519. + .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms + for Entropy Regularized Transport Problems. + arXiv preprint arXiv:1610.06519. - .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. + .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). + Scaling algorithms for unbalanced transport problems. + arXiv preprint arXiv:1607.05816. .. [21] Altschuler J., Weed J., Rigollet P. : Near-linear time approximation - algorithms for optimal transport via Sinkhorn iteration, Advances in Neural - Information Processing Systems (NIPS) 31, 2017 - - .. [34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., Trouvé, A., & Peyré, G. (2019, April). Interpolating between optimal transport and MMD using Sinkhorn divergences. In The 22nd International Conference on Artificial Intelligence and Statistics (pp. 2681-2690). PMLR. - + algorithms for optimal transport via Sinkhorn iteration, + Advances in Neural Information Processing Systems (NIPS) 31, 2017 + .. [34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., + Trouvé, A., & Peyré, G. (2019, April). + Interpolating between optimal transport and MMD using Sinkhorn + divergences. In The 22nd International Conference on Artificial + Intelligence and Statistics (pp. 2681-2690). PMLR. See Also -------- @@ -272,7 +307,8 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, ot.optim.cg : General regularized OT ot.bregman.sinkhorn_knopp : Classic Sinkhorn :ref:`[2] ` ot.bregman.greenkhorn : Greenkhorn :ref:`[21] ` - ot.bregman.sinkhorn_stabilized: Stabilized sinkhorn :ref:`[9] ` :ref:`[10] ` + ot.bregman.sinkhorn_stabilized: Stabilized sinkhorn :ref:`[9] ` + :ref:`[10] ` """ @@ -317,8 +353,9 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, raise ValueError("Unknown method '%s'." % method) -def sinkhorn_knopp(a, b, M, reg, numItermax=1000, - stopThr=1e-9, verbose=False, log=False, **kwargs): +def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9, + verbose=False, log=False, warn=True, + **kwargs): r""" Solve the entropic regularization optimal transport problem and return the OT matrix @@ -335,10 +372,13 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, where : - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix - - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (histograms, both sum to 1) + - :math:`\Omega` is the entropic regularization term + :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target + weights (histograms, both sum to 1) - The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[2] ` + The algorithm used for solving the problem is the Sinkhorn-Knopp + matrix scaling algorithm as proposed in :ref:`[2] ` Parameters @@ -347,7 +387,8 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, samples weights in the source domain b : array-like, shape (dim_b,) or array-like, shape (dim_b, n_hists) samples in the target domain, compute sinkhorn with multiple targets - and fixed :math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix (return OT loss + dual variables in log) + and fixed :math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix + (return OT loss + dual variables in log) M : array-like, shape (dim_a, dim_b) loss matrix reg : float @@ -360,6 +401,8 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, Print information along iterations log : bool, optional record log if True + warn : bool, optional + if True, raises a warning if the algorithm doesn't convergence. Returns ------- @@ -384,7 +427,9 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, References ---------- - .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation + of Optimal Transport, Advances in Neural Information + Processing Systems (NIPS) 26, 2013 See Also @@ -427,9 +472,9 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, K = nx.exp(M / (-reg)) Kp = (1 / a).reshape(-1, 1) * K - cpt = 0 + err = 1 - while (err > stopThr and cpt < numItermax): + for ii in range(numItermax): uprev = u vprev = v KtransposeU = nx.dot(K.T, u) @@ -441,11 +486,11 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, or nx.any(nx.isinf(u)) or nx.any(nx.isinf(v))): # we have reached the machine precision # come back to previous solution and quit loop - print('Warning: numerical errors at iteration', cpt) + warnings.warn('Warning: numerical errors at iteration %d' % ii) u = uprev v = vprev break - if cpt % 10 == 0: + if ii % 10 == 0: # we can speed up the process by checking for the error only all # the 10th iterations if n_hists: @@ -457,13 +502,20 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, if log: log['err'].append(err) + if err < stopThr: + break if verbose: - if cpt % 200 == 0: + if ii % 200 == 0: print( '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(cpt, err)) - cpt = cpt + 1 + print('{:5d}|{:8e}|'.format(ii, err)) + else: + if warn: + warnings.warn("Sinkhorn did not converge. You might want to " + "increase the number of iterations `numItermax` " + "or the regularization parameter `reg`.") if log: + log['niter'] = ii log['u'] = u log['v'] = v @@ -482,8 +534,8 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, return u.reshape((-1, 1)) * K * v.reshape((1, -1)) -def sinkhorn_log(a, b, M, reg, numItermax=1000, - stopThr=1e-9, verbose=False, log=False, **kwargs): +def sinkhorn_log(a, b, M, reg, numItermax=1000, stopThr=1e-9, verbose=False, + log=False, warn=True, **kwargs): r""" Solve the entropic regularization optimal transport problem in log space and return the OT matrix @@ -528,6 +580,8 @@ def sinkhorn_log(a, b, M, reg, numItermax=1000, Print information along iterations log : bool, optional record log if True + warn : bool, optional + if True, raises a warning if the algorithm doesn't convergence. Returns ------- @@ -552,9 +606,15 @@ def sinkhorn_log(a, b, M, reg, numItermax=1000, References ---------- - .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of + Optimal Transport, Advances in Neural Information Processing + Systems (NIPS) 26, 2013 - .. [34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., Trouvé, A., & Peyré, G. (2019, April). Interpolating between optimal transport and MMD using Sinkhorn divergences. In The 22nd International Conference on Artificial Intelligence and Statistics (pp. 2681-2690). PMLR. + .. [34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., + Trouvé, A., & Peyré, G. (2019, April). Interpolating between + optimal transport and MMD using Sinkhorn divergences. In The + 22nd International Conference on Artificial Intelligence and + Statistics (pp. 2681-2690). PMLR. See Also @@ -613,7 +673,7 @@ def sinkhorn_log(a, b, M, reg, numItermax=1000, if log: log = {'err': []} - Mr = M / (-reg) + Mr = - M / reg # we assume that no distances are null except those of the diagonal of # distances @@ -630,14 +690,13 @@ def sinkhorn_log(a, b, M, reg, numItermax=1000, loga = nx.log(a) logb = nx.log(b) - cpt = 0 err = 1 - while (err > stopThr and cpt < numItermax): + for ii in range(numItermax): v = logb - nx.logsumexp(Mr + u[:, None], 0) u = loga - nx.logsumexp(Mr + v[None, :], 1) - if cpt % 10 == 0: + if ii % 10 == 0: # we can speed up the process by checking for the error only all # the 10th iterations @@ -648,13 +707,20 @@ def sinkhorn_log(a, b, M, reg, numItermax=1000, log['err'].append(err) if verbose: - if cpt % 200 == 0: + if ii % 200 == 0: print( '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(cpt, err)) - cpt = cpt + 1 + print('{:5d}|{:8e}|'.format(ii, err)) + if err < stopThr: + break + else: + if warn: + warnings.warn("Sinkhorn did not converge. You might want to " + "increase the number of iterations `numItermax` " + "or the regularization parameter `reg`.") if log: + log['niter'] = ii log['log_u'] = u log['log_v'] = v log['u'] = nx.exp(u) @@ -667,11 +733,13 @@ def sinkhorn_log(a, b, M, reg, numItermax=1000, def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, - log=False): + log=False, warn=True): r""" Solve the entropic regularization optimal transport problem and return the OT matrix - The algorithm used is based on the paper :ref:`[22] ` which is a stochastic version of the Sinkhorn-Knopp algorithm :ref:`[2] ` + The algorithm used is based on the paper :ref:`[22] ` + which is a stochastic version of the Sinkhorn-Knopp + algorithm :ref:`[2] ` The function solves the following optimization problem: @@ -686,8 +754,10 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, where : - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix - - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (histograms, both sum to 1) + - :math:`\Omega` is the entropic regularization term + :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target + weights (histograms, both sum to 1) Parameters @@ -696,7 +766,8 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, samples weights in the source domain b : array-like, shape (dim_b,) or array-like, shape (dim_b, n_hists) samples in the target domain, compute sinkhorn with multiple targets - and fixed :math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix (return OT loss + dual variables in log) + and fixed :math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix + (return OT loss + dual variables in log) M : array-like, shape (dim_a, dim_b) loss matrix reg : float @@ -707,6 +778,8 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, Stop threshold on error (>0) log : bool, optional record log if True + warn : bool, optional + if True, raises a warning if the algorithm doesn't convergence. Returns ------- @@ -731,9 +804,14 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, References ---------- - .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation + of Optimal Transport, Advances in Neural Information + Processing Systems (NIPS) 26, 2013 - .. [22] J. Altschuler, J.Weed, P. Rigollet : Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration, Advances in Neural Information Processing Systems (NIPS) 31, 2017 + .. [22] J. Altschuler, J.Weed, P. Rigollet : Near-linear time + approximation algorithms for optimal transport via Sinkhorn + iteration, Advances in Neural Information Processing + Systems (NIPS) 31, 2017 See Also @@ -747,7 +825,8 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, nx = get_backend(M, a, b) if nx.__name__ == "jax": - raise TypeError("JAX arrays have been received. Greenkhorn is not compatible with JAX") + raise TypeError("JAX arrays have been received. Greenkhorn is not " + "compatible with JAX") if len(a) == 0: a = nx.ones((M.shape[0],), type_as=M) / M.shape[0] @@ -771,7 +850,7 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log['u'] = u log['v'] = v - for i in range(numItermax): + for ii in range(numItermax): i_1 = nx.argmax(nx.abs(viol)) i_2 = nx.argmax(nx.abs(viol_2)) m_viol_1 = nx.abs(viol[i_1]) @@ -795,14 +874,17 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, viol += (-old_v + new_v) * K[:, i_2] * u viol_2[i_2] = new_v * K[:, i_2].dot(u) - b[i_2] v[i_2] = new_v - # print('b',np.max(abs(aviol -viol)),np.max(abs(aviol_2 - viol_2))) if stopThr_val <= stopThr: break else: - print('Warning: Algorithm did not converge') + if warn: + warnings.warn("Sinkhorn did not converge. You might want to " + "increase the number of iterations `numItermax` " + "or the regularization parameter `reg`.") if log: + log["n_iter"] = ii log['u'] = u log['v'] = v @@ -814,7 +896,7 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, warmstart=None, verbose=False, print_period=20, - log=False, **kwargs): + log=False, warn=True, **kwargs): r""" Solve the entropic regularization OT problem with log stabilization @@ -831,13 +913,17 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, where : - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix - - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (histograms, both sum to 1) + - :math:`\Omega` is the entropic regularization term + :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target + weights (histograms, both sum to 1) The algorithm used for solving the problem is the Sinkhorn-Knopp matrix - scaling algorithm as proposed in :ref:`[2] ` but with the log stabilization - proposed in :ref:`[10] ` an defined in :ref:`[9] ` (Algo 3.1) . + scaling algorithm as proposed in :ref:`[2] ` + but with the log stabilization + proposed in :ref:`[10] ` an defined in + :ref:`[9] ` (Algo 3.1) . Parameters @@ -851,7 +937,8 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, reg : float Regularization term >0 tau : float - threshold for max value in :math:`\mathbf{u}` or :math:`\mathbf{v}` for log scaling + threshold for max value in :math:`\mathbf{u}` or :math:`\mathbf{v}` + for log scaling warmstart : table of vectors if given then starting values for alpha and beta log scalings numItermax : int, optional @@ -862,6 +949,8 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, Print information along iterations log : bool, optional record log if True + warn : bool, optional + if True, raises a warning if the algorithm doesn't convergence. Returns ------- @@ -886,11 +975,17 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, References ---------- - .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of + Optimal Transport, Advances in Neural Information Processing + Systems (NIPS) 26, 2013 - .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519. + .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms + for Entropy Regularized Transport Problems. + arXiv preprint arXiv:1610.06519. - .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. + .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). + Scaling algorithms for unbalanced transport problems. + arXiv preprint arXiv:1607.05816. See Also @@ -920,7 +1015,6 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, dim_a = len(a) dim_b = len(b) - cpt = 0 if log: log = {'err': []} @@ -935,7 +1029,9 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, u = nx.ones((dim_a, n_hists), type_as=M) / dim_a v = nx.ones((dim_b, n_hists), type_as=M) / dim_b else: - u, v = nx.ones(dim_a, type_as=M) / dim_a, nx.ones(dim_b, type_as=M) / dim_b + u, v = nx.ones(dim_a, type_as=M), nx.ones(dim_b, type_as=M) + u /= dim_a + v /= dim_b def get_K(alpha, beta): """log space computation""" @@ -947,21 +1043,17 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, return nx.exp(-(M - alpha.reshape((dim_a, 1)) - beta.reshape((1, dim_b))) / reg + nx.log(u.reshape((dim_a, 1))) + nx.log(v.reshape((1, dim_b)))) - # print(np.min(K)) - K = get_K(alpha, beta) transp = K - loop = 1 - cpt = 0 err = 1 - while loop: + for ii in range(numItermax): uprev = u vprev = v # sinkhorn update - v = b / (nx.dot(K.T, u) + 1e-16) - u = a / (nx.dot(K, v) + 1e-16) + v = b / (nx.dot(K.T, u)) + u = a / (nx.dot(K, v)) # remove numerical problems and store them in K if nx.max(nx.abs(u)) > tau or nx.max(nx.abs(v)) > tau: @@ -977,7 +1069,7 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, v = nx.ones(dim_b, type_as=M) / dim_b K = get_K(alpha, beta) - if cpt % print_period == 0: + if ii % print_period == 0: # we can speed up the process by checking for the error only all # the 10th iterations if n_hists: @@ -993,33 +1085,33 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, log['err'].append(err) if verbose: - if cpt % (print_period * 20) == 0: + if ii % (print_period * 20) == 0: print( '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(cpt, err)) + print('{:5d}|{:8e}|'.format(ii, err)) if err <= stopThr: - loop = False - - if cpt >= numItermax: - loop = False + break if nx.any(nx.isnan(u)) or nx.any(nx.isnan(v)): # we have reached the machine precision # come back to previous solution and quit loop - print('Warning: numerical errors at iteration', cpt) + warnings.warn('Numerical errors at iteration %d' % ii) u = uprev v = vprev break - - cpt = cpt + 1 - + else: + if warn: + warnings.warn("Sinkhorn did not converge. You might want to " + "increase the number of iterations `numItermax` " + "or the regularization parameter `reg`.") if log: if n_hists: alpha = alpha[:, None] beta = beta[:, None] logu = alpha / reg + nx.log(u) logv = beta / reg + nx.log(v) + log["n_iter"] = ii log['logu'] = logu log['logv'] = logv log['alpha'] = alpha + reg * nx.log(u) @@ -1048,13 +1140,11 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, numInnerItermax=100, tau=1e3, stopThr=1e-9, warmstart=None, verbose=False, print_period=10, - log=False, **kwargs): + log=False, warn=True, **kwargs): r""" Solve the entropic regularization optimal transport problem with log stabilization and epsilon scaling. - The function solves the following optimization problem: - .. math:: \gamma = \mathop{\arg \min}_\gamma <\gamma, \mathbf{M}>_F + \mathrm{reg}\cdot\Omega(\gamma) @@ -1064,16 +1154,16 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, \gamma &\geq 0 where : - - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix - - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (histograms, both sum to 1) - - + - :math:`\Omega` is the entropic regularization term + :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights + (histograms, both sum to 1) The algorithm used for solving the problem is the Sinkhorn-Knopp matrix - scaling algorithm as proposed in :ref:`[2] ` but with the log stabilization - proposed in :ref:`[10] ` and the log scaling proposed in :ref:`[9] ` algorithm 3.2 - + scaling algorithm as proposed in :ref:`[2] ` + but with the log stabilization + proposed in :ref:`[10] ` and the log scaling + proposed in :ref:`[9] ` algorithm 3.2 Parameters ---------- @@ -1086,7 +1176,8 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, reg : float Regularization term >0 tau : float - threshold for max value in :math:`\mathbf{u}` or :math:`\mathbf{b}` for log scaling + threshold for max value in :math:`\mathbf{u}` or :math:`\mathbf{b}` + for log scaling warmstart : tuple of vectors if given then starting values for alpha and beta log scalings numItermax : int, optional @@ -1101,6 +1192,8 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, Print information along iterations log : bool, optional record log if True + warn : bool, optional + if True, raises a warning if the algorithm doesn't convergence. Returns ------- @@ -1108,10 +1201,8 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, Optimal transportation matrix for the given parameters log : dict log dictionary return only if log==True in parameters - Examples -------- - >>> import ot >>> a=[.5, .5] >>> b=[.5, .5] @@ -1123,19 +1214,19 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, .. _references-sinkhorn-epsilon-scaling: References ---------- + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal + Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 - .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 - - .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519. - - .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. + .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for + Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519. + .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). + Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. See Also -------- ot.lp.emd : Unregularized OT ot.optim.cg : General regularized OT - """ a, b, M = list_to_array(a, b, M) @@ -1155,7 +1246,7 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, numItermin = 35 numItermax = max(numItermin, numItermax) # ensure that last velue is exact - cpt = 0 + ii = 0 if log: log = {'err': []} @@ -1170,12 +1261,10 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, def get_reg(n): # exponential decreasing return (epsilon0 - reg) * np.exp(-n) + reg - loop = 1 - cpt = 0 err = 1 - while loop: + for ii in range(numItermax): - regi = get_reg(cpt) + regi = get_reg(ii) G, logi = sinkhorn_stabilized(a, b, M, regi, numItermax=numInnerItermax, stopThr=1e-9, @@ -1185,10 +1274,7 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, alpha = logi['alpha'] beta = logi['beta'] - if cpt >= numItermax: - loop = False - - if cpt % (print_period) == 0: # spsion nearly converged + if ii % (print_period) == 0: # spsion nearly converged # we can speed up the process by checking for the error only all # the 10th iterations transp = G @@ -1197,19 +1283,22 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, log['err'].append(err) if verbose: - if cpt % (print_period * 10) == 0: + if ii % (print_period * 10) == 0: print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(cpt, err)) - - if err <= stopThr and cpt > numItermin: - loop = False + print('{:5d}|{:8e}|'.format(ii, err)) - cpt = cpt + 1 - # print('err=',err,' cpt=',cpt) + if err <= stopThr and ii > numItermin: + break + else: + if warn: + warnings.warn("Sinkhorn did not converge. You might want to " + "increase the number of iterations `numItermax` " + "or the regularization parameter `reg`.") if log: log['alpha'] = alpha log['beta'] = beta log['warmstart'] = (log['alpha'], log['beta']) + log['niter'] = ii return G, log else: return G @@ -1245,7 +1334,7 @@ def projC(gamma, q): def barycenter(A, M, reg, weights=None, method="sinkhorn", numItermax=10000, - stopThr=1e-4, verbose=False, log=False, **kwargs): + stopThr=1e-4, verbose=False, log=False, warn=True, **kwargs): r"""Compute the entropic regularized wasserstein barycenter of distributions :math:`\mathbf{A}` The function solves the following optimization problem: @@ -1255,11 +1344,16 @@ def barycenter(A, M, reg, weights=None, method="sinkhorn", numItermax=10000, where : - - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see :py:func:`ot.bregman.sinkhorn`) - - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}` - - `reg` and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT + - :math:`OT_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein + distance (see :py:func:`ot.bregman.sinkhorn`) + if `method` is `sinkhorn` or `sinkhorn_stabilized` or `sinkhorn_log`. + - :math:`\mathbf{a}_i` are training distributions in the columns of matrix + :math:`\mathbf{A}` + - `reg` and :math:`\mathbf{M}` are respectively the regularization term and + the cost matrix for OT - The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[3] ` + The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling + algorithm as proposed in :ref:`[3] ` Parameters ---------- @@ -1270,7 +1364,7 @@ def barycenter(A, M, reg, weights=None, method="sinkhorn", numItermax=10000, reg : float Regularization term > 0 method : str (optional) - method used for the solver either 'sinkhorn' or 'sinkhorn_stabilized' + method used for the solver either 'sinkhorn' or 'sinkhorn_stabilized' or 'sinkhorn_log' weights : array-like, shape (n_hists,) Weights of each histogram :math:`\mathbf{a}_i` on the simplex (barycentric coodinates) numItermax : int, optional @@ -1281,6 +1375,8 @@ def barycenter(A, M, reg, weights=None, method="sinkhorn", numItermax=10000, Print information along iterations log : bool, optional record log if True + warn : bool, optional + if True, raises a warning if the algorithm doesn't convergence. Returns @@ -1295,7 +1391,9 @@ def barycenter(A, M, reg, weights=None, method="sinkhorn", numItermax=10000, References ---------- - .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). Iterative Bregman projections for regularized transportation problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138. + .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). + Iterative Bregman projections for regularized transportation problems. + SIAM Journal on Scientific Computing, 37(2), A1111-A1138. """ @@ -1303,18 +1401,24 @@ def barycenter(A, M, reg, weights=None, method="sinkhorn", numItermax=10000, return barycenter_sinkhorn(A, M, reg, weights=weights, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, + warn=warn, **kwargs) elif method.lower() == 'sinkhorn_stabilized': return barycenter_stabilized(A, M, reg, weights=weights, numItermax=numItermax, stopThr=stopThr, verbose=verbose, - log=log, **kwargs) + log=log, warn=warn, **kwargs) + elif method.lower() == 'sinkhorn_log': + return _barycenter_sinkhorn_log(A, M, reg, weights=weights, + numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, warn=warn, **kwargs) else: raise ValueError("Unknown method '%s'." % method) def barycenter_sinkhorn(A, M, reg, weights=None, numItermax=1000, - stopThr=1e-4, verbose=False, log=False): + stopThr=1e-4, verbose=False, log=False, warn=True): r"""Compute the entropic regularized wasserstein barycenter of distributions :math:`\mathbf{A}` The function solves the following optimization problem: @@ -1324,11 +1428,15 @@ def barycenter_sinkhorn(A, M, reg, weights=None, numItermax=1000, where : - - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see :py:func:`ot.bregman.sinkhorn`) - - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}` - - `reg` and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT + - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance + (see :py:func:`ot.bregman.sinkhorn`) + - :math:`\mathbf{a}_i` are training distributions in the columns of matrix + :math:`\mathbf{A}` + - `reg` and :math:`\mathbf{M}` are respectively the regularization term and + the cost matrix for OT - The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[3] ` + The algorithm used for solving the problem is the Sinkhorn-Knopp matrix + scaling algorithm as proposed in :ref:`[3]`. Parameters ---------- @@ -1348,6 +1456,8 @@ def barycenter_sinkhorn(A, M, reg, weights=None, numItermax=1000, Print information along iterations log : bool, optional record log if True + warn : bool, optional + if True, raises a warning if the algorithm doesn't convergence. Returns @@ -1362,7 +1472,9 @@ def barycenter_sinkhorn(A, M, reg, weights=None, numItermax=1000, References ---------- - .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). Iterative Bregman projections for regularized transportation problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138. + .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). + Iterative Bregman projections for regularized transportation problems. + SIAM Journal on Scientific Computing, 37(2), A1111-A1138. """ @@ -1378,43 +1490,109 @@ def barycenter_sinkhorn(A, M, reg, weights=None, numItermax=1000, if log: log = {'err': []} - # M = M/np.median(M) # suggested by G. Peyre K = nx.exp(-M / reg) - cpt = 0 err = 1 UKv = nx.dot(K, (A.T / nx.sum(K, axis=0)).T) u = (geometricMean(UKv) / UKv.T).T - while (err > stopThr and cpt < numItermax): - cpt = cpt + 1 + for ii in range(numItermax): + UKv = u * nx.dot(K, A / nx.dot(K, u)) u = (u.T * geometricBar(weights, UKv)).T / UKv - if cpt % 10 == 1: + if ii % 10 == 1: err = nx.sum(nx.std(UKv, axis=1)) # log and verbose print if log: log['err'].append(err) + if err < stopThr: + break if verbose: - if cpt % 200 == 0: + if ii % 200 == 0: print( '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(cpt, err)) - + print('{:5d}|{:8e}|'.format(ii, err)) + else: + if warn: + warnings.warn("Sinkhorn did not converge. You might want to " + "increase the number of iterations `numItermax` " + "or the regularization parameter `reg`.") if log: - log['niter'] = cpt + log['niter'] = ii return geometricBar(weights, UKv), log else: return geometricBar(weights, UKv) +def _barycenter_sinkhorn_log(A, M, reg, weights=None, numItermax=1000, + stopThr=1e-4, verbose=False, log=False, warn=True): + r"""Compute the entropic wasserstein barycenter in log-domain + """ + + A, M = list_to_array(A, M) + dim, n_hists = A.shape + + nx = get_backend(A, M) + + if nx.__name__ == "jax": + raise NotImplementedError("Log-domain functions are not yet implemented" + " for Jax. Use numpy or torch arrays instead.") + + if weights is None: + weights = nx.ones(n_hists, type_as=A) / n_hists + else: + assert (len(weights) == A.shape[1]) + + if log: + log = {'err': []} + + M = - M / reg + logA = nx.log(A + 1e-15) + log_KU, G = nx.zeros((2, *logA.shape), type_as=A) + err = 1 + for ii in range(numItermax): + log_bar = nx.zeros(dim, type_as=A) + for k in range(n_hists): + f = logA[:, k] - nx.logsumexp(M + G[None, :, k], axis=1) + log_KU[:, k] = nx.logsumexp(M + f[:, None], axis=0) + log_bar = log_bar + weights[k] * log_KU[:, k] + + if ii % 10 == 1: + err = nx.exp(G + log_KU).std(axis=1).sum() + + # log and verbose print + if log: + log['err'].append(err) + + if err < stopThr: + break + if verbose: + if ii % 200 == 0: + print( + '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(ii, err)) + + G = log_bar[:, None] - log_KU + + else: + if warn: + warnings.warn("Sinkhorn did not converge. You might want to " + "increase the number of iterations `numItermax` " + "or the regularization parameter `reg`.") + if log: + log['niter'] = ii + return nx.exp(log_bar), log + else: + return nx.exp(log_bar) + + def barycenter_stabilized(A, M, reg, tau=1e10, weights=None, numItermax=1000, - stopThr=1e-4, verbose=False, log=False): + stopThr=1e-4, verbose=False, log=False, warn=True): r"""Compute the entropic regularized wasserstein barycenter of distributions :math:`\mathbf{A}` with stabilization. The function solves the following optimization problem: @@ -1424,11 +1602,15 @@ def barycenter_stabilized(A, M, reg, tau=1e10, weights=None, numItermax=1000, where : - - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see :py:func:`ot.bregman.sinkhorn`) - - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}` - - `reg` and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT + - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein + distance (see :py:func:`ot.bregman.sinkhorn`) + - :math:`\mathbf{a}_i` are training distributions in the columns of matrix + :math:`\mathbf{A}` + - `reg` and :math:`\mathbf{M}` are respectively the regularization term and + the cost matrix for OT - The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[3] ` + The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling + algorithm as proposed in :ref:`[3] ` Parameters ---------- @@ -1439,7 +1621,8 @@ def barycenter_stabilized(A, M, reg, tau=1e10, weights=None, numItermax=1000, reg : float Regularization term > 0 tau : float - threshold for max value in :math:`\mathbf{u}` or :math:`\mathbf{v}` for log scaling + threshold for max value in :math:`\mathbf{u}` or :math:`\mathbf{v}` + for log scaling weights : array-like, shape (n_hists,) Weights of each histogram :math:`\mathbf{a}_i` on the simplex (barycentric coodinates) numItermax : int, optional @@ -1450,6 +1633,8 @@ def barycenter_stabilized(A, M, reg, tau=1e10, weights=None, numItermax=1000, Print information along iterations log : bool, optional record log if True + warn : bool, optional + if True, raises a warning if the algorithm doesn't convergence. Returns @@ -1464,7 +1649,9 @@ def barycenter_stabilized(A, M, reg, tau=1e10, weights=None, numItermax=1000, References ---------- - .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). Iterative Bregman projections for regularized transportation problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138. + .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). + Iterative Bregman projections for regularized transportation problems. + SIAM Journal on Scientific Computing, 37(2), A1111-A1138. """ @@ -1486,19 +1673,18 @@ def barycenter_stabilized(A, M, reg, tau=1e10, weights=None, numItermax=1000, K = nx.exp(-M / reg) - cpt = 0 err = 1. alpha = nx.zeros((dim,), type_as=M) beta = nx.zeros((dim,), type_as=M) q = nx.ones((dim,), type_as=M) / dim - while (err > stopThr and cpt < numItermax): + for ii in range(numItermax): qprev = q Kv = nx.dot(K, v) - u = A / (Kv + 1e-16) + u = A / Kv Ktu = nx.dot(K.T, u) q = geometricBar(weights, Ktu) Q = q[:, None] - v = Q / (Ktu + 1e-16) + v = Q / Ktu absorbing = False if nx.any(u > tau) or nx.any(v > tau): absorbing = True @@ -1512,40 +1698,244 @@ def barycenter_stabilized(A, M, reg, tau=1e10, weights=None, numItermax=1000, or nx.any(nx.isinf(u)) or nx.any(nx.isinf(v))): # we have reached the machine precision # come back to previous solution and quit loop - warnings.warn('Numerical errors at iteration %s' % cpt) + warnings.warn('Numerical errors at iteration %s' % ii) q = qprev break - if (cpt % 10 == 0 and not absorbing) or cpt == 0: + if (ii % 10 == 0 and not absorbing) or ii == 0: # we can speed up the process by checking for the error only all # the 10th iterations err = nx.max(nx.abs(u * Kv - A)) if log: log['err'].append(err) + if err < stopThr: + break if verbose: - if cpt % 50 == 0: + if ii % 50 == 0: print( '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(cpt, err)) + print('{:5d}|{:8e}|'.format(ii, err)) - cpt += 1 - if err > stopThr: - warnings.warn("Stabilized Unbalanced Sinkhorn did not converge." + - "Try a larger entropy `reg`" + - "Or a larger absorption threshold `tau`.") + else: + if warn: + warnings.warn("Stabilized Sinkhorn did not converge." + + "Try a larger entropy `reg`" + + "Or a larger absorption threshold `tau`.") if log: - log['niter'] = cpt - log['logu'] = nx.log(u + 1e-16) - log['logv'] = nx.log(v + 1e-16) + log['niter'] = ii + log['logu'] = np.log(u + 1e-16) + log['logv'] = np.log(v + 1e-16) return q, log else: return q -def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, - stopThr=1e-9, stabThr=1e-30, verbose=False, - log=False): - r"""Compute the entropic regularized wasserstein barycenter of distributions :math:`\mathbf{A}` - where :math:`\mathbf{A}` is a collection of 2D images. +def barycenter_debiased(A, M, reg, weights=None, method="sinkhorn", numItermax=10000, + stopThr=1e-4, verbose=False, log=False, warn=True, **kwargs): + r"""Compute the debiased Sinkhorn barycenter of distributions A + + The function solves the following optimization problem: + + .. math:: + \mathbf{a} = arg\min_\mathbf{a} \sum_i S_{reg}(\mathbf{a},\mathbf{a}_i) + + where : + + - :math:`S_{reg}(\cdot,\cdot)` is the debiased Sinkhorn divergence + (see :py:func:`ot.bregman.emirical_sinkhorn_divergence`) + - :math:`\mathbf{a}_i` are training distributions in the columns of matrix + :math:`\mathbf{A}` + - `reg` and :math:`\mathbf{M}` are respectively the regularization term and + the cost matrix for OT + + The algorithm used for solving the problem is the debiased Sinkhorn + algorithm as proposed in :ref:`[37] ` + + Parameters + ---------- + A : array-like, shape (dim, n_hists) + `n_hists` training distributions :math:`a_i` of size `dim` + M : array-like, shape (dim, dim) + loss matrix for OT + reg : float + Regularization term > 0 + method : str (optional) + method used for the solver either 'sinkhorn' or 'sinkhorn_log' + weights : array-like, shape (n_hists,) + Weights of each histogram :math:`a_i` on the simplex (barycentric coodinates) + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshold on error (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + warn : bool, optional + if True, raises a warning if the algorithm doesn't convergence. + + + + Returns + ------- + a : (dim,) array-like + Wasserstein barycenter + log : dict + log dictionary return only if log==True in parameters + + .. _references-sinkhorn-debiased: + References + ---------- + + .. [37] Janati, H., Cuturi, M., Gramfort, A. Proceedings of the 37th International + Conference on Machine Learning, PMLR 119:4692-4701, 2020 + """ + + if method.lower() == 'sinkhorn': + return _barycenter_debiased(A, M, reg, weights=weights, + numItermax=numItermax, + stopThr=stopThr, verbose=verbose, log=log, + warn=warn, **kwargs) + elif method.lower() == 'sinkhorn_log': + return _barycenter_debiased_log(A, M, reg, weights=weights, + numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, warn=warn, **kwargs) + else: + raise ValueError("Unknown method '%s'." % method) + + +def _barycenter_debiased(A, M, reg, weights=None, numItermax=1000, + stopThr=1e-4, verbose=False, log=False, warn=True): + r"""Compute the debiased sinkhorn barycenter of distributions A. + """ + + A, M = list_to_array(A, M) + + nx = get_backend(A, M) + + if weights is None: + weights = nx.ones((A.shape[1],), type_as=A) / A.shape[1] + else: + assert (len(weights) == A.shape[1]) + + if log: + log = {'err': []} + + K = nx.exp(-M / reg) + + err = 1 + + UKv = nx.dot(K, (A.T / nx.sum(K, axis=0)).T) + + u = (geometricMean(UKv) / UKv.T).T + c = nx.ones(A.shape[0], type_as=A) + bar = nx.ones(A.shape[0], type_as=A) + + for ii in range(numItermax): + bold = bar + UKv = nx.dot(K, A / nx.dot(K, u)) + bar = c * geometricBar(weights, UKv) + u = bar[:, None] / UKv + c = (c * bar / nx.dot(K, c)) ** 0.5 + + if ii % 10 == 9: + err = abs(bar - bold).max() / max(bar.max(), 1.) + + # log and verbose print + if log: + log['err'].append(err) + + # debiased Sinkhorn does not converge monotonically + # guarantee a few iterations are done before stopping + if err < stopThr and ii > 20: + break + if verbose: + if ii % 200 == 0: + print( + '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(ii, err)) + else: + if warn: + warnings.warn("Sinkhorn did not converge. You might want to " + "increase the number of iterations `numItermax` " + "or the regularization parameter `reg`.") + if log: + log['niter'] = ii + return bar, log + else: + return bar + + +def _barycenter_debiased_log(A, M, reg, weights=None, numItermax=1000, + stopThr=1e-4, verbose=False, log=False, + warn=True): + r"""Compute the debiased sinkhorn barycenter in log domain. + """ + + A, M = list_to_array(A, M) + dim, n_hists = A.shape + + nx = get_backend(A, M) + if nx.__name__ == "jax": + raise NotImplementedError("Log-domain functions are not yet implemented" + " for Jax. Use numpy or torch arrays instead.") + + if weights is None: + weights = nx.ones(n_hists, type_as=A) / n_hists + else: + assert (len(weights) == A.shape[1]) + + if log: + log = {'err': []} + + M = - M / reg + logA = nx.log(A + 1e-15) + log_KU, G = nx.zeros((2, *logA.shape), type_as=A) + c = nx.zeros(dim, type_as=A) + err = 1 + for ii in range(numItermax): + log_bar = nx.zeros(dim, type_as=A) + for k in range(n_hists): + f = logA[:, k] - nx.logsumexp(M + G[None, :, k], axis=1) + log_KU[:, k] = nx.logsumexp(M + f[:, None], axis=0) + log_bar += weights[k] * log_KU[:, k] + log_bar += c + if ii % 10 == 1: + err = nx.exp(G + log_KU).std(axis=1).sum() + + # log and verbose print + if log: + log['err'].append(err) + + if err < stopThr and ii > 20: + break + if verbose: + if ii % 200 == 0: + print( + '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(ii, err)) + + G = log_bar[:, None] - log_KU + for _ in range(10): + c = 0.5 * (c + log_bar - nx.logsumexp(M + c[:, None], axis=0)) + + else: + if warn: + warnings.warn("Sinkhorn did not converge. You might want to " + "increase the number of iterations `numItermax` " + "or the regularization parameter `reg`.") + if log: + log['niter'] = ii + return nx.exp(log_bar), log + else: + return nx.exp(log_bar) + + +def convolutional_barycenter2d(A, reg, weights=None, method="sinkhorn", numItermax=10000, + stopThr=1e-4, verbose=False, log=False, + warn=True, **kwargs): + r"""Compute the entropic regularized wasserstein barycenter of distributions A + where A is a collection of 2D images. The function solves the following optimization problem: @@ -1554,11 +1944,14 @@ def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, where : - - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see :py:func:`ot.bregman.sinkhorn`) - - :math:`\mathbf{a}_i` are training distributions (2D images) in the mast two dimensions of matrix :math:`\mathbf{A}` + - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein + distance (see :py:func:`ot.bregman.sinkhorn`) + - :math:`\mathbf{a}_i` are training distributions (2D images) in the mast two dimensions + of matrix :math:`\mathbf{A}` - `reg` is the regularization strength scalar value - The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[21] ` + The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm + as proposed in :ref:`[21] ` Parameters ---------- @@ -1568,6 +1961,8 @@ def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, Regularization term >0 weights : array-like, shape (n_hists,) Weights of each image on the simplex (barycentric coodinates) + method : string, optional + method used for the solver either 'sinkhorn' or 'sinkhorn_log' numItermax : int, optional Max number of iterations stopThr : float, optional @@ -1578,6 +1973,8 @@ def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, Print information along iterations log : bool, optional record log if True + warn : bool, optional + if True, raises a warning if the algorithm doesn't convergence. Returns ------- @@ -1591,9 +1988,36 @@ def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, References ---------- - .. [21] Solomon, J., De Goes, F., Peyré, G., Cuturi, M., Butscher, A., Nguyen, A. & Guibas, L. (2015). Convolutional wasserstein distances: Efficient optimal transportation on geometric domains. ACM Transactions on Graphics (TOG), 34(4), 66 + .. [21] Solomon, J., De Goes, F., Peyré, G., Cuturi, M., Butscher, + A., Nguyen, A. & Guibas, L. (2015). Convolutional wasserstein distances: + Efficient optimal transportation on geometric domains. ACM Transactions + on Graphics (TOG), 34(4), 66 + .. [37] Janati, H., Cuturi, M., Gramfort, A. Proceedings of the 37th + International Conference on Machine Learning, PMLR 119:4692-4701, 2020 + """ + if method.lower() == 'sinkhorn': + return _convolutional_barycenter2d(A, reg, weights=weights, + numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, warn=warn, + **kwargs) + elif method.lower() == 'sinkhorn_log': + return _convolutional_barycenter2d_log(A, reg, weights=weights, + numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, warn=warn, + **kwargs) + else: + raise ValueError("Unknown method '%s'." % method) + + +def _convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, + stopThr=1e-9, stabThr=1e-30, verbose=False, + log=False, warn=True): + r"""Compute the entropic regularized wasserstein barycenter of distributions A + where A is a collection of 2D images. """ A = list_to_array(A) @@ -1608,65 +2032,373 @@ def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, if log: log = {'err': []} - b = nx.zeros(A.shape[1:], type_as=A) + bar = nx.ones(A.shape[1:], type_as=A) + bar /= bar.sum() U = nx.ones(A.shape, type_as=A) - KV = nx.ones(A.shape, type_as=A) - - cpt = 0 + V = nx.ones(A.shape, type_as=A) err = 1 # build the convolution operator # this is equivalent to blurring on horizontal then vertical directions t = nx.linspace(0, 1, A.shape[1]) [Y, X] = nx.meshgrid(t, t) - xi1 = nx.exp(-(X - Y) ** 2 / reg) + K1 = nx.exp(-(X - Y) ** 2 / reg) t = nx.linspace(0, 1, A.shape[2]) [Y, X] = nx.meshgrid(t, t) - xi2 = nx.exp(-(X - Y) ** 2 / reg) - - def K(x): - return nx.dot(nx.dot(xi1, x), xi2) - - while (err > stopThr and cpt < numItermax): - - bold = b - cpt = cpt + 1 - - b = nx.zeros(A.shape[1:], type_as=A) - KV_cols = [] - for r in range(A.shape[0]): - KV_col_r = K(A[r, :, :] / nx.maximum(stabThr, K(U[r, :, :]))) - b += weights[r] * nx.log(nx.maximum(stabThr, U[r, :, :] * KV_col_r)) - KV_cols.append(KV_col_r) - KV = nx.stack(KV_cols) - b = nx.exp(b) - - U = nx.stack([ - b / nx.maximum(stabThr, KV[r, :, :]) - for r in range(A.shape[0]) - ]) - if cpt % 10 == 1: - err = nx.sum(nx.abs(bold - b)) + K2 = nx.exp(-(X - Y) ** 2 / reg) + + def convol_imgs(imgs): + kx = nx.einsum("...ij,kjl->kil", K1, imgs) + kxy = nx.einsum("...ij,klj->kli", K2, kx) + return kxy + + KU = convol_imgs(U) + for ii in range(numItermax): + V = bar[None] / KU + KV = convol_imgs(V) + U = A / KV + KU = convol_imgs(U) + bar = nx.exp((weights[:, None, None] * nx.log(KU + stabThr)).sum(axis=0)) + if ii % 10 == 9: + err = (V * KU).std(axis=0).sum() + # log and verbose print + if log: + log['err'].append(err) + + if verbose: + if ii % 200 == 0: + print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(ii, err)) + if err < stopThr: + break + + else: + if warn: + warnings.warn("Convolutional Sinkhorn did not converge. " + "Try a larger number of iterations `numItermax` " + "or a larger entropy `reg`.") + if log: + log['niter'] = ii + log['U'] = U + return bar, log + else: + return bar + + +def _convolutional_barycenter2d_log(A, reg, weights=None, numItermax=10000, + stopThr=1e-4, stabThr=1e-30, verbose=False, + log=False, warn=True): + r"""Compute the entropic regularized wasserstein barycenter of distributions A + where A is a collection of 2D images in log-domain. + """ + + A = list_to_array(A) + + nx = get_backend(A) + if nx.__name__ == "jax": + raise NotImplementedError("Log-domain functions are not yet implemented" + " for Jax. Use numpy or torch arrays instead.") + + n_hists, width, height = A.shape + + if weights is None: + weights = nx.ones((n_hists,), type_as=A) / n_hists + else: + assert (len(weights) == n_hists) + + if log: + log = {'err': []} + + err = 1 + # build the convolution operator + # this is equivalent to blurring on horizontal then vertical directions + t = nx.linspace(0, 1, width) + [Y, X] = nx.meshgrid(t, t) + M1 = - (X - Y) ** 2 / reg + + t = nx.linspace(0, 1, height) + [Y, X] = nx.meshgrid(t, t) + M2 = - (X - Y) ** 2 / reg + + def convol_img(log_img): + log_img = nx.logsumexp(M1[:, :, None] + log_img[None], axis=1) + log_img = nx.logsumexp(M2[:, :, None] + log_img.T[None], axis=1).T + return log_img + + logA = nx.log(A + stabThr) + log_KU, G, F = nx.zeros((3, *logA.shape), type_as=A) + err = 1 + for ii in range(numItermax): + log_bar = nx.zeros((width, height), type_as=A) + for k in range(n_hists): + f = logA[k] - convol_img(G[k]) + log_KU[k] = convol_img(f) + log_bar = log_bar + weights[k] * log_KU[k] + + if ii % 10 == 9: + err = nx.exp(G + log_KU).std(axis=0).sum() + # log and verbose print + if log: + log['err'].append(err) + + if verbose: + if ii % 200 == 0: + print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(ii, err)) + if err < stopThr: + break + G = log_bar[None, :, :] - log_KU + + else: + if warn: + warnings.warn("Convolutional Sinkhorn did not converge. " + "Try a larger number of iterations `numItermax` " + "or a larger entropy `reg`.") + if log: + log['niter'] = ii + return nx.exp(log_bar), log + else: + return nx.exp(log_bar) + + +def convolutional_barycenter2d_debiased(A, reg, weights=None, method="sinkhorn", + numItermax=10000, stopThr=1e-3, + verbose=False, log=False, warn=True, + **kwargs): + r"""Compute the debiased sinkhorn barycenter of distributions A + where A is a collection of 2D images. + + The function solves the following optimization problem: + + .. math:: + \mathbf{a} = arg\min_\mathbf{a} \sum_i S_{reg}(\mathbf{a},\mathbf{a}_i) + + where : + + - :math:`S_{reg}(\cdot,\cdot)` is the debiased entropic regularized Wasserstein + distance (see :py:func:`ot.bregman.sinkhorn_debiased`) + - :math:`\mathbf{a}_i` are training distributions (2D images) in the mast two + dimensions of matrix :math:`\mathbf{A}` + - `reg` is the regularization strength scalar value + + The algorithm used for solving the problem is the debiased Sinkhorn scaling + algorithm as proposed in :ref:`[37] ` + + Parameters + ---------- + A : array-like, shape (n_hists, width, height) + `n` distributions (2D images) of size `width` x `height` + reg : float + Regularization term >0 + weights : array-like, shape (n_hists,) + Weights of each image on the simplex (barycentric coodinates) + method : string, optional + method used for the solver either 'sinkhorn' or 'sinkhorn_log' + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshold on error (> 0) + stabThr : float, optional + Stabilization threshold to avoid numerical precision issue + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + warn : bool, optional + if True, raises a warning if the algorithm doesn't convergence. + + + Returns + ------- + a : array-like, shape (width, height) + 2D Wasserstein barycenter + log : dict + log dictionary return only if log==True in parameters + + + .. _references-sinkhorn-debiased: + References + ---------- + + .. [37] Janati, H., Cuturi, M., Gramfort, A. Proceedings of the 37th International + Conference on Machine Learning, PMLR 119:4692-4701, 2020 + """ + + if method.lower() == 'sinkhorn': + return _convolutional_barycenter2d_debiased(A, reg, weights=weights, + numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, warn=warn, + **kwargs) + elif method.lower() == 'sinkhorn_log': + return _convolutional_barycenter2d_debiased_log(A, reg, weights=weights, + numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, warn=warn, + **kwargs) + else: + raise ValueError("Unknown method '%s'." % method) + + +def _convolutional_barycenter2d_debiased(A, reg, weights=None, numItermax=10000, + stopThr=1e-3, stabThr=1e-15, verbose=False, + log=False, warn=True): + r"""Compute the debiased barycenter of 2D images via sinkhorn convolutions. + """ + + A = list_to_array(A) + n_hists, width, height = A.shape + + nx = get_backend(A) + + if weights is None: + weights = nx.ones((n_hists,), type_as=A) / n_hists + else: + assert (len(weights) == n_hists) + + if log: + log = {'err': []} + + bar = nx.ones((width, height), type_as=A) + bar /= width * height + U = nx.ones(A.shape, type_as=A) + V = nx.ones(A.shape, type_as=A) + c = nx.ones(A.shape[1:], type_as=A) + err = 1 + + # build the convolution operator + # this is equivalent to blurring on horizontal then vertical directions + t = nx.linspace(0, 1, width) + [Y, X] = nx.meshgrid(t, t) + K1 = nx.exp(-(X - Y) ** 2 / reg) + + t = nx.linspace(0, 1, height) + [Y, X] = nx.meshgrid(t, t) + K2 = nx.exp(-(X - Y) ** 2 / reg) + + def convol_imgs(imgs): + kx = nx.einsum("...ij,kjl->kil", K1, imgs) + kxy = nx.einsum("...ij,klj->kli", K2, kx) + return kxy + + KU = convol_imgs(U) + for ii in range(numItermax): + V = bar[None] / KU + KV = convol_imgs(V) + U = A / KV + KU = convol_imgs(U) + bar = c * nx.exp((weights[:, None, None] * nx.log(KU + stabThr)).sum(axis=0)) + + for _ in range(10): + c = (c * bar / convol_imgs(c[None]).squeeze()) ** 0.5 + + if ii % 10 == 9: + err = (V * KU).std(axis=0).sum() # log and verbose print if log: log['err'].append(err) if verbose: - if cpt % 200 == 0: + if ii % 200 == 0: print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(cpt, err)) + print('{:5d}|{:8e}|'.format(ii, err)) + # debiased Sinkhorn does not converge monotonically + # guarantee a few iterations are done before stopping + if err < stopThr and ii > 20: + break + else: + if warn: + warnings.warn("Sinkhorn did not converge. You might want to " + "increase the number of iterations `numItermax` " + "or the regularization parameter `reg`.") if log: - log['niter'] = cpt + log['niter'] = ii log['U'] = U - return b, log + return bar, log + else: + return bar + + +def _convolutional_barycenter2d_debiased_log(A, reg, weights=None, numItermax=10000, + stopThr=1e-3, stabThr=1e-30, verbose=False, + log=False, warn=True): + r"""Compute the debiased barycenter of 2D images in log-domain. + """ + + A = list_to_array(A) + n_hists, width, height = A.shape + nx = get_backend(A) + if nx.__name__ == "jax": + raise NotImplementedError("Log-domain functions are not yet implemented" + " for Jax. Use numpy or torch arrays instead.") + if weights is None: + weights = nx.ones((n_hists,), type_as=A) / n_hists + else: + assert (len(weights) == A.shape[0]) + + if log: + log = {'err': []} + + err = 1 + # build the convolution operator + # this is equivalent to blurring on horizontal then vertical directions + t = nx.linspace(0, 1, width) + [Y, X] = nx.meshgrid(t, t) + M1 = - (X - Y) ** 2 / reg + + t = nx.linspace(0, 1, height) + [Y, X] = nx.meshgrid(t, t) + M2 = - (X - Y) ** 2 / reg + + def convol_img(log_img): + log_img = nx.logsumexp(M1[:, :, None] + log_img[None], axis=1) + log_img = nx.logsumexp(M2[:, :, None] + log_img.T[None], axis=1).T + return log_img + + logA = nx.log(A + stabThr) + log_bar, c = nx.zeros((2, width, height), type_as=A) + log_KU, G, F = nx.zeros((3, *logA.shape), type_as=A) + err = 1 + for ii in range(numItermax): + log_bar = nx.zeros((width, height), type_as=A) + for k in range(n_hists): + f = logA[k] - convol_img(G[k]) + log_KU[k] = convol_img(f) + log_bar = log_bar + weights[k] * log_KU[k] + log_bar += c + for _ in range(10): + c = 0.5 * (c + log_bar - convol_img(c)) + + if ii % 10 == 9: + err = nx.exp(G + log_KU).std(axis=0).sum() + # log and verbose print + if log: + log['err'].append(err) + + if verbose: + if ii % 200 == 0: + print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(ii, err)) + if err < stopThr and ii > 20: + break + G = log_bar[None, :, :] - log_KU + else: - return b + if warn: + warnings.warn("Convolutional Sinkhorn did not converge. " + "Try a larger number of iterations `numItermax` " + "or a larger entropy `reg`.") + if log: + log['niter'] = ii + return nx.exp(log_bar), log + else: + return nx.exp(log_bar) def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000, - stopThr=1e-3, verbose=False, log=False): + stopThr=1e-3, verbose=False, log=False, warn=True): r""" Compute the unmixing of an observation with a given dictionary using Wasserstein distance @@ -1679,16 +2411,21 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000, where : - - :math:`W_{M,reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance with :math:`\mathbf{M}` loss matrix (see :py:func:`ot.bregman.sinkhorn`) - - :math:`\mathbf{D}` is a dictionary of `n_atoms` atoms of dimension `dim_a`, its expected shape is `(dim_a, n_atoms)` + - :math:`W_{M,reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance + with M loss matrix (see :py:func:`ot.bregman.sinkhorn`) + - :math:`\mathbf{D}` is a dictionary of `n_atoms` atoms of dimension `dim_a`, + its expected shape is `(dim_a, n_atoms)` - :math:`\mathbf{h}` is the estimated unmixing of dimension `n_atoms` - :math:`\mathbf{a}` is an observed distribution of dimension `dim_a` - :math:`\mathbf{h}_0` is a prior on :math:`\mathbf{h}` of dimension `dim_prior` - - `reg` and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix (`dim_a`, `dim_a`) for OT data fitting - - `reg`:math:`_0` and :math:`\mathbf{M_0}` are respectively the regularization term and the cost matrix (`dim_prior`, `n_atoms`) regularization + - `reg` and :math:`\mathbf{M}` are respectively the regularization term and the + cost matrix (`dim_a`, `dim_a`) for OT data fitting + - `reg`:math:`_0` and :math:`\mathbf{M_0}` are respectively the regularization + term and the cost matrix (`dim_prior`, `n_atoms`) regularization - :math:`\\alpha` weight data fitting and regularization - The optimization problem is solved following the algorithm described in :ref:`[4] ` + The optimization problem is solved following the algorithm described + in :ref:`[4] ` Parameters @@ -1717,7 +2454,8 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000, Print information along iterations log : bool, optional record log if True - + warn : bool, optional + if True, raises a warning if the algorithm doesn't convergence. Returns ------- @@ -1731,8 +2469,10 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000, References ---------- - .. [4] S. Nakhostin, N. Courty, R. Flamary, D. Tuia, T. Corpetti, Supervised planetary unmixing with optimal transport, Whorkshop on Hyperspectral Image and Signal Processing : Evolution in Remote Sensing (WHISPERS), 2016. - + .. [4] S. Nakhostin, N. Courty, R. Flamary, D. Tuia, T. Corpetti, + Supervised planetary unmixing with optimal transport, Whorkshop + on Hyperspectral Image and Signal Processing : + Evolution in Remote Sensing (WHISPERS), 2016. """ a, D, M, M0, h0 = list_to_array(a, D, M, M0, h0) @@ -1747,12 +2487,11 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000, old = h0 err = 1 - cpt = 0 # log = {'niter':0, 'all_err':[]} if log: log = {'err': []} - while (err > stopThr and cpt < numItermax): + for ii in range(numItermax): K = projC(K, a) K0 = projC(K0, h0) new = nx.sum(K0, axis=1) @@ -1770,22 +2509,27 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000, log['err'].append(err) if verbose: - if cpt % 200 == 0: + if ii % 200 == 0: print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(cpt, err)) - - cpt = cpt + 1 - + print('{:5d}|{:8e}|'.format(ii, err)) + if err < stopThr: + break + else: + if warn: + warnings.warn("Unmixing algorithm did not converge. You might want to " + "increase the number of iterations `numItermax` " + "or the regularization parameter `reg`.") if log: - log['niter'] = cpt + log['niter'] = ii return nx.sum(K0, axis=1), log else: return nx.sum(K0, axis=1) def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, - stopThr=1e-6, verbose=False, log=False, **kwargs): - r'''Joint OT and proportion estimation for multi-source target shift as proposed in :ref:`[27] ` + stopThr=1e-6, verbose=False, log=False, warn=True, **kwargs): + r'''Joint OT and proportion estimation for multi-source target shift as + proposed in :ref:`[27] ` The function solves the following optimization problem: @@ -1799,16 +2543,23 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, where : - :math:`\lambda_k` is the weight of `k`-th source domain - - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see :py:func:`ot.bregman.sinkhorn`) - - :math:`\mathbf{D}_2^{(k)}` is a matrix of weights related to `k`-th source domain defined as in [p. 5, :ref:`27 `], its expected shape is :math:`(n_k, C)` where :math:`n_k` is the number of elements in the `k`-th source domain and `C` is the number of classes + - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance + (see :py:func:`ot.bregman.sinkhorn`) + - :math:`\mathbf{D}_2^{(k)}` is a matrix of weights related to `k`-th source domain + defined as in [p. 5, :ref:`27 `], its expected shape + is :math:`(n_k, C)` where :math:`n_k` is the number of elements in the `k`-th source + domain and `C` is the number of classes - :math:`\mathbf{h}` is a vector of estimated proportions in the target domain of size `C` - :math:`\mathbf{a}` is a uniform vector of weights in the target domain of size `n` - - :math:`\mathbf{D}_1^{(k)}` is a matrix of class assignments defined as in [p. 5, :ref:`27 `], its expected shape is :math:`(n_k, C)` + - :math:`\mathbf{D}_1^{(k)}` is a matrix of class assignments defined as in + [p. 5, :ref:`27 `], its expected shape is :math:`(n_k, C)` - The problem consist in solving a Wasserstein barycenter problem to estimate the proportions :math:`\mathbf{h}` in the target domain. + The problem consist in solving a Wasserstein barycenter problem to estimate + the proportions :math:`\mathbf{h}` in the target domain. The algorithm used for solving the problem is the Iterative Bregman projections algorithm - with two sets of marginal constraints related to the unknown vector :math:`\mathbf{h}` and uniform target distribution. + with two sets of marginal constraints related to the unknown vector + :math:`\mathbf{h}` and uniform target distribution. Parameters ---------- @@ -1826,10 +2577,12 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, Max number of iterations stopThr : float, optional Stop threshold on relative change in the barycenter (>0) - log : bool, optional - record log if True verbose : bool, optional (default=False) Controls the verbosity of the optimization algorithm + log : bool, optional + record log if True + warn : bool, optional + if True, raises a warning if the algorithm doesn't convergence. Returns ------- @@ -1844,9 +2597,8 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, ---------- .. [27] Ievgen Redko, Nicolas Courty, Rémi Flamary, Devis Tuia - "Optimal transport for multi-source domain adaptation under target shift", - International Conference on Artificial Intelligence and Statistics (AISTATS), 2019. - + "Optimal transport for multi-source domain adaptation under target shift", + International Conference on Artificial Intelligence and Statistics (AISTATS), 2019. ''' Xs = list_to_array(*Xs) @@ -1901,11 +2653,10 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, # uniform target distribution a = nx.from_numpy(unif(Xt.shape[0]), type_as=Xs[0]) - cpt = 0 # iterations count err = 1 old_bary = nx.ones((nbclasses,), type_as=Xs[0]) - while (err > stopThr and cpt < numItermax): + for ii in range(numItermax): bary = nx.zeros((nbclasses,), type_as=Xs[0]) @@ -1923,21 +2674,27 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, K[d] = projR(K[d], new) err = nx.norm(bary - old_bary) - cpt = cpt + 1 + old_bary = bary if log: log['err'].append(err) + if err < stopThr: + break if verbose: - if cpt % 200 == 0: + if ii % 200 == 0: print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(cpt, err)) - + print('{:5d}|{:8e}|'.format(ii, err)) + else: + if warn: + warnings.warn("Algorithm did not converge. You might want to " + "increase the number of iterations `numItermax` " + "or the regularization parameter `reg`.") bary = bary / nx.sum(bary) if log: - log['niter'] = cpt + log['niter'] = ii log['M'] = M log['D1'] = D1 log['D2'] = D2 @@ -1949,7 +2706,7 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9, isLazy=False, batchSize=100, verbose=False, - log=False, **kwargs): + log=False, warn=True, **kwargs): r''' Solve the entropic regularization optimal transport problem and return the OT matrix from empirical data @@ -1967,7 +2724,8 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', where : - :math:`\mathbf{M}` is the (`n_samples_a`, `n_samples_b`) metric cost matrix - - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\Omega` is the entropic regularization term + :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1) @@ -1988,7 +2746,9 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', stopThr : float, optional Stop threshold on error (>0) isLazy: boolean, optional - If True, then only calculate the cost matrix by block and return the dual potentials only (to save memory). If False, calculate full cost matrix and return outputs of sinkhorn function. + If True, then only calculate the cost matrix by block and return + the dual potentials only (to save memory). If False, calculate full + cost matrix and return outputs of sinkhorn function. batchSize: int or tuple of 2 int, optional Size of the batches used to compute the sinkhorn update without memory overhead. When a tuple is provided it sets the size of the left/right batches. @@ -1996,6 +2756,8 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', Print information along iterations log : bool, optional record log if True + warn : bool, optional + if True, raises a warning if the algorithm doesn't convergence. Returns @@ -2021,11 +2783,14 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', References ---------- - .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal + Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 - .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519. + .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for + Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519. - .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. + .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). + Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. ''' X_s, X_t = list_to_array(X_s, X_t) @@ -2100,7 +2865,11 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', if err <= stopThr: break - + else: + if warn: + warnings.warn("Sinkhorn did not converge. You might want to " + "increase the number of iterations `numItermax` " + "or the regularization parameter `reg`.") if log: dict_log["u"] = f dict_log["v"] = g @@ -2111,15 +2880,18 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', else: M = dist(X_s, X_t, metric=metric) if log: - pi, log = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=True, **kwargs) + pi, log = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, + verbose=verbose, log=True, **kwargs) return pi, log else: - pi = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=False, **kwargs) + pi = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, + verbose=verbose, log=False, **kwargs) return pi -def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9, - isLazy=False, batchSize=100, verbose=False, log=False, **kwargs): +def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', + numIterMax=10000, stopThr=1e-9, isLazy=False, + batchSize=100, verbose=False, log=False, warn=True, **kwargs): r''' Solve the entropic regularization optimal transport problem from empirical data and return the OT loss @@ -2138,7 +2910,8 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num where : - :math:`\mathbf{M}` is the (`n_samples_a`, `n_samples_b`) metric cost matrix - - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\Omega` is the entropic regularization term + :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1) @@ -2159,7 +2932,9 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num stopThr : float, optional Stop threshold on error (>0) isLazy: boolean, optional - If True, then only calculate the cost matrix by block and return the dual potentials only (to save memory). If False, calculate full cost matrix and return outputs of sinkhorn function. + If True, then only calculate the cost matrix by block and return + the dual potentials only (to save memory). If False, calculate + full cost matrix and return outputs of sinkhorn function. batchSize: int or tuple of 2 int, optional Size of the batches used to compute the sinkhorn update without memory overhead. When a tuple is provided it sets the size of the left/right batches. @@ -2167,6 +2942,8 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num Print information along iterations log : bool, optional record log if True + warn : bool, optional + if True, raises a warning if the algorithm doesn't convergence. Returns @@ -2192,11 +2969,17 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num References ---------- - .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation + of Optimal Transport, Advances in Neural Information + Processing Systems (NIPS) 26, 2013 - .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519. + .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling + Algorithms for Entropy Regularized Transport Problems. + arXiv preprint arXiv:1610.06519. - .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. + .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). + Scaling algorithms for unbalanced transport problems. + arXiv preprint arXiv:1607.05816. ''' X_s, X_t = list_to_array(X_s, X_t) @@ -2211,11 +2994,19 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num if isLazy: if log: - f, g, dict_log = empirical_sinkhorn(X_s, X_t, reg, a, b, metric, numIterMax=numIterMax, stopThr=stopThr, - isLazy=isLazy, batchSize=batchSize, verbose=verbose, log=log) + f, g, dict_log = empirical_sinkhorn(X_s, X_t, reg, a, b, metric, + numIterMax=numIterMax, + stopThr=stopThr, + isLazy=isLazy, + batchSize=batchSize, + verbose=verbose, log=log, + warn=warn) else: - f, g = empirical_sinkhorn(X_s, X_t, reg, a, b, metric, numIterMax=numIterMax, stopThr=stopThr, - isLazy=isLazy, batchSize=batchSize, verbose=verbose, log=log) + f, g = empirical_sinkhorn(X_s, X_t, reg, a, b, metric, + numIterMax=numIterMax, stopThr=stopThr, + isLazy=isLazy, batchSize=batchSize, + verbose=verbose, log=log, + warn=warn) bs = batchSize if isinstance(batchSize, int) else batchSize[0] range_s = range(0, ns, bs) @@ -2241,17 +3032,21 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num M = nx.from_numpy(M, type_as=a) if log: - sinkhorn_loss, log = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log, - **kwargs) + sinkhorn_loss, log = sinkhorn2(a, b, M, reg, numItermax=numIterMax, + stopThr=stopThr, verbose=verbose, log=log, + warn=warn, **kwargs) return sinkhorn_loss, log else: - sinkhorn_loss = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log, - **kwargs) + sinkhorn_loss = sinkhorn2(a, b, M, reg, numItermax=numIterMax, + stopThr=stopThr, verbose=verbose, log=log, + warn=warn, **kwargs) return sinkhorn_loss -def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9, - verbose=False, log=False, **kwargs): +def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', + numIterMax=10000, stopThr=1e-9, + verbose=False, log=False, warn=True, + **kwargs): r''' Compute the sinkhorn divergence loss from empirical data @@ -2288,8 +3083,11 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli \gamma_b &\geq 0 where : - - :math:`\mathbf{M}` (resp. :math:`\mathbf{M_a}`, :math:`\mathbf{M_b}`) is the (`n_samples_a`, `n_samples_b`) metric cost matrix (resp (`n_samples_a, n_samples_a`) and (`n_samples_b`, `n_samples_b`)) - - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\mathbf{M}` (resp. :math:`\mathbf{M_a}`, :math:`\mathbf{M_b}`) + is the (`n_samples_a`, `n_samples_b`) metric cost matrix + (resp (`n_samples_a, n_samples_a`) and (`n_samples_b`, `n_samples_b`)) + - :math:`\Omega` is the entropic regularization term + :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1) @@ -2313,6 +3111,8 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli Print information along iterations log : bool, optional record log if True + warn : bool, optional + if True, raises a warning if the algorithm doesn't convergence. Returns ------- @@ -2334,17 +3134,26 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli References ---------- - .. [23] Aude Genevay, Gabriel Peyré, Marco Cuturi, Learning Generative Models with Sinkhorn Divergences, Proceedings of the Twenty-First International Conference on Artficial Intelligence and Statistics, (AISTATS) 21, 2018 + .. [23] Aude Genevay, Gabriel Peyré, Marco Cuturi, Learning Generative + Models with Sinkhorn Divergences, Proceedings of the Twenty-First + International Conference on Artficial Intelligence and Statistics, + (AISTATS) 21, 2018 ''' if log: - sinkhorn_loss_ab, log_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, - stopThr=1e-9, verbose=verbose, log=log, **kwargs) + sinkhorn_loss_ab, log_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, + numIterMax=numIterMax, + stopThr=1e-9, verbose=verbose, + log=log, warn=warn, **kwargs) - sinkhorn_loss_a, log_a = empirical_sinkhorn2(X_s, X_s, reg, a, a, metric=metric, numIterMax=numIterMax, - stopThr=1e-9, verbose=verbose, log=log, **kwargs) + sinkhorn_loss_a, log_a = empirical_sinkhorn2(X_s, X_s, reg, a, a, metric=metric, + numIterMax=numIterMax, + stopThr=1e-9, verbose=verbose, + log=log, warn=warn, **kwargs) - sinkhorn_loss_b, log_b = empirical_sinkhorn2(X_t, X_t, reg, b, b, metric=metric, numIterMax=numIterMax, - stopThr=1e-9, verbose=verbose, log=log, **kwargs) + sinkhorn_loss_b, log_b = empirical_sinkhorn2(X_t, X_t, reg, b, b, metric=metric, + numIterMax=numIterMax, + stopThr=1e-9, verbose=verbose, + log=log, warn=warn, **kwargs) sinkhorn_div = sinkhorn_loss_ab - 0.5 * (sinkhorn_loss_a + sinkhorn_loss_b) @@ -2359,25 +3168,33 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli return max(0, sinkhorn_div), log else: - sinkhorn_loss_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, - verbose=verbose, log=log, **kwargs) + sinkhorn_loss_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, + numIterMax=numIterMax, stopThr=1e-9, + verbose=verbose, log=log, + warn=warn, **kwargs) - sinkhorn_loss_a = empirical_sinkhorn2(X_s, X_s, reg, a, a, metric=metric, numIterMax=numIterMax, stopThr=1e-9, - verbose=verbose, log=log, **kwargs) + sinkhorn_loss_a = empirical_sinkhorn2(X_s, X_s, reg, a, a, metric=metric, + numIterMax=numIterMax, stopThr=1e-9, + verbose=verbose, log=log, + warn=warn, **kwargs) - sinkhorn_loss_b = empirical_sinkhorn2(X_t, X_t, reg, b, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, - verbose=verbose, log=log, **kwargs) + sinkhorn_loss_b = empirical_sinkhorn2(X_t, X_t, reg, b, b, metric=metric, + numIterMax=numIterMax, stopThr=1e-9, + verbose=verbose, log=log, + warn=warn, **kwargs) sinkhorn_div = sinkhorn_loss_ab - 0.5 * (sinkhorn_loss_a + sinkhorn_loss_b) return max(0, sinkhorn_div) -def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, restricted=True, - maxiter=10000, maxfun=10000, pgtol=1e-09, verbose=False, log=False): +def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, + restricted=True, maxiter=10000, maxfun=10000, pgtol=1e-09, + verbose=False, log=False): r""" Screening Sinkhorn Algorithm for Regularized Optimal Transport - The function solves an approximate dual of Sinkhorn divergence :ref:`[2] ` which is written as the following optimization problem: + The function solves an approximate dual of Sinkhorn divergence :ref:`[2] + ` which is written as the following optimization problem: .. math:: @@ -2395,56 +3212,49 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res e^{v_j} &\geq \epsilon \kappa, \forall j \in \{1, \ldots, nt\} - The parameters `kappa` and `epsilon` are determined w.r.t the couple number budget of points (`ns_budget`, `nt_budget`), see Equation (5) in :ref:`[26] ` + The parameters `kappa` and `epsilon` are determined w.r.t the couple number + budget of points (`ns_budget`, `nt_budget`), see Equation (5) + in :ref:`[26] ` Parameters ---------- - a : array-like, shape=(ns,) + a: array-like, shape=(ns,) samples weights in the source domain - - b : array-like, shape=(nt,) + b: array-like, shape=(nt,) samples weights in the target domain - - M : array-like, shape=(ns, nt) + M: array-like, shape=(ns, nt) Cost matrix - - reg : `float` + reg: `float` Level of the entropy regularisation - - ns_budget : `int`, default=None + ns_budget: `int`, default=None Number budget of points to be kept in the source domain. If it is None then 50% of the source sample points will be kept - - nt_budget : `int`, default=None + nt_budget: `int`, default=None Number budget of points to be kept in the target domain. If it is None then 50% of the target sample points will be kept - - uniform : `bool`, default=False - If `True`, the source and target distribution are supposed to be uniform, i.e., :math:`a_i = 1 / ns` and :math:`b_j = 1 / nt` - + uniform: `bool`, default=False + If `True`, the source and target distribution are supposed to be uniform, + i.e., :math:`a_i = 1 / ns` and :math:`b_j = 1 / nt` restricted : `bool`, default=True If `True`, a warm-start initialization for the L-BFGS-B solver using a restricted Sinkhorn algorithm with at most 5 iterations - - maxiter : `int`, default=10000 + maxiter: `int`, default=10000 Maximum number of iterations in LBFGS solver - - maxfun : `int`, default=10000 + maxfun: `int`, default=10000 Maximum number of function evaluations in LBFGS solver - - pgtol : `float`, default=1e-09 + pgtol: `float`, default=1e-09 Final objective function accuracy in LBFGS solver - - verbose : `bool`, default=False - If `True`, display informations about the cardinals of the active sets and the parameters kappa - and epsilon - + verbose: `bool`, default=False + If `True`, display informations about the cardinals of the active sets + and the parameters kappa and epsilon Dependency ---------- - To gain more efficiency, screenkhorn needs to call the "Bottleneck" package (https://pypi.org/project/Bottleneck/) - in the screening pre-processing step. If Bottleneck isn't installed, the following error message appears: + To gain more efficiency, screenkhorn needs to call the "Bottleneck" + package (https://pypi.org/project/Bottleneck/) + in the screening pre-processing step. If Bottleneck isn't installed, + the following error message appears: "Bottleneck module doesn't exist. Install it from https://pypi.org/project/Bottleneck/" @@ -2461,9 +3271,11 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res References ----------- - .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, + Advances in Neural Information Processing Systems (NIPS) 26, 2013 - .. [26] Alaya M. Z., Bérar M., Gasso G., Rakotomamonjy A. (2019). Screening Sinkhorn Algorithm for Regularized Optimal Transport (NIPS) 33, 2019 + .. [26] Alaya M. Z., Bérar M., Gasso G., Rakotomamonjy A. (2019). + Screening Sinkhorn Algorithm for Regularized Optimal Transport (NIPS) 33, 2019 """ # check if bottleneck module exists @@ -2471,14 +3283,16 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res import bottleneck except ImportError: warnings.warn( - "Bottleneck module is not installed. Install it from https://pypi.org/project/Bottleneck/ for better performance.") + "Bottleneck module is not installed. Install it from" + " https://pypi.org/project/Bottleneck/ for better performance.") bottleneck = np a, b, M = list_to_array(a, b, M) nx = get_backend(M, a, b) if nx.__name__ == "jax": - raise TypeError("JAX arrays have been received but screenkhorn is not compatible with JAX.") + raise TypeError("JAX arrays have been received but screenkhorn is not " + "compatible with JAX.") ns, nt = M.shape @@ -2582,7 +3396,8 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res if verbose: print("epsilon = %s\n" % epsilon) print("kappa = %s\n" % kappa) - print('Cardinality of selected points: |Isel| = %s \t |Jsel| = %s \n' % (sum(Isel), sum(Jsel))) + print('Cardinality of selected points: |Isel| = %s \t |Jsel| = %s \n' + % (sum(Isel), sum(Jsel))) # Ic, Jc: complementary of the active sets I and J Ic = ~Isel @@ -2638,13 +3453,11 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res cst_u = kappa * epsilon * nx.sum(K_IJc, axis=1) cst_v = epsilon * nx.sum(K_IcJ, axis=0) / kappa - cpt = 1 - while cpt < 5: # 5 iterations + for _ in range(5): # 5 iterations K_IJ_v = nx.dot(K_IJ.T, u0) + cst_v v0 = b_J / (kappa * K_IJ_v) KIJ_u = nx.dot(K_IJ, v0) + cst_u u0 = (kappa * a_I) / KIJ_u - cpt += 1 u0 = projection(u0, epsilon / kappa) v0 = projection(v0, epsilon * kappa) @@ -2655,15 +3468,13 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res def restricted_sinkhorn(usc, vsc, max_iter=5): """ - Restricted Sinkhorn Algorithm as a warm-start initialized point for L-BFGS-B (see Algorithm 1 in supplementary of [26]) + Restricted Sinkhorn Algorithm as a warm-start initialized pointfor L-BFGS-B) """ - cpt = 1 - while cpt < max_iter: + for _ in range(max_iter): K_IJ_v = nx.dot(K_IJ.T, usc) + cst_v vsc = b_J / (kappa * K_IJ_v) KIJ_u = nx.dot(K_IJ, vsc) + cst_u usc = (kappa * a_I) / KIJ_u - cpt += 1 usc = projection(usc, epsilon / kappa) vsc = projection(vsc, epsilon * kappa) diff --git a/test/test_bregman.py b/test/test_bregman.py index 6923d31..edfe9c3 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -6,6 +6,8 @@ # # License: MIT License +from itertools import product + import numpy as np import pytest @@ -13,7 +15,8 @@ import ot from ot.backend import torch -def test_sinkhorn(): +@pytest.mark.parametrize("verbose, warn", product([True, False], [True, False])) +def test_sinkhorn(verbose, warn): # test sinkhorn n = 100 rng = np.random.RandomState(0) @@ -23,7 +26,7 @@ def test_sinkhorn(): M = ot.dist(x, x) - G = ot.sinkhorn(u, u, M, 1, stopThr=1e-10) + G = ot.sinkhorn(u, u, M, 1, stopThr=1e-10, verbose=verbose, warn=warn) # check constraints np.testing.assert_allclose( @@ -31,8 +34,92 @@ def test_sinkhorn(): np.testing.assert_allclose( u, G.sum(0), atol=1e-05) # cf convergence sinkhorn + with pytest.warns(UserWarning): + ot.sinkhorn(u, u, M, 1, stopThr=0, numItermax=1) + + +@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized", + "sinkhorn_epsilon_scaling", + "greenkhorn", + "sinkhorn_log"]) +def test_convergence_warning(method): + # test sinkhorn + n = 100 + a1 = ot.datasets.make_1D_gauss(n, m=30, s=10) + a2 = ot.datasets.make_1D_gauss(n, m=40, s=10) + A = np.asarray([a1, a2]).T + M = ot.utils.dist0(n) + + with pytest.warns(UserWarning): + ot.sinkhorn(a1, a2, M, 1., method=method, stopThr=0, numItermax=1) + + if method in ["sinkhorn", "sinkhorn_stabilized", "sinkhorn_log"]: + with pytest.warns(UserWarning): + ot.barycenter(A, M, 1, method=method, stopThr=0, numItermax=1) + with pytest.warns(UserWarning): + ot.sinkhorn2(a1, a2, M, 1, method=method, stopThr=0, numItermax=1) + + +def test_not_impemented_method(): + # test sinkhorn + w = 10 + n = w ** 2 + rng = np.random.RandomState(42) + A_img = rng.rand(2, w, w) + A_flat = A_img.reshape(n, 2) + a1, a2 = A_flat.T + M_flat = ot.utils.dist0(n) + not_implemented = "new_method" + reg = 0.01 + with pytest.raises(ValueError): + ot.sinkhorn(a1, a2, M_flat, reg, method=not_implemented) + with pytest.raises(ValueError): + ot.sinkhorn2(a1, a2, M_flat, reg, method=not_implemented) + with pytest.raises(ValueError): + ot.barycenter(A_flat, M_flat, reg, method=not_implemented) + with pytest.raises(ValueError): + ot.bregman.barycenter_debiased(A_flat, M_flat, reg, + method=not_implemented) + with pytest.raises(ValueError): + ot.bregman.convolutional_barycenter2d(A_img, reg, + method=not_implemented) + with pytest.raises(ValueError): + ot.bregman.convolutional_barycenter2d_debiased(A_img, reg, + method=not_implemented) + + +@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"]) +def test_nan_warning(method): + # test sinkhorn + n = 100 + a1 = ot.datasets.make_1D_gauss(n, m=30, s=10) + a2 = ot.datasets.make_1D_gauss(n, m=40, s=10) + + M = ot.utils.dist0(n) + reg = 0 + with pytest.warns(UserWarning): + # warn set to False to avoid catching a convergence warning instead + ot.sinkhorn(a1, a2, M, reg, method=method, warn=False) + + +def test_sinkhorn_stabilization(): + # test sinkhorn + n = 100 + a1 = ot.datasets.make_1D_gauss(n, m=30, s=10) + a2 = ot.datasets.make_1D_gauss(n, m=40, s=10) + M = ot.utils.dist0(n) + reg = 1e-5 + loss1 = ot.sinkhorn2(a1, a2, M, reg, method="sinkhorn_log") + loss2 = ot.sinkhorn2(a1, a2, M, reg, tau=1, method="sinkhorn_stabilized") + np.testing.assert_allclose( + loss1, loss2, atol=1e-06) # cf convergence sinkhorn + -def test_sinkhorn_multi_b(): +@pytest.mark.parametrize("method, verbose, warn", + product(["sinkhorn", "sinkhorn_stabilized", + "sinkhorn_log"], + [True, False], [True, False])) +def test_sinkhorn_multi_b(method, verbose, warn): # test sinkhorn n = 10 rng = np.random.RandomState(0) @@ -45,12 +132,14 @@ def test_sinkhorn_multi_b(): M = ot.dist(x, x) - loss0, log = ot.sinkhorn(u, b, M, .1, stopThr=1e-10, log=True) + loss0, log = ot.sinkhorn(u, b, M, .1, method=method, stopThr=1e-10, + log=True) - loss = [ot.sinkhorn2(u, b[:, k], M, .1, stopThr=1e-10) for k in range(3)] + loss = [ot.sinkhorn2(u, b[:, k], M, .1, method=method, stopThr=1e-10, + verbose=verbose, warn=warn) for k in range(3)] # check constraints np.testing.assert_allclose( - loss0, loss, atol=1e-06) # cf convergence sinkhorn + loss0, loss, atol=1e-4) # cf convergence sinkhorn def test_sinkhorn_backends(nx): @@ -67,9 +156,9 @@ def test_sinkhorn_backends(nx): G = ot.sinkhorn(a, a, M, 1) ab = nx.from_numpy(a) - Mb = nx.from_numpy(M) + M_nx = nx.from_numpy(M) - Gb = ot.sinkhorn(ab, ab, Mb, 1) + Gb = ot.sinkhorn(ab, ab, M_nx, 1) np.allclose(G, nx.to_numpy(Gb)) @@ -88,9 +177,9 @@ def test_sinkhorn2_backends(nx): G = ot.sinkhorn(a, a, M, 1) ab = nx.from_numpy(a) - Mb = nx.from_numpy(M) + M_nx = nx.from_numpy(M) - Gb = ot.sinkhorn2(ab, ab, Mb, 1) + Gb = ot.sinkhorn2(ab, ab, M_nx, 1) np.allclose(G, nx.to_numpy(Gb)) @@ -131,6 +220,12 @@ def test_sinkhorn_empty(): M = ot.dist(x, x) + G, log = ot.sinkhorn([], [], M, 1, stopThr=1e-10, method="sinkhorn_log", + verbose=True, log=True) + # check constraints + np.testing.assert_allclose(u, G.sum(1), atol=1e-05) + np.testing.assert_allclose(u, G.sum(0), atol=1e-05) + G, log = ot.sinkhorn([], [], M, 1, stopThr=1e-10, verbose=True, log=True) # check constraints np.testing.assert_allclose(u, G.sum(1), atol=1e-05) @@ -165,15 +260,15 @@ def test_sinkhorn_variants(nx): M = ot.dist(x, x) ub = nx.from_numpy(u) - Mb = nx.from_numpy(M) + M_nx = nx.from_numpy(M) G = ot.sinkhorn(u, u, M, 1, method='sinkhorn', stopThr=1e-10) - Gl = nx.to_numpy(ot.sinkhorn(ub, ub, Mb, 1, method='sinkhorn_log', stopThr=1e-10)) - G0 = nx.to_numpy(ot.sinkhorn(ub, ub, Mb, 1, method='sinkhorn', stopThr=1e-10)) - Gs = nx.to_numpy(ot.sinkhorn(ub, ub, Mb, 1, method='sinkhorn_stabilized', stopThr=1e-10)) + Gl = nx.to_numpy(ot.sinkhorn(ub, ub, M_nx, 1, method='sinkhorn_log', stopThr=1e-10)) + G0 = nx.to_numpy(ot.sinkhorn(ub, ub, M_nx, 1, method='sinkhorn', stopThr=1e-10)) + Gs = nx.to_numpy(ot.sinkhorn(ub, ub, M_nx, 1, method='sinkhorn_stabilized', stopThr=1e-10)) Ges = nx.to_numpy(ot.sinkhorn( - ub, ub, Mb, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10)) - G_green = nx.to_numpy(ot.sinkhorn(ub, ub, Mb, 1, method='greenkhorn', stopThr=1e-10)) + ub, ub, M_nx, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10)) + G_green = nx.to_numpy(ot.sinkhorn(ub, ub, M_nx, 1, method='greenkhorn', stopThr=1e-10)) # check values np.testing.assert_allclose(G, G0, atol=1e-05) @@ -199,12 +294,12 @@ def test_sinkhorn_variants_multi_b(nx): ub = nx.from_numpy(u) bb = nx.from_numpy(b) - Mb = nx.from_numpy(M) + M_nx = nx.from_numpy(M) G = ot.sinkhorn(u, b, M, 1, method='sinkhorn', stopThr=1e-10) - Gl = nx.to_numpy(ot.sinkhorn(ub, bb, Mb, 1, method='sinkhorn_log', stopThr=1e-10)) - G0 = nx.to_numpy(ot.sinkhorn(ub, bb, Mb, 1, method='sinkhorn', stopThr=1e-10)) - Gs = nx.to_numpy(ot.sinkhorn(ub, bb, Mb, 1, method='sinkhorn_stabilized', stopThr=1e-10)) + Gl = nx.to_numpy(ot.sinkhorn(ub, bb, M_nx, 1, method='sinkhorn_log', stopThr=1e-10)) + G0 = nx.to_numpy(ot.sinkhorn(ub, bb, M_nx, 1, method='sinkhorn', stopThr=1e-10)) + Gs = nx.to_numpy(ot.sinkhorn(ub, bb, M_nx, 1, method='sinkhorn_stabilized', stopThr=1e-10)) # check values np.testing.assert_allclose(G, G0, atol=1e-05) @@ -228,12 +323,12 @@ def test_sinkhorn2_variants_multi_b(nx): ub = nx.from_numpy(u) bb = nx.from_numpy(b) - Mb = nx.from_numpy(M) + M_nx = nx.from_numpy(M) G = ot.sinkhorn2(u, b, M, 1, method='sinkhorn', stopThr=1e-10) - Gl = nx.to_numpy(ot.sinkhorn2(ub, bb, Mb, 1, method='sinkhorn_log', stopThr=1e-10)) - G0 = nx.to_numpy(ot.sinkhorn2(ub, bb, Mb, 1, method='sinkhorn', stopThr=1e-10)) - Gs = nx.to_numpy(ot.sinkhorn2(ub, bb, Mb, 1, method='sinkhorn_stabilized', stopThr=1e-10)) + Gl = nx.to_numpy(ot.sinkhorn2(ub, bb, M_nx, 1, method='sinkhorn_log', stopThr=1e-10)) + G0 = nx.to_numpy(ot.sinkhorn2(ub, bb, M_nx, 1, method='sinkhorn', stopThr=1e-10)) + Gs = nx.to_numpy(ot.sinkhorn2(ub, bb, M_nx, 1, method='sinkhorn_stabilized', stopThr=1e-10)) # check values np.testing.assert_allclose(G, G0, atol=1e-05) @@ -255,7 +350,7 @@ def test_sinkhorn_variants_log(): Gl, logl = ot.sinkhorn(u, u, M, 1, method='sinkhorn_log', stopThr=1e-10, log=True) Gs, logs = ot.sinkhorn(u, u, M, 1, method='sinkhorn_stabilized', stopThr=1e-10, log=True) Ges, loges = ot.sinkhorn( - u, u, M, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10, log=True) + u, u, M, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10, log=True,) G_green, loggreen = ot.sinkhorn(u, u, M, 1, method='greenkhorn', stopThr=1e-10, log=True) # check values @@ -265,7 +360,8 @@ def test_sinkhorn_variants_log(): np.testing.assert_allclose(G0, G_green, atol=1e-5) -def test_sinkhorn_variants_log_multib(): +@pytest.mark.parametrize("verbose, warn", product([True, False], [True, False])) +def test_sinkhorn_variants_log_multib(verbose, warn): # test sinkhorn n = 50 rng = np.random.RandomState(0) @@ -278,16 +374,20 @@ def test_sinkhorn_variants_log_multib(): M = ot.dist(x, x) G0, log0 = ot.sinkhorn(u, b, M, 1, method='sinkhorn', stopThr=1e-10, log=True) - Gl, logl = ot.sinkhorn(u, b, M, 1, method='sinkhorn_log', stopThr=1e-10, log=True) - Gs, logs = ot.sinkhorn(u, b, M, 1, method='sinkhorn_stabilized', stopThr=1e-10, log=True) + Gl, logl = ot.sinkhorn(u, b, M, 1, method='sinkhorn_log', stopThr=1e-10, log=True, + verbose=verbose, warn=warn) + Gs, logs = ot.sinkhorn(u, b, M, 1, method='sinkhorn_stabilized', stopThr=1e-10, log=True, + verbose=verbose, warn=warn) # check values np.testing.assert_allclose(G0, Gs, atol=1e-05) np.testing.assert_allclose(G0, Gl, atol=1e-05) -@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"]) -def test_barycenter(nx, method): +@pytest.mark.parametrize("method, verbose, warn", + product(["sinkhorn", "sinkhorn_stabilized", "sinkhorn_log"], + [True, False], [True, False])) +def test_barycenter(nx, method, verbose, warn): n_bins = 100 # nb bins # Gaussian distributions @@ -304,20 +404,98 @@ def test_barycenter(nx, method): alpha = 0.5 # 0<=alpha<=1 weights = np.array([1 - alpha, alpha]) - Ab = nx.from_numpy(A) - Mb = nx.from_numpy(M) - weightsb = nx.from_numpy(weights) + A_nx = nx.from_numpy(A) + M_nx = nx.from_numpy(M) + weights_nx = nx.from_numpy(weights) + reg = 1e-2 + + if nx.__name__ == "jax" and method == "sinkhorn_log": + with pytest.raises(NotImplementedError): + ot.bregman.barycenter(A_nx, M_nx, reg, weights, method=method) + else: + # wasserstein + bary_wass_np = ot.bregman.barycenter(A, M, reg, weights, method=method, verbose=verbose, warn=warn) + bary_wass, _ = ot.bregman.barycenter(A_nx, M_nx, reg, weights_nx, method=method, log=True) + bary_wass = nx.to_numpy(bary_wass) + + np.testing.assert_allclose(1, np.sum(bary_wass)) + np.testing.assert_allclose(bary_wass, bary_wass_np) + + ot.bregman.barycenter(A_nx, M_nx, reg, log=True) + + +@pytest.mark.parametrize("method, verbose, warn", + product(["sinkhorn", "sinkhorn_log"], + [True, False], [True, False])) +def test_barycenter_debiased(nx, method, verbose, warn): + n_bins = 100 # nb bins + + # Gaussian distributions + a1 = ot.datasets.make_1D_gauss(n_bins, m=30, s=10) # m= mean, s= std + a2 = ot.datasets.make_1D_gauss(n_bins, m=40, s=10) + + # creating matrix A containing all distributions + A = np.vstack((a1, a2)).T + + # loss matrix + normalization + M = ot.utils.dist0(n_bins) + M /= M.max() + + alpha = 0.5 # 0<=alpha<=1 + weights = np.array([1 - alpha, alpha]) + + A_nx = nx.from_numpy(A) + M_nx = nx.from_numpy(M) + weights_nx = nx.from_numpy(weights) # wasserstein reg = 1e-2 - bary_wass_np, log = ot.bregman.barycenter(A, M, reg, weights, method=method, log=True) - bary_wass, _ = ot.bregman.barycenter(Ab, Mb, reg, weightsb, method=method, log=True) - bary_wass = nx.to_numpy(bary_wass) + if nx.__name__ == "jax" and method == "sinkhorn_log": + with pytest.raises(NotImplementedError): + ot.bregman.barycenter_debiased(A_nx, M_nx, reg, weights, method=method) + else: + bary_wass_np = ot.bregman.barycenter_debiased(A, M, reg, weights, method=method, + verbose=verbose, warn=warn) + bary_wass, _ = ot.bregman.barycenter_debiased(A_nx, M_nx, reg, weights_nx, method=method, log=True) + bary_wass = nx.to_numpy(bary_wass) + + np.testing.assert_allclose(1, np.sum(bary_wass), atol=1e-3) + np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-5) + + ot.bregman.barycenter_debiased(A_nx, M_nx, reg, log=True, verbose=False) - np.testing.assert_allclose(1, np.sum(bary_wass)) - np.testing.assert_allclose(bary_wass, bary_wass_np) - ot.bregman.barycenter(Ab, Mb, reg, log=True, verbose=True) +@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_log"]) +def test_convergence_warning_barycenters(method): + w = 10 + n_bins = w ** 2 # nb bins + + # Gaussian distributions + a1 = ot.datasets.make_1D_gauss(n_bins, m=30, s=10) # m= mean, s= std + a2 = ot.datasets.make_1D_gauss(n_bins, m=40, s=10) + + # creating matrix A containing all distributions + A = np.vstack((a1, a2)).T + A_img = A.reshape(2, w, w) + A_img /= A_img.sum((1, 2))[:, None, None] + + # loss matrix + normalization + M = ot.utils.dist0(n_bins) + M /= M.max() + + alpha = 0.5 # 0<=alpha<=1 + weights = np.array([1 - alpha, alpha]) + reg = 0.1 + with pytest.warns(UserWarning): + ot.bregman.barycenter_debiased(A, M, reg, weights, method=method, numItermax=1) + with pytest.warns(UserWarning): + ot.bregman.barycenter(A, M, reg, weights, method=method, numItermax=1) + with pytest.warns(UserWarning): + ot.bregman.convolutional_barycenter2d(A_img, reg, weights, + method=method, numItermax=1) + with pytest.warns(UserWarning): + ot.bregman.convolutional_barycenter2d_debiased(A_img, reg, weights, + method=method, numItermax=1) def test_barycenter_stabilization(nx): @@ -337,31 +515,64 @@ def test_barycenter_stabilization(nx): alpha = 0.5 # 0<=alpha<=1 weights = np.array([1 - alpha, alpha]) - Ab = nx.from_numpy(A) - Mb = nx.from_numpy(M) + A_nx = nx.from_numpy(A) + M_nx = nx.from_numpy(M) weights_b = nx.from_numpy(weights) # wasserstein reg = 1e-2 bar_np = ot.bregman.barycenter(A, M, reg, weights, method="sinkhorn", stopThr=1e-8, verbose=True) bar_stable = nx.to_numpy(ot.bregman.barycenter( - Ab, Mb, reg, weights_b, method="sinkhorn_stabilized", + A_nx, M_nx, reg, weights_b, method="sinkhorn_stabilized", stopThr=1e-8, verbose=True )) bar = nx.to_numpy(ot.bregman.barycenter( - Ab, Mb, reg, weights_b, method="sinkhorn", + A_nx, M_nx, reg, weights_b, method="sinkhorn", stopThr=1e-8, verbose=True )) np.testing.assert_allclose(bar, bar_stable) np.testing.assert_allclose(bar, bar_np) -def test_wasserstein_bary_2d(nx): - size = 100 # size of a square image - a1 = np.random.randn(size, size) +@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_log"]) +def test_wasserstein_bary_2d(nx, method): + size = 20 # size of a square image + a1 = np.random.rand(size, size) + a1 += a1.min() + a1 = a1 / np.sum(a1) + a2 = np.random.rand(size, size) + a2 += a2.min() + a2 = a2 / np.sum(a2) + # creating matrix A containing all distributions + A = np.zeros((2, size, size)) + A[0, :, :] = a1 + A[1, :, :] = a2 + + A_nx = nx.from_numpy(A) + + # wasserstein + reg = 1e-2 + if nx.__name__ == "jax" and method == "sinkhorn_log": + with pytest.raises(NotImplementedError): + ot.bregman.convolutional_barycenter2d(A_nx, reg, method=method) + else: + bary_wass_np = ot.bregman.convolutional_barycenter2d(A, reg, method=method) + bary_wass = nx.to_numpy(ot.bregman.convolutional_barycenter2d(A_nx, reg, method=method)) + + np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3) + np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3) + + # help in checking if log and verbose do not bug the function + ot.bregman.convolutional_barycenter2d(A, reg, log=True, verbose=True) + + +@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_log"]) +def test_wasserstein_bary_2d_debiased(nx, method): + size = 20 # size of a square image + a1 = np.random.rand(size, size) a1 += a1.min() a1 = a1 / np.sum(a1) - a2 = np.random.randn(size, size) + a2 = np.random.rand(size, size) a2 += a2.min() a2 = a2 / np.sum(a2) # creating matrix A containing all distributions @@ -369,18 +580,22 @@ def test_wasserstein_bary_2d(nx): A[0, :, :] = a1 A[1, :, :] = a2 - Ab = nx.from_numpy(A) + A_nx = nx.from_numpy(A) # wasserstein reg = 1e-2 - bary_wass_np = ot.bregman.convolutional_barycenter2d(A, reg) - bary_wass = nx.to_numpy(ot.bregman.convolutional_barycenter2d(Ab, reg)) + if nx.__name__ == "jax" and method == "sinkhorn_log": + with pytest.raises(NotImplementedError): + ot.bregman.convolutional_barycenter2d_debiased(A_nx, reg, method=method) + else: + bary_wass_np = ot.bregman.convolutional_barycenter2d_debiased(A, reg, method=method) + bary_wass = nx.to_numpy(ot.bregman.convolutional_barycenter2d_debiased(A_nx, reg, method=method)) - np.testing.assert_allclose(1, np.sum(bary_wass)) - np.testing.assert_allclose(bary_wass, bary_wass_np) + np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3) + np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3) - # help in checking if log and verbose do not bug the function - ot.bregman.convolutional_barycenter2d(A, reg, log=True, verbose=True) + # help in checking if log and verbose do not bug the function + ot.bregman.convolutional_barycenter2d(A, reg, log=True, verbose=True) def test_unmix(nx): @@ -405,20 +620,20 @@ def test_unmix(nx): ab = nx.from_numpy(a) Db = nx.from_numpy(D) - Mb = nx.from_numpy(M) + M_nx = nx.from_numpy(M) M0b = nx.from_numpy(M0) h0b = nx.from_numpy(h0) # wasserstein reg = 1e-3 um_np = ot.bregman.unmix(a, D, M, M0, h0, reg, 1, alpha=0.01) - um = nx.to_numpy(ot.bregman.unmix(ab, Db, Mb, M0b, h0b, reg, 1, alpha=0.01)) + um = nx.to_numpy(ot.bregman.unmix(ab, Db, M_nx, M0b, h0b, reg, 1, alpha=0.01)) np.testing.assert_allclose(1, np.sum(um), rtol=1e-03, atol=1e-03) np.testing.assert_allclose([0.5, 0.5], um, rtol=1e-03, atol=1e-03) np.testing.assert_allclose(um, um_np) - ot.bregman.unmix(ab, Db, Mb, M0b, h0b, reg, + ot.bregman.unmix(ab, Db, M_nx, M0b, h0b, reg, 1, alpha=0.01, log=True, verbose=True) @@ -437,22 +652,22 @@ def test_empirical_sinkhorn(nx): bb = nx.from_numpy(b) X_sb = nx.from_numpy(X_s) X_tb = nx.from_numpy(X_t) - Mb = nx.from_numpy(M, type_as=ab) + M_nx = nx.from_numpy(M, type_as=ab) M_mb = nx.from_numpy(M_m, type_as=ab) G_sqe = nx.to_numpy(ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1)) - sinkhorn_sqe = nx.to_numpy(ot.sinkhorn(ab, bb, Mb, 1)) + sinkhorn_sqe = nx.to_numpy(ot.sinkhorn(ab, bb, M_nx, 1)) G_log, log_es = ot.bregman.empirical_sinkhorn(X_sb, X_tb, 0.1, log=True) G_log = nx.to_numpy(G_log) - sinkhorn_log, log_s = ot.sinkhorn(ab, bb, Mb, 0.1, log=True) + sinkhorn_log, log_s = ot.sinkhorn(ab, bb, M_nx, 0.1, log=True) sinkhorn_log = nx.to_numpy(sinkhorn_log) G_m = nx.to_numpy(ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1, metric='euclidean')) sinkhorn_m = nx.to_numpy(ot.sinkhorn(ab, bb, M_mb, 1)) loss_emp_sinkhorn = nx.to_numpy(ot.bregman.empirical_sinkhorn2(X_sb, X_tb, 1)) - loss_sinkhorn = nx.to_numpy(ot.sinkhorn2(ab, bb, Mb, 1)) + loss_sinkhorn = nx.to_numpy(ot.sinkhorn2(ab, bb, M_nx, 1)) # check constraints np.testing.assert_allclose( @@ -486,18 +701,18 @@ def test_lazy_empirical_sinkhorn(nx): bb = nx.from_numpy(b) X_sb = nx.from_numpy(X_s) X_tb = nx.from_numpy(X_t) - Mb = nx.from_numpy(M, type_as=ab) + M_nx = nx.from_numpy(M, type_as=ab) M_mb = nx.from_numpy(M_m, type_as=ab) f, g = ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1, numIterMax=numIterMax, isLazy=True, batchSize=(1, 3), verbose=True) f, g = nx.to_numpy(f), nx.to_numpy(g) G_sqe = np.exp(f[:, None] + g[None, :] - M / 1) - sinkhorn_sqe = nx.to_numpy(ot.sinkhorn(ab, bb, Mb, 1)) + sinkhorn_sqe = nx.to_numpy(ot.sinkhorn(ab, bb, M_nx, 1)) f, g, log_es = ot.bregman.empirical_sinkhorn(X_sb, X_tb, 0.1, numIterMax=numIterMax, isLazy=True, batchSize=1, log=True) f, g = nx.to_numpy(f), nx.to_numpy(g) G_log = np.exp(f[:, None] + g[None, :] - M / 0.1) - sinkhorn_log, log_s = ot.sinkhorn(ab, bb, Mb, 0.1, log=True) + sinkhorn_log, log_s = ot.sinkhorn(ab, bb, M_nx, 0.1, log=True) sinkhorn_log = nx.to_numpy(sinkhorn_log) f, g = ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1, metric='euclidean', numIterMax=numIterMax, isLazy=True, batchSize=1) @@ -507,7 +722,7 @@ def test_lazy_empirical_sinkhorn(nx): loss_emp_sinkhorn, log = ot.bregman.empirical_sinkhorn2(X_sb, X_tb, 1, numIterMax=numIterMax, isLazy=True, batchSize=1, log=True) loss_emp_sinkhorn = nx.to_numpy(loss_emp_sinkhorn) - loss_sinkhorn = nx.to_numpy(ot.sinkhorn2(ab, bb, Mb, 1)) + loss_sinkhorn = nx.to_numpy(ot.sinkhorn2(ab, bb, M_nx, 1)) # check constraints np.testing.assert_allclose( @@ -541,13 +756,13 @@ def test_empirical_sinkhorn_divergence(nx): bb = nx.from_numpy(b) X_sb = nx.from_numpy(X_s) X_tb = nx.from_numpy(X_t) - Mb = nx.from_numpy(M, type_as=ab) + M_nx = nx.from_numpy(M, type_as=ab) M_sb = nx.from_numpy(M_s, type_as=ab) M_tb = nx.from_numpy(M_t, type_as=ab) emp_sinkhorn_div = nx.to_numpy(ot.bregman.empirical_sinkhorn_divergence(X_sb, X_tb, 1, a=ab, b=bb)) sinkhorn_div = nx.to_numpy( - ot.sinkhorn2(ab, bb, Mb, 1) + ot.sinkhorn2(ab, bb, M_nx, 1) - 1 / 2 * ot.sinkhorn2(ab, ab, M_sb, 1) - 1 / 2 * ot.sinkhorn2(bb, bb, M_tb, 1) ) @@ -580,14 +795,14 @@ def test_stabilized_vs_sinkhorn_multidim(nx): ab = nx.from_numpy(a) bb = nx.from_numpy(b) - Mb = nx.from_numpy(M, type_as=ab) + M_nx = nx.from_numpy(M, type_as=ab) G_np, _ = ot.bregman.sinkhorn(a, b, M, reg=epsilon, method="sinkhorn", log=True) - G, log = ot.bregman.sinkhorn(ab, bb, Mb, reg=epsilon, + G, log = ot.bregman.sinkhorn(ab, bb, M_nx, reg=epsilon, method="sinkhorn_stabilized", log=True) G = nx.to_numpy(G) - G2, log2 = ot.bregman.sinkhorn(ab, bb, Mb, epsilon, + G2, log2 = ot.bregman.sinkhorn(ab, bb, M_nx, epsilon, method="sinkhorn", log=True) G2 = nx.to_numpy(G2) @@ -642,14 +857,14 @@ def test_screenkhorn(nx): ab = nx.from_numpy(a) bb = nx.from_numpy(b) - Mb = nx.from_numpy(M, type_as=ab) + M_nx = nx.from_numpy(M, type_as=ab) # np sinkhorn G_sink_np = ot.sinkhorn(a, b, M, 1e-03) # sinkhorn - G_sink = nx.to_numpy(ot.sinkhorn(ab, bb, Mb, 1e-03)) + G_sink = nx.to_numpy(ot.sinkhorn(ab, bb, M_nx, 1e-03)) # screenkhorn - G_screen = nx.to_numpy(ot.bregman.screenkhorn(ab, bb, Mb, 1e-03, uniform=True, verbose=True)) + G_screen = nx.to_numpy(ot.bregman.screenkhorn(ab, bb, M_nx, 1e-03, uniform=True, verbose=True)) # check marginals np.testing.assert_allclose(G_sink_np, G_sink) np.testing.assert_allclose(G_sink.sum(0), G_screen.sum(0), atol=1e-02) @@ -659,10 +874,10 @@ def test_screenkhorn(nx): def test_convolutional_barycenter_non_square(nx): # test for image with height not equal width A = np.ones((2, 2, 3)) / (2 * 3) - Ab = nx.from_numpy(A) + A_nx = nx.from_numpy(A) b_np = ot.bregman.convolutional_barycenter2d(A, 1e-03) - b = nx.to_numpy(ot.bregman.convolutional_barycenter2d(Ab, 1e-03)) + b = nx.to_numpy(ot.bregman.convolutional_barycenter2d(A_nx, 1e-03)) np.testing.assert_allclose(np.ones((2, 3)) / (2 * 3), b, atol=1e-02) np.testing.assert_allclose(np.ones((2, 3)) / (2 * 3), b, atol=1e-02) -- cgit v1.2.3 From 2fe69eb130827560ada704bc25998397c4357821 Mon Sep 17 00:00:00 2001 From: Rémi Flamary Date: Thu, 4 Nov 2021 11:00:09 +0100 Subject: [MRG] Make gromov loss differentiable wrt matrices and weights (#302) * grmov differentable * new stuff * test gromov gradients * fgwdifferentiable * fgw tested * correc name test * add awesome example with gromov optimizatrion * pep8+ typos * damn pep8 * thunbnail * remove prints --- README.md | 9 +- examples/backends/plot_optim_gromov_pytorch.py | 260 +++++++++++++++++++++++++ ot/__init__.py | 2 + ot/gromov.py | 141 +++++++++++--- ot/optim.py | 3 +- test/test_gromov.py | 76 ++++++++ 6 files changed, 460 insertions(+), 31 deletions(-) create mode 100644 examples/backends/plot_optim_gromov_pytorch.py (limited to 'test') diff --git a/README.md b/README.md index ff32c53..08db003 100644 --- a/README.md +++ b/README.md @@ -26,7 +26,7 @@ POT provides the following generic OT solvers (links to examples): * Debiased Sinkhorn barycenters [Sinkhorn divergence barycenter](https://pythonot.github.io/auto_examples/barycenters/plot_debiased_barycenter.html) [37] * [Smooth optimal transport solvers](https://pythonot.github.io/auto_examples/plot_OT_1D_smooth.html) (dual and semi-dual) for KL and squared L2 regularizations [17]. * Non regularized [Wasserstein barycenters [16] ](https://pythonot.github.io/auto_examples/barycenters/plot_barycenter_lp_vs_entropic.html)) with LP solver (only small scale). -* [Gromov-Wasserstein distances](https://pythonot.github.io/auto_examples/gromov/plot_gromov.html) and [GW barycenters](https://pythonot.github.io/auto_examples/gromov/plot_gromov_barycenter.html) (exact [13] and regularized [12]) +* [Gromov-Wasserstein distances](https://pythonot.github.io/auto_examples/gromov/plot_gromov.html) and [GW barycenters](https://pythonot.github.io/auto_examples/gromov/plot_gromov_barycenter.html) (exact [13] and regularized [12]), differentiable using gradients from * [Fused-Gromov-Wasserstein distances solver](https://pythonot.github.io/auto_examples/gromov/plot_fgw.html#sphx-glr-auto-examples-plot-fgw-py) and [FGW barycenters](https://pythonot.github.io/auto_examples/gromov/plot_barycenter_fgw.html) [24] * [Stochastic solver](https://pythonot.github.io/auto_examples/plot_stochastic.html) for Large-scale Optimal Transport (semi-dual problem [18] and dual problem [19]) * [Stochastic solver of Gromov Wasserstein](https://pythonot.github.io/auto_examples/gromov/plot_gromov.html) for large-scale problem with any loss functions [33] @@ -295,5 +295,8 @@ You can also post bug reports and feature requests in Github issues. Make sure t via optimal transport and diffusions](http://proceedings.mlr.press/v97/liutkus19a/liutkus19a.pdf). In International Conference on Machine Learning (pp. 4104-4113). PMLR. -[37] Janati, H., Cuturi, M., Gramfort, A. Proceedings of the 37th International -Conference on Machine Learning, PMLR 119:4692-4701, 2020 \ No newline at end of file +[37] Janati, H., Cuturi, M., Gramfort, A. [Debiased sinkhorn barycenters](http://proceedings.mlr.press/v119/janati20a/janati20a.pdf) Proceedings of the 37th International +Conference on Machine Learning, PMLR 119:4692-4701, 2020 + +[38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, [Online Graph +Dictionary Learning](https://arxiv.org/pdf/2102.06555.pdf), International Conference on Machine Learning (ICML), 2021. \ No newline at end of file diff --git a/examples/backends/plot_optim_gromov_pytorch.py b/examples/backends/plot_optim_gromov_pytorch.py new file mode 100644 index 0000000..465f612 --- /dev/null +++ b/examples/backends/plot_optim_gromov_pytorch.py @@ -0,0 +1,260 @@ +r""" +================================= +Optimizing the Gromov-Wasserstein distance with PyTorch +================================= + +In this exemple we use the pytorch backend to optimize the Gromov-Wasserstein +(GW) loss between two graphs expressed as empirical distribution. + +In the first example we optimize the weights on the node of a simple template +graph so that it minimizes the GW with a given Stochastic Block Model graph. +We can see that this actually recovers the proportion of classes in the SBM +and allows for an accurate clustering of the nodes using the GW optimal plan. + +In a second example we optimize simultaneously the weights and the sructure of +the template graph which allows us to perform graph compression and to recover +other properties of the SBM. + +The backend actually uses the gradients expressed in [38] to optimize the +weights. + +[38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, Online Graph +Dictionary Learning, International Conference on Machine Learning (ICML), 2021. + +""" +# Author: Rémi Flamary +# +# License: MIT License + +# sphinx_gallery_thumbnail_number = 3 + +from sklearn.manifold import MDS +import numpy as np +import matplotlib.pylab as pl +import torch + +import ot +from ot.gromov import gromov_wasserstein2 + +# %% +# Graph generation +# --------------- + +rng = np.random.RandomState(42) + + +def get_sbm(n, nc, ratio, P): + nbpc = np.round(n * ratio).astype(int) + n = np.sum(nbpc) + C = np.zeros((n, n)) + for c1 in range(nc): + for c2 in range(c1 + 1): + if c1 == c2: + for i in range(np.sum(nbpc[:c1]), np.sum(nbpc[:c1 + 1])): + for j in range(np.sum(nbpc[:c2]), i): + if rng.rand() <= P[c1, c2]: + C[i, j] = 1 + else: + for i in range(np.sum(nbpc[:c1]), np.sum(nbpc[:c1 + 1])): + for j in range(np.sum(nbpc[:c2]), np.sum(nbpc[:c2 + 1])): + if rng.rand() <= P[c1, c2]: + C[i, j] = 1 + + return C + C.T + + +n = 100 +nc = 3 +ratio = np.array([.5, .3, .2]) +P = np.array(0.6 * np.eye(3) + 0.05 * np.ones((3, 3))) +C1 = get_sbm(n, nc, ratio, P) + +# get 2d position for nodes +x1 = MDS(dissimilarity='precomputed', random_state=0).fit_transform(1 - C1) + + +def plot_graph(x, C, color='C0', s=None): + for j in range(C.shape[0]): + for i in range(j): + if C[i, j] > 0: + pl.plot([x[i, 0], x[j, 0]], [x[i, 1], x[j, 1]], alpha=0.2, color='k') + pl.scatter(x[:, 0], x[:, 1], c=color, s=s, zorder=10, edgecolors='k', cmap='tab10', vmax=9) + + +pl.figure(1, (10, 5)) +pl.clf() +pl.subplot(1, 2, 1) +plot_graph(x1, C1, color='C0') +pl.title("SBM Graph") +pl.axis("off") +pl.subplot(1, 2, 2) +pl.imshow(C1, interpolation='nearest') +pl.title("Adjacency matrix") +pl.axis("off") + + +# %% +# Optimizing the weights of a simple template C0=eye(3) to fit Graph 1 +# ------------------------------------------------ +# The adajacency matrix C1 is block diagonal with 3 blocks. We want to +# optimize the weights of a simple template C0=eye(3) and see if we can +# recover the proportion of classes from the SBM (up to a permutation). + +C0 = np.eye(3) + + +def min_weight_gw(C1, C2, a2, nb_iter_max=100, lr=1e-2): + """ solve min_a GW(C1,C2,a, a2) by gradient descent""" + + # use pyTorch for our data + C1_torch = torch.tensor(C1) + C2_torch = torch.tensor(C2) + + a0 = rng.rand(C1.shape[0]) # random_init + a0 /= a0.sum() # on simplex + a1_torch = torch.tensor(a0).requires_grad_(True) + a2_torch = torch.tensor(a2) + + loss_iter = [] + + for i in range(nb_iter_max): + + loss = gromov_wasserstein2(C1_torch, C2_torch, a1_torch, a2_torch) + + loss_iter.append(loss.clone().detach().cpu().numpy()) + loss.backward() + + #print("{:03d} | {}".format(i, loss_iter[-1])) + + # performs a step of projected gradient descent + with torch.no_grad(): + grad = a1_torch.grad + a1_torch -= grad * lr # step + a1_torch.grad.zero_() + a1_torch.data = ot.utils.proj_simplex(a1_torch) + + a1 = a1_torch.clone().detach().cpu().numpy() + + return a1, loss_iter + + +a0_est, loss_iter0 = min_weight_gw(C0, C1, ot.unif(n), nb_iter_max=100, lr=1e-2) + +pl.figure(2) +pl.plot(loss_iter0) +pl.title("Loss along iterations") + +print("Estimated weights : ", a0_est) +print("True proportions : ", ratio) + + +# %% +# It is clear that the optimization has converged and that we recover the +# ratio of the different classes in the SBM graph up to a permutation. + + +# %% +# Community clustering with uniform and estimated weights +# -------------------------------------------- +# The GW OT plan can be used to perform a clustering of the nodes of a graph +# when computing the GW with a simple template like C0 by labeling nodes in +# the original graph using by the index of the noe in the template receiving +# the most mass. +# +# We show here the result of such a clustering when using uniform weights on +# the template C0 and when using the optimal weights previously estimated. + + +T_unif = ot.gromov_wasserstein(C1, C0, ot.unif(n), ot.unif(3)) +label_unif = T_unif.argmax(1) + +T_est = ot.gromov_wasserstein(C1, C0, ot.unif(n), a0_est) +label_est = T_est.argmax(1) + +pl.figure(3, (10, 5)) +pl.clf() +pl.subplot(1, 2, 1) +plot_graph(x1, C1, color=label_unif) +pl.title("Graph clustering unif. weights") +pl.axis("off") +pl.subplot(1, 2, 2) +plot_graph(x1, C1, color=label_est) +pl.title("Graph clustering est. weights") +pl.axis("off") + + +# %% +# Graph compression with GW +# ------------------------- + +# Now we optimize both the weights and structure of a small graph that +# minimize the GW distance wrt our data graph. This can be seen as graph +# compression but can also recover important properties of an SBM such +# as its class proportion but also its matrix of probability of links between +# classes + + +def graph_compession_gw(nb_nodes, C2, a2, nb_iter_max=100, lr=1e-2): + """ solve min_a GW(C1,C2,a, a2) by gradient descent""" + + # use pyTorch for our data + + C2_torch = torch.tensor(C2) + a2_torch = torch.tensor(a2) + + a0 = rng.rand(nb_nodes) # random_init + a0 /= a0.sum() # on simplex + a1_torch = torch.tensor(a0).requires_grad_(True) + C0 = np.eye(nb_nodes) + C1_torch = torch.tensor(C0).requires_grad_(True) + + loss_iter = [] + + for i in range(nb_iter_max): + + loss = gromov_wasserstein2(C1_torch, C2_torch, a1_torch, a2_torch) + + loss_iter.append(loss.clone().detach().cpu().numpy()) + loss.backward() + + #print("{:03d} | {}".format(i, loss_iter[-1])) + + # performs a step of projected gradient descent + with torch.no_grad(): + grad = a1_torch.grad + a1_torch -= grad * lr # step + a1_torch.grad.zero_() + a1_torch.data = ot.utils.proj_simplex(a1_torch) + + grad = C1_torch.grad + C1_torch -= grad * lr # step + C1_torch.grad.zero_() + C1_torch.data = torch.clamp(C1_torch, 0, 1) + + a1 = a1_torch.clone().detach().cpu().numpy() + C1 = C1_torch.clone().detach().cpu().numpy() + + return a1, C1, loss_iter + + +nb_nodes = 3 +a0_est2, C0_est2, loss_iter2 = graph_compession_gw(nb_nodes, C1, ot.unif(n), + nb_iter_max=100, lr=5e-2) + +pl.figure(4) +pl.plot(loss_iter2) +pl.title("Loss along iterations") + + +print("Estimated weights : ", a0_est2) +print("True proportions : ", ratio) + +pl.figure(6, (10, 3.5)) +pl.clf() +pl.subplot(1, 2, 1) +pl.imshow(P, vmin=0, vmax=1) +pl.title('True SBM P matrix') +pl.subplot(1, 2, 2) +pl.imshow(C0_est2, vmin=0, vmax=1) +pl.title('Estimated C0 matrix') +pl.colorbar() diff --git a/ot/__init__.py b/ot/__init__.py index f20332c..4292b41 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -43,6 +43,8 @@ from .unbalanced import (sinkhorn_unbalanced, barycenter_unbalanced, sinkhorn_unbalanced2) from .da import sinkhorn_lpl1_mm from .sliced import sliced_wasserstein_distance, max_sliced_wasserstein_distance +from .gromov import (gromov_wasserstein, gromov_wasserstein2, + gromov_barycenters, fused_gromov_wasserstein, fused_gromov_wasserstein2) # utils functions from .utils import dist, unif, tic, toc, toq diff --git a/ot/gromov.py b/ot/gromov.py index 465693d..ea667e4 100644 --- a/ot/gromov.py +++ b/ot/gromov.py @@ -174,7 +174,7 @@ def tensor_product(constC, hC1, hC2, T): def gwloss(constC, hC1, hC2, T): - """Return the Loss for Gromov-Wasserstein + r"""Return the Loss for Gromov-Wasserstein The loss is computed as described in Proposition 1 Eq. (6) in :ref:`[12] ` @@ -213,7 +213,7 @@ def gwloss(constC, hC1, hC2, T): def gwggrad(constC, hC1, hC2, T): - """Return the gradient for Gromov-Wasserstein + r"""Return the gradient for Gromov-Wasserstein The gradient is computed as described in Proposition 2 in :ref:`[12] ` @@ -247,7 +247,7 @@ def gwggrad(constC, hC1, hC2, T): def update_square_loss(p, lambdas, T, Cs): - """ + r""" Updates :math:`\mathbf{C}` according to the L2 Loss kernel with the `S` :math:`\mathbf{T}_s` couplings calculated at each iteration @@ -284,7 +284,7 @@ def update_square_loss(p, lambdas, T, Cs): def update_kl_loss(p, lambdas, T, Cs): - """ + r""" Updates :math:`\mathbf{C}` according to the KL Loss kernel with the `S` :math:`\mathbf{T}_s` couplings calculated at each iteration @@ -320,7 +320,7 @@ def update_kl_loss(p, lambdas, T, Cs): return nx.exp(tmpsum / ppt) -def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs): +def gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=False, **kwargs): r""" Returns the gromov-wasserstein transport between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` @@ -386,6 +386,14 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs """ p, q = list_to_array(p, q) + p0, q0, C10, C20 = p, q, C1, C2 + nx = get_backend(p0, q0, C10, C20) + + p = nx.to_numpy(p) + q = nx.to_numpy(q) + C1 = nx.to_numpy(C10) + C2 = nx.to_numpy(C20) + constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun) G0 = p[:, None] * q[None, :] @@ -398,13 +406,15 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs if log: res, log = cg(p, q, 0, 1, f, df, G0, log=True, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs) - log['gw_dist'] = gwloss(constC, hC1, hC2, res) - return res, log + log['gw_dist'] = nx.from_numpy(gwloss(constC, hC1, hC2, res), type_as=C10) + log['u'] = nx.from_numpy(log['u'], type_as=C10) + log['v'] = nx.from_numpy(log['v'], type_as=C10) + return nx.from_numpy(res, type_as=C10), log else: - return cg(p, q, 0, 1, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs) + return nx.from_numpy(cg(p, q, 0, 1, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=False, **kwargs), type_as=C10) -def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs): +def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=False, **kwargs): r""" Returns the gromov-wasserstein discrepancy between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` @@ -420,7 +430,11 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwarg - :math:`\mathbf{C_2}`: Metric cost matrix in the target space - :math:`\mathbf{p}`: distribution in the source space - :math:`\mathbf{q}`: distribution in the target space - - `L`: loss function to account for the misfit between the similarity matrices + - `L`: loss function to account for the misfit between the similarity + matrices + + Note that when using backends, this loss function is differentiable wrt the + marices and weights for quadratic loss using the gradients from [38]_. Parameters ---------- @@ -463,9 +477,21 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwarg metric approach to object matching. Foundations of computational mathematics 11.4 (2011): 417-487. + .. [38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, Online + Graph Dictionary Learning, International Conference on Machine Learning + (ICML), 2021. + """ p, q = list_to_array(p, q) + p0, q0, C10, C20 = p, q, C1, C2 + nx = get_backend(p0, q0, C10, C20) + + p = nx.to_numpy(p) + q = nx.to_numpy(q) + C1 = nx.to_numpy(C10) + C2 = nx.to_numpy(C20) + constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun) G0 = p[:, None] * q[None, :] @@ -475,13 +501,28 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwarg def df(G): return gwggrad(constC, hC1, hC2, G) - res, log_gw = cg(p, q, 0, 1, f, df, G0, log=True, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs) - log_gw['gw_dist'] = gwloss(constC, hC1, hC2, res) - log_gw['T'] = res + + T, log_gw = cg(p, q, 0, 1, f, df, G0, log=True, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs) + + T0 = nx.from_numpy(T, type_as=C10) + + log_gw['gw_dist'] = nx.from_numpy(gwloss(constC, hC1, hC2, T), type_as=C10) + log_gw['u'] = nx.from_numpy(log_gw['u'], type_as=C10) + log_gw['v'] = nx.from_numpy(log_gw['v'], type_as=C10) + log_gw['T'] = T0 + + gw = log_gw['gw_dist'] + + if loss_fun == 'square_loss': + gC1 = nx.from_numpy(2 * C1 * (p[:, None] * p[None, :]) - 2 * T.dot(C2).dot(T.T)) + gC2 = nx.from_numpy(2 * C2 * (q[:, None] * q[None, :]) - 2 * T.T.dot(C1).dot(T)) + gw = nx.set_gradients(gw, (p0, q0, C10, C20), + (log_gw['u'], log_gw['v'], gC1, gC2)) + if log: - return log_gw['gw_dist'], log_gw + return gw, log_gw else: - return log_gw['gw_dist'] + return gw def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, **kwargs): @@ -548,6 +589,15 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, """ p, q = list_to_array(p, q) + p0, q0, C10, C20, M0 = p, q, C1, C2, M + nx = get_backend(p0, q0, C10, C20, M0) + + p = nx.to_numpy(p) + q = nx.to_numpy(q) + C1 = nx.to_numpy(C10) + C2 = nx.to_numpy(C20) + M = nx.to_numpy(M0) + constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun) G0 = p[:, None] * q[None, :] @@ -560,10 +610,16 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, if log: res, log = cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs) - log['fgw_dist'] = log['loss'][::-1][0] - return res, log + + fgw_dist = nx.from_numpy(log['loss'][-1], type_as=C10) + + log['fgw_dist'] = fgw_dist + log['u'] = nx.from_numpy(log['u'], type_as=C10) + log['v'] = nx.from_numpy(log['v'], type_as=C10) + return nx.from_numpy(res, type_as=C10), log + else: - return cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs) + return nx.from_numpy(cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs), type_as=C10) def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, **kwargs): @@ -586,7 +642,11 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5 - :math:`\mathbf{p}` and :math:`\mathbf{q}` are source and target weights (sum to 1) - `L` is a loss function to account for the misfit between the similarity matrices - The algorithm used for solving the problem is conditional gradient as discussed in :ref:`[24] ` + The algorithm used for solving the problem is conditional gradient as + discussed in :ref:`[24] ` + + Note that when using backends, this loss function is differentiable wrt the + marices and weights for quadratic loss using the gradients from [38]_. Parameters ---------- @@ -627,9 +687,22 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5 and Courty Nicolas "Optimal Transport for structured data with application on graphs" International Conference on Machine Learning (ICML). 2019. + + .. [38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, Online + Graph Dictionary Learning, International Conference on Machine Learning + (ICML), 2021. """ p, q = list_to_array(p, q) + p0, q0, C10, C20, M0 = p, q, C1, C2, M + nx = get_backend(p0, q0, C10, C20, M0) + + p = nx.to_numpy(p) + q = nx.to_numpy(q) + C1 = nx.to_numpy(C10) + C2 = nx.to_numpy(C20) + M = nx.to_numpy(M0) + constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun) G0 = p[:, None] * q[None, :] @@ -640,13 +713,27 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5 def df(G): return gwggrad(constC, hC1, hC2, G) - res, log = cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs) + T, log_fgw = cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs) + + fgw_dist = nx.from_numpy(log_fgw['loss'][-1], type_as=C10) + + T0 = nx.from_numpy(T, type_as=C10) + + log_fgw['fgw_dist'] = fgw_dist + log_fgw['u'] = nx.from_numpy(log_fgw['u'], type_as=C10) + log_fgw['v'] = nx.from_numpy(log_fgw['v'], type_as=C10) + log_fgw['T'] = T0 + + if loss_fun == 'square_loss': + gC1 = nx.from_numpy(2 * C1 * (p[:, None] * p[None, :]) - 2 * T.dot(C2).dot(T.T)) + gC2 = nx.from_numpy(2 * C2 * (q[:, None] * q[None, :]) - 2 * T.T.dot(C1).dot(T)) + fgw_dist = nx.set_gradients(fgw_dist, (p0, q0, C10, C20, M0), + (log_fgw['u'], log_fgw['v'], alpha * gC1, alpha * gC2, (1 - alpha) * T0)) + if log: - log['fgw_dist'] = log['loss'][::-1][0] - log['T'] = res - return log['fgw_dist'], log + return fgw_dist, log_fgw else: - return log['fgw_dist'] + return fgw_dist def GW_distance_estimation(C1, C2, p, q, loss_fun, T, @@ -1447,7 +1534,7 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_features=False, p=None, loss_fun='square_loss', max_iter=100, tol=1e-9, verbose=False, log=False, init_C=None, init_X=None, random_state=None): - """Compute the fgw barycenter as presented eq (5) in :ref:`[24] ` + r"""Compute the fgw barycenter as presented eq (5) in :ref:`[24] ` Parameters ---------- @@ -1604,7 +1691,7 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_ def update_structure_matrix(p, lambdas, T, Cs): - """Updates :math:`\mathbf{C}` according to the L2 Loss kernel with the `S` :math:`\mathbf{T}_s` couplings. + r"""Updates :math:`\mathbf{C}` according to the L2 Loss kernel with the `S` :math:`\mathbf{T}_s` couplings. It is calculated at each iteration @@ -1640,7 +1727,7 @@ def update_structure_matrix(p, lambdas, T, Cs): def update_feature_matrix(lambdas, Ys, Ts, p): - """Updates the feature with respect to the `S` :math:`\mathbf{T}_s` couplings. + r"""Updates the feature with respect to the `S` :math:`\mathbf{T}_s` couplings. See "Solving the barycenter problem with Block Coordinate Descent (BCD)" diff --git a/ot/optim.py b/ot/optim.py index cc286b6..bd8ca26 100644 --- a/ot/optim.py +++ b/ot/optim.py @@ -267,7 +267,7 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000, Mi += nx.min(Mi) # solve linear program - Gc = emd(a, b, Mi, numItermax=numItermaxEmd) + Gc, logemd = emd(a, b, Mi, numItermax=numItermaxEmd, log=True) deltaG = Gc - G @@ -297,6 +297,7 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000, print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, f_val, relative_delta_fval, abs_delta_fval)) if log: + log.update(logemd) return G, log else: return G diff --git a/test/test_gromov.py b/test/test_gromov.py index 509c54d..bcbcc3a 100644 --- a/test/test_gromov.py +++ b/test/test_gromov.py @@ -9,6 +9,7 @@ import numpy as np import ot from ot.backend import NumpyBackend +from ot.backend import torch import pytest @@ -74,6 +75,42 @@ def test_gromov(nx): q, Gb.sum(0), atol=1e-04) # cf convergence gromov +def test_gromov2_gradients(): + n_samples = 50 # nb samples + + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=4) + + xt = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=5) + + p = ot.unif(n_samples) + q = ot.unif(n_samples) + + C1 = ot.dist(xs, xs) + C2 = ot.dist(xt, xt) + + C1 /= C1.max() + C2 /= C2.max() + + if torch: + + p1 = torch.tensor(p, requires_grad=True) + q1 = torch.tensor(q, requires_grad=True) + C11 = torch.tensor(C1, requires_grad=True) + C12 = torch.tensor(C2, requires_grad=True) + + val = ot.gromov_wasserstein2(C11, C12, p1, q1) + + val.backward() + + assert q1.shape == q1.grad.shape + assert p1.shape == p1.grad.shape + assert C11.shape == C11.grad.shape + assert C12.shape == C12.grad.shape + + @pytest.skip_backend("jax", reason="test very slow with jax backend") def test_entropic_gromov(nx): n_samples = 50 # nb samples @@ -389,6 +426,45 @@ def test_fgw(nx): q, Gb.sum(0), atol=1e-04) # cf convergence gromov +def test_fgw2_gradients(): + n_samples = 50 # nb samples + + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=4) + + xt = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=5) + + p = ot.unif(n_samples) + q = ot.unif(n_samples) + + C1 = ot.dist(xs, xs) + C2 = ot.dist(xt, xt) + M = ot.dist(xs, xt) + + C1 /= C1.max() + C2 /= C2.max() + + if torch: + + p1 = torch.tensor(p, requires_grad=True) + q1 = torch.tensor(q, requires_grad=True) + C11 = torch.tensor(C1, requires_grad=True) + C12 = torch.tensor(C2, requires_grad=True) + M1 = torch.tensor(M, requires_grad=True) + + val = ot.fused_gromov_wasserstein2(M1, C11, C12, p1, q1) + + val.backward() + + assert q1.shape == q1.grad.shape + assert p1.shape == p1.grad.shape + assert C11.shape == C11.grad.shape + assert C12.shape == C12.grad.shape + assert M1.shape == M1.grad.shape + + def test_fgw_barycenter(nx): np.random.seed(42) -- cgit v1.2.3 From 0e431c203a66c6d48e6bb1efeda149460472a0f0 Mon Sep 17 00:00:00 2001 From: Rémi Flamary Date: Thu, 4 Nov 2021 15:19:57 +0100 Subject: [MRG] Add tests about type and GPU for emd/emd2 + 1d variants + wasserstein1d (#304) * new test gpu * pep 8 of couse * debug torch * jax with gpu * device put * device put * it works * emd1d and emd2_1d working * emd_1d and emd2_1d done * cleanup * of course * should work on gpu now * tests done+ pep8 --- ot/backend.py | 20 ++++++++++- ot/lp/solver_1d.py | 14 ++++---- test/test_1d_solver.py | 93 ++++++++++++++++++++++++++++++++++++++++++++++++++ test/test_ot.py | 67 +++++++++++++++--------------------- 4 files changed, 146 insertions(+), 48 deletions(-) (limited to 'test') diff --git a/ot/backend.py b/ot/backend.py index d3df44c..55e10d3 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -102,6 +102,7 @@ class Backend(): __name__ = None __type__ = None + __type_list__ = None rng_ = None @@ -663,6 +664,8 @@ class NumpyBackend(Backend): __name__ = 'numpy' __type__ = np.ndarray + __type_list__ = [np.array(1, dtype=np.float32), + np.array(1, dtype=np.float64)] rng_ = np.random.RandomState() @@ -888,12 +891,17 @@ class JaxBackend(Backend): __name__ = 'jax' __type__ = jax_type + __type_list__ = None rng_ = None def __init__(self): self.rng_ = jax.random.PRNGKey(42) + for d in jax.devices(): + self.__type_list__ = [jax.device_put(jnp.array(1, dtype=np.float32), d), + jax.device_put(jnp.array(1, dtype=np.float64), d)] + def to_numpy(self, a): return np.array(a) @@ -901,7 +909,7 @@ class JaxBackend(Backend): if type_as is None: return jnp.array(a) else: - return jnp.array(a).astype(type_as.dtype) + return jax.device_put(jnp.array(a).astype(type_as.dtype), type_as.device_buffer.device()) def set_gradients(self, val, inputs, grads): from jax.flatten_util import ravel_pytree @@ -1130,6 +1138,7 @@ class TorchBackend(Backend): __name__ = 'torch' __type__ = torch_type + __type_list__ = None rng_ = None @@ -1138,6 +1147,13 @@ class TorchBackend(Backend): self.rng_ = torch.Generator() self.rng_.seed() + self.__type_list__ = [torch.tensor(1, dtype=torch.float32), + torch.tensor(1, dtype=torch.float64)] + + if torch.cuda.is_available(): + self.__type_list__.append(torch.tensor(1, dtype=torch.float32, device='cuda')) + self.__type_list__.append(torch.tensor(1, dtype=torch.float64, device='cuda')) + from torch.autograd import Function # define a function that takes inputs val and grads @@ -1160,6 +1176,8 @@ class TorchBackend(Backend): return a.cpu().detach().numpy() def from_numpy(self, a, type_as=None): + if isinstance(a, float): + a = np.array(a) if type_as is None: return torch.from_numpy(a) else: diff --git a/ot/lp/solver_1d.py b/ot/lp/solver_1d.py index 42554aa..8b4d0c3 100644 --- a/ot/lp/solver_1d.py +++ b/ot/lp/solver_1d.py @@ -235,8 +235,8 @@ def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True, # ensure that same mass np.testing.assert_almost_equal( - nx.sum(a, axis=0), - nx.sum(b, axis=0), + nx.to_numpy(nx.sum(a, axis=0)), + nx.to_numpy(nx.sum(b, axis=0)), err_msg='a and b vector must have the same sum' ) b = b * nx.sum(a) / nx.sum(b) @@ -247,10 +247,10 @@ def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True, perm_b = nx.argsort(x_b_1d) G_sorted, indices, cost = emd_1d_sorted( - nx.to_numpy(a[perm_a]), - nx.to_numpy(b[perm_b]), - nx.to_numpy(x_a_1d[perm_a]), - nx.to_numpy(x_b_1d[perm_b]), + nx.to_numpy(a[perm_a]).astype(np.float64), + nx.to_numpy(b[perm_b]).astype(np.float64), + nx.to_numpy(x_a_1d[perm_a]).astype(np.float64), + nx.to_numpy(x_b_1d[perm_b]).astype(np.float64), metric=metric, p=p ) @@ -266,7 +266,7 @@ def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True, elif str(nx) == "jax": warnings.warn("JAX does not support sparse matrices, converting to dense") if log: - log = {'cost': cost} + log = {'cost': nx.from_numpy(cost, type_as=x_a)} return G, log return G diff --git a/test/test_1d_solver.py b/test/test_1d_solver.py index 2c470c2..77b1234 100644 --- a/test/test_1d_solver.py +++ b/test/test_1d_solver.py @@ -83,3 +83,96 @@ def test_wasserstein_1d(nx): Xb = nx.from_numpy(X) res = wasserstein_1d(Xb, Xb, rho_ub, rho_vb, p=2) np.testing.assert_almost_equal(100 * res[0], res[1], decimal=4) + + +@pytest.mark.parametrize('nx', backend_list) +def test_wasserstein_1d_type_devices(nx): + + rng = np.random.RandomState(0) + + n = 10 + x = np.linspace(0, 5, n) + rho_u = np.abs(rng.randn(n)) + rho_u /= rho_u.sum() + rho_v = np.abs(rng.randn(n)) + rho_v /= rho_v.sum() + + for tp in nx.__type_list__: + + print(tp.dtype) + + xb = nx.from_numpy(x, type_as=tp) + rho_ub = nx.from_numpy(rho_u, type_as=tp) + rho_vb = nx.from_numpy(rho_v, type_as=tp) + + res = wasserstein_1d(xb, xb, rho_ub, rho_vb, p=1) + + if not str(nx) == 'numpy': + assert res.dtype == xb.dtype + + +def test_emd_1d_emd2_1d(): + # test emd1d gives similar results as emd + n = 20 + m = 30 + rng = np.random.RandomState(0) + u = rng.randn(n, 1) + v = rng.randn(m, 1) + + M = ot.dist(u, v, metric='sqeuclidean') + + G, log = ot.emd([], [], M, log=True) + wass = log["cost"] + G_1d, log = ot.emd_1d(u, v, [], [], metric='sqeuclidean', log=True) + wass1d = log["cost"] + wass1d_emd2 = ot.emd2_1d(u, v, [], [], metric='sqeuclidean', log=False) + wass1d_euc = ot.emd2_1d(u, v, [], [], metric='euclidean', log=False) + + # check loss is similar + np.testing.assert_allclose(wass, wass1d) + np.testing.assert_allclose(wass, wass1d_emd2) + + # check loss is similar to scipy's implementation for Euclidean metric + wass_sp = wasserstein_distance(u.reshape((-1,)), v.reshape((-1,))) + np.testing.assert_allclose(wass_sp, wass1d_euc) + + # check constraints + np.testing.assert_allclose(np.ones((n,)) / n, G.sum(1)) + np.testing.assert_allclose(np.ones((m,)) / m, G.sum(0)) + + # check G is similar + np.testing.assert_allclose(G, G_1d, atol=1e-15) + + # check AssertionError is raised if called on non 1d arrays + u = np.random.randn(n, 2) + v = np.random.randn(m, 2) + with pytest.raises(AssertionError): + ot.emd_1d(u, v, [], []) + + +def test_emd1d_type_devices(nx): + + rng = np.random.RandomState(0) + + n = 10 + x = np.linspace(0, 5, n) + rho_u = np.abs(rng.randn(n)) + rho_u /= rho_u.sum() + rho_v = np.abs(rng.randn(n)) + rho_v /= rho_v.sum() + + for tp in nx.__type_list__: + + print(tp.dtype) + + xb = nx.from_numpy(x, type_as=tp) + rho_ub = nx.from_numpy(rho_u, type_as=tp) + rho_vb = nx.from_numpy(rho_v, type_as=tp) + + emd = ot.emd_1d(xb, xb, rho_ub, rho_vb) + + emd2 = ot.emd2_1d(xb, xb, rho_ub, rho_vb) + + assert emd.dtype == xb.dtype + if not str(nx) == 'numpy': + assert emd2.dtype == xb.dtype diff --git a/test/test_ot.py b/test/test_ot.py index 5bfde1d..dc3930a 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -12,7 +12,6 @@ import pytest import ot from ot.datasets import make_1D_gauss as gauss from ot.backend import torch -from scipy.stats import wasserstein_distance def test_emd_dimension_and_mass_mismatch(): @@ -77,6 +76,33 @@ def test_emd2_backends(nx): np.allclose(val, nx.to_numpy(valb)) +def test_emd_emd2_types_devices(nx): + n_samples = 100 + n_features = 2 + rng = np.random.RandomState(0) + + x = rng.randn(n_samples, n_features) + y = rng.randn(n_samples, n_features) + a = ot.utils.unif(n_samples) + + M = ot.dist(x, y) + + for tp in nx.__type_list__: + + print(tp.dtype) + + ab = nx.from_numpy(a, type_as=tp) + Mb = nx.from_numpy(M, type_as=tp) + + Gb = ot.emd(ab, ab, Mb) + + w = ot.emd2(ab, ab, Mb) + + assert Gb.dtype == Mb.dtype + if not str(nx) == 'numpy': + assert w.dtype == Mb.dtype + + def test_emd2_gradients(): n_samples = 100 n_features = 2 @@ -126,45 +152,6 @@ def test_emd_emd2(): np.testing.assert_allclose(w, 0) -def test_emd_1d_emd2_1d(): - # test emd1d gives similar results as emd - n = 20 - m = 30 - rng = np.random.RandomState(0) - u = rng.randn(n, 1) - v = rng.randn(m, 1) - - M = ot.dist(u, v, metric='sqeuclidean') - - G, log = ot.emd([], [], M, log=True) - wass = log["cost"] - G_1d, log = ot.emd_1d(u, v, [], [], metric='sqeuclidean', log=True) - wass1d = log["cost"] - wass1d_emd2 = ot.emd2_1d(u, v, [], [], metric='sqeuclidean', log=False) - wass1d_euc = ot.emd2_1d(u, v, [], [], metric='euclidean', log=False) - - # check loss is similar - np.testing.assert_allclose(wass, wass1d) - np.testing.assert_allclose(wass, wass1d_emd2) - - # check loss is similar to scipy's implementation for Euclidean metric - wass_sp = wasserstein_distance(u.reshape((-1,)), v.reshape((-1,))) - np.testing.assert_allclose(wass_sp, wass1d_euc) - - # check constraints - np.testing.assert_allclose(np.ones((n,)) / n, G.sum(1)) - np.testing.assert_allclose(np.ones((m,)) / m, G.sum(0)) - - # check G is similar - np.testing.assert_allclose(G, G_1d, atol=1e-15) - - # check AssertionError is raised if called on non 1d arrays - u = np.random.randn(n, 2) - v = np.random.randn(m, 2) - with pytest.raises(AssertionError): - ot.emd_1d(u, v, [], []) - - def test_emd_empty(): # test emd and emd2 for simple identity n = 100 -- cgit v1.2.3 From 0eac835c70cc1a13bb998f3b6cdb0515fafc05e1 Mon Sep 17 00:00:00 2001 From: ncassereau-idris <84033440+ncassereau-idris@users.noreply.github.com> Date: Fri, 5 Nov 2021 15:57:08 +0100 Subject: [MRG] Tests with types/device on sliced/bregman/gromov functions (#303) * First draft : making pytest use gpu for torch testing * bug solve * Revert "bug solve" This reverts commit 29b013abd162f8693128f26d8129186b79923609. * Revert "First draft : making pytest use gpu for torch testing" This reverts commit 2778175bcc338016c704efa4187d132fe5162e3a. * sliced * sliced * ot 1dsolver * bregman * better print * jax works with sinkhorn, sinkhorn_log and sinkhornn_stabilized, no need to skip them * gromov & entropic gromov --- ot/backend.py | 59 ++++++++++++++++++++++++++++++++++----- ot/sliced.py | 8 +++--- test/conftest.py | 25 +++++++++++++---- test/test_1d_solver.py | 16 ++++------- test/test_bregman.py | 45 ++++++++++++++++++++++++++++++ test/test_gromov.py | 75 ++++++++++++++++++++++++++++++++++++++++++++++++++ test/test_ot.py | 8 ++---- test/test_sliced.py | 44 +++++++++++++++++++++++++++++ 8 files changed, 247 insertions(+), 33 deletions(-) (limited to 'test') diff --git a/ot/backend.py b/ot/backend.py index 55e10d3..a044f84 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -653,6 +653,18 @@ class Backend(): """ raise NotImplementedError() + def dtype_device(self, a): + r""" + Returns the dtype and the device of the given tensor. + """ + raise NotImplementedError() + + def assert_same_dtype_device(self, a, b): + r""" + Checks whether or not the two given inputs have the same dtype as well as the same device + """ + raise NotImplementedError() + class NumpyBackend(Backend): """ @@ -880,6 +892,16 @@ class NumpyBackend(Backend): def allclose(self, a, b, rtol=1e-05, atol=1e-08, equal_nan=False): return np.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan) + def dtype_device(self, a): + if hasattr(a, "dtype"): + return a.dtype, "cpu" + else: + return type(a), "cpu" + + def assert_same_dtype_device(self, a, b): + # numpy has implicit type conversion so we automatically validate the test + pass + class JaxBackend(Backend): """ @@ -899,17 +921,20 @@ class JaxBackend(Backend): self.rng_ = jax.random.PRNGKey(42) for d in jax.devices(): - self.__type_list__ = [jax.device_put(jnp.array(1, dtype=np.float32), d), - jax.device_put(jnp.array(1, dtype=np.float64), d)] + self.__type_list__ = [jax.device_put(jnp.array(1, dtype=jnp.float32), d), + jax.device_put(jnp.array(1, dtype=jnp.float64), d)] def to_numpy(self, a): return np.array(a) + def _change_device(self, a, type_as): + return jax.device_put(a, type_as.device_buffer.device()) + def from_numpy(self, a, type_as=None): if type_as is None: return jnp.array(a) else: - return jax.device_put(jnp.array(a).astype(type_as.dtype), type_as.device_buffer.device()) + return self._change_device(jnp.array(a).astype(type_as.dtype), type_as) def set_gradients(self, val, inputs, grads): from jax.flatten_util import ravel_pytree @@ -928,13 +953,13 @@ class JaxBackend(Backend): if type_as is None: return jnp.zeros(shape) else: - return jnp.zeros(shape, dtype=type_as.dtype) + return self._change_device(jnp.zeros(shape, dtype=type_as.dtype), type_as) def ones(self, shape, type_as=None): if type_as is None: return jnp.ones(shape) else: - return jnp.ones(shape, dtype=type_as.dtype) + return self._change_device(jnp.ones(shape, dtype=type_as.dtype), type_as) def arange(self, stop, start=0, step=1, type_as=None): return jnp.arange(start, stop, step) @@ -943,13 +968,13 @@ class JaxBackend(Backend): if type_as is None: return jnp.full(shape, fill_value) else: - return jnp.full(shape, fill_value, dtype=type_as.dtype) + return self._change_device(jnp.full(shape, fill_value, dtype=type_as.dtype), type_as) def eye(self, N, M=None, type_as=None): if type_as is None: return jnp.eye(N, M) else: - return jnp.eye(N, M, dtype=type_as.dtype) + return self._change_device(jnp.eye(N, M, dtype=type_as.dtype), type_as) def sum(self, a, axis=None, keepdims=False): return jnp.sum(a, axis, keepdims=keepdims) @@ -1127,6 +1152,16 @@ class JaxBackend(Backend): def allclose(self, a, b, rtol=1e-05, atol=1e-08, equal_nan=False): return jnp.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan) + def dtype_device(self, a): + return a.dtype, a.device_buffer.device() + + def assert_same_dtype_device(self, a, b): + a_dtype, a_device = self.dtype_device(a) + b_dtype, b_device = self.dtype_device(b) + + assert a_dtype == b_dtype, "Dtype discrepancy" + assert a_device == b_device, f"Device discrepancy. First input is on {str(a_device)}, whereas second input is on {str(b_device)}" + class TorchBackend(Backend): """ @@ -1455,3 +1490,13 @@ class TorchBackend(Backend): def allclose(self, a, b, rtol=1e-05, atol=1e-08, equal_nan=False): return torch.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan) + + def dtype_device(self, a): + return a.dtype, a.device + + def assert_same_dtype_device(self, a, b): + a_dtype, a_device = self.dtype_device(a) + b_dtype, b_device = self.dtype_device(b) + + assert a_dtype == b_dtype, "Dtype discrepancy" + assert a_device == b_device, f"Device discrepancy. First input is on {str(a_device)}, whereas second input is on {str(b_device)}" diff --git a/ot/sliced.py b/ot/sliced.py index 7c09111..cf2d3be 100644 --- a/ot/sliced.py +++ b/ot/sliced.py @@ -139,9 +139,9 @@ def sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, p=2, X_t.shape[1])) if a is None: - a = nx.full(n, 1 / n) + a = nx.full(n, 1 / n, type_as=X_s) if b is None: - b = nx.full(m, 1 / m) + b = nx.full(m, 1 / m, type_as=X_s) d = X_s.shape[1] @@ -238,9 +238,9 @@ def max_sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, X_t.shape[1])) if a is None: - a = nx.full(n, 1 / n) + a = nx.full(n, 1 / n, type_as=X_s) if b is None: - b = nx.full(m, 1 / m) + b = nx.full(m, 1 / m, type_as=X_s) d = X_s.shape[1] diff --git a/test/conftest.py b/test/conftest.py index 876b525..987d98e 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -11,6 +11,7 @@ import functools if jax: from jax.config import config + config.update("jax_enable_x64", True) backend_list = get_backend_list() @@ -18,16 +19,25 @@ backend_list = get_backend_list() @pytest.fixture(params=backend_list) def nx(request): backend = request.param - if backend.__name__ == "jax": - config.update("jax_enable_x64", True) yield backend - if backend.__name__ == "jax": - config.update("jax_enable_x64", False) - def skip_arg(arg, value, reason=None, getter=lambda x: x): + if isinstance(arg, tuple) or isinstance(arg, list): + n = len(arg) + else: + arg = (arg, ) + n = 1 + if n != 1 and (isinstance(value, tuple) or isinstance(value, list)): + pass + else: + value = (value, ) + if isinstance(getter, tuple) or isinstance(value, list): + pass + else: + getter = [getter] * n + if reason is None: reason = f"Param {arg} should be skipped for value {value}" @@ -35,7 +45,10 @@ def skip_arg(arg, value, reason=None, getter=lambda x: x): @functools.wraps(function) def wrapped(*args, **kwargs): - if arg in kwargs.keys() and getter(kwargs[arg]) == value: + if all( + arg[i] in kwargs.keys() and getter[i](kwargs[arg[i]]) == value[i] + for i in range(n) + ): pytest.skip(reason) return function(*args, **kwargs) diff --git a/test/test_1d_solver.py b/test/test_1d_solver.py index 77b1234..cb85cb9 100644 --- a/test/test_1d_solver.py +++ b/test/test_1d_solver.py @@ -85,7 +85,6 @@ def test_wasserstein_1d(nx): np.testing.assert_almost_equal(100 * res[0], res[1], decimal=4) -@pytest.mark.parametrize('nx', backend_list) def test_wasserstein_1d_type_devices(nx): rng = np.random.RandomState(0) @@ -98,8 +97,7 @@ def test_wasserstein_1d_type_devices(nx): rho_v /= rho_v.sum() for tp in nx.__type_list__: - - print(tp.dtype) + print(nx.dtype_device(tp)) xb = nx.from_numpy(x, type_as=tp) rho_ub = nx.from_numpy(rho_u, type_as=tp) @@ -107,8 +105,7 @@ def test_wasserstein_1d_type_devices(nx): res = wasserstein_1d(xb, xb, rho_ub, rho_vb, p=1) - if not str(nx) == 'numpy': - assert res.dtype == xb.dtype + nx.assert_same_dtype_device(xb, res) def test_emd_1d_emd2_1d(): @@ -162,17 +159,14 @@ def test_emd1d_type_devices(nx): rho_v /= rho_v.sum() for tp in nx.__type_list__: - - print(tp.dtype) + print(nx.dtype_device(tp)) xb = nx.from_numpy(x, type_as=tp) rho_ub = nx.from_numpy(rho_u, type_as=tp) rho_vb = nx.from_numpy(rho_v, type_as=tp) emd = ot.emd_1d(xb, xb, rho_ub, rho_vb) - emd2 = ot.emd2_1d(xb, xb, rho_ub, rho_vb) - assert emd.dtype == xb.dtype - if not str(nx) == 'numpy': - assert emd2.dtype == xb.dtype + nx.assert_same_dtype_device(xb, emd) + nx.assert_same_dtype_device(xb, emd2) diff --git a/test/test_bregman.py b/test/test_bregman.py index edfe9c3..830052d 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -278,6 +278,51 @@ def test_sinkhorn_variants(nx): np.testing.assert_allclose(G0, G_green, atol=1e-5) +@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized", + "sinkhorn_epsilon_scaling", + "greenkhorn", + "sinkhorn_log"]) +@pytest.skip_arg(("nx", "method"), ("jax", "sinkhorn_epsilon_scaling"), reason="jax does not support sinkhorn_epsilon_scaling", getter=str) +@pytest.skip_arg(("nx", "method"), ("jax", "greenkhorn"), reason="jax does not support greenkhorn", getter=str) +def test_sinkhorn_variants_dtype_device(nx, method): + n = 100 + + x = np.random.randn(n, 2) + u = ot.utils.unif(n) + + M = ot.dist(x, x) + + for tp in nx.__type_list__: + print(nx.dtype_device(tp)) + + ub = nx.from_numpy(u, type_as=tp) + Mb = nx.from_numpy(M, type_as=tp) + + Gb = ot.sinkhorn(ub, ub, Mb, 1, method=method, stopThr=1e-10) + + nx.assert_same_dtype_device(Mb, Gb) + + +@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized", "sinkhorn_log"]) +def test_sinkhorn2_variants_dtype_device(nx, method): + n = 100 + + x = np.random.randn(n, 2) + u = ot.utils.unif(n) + + M = ot.dist(x, x) + + for tp in nx.__type_list__: + print(nx.dtype_device(tp)) + + ub = nx.from_numpy(u, type_as=tp) + Mb = nx.from_numpy(M, type_as=tp) + + lossb = ot.sinkhorn2(ub, ub, Mb, 1, method=method, stopThr=1e-10) + + nx.assert_same_dtype_device(Mb, lossb) + + @pytest.skip_backend("jax") def test_sinkhorn_variants_multi_b(nx): # test sinkhorn diff --git a/test/test_gromov.py b/test/test_gromov.py index bcbcc3a..c4bc04c 100644 --- a/test/test_gromov.py +++ b/test/test_gromov.py @@ -75,6 +75,41 @@ def test_gromov(nx): q, Gb.sum(0), atol=1e-04) # cf convergence gromov +def test_gromov_dtype_device(nx): + # setup + n_samples = 50 # nb samples + + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=4) + + xt = xs[::-1].copy() + + p = ot.unif(n_samples) + q = ot.unif(n_samples) + + C1 = ot.dist(xs, xs) + C2 = ot.dist(xt, xt) + + C1 /= C1.max() + C2 /= C2.max() + + for tp in nx.__type_list__: + print(nx.dtype_device(tp)) + + C1b = nx.from_numpy(C1, type_as=tp) + C2b = nx.from_numpy(C2, type_as=tp) + pb = nx.from_numpy(p, type_as=tp) + qb = nx.from_numpy(q, type_as=tp) + + Gb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', verbose=True) + gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', log=False) + + nx.assert_same_dtype_device(C1b, Gb) + nx.assert_same_dtype_device(C1b, gw_valb) + + def test_gromov2_gradients(): n_samples = 50 # nb samples @@ -168,6 +203,46 @@ def test_entropic_gromov(nx): q, Gb.sum(0), atol=1e-04) # cf convergence gromov +@pytest.skip_backend("jax", reason="test very slow with jax backend") +def test_entropic_gromov_dtype_device(nx): + # setup + n_samples = 50 # nb samples + + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=42) + + xt = xs[::-1].copy() + + p = ot.unif(n_samples) + q = ot.unif(n_samples) + + C1 = ot.dist(xs, xs) + C2 = ot.dist(xt, xt) + + C1 /= C1.max() + C2 /= C2.max() + + for tp in nx.__type_list__: + print(nx.dtype_device(tp)) + + C1b = nx.from_numpy(C1, type_as=tp) + C2b = nx.from_numpy(C2, type_as=tp) + pb = nx.from_numpy(p, type_as=tp) + qb = nx.from_numpy(q, type_as=tp) + + Gb = ot.gromov.entropic_gromov_wasserstein( + C1b, C2b, pb, qb, 'square_loss', epsilon=5e-4, verbose=True + ) + gw_valb = ot.gromov.entropic_gromov_wasserstein2( + C1b, C2b, pb, qb, 'square_loss', epsilon=5e-4, verbose=True + ) + + nx.assert_same_dtype_device(C1b, Gb) + nx.assert_same_dtype_device(C1b, gw_valb) + + def test_pointwise_gromov(nx): n_samples = 50 # nb samples diff --git a/test/test_ot.py b/test/test_ot.py index dc3930a..92f26a7 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -88,8 +88,7 @@ def test_emd_emd2_types_devices(nx): M = ot.dist(x, y) for tp in nx.__type_list__: - - print(tp.dtype) + print(nx.dtype_device(tp)) ab = nx.from_numpy(a, type_as=tp) Mb = nx.from_numpy(M, type_as=tp) @@ -98,9 +97,8 @@ def test_emd_emd2_types_devices(nx): w = ot.emd2(ab, ab, Mb) - assert Gb.dtype == Mb.dtype - if not str(nx) == 'numpy': - assert w.dtype == Mb.dtype + nx.assert_same_dtype_device(Mb, Gb) + nx.assert_same_dtype_device(Mb, w) def test_emd2_gradients(): diff --git a/test/test_sliced.py b/test/test_sliced.py index 0bd74ec..245202c 100644 --- a/test/test_sliced.py +++ b/test/test_sliced.py @@ -139,6 +139,28 @@ def test_sliced_backend(nx): assert np.allclose(val0, valb) +def test_sliced_backend_type_devices(nx): + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + y = rng.randn(2 * n, 2) + + P = rng.randn(2, 20) + P = P / np.sqrt((P**2).sum(0, keepdims=True)) + + for tp in nx.__type_list__: + print(nx.dtype_device(tp)) + + xb = nx.from_numpy(x, type_as=tp) + yb = nx.from_numpy(y, type_as=tp) + Pb = nx.from_numpy(P, type_as=tp) + + valb = ot.sliced_wasserstein_distance(xb, yb, projections=Pb) + + nx.assert_same_dtype_device(xb, valb) + + def test_max_sliced_backend(nx): n = 100 @@ -167,3 +189,25 @@ def test_max_sliced_backend(nx): valb = nx.to_numpy(ot.max_sliced_wasserstein_distance(xb, yb, projections=Pb)) assert np.allclose(val0, valb) + + +def test_max_sliced_backend_type_devices(nx): + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + y = rng.randn(2 * n, 2) + + P = rng.randn(2, 20) + P = P / np.sqrt((P**2).sum(0, keepdims=True)) + + for tp in nx.__type_list__: + print(nx.dtype_device(tp)) + + xb = nx.from_numpy(x, type_as=tp) + yb = nx.from_numpy(y, type_as=tp) + Pb = nx.from_numpy(P, type_as=tp) + + valb = ot.max_sliced_wasserstein_distance(xb, yb, projections=Pb) + + nx.assert_same_dtype_device(xb, valb) -- cgit v1.2.3