From 184f8f4f7ac78f1dd7f653496d2753211a4e3426 Mon Sep 17 00:00:00 2001 From: Rémi Flamary Date: Tue, 1 Jun 2021 10:10:54 +0200 Subject: [MRG] POT numpy/torch/jax backends (#249) * add numpy and torch backends * stat sets on functions * proper import * install recent torch on windows * install recent torch on windows * now testing all functions in backedn * add jax backedn * clenaup windowds * proper convert for jax backedn * pep8 * try again windows tests * test jax conversion * try proper widows tests * emd fuction ses backedn * better test partial OT * proper tests to_numpy and teplate Backend * pep8 * pep8 x2 * feaking sinkhorn works with torch * sinkhorn2 compatible * working ot.emd2 * important detach * it should work * jax autodiff emd * pep8 * no tast same for jax * new independat tests per backedn * freaking pep8 * add tests for gradients * deprecate ot.gpu * worging dist function * working dist * dist done in backedn * not in * remove indexing * change accuacy for jax * first pull backend * projection simplex * projection simplex * projection simplex * projection simplex no ci * projection simplex no ci * projection simplex no ci * pep8 * add backedn discusion to quickstart guide * projection simplex no ci * projection simplex no ci * projection simplex no ci * pep8 + better doc * proper links * corect doctest * big debug documentation * doctest again * doctest again bis * doctest again ter (last one or i kill myself) * backend test + doc proj simplex * correction test_utils * correction test_utils * correction cumsum * correction flip * correction flip v2 * more debug * more debug * more debug + pep8 * pep8 * argh * proj_simplex * backedn works for sort * proj simplex * jax sucks * update doc * Update test/test_utils.py Co-authored-by: Alexandre Gramfort * Update docs/source/quickstart.rst Co-authored-by: Alexandre Gramfort * Update docs/source/quickstart.rst Co-authored-by: Alexandre Gramfort * Update docs/source/quickstart.rst Co-authored-by: Alexandre Gramfort * Update docs/source/readme.rst Co-authored-by: Alexandre Gramfort * Update test/test_utils.py Co-authored-by: Alexandre Gramfort * Update ot/utils.py Co-authored-by: Alexandre Gramfort * Update docs/source/readme.rst Co-authored-by: Alexandre Gramfort * Update ot/lp/__init__.py Co-authored-by: Alexandre Gramfort * begin comment alex * comment alex part 2 * optimize test gromov * proj_simplex on vectors * add awesome gradient decsnt example on the weights * pep98 of course * proof read example by alex * pep8 again * encoding oos in translation * correct legend Co-authored-by: Nicolas Courty Co-authored-by: Alexandre Gramfort --- test/test_gromov.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) (limited to 'test/test_gromov.py') diff --git a/test/test_gromov.py b/test/test_gromov.py index 43da9fc..81138ca 100644 --- a/test/test_gromov.py +++ b/test/test_gromov.py @@ -181,7 +181,7 @@ def test_fgw(): M = ot.dist(ys, yt) M /= M.max() - G = ot.gromov.fused_gromov_wasserstein(M, C1, C2, p, q, 'square_loss', alpha=0.5) + G, log = ot.gromov.fused_gromov_wasserstein(M, C1, C2, p, q, 'square_loss', alpha=0.5, log=True) # check constratints np.testing.assert_allclose( @@ -242,9 +242,9 @@ def test_fgw_barycenter(): init_X = np.random.randn(n_samples, ys.shape[1]) - X, C = ot.gromov.fgw_barycenters(n_samples, [ys, yt], [C1, C2], [ot.unif(ns), ot.unif(nt)], [.5, .5], 0.5, - fixed_structure=False, fixed_features=True, init_X=init_X, - p=ot.unif(n_samples), loss_fun='square_loss', - max_iter=100, tol=1e-3) + X, C, log = ot.gromov.fgw_barycenters(n_samples, [ys, yt], [C1, C2], [ot.unif(ns), ot.unif(nt)], [.5, .5], 0.5, + fixed_structure=False, fixed_features=True, init_X=init_X, + p=ot.unif(n_samples), loss_fun='square_loss', + max_iter=100, tol=1e-3, log=True) np.testing.assert_allclose(C.shape, (n_samples, n_samples)) np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1])) -- cgit v1.2.3 From 8ef3341a472909f223ec0f678f11f136f55c1406 Mon Sep 17 00:00:00 2001 From: Rémi Flamary Date: Thu, 17 Jun 2021 11:46:37 +0200 Subject: [MRG] Speedup tests (#262) * speedup tests * add color to tests and timings * add test unbalanced * stupid missing - --- .github/workflows/build_tests.yml | 8 ++++---- Makefile | 4 ++-- test/test_bregman.py | 7 ++++--- test/test_da.py | 8 ++++---- test/test_gromov.py | 15 +++++++++------ test/test_optim.py | 6 +++--- test/test_stochastic.py | 40 +++++++++++++++++++-------------------- test/test_unbalanced.py | 33 ++++++++++++++++++++++++++++++-- 8 files changed, 77 insertions(+), 44 deletions(-) (limited to 'test/test_gromov.py') diff --git a/.github/workflows/build_tests.yml b/.github/workflows/build_tests.yml index 92a07b5..fd0ade6 100644 --- a/.github/workflows/build_tests.yml +++ b/.github/workflows/build_tests.yml @@ -40,7 +40,7 @@ jobs: pip install -e . - name: Run tests run: | - python -m pytest -v test/ ot/ --doctest-modules --ignore ot/gpu/ --cov=ot --color=yes + python -m pytest --durations=20 -v test/ ot/ --doctest-modules --ignore ot/gpu/ --cov=ot --color=yes - name: Upload codecov run: | codecov @@ -95,7 +95,7 @@ jobs: pip install -e . - name: Run tests run: | - python -m pytest -v test/ ot/ --ignore ot/gpu/ + python -m pytest --durations=20 -v test/ ot/ --ignore ot/gpu/ --color=yes macos: @@ -122,7 +122,7 @@ jobs: pip install -e . - name: Run tests run: | - python -m pytest -v test/ ot/ --doctest-modules --ignore ot/gpu/ --cov=ot + python -m pytest --durations=20 -v test/ ot/ --doctest-modules --ignore ot/gpu/ --cov=ot --color=yes windows: @@ -150,4 +150,4 @@ jobs: python -m pip install -e . - name: Run tests run: | - python -m pytest -v test/ ot/ --doctest-modules --ignore ot/gpu/ --cov=ot + python -m pytest --durations=20 -v test/ ot/ --doctest-modules --ignore ot/gpu/ --cov=ot --color=yes diff --git a/Makefile b/Makefile index 32332b4..315218d 100644 --- a/Makefile +++ b/Makefile @@ -45,10 +45,10 @@ pep8 : flake8 examples/ ot/ test/ test : FORCE pep8 - $(PYTHON) -m pytest -v test/ --doctest-modules --ignore ot/gpu/ + $(PYTHON) -m pytest --durations=20 -v test/ --doctest-modules --ignore ot/gpu/ pytest : FORCE - $(PYTHON) -m pytest -v test/ --doctest-modules --ignore ot/gpu/ + $(PYTHON) -m pytest --durations=20 -v test/ --doctest-modules --ignore ot/gpu/ release : twine upload dist/* diff --git a/test/test_bregman.py b/test/test_bregman.py index 9665229..88166a5 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -293,7 +293,7 @@ def test_unmix(): def test_empirical_sinkhorn(): # test sinkhorn - n = 100 + n = 10 a = ot.unif(n) b = ot.unif(n) @@ -332,7 +332,7 @@ def test_empirical_sinkhorn(): def test_lazy_empirical_sinkhorn(): # test sinkhorn - n = 100 + n = 10 a = ot.unif(n) b = ot.unif(n) numIterMax = 1000 @@ -342,7 +342,7 @@ def test_lazy_empirical_sinkhorn(): M = ot.dist(X_s, X_t) M_m = ot.dist(X_s, X_t, metric='minkowski') - f, g = ot.bregman.empirical_sinkhorn(X_s, X_t, 1, numIterMax=numIterMax, isLazy=True, batchSize=(1, 1), verbose=True) + f, g = ot.bregman.empirical_sinkhorn(X_s, X_t, 1, numIterMax=numIterMax, isLazy=True, batchSize=(1, 3), verbose=True) G_sqe = np.exp(f[:, None] + g[None, :] - M / 1) sinkhorn_sqe = ot.sinkhorn(a, b, M, 1) @@ -458,6 +458,7 @@ def test_implemented_methods(): ot.bregman.sinkhorn2(a, b, M, epsilon, method=method) +@pytest.mark.filterwarnings("ignore:Bottleneck") def test_screenkhorn(): # test screenkhorn rng = np.random.RandomState(0) diff --git a/test/test_da.py b/test/test_da.py index 52c6a48..44bb2e9 100644 --- a/test/test_da.py +++ b/test/test_da.py @@ -106,8 +106,8 @@ def test_sinkhorn_l1l2_transport_class(): """test_sinkhorn_transport """ - ns = 150 - nt = 200 + ns = 50 + nt = 100 Xs, ys = make_data_classif('3gauss', ns) Xt, yt = make_data_classif('3gauss2', nt) @@ -448,8 +448,8 @@ def test_mapping_transport_class(): """test_mapping_transport """ - ns = 60 - nt = 120 + ns = 20 + nt = 30 Xs, ys = make_data_classif('3gauss', ns) Xt, yt = make_data_classif('3gauss2', nt) diff --git a/test/test_gromov.py b/test/test_gromov.py index 81138ca..56414a8 100644 --- a/test/test_gromov.py +++ b/test/test_gromov.py @@ -9,6 +9,8 @@ import numpy as np import ot +import pytest + def test_gromov(): n_samples = 50 # nb samples @@ -128,9 +130,10 @@ def test_gromov_barycenter(): np.testing.assert_allclose(Cb2.shape, (n_samples, n_samples)) +@pytest.mark.filterwarnings("ignore:divide") def test_gromov_entropic_barycenter(): - ns = 50 - nt = 60 + ns = 20 + nt = 30 Xs, ys = ot.datasets.make_data_classif('3gauss', ns, random_state=42) Xt, yt = ot.datasets.make_data_classif('3gauss2', nt, random_state=42) @@ -138,19 +141,19 @@ def test_gromov_entropic_barycenter(): C1 = ot.dist(Xs) C2 = ot.dist(Xt) - n_samples = 3 + n_samples = 2 Cb = ot.gromov.entropic_gromov_barycenters(n_samples, [C1, C2], [ot.unif(ns), ot.unif(nt) ], ot.unif(n_samples), [.5, .5], - 'square_loss', 2e-3, - max_iter=100, tol=1e-3, + 'square_loss', 1e-3, + max_iter=50, tol=1e-5, verbose=True) np.testing.assert_allclose(Cb.shape, (n_samples, n_samples)) Cb2 = ot.gromov.entropic_gromov_barycenters(n_samples, [C1, C2], [ot.unif(ns), ot.unif(nt) ], ot.unif(n_samples), [.5, .5], - 'kl_loss', 2e-3, + 'kl_loss', 1e-3, max_iter=100, tol=1e-3) np.testing.assert_allclose(Cb2.shape, (n_samples, n_samples)) diff --git a/test/test_optim.py b/test/test_optim.py index 48de38a..fd194c2 100644 --- a/test/test_optim.py +++ b/test/test_optim.py @@ -37,8 +37,8 @@ def test_conditional_gradient(): np.testing.assert_allclose(b, G.sum(0)) -def test_conditional_gradient2(): - n = 1000 # nb samples +def test_conditional_gradient_itermax(): + n = 100 # nb samples mu_s = np.array([0, 0]) cov_s = np.array([[1, 0], [0, 1]]) @@ -63,7 +63,7 @@ def test_conditional_gradient2(): reg = 1e-1 - G, log = ot.optim.cg(a, b, M, reg, f, df, numItermaxEmd=200000, + G, log = ot.optim.cg(a, b, M, reg, f, df, numItermaxEmd=10000, verbose=True, log=True) np.testing.assert_allclose(a, G.sum(1)) diff --git a/test/test_stochastic.py b/test/test_stochastic.py index 155622c..98e93ec 100644 --- a/test/test_stochastic.py +++ b/test/test_stochastic.py @@ -30,7 +30,7 @@ import ot def test_stochastic_sag(): # test sag - n = 15 + n = 10 reg = 1 numItermax = 30000 rng = np.random.RandomState(0) @@ -45,9 +45,9 @@ def test_stochastic_sag(): # check constratints np.testing.assert_allclose( - u, G.sum(1), atol=1e-04) # cf convergence sag + u, G.sum(1), atol=1e-03) # cf convergence sag np.testing.assert_allclose( - u, G.sum(0), atol=1e-04) # cf convergence sag + u, G.sum(0), atol=1e-03) # cf convergence sag ############################################################################# @@ -60,9 +60,9 @@ def test_stochastic_sag(): def test_stochastic_asgd(): # test asgd - n = 15 + n = 10 reg = 1 - numItermax = 100000 + numItermax = 10000 rng = np.random.RandomState(0) x = rng.randn(n, 2) @@ -75,9 +75,9 @@ def test_stochastic_asgd(): # check constratints np.testing.assert_allclose( - u, G.sum(1), atol=1e-03) # cf convergence asgd + u, G.sum(1), atol=1e-02) # cf convergence asgd np.testing.assert_allclose( - u, G.sum(0), atol=1e-03) # cf convergence asgd + u, G.sum(0), atol=1e-02) # cf convergence asgd ############################################################################# @@ -90,9 +90,9 @@ def test_stochastic_asgd(): def test_sag_asgd_sinkhorn(): # test all algorithms - n = 15 + n = 10 reg = 1 - nb_iter = 100000 + nb_iter = 10000 rng = np.random.RandomState(0) x = rng.randn(n, 2) @@ -107,17 +107,17 @@ def test_sag_asgd_sinkhorn(): # check constratints np.testing.assert_allclose( - G_sag.sum(1), G_sinkhorn.sum(1), atol=1e-03) + G_sag.sum(1), G_sinkhorn.sum(1), atol=1e-02) np.testing.assert_allclose( - G_sag.sum(0), G_sinkhorn.sum(0), atol=1e-03) + G_sag.sum(0), G_sinkhorn.sum(0), atol=1e-02) np.testing.assert_allclose( - G_asgd.sum(1), G_sinkhorn.sum(1), atol=1e-03) + G_asgd.sum(1), G_sinkhorn.sum(1), atol=1e-02) np.testing.assert_allclose( - G_asgd.sum(0), G_sinkhorn.sum(0), atol=1e-03) + G_asgd.sum(0), G_sinkhorn.sum(0), atol=1e-02) np.testing.assert_allclose( - G_sag, G_sinkhorn, atol=1e-03) # cf convergence sag + G_sag, G_sinkhorn, atol=1e-02) # cf convergence sag np.testing.assert_allclose( - G_asgd, G_sinkhorn, atol=1e-03) # cf convergence asgd + G_asgd, G_sinkhorn, atol=1e-02) # cf convergence asgd ############################################################################# @@ -136,7 +136,7 @@ def test_stochastic_dual_sgd(): # test sgd n = 10 reg = 1 - numItermax = 15000 + numItermax = 5000 batch_size = 10 rng = np.random.RandomState(0) @@ -167,7 +167,7 @@ def test_dual_sgd_sinkhorn(): # test all dual algorithms n = 10 reg = 1 - nb_iter = 15000 + nb_iter = 5000 batch_size = 10 rng = np.random.RandomState(0) @@ -183,11 +183,11 @@ def test_dual_sgd_sinkhorn(): # check constratints np.testing.assert_allclose( - G_sgd.sum(1), G_sinkhorn.sum(1), atol=1e-03) + G_sgd.sum(1), G_sinkhorn.sum(1), atol=1e-02) np.testing.assert_allclose( - G_sgd.sum(0), G_sinkhorn.sum(0), atol=1e-03) + G_sgd.sum(0), G_sinkhorn.sum(0), atol=1e-02) np.testing.assert_allclose( - G_sgd, G_sinkhorn, atol=1e-03) # cf convergence sgd + G_sgd, G_sinkhorn, atol=1e-02) # cf convergence sgd # Test gaussian n = 30 diff --git a/test/test_unbalanced.py b/test/test_unbalanced.py index dfeaad9..e8349d1 100644 --- a/test/test_unbalanced.py +++ b/test/test_unbalanced.py @@ -115,7 +115,8 @@ def test_stabilized_vs_sinkhorn(): G, log = ot.unbalanced.sinkhorn_unbalanced2(a, b, M, reg=epsilon, method="sinkhorn_stabilized", reg_m=reg_m, - log=True) + log=True, + verbose=True) G2, log2 = ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, reg_m, method="sinkhorn", log=True) @@ -138,7 +139,7 @@ def test_unbalanced_barycenter(method): reg_m = 1. q, log = barycenter_unbalanced(A, M, reg=epsilon, reg_m=reg_m, - method=method, log=True) + method=method, log=True, verbose=True) # check fixed point equations fi = reg_m / (reg_m + epsilon) logA = np.log(A + 1e-16) @@ -173,6 +174,7 @@ def test_barycenter_stabilized_vs_sinkhorn(): reg_m=reg_m, log=True, tau=100, method="sinkhorn_stabilized", + verbose=True ) q, log = barycenter_unbalanced(A, M, reg=epsilon, reg_m=reg_m, method="sinkhorn", @@ -182,6 +184,33 @@ def test_barycenter_stabilized_vs_sinkhorn(): q, qstable, atol=1e-05) +def test_wrong_method(): + + n = 10 + rng = np.random.RandomState(42) + + x = rng.randn(n, 2) + a = ot.utils.unif(n) + + # make dists unbalanced + b = ot.utils.unif(n) * 1.5 + + M = ot.dist(x, x) + epsilon = 1. + reg_m = 1. + + with pytest.raises(ValueError): + ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon, + reg_m=reg_m, + method='badmethod', + log=True, + verbose=True) + with pytest.raises(ValueError): + ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, reg_m, + method='badmethod', + verbose=True) + + def test_implemented_methods(): IMPLEMENTED_METHODS = ['sinkhorn', 'sinkhorn_stabilized'] TO_BE_IMPLEMENTED_METHODS = ['sinkhorn_reg_scaling'] -- cgit v1.2.3 From e0ba31ce39a7d9e65e50ea970a574b3db54e4207 Mon Sep 17 00:00:00 2001 From: Tanguy Date: Fri, 17 Sep 2021 18:36:33 +0200 Subject: [MRG] Implementation of two news algorithms: SaGroW and PoGroW. (#275) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add two new algorithms to solve Gromov Wasserstein: Sampled Gromov Wasserstein and Pointwise Gromov Wasserstein. * Correct some lines in SaGroW and PoGroW to follow pep8 guide. * Change nb_samples name. Use rdm state. Change symmetric check. * Change names of len(p) and len(q) in SaGroW and PoGroW. * Re-add some deleted lines in the comments of gromov.py Co-authored-by: Rémi Flamary --- README.md | 4 + examples/gromov/plot_gromov.py | 34 ++++ ot/gromov.py | 376 +++++++++++++++++++++++++++++++++++++++++ test/test_gromov.py | 88 +++++++++- 4 files changed, 496 insertions(+), 6 deletions(-) (limited to 'test/test_gromov.py') diff --git a/README.md b/README.md index 6a2cf15..266d847 100644 --- a/README.md +++ b/README.md @@ -28,6 +28,7 @@ POT provides the following generic OT solvers (links to examples): * [Gromov-Wasserstein distances](https://pythonot.github.io/auto_examples/gromov/plot_gromov.html) and [GW barycenters](https://pythonot.github.io/auto_examples/gromov/plot_gromov_barycenter.html) (exact [13] and regularized [12]) * [Fused-Gromov-Wasserstein distances solver](https://pythonot.github.io/auto_examples/gromov/plot_fgw.html#sphx-glr-auto-examples-plot-fgw-py) and [FGW barycenters](https://pythonot.github.io/auto_examples/gromov/plot_barycenter_fgw.html) [24] * [Stochastic solver](https://pythonot.github.io/auto_examples/plot_stochastic.html) for Large-scale Optimal Transport (semi-dual problem [18] and dual problem [19]) +* [Stochastic solver of Gromov Wasserstein](https://pythonot.github.io/auto_examples/gromov/plot_gromov.html) for large-scale problem with any loss functions [33] * Non regularized [free support Wasserstein barycenters](https://pythonot.github.io/auto_examples/barycenters/plot_free_support_barycenter.html) [20]. * [Unbalanced OT](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT_1D.html) with KL relaxation and [barycenter](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT_barycenter_1D.html) [10, 25]. * [Partial Wasserstein and Gromov-Wasserstein](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_partial_wass_and_gromov.html) (exact [29] and entropic [3] @@ -198,6 +199,7 @@ The contributors to this library are * [Mokhtar Z. Alaya](http://mzalaya.github.io/) (Screenkhorn) * [Ievgen Redko](https://ievred.github.io/) (Laplacian DA, JCPOT) * [Adrien Corenflos](https://adriencorenflos.github.io/) (Sliced Wasserstein Distance) +* [Tanguy Kerdoncuff](https://hv0nnus.github.io/) (Sampled Gromov Wasserstein) * [Minhui Huang](https://mhhuang95.github.io) (Projection Robust Wasserstein Distance) This toolbox benefit a lot from open source research and we would like to thank the following persons for providing some code (in various languages): @@ -286,3 +288,5 @@ You can also post bug reports and feature requests in Github issues. Make sure t [31] Bonneel, Nicolas, et al. [Sliced and radon wasserstein barycenters of measures](https://perso.liris.cnrs.fr/nicolas.bonneel/WassersteinSliced-JMIV.pdf), Journal of Mathematical Imaging and Vision 51.1 (2015): 22-45 [32] Huang, M., Ma S., Lai, L. (2021). [A Riemannian Block Coordinate Descent Method for Computing the Projection Robust Wasserstein Distance](http://proceedings.mlr.press/v139/huang21e.html), Proceedings of the 38th International Conference on Machine Learning (ICML). + +[33] Kerdoncuff T., Emonet R., Marc S. [Sampled Gromov Wasserstein](https://hal.archives-ouvertes.fr/hal-03232509/document), Machine Learning Journal (MJL), 2021 diff --git a/examples/gromov/plot_gromov.py b/examples/gromov/plot_gromov.py index deb2f86..5a362cf 100644 --- a/examples/gromov/plot_gromov.py +++ b/examples/gromov/plot_gromov.py @@ -104,3 +104,37 @@ pl.imshow(gw, cmap='jet') pl.title('Entropic Gromov Wasserstein') pl.show() + +############################################################################# +# +# Compute GW with a scalable stochastic method with any loss function +# ---------------------------------------------------------------------- + + +def loss(x, y): + return np.abs(x - y) + + +pgw, plog = ot.gromov.pointwise_gromov_wasserstein(C1, C2, p, q, loss, max_iter=100, + log=True) + +sgw, slog = ot.gromov.sampled_gromov_wasserstein(C1, C2, p, q, loss, epsilon=0.1, max_iter=100, + log=True) + +print('Pointwise Gromov-Wasserstein distance estimated: ' + str(plog['gw_dist_estimated'])) +print('Variance estimated: ' + str(plog['gw_dist_std'])) +print('Sampled Gromov-Wasserstein distance: ' + str(slog['gw_dist_estimated'])) +print('Variance estimated: ' + str(slog['gw_dist_std'])) + + +pl.figure(1, (10, 5)) + +pl.subplot(1, 2, 1) +pl.imshow(pgw.toarray(), cmap='jet') +pl.title('Pointwise Gromov Wasserstein') + +pl.subplot(1, 2, 2) +pl.imshow(sgw, cmap='jet') +pl.title('Sampled Gromov Wasserstein') + +pl.show() diff --git a/ot/gromov.py b/ot/gromov.py index 8f457e9..a27217a 100644 --- a/ot/gromov.py +++ b/ot/gromov.py @@ -16,6 +16,10 @@ import numpy as np from .bregman import sinkhorn from .utils import dist, UndefinedParameter from .optim import cg +from .lp import emd_1d, emd +from .utils import check_random_state + +from scipy.sparse import issparse def init_matrix(C1, C2, p, q, loss_fun='square_loss'): @@ -572,6 +576,378 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5 return log['fgw_dist'] +def GW_distance_estimation(C1, C2, p, q, loss_fun, T, + nb_samples_p=None, nb_samples_q=None, std=True, random_state=None): + r""" + Returns an approximation of the gromov-wasserstein cost between (C1,p) and (C2,q) + with a fixed transport plan T. + + The function gives an unbiased approximation of the following equation: + + .. math:: + GW = \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l} + + Where : + + - C1 : Metric cost matrix in the source space + - C2 : Metric cost matrix in the target space + - L : Loss function to account for the misfit between the similarity matrices + - T : Matrix with marginal p and q + + Parameters + ---------- + C1 : ndarray, shape (ns, ns) + Metric cost matrix in the source space + C2 : ndarray, shape (nt, nt) + Metric costfr matrix in the target space + p : ndarray, shape (ns,) + Distribution in the source space + q : ndarray, shape (nt,) + Distribution in the target space + loss_fun : function: \mathcal{R} \times \mathcal{R} \shortarrow \mathcal{R} + Loss function used for the distance, the transport plan does not depend on the loss function + T : csr or ndarray, shape (ns, nt) + Transport plan matrix, either a sparse csr matrix or + nb_samples_p : int, optional + nb_samples_p is the number of samples (without replacement) along the first dimension of T. + nb_samples_q : int, optional + nb_samples_q is the number of samples along the second dimension of T, for each sample along the first. + std : bool, optional + Standard deviation associated with the prediction of the gromov-wasserstein cost. + random_state : int or RandomState instance, optional + Fix the seed for to allow reproducibility + + Returns + ------- + : float + Gromov-wasserstein cost + + References + ---------- + .. [14] Kerdoncuff, Tanguy, Emonet, Rémi, Sebban, Marc + "Sampled Gromov Wasserstein." + Machine Learning Journal (MLJ). 2021. + + """ + generator = check_random_state(random_state) + + len_p = len(p) + len_q = len(q) + + # It is always better to sample from the biggest distribution first. + if len_p < len_q: + p, q = q, p + len_p, len_q = len_q, len_p + C1, C2 = C2, C1 + T = T.T + + if nb_samples_p is None: + if issparse(T): + # If T is sparse, it probably mean that PoGroW was used, thus the number of sample is reduced + nb_samples_p = min(int(5 * (len_p * np.log(len_p)) ** 0.5), len_p) + else: + nb_samples_p = len_p + else: + # The number of sample along the first dimension is without replacement. + nb_samples_p = min(nb_samples_p, len_p) + if nb_samples_q is None: + nb_samples_q = 1 + if std: + nb_samples_q = max(2, nb_samples_q) + + index_k = np.zeros((nb_samples_p, nb_samples_q), dtype=int) + index_l = np.zeros((nb_samples_p, nb_samples_q), dtype=int) + list_value_sample = np.zeros((nb_samples_p, nb_samples_p, nb_samples_q)) + + index_i = generator.choice(len_p, size=nb_samples_p, p=p, replace=False) + index_j = generator.choice(len_p, size=nb_samples_p, p=p, replace=False) + + for i in range(nb_samples_p): + if issparse(T): + T_indexi = T[index_i[i], :].toarray()[0] + T_indexj = T[index_j[i], :].toarray()[0] + else: + T_indexi = T[index_i[i], :] + T_indexj = T[index_j[i], :] + # For each of the row sampled, the column is sampled. + index_k[i] = generator.choice(len_q, size=nb_samples_q, p=T_indexi / T_indexi.sum(), replace=True) + index_l[i] = generator.choice(len_q, size=nb_samples_q, p=T_indexj / T_indexj.sum(), replace=True) + + for n in range(nb_samples_q): + list_value_sample[:, :, n] = loss_fun(C1[np.ix_(index_i, index_j)], C2[np.ix_(index_k[:, n], index_l[:, n])]) + + if std: + std_value = np.sum(np.std(list_value_sample, axis=2) ** 2) ** 0.5 + return np.mean(list_value_sample), std_value / (nb_samples_p * nb_samples_p) + else: + return np.mean(list_value_sample) + + +def pointwise_gromov_wasserstein(C1, C2, p, q, loss_fun, + alpha=1, max_iter=100, threshold_plan=0, log=False, verbose=False, random_state=None): + r""" + Returns the gromov-wasserstein transport between (C1,p) and (C2,q) using a stochastic Frank-Wolfe. + This method as a O(max_iter \times PN^2) time complexity with P the number of Sinkhorn iterations. + + The function solves the following optimization problem: + + .. math:: + GW = arg\min_T \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l} + + s.t. T 1 = p + + T^T 1= q + + T\geq 0 + + Where : + + - C1 : Metric cost matrix in the source space + - C2 : Metric cost matrix in the target space + - p : distribution in the source space + - q : distribution in the target space + - L : loss function to account for the misfit between the similarity matrices + + Parameters + ---------- + C1 : ndarray, shape (ns, ns) + Metric cost matrix in the source space + C2 : ndarray, shape (nt, nt) + Metric costfr matrix in the target space + p : ndarray, shape (ns,) + Distribution in the source space + q : ndarray, shape (nt,) + Distribution in the target space + loss_fun : function: \mathcal{R} \times \mathcal{R} \shortarrow \mathcal{R} + Loss function used for the distance, the transport plan does not depend on the loss function + alpha : float + Step of the Frank-Wolfe algorithm, should be between 0 and 1 + max_iter : int, optional + Max number of iterations + threshold_plan : float, optional + Deleting very small value in the transport plan. If above zero, it violate the marginal constraints. + verbose : bool, optional + Print information along iterations + log : bool, optional + Gives the distance estimated and the standard deviation + random_state : int or RandomState instance, optional + Fix the seed for to allow reproducibility + + Returns + ------- + T : ndarray, shape (ns, nt) + Optimal coupling between the two spaces + + References + ---------- + .. [14] Kerdoncuff, Tanguy, Emonet, Rémi, Sebban, Marc + "Sampled Gromov Wasserstein." + Machine Learning Journal (MLJ). 2021. + + """ + C1 = np.asarray(C1, dtype=np.float64) + C2 = np.asarray(C2, dtype=np.float64) + p = np.asarray(p, dtype=np.float64) + q = np.asarray(q, dtype=np.float64) + len_p = len(p) + len_q = len(q) + + generator = check_random_state(random_state) + + index = np.zeros(2, dtype=int) + + # Initialize with default marginal + index[0] = generator.choice(len_p, size=1, p=p) + index[1] = generator.choice(len_q, size=1, p=q) + T = emd_1d(C1[index[0]], C2[index[1]], a=p, b=q, dense=False).tocsr() + + best_gw_dist_estimated = np.inf + for cpt in range(max_iter): + index[0] = generator.choice(len_p, size=1, p=p) + T_index0 = T[index[0], :].toarray()[0] + index[1] = generator.choice(len_q, size=1, p=T_index0 / T_index0.sum()) + + if alpha == 1: + T = emd_1d(C1[index[0]], C2[index[1]], a=p, b=q, dense=False).tocsr() + else: + new_T = emd_1d(C1[index[0]], C2[index[1]], a=p, b=q, dense=False).tocsr() + T = (1 - alpha) * T + alpha * new_T + # To limit the number of non 0, the values bellow the threshold are set to 0. + T.data[T.data < threshold_plan] = 0 + T.eliminate_zeros() + + if cpt % 10 == 0 or cpt == (max_iter - 1): + gw_dist_estimated = GW_distance_estimation(C1=C1, C2=C2, loss_fun=loss_fun, + p=p, q=q, T=T, std=False, random_state=generator) + + if gw_dist_estimated < best_gw_dist_estimated: + best_gw_dist_estimated = gw_dist_estimated + best_T = T.copy() + + if verbose: + if cpt % 200 == 0: + print('{:5s}|{:12s}'.format('It.', 'Best gw estimated') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(cpt, best_gw_dist_estimated)) + + if log: + log = {} + log["gw_dist_estimated"], log["gw_dist_std"] = GW_distance_estimation(C1=C1, C2=C2, loss_fun=loss_fun, + p=p, q=q, T=best_T, + random_state=generator) + return best_T, log + return best_T + + +def sampled_gromov_wasserstein(C1, C2, p, q, loss_fun, + nb_samples_grad=100, epsilon=1, max_iter=500, log=False, verbose=False, + random_state=None): + r""" + Returns the gromov-wasserstein transport between (C1,p) and (C2,q) using a 1-stochastic Frank-Wolfe. + This method as a O(max_iter \times Nlog(N)) time complexity by relying on the 1D Optimal Transport solver. + + The function solves the following optimization problem: + + .. math:: + GW = arg\min_T \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l} + + s.t. T 1 = p + + T^T 1= q + + T\geq 0 + + Where : + + - C1 : Metric cost matrix in the source space + - C2 : Metric cost matrix in the target space + - p : distribution in the source space + - q : distribution in the target space + - L : loss function to account for the misfit between the similarity matrices + + Parameters + ---------- + C1 : ndarray, shape (ns, ns) + Metric cost matrix in the source space + C2 : ndarray, shape (nt, nt) + Metric costfr matrix in the target space + p : ndarray, shape (ns,) + Distribution in the source space + q : ndarray, shape (nt,) + Distribution in the target space + loss_fun : function: \mathcal{R} \times \mathcal{R} \shortarrow \mathcal{R} + Loss function used for the distance, the transport plan does not depend on the loss function + nb_samples_grad : int + Number of samples to approximate the gradient + epsilon : float + Weight of the Kullback-Leiber regularization + max_iter : int, optional + Max number of iterations + verbose : bool, optional + Print information along iterations + log : bool, optional + Gives the distance estimated and the standard deviation + random_state : int or RandomState instance, optional + Fix the seed for to allow reproducibility + + Returns + ------- + T : ndarray, shape (ns, nt) + Optimal coupling between the two spaces + + References + ---------- + .. [14] Kerdoncuff, Tanguy, Emonet, Rémi, Sebban, Marc + "Sampled Gromov Wasserstein." + Machine Learning Journal (MLJ). 2021. + + """ + C1 = np.asarray(C1, dtype=np.float64) + C2 = np.asarray(C2, dtype=np.float64) + p = np.asarray(p, dtype=np.float64) + q = np.asarray(q, dtype=np.float64) + len_p = len(p) + len_q = len(q) + + generator = check_random_state(random_state) + + # The most natural way to define nb_sample is with a simple integer. + if isinstance(nb_samples_grad, int): + if nb_samples_grad > len_p: + # As the sampling along the first dimension is done without replacement, the rest is reported to the second + # dimension. + nb_samples_grad_p, nb_samples_grad_q = len_p, nb_samples_grad // len_p + else: + nb_samples_grad_p, nb_samples_grad_q = nb_samples_grad, 1 + else: + nb_samples_grad_p, nb_samples_grad_q = nb_samples_grad + T = np.outer(p, q) + # continue_loop allows to stop the loop if there is several successive small modification of T. + continue_loop = 0 + + # The gradient of GW is more complex if the two matrices are not symmetric. + C_are_symmetric = np.allclose(C1, C1.T, rtol=1e-10, atol=1e-10) and np.allclose(C2, C2.T, rtol=1e-10, atol=1e-10) + + for cpt in range(max_iter): + index0 = generator.choice(len_p, size=nb_samples_grad_p, p=p, replace=False) + Lik = 0 + for i, index0_i in enumerate(index0): + index1 = generator.choice(len_q, + size=nb_samples_grad_q, + p=T[index0_i, :] / T[index0_i, :].sum(), + replace=False) + # If the matrices C are not symmetric, the gradient has 2 terms, thus the term is chosen randomly. + if (not C_are_symmetric) and generator.rand(1) > 0.5: + Lik += np.mean(loss_fun(np.expand_dims(C1[:, np.repeat(index0[i], nb_samples_grad_q)], 1), + np.expand_dims(C2[:, index1], 0)), + axis=2) + else: + Lik += np.mean(loss_fun(np.expand_dims(C1[np.repeat(index0[i], nb_samples_grad_q), :], 2), + np.expand_dims(C2[index1, :], 1)), + axis=0) + + max_Lik = np.max(Lik) + if max_Lik == 0: + continue + # This division by the max is here to facilitate the choice of epsilon. + Lik /= max_Lik + + if epsilon > 0: + # Set to infinity all the numbers bellow exp(-200) to avoid log of 0. + log_T = np.log(np.clip(T, np.exp(-200), 1)) + log_T[log_T == -200] = -np.inf + Lik = Lik - epsilon * log_T + + try: + new_T = sinkhorn(a=p, b=q, M=Lik, reg=epsilon) + except (RuntimeWarning, UserWarning): + print("Warning catched in Sinkhorn: Return last stable T") + break + else: + new_T = emd(a=p, b=q, M=Lik) + + change_T = ((T - new_T) ** 2).mean() + if change_T <= 10e-20: + continue_loop += 1 + if continue_loop > 100: # Number max of low modifications of T + T = new_T.copy() + break + else: + continue_loop = 0 + + if verbose and cpt % 10 == 0: + if cpt % 200 == 0: + print('{:5s}|{:12s}'.format('It.', '||T_n - T_{n+1}||') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(cpt, change_T)) + T = new_T.copy() + + if log: + log = {} + log["gw_dist_estimated"], log["gw_dist_std"] = GW_distance_estimation(C1=C1, C2=C2, loss_fun=loss_fun, + p=p, q=q, T=T, random_state=generator) + return T, log + return T + + def entropic_gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, tol=1e-9, verbose=False, log=False): r""" diff --git a/test/test_gromov.py b/test/test_gromov.py index 56414a8..19d61b1 100644 --- a/test/test_gromov.py +++ b/test/test_gromov.py @@ -33,7 +33,7 @@ def test_gromov(): G = ot.gromov.gromov_wasserstein(C1, C2, p, q, 'square_loss', verbose=True) - # check constratints + # check constraints np.testing.assert_allclose( p, G.sum(1), atol=1e-04) # cf convergence gromov np.testing.assert_allclose( @@ -54,7 +54,7 @@ def test_gromov(): np.testing.assert_allclose(gw, gw_val, atol=1e-1, rtol=1e-1) # cf log=False - # check constratints + # check constraints np.testing.assert_allclose( p, G.sum(1), atol=1e-04) # cf convergence gromov np.testing.assert_allclose( @@ -83,7 +83,7 @@ def test_entropic_gromov(): G = ot.gromov.entropic_gromov_wasserstein( C1, C2, p, q, 'square_loss', epsilon=5e-4, verbose=True) - # check constratints + # check constraints np.testing.assert_allclose( p, G.sum(1), atol=1e-04) # cf convergence gromov np.testing.assert_allclose( @@ -96,13 +96,89 @@ def test_entropic_gromov(): np.testing.assert_allclose(gw, 0, atol=1e-1, rtol=1e-1) - # check constratints + # check constraints np.testing.assert_allclose( p, G.sum(1), atol=1e-04) # cf convergence gromov np.testing.assert_allclose( q, G.sum(0), atol=1e-04) # cf convergence gromov +def test_pointwise_gromov(): + n_samples = 50 # nb samples + + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=42) + + xt = xs[::-1].copy() + + p = ot.unif(n_samples) + q = ot.unif(n_samples) + + C1 = ot.dist(xs, xs) + C2 = ot.dist(xt, xt) + + C1 /= C1.max() + C2 /= C2.max() + + def loss(x, y): + return np.abs(x - y) + + G, log = ot.gromov.pointwise_gromov_wasserstein( + C1, C2, p, q, loss, max_iter=100, log=True, verbose=True, random_state=42) + + # check constraints + np.testing.assert_allclose( + p[:, np.newaxis], G.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose( + q[np.newaxis, :], G.sum(0), atol=1e-04) # cf convergence gromov + + assert log['gw_dist_estimated'] == 0.0 + assert log['gw_dist_std'] == 0.0 + + G, log = ot.gromov.pointwise_gromov_wasserstein( + C1, C2, p, q, loss, max_iter=100, alpha=0.1, log=True, verbose=True, random_state=42) + + assert log['gw_dist_estimated'] == 0.10342276348494964 + assert log['gw_dist_std'] == 0.0015952535464736394 + + +def test_sampled_gromov(): + n_samples = 50 # nb samples + + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=42) + + xt = xs[::-1].copy() + + p = ot.unif(n_samples) + q = ot.unif(n_samples) + + C1 = ot.dist(xs, xs) + C2 = ot.dist(xt, xt) + + C1 /= C1.max() + C2 /= C2.max() + + def loss(x, y): + return np.abs(x - y) + + G, log = ot.gromov.sampled_gromov_wasserstein( + C1, C2, p, q, loss, max_iter=100, epsilon=1, log=True, verbose=True, random_state=42) + + # check constraints + np.testing.assert_allclose( + p, G.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose( + q, G.sum(0), atol=1e-04) # cf convergence gromov + + assert log['gw_dist_estimated'] == 0.05679474884977278 + assert log['gw_dist_std'] == 0.0005986592106971995 + + def test_gromov_barycenter(): ns = 50 nt = 60 @@ -186,7 +262,7 @@ def test_fgw(): G, log = ot.gromov.fused_gromov_wasserstein(M, C1, C2, p, q, 'square_loss', alpha=0.5, log=True) - # check constratints + # check constraints np.testing.assert_allclose( p, G.sum(1), atol=1e-04) # cf convergence fgw np.testing.assert_allclose( @@ -203,7 +279,7 @@ def test_fgw(): np.testing.assert_allclose(fgw, 0, atol=1e-1, rtol=1e-1) - # check constratints + # check constraints np.testing.assert_allclose( p, G.sum(1), atol=1e-04) # cf convergence gromov np.testing.assert_allclose( -- cgit v1.2.3 From d7554331fc409fea48ee758fd630909dd9dc4827 Mon Sep 17 00:00:00 2001 From: Rémi Flamary Date: Wed, 27 Oct 2021 08:41:08 +0200 Subject: [WIP] Sinkhorn in log space (#290) * adda sinkhorn log and working sinkhorn2 function * more tests pass * more tests pass * it works but not by default yet * remove warningd * update circleci doc * update circleci doc * new sinkhorn implemeted but not by default * better * doctest pass * test doctest * new test utils * remove pep8 errors * remove pep8 errors * doc new implementtaion with log * test sinkhorn 2 * doc for log implementation --- .circleci/config.yml | 14 +-- README.md | 4 +- docs/source/quickstart.rst | 10 +- ot/bregman.py | 272 +++++++++++++++++++++++++++++++++++++++++---- ot/dr.py | 4 +- ot/gromov.py | 4 +- ot/optim.py | 4 +- ot/utils.py | 4 +- test/test_bregman.py | 120 ++++++++++++++++++-- test/test_gromov.py | 10 +- test/test_helpers.py | 4 +- test/test_utils.py | 15 +++ 12 files changed, 403 insertions(+), 62 deletions(-) (limited to 'test/test_gromov.py') diff --git a/.circleci/config.yml b/.circleci/config.yml index e4c71dd..379394a 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -4,7 +4,7 @@ version: 2 jobs: build_docs: docker: - - image: circleci/python:3.7-stretch + - image: cimg/python:3.9 steps: - checkout - run: @@ -34,18 +34,6 @@ jobs: - data-cache-0 - pip-cache - - run: - name: Spin up Xvfb - command: | - /sbin/start-stop-daemon --start --quiet --pidfile /tmp/custom_xvfb_99.pid --make-pidfile --background --exec /usr/bin/Xvfb -- :99 -screen 0 1400x900x24 -ac +extension GLX +render -noreset; - - # https://github.com/ContinuumIO/anaconda-issues/issues/9190#issuecomment-386508136 - # https://github.com/golemfactory/golem/issues/1019 - - run: - name: Fix libgcc_s.so.1 pthread_cancel bug - command: | - sudo apt-get install qt5-default - - run: name: Get Python running command: | diff --git a/README.md b/README.md index 266d847..ffad0bd 100644 --- a/README.md +++ b/README.md @@ -20,7 +20,7 @@ POT provides the following generic OT solvers (links to examples): * [OT Network Simplex solver](https://pythonot.github.io/auto_examples/plot_OT_1D.html) for the linear program/ Earth Movers Distance [1] . * [Conditional gradient](https://pythonot.github.io/auto_examples/plot_optim_OTreg.html) [6] and [Generalized conditional gradient](https://pythonot.github.io/auto_examples/plot_optim_OTreg.html) for regularized OT [7]. -* Entropic regularization OT solver with [Sinkhorn Knopp Algorithm](https://pythonot.github.io/auto_examples/plot_OT_1D.html) [2] , stabilized version [9] [10], greedy Sinkhorn [22] and [Screening Sinkhorn [26] ](https://pythonot.github.io/auto_examples/plot_screenkhorn_1D.html). +* Entropic regularization OT solver with [Sinkhorn Knopp Algorithm](https://pythonot.github.io/auto_examples/plot_OT_1D.html) [2] , stabilized version [9] [10] [34], greedy Sinkhorn [22] and [Screening Sinkhorn [26] ](https://pythonot.github.io/auto_examples/plot_screenkhorn_1D.html). * Bregman projections for [Wasserstein barycenter](https://pythonot.github.io/auto_examples/barycenters/plot_barycenter_lp_vs_entropic.html) [3], [convolutional barycenter](https://pythonot.github.io/auto_examples/barycenters/plot_convolutional_barycenter.html) [21] and unmixing [4]. * Sinkhorn divergence [23] and entropic regularization OT from empirical data. * [Smooth optimal transport solvers](https://pythonot.github.io/auto_examples/plot_OT_1D_smooth.html) (dual and semi-dual) for KL and squared L2 regularizations [17]. @@ -290,3 +290,5 @@ You can also post bug reports and feature requests in Github issues. Make sure t [32] Huang, M., Ma S., Lai, L. (2021). [A Riemannian Block Coordinate Descent Method for Computing the Projection Robust Wasserstein Distance](http://proceedings.mlr.press/v139/huang21e.html), Proceedings of the 38th International Conference on Machine Learning (ICML). [33] Kerdoncuff T., Emonet R., Marc S. [Sampled Gromov Wasserstein](https://hal.archives-ouvertes.fr/hal-03232509/document), Machine Learning Journal (MJL), 2021 + +[34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., Trouvé, A., & Peyré, G. (2019, April). Interpolating between optimal transport and MMD using Sinkhorn divergences. In The 22nd International Conference on Artificial Intelligence and Statistics (pp. 2681-2690). PMLR. diff --git a/docs/source/quickstart.rst b/docs/source/quickstart.rst index fd046a1..232df7b 100644 --- a/docs/source/quickstart.rst +++ b/docs/source/quickstart.rst @@ -358,6 +358,11 @@ More details about the algorithms used are given in the following note. + :code:`method='sinkhorn'` calls :any:`ot.bregman.sinkhorn_knopp` the classic algorithm [2]_. + + :code:`method='sinkhorn_log'` calls :any:`ot.bregman.sinkhorn_log` the + sinkhorn algorithm in log space [2]_ that is more stable but can be + slower in numpy since `logsumexp` is not implmemented in parallel. + It is the recommended solver for applications that requires + differentiability with a small number of iterations. + :code:`method='sinkhorn_stabilized'` calls :any:`ot.bregman.sinkhorn_stabilized` the log stabilized version of the algorithm [9]_. + :code:`method='sinkhorn_epsilon_scaling'` calls @@ -389,7 +394,10 @@ More details about the algorithms used are given in the following note. solutions. Note that the greedy version of the Sinkhorn :any:`ot.bregman.greenkhorn` can also lead to a speedup and the screening version of the Sinkhorn :any:`ot.bregman.screenkhorn` aim a providing a - fast approximation of the Sinkhorn problem. + fast approximation of the Sinkhorn problem. For use of GPU and gradient + computation with small number of iterations we strongly recommend the + :any:`ot.bregman.sinkhorn_log` solver that will no need to check for + numerical problems. diff --git a/ot/bregman.py b/ot/bregman.py index b59ee1b..2aa76ff 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -64,7 +64,10 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, solutions. Note that the greedy version of the sinkhorn :py:func:`ot.bregman.greenkhorn` can also lead to a speedup and the screening version of the sinkhorn :py:func:`ot.bregman.screenkhorn` aim at providing a - fast approximation of the Sinkhorn problem. + fast approximation of the Sinkhorn problem. For use of GPU and gradient + computation with small number of iterations we strongly recommend the + :any:`ot.bregman.sinkhorn_log` solver that will no need to check for + numerical problems. Parameters @@ -79,8 +82,9 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, reg : float Regularization term >0 method : str - method used for the solver either 'sinkhorn', 'greenkhorn', 'sinkhorn_stabilized' or - 'sinkhorn_epsilon_scaling', see those function for specific parameters + method used for the solver either 'sinkhorn','sinkhorn_log', + 'greenkhorn', 'sinkhorn_stabilized' or 'sinkhorn_epsilon_scaling', see + those function for specific parameters numItermax : int, optional Max number of iterations stopThr : float, optional @@ -118,6 +122,7 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. + .. [34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., Trouvé, A., & Peyré, G. (2019, April). Interpolating between optimal transport and MMD using Sinkhorn divergences. In The 22nd International Conference on Artificial Intelligence and Statistics (pp. 2681-2690). PMLR. See Also @@ -134,6 +139,10 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) + elif method.lower() == 'sinkhorn_log': + return sinkhorn_log(a, b, M, reg, numItermax=numItermax, + stopThr=stopThr, verbose=verbose, log=log, + **kwargs) elif method.lower() == 'greenkhorn': return greenkhorn(a, b, M, reg, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log) @@ -182,7 +191,7 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, By default and when using a regularization parameter that is not too small the default sinkhorn solver should be enough. If you need to use a small regularization to get sharper OT matrices, you should use the - :any:`ot.bregman.sinkhorn_stabilized` solver that will avoid numerical + :any:`ot.bregman.sinkhorn_log` solver that will avoid numerical errors. This last solver can be very slow in practice and might not even converge to a reasonable OT matrix in a finite time. This is why :any:`ot.bregman.sinkhorn_epsilon_scaling` that relies on iterating the value @@ -190,7 +199,10 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, solutions. Note that the greedy version of the sinkhorn :any:`ot.bregman.greenkhorn` can also lead to a speedup and the screening version of the sinkhorn :any:`ot.bregman.screenkhorn` aim a providing a - fast approximation of the Sinkhorn problem. + fast approximation of the Sinkhorn problem. For use of GPU and gradient + computation with small number of iterations we strongly recommend the + :any:`ot.bregman.sinkhorn_log` solver that will no need to check for + numerical problems. Parameters ---------- @@ -204,7 +216,8 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, reg : float Regularization term >0 method : str - method used for the solver either 'sinkhorn', 'sinkhorn_stabilized', see those function for specific parameters + method used for the solver either 'sinkhorn','sinkhorn_log', + 'sinkhorn_stabilized', see those function for specific parameters numItermax : int, optional Max number of iterations stopThr : float, optional @@ -230,7 +243,7 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, >>> b=[.5, .5] >>> M=[[0., 1.], [1., 0.]] >>> ot.sinkhorn2(a, b, M, 1) - array([0.26894142]) + 0.26894142136999516 .. _references-sinkhorn2: @@ -243,7 +256,11 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. - .. [21] Altschuler J., Weed J., Rigollet P. : Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration, Advances in Neural Information Processing Systems (NIPS) 31, 2017 + .. [21] Altschuler J., Weed J., Rigollet P. : Near-linear time approximation + algorithms for optimal transport via Sinkhorn iteration, Advances in Neural + Information Processing Systems (NIPS) 31, 2017 + + .. [34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., Trouvé, A., & Peyré, G. (2019, April). Interpolating between optimal transport and MMD using Sinkhorn divergences. In The 22nd International Conference on Artificial Intelligence and Statistics (pp. 2681-2690). PMLR. @@ -257,20 +274,45 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, """ - b = list_to_array(b) + M, a, b = list_to_array(M, a, b) + nx = get_backend(M, a, b) + if len(b.shape) < 2: - b = b[:, None] + if method.lower() == 'sinkhorn': + res = sinkhorn_knopp(a, b, M, reg, numItermax=numItermax, + stopThr=stopThr, verbose=verbose, log=log, + **kwargs) + elif method.lower() == 'sinkhorn_log': + res = sinkhorn_log(a, b, M, reg, numItermax=numItermax, + stopThr=stopThr, verbose=verbose, log=log, + **kwargs) + elif method.lower() == 'sinkhorn_stabilized': + res = sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax, + stopThr=stopThr, verbose=verbose, log=log, + **kwargs) + else: + raise ValueError("Unknown method '%s'." % method) + if log: + return nx.sum(M * res[0]), res[1] + else: + return nx.sum(M * res) - if method.lower() == 'sinkhorn': - return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax, - stopThr=stopThr, verbose=verbose, log=log, - **kwargs) - elif method.lower() == 'sinkhorn_stabilized': - return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax, - stopThr=stopThr, verbose=verbose, log=log, - **kwargs) else: - raise ValueError("Unknown method '%s'." % method) + + if method.lower() == 'sinkhorn': + return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax, + stopThr=stopThr, verbose=verbose, log=log, + **kwargs) + elif method.lower() == 'sinkhorn_log': + return sinkhorn_log(a, b, M, reg, numItermax=numItermax, + stopThr=stopThr, verbose=verbose, log=log, + **kwargs) + elif method.lower() == 'sinkhorn_stabilized': + return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax, + stopThr=stopThr, verbose=verbose, log=log, + **kwargs) + else: + raise ValueError("Unknown method '%s'." % method) def sinkhorn_knopp(a, b, M, reg, numItermax=1000, @@ -361,7 +403,7 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, # init data dim_a = len(a) - dim_b = len(b) + dim_b = b.shape[0] if len(b.shape) > 1: n_hists = b.shape[1] @@ -438,6 +480,191 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, return u.reshape((-1, 1)) * K * v.reshape((1, -1)) +def sinkhorn_log(a, b, M, reg, numItermax=1000, + stopThr=1e-9, verbose=False, log=False, **kwargs): + r""" + Solve the entropic regularization optimal transport problem in log space + and return the OT matrix + + The function solves the following optimization problem: + + .. math:: + \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + + s.t. \gamma 1 = a + + \gamma^T 1= b + + \gamma\geq 0 + where : + + - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix + - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (histograms, both sum to 1) + + The algorithm used for solving the problem is the Sinkhorn-Knopp matrix + scaling algorithm :ref:`[2] ` with the + implementation from :ref:`[34] ` + + + Parameters + ---------- + a : array-like, shape (dim_a,) + samples weights in the source domain + b : array-like, shape (dim_b,) or array-like, shape (dim_b, n_hists) + samples in the target domain, compute sinkhorn with multiple targets + and fixed :math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix (return OT loss + dual variables in log) + M : array-like, shape (dim_a, dim_b) + loss matrix + reg : float + Regularization term >0 + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshold on error (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + Returns + ------- + gamma : array-like, shape (dim_a, dim_b) + Optimal transportation matrix for the given parameters + log : dict + log dictionary return only if log==True in parameters + + Examples + -------- + + >>> import ot + >>> a=[.5, .5] + >>> b=[.5, .5] + >>> M=[[0., 1.], [1., 0.]] + >>> ot.sinkhorn(a, b, M, 1) + array([[0.36552929, 0.13447071], + [0.13447071, 0.36552929]]) + + + .. _references-sinkhorn-log: + References + ---------- + + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal + Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 + + .. [34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., Trouvé, A., & Peyré, G. (2019, April). Interpolating between optimal transport and MMD using Sinkhorn divergences. In The 22nd International Conference on Artificial Intelligence and Statistics (pp. 2681-2690). PMLR. + + + See Also + -------- + ot.lp.emd : Unregularized OT + ot.optim.cg : General regularized OT + + """ + + a, b, M = list_to_array(a, b, M) + + nx = get_backend(M, a, b) + + if len(a) == 0: + a = nx.full((M.shape[0],), 1.0 / M.shape[0], type_as=M) + if len(b) == 0: + b = nx.full((M.shape[1],), 1.0 / M.shape[1], type_as=M) + + # init data + dim_a = len(a) + dim_b = b.shape[0] + + if len(b.shape) > 1: + n_hists = b.shape[1] + else: + n_hists = 0 + + if n_hists: # we do not want to use tensors sor we do a loop + + lst_loss = [] + lst_u = [] + lst_v = [] + + for k in range(n_hists): + res = sinkhorn_log(a, b[:, k], M, reg, numItermax=numItermax, + stopThr=stopThr, verbose=verbose, log=log, **kwargs) + + if log: + lst_loss.append(nx.sum(M * res[0])) + lst_u.append(res[1]['log_u']) + lst_v.append(res[1]['log_v']) + else: + lst_loss.append(nx.sum(M * res)) + res = nx.stack(lst_loss) + if log: + log = {'log_u': nx.stack(lst_u, 1), + 'log_v': nx.stack(lst_v, 1), } + log['u'] = nx.exp(log['log_u']) + log['v'] = nx.exp(log['log_v']) + return res, log + else: + return res + + else: + + if log: + log = {'err': []} + + Mr = M / (-reg) + + # we assume that no distances are null except those of the diagonal of + # distances + + u = nx.zeros(dim_a, type_as=M) + v = nx.zeros(dim_b, type_as=M) + + def get_logT(u, v): + if n_hists: + return Mr[:, :, None] + u + v + else: + return Mr + u[:, None] + v[None, :] + + loga = nx.log(a) + logb = nx.log(b) + + cpt = 0 + err = 1 + while (err > stopThr and cpt < numItermax): + + v = logb - nx.logsumexp(Mr + u[:, None], 0) + u = loga - nx.logsumexp(Mr + v[None, :], 1) + + if cpt % 10 == 0: + # we can speed up the process by checking for the error only all + # the 10th iterations + + # compute right marginal tmp2= (diag(u)Kdiag(v))^T1 + tmp2 = nx.sum(nx.exp(get_logT(u, v)), 0) + err = nx.norm(tmp2 - b) # violation of marginal + if log: + log['err'].append(err) + + if verbose: + if cpt % 200 == 0: + print( + '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(cpt, err)) + cpt = cpt + 1 + + if log: + log['log_u'] = u + log['log_v'] = v + log['u'] = nx.exp(u) + log['v'] = nx.exp(v) + + return nx.exp(get_logT(u, v)), log + + else: + return nx.exp(get_logT(u, v)) + + def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log=False): r""" @@ -1881,8 +2108,7 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', return (f, g) else: - M = dist(nx.to_numpy(X_s), nx.to_numpy(X_t), metric=metric) - M = nx.from_numpy(M, type_as=a) + M = dist(X_s, X_t, metric=metric) if log: pi, log = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=True, **kwargs) return pi, log @@ -2102,7 +2328,7 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli >>> X_s = np.reshape(np.arange(n_samples_a), (n_samples_a, 1)) >>> X_t = np.reshape(np.arange(0, n_samples_b), (n_samples_b, 1)) >>> empirical_sinkhorn_divergence(X_s, X_t, reg) # doctest: +ELLIPSIS - array([1.499...]) + 1.499887176049052 References diff --git a/ot/dr.py b/ot/dr.py index 64588cf..de39662 100644 --- a/ot/dr.py +++ b/ot/dr.py @@ -209,11 +209,11 @@ def projection_robust_wasserstein(X, Y, a, b, tau, U0=None, reg=0.1, k=2, stopTh .. math:: \max_{U \in St(d, k)} \min_{\pi \in \Pi(\mu,\nu)} \sum_{i,j} \pi_{i,j} \|U^T(x_i - y_j)\|^2 - reg * H(\pi) - + - :math:`U` is a linear projection operator in the Stiefel(d, k) manifold - :math:`H(\pi)` is entropy regularizer - :math:`x_i`, :math:`y_j` are samples of measures \mu and \nu respectively - + Parameters ---------- X : ndarray, shape (n, d) diff --git a/ot/gromov.py b/ot/gromov.py index 85b1549..33b4453 100644 --- a/ot/gromov.py +++ b/ot/gromov.py @@ -1030,7 +1030,7 @@ def entropic_gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, # compute the gradient tens = gwggrad(constC, hC1, hC2, T) - T = sinkhorn(p, q, tens, epsilon) + T = sinkhorn(p, q, tens, epsilon, method='sinkhorn') if cpt % 10 == 0: # we can speed up the process by checking for the error only all @@ -1204,7 +1204,7 @@ def entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, Cprev = C T = [entropic_gromov_wasserstein(Cs[s], C, ps[s], p, loss_fun, epsilon, - max_iter, 1e-5, verbose, log) for s in range(S)] + max_iter, 1e-4, verbose, log) for s in range(S)] if loss_fun == 'square_loss': C = update_square_loss(p, lambdas, T, Cs) diff --git a/ot/optim.py b/ot/optim.py index 6822e4e..34cbb17 100644 --- a/ot/optim.py +++ b/ot/optim.py @@ -20,7 +20,7 @@ from .backend import get_backend def line_search_armijo(f, xk, pk, gfk, old_fval, args=(), c1=1e-4, alpha0=0.99): - """ + r""" Armijo linesearch function that works with matrices Find an approximate minimum of :math:`f(x_k + \\alpha \cdot p_k)` that satisfies the @@ -447,7 +447,7 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10, def solve_1d_linesearch_quad(a, b, c): - """ + r""" For any convex or non-convex 1d quadratic function `f`, solve the following problem: .. math:: diff --git a/ot/utils.py b/ot/utils.py index 6a782e6..0608aee 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -183,7 +183,7 @@ def euclidean_distances(X, Y, squared=False): return c -def dist(x1, x2=None, metric='sqeuclidean'): +def dist(x1, x2=None, metric='sqeuclidean', p=2): """Compute distance between samples in x1 and x2 .. note:: This function is backend-compatible and will work on arrays @@ -222,7 +222,7 @@ def dist(x1, x2=None, metric='sqeuclidean'): if not get_backend(x1, x2).__name__ == 'numpy': raise NotImplementedError() else: - return cdist(x1, x2, metric=metric) + return cdist(x1, x2, metric=metric, p=p) def dist0(n, method='lin_square'): diff --git a/test/test_bregman.py b/test/test_bregman.py index 942cb6d..c1120ba 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -32,6 +32,27 @@ def test_sinkhorn(): u, G.sum(0), atol=1e-05) # cf convergence sinkhorn +def test_sinkhorn_multi_b(): + # test sinkhorn + n = 10 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + u = ot.utils.unif(n) + + b = rng.rand(n, 3) + b = b / np.sum(b, 0, keepdims=True) + + M = ot.dist(x, x) + + loss0, log = ot.sinkhorn(u, b, M, .1, stopThr=1e-10, log=True) + + loss = [ot.sinkhorn2(u, b[:, k], M, .1, stopThr=1e-10) for k in range(3)] + # check constraints + np.testing.assert_allclose( + loss0, loss, atol=1e-06) # cf convergence sinkhorn + + def test_sinkhorn_backends(nx): n_samples = 100 n_features = 2 @@ -147,6 +168,7 @@ def test_sinkhorn_variants(nx): Mb = nx.from_numpy(M) G = ot.sinkhorn(u, u, M, 1, method='sinkhorn', stopThr=1e-10) + Gl = nx.to_numpy(ot.sinkhorn(ub, ub, Mb, 1, method='sinkhorn_log', stopThr=1e-10)) G0 = nx.to_numpy(ot.sinkhorn(ub, ub, Mb, 1, method='sinkhorn', stopThr=1e-10)) Gs = nx.to_numpy(ot.sinkhorn(ub, ub, Mb, 1, method='sinkhorn_stabilized', stopThr=1e-10)) Ges = nx.to_numpy(ot.sinkhorn( @@ -155,15 +177,73 @@ def test_sinkhorn_variants(nx): # check values np.testing.assert_allclose(G, G0, atol=1e-05) + np.testing.assert_allclose(G, Gl, atol=1e-05) np.testing.assert_allclose(G0, Gs, atol=1e-05) np.testing.assert_allclose(G0, Ges, atol=1e-05) np.testing.assert_allclose(G0, G_green, atol=1e-5) - print(G0, G_green) + + +@pytest.skip_backend("jax") +def test_sinkhorn_variants_multi_b(nx): + # test sinkhorn + n = 50 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + u = ot.utils.unif(n) + + b = rng.rand(n, 3) + b = b / np.sum(b, 0, keepdims=True) + + M = ot.dist(x, x) + + ub = nx.from_numpy(u) + bb = nx.from_numpy(b) + Mb = nx.from_numpy(M) + + G = ot.sinkhorn(u, b, M, 1, method='sinkhorn', stopThr=1e-10) + Gl = nx.to_numpy(ot.sinkhorn(ub, bb, Mb, 1, method='sinkhorn_log', stopThr=1e-10)) + G0 = nx.to_numpy(ot.sinkhorn(ub, bb, Mb, 1, method='sinkhorn', stopThr=1e-10)) + Gs = nx.to_numpy(ot.sinkhorn(ub, bb, Mb, 1, method='sinkhorn_stabilized', stopThr=1e-10)) + + # check values + np.testing.assert_allclose(G, G0, atol=1e-05) + np.testing.assert_allclose(G, Gl, atol=1e-05) + np.testing.assert_allclose(G0, Gs, atol=1e-05) + + +@pytest.skip_backend("jax") +def test_sinkhorn2_variants_multi_b(nx): + # test sinkhorn + n = 50 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + u = ot.utils.unif(n) + + b = rng.rand(n, 3) + b = b / np.sum(b, 0, keepdims=True) + + M = ot.dist(x, x) + + ub = nx.from_numpy(u) + bb = nx.from_numpy(b) + Mb = nx.from_numpy(M) + + G = ot.sinkhorn2(u, b, M, 1, method='sinkhorn', stopThr=1e-10) + Gl = nx.to_numpy(ot.sinkhorn2(ub, bb, Mb, 1, method='sinkhorn_log', stopThr=1e-10)) + G0 = nx.to_numpy(ot.sinkhorn2(ub, bb, Mb, 1, method='sinkhorn', stopThr=1e-10)) + Gs = nx.to_numpy(ot.sinkhorn2(ub, bb, Mb, 1, method='sinkhorn_stabilized', stopThr=1e-10)) + + # check values + np.testing.assert_allclose(G, G0, atol=1e-05) + np.testing.assert_allclose(G, Gl, atol=1e-05) + np.testing.assert_allclose(G0, Gs, atol=1e-05) def test_sinkhorn_variants_log(): # test sinkhorn - n = 100 + n = 50 rng = np.random.RandomState(0) x = rng.randn(n, 2) @@ -172,6 +252,7 @@ def test_sinkhorn_variants_log(): M = ot.dist(x, x) G0, log0 = ot.sinkhorn(u, u, M, 1, method='sinkhorn', stopThr=1e-10, log=True) + Gl, logl = ot.sinkhorn(u, u, M, 1, method='sinkhorn_log', stopThr=1e-10, log=True) Gs, logs = ot.sinkhorn(u, u, M, 1, method='sinkhorn_stabilized', stopThr=1e-10, log=True) Ges, loges = ot.sinkhorn( u, u, M, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10, log=True) @@ -179,9 +260,30 @@ def test_sinkhorn_variants_log(): # check values np.testing.assert_allclose(G0, Gs, atol=1e-05) + np.testing.assert_allclose(G0, Gl, atol=1e-05) np.testing.assert_allclose(G0, Ges, atol=1e-05) np.testing.assert_allclose(G0, G_green, atol=1e-5) - print(G0, G_green) + + +def test_sinkhorn_variants_log_multib(): + # test sinkhorn + n = 50 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + u = ot.utils.unif(n) + b = rng.rand(n, 3) + b = b / np.sum(b, 0, keepdims=True) + + M = ot.dist(x, x) + + G0, log0 = ot.sinkhorn(u, b, M, 1, method='sinkhorn', stopThr=1e-10, log=True) + Gl, logl = ot.sinkhorn(u, b, M, 1, method='sinkhorn_log', stopThr=1e-10, log=True) + Gs, logs = ot.sinkhorn(u, b, M, 1, method='sinkhorn_stabilized', stopThr=1e-10, log=True) + + # check values + np.testing.assert_allclose(G0, Gs, atol=1e-05) + np.testing.assert_allclose(G0, Gl, atol=1e-05) @pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"]) @@ -326,10 +428,10 @@ def test_empirical_sinkhorn(nx): a = ot.unif(n) b = ot.unif(n) - X_s = np.reshape(np.arange(n), (n, 1)) - X_t = np.reshape(np.arange(0, n), (n, 1)) + X_s = np.reshape(1.0 * np.arange(n), (n, 1)) + X_t = np.reshape(1.0 * np.arange(0, n), (n, 1)) M = ot.dist(X_s, X_t) - M_m = ot.dist(X_s, X_t, metric='minkowski') + M_m = ot.dist(X_s, X_t, metric='euclidean') ab = nx.from_numpy(a) bb = nx.from_numpy(b) @@ -346,7 +448,7 @@ def test_empirical_sinkhorn(nx): sinkhorn_log, log_s = ot.sinkhorn(ab, bb, Mb, 0.1, log=True) sinkhorn_log = nx.to_numpy(sinkhorn_log) - G_m = nx.to_numpy(ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1, metric='minkowski')) + G_m = nx.to_numpy(ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1, metric='euclidean')) sinkhorn_m = nx.to_numpy(ot.sinkhorn(ab, bb, M_mb, 1)) loss_emp_sinkhorn = nx.to_numpy(ot.bregman.empirical_sinkhorn2(X_sb, X_tb, 1)) @@ -378,7 +480,7 @@ def test_lazy_empirical_sinkhorn(nx): X_s = np.reshape(np.arange(n), (n, 1)) X_t = np.reshape(np.arange(0, n), (n, 1)) M = ot.dist(X_s, X_t) - M_m = ot.dist(X_s, X_t, metric='minkowski') + M_m = ot.dist(X_s, X_t, metric='euclidean') ab = nx.from_numpy(a) bb = nx.from_numpy(b) @@ -398,7 +500,7 @@ def test_lazy_empirical_sinkhorn(nx): sinkhorn_log, log_s = ot.sinkhorn(ab, bb, Mb, 0.1, log=True) sinkhorn_log = nx.to_numpy(sinkhorn_log) - f, g = ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1, metric='minkowski', numIterMax=numIterMax, isLazy=True, batchSize=1) + f, g = ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1, metric='euclidean', numIterMax=numIterMax, isLazy=True, batchSize=1) f, g = nx.to_numpy(f), nx.to_numpy(g) G_m = np.exp(f[:, None] + g[None, :] - M_m / 1) sinkhorn_m = nx.to_numpy(ot.sinkhorn(ab, bb, M_mb, 1)) diff --git a/test/test_gromov.py b/test/test_gromov.py index 19d61b1..0242d72 100644 --- a/test/test_gromov.py +++ b/test/test_gromov.py @@ -180,8 +180,8 @@ def test_sampled_gromov(): def test_gromov_barycenter(): - ns = 50 - nt = 60 + ns = 10 + nt = 20 Xs, ys = ot.datasets.make_data_classif('3gauss', ns, random_state=42) Xt, yt = ot.datasets.make_data_classif('3gauss2', nt, random_state=42) @@ -208,8 +208,8 @@ def test_gromov_barycenter(): @pytest.mark.filterwarnings("ignore:divide") def test_gromov_entropic_barycenter(): - ns = 20 - nt = 30 + ns = 10 + nt = 20 Xs, ys = ot.datasets.make_data_classif('3gauss', ns, random_state=42) Xt, yt = ot.datasets.make_data_classif('3gauss2', nt, random_state=42) @@ -222,7 +222,7 @@ def test_gromov_entropic_barycenter(): [ot.unif(ns), ot.unif(nt) ], ot.unif(n_samples), [.5, .5], 'square_loss', 1e-3, - max_iter=50, tol=1e-5, + max_iter=50, tol=1e-3, verbose=True) np.testing.assert_allclose(Cb.shape, (n_samples, n_samples)) diff --git a/test/test_helpers.py b/test/test_helpers.py index 8bd0015..cc4c90e 100644 --- a/test/test_helpers.py +++ b/test/test_helpers.py @@ -9,8 +9,8 @@ import sys sys.path.append(os.path.join("ot", "helpers")) -from openmp_helpers import get_openmp_flag, check_openmp_support # noqa -from pre_build_helpers import _get_compiler, compile_test_program # noqa +from openmp_helpers import get_openmp_flag, check_openmp_support # noqa +from pre_build_helpers import _get_compiler, compile_test_program # noqa def test_helpers(): diff --git a/test/test_utils.py b/test/test_utils.py index 60ad5d3..0650ce2 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -7,6 +7,7 @@ import ot import numpy as np import sys +import pytest def test_proj_simplex(nx): @@ -108,6 +109,10 @@ def test_dist(): D2 = ot.dist(x, x) D3 = ot.dist(x) + D4 = ot.dist(x, x, metric='minkowski', p=0.5) + + assert D4[0, 1] == D4[1, 0] + # dist shoul return squared euclidean np.testing.assert_allclose(D, D2, atol=1e-14) np.testing.assert_allclose(D, D3, atol=1e-14) @@ -220,6 +225,13 @@ def test_deprecated_func(): class Class(): pass + with pytest.warns(DeprecationWarning): + fun() + + with pytest.warns(DeprecationWarning): + cl = Class() + print(cl) + if sys.version_info < (3, 5): print('Not tested') else: @@ -250,4 +262,7 @@ def test_BaseEstimator(): params['first'] = 'spam again' cl.set_params(**params) + with pytest.raises(ValueError): + cl.set_params(bibi=10) + assert cl.first == 'spam again' -- cgit v1.2.3 From a335324d008e8982be61d7ace937815a2bfa98f9 Mon Sep 17 00:00:00 2001 From: ncassereau-idris <84033440+ncassereau-idris@users.noreply.github.com> Date: Tue, 2 Nov 2021 13:42:02 +0100 Subject: [MRG] Backend for gromov (#294) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * bregman: small correction * gromov backend first draft * Removing decorators * Reworked casting method * Bug solve * Removing casting * Bug solve * toarray renamed todense ; expand_dims removed * Warning (jax not supporting sparse matrix) moved * Mistake corrected * test backend * Sparsity test for older versions of pytorch * Trying pytorch/1.10 * Attempt to correct torch sparse bug * Backend version of gromov tests * Random state introduced for remaining gromov functions * review changes * code coverage * Docs (first draft, to be continued) * Gromov docs * Prettified docs * mistake corrected in the docs * little change Co-authored-by: Rémi Flamary --- ot/backend.py | 214 ++++++++- ot/bregman.py | 184 ++++---- ot/gromov.py | 1220 +++++++++++++++++++++++++++----------------------- ot/lp/__init__.py | 58 ++- ot/optim.py | 22 +- test/test_backend.py | 56 +++ test/test_bregman.py | 4 +- test/test_gromov.py | 297 ++++++++---- 8 files changed, 1289 insertions(+), 766 deletions(-) (limited to 'test/test_gromov.py') diff --git a/ot/backend.py b/ot/backend.py index 876b96a..358297c 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -26,6 +26,7 @@ Examples import numpy as np import scipy.special as scipy +from scipy.sparse import issparse, coo_matrix, csr_matrix try: import torch @@ -539,6 +540,86 @@ class Backend(): """ raise NotImplementedError() + def coo_matrix(self, data, rows, cols, shape=None, type_as=None): + r""" + Creates a sparse tensor in COOrdinate format. + + This function follows the api from :any:`scipy.sparse.coo_matrix` + + See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.coo_matrix.html + """ + raise NotImplementedError() + + def issparse(self, a): + r""" + Checks whether or not the input tensor is a sparse tensor. + + This function follows the api from :any:`scipy.sparse.issparse` + + See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.issparse.html + """ + raise NotImplementedError() + + def tocsr(self, a): + r""" + Converts this matrix to Compressed Sparse Row format. + + This function follows the api from :any:`scipy.sparse.coo_matrix.tocsr` + + See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.coo_matrix.tocsr.html + """ + raise NotImplementedError() + + def eliminate_zeros(self, a, threshold=0.): + r""" + Removes entries smaller than the given threshold from the sparse tensor. + + This function follows the api from :any:`scipy.sparse.csr_matrix.eliminate_zeros` + + See: https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.sparse.csr_matrix.eliminate_zeros.html + """ + raise NotImplementedError() + + def todense(self, a): + r""" + Converts a sparse tensor to a dense tensor. + + This function follows the api from :any:`scipy.sparse.csr_matrix.toarray` + + See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.csr_matrix.toarray.html + """ + raise NotImplementedError() + + def where(self, condition, x, y): + r""" + Returns elements chosen from x or y depending on condition. + + This function follows the api from :any:`numpy.where` + + See: https://numpy.org/doc/stable/reference/generated/numpy.where.html + """ + raise NotImplementedError() + + def copy(self, a): + r""" + Returns a copy of the given tensor. + + This function follows the api from :any:`numpy.copy` + + See: https://numpy.org/doc/stable/reference/generated/numpy.copy.html + """ + raise NotImplementedError() + + def allclose(self, a, b, rtol=1e-05, atol=1e-08, equal_nan=False): + r""" + Returns True if two arrays are element-wise equal within a tolerance. + + This function follows the api from :any:`numpy.allclose` + + See: https://numpy.org/doc/stable/reference/generated/numpy.allclose.html + """ + raise NotImplementedError() + class NumpyBackend(Backend): """ @@ -712,6 +793,46 @@ class NumpyBackend(Backend): def reshape(self, a, shape): return np.reshape(a, shape) + def coo_matrix(self, data, rows, cols, shape=None, type_as=None): + if type_as is None: + return coo_matrix((data, (rows, cols)), shape=shape) + else: + return coo_matrix((data, (rows, cols)), shape=shape, dtype=type_as.dtype) + + def issparse(self, a): + return issparse(a) + + def tocsr(self, a): + if self.issparse(a): + return a.tocsr() + else: + return csr_matrix(a) + + def eliminate_zeros(self, a, threshold=0.): + if threshold > 0: + if self.issparse(a): + a.data[self.abs(a.data) <= threshold] = 0 + else: + a[self.abs(a) <= threshold] = 0 + if self.issparse(a): + a.eliminate_zeros() + return a + + def todense(self, a): + if self.issparse(a): + return a.toarray() + else: + return a + + def where(self, condition, x, y): + return np.where(condition, x, y) + + def copy(self, a): + return a.copy() + + def allclose(self, a, b, rtol=1e-05, atol=1e-08, equal_nan=False): + return np.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan) + class JaxBackend(Backend): """ @@ -889,6 +1010,48 @@ class JaxBackend(Backend): def reshape(self, a, shape): return jnp.reshape(a, shape) + def coo_matrix(self, data, rows, cols, shape=None, type_as=None): + # Currently, JAX does not support sparse matrices + data = self.to_numpy(data) + rows = self.to_numpy(rows) + cols = self.to_numpy(cols) + nx = NumpyBackend() + coo_matrix = nx.coo_matrix(data, rows, cols, shape=shape, type_as=type_as) + matrix = nx.todense(coo_matrix) + return self.from_numpy(matrix) + + def issparse(self, a): + # Currently, JAX does not support sparse matrices + return False + + def tocsr(self, a): + # Currently, JAX does not support sparse matrices + return a + + def eliminate_zeros(self, a, threshold=0.): + # Currently, JAX does not support sparse matrices + if threshold > 0: + return self.where( + self.abs(a) <= threshold, + self.zeros((1,), type_as=a), + a + ) + return a + + def todense(self, a): + # Currently, JAX does not support sparse matrices + return a + + def where(self, condition, x, y): + return jnp.where(condition, x, y) + + def copy(self, a): + # No need to copy, JAX arrays are immutable + return a + + def allclose(self, a, b, rtol=1e-05, atol=1e-08, equal_nan=False): + return jnp.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan) + class TorchBackend(Backend): """ @@ -999,7 +1162,7 @@ class TorchBackend(Backend): a = torch.tensor([float(a)], dtype=b.dtype, device=b.device) if isinstance(b, int) or isinstance(b, float): b = torch.tensor([float(b)], dtype=a.dtype, device=a.device) - if torch.__version__ >= '1.7.0': + if hasattr(torch, "maximum"): return torch.maximum(a, b) else: return torch.max(torch.stack(torch.broadcast_tensors(a, b)), axis=0)[0] @@ -1009,7 +1172,7 @@ class TorchBackend(Backend): a = torch.tensor([float(a)], dtype=b.dtype, device=b.device) if isinstance(b, int) or isinstance(b, float): b = torch.tensor([float(b)], dtype=a.dtype, device=a.device) - if torch.__version__ >= '1.7.0': + if hasattr(torch, "minimum"): return torch.minimum(a, b) else: return torch.min(torch.stack(torch.broadcast_tensors(a, b)), axis=0)[0] @@ -1129,3 +1292,50 @@ class TorchBackend(Backend): def reshape(self, a, shape): return torch.reshape(a, shape) + + def coo_matrix(self, data, rows, cols, shape=None, type_as=None): + if type_as is None: + return torch.sparse_coo_tensor(torch.stack([rows, cols]), data, size=shape) + else: + return torch.sparse_coo_tensor( + torch.stack([rows, cols]), data, size=shape, + dtype=type_as.dtype, device=type_as.device + ) + + def issparse(self, a): + return getattr(a, "is_sparse", False) or getattr(a, "is_sparse_csr", False) + + def tocsr(self, a): + # Versions older than 1.9 do not support CSR tensors. PyTorch 1.9 and 1.10 offer a very limited support + return self.todense(a) + + def eliminate_zeros(self, a, threshold=0.): + if self.issparse(a): + if threshold > 0: + mask = self.abs(a) <= threshold + mask = ~mask + mask = mask.nonzero() + else: + mask = a._values().nonzero() + nv = a._values().index_select(0, mask.view(-1)) + ni = a._indices().index_select(1, mask.view(-1)) + return self.coo_matrix(nv, ni[0], ni[1], shape=a.shape, type_as=a) + else: + if threshold > 0: + a[self.abs(a) <= threshold] = 0 + return a + + def todense(self, a): + if self.issparse(a): + return a.to_dense() + else: + return a + + def where(self, condition, x, y): + return torch.where(condition, x, y) + + def copy(self, a): + return torch.clone(a) + + def allclose(self, a, b, rtol=1e-05, atol=1e-08, equal_nan=False): + return torch.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan) diff --git a/ot/bregman.py b/ot/bregman.py index 2aa76ff..0499b8e 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -32,13 +32,14 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, The function solves the following optimization problem: .. math:: - \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + \gamma = \mathop{\arg \min}_\gamma <\gamma, \mathbf{M}>_F + \mathrm{reg}\cdot\Omega(\gamma) - s.t. \ \gamma 1 = a + s.t. \ \gamma \mathbf{1} &= \mathbf{a} - \gamma^T 1= b + \gamma^T \mathbf{1} &= \mathbf{b} + + \gamma &\geq 0 - \gamma\geq 0 where : - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix @@ -167,13 +168,14 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, The function solves the following optimization problem: .. math:: - W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + W = \min_\gamma <\gamma, \mathbf{M}>_F + \mathrm{reg}\cdot\Omega(\gamma) + + s.t. \ \gamma \mathbf{1} &= \mathbf{a} - s.t. \ \gamma 1 = a + \gamma^T \mathbf{1} &= \mathbf{b} - \gamma^T 1= b + \gamma &\geq 0 - \gamma\geq 0 where : - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix @@ -323,13 +325,13 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, The function solves the following optimization problem: .. math:: - \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + \gamma = \mathop{\arg \min}_\gamma <\gamma, \mathbf{M}>_F + \mathrm{reg}\cdot\Omega(\gamma) - s.t. \gamma 1 = a + s.t. \ \gamma \mathbf{1} &= \mathbf{a} - \gamma^T 1= b + \gamma^T \mathbf{1} &= \mathbf{b} - \gamma\geq 0 + \gamma &\geq 0 where : - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix @@ -489,13 +491,13 @@ def sinkhorn_log(a, b, M, reg, numItermax=1000, The function solves the following optimization problem: .. math:: - \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + \gamma = \mathop{\arg \min}_\gamma <\gamma, \mathbf{M}>_F + \mathrm{reg}\cdot\Omega(\gamma) - s.t. \gamma 1 = a + s.t. \ \gamma \mathbf{1} &= \mathbf{a} - \gamma^T 1= b + \gamma^T \mathbf{1} &= \mathbf{b} - \gamma\geq 0 + \gamma &\geq 0 where : - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix @@ -550,8 +552,7 @@ def sinkhorn_log(a, b, M, reg, numItermax=1000, References ---------- - .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal - Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 .. [34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., Trouvé, A., & Peyré, G. (2019, April). Interpolating between optimal transport and MMD using Sinkhorn divergences. In The 22nd International Conference on Artificial Intelligence and Statistics (pp. 2681-2690). PMLR. @@ -675,13 +676,13 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, The function solves the following optimization problem: .. math:: - \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + \gamma = \mathop{\arg \min}_\gamma <\gamma, \mathbf{M}>_F + \mathrm{reg}\cdot\Omega(\gamma) - s.t. \ \gamma 1 = a + s.t. \ \gamma \mathbf{1} &= \mathbf{a} - \gamma^T 1= b + \gamma^T \mathbf{1} &= \mathbf{b} - \gamma\geq 0 + \gamma &\geq 0 where : - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix @@ -820,13 +821,13 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, The function solves the following optimization problem: .. math:: - \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + \gamma = \mathop{\arg \min}_\gamma <\gamma, \mathbf{M}>_F + \mathrm{reg}\cdot\Omega(\gamma) - s.t. \ \gamma 1 = a + s.t. \ \gamma \mathbf{1} &= \mathbf{a} - \gamma^T 1= b + \gamma^T \mathbf{1} &= \mathbf{b} - \gamma\geq 0 + \gamma &\geq 0 where : - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix @@ -965,7 +966,7 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, # remove numerical problems and store them in K if nx.max(nx.abs(u)) > tau or nx.max(nx.abs(v)) > tau: if n_hists: - alpha, beta = alpha + reg * nx.max(nx.log(u), 1), beta + reg * nx.max(np.log(v)) + alpha, beta = alpha + reg * nx.max(nx.log(u), 1), beta + reg * nx.max(nx.log(v)) else: alpha, beta = alpha + reg * nx.log(u), beta + reg * nx.log(v) if n_hists: @@ -1055,13 +1056,13 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, The function solves the following optimization problem: .. math:: - \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + \gamma = \mathop{\arg \min}_\gamma <\gamma, \mathbf{M}>_F + \mathrm{reg}\cdot\Omega(\gamma) - s.t. \ \gamma 1 = a + s.t. \ \gamma \mathbf{1} &= \mathbf{a} - \gamma^T 1= b + \gamma^T \mathbf{1} &= \mathbf{b} - \gamma\geq 0 + \gamma &\geq 0 where : - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix @@ -1245,12 +1246,12 @@ def projC(gamma, q): def barycenter(A, M, reg, weights=None, method="sinkhorn", numItermax=10000, stopThr=1e-4, verbose=False, log=False, **kwargs): - r"""Compute the entropic regularized wasserstein barycenter of distributions A + r"""Compute the entropic regularized wasserstein barycenter of distributions :math:`\mathbf{A}` The function solves the following optimization problem: .. math:: - \mathbf{a} = arg\min_\mathbf{a} \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i) + \mathbf{a} = \mathop{\arg \min}_\mathbf{a} \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i) where : @@ -1263,7 +1264,7 @@ def barycenter(A, M, reg, weights=None, method="sinkhorn", numItermax=10000, Parameters ---------- A : array-like, shape (dim, n_hists) - `n_hists` training distributions :math:`a_i` of size `dim` + `n_hists` training distributions :math:`\mathbf{a}_i` of size `dim` M : array-like, shape (dim, dim) loss matrix for OT reg : float @@ -1271,7 +1272,7 @@ def barycenter(A, M, reg, weights=None, method="sinkhorn", numItermax=10000, method : str (optional) method used for the solver either 'sinkhorn' or 'sinkhorn_stabilized' weights : array-like, shape (n_hists,) - Weights of each histogram :math:`a_i` on the simplex (barycentric coodinates) + Weights of each histogram :math:`\mathbf{a}_i` on the simplex (barycentric coodinates) numItermax : int, optional Max number of iterations stopThr : float, optional @@ -1314,12 +1315,12 @@ def barycenter(A, M, reg, weights=None, method="sinkhorn", numItermax=10000, def barycenter_sinkhorn(A, M, reg, weights=None, numItermax=1000, stopThr=1e-4, verbose=False, log=False): - r"""Compute the entropic regularized wasserstein barycenter of distributions A + r"""Compute the entropic regularized wasserstein barycenter of distributions :math:`\mathbf{A}` The function solves the following optimization problem: .. math:: - \mathbf{a} = arg\min_\mathbf{a} \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i) + \mathbf{a} = \mathop{\arg \min}_\mathbf{a} \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i) where : @@ -1332,13 +1333,13 @@ def barycenter_sinkhorn(A, M, reg, weights=None, numItermax=1000, Parameters ---------- A : array-like, shape (dim, n_hists) - `n_hists` training distributions :math:`a_i` of size `dim` + `n_hists` training distributions :math:`\mathbf{a}_i` of size `dim` M : array-like, shape (dim, dim) loss matrix for OT reg : float Regularization term > 0 weights : array-like, shape (n_hists,) - Weights of each histogram :math:`a_i` on the simplex (barycentric coodinates) + Weights of each histogram :math:`\mathbf{a}_i` on the simplex (barycentric coodinates) numItermax : int, optional Max number of iterations stopThr : float, optional @@ -1414,12 +1415,12 @@ def barycenter_sinkhorn(A, M, reg, weights=None, numItermax=1000, def barycenter_stabilized(A, M, reg, tau=1e10, weights=None, numItermax=1000, stopThr=1e-4, verbose=False, log=False): - r"""Compute the entropic regularized wasserstein barycenter of distributions A with stabilization. + r"""Compute the entropic regularized wasserstein barycenter of distributions :math:`\mathbf{A}` with stabilization. The function solves the following optimization problem: .. math:: - \mathbf{a} = arg\min_\mathbf{a} \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i) + \mathbf{a} = \mathop{\arg \min}_\mathbf{a} \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i) where : @@ -1432,7 +1433,7 @@ def barycenter_stabilized(A, M, reg, tau=1e10, weights=None, numItermax=1000, Parameters ---------- A : array-like, shape (dim, n_hists) - `n_hists` training distributions :math:`a_i` of size `dim` + `n_hists` training distributions :math:`\mathbf{a}_i` of size `dim` M : array-like, shape (dim, dim) loss matrix for OT reg : float @@ -1440,7 +1441,7 @@ def barycenter_stabilized(A, M, reg, tau=1e10, weights=None, numItermax=1000, tau : float threshold for max value in :math:`\mathbf{u}` or :math:`\mathbf{v}` for log scaling weights : array-like, shape (n_hists,) - Weights of each histogram :math:`a_i` on the simplex (barycentric coodinates) + Weights of each histogram :math:`\mathbf{a}_i` on the simplex (barycentric coodinates) numItermax : int, optional Max number of iterations stopThr : float, optional @@ -1533,8 +1534,8 @@ def barycenter_stabilized(A, M, reg, tau=1e10, weights=None, numItermax=1000, "Or a larger absorption threshold `tau`.") if log: log['niter'] = cpt - log['logu'] = np.log(u + 1e-16) - log['logv'] = np.log(v + 1e-16) + log['logu'] = nx.log(u + 1e-16) + log['logv'] = nx.log(v + 1e-16) return q, log else: return q @@ -1543,13 +1544,13 @@ def barycenter_stabilized(A, M, reg, tau=1e10, weights=None, numItermax=1000, def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, stopThr=1e-9, stabThr=1e-30, verbose=False, log=False): - r"""Compute the entropic regularized wasserstein barycenter of distributions A - where A is a collection of 2D images. + r"""Compute the entropic regularized wasserstein barycenter of distributions :math:`\mathbf{A}` + where :math:`\mathbf{A}` is a collection of 2D images. The function solves the following optimization problem: .. math:: - \mathbf{a} = arg\min_\mathbf{a} \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i) + \mathbf{a} = \mathop{\arg \min}_\mathbf{a} \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i) where : @@ -1673,12 +1674,12 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000, .. math:: - \mathbf{h} = arg\min_\mathbf{h} (1- \alpha) W_{M,reg}(\mathbf{a},\mathbf{Dh})+\alpha W_{M_0,reg_0}(\mathbf{h}_0,\mathbf{h}) + \mathbf{h} = \mathop{\arg \min}_\mathbf{h} (1- \alpha) W_{\mathbf{M}, \mathrm{reg}}(\mathbf{a},\mathbf{Dh})+\alpha W_{\mathbf{M_0},\mathrm{reg}_0}(\mathbf{h}_0,\mathbf{h}) where : - - :math:`W_{M,reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance with M loss matrix (see :py:func:`ot.bregman.sinkhorn`) + - :math:`W_{M,reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance with :math:`\mathbf{M}` loss matrix (see :py:func:`ot.bregman.sinkhorn`) - :math:`\mathbf{D}` is a dictionary of `n_atoms` atoms of dimension `dim_a`, its expected shape is `(dim_a, n_atoms)` - :math:`\mathbf{h}` is the estimated unmixing of dimension `n_atoms` - :math:`\mathbf{a}` is an observed distribution of dimension `dim_a` @@ -1790,7 +1791,7 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, .. math:: - \mathbf{h} = arg\min_{\mathbf{h}}\quad \sum_{k=1}^{K} \lambda_k + \mathbf{h} = \mathop{\arg \min}_{\mathbf{h}} \sum_{k=1}^{K} \lambda_k W_{reg}((\mathbf{D}_2^{(k)} \mathbf{h})^T, \mathbf{a}) s.t. \ \forall k, \mathbf{D}_1^{(k)} \gamma_k \mathbf{1}_n= \mathbf{h} @@ -1898,7 +1899,7 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, K.append(Ktmp) # uniform target distribution - a = nx.from_numpy(unif(np.shape(Xt)[0])) + a = nx.from_numpy(unif(Xt.shape[0]), type_as=Xs[0]) cpt = 0 # iterations count err = 1 @@ -1956,13 +1957,13 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', The function solves the following optimization problem: .. math:: - \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + \gamma = \mathop{\arg \min}_\gamma <\gamma, \mathbf{M}>_F + \mathrm{reg} \cdot\Omega(\gamma) - s.t. \ \gamma 1 = a + s.t. \ \gamma \mathbf{1} &= a - \gamma^T 1= b + \gamma^T \mathbf{1} &= b - \gamma\geq 0 + \gamma &\geq 0 where : - :math:`\mathbf{M}` is the (`n_samples_a`, `n_samples_b`) metric cost matrix @@ -2010,8 +2011,8 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', >>> n_samples_a = 2 >>> n_samples_b = 2 >>> reg = 0.1 - >>> X_s = np.reshape(np.arange(n_samples_a), (n_samples_a, 1)) - >>> X_t = np.reshape(np.arange(0, n_samples_b), (n_samples_b, 1)) + >>> X_s = np.reshape(np.arange(n_samples_a, dtype=np.float64), (n_samples_a, 1)) + >>> X_t = np.reshape(np.arange(0, n_samples_b, dtype=np.float64), (n_samples_b, 1)) >>> empirical_sinkhorn(X_s, X_t, reg=reg, verbose=False) # doctest: +NORMALIZE_WHITESPACE array([[4.99977301e-01, 2.26989344e-05], [2.26989344e-05, 4.99977301e-01]]) @@ -2033,9 +2034,9 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', ns, nt = X_s.shape[0], X_t.shape[0] if a is None: - a = nx.from_numpy(unif(ns)) + a = nx.from_numpy(unif(ns), type_as=X_s) if b is None: - b = nx.from_numpy(unif(nt)) + b = nx.from_numpy(unif(nt), type_as=X_s) if isLazy: if log: @@ -2127,13 +2128,13 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num The function solves the following optimization problem: .. math:: - W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + W = \min_\gamma <\gamma, \mathbf{M}>_F + \mathrm{reg} \cdot\Omega(\gamma) - s.t. \ \gamma 1 = a + s.t. \ \gamma \mathbf{1} &= a - \gamma^T 1= b + \gamma^T \mathbf{1} &= b - \gamma\geq 0 + \gamma &\geq 0 where : - :math:`\mathbf{M}` is the (`n_samples_a`, `n_samples_b`) metric cost matrix @@ -2181,8 +2182,8 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num >>> n_samples_a = 2 >>> n_samples_b = 2 >>> reg = 0.1 - >>> X_s = np.reshape(np.arange(n_samples_a), (n_samples_a, 1)) - >>> X_t = np.reshape(np.arange(0, n_samples_b), (n_samples_b, 1)) + >>> X_s = np.reshape(np.arange(n_samples_a, dtype=np.float64), (n_samples_a, 1)) + >>> X_t = np.reshape(np.arange(0, n_samples_b, dtype=np.float64), (n_samples_b, 1)) >>> b = np.full((n_samples_b, 3), 1/n_samples_b) >>> empirical_sinkhorn2(X_s, X_t, b=b, reg=reg, verbose=False) array([4.53978687e-05, 4.53978687e-05, 4.53978687e-05]) @@ -2204,9 +2205,9 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num ns, nt = X_s.shape[0], X_t.shape[0] if a is None: - a = nx.from_numpy(unif(ns)) + a = nx.from_numpy(unif(ns), type_as=X_s) if b is None: - b = nx.from_numpy(unif(nt)) + b = nx.from_numpy(unif(nt), type_as=X_s) if isLazy: if log: @@ -2259,32 +2260,32 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli .. math:: - W &= \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + W &= \min_\gamma <\gamma, \mathbf{M}>_F + \mathrm{reg} \cdot\Omega(\gamma) - W_a &= \min_{\gamma_a} <\gamma_a,M_a>_F + reg\cdot\Omega(\gamma_a) + W_a &= \min_{\gamma_a} <\gamma_a, \mathbf{M_a}>_F + \mathrm{reg} \cdot\Omega(\gamma_a) - W_b &= \min_{\gamma_b} <\gamma_b,M_b>_F + reg\cdot\Omega(\gamma_b) + W_b &= \min_{\gamma_b} <\gamma_b, \mathbf{M_b}>_F + \mathrm{reg} \cdot\Omega(\gamma_b) S &= W - \frac{W_a + W_b}{2} .. math:: - s.t. \ \gamma 1 = a + s.t. \ \gamma \mathbf{1} &= a - \gamma^T 1= b + \gamma^T \mathbf{1} &= b - \gamma\geq 0 + \gamma &\geq 0 - \gamma_a 1 = a + \gamma_a \mathbf{1} &= \mathbf{a} - \gamma_a^T 1= a + \gamma_a^T \mathbf{1} &= \mathbf{a} - \gamma_a\geq 0 + \gamma_a &\geq 0 - \gamma_b 1 = b + \gamma_b \mathbf{1} &= \mathbf{b} - \gamma_b^T 1= b + \gamma_b^T \mathbf{1} &= \mathbf{b} - \gamma_b\geq 0 + \gamma_b &\geq 0 where : - :math:`\mathbf{M}` (resp. :math:`\mathbf{M_a}`, :math:`\mathbf{M_b}`) is the (`n_samples_a`, `n_samples_b`) metric cost matrix (resp (`n_samples_a, n_samples_a`) and (`n_samples_b`, `n_samples_b`)) @@ -2325,8 +2326,8 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli >>> n_samples_a = 2 >>> n_samples_b = 4 >>> reg = 0.1 - >>> X_s = np.reshape(np.arange(n_samples_a), (n_samples_a, 1)) - >>> X_t = np.reshape(np.arange(0, n_samples_b), (n_samples_b, 1)) + >>> X_s = np.reshape(np.arange(n_samples_a, dtype=np.float64), (n_samples_a, 1)) + >>> X_t = np.reshape(np.arange(0, n_samples_b, dtype=np.float64), (n_samples_b, 1)) >>> empirical_sinkhorn_divergence(X_s, X_t, reg) # doctest: +ELLIPSIS 1.499887176049052 @@ -2380,19 +2381,19 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res .. math:: - (u, v) = arg\min_{u, v} 1_{ns}^T B(u,v) 1_{nt} - <\kappa u, a> - + (\mathbf{u}, \mathbf{v}) = \mathop{\arg \min}_{\mathbf{u}, \mathbf{v}} \ \mathbf{1}_{ns}^T \mathbf{B}(\mathbf{u}, \mathbf{v}) \mathbf{1}_{nt} - <\kappa \mathbf{u}, \mathbf{a}> - <\mathbf{v} / \kappa, \mathbf{b}> where: .. math:: - B(u,v) = \mathrm{diag}(e^u) K \mathrm{diag}(e^v) \text{, with } K = e^{-M/reg} \text{ and} + \mathbf{B}(\mathbf{u}, \mathbf{v}) = \mathrm{diag}(e^\mathbf{u}) \mathbf{K} \mathrm{diag}(e^\mathbf{v}) \text{, with } \mathbf{K} = e^{-\mathbf{M} / \mathrm{reg}} \text{ and} .. math:: - s.t. \ e^{u_i} \geq \epsilon / \kappa, \forall i \in \{1, \ldots, ns\} + s.t. \ e^{u_i} &\geq \epsilon / \kappa, \forall i \in \{1, \ldots, ns\} - e^{v_j} \geq \epsilon \kappa, \forall j \in \{1, \ldots, nt\} + e^{v_j} &\geq \epsilon \kappa, \forall j \in \{1, \ldots, nt\} The parameters `kappa` and `epsilon` are determined w.r.t the couple number budget of points (`ns_budget`, `nt_budget`), see Equation (5) in :ref:`[26] ` @@ -2531,7 +2532,8 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res epsilon_u_square = a[0] / aK_sort[ns_budget - 1] else: aK_sort = nx.from_numpy( - bottleneck.partition(nx.to_numpy(K_sum_cols), ns_budget - 1)[ns_budget - 1] + bottleneck.partition(nx.to_numpy(K_sum_cols), ns_budget - 1)[ns_budget - 1], + type_as=M ) epsilon_u_square = a[0] / aK_sort @@ -2540,7 +2542,8 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res epsilon_v_square = b[0] / bK_sort[nt_budget - 1] else: bK_sort = nx.from_numpy( - bottleneck.partition(nx.to_numpy(K_sum_rows), nt_budget - 1)[nt_budget - 1] + bottleneck.partition(nx.to_numpy(K_sum_rows), nt_budget - 1)[nt_budget - 1], + type_as=M ) epsilon_v_square = b[0] / bK_sort else: @@ -2589,10 +2592,9 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res K_IcJ = K[np.ix_(Ic, Jsel)] K_IJc = K[np.ix_(Isel, Jc)] - #K_min = K_IJ.min() K_min = nx.min(K_IJ) if K_min == 0: - K_min = np.finfo(float).tiny + K_min = float(np.finfo(float).tiny) # a_I, b_J, a_Ic, b_Jc a_I = a[Isel] @@ -2713,7 +2715,7 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res maxfun=maxfun, pgtol=pgtol, maxiter=maxiter) - theta = nx.from_numpy(theta) + theta = nx.from_numpy(theta, type_as=M) usc = theta[:ns_budget] vsc = theta[ns_budget:] diff --git a/ot/gromov.py b/ot/gromov.py index 33b4453..a0fbf48 100644 --- a/ot/gromov.py +++ b/ot/gromov.py @@ -14,67 +14,85 @@ import numpy as np from .bregman import sinkhorn -from .utils import dist, UndefinedParameter +from .utils import dist, UndefinedParameter, list_to_array from .optim import cg from .lp import emd_1d, emd from .utils import check_random_state - -from scipy.sparse import issparse +from .backend import get_backend def init_matrix(C1, C2, p, q, loss_fun='square_loss'): r"""Return loss matrices and tensors for Gromov-Wasserstein fast computation - Returns the value of \mathcal{L}(C1,C2) \otimes T with the selected loss - function as the loss function of Gromow-Wasserstein discrepancy. + Returns the value of :math:`\mathcal{L}(\mathbf{C_1}, \mathbf{C_2}) \otimes \mathbf{T}` with the + selected loss function as the loss function of Gromow-Wasserstein discrepancy. - The matrices are computed as described in Proposition 1 in [12] + The matrices are computed as described in Proposition 1 in :ref:`[12] ` Where : - * C1 : Metric cost matrix in the source space - * C2 : Metric cost matrix in the target space - * T : A coupling between those two spaces - - The square-loss function L(a,b)=|a-b|^2 is read as : - L(a,b) = f1(a)+f2(b)-h1(a)*h2(b) with : - * f1(a)=(a^2) - * f2(b)=(b^2) - * h1(a)=a - * h2(b)=2*b - - The kl-loss function L(a,b)=a*log(a/b)-a+b is read as : - L(a,b) = f1(a)+f2(b)-h1(a)*h2(b) with : - * f1(a)=a*log(a)-a - * f2(b)=b - * h1(a)=a - * h2(b)=log(b) + + - :math:`\mathbf{C_1}`: Metric cost matrix in the source space + - :math:`\mathbf{C_2}`: Metric cost matrix in the target space + - :math:`\mathbf{T}`: A coupling between those two spaces + + The square-loss function :math:`L(a, b) = |a - b|^2` is read as : + + .. math:: + + L(a, b) = f_1(a) + f_2(b) - h_1(a) h_2(b) + + \mathrm{with} \ f_1(a) &= a^2 + + f_2(b) &= b^2 + + h_1(a) &= a + + h_2(b) &= 2b + + The kl-loss function :math:`L(a, b) = a \log\left(\frac{a}{b}\right) - a + b` is read as : + + .. math:: + + L(a, b) = f_1(a) + f_2(b) - h_1(a) h_2(b) + + \mathrm{with} \ f_1(a) &= a \log(a) - a + + f_2(b) &= b + + h_1(a) &= a + + h_2(b) &= \log(b) Parameters ---------- - C1 : ndarray, shape (ns, ns) + C1 : array-like, shape (ns, ns) Metric cost matrix in the source space - C2 : ndarray, shape (nt, nt) - Metric costfr matrix in the target space - T : ndarray, shape (ns, nt) + C2 : array-like, shape (nt, nt) + Metric cost matrix in the target space + T : array-like, shape (ns, nt) Coupling between source and target spaces - p : ndarray, shape (ns,) + p : array-like, shape (ns,) Returns ------- - constC : ndarray, shape (ns, nt) - Constant C matrix in Eq. (6) - hC1 : ndarray, shape (ns, ns) - h1(C1) matrix in Eq. (6) - hC2 : ndarray, shape (nt, nt) - h2(C) matrix in Eq. (6) + constC : array-like, shape (ns, nt) + Constant :math:`\mathbf{C}` matrix in Eq. (6) + hC1 : array-like, shape (ns, ns) + :math:`\mathbf{h1}(\mathbf{C1})` matrix in Eq. (6) + hC2 : array-like, shape (nt, nt) + :math:`\mathbf{h2}(\mathbf{C2})` matrix in Eq. (6) + + .. _references-init-matrix: References ---------- - .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon, + .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon, "Gromov-Wasserstein averaging of kernel and distance matrices." International Conference on Machine Learning (ICML). 2016. """ + C1, C2, p, q = list_to_array(C1, C2, p, q) + nx = get_backend(C1, C2, p, q) if loss_fun == 'square_loss': def f1(a): @@ -90,7 +108,7 @@ def init_matrix(C1, C2, p, q, loss_fun='square_loss'): return 2 * b elif loss_fun == 'kl_loss': def f1(a): - return a * np.log(a + 1e-15) - a + return a * nx.log(a + 1e-15) - a def f2(b): return b @@ -99,12 +117,16 @@ def init_matrix(C1, C2, p, q, loss_fun='square_loss'): return a def h2(b): - return np.log(b + 1e-15) - - constC1 = np.dot(np.dot(f1(C1), p.reshape(-1, 1)), - np.ones(len(q)).reshape(1, -1)) - constC2 = np.dot(np.ones(len(p)).reshape(-1, 1), - np.dot(q.reshape(1, -1), f2(C2).T)) + return nx.log(b + 1e-15) + + constC1 = nx.dot( + nx.dot(f1(C1), nx.reshape(p, (-1, 1))), + nx.ones((1, len(q)), type_as=q) + ) + constC2 = nx.dot( + nx.ones((len(p), 1), type_as=p), + nx.dot(nx.reshape(q, (1, -1)), f2(C2).T) + ) constC = constC1 + constC2 hC1 = h1(C1) hC2 = h2(C2) @@ -115,30 +137,37 @@ def init_matrix(C1, C2, p, q, loss_fun='square_loss'): def tensor_product(constC, hC1, hC2, T): r"""Return the tensor for Gromov-Wasserstein fast computation - The tensor is computed as described in Proposition 1 Eq. (6) in [12]. + The tensor is computed as described in Proposition 1 Eq. (6) in :ref:`[12] ` Parameters ---------- - constC : ndarray, shape (ns, nt) - Constant C matrix in Eq. (6) - hC1 : ndarray, shape (ns, ns) - h1(C1) matrix in Eq. (6) - hC2 : ndarray, shape (nt, nt) - h2(C) matrix in Eq. (6) + constC : array-like, shape (ns, nt) + Constant :math:`\mathbf{C}` matrix in Eq. (6) + hC1 : array-like, shape (ns, ns) + :math:`\mathbf{h1}(\mathbf{C1})` matrix in Eq. (6) + hC2 : array-like, shape (nt, nt) + :math:`\mathbf{h2}(\mathbf{C2})` matrix in Eq. (6) Returns ------- - tens : ndarray, shape (ns, nt) - \mathcal{L}(C1,C2) \otimes T tensor-matrix multiplication result + tens : array-like, shape (`ns`, `nt`) + :math:`\mathcal{L}(\mathbf{C_1}, \mathbf{C_2}) \otimes \mathbf{T}` tensor-matrix multiplication result + + .. _references-tensor-product: References ---------- - .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon, + .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon, "Gromov-Wasserstein averaging of kernel and distance matrices." International Conference on Machine Learning (ICML). 2016. """ - A = -np.dot(hC1, T).dot(hC2.T) + constC, hC1, hC2, T = list_to_array(constC, hC1, hC2, T) + nx = get_backend(constC, hC1, hC2, T) + + A = - nx.dot( + nx.dot(hC1, T), hC2.T + ) tens = constC + A # tens -= tens.min() return tens @@ -147,27 +176,29 @@ def tensor_product(constC, hC1, hC2, T): def gwloss(constC, hC1, hC2, T): """Return the Loss for Gromov-Wasserstein - The loss is computed as described in Proposition 1 Eq. (6) in [12]. + The loss is computed as described in Proposition 1 Eq. (6) in :ref:`[12] ` Parameters ---------- - constC : ndarray, shape (ns, nt) - Constant C matrix in Eq. (6) - hC1 : ndarray, shape (ns, ns) - h1(C1) matrix in Eq. (6) - hC2 : ndarray, shape (nt, nt) - h2(C) matrix in Eq. (6) - T : ndarray, shape (ns, nt) - Current value of transport matrix T + constC : array-like, shape (ns, nt) + Constant :math:`\mathbf{C}` matrix in Eq. (6) + hC1 : array-like, shape (ns, ns) + :math:`\mathbf{h1}(\mathbf{C1})` matrix in Eq. (6) + hC2 : array-like, shape (nt, nt) + :math:`\mathbf{h2}(\mathbf{C2})` matrix in Eq. (6) + T : array-like, shape (ns, nt) + Current value of transport matrix :math:`\mathbf{T}` Returns ------- loss : float Gromov Wasserstein loss + + .. _references-gwloss: References ---------- - .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon, + .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon, "Gromov-Wasserstein averaging of kernel and distance matrices." International Conference on Machine Learning (ICML). 2016. @@ -175,33 +206,38 @@ def gwloss(constC, hC1, hC2, T): tens = tensor_product(constC, hC1, hC2, T) - return np.sum(tens * T) + tens, T = list_to_array(tens, T) + nx = get_backend(tens, T) + + return nx.sum(tens * T) def gwggrad(constC, hC1, hC2, T): """Return the gradient for Gromov-Wasserstein - The gradient is computed as described in Proposition 2 in [12]. + The gradient is computed as described in Proposition 2 in :ref:`[12] ` Parameters ---------- - constC : ndarray, shape (ns, nt) - Constant C matrix in Eq. (6) - hC1 : ndarray, shape (ns, ns) - h1(C1) matrix in Eq. (6) - hC2 : ndarray, shape (nt, nt) - h2(C) matrix in Eq. (6) - T : ndarray, shape (ns, nt) - Current value of transport matrix T + constC : array-like, shape (ns, nt) + Constant :math:`\mathbf{C}` matrix in Eq. (6) + hC1 : array-like, shape (ns, ns) + :math:`\mathbf{h1}(\mathbf{C1})` matrix in Eq. (6) + hC2 : array-like, shape (nt, nt) + :math:`\mathbf{h2}(\mathbf{C2})` matrix in Eq. (6) + T : array-like, shape (ns, nt) + Current value of transport matrix :math:`\mathbf{T}` Returns ------- - grad : ndarray, shape (ns, nt) + grad : array-like, shape (`ns`, `nt`) Gromov Wasserstein gradient + + .. _references-gwggrad: References ---------- - .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon, + .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon, "Gromov-Wasserstein averaging of kernel and distance matrices." International Conference on Machine Learning (ICML). 2016. @@ -212,88 +248,107 @@ def gwggrad(constC, hC1, hC2, T): def update_square_loss(p, lambdas, T, Cs): """ - Updates C according to the L2 Loss kernel with the S Ts couplings - calculated at each iteration + Updates :math:`\mathbf{C}` according to the L2 Loss kernel with the `S` :math:`\mathbf{T}_s` + couplings calculated at each iteration Parameters ---------- - p : ndarray, shape (N,) + p : array-like, shape (N,) Masses in the targeted barycenter. lambdas : list of float - List of the S spaces' weights. - T : list of S np.ndarray of shape (ns,N) - The S Ts couplings calculated at each iteration. - Cs : list of S ndarray, shape(ns,ns) + List of the `S` spaces' weights. + T : list of S array-like of shape (ns,N) + The `S` :math:`\mathbf{T}_s` couplings calculated at each iteration. + Cs : list of S array-like, shape(ns,ns) Metric cost matrices. Returns ---------- - C : ndarray, shape (nt, nt) - Updated C matrix. + C : array-like, shape (`nt`, `nt`) + Updated :math:`\mathbf{C}` matrix. """ - tmpsum = sum([lambdas[s] * np.dot(T[s].T, Cs[s]).dot(T[s]) - for s in range(len(T))]) - ppt = np.outer(p, p) + T = list_to_array(*T) + Cs = list_to_array(*Cs) + p = list_to_array(p) + nx = get_backend(p, *T, *Cs) + + tmpsum = sum([ + lambdas[s] * nx.dot( + nx.dot(T[s].T, Cs[s]), + T[s] + ) for s in range(len(T)) + ]) + ppt = nx.outer(p, p) - return np.divide(tmpsum, ppt) + return tmpsum / ppt def update_kl_loss(p, lambdas, T, Cs): """ - Updates C according to the KL Loss kernel with the S Ts couplings calculated at each iteration + Updates :math:`\mathbf{C}` according to the KL Loss kernel with the `S` :math:`\mathbf{T}_s` couplings calculated at each iteration Parameters ---------- - p : ndarray, shape (N,) + p : array-like, shape (N,) Weights in the targeted barycenter. - lambdas : list of the S spaces' weights - T : list of S np.ndarray of shape (ns,N) - The S Ts couplings calculated at each iteration. - Cs : list of S ndarray, shape(ns,ns) + lambdas : list of float + List of the `S` spaces' weights + T : list of S array-like of shape (ns,N) + The `S` :math:`\mathbf{T}_s` couplings calculated at each iteration. + Cs : list of S array-like, shape(ns,ns) Metric cost matrices. Returns ---------- - C : ndarray, shape (ns,ns) - updated C matrix + C : array-like, shape (`ns`, `ns`) + updated :math:`\mathbf{C}` matrix """ - tmpsum = sum([lambdas[s] * np.dot(T[s].T, Cs[s]).dot(T[s]) - for s in range(len(T))]) - ppt = np.outer(p, p) + Cs = list_to_array(*Cs) + T = list_to_array(*T) + p = list_to_array(p) + nx = get_backend(p, *T, *Cs) - return np.exp(np.divide(tmpsum, ppt)) + tmpsum = sum([ + lambdas[s] * nx.dot( + nx.dot(T[s].T, Cs[s]), + T[s] + ) for s in range(len(T)) + ]) + ppt = nx.outer(p, p) + + return nx.exp(tmpsum / ppt) def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs): r""" - Returns the gromov-wasserstein transport between (C1,p) and (C2,q) + Returns the gromov-wasserstein transport between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` The function solves the following optimization problem: .. math:: - GW = \min_T \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l} + \mathbf{GW} = \mathop{\arg \min}_\mathbf{T} \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} Where : - - C1 : Metric cost matrix in the source space - - C2 : Metric cost matrix in the target space - - p : distribution in the source space - - q : distribution in the target space - - L : loss function to account for the misfit between the similarity matrices + + - :math:`\mathbf{C_1}`: Metric cost matrix in the source space + - :math:`\mathbf{C_2}`: Metric cost matrix in the target space + - :math:`\mathbf{p}`: distribution in the source space + - :math:`\mathbf{q}`: distribution in the target space + - `L`: loss function to account for the misfit between the similarity matrices Parameters ---------- - C1 : ndarray, shape (ns, ns) + C1 : array-like, shape (ns, ns) Metric cost matrix in the source space - C2 : ndarray, shape (nt, nt) - Metric costfr matrix in the target space - p : ndarray, shape (ns,) + C2 : array-like, shape (nt, nt) + Metric cost matrix in the target space + p : array-like, shape (ns,) Distribution in the source space - q : ndarray, shape (nt,) + q : array-like, shape (nt,) Distribution in the target space loss_fun : str loss function used for the solver either 'square_loss' or 'kl_loss' - max_iter : int, optional Max number of iterations tol : float, optional @@ -303,22 +358,23 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs log : bool, optional record log if True armijo : bool, optional - If True the steps of the line-search is found via an armijo research. Else closed form is used. - If there is convergence issues use False. + If True the step of the line-search is found via an armijo research. Else closed form is used. + If there are convergence issues use False. **kwargs : dict parameters can be directly passed to the ot.optim.cg solver Returns ------- - T : ndarray, shape (ns, nt) - Doupling between the two spaces that minimizes: - \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l} + T : array-like, shape (`ns`, `nt`) + Coupling between the two spaces that minimizes: + + :math:`\sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l}` log : dict Convergence information and loss. References ---------- - .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon, + .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon, "Gromov-Wasserstein averaging of kernel and distance matrices." International Conference on Machine Learning (ICML). 2016. @@ -327,6 +383,7 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs mathematics 11.4 (2011): 417-487. """ + p, q = list_to_array(p, q) constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun) @@ -348,29 +405,30 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs): r""" - Returns the gromov-wasserstein discrepancy between (C1,p) and (C2,q) + Returns the gromov-wasserstein discrepancy between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` The function solves the following optimization problem: .. math:: - GW = \min_T \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l} + GW = \min_\mathbf{T} \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} Where : - - C1 : Metric cost matrix in the source space - - C2 : Metric cost matrix in the target space - - p : distribution in the source space - - q : distribution in the target space - - L : loss function to account for the misfit between the similarity matrices + + - :math:`\mathbf{C_1}`: Metric cost matrix in the source space + - :math:`\mathbf{C_2}`: Metric cost matrix in the target space + - :math:`\mathbf{p}`: distribution in the source space + - :math:`\mathbf{q}`: distribution in the target space + - `L`: loss function to account for the misfit between the similarity matrices Parameters ---------- - C1 : ndarray, shape (ns, ns) + C1 : array-like, shape (ns, ns) Metric cost matrix in the source space - C2 : ndarray, shape (nt, nt) + C2 : array-like, shape (nt, nt) Metric cost matrix in the target space - p : ndarray, shape (ns,) + p : array-like, shape (ns,) Distribution in the source space. - q : ndarray, shape (nt,) + q : array-like, shape (nt,) Distribution in the target space. loss_fun : str loss function used for the solver either 'square_loss' or 'kl_loss' @@ -383,8 +441,8 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwarg log : bool, optional record log if True armijo : bool, optional - If True the steps of the line-search is found via an armijo research. Else closed form is used. - If there is convergence issues use False. + If True the step of the line-search is found via an armijo research. Else closed form is used. + If there are convergence issues use False. Returns ------- @@ -395,7 +453,7 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwarg References ---------- - .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon, + .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon, "Gromov-Wasserstein averaging of kernel and distance matrices." International Conference on Machine Learning (ICML). 2016. @@ -404,6 +462,7 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwarg mathematics 11.4 (2011): 417-487. """ + p, q = list_to_array(p, q) constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun) @@ -425,42 +484,45 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwarg def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, **kwargs): r""" - Computes the FGW transport between two graphs see [24] + Computes the FGW transport between two graphs (see :ref:`[24] `) .. math:: - \gamma = arg\min_\gamma (1-\\alpha)*<\gamma,M>_F + \\alpha* \sum_{i,j,k,l} - L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l} + \gamma = \mathop{\arg \min}_\gamma (1 - \alpha) <\gamma, \mathbf{M}>_F + \alpha \sum_{i,j,k,l} + L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} + + s.t. \ \mathbf{\gamma} \mathbf{1} &= \mathbf{p} - s.t. \gamma 1 = p - \gamma^T 1= q - \gamma\geq 0 + \mathbf{\gamma}^T \mathbf{1} &= \mathbf{q} + + \mathbf{\gamma} &\geq 0 where : - - M is the (ns,nt) metric cost matrix - - p and q are source and target weights (sum to 1) - - L is a loss function to account for the misfit between the similarity matrices - The algorithm used for solving the problem is conditional gradient as discussed in [24]_ + - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix + - :math:`\mathbf{p}` and :math:`\mathbf{q}` are source and target weights (sum to 1) + - `L` is a loss function to account for the misfit between the similarity matrices + + The algorithm used for solving the problem is conditional gradient as discussed in :ref:`[24] ` Parameters ---------- - M : ndarray, shape (ns, nt) + M : array-like, shape (ns, nt) Metric cost matrix between features across domains - C1 : ndarray, shape (ns, ns) + C1 : array-like, shape (ns, ns) Metric cost matrix representative of the structure in the source space - C2 : ndarray, shape (nt, nt) + C2 : array-like, shape (nt, nt) Metric cost matrix representative of the structure in the target space - p : ndarray, shape (ns,) + p : array-like, shape (ns,) Distribution in the source space - q : ndarray, shape (nt,) + q : array-like, shape (nt,) Distribution in the target space loss_fun : str, optional Loss function used for the solver alpha : float, optional Trade-off parameter (0 < alpha < 1) armijo : bool, optional - If True the steps of the line-search is found via an armijo research. Else closed form is used. - If there is convergence issues use False. + If True the step of the line-search is found via an armijo research. Else closed form is used. + If there are convergence issues use False. log : bool, optional record log if True **kwargs : dict @@ -468,18 +530,21 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, Returns ------- - gamma : ndarray, shape (ns, nt) + gamma : array-like, shape (`ns`, `nt`) Optimal transportation matrix for the given parameters. log : dict Log dictionary return only if log==True in parameters. + + .. _references-fused-gromov-wasserstein: References ---------- - .. [24] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain + .. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain and Courty Nicolas "Optimal Transport for structured data with application on graphs", International Conference on Machine Learning (ICML). 2019. """ + p, q = list_to_array(p, q) constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun) @@ -501,61 +566,67 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, **kwargs): r""" - Computes the FGW distance between two graphs see [24] + Computes the FGW distance between two graphs see (see :ref:`[24] `) .. math:: - \min_\gamma (1-\\alpha)*<\gamma,M>_F + \\alpha* \sum_{i,j,k,l} - L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l} + \min_\gamma (1 - \alpha) <\gamma, \mathbf{M}>_F + \alpha \sum_{i,j,k,l} + L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} + s.t. \ \mathbf{\gamma} \mathbf{1} &= \mathbf{p} - s.t. \gamma 1 = p - \gamma^T 1= q - \gamma\geq 0 + \mathbf{\gamma}^T \mathbf{1} &= \mathbf{q} + + \mathbf{\gamma} &\geq 0 where : - - M is the (ns,nt) metric cost matrix - - p and q are source and target weights (sum to 1) - - L is a loss function to account for the misfit between the similarity matrices - The algorithm used for solving the problem is conditional gradient as discussed in [1]_ + + - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix + - :math:`\mathbf{p}` and :math:`\mathbf{q}` are source and target weights (sum to 1) + - `L` is a loss function to account for the misfit between the similarity matrices + + The algorithm used for solving the problem is conditional gradient as discussed in :ref:`[24] ` Parameters ---------- - M : ndarray, shape (ns, nt) + M : array-like, shape (ns, nt) Metric cost matrix between features across domains - C1 : ndarray, shape (ns, ns) + C1 : array-like, shape (ns, ns) Metric cost matrix respresentative of the structure in the source space. - C2 : ndarray, shape (nt, nt) + C2 : array-like, shape (nt, nt) Metric cost matrix espresentative of the structure in the target space. - p : ndarray, shape (ns,) + p : array-like, shape (ns,) Distribution in the source space. - q : ndarray, shape (nt,) + q : array-like, shape (nt,) Distribution in the target space. loss_fun : str, optional Loss function used for the solver. alpha : float, optional Trade-off parameter (0 < alpha < 1) armijo : bool, optional - If True the steps of the line-search is found via an armijo research. - Else closed form is used. If there is convergence issues use False. + If True the step of the line-search is found via an armijo research. + Else closed form is used. If there are convergence issues use False. log : bool, optional Record log if True. **kwargs : dict - Parameters can be directly pased to the ot.optim.cg solver. + Parameters can be directly passed to the ot.optim.cg solver. Returns ------- - gamma : ndarray, shape (ns, nt) + gamma : array-like, shape (ns, nt) Optimal transportation matrix for the given parameters. log : dict Log dictionary return only if log==True in parameters. + + .. _references-fused-gromov-wasserstein2: References ---------- - .. [24] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain + .. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain and Courty Nicolas "Optimal Transport for structured data with application on graphs" International Conference on Machine Learning (ICML). 2019. """ + p, q = list_to_array(p, q) constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun) @@ -579,60 +650,64 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5 def GW_distance_estimation(C1, C2, p, q, loss_fun, T, nb_samples_p=None, nb_samples_q=None, std=True, random_state=None): r""" - Returns an approximation of the gromov-wasserstein cost between (C1,p) and (C2,q) - with a fixed transport plan T. - - The function gives an unbiased approximation of the following equation: - - .. math:: - GW = \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l} - - Where : - - - C1 : Metric cost matrix in the source space - - C2 : Metric cost matrix in the target space - - L : Loss function to account for the misfit between the similarity matrices - - T : Matrix with marginal p and q - - Parameters - ---------- - C1 : ndarray, shape (ns, ns) - Metric cost matrix in the source space - C2 : ndarray, shape (nt, nt) - Metric costfr matrix in the target space - p : ndarray, shape (ns,) - Distribution in the source space - q : ndarray, shape (nt,) - Distribution in the target space - loss_fun : function: \mathcal{R} \times \mathcal{R} \shortarrow \mathcal{R} - Loss function used for the distance, the transport plan does not depend on the loss function - T : csr or ndarray, shape (ns, nt) - Transport plan matrix, either a sparse csr matrix or - nb_samples_p : int, optional - nb_samples_p is the number of samples (without replacement) along the first dimension of T. - nb_samples_q : int, optional - nb_samples_q is the number of samples along the second dimension of T, for each sample along the first. - std : bool, optional - Standard deviation associated with the prediction of the gromov-wasserstein cost. - random_state : int or RandomState instance, optional - Fix the seed for to allow reproducibility - - Returns - ------- - : float - Gromov-wasserstein cost - - References - ---------- - .. [14] Kerdoncuff, Tanguy, Emonet, Rémi, Sebban, Marc - "Sampled Gromov Wasserstein." - Machine Learning Journal (MLJ). 2021. - - """ + Returns an approximation of the gromov-wasserstein cost between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` + with a fixed transport plan :math:`\mathbf{T}`. + + The function gives an unbiased approximation of the following equation: + + .. math:: + + GW = \sum_{i,j,k,l} L(\mathbf{C_{1}}_{i,k}, \mathbf{C_{2}}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} + + Where : + + - :math:`\mathbf{C_1}`: Metric cost matrix in the source space + - :math:`\mathbf{C_2}`: Metric cost matrix in the target space + - `L` : Loss function to account for the misfit between the similarity matrices + - :math:`\mathbf{T}`: Matrix with marginal :math:`\mathbf{p}` and :math:`\mathbf{q}` + + Parameters + ---------- + C1 : array-like, shape (ns, ns) + Metric cost matrix in the source space + C2 : array-like, shape (nt, nt) + Metric cost matrix in the target space + p : array-like, shape (ns,) + Distribution in the source space + q : array-like, shape (nt,) + Distribution in the target space + loss_fun : function: :math:`\mathbb{R} \times \mathbb{R} \mapsto \mathbb{R}` + Loss function used for the distance, the transport plan does not depend on the loss function + T : csr or array-like, shape (ns, nt) + Transport plan matrix, either a sparse csr or a dense matrix + nb_samples_p : int, optional + `nb_samples_p` is the number of samples (without replacement) along the first dimension of :math:`\mathbf{T}` + nb_samples_q : int, optional + `nb_samples_q` is the number of samples along the second dimension of :math:`\mathbf{T}`, for each sample along the first + std : bool, optional + Standard deviation associated with the prediction of the gromov-wasserstein cost + random_state : int or RandomState instance, optional + Fix the seed for reproducibility + + Returns + ------- + : float + Gromov-wasserstein cost + + References + ---------- + .. [14] Kerdoncuff, Tanguy, Emonet, Rémi, Sebban, Marc + "Sampled Gromov Wasserstein." + Machine Learning Journal (MLJ). 2021. + + """ + C1, C2, p, q = list_to_array(C1, C2, p, q) + nx = get_backend(C1, C2, p, q) + generator = check_random_state(random_state) - len_p = len(p) - len_q = len(q) + len_p = p.shape[0] + len_q = q.shape[0] # It is always better to sample from the biggest distribution first. if len_p < len_q: @@ -642,7 +717,7 @@ def GW_distance_estimation(C1, C2, p, q, loss_fun, T, T = T.T if nb_samples_p is None: - if issparse(T): + if nx.issparse(T): # If T is sparse, it probably mean that PoGroW was used, thus the number of sample is reduced nb_samples_p = min(int(5 * (len_p * np.log(len_p)) ** 0.5), len_p) else: @@ -657,100 +732,112 @@ def GW_distance_estimation(C1, C2, p, q, loss_fun, T, index_k = np.zeros((nb_samples_p, nb_samples_q), dtype=int) index_l = np.zeros((nb_samples_p, nb_samples_q), dtype=int) - list_value_sample = np.zeros((nb_samples_p, nb_samples_p, nb_samples_q)) index_i = generator.choice(len_p, size=nb_samples_p, p=p, replace=False) index_j = generator.choice(len_p, size=nb_samples_p, p=p, replace=False) for i in range(nb_samples_p): - if issparse(T): - T_indexi = T[index_i[i], :].toarray()[0] - T_indexj = T[index_j[i], :].toarray()[0] + if nx.issparse(T): + T_indexi = nx.reshape(nx.todense(T[index_i[i], :]), (-1,)) + T_indexj = nx.reshape(nx.todense(T[index_j[i], :]), (-1,)) else: T_indexi = T[index_i[i], :] T_indexj = T[index_j[i], :] # For each of the row sampled, the column is sampled. - index_k[i] = generator.choice(len_q, size=nb_samples_q, p=T_indexi / T_indexi.sum(), replace=True) - index_l[i] = generator.choice(len_q, size=nb_samples_q, p=T_indexj / T_indexj.sum(), replace=True) - - for n in range(nb_samples_q): - list_value_sample[:, :, n] = loss_fun(C1[np.ix_(index_i, index_j)], C2[np.ix_(index_k[:, n], index_l[:, n])]) + index_k[i] = generator.choice( + len_q, + size=nb_samples_q, + p=T_indexi / nx.sum(T_indexi), + replace=True + ) + index_l[i] = generator.choice( + len_q, + size=nb_samples_q, + p=T_indexj / nx.sum(T_indexj), + replace=True + ) + + list_value_sample = nx.stack([ + loss_fun( + C1[np.ix_(index_i, index_j)], + C2[np.ix_(index_k[:, n], index_l[:, n])] + ) for n in range(nb_samples_q) + ], axis=2) if std: - std_value = np.sum(np.std(list_value_sample, axis=2) ** 2) ** 0.5 - return np.mean(list_value_sample), std_value / (nb_samples_p * nb_samples_p) + std_value = nx.sum(nx.std(list_value_sample, axis=2) ** 2) ** 0.5 + return nx.mean(list_value_sample), std_value / (nb_samples_p * nb_samples_p) else: - return np.mean(list_value_sample) + return nx.mean(list_value_sample) def pointwise_gromov_wasserstein(C1, C2, p, q, loss_fun, alpha=1, max_iter=100, threshold_plan=0, log=False, verbose=False, random_state=None): r""" - Returns the gromov-wasserstein transport between (C1,p) and (C2,q) using a stochastic Frank-Wolfe. - This method as a O(max_iter \times PN^2) time complexity with P the number of Sinkhorn iterations. - - The function solves the following optimization problem: - - .. math:: - GW = arg\min_T \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l} - - s.t. T 1 = p - - T^T 1= q - - T\geq 0 - - Where : - - - C1 : Metric cost matrix in the source space - - C2 : Metric cost matrix in the target space - - p : distribution in the source space - - q : distribution in the target space - - L : loss function to account for the misfit between the similarity matrices - - Parameters - ---------- - C1 : ndarray, shape (ns, ns) - Metric cost matrix in the source space - C2 : ndarray, shape (nt, nt) - Metric costfr matrix in the target space - p : ndarray, shape (ns,) - Distribution in the source space - q : ndarray, shape (nt,) - Distribution in the target space - loss_fun : function: \mathcal{R} \times \mathcal{R} \shortarrow \mathcal{R} - Loss function used for the distance, the transport plan does not depend on the loss function - alpha : float - Step of the Frank-Wolfe algorithm, should be between 0 and 1 - max_iter : int, optional - Max number of iterations - threshold_plan : float, optional - Deleting very small value in the transport plan. If above zero, it violate the marginal constraints. - verbose : bool, optional - Print information along iterations - log : bool, optional - Gives the distance estimated and the standard deviation - random_state : int or RandomState instance, optional - Fix the seed for to allow reproducibility - - Returns - ------- - T : ndarray, shape (ns, nt) - Optimal coupling between the two spaces - - References - ---------- - .. [14] Kerdoncuff, Tanguy, Emonet, Rémi, Sebban, Marc - "Sampled Gromov Wasserstein." - Machine Learning Journal (MLJ). 2021. - - """ - C1 = np.asarray(C1, dtype=np.float64) - C2 = np.asarray(C2, dtype=np.float64) - p = np.asarray(p, dtype=np.float64) - q = np.asarray(q, dtype=np.float64) - len_p = len(p) - len_q = len(q) + Returns the gromov-wasserstein transport between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` using a stochastic Frank-Wolfe. + This method has a :math:`\mathcal{O}(\mathrm{max\_iter} \times PN^2)` time complexity with `P` the number of Sinkhorn iterations. + + The function solves the following optimization problem: + + .. math:: + \mathbf{GW} = \mathop{\arg \min}_\mathbf{T} \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} + + s.t. \ \mathbf{T} \mathbf{1} &= \mathbf{p} + + \mathbf{T}^T \mathbf{1} &= \mathbf{q} + + \mathbf{T} &\geq 0 + + Where : + + - :math:`\mathbf{C_1}`: Metric cost matrix in the source space + - :math:`\mathbf{C_2}`: Metric cost matrix in the target space + - :math:`\mathbf{p}`: distribution in the source space + - :math:`\mathbf{q}`: distribution in the target space + - `L`: loss function to account for the misfit between the similarity matrices + + Parameters + ---------- + C1 : array-like, shape (ns, ns) + Metric cost matrix in the source space + C2 : array-like, shape (nt, nt) + Metric cost matrix in the target space + p : array-like, shape (ns,) + Distribution in the source space + q : array-like, shape (nt,) + Distribution in the target space + loss_fun : function: :math:`\mathbb{R} \times \mathbb{R} \mapsto \mathbb{R}` + Loss function used for the distance, the transport plan does not depend on the loss function + alpha : float + Step of the Frank-Wolfe algorithm, should be between 0 and 1 + max_iter : int, optional + Max number of iterations + threshold_plan : float, optional + Deleting very small values in the transport plan. If above zero, it violates the marginal constraints. + verbose : bool, optional + Print information along iterations + log : bool, optional + Gives the distance estimated and the standard deviation + random_state : int or RandomState instance, optional + Fix the seed for reproducibility + + Returns + ------- + T : array-like, shape (`ns`, `nt`) + Optimal coupling between the two spaces + + References + ---------- + .. [14] Kerdoncuff, Tanguy, Emonet, Rémi, Sebban, Marc + "Sampled Gromov Wasserstein." + Machine Learning Journal (MLJ). 2021. + + """ + C1, C2, p, q = list_to_array(C1, C2, p, q) + nx = get_backend(C1, C2, p, q) + + len_p = p.shape[0] + len_q = q.shape[0] generator = check_random_state(random_state) @@ -759,30 +846,35 @@ def pointwise_gromov_wasserstein(C1, C2, p, q, loss_fun, # Initialize with default marginal index[0] = generator.choice(len_p, size=1, p=p) index[1] = generator.choice(len_q, size=1, p=q) - T = emd_1d(C1[index[0]], C2[index[1]], a=p, b=q, dense=False).tocsr() + T = nx.tocsr(emd_1d(C1[index[0]], C2[index[1]], a=p, b=q, dense=False)) best_gw_dist_estimated = np.inf for cpt in range(max_iter): index[0] = generator.choice(len_p, size=1, p=p) - T_index0 = T[index[0], :].toarray()[0] + T_index0 = nx.reshape(nx.todense(T[index[0], :]), (-1,)) index[1] = generator.choice(len_q, size=1, p=T_index0 / T_index0.sum()) if alpha == 1: - T = emd_1d(C1[index[0]], C2[index[1]], a=p, b=q, dense=False).tocsr() + T = nx.tocsr( + emd_1d(C1[index[0]], C2[index[1]], a=p, b=q, dense=False) + ) else: - new_T = emd_1d(C1[index[0]], C2[index[1]], a=p, b=q, dense=False).tocsr() + new_T = nx.tocsr( + emd_1d(C1[index[0]], C2[index[1]], a=p, b=q, dense=False) + ) T = (1 - alpha) * T + alpha * new_T - # To limit the number of non 0, the values bellow the threshold are set to 0. - T.data[T.data < threshold_plan] = 0 - T.eliminate_zeros() + # To limit the number of non 0, the values below the threshold are set to 0. + T = nx.eliminate_zeros(T, threshold=threshold_plan) if cpt % 10 == 0 or cpt == (max_iter - 1): - gw_dist_estimated = GW_distance_estimation(C1=C1, C2=C2, loss_fun=loss_fun, - p=p, q=q, T=T, std=False, random_state=generator) + gw_dist_estimated = GW_distance_estimation( + C1=C1, C2=C2, loss_fun=loss_fun, + p=p, q=q, T=T, std=False, random_state=generator + ) if gw_dist_estimated < best_gw_dist_estimated: best_gw_dist_estimated = gw_dist_estimated - best_T = T.copy() + best_T = nx.copy(T) if verbose: if cpt % 200 == 0: @@ -791,9 +883,10 @@ def pointwise_gromov_wasserstein(C1, C2, p, q, loss_fun, if log: log = {} - log["gw_dist_estimated"], log["gw_dist_std"] = GW_distance_estimation(C1=C1, C2=C2, loss_fun=loss_fun, - p=p, q=q, T=best_T, - random_state=generator) + log["gw_dist_estimated"], log["gw_dist_std"] = GW_distance_estimation( + C1=C1, C2=C2, loss_fun=loss_fun, + p=p, q=q, T=best_T, random_state=generator + ) return best_T, log return best_T @@ -802,71 +895,70 @@ def sampled_gromov_wasserstein(C1, C2, p, q, loss_fun, nb_samples_grad=100, epsilon=1, max_iter=500, log=False, verbose=False, random_state=None): r""" - Returns the gromov-wasserstein transport between (C1,p) and (C2,q) using a 1-stochastic Frank-Wolfe. - This method as a O(max_iter \times Nlog(N)) time complexity by relying on the 1D Optimal Transport solver. - - The function solves the following optimization problem: - - .. math:: - GW = arg\min_T \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l} - - s.t. T 1 = p - - T^T 1= q - - T\geq 0 - - Where : - - - C1 : Metric cost matrix in the source space - - C2 : Metric cost matrix in the target space - - p : distribution in the source space - - q : distribution in the target space - - L : loss function to account for the misfit between the similarity matrices - - Parameters - ---------- - C1 : ndarray, shape (ns, ns) - Metric cost matrix in the source space - C2 : ndarray, shape (nt, nt) - Metric costfr matrix in the target space - p : ndarray, shape (ns,) - Distribution in the source space - q : ndarray, shape (nt,) - Distribution in the target space - loss_fun : function: \mathcal{R} \times \mathcal{R} \shortarrow \mathcal{R} - Loss function used for the distance, the transport plan does not depend on the loss function - nb_samples_grad : int - Number of samples to approximate the gradient - epsilon : float - Weight of the Kullback-Leiber regularization - max_iter : int, optional - Max number of iterations - verbose : bool, optional - Print information along iterations - log : bool, optional - Gives the distance estimated and the standard deviation - random_state : int or RandomState instance, optional - Fix the seed for to allow reproducibility - - Returns - ------- - T : ndarray, shape (ns, nt) - Optimal coupling between the two spaces - - References - ---------- - .. [14] Kerdoncuff, Tanguy, Emonet, Rémi, Sebban, Marc - "Sampled Gromov Wasserstein." - Machine Learning Journal (MLJ). 2021. - - """ - C1 = np.asarray(C1, dtype=np.float64) - C2 = np.asarray(C2, dtype=np.float64) - p = np.asarray(p, dtype=np.float64) - q = np.asarray(q, dtype=np.float64) - len_p = len(p) - len_q = len(q) + Returns the gromov-wasserstein transport between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` using a 1-stochastic Frank-Wolfe. + This method has a :math:`\mathcal{O}(\mathrm{max\_iter} \times N \log(N))` time complexity by relying on the 1D Optimal Transport solver. + + The function solves the following optimization problem: + + .. math:: + \mathbf{GW} = \mathop{\arg \min}_\mathbf{T} \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} + + s.t. \ \mathbf{T} \mathbf{1} &= \mathbf{p} + + \mathbf{T}^T \mathbf{1} &= \mathbf{q} + + \mathbf{T} &\geq 0 + + Where : + + - :math:`\mathbf{C_1}`: Metric cost matrix in the source space + - :math:`\mathbf{C_2}`: Metric cost matrix in the target space + - :math:`\mathbf{p}`: distribution in the source space + - :math:`\mathbf{q}`: distribution in the target space + - `L`: loss function to account for the misfit between the similarity matrices + + Parameters + ---------- + C1 : array-like, shape (ns, ns) + Metric cost matrix in the source space + C2 : array-like, shape (nt, nt) + Metric cost matrix in the target space + p : array-like, shape (ns,) + Distribution in the source space + q : array-like, shape (nt,) + Distribution in the target space + loss_fun : function: :math:`\mathbb{R} \times \mathbb{R} \mapsto \mathbb{R}` + Loss function used for the distance, the transport plan does not depend on the loss function + nb_samples_grad : int + Number of samples to approximate the gradient + epsilon : float + Weight of the Kullback-Leibler regularization + max_iter : int, optional + Max number of iterations + verbose : bool, optional + Print information along iterations + log : bool, optional + Gives the distance estimated and the standard deviation + random_state : int or RandomState instance, optional + Fix the seed for reproducibility + + Returns + ------- + T : array-like, shape (`ns`, `nt`) + Optimal coupling between the two spaces + + References + ---------- + .. [14] Kerdoncuff, Tanguy, Emonet, Rémi, Sebban, Marc + "Sampled Gromov Wasserstein." + Machine Learning Journal (MLJ). 2021. + + """ + C1, C2, p, q = list_to_array(C1, C2, p, q) + nx = get_backend(C1, C2, p, q) + + len_p = p.shape[0] + len_q = q.shape[0] generator = check_random_state(random_state) @@ -880,12 +972,12 @@ def sampled_gromov_wasserstein(C1, C2, p, q, loss_fun, nb_samples_grad_p, nb_samples_grad_q = nb_samples_grad, 1 else: nb_samples_grad_p, nb_samples_grad_q = nb_samples_grad - T = np.outer(p, q) + T = nx.outer(p, q) # continue_loop allows to stop the loop if there is several successive small modification of T. continue_loop = 0 # The gradient of GW is more complex if the two matrices are not symmetric. - C_are_symmetric = np.allclose(C1, C1.T, rtol=1e-10, atol=1e-10) and np.allclose(C2, C2.T, rtol=1e-10, atol=1e-10) + C_are_symmetric = nx.allclose(C1, C1.T, rtol=1e-10, atol=1e-10) and nx.allclose(C2, C2.T, rtol=1e-10, atol=1e-10) for cpt in range(max_iter): index0 = generator.choice(len_p, size=nb_samples_grad_p, p=p, replace=False) @@ -893,28 +985,30 @@ def sampled_gromov_wasserstein(C1, C2, p, q, loss_fun, for i, index0_i in enumerate(index0): index1 = generator.choice(len_q, size=nb_samples_grad_q, - p=T[index0_i, :] / T[index0_i, :].sum(), + p=T[index0_i, :] / nx.sum(T[index0_i, :]), replace=False) # If the matrices C are not symmetric, the gradient has 2 terms, thus the term is chosen randomly. if (not C_are_symmetric) and generator.rand(1) > 0.5: - Lik += np.mean(loss_fun(np.expand_dims(C1[:, np.repeat(index0[i], nb_samples_grad_q)], 1), - np.expand_dims(C2[:, index1], 0)), - axis=2) + Lik += nx.mean(loss_fun( + C1[:, [index0[i]] * nb_samples_grad_q][:, None, :], + C2[:, index1][None, :, :] + ), axis=2) else: - Lik += np.mean(loss_fun(np.expand_dims(C1[np.repeat(index0[i], nb_samples_grad_q), :], 2), - np.expand_dims(C2[index1, :], 1)), - axis=0) + Lik += nx.mean(loss_fun( + C1[[index0[i]] * nb_samples_grad_q, :][:, :, None], + C2[index1, :][:, None, :] + ), axis=0) - max_Lik = np.max(Lik) + max_Lik = nx.max(Lik) if max_Lik == 0: continue # This division by the max is here to facilitate the choice of epsilon. Lik /= max_Lik if epsilon > 0: - # Set to infinity all the numbers bellow exp(-200) to avoid log of 0. - log_T = np.log(np.clip(T, np.exp(-200), 1)) - log_T[log_T == -200] = -np.inf + # Set to infinity all the numbers below exp(-200) to avoid log of 0. + log_T = nx.log(nx.clip(T, np.exp(-200), 1)) + log_T = nx.where(log_T == -200, -np.inf, log_T) Lik = Lik - epsilon * log_T try: @@ -925,11 +1019,11 @@ def sampled_gromov_wasserstein(C1, C2, p, q, loss_fun, else: new_T = emd(a=p, b=q, M=Lik) - change_T = ((T - new_T) ** 2).mean() + change_T = nx.mean((T - new_T) ** 2) if change_T <= 10e-20: continue_loop += 1 if continue_loop > 100: # Number max of low modifications of T - T = new_T.copy() + T = nx.copy(new_T) break else: continue_loop = 0 @@ -938,12 +1032,14 @@ def sampled_gromov_wasserstein(C1, C2, p, q, loss_fun, if cpt % 200 == 0: print('{:5s}|{:12s}'.format('It.', '||T_n - T_{n+1}||') + '\n' + '-' * 19) print('{:5d}|{:8e}|'.format(cpt, change_T)) - T = new_T.copy() + T = nx.copy(new_T) if log: log = {} - log["gw_dist_estimated"], log["gw_dist_std"] = GW_distance_estimation(C1=C1, C2=C2, loss_fun=loss_fun, - p=p, q=q, T=T, random_state=generator) + log["gw_dist_estimated"], log["gw_dist_std"] = GW_distance_estimation( + C1=C1, C2=C2, loss_fun=loss_fun, + p=p, q=q, T=T, random_state=generator + ) return T, log return T @@ -951,38 +1047,37 @@ def sampled_gromov_wasserstein(C1, C2, p, q, loss_fun, def entropic_gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, tol=1e-9, verbose=False, log=False): r""" - Returns the gromov-wasserstein transport between (C1,p) and (C2,q) - - (C1,p) and (C2,q) + Returns the gromov-wasserstein transport between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` The function solves the following optimization problem: .. math:: - GW = arg\min_T \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l}-\epsilon(H(T)) + \mathbf{GW} = \mathop{\arg\min}_\mathbf{T} \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} - \epsilon(H(\mathbf{T})) - s.t. T 1 = p + s.t. \ \mathbf{T} \mathbf{1} &= \mathbf{p} - T^T 1= q + \mathbf{T}^T \mathbf{1} &= \mathbf{q} - T\geq 0 + \mathbf{T} &\geq 0 Where : - - C1 : Metric cost matrix in the source space - - C2 : Metric cost matrix in the target space - - p : distribution in the source space - - q : distribution in the target space - - L : loss function to account for the misfit between the similarity matrices - - H : entropy + + - :math:`\mathbf{C_1}`: Metric cost matrix in the source space + - :math:`\mathbf{C_2}`: Metric cost matrix in the target space + - :math:`\mathbf{p}`: distribution in the source space + - :math:`\mathbf{q}`: distribution in the target space + - `L`: loss function to account for the misfit between the similarity matrices + - `H`: entropy Parameters ---------- - C1 : ndarray, shape (ns, ns) + C1 : array-like, shape (ns, ns) Metric cost matrix in the source space - C2 : ndarray, shape (nt, nt) - Metric costfr matrix in the target space - p : ndarray, shape (ns,) + C2 : array-like, shape (nt, nt) + Metric cost matrix in the target space + p : array-like, shape (ns,) Distribution in the source space - q : ndarray, shape (nt,) + q : array-like, shape (nt,) Distribution in the target space loss_fun : string Loss function used for the solver either 'square_loss' or 'kl_loss' @@ -999,21 +1094,20 @@ def entropic_gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, Returns ------- - T : ndarray, shape (ns, nt) + T : array-like, shape (`ns`, `nt`) Optimal coupling between the two spaces References ---------- - .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon, + .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon, "Gromov-Wasserstein averaging of kernel and distance matrices." International Conference on Machine Learning (ICML). 2016. """ + C1, C2, p, q = list_to_array(C1, C2, p, q) + nx = get_backend(C1, C2, p, q) - C1 = np.asarray(C1, dtype=np.float64) - C2 = np.asarray(C2, dtype=np.float64) - - T = np.outer(p, q) # Initialization + T = nx.outer(p, q) constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun) @@ -1035,7 +1129,7 @@ def entropic_gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, if cpt % 10 == 0: # we can speed up the process by checking for the error only all # the 10th iterations - err = np.linalg.norm(T - Tprev) + err = nx.norm(T - Tprev) if log: log['err'].append(err) @@ -1058,32 +1152,31 @@ def entropic_gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, def entropic_gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, tol=1e-9, verbose=False, log=False): r""" - Returns the entropic gromov-wasserstein discrepancy between the two measured similarity matrices - - (C1,p) and (C2,q) + Returns the entropic gromov-wasserstein discrepancy between the two measured similarity matrices :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` The function solves the following optimization problem: .. math:: - GW = \min_T \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l}-\epsilon(H(T)) + GW = \min_\mathbf{T} \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} - \epsilon(H(\mathbf{T})) Where : - - C1 : Metric cost matrix in the source space - - C2 : Metric cost matrix in the target space - - p : distribution in the source space - - q : distribution in the target space - - L : loss function to account for the misfit between the similarity matrices - - H : entropy + + - :math:`\mathbf{C_1}`: Metric cost matrix in the source space + - :math:`\mathbf{C_2}`: Metric cost matrix in the target space + - :math:`\mathbf{p}`: distribution in the source space + - :math:`\mathbf{q}`: distribution in the target space + - `L`: loss function to account for the misfit between the similarity matrices + - `H`: entropy Parameters ---------- - C1 : ndarray, shape (ns, ns) + C1 : array-like, shape (ns, ns) Metric cost matrix in the source space - C2 : ndarray, shape (nt, nt) - Metric costfr matrix in the target space - p : ndarray, shape (ns,) + C2 : array-like, shape (nt, nt) + Metric cost matrix in the target space + p : array-like, shape (ns,) Distribution in the source space - q : ndarray, shape (nt,) + q : array-like, shape (nt,) Distribution in the target space loss_fun : str Loss function used for the solver either 'square_loss' or 'kl_loss' @@ -1105,7 +1198,7 @@ def entropic_gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon, References ---------- - .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon, + .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon, "Gromov-Wasserstein averaging of kernel and distance matrices." International Conference on Machine Learning (ICML). 2016. @@ -1122,40 +1215,39 @@ def entropic_gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon, def entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, - max_iter=1000, tol=1e-9, verbose=False, log=False, init_C=None): + max_iter=1000, tol=1e-9, verbose=False, log=False, init_C=None, random_state=None): r""" - Returns the gromov-wasserstein barycenters of S measured similarity matrices - - (Cs)_{s=1}^{s=S} + Returns the gromov-wasserstein barycenters of `S` measured similarity matrices :math:`(\mathbf{C}_s)_{1 \leq s \leq S}` The function solves the following optimization problem: .. math:: - C = argmin_{C\in R^{NxN}} \sum_s \lambda_s GW(C,C_s,p,p_s) + \mathbf{C} = \mathop{\arg \min}_{\mathbf{C}\in \mathbb{R}^{N \times N}} \sum_s \lambda_s \mathrm{GW}(\mathbf{C}, \mathbf{C}_s, \mathbf{p}, \mathbf{p}_s) Where : - - :math:`C_s` : metric cost matrix - - :math:`p_s` : distribution + - :math:`\mathbf{C}_s`: metric cost matrix + - :math:`\mathbf{p}_s`: distribution Parameters ---------- N : int Size of the targeted barycenter - Cs : list of S np.ndarray of shape (ns,ns) + Cs : list of S array-like of shape (ns,ns) Metric cost matrices - ps : list of S np.ndarray of shape (ns,) - Sample weights in the S spaces - p : ndarray, shape(N,) + ps : list of S array-like of shape (ns,) + Sample weights in the `S` spaces + p : array-like, shape(N,) Weights in the targeted barycenter lambdas : list of float - List of the S spaces' weights. + List of the `S` spaces' weights. loss_fun : callable Tensor-matrix multiplication function based on specific loss function. update : callable - function(p,lambdas,T,Cs) that updates C according to a specific Kernel - with the S Ts couplings calculated at each iteration + function(:math:`\mathbf{p}`, lambdas, :math:`\mathbf{T}`, :math:`\mathbf{Cs}`) that updates + :math:`\mathbf{C}` according to a specific Kernel with the `S` :math:`\mathbf{T}_s` couplings + calculated at each iteration epsilon : float Regularization term >0 max_iter : int, optional @@ -1166,32 +1258,36 @@ def entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, Print information along iterations. log : bool, optional Record log if True. - init_C : bool | ndarray, shape (N, N) - Random initial value for the C matrix provided by user. + init_C : bool | array-like, shape (N, N) + Random initial value for the :math:`\mathbf{C}` matrix provided by user. + random_state : int or RandomState instance, optional + Fix the seed for reproducibility Returns ------- - C : ndarray, shape (N, N) + C : array-like, shape (`N`, `N`) Similarity matrix in the barycenter space (permutated arbitrarily) References ---------- - .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon, + .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon, "Gromov-Wasserstein averaging of kernel and distance matrices." International Conference on Machine Learning (ICML). 2016. """ + Cs = list_to_array(*Cs) + ps = list_to_array(*ps) + p = list_to_array(p) + nx = get_backend(*Cs, *ps, p) S = len(Cs) - Cs = [np.asarray(Cs[s], dtype=np.float64) for s in range(S)] - lambdas = np.asarray(lambdas, dtype=np.float64) - # Initialization of C : random SPD matrix (if not provided by user) if init_C is None: - # XXX use random state - xalea = np.random.randn(N, 2) + generator = check_random_state(random_state) + xalea = generator.randn(N, 2) C = dist(xalea, xalea) C /= C.max() + C = nx.from_numpy(C, type_as=p) else: C = init_C @@ -1214,7 +1310,7 @@ def entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, if cpt % 10 == 0: # we can speed up the process by checking for the error only all # the 10th iterations - err = np.linalg.norm(C - Cprev) + err = nx.norm(C - Cprev) error.append(err) if log: @@ -1232,38 +1328,39 @@ def entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, - max_iter=1000, tol=1e-9, verbose=False, log=False, init_C=None): + max_iter=1000, tol=1e-9, verbose=False, log=False, init_C=None, random_state=None): r""" - Returns the gromov-wasserstein barycenters of S measured similarity matrices - - (Cs)_{s=1}^{s=S} + Returns the gromov-wasserstein barycenters of `S` measured similarity matrices :math:`(\mathbf{C}_s)_{1 \leq s \leq S}` - The function solves the following optimization problem with block - coordinate descent: + The function solves the following optimization problem with block coordinate descent: .. math:: - C = argmin_C\in R^NxN \sum_s \lambda_s GW(C,Cs,p,ps) + + \mathbf{C} = \mathop{\arg \min}_{\mathbf{C}\in \mathbb{R}^{N \times N}} \sum_s \lambda_s \mathrm{GW}(\mathbf{C}, \mathbf{C}_s, \mathbf{p}, \mathbf{p}_s) Where : - - Cs : metric cost matrix - - ps : distribution + - :math:`\mathbf{C}_s`: metric cost matrix + - :math:`\mathbf{p}_s`: distribution Parameters ---------- N : int Size of the targeted barycenter - Cs : list of S np.ndarray of shape (ns, ns) + Cs : list of S array-like of shape (ns, ns) Metric cost matrices - ps : list of S np.ndarray of shape (ns,) - Sample weights in the S spaces - p : ndarray, shape (N,) + ps : list of S array-like of shape (ns,) + Sample weights in the `S` spaces + p : array-like, shape (N,) Weights in the targeted barycenter lambdas : list of float - List of the S spaces' weights - loss_fun : tensor-matrix multiplication function based on specific loss function - update : function(p,lambdas,T,Cs) that updates C according to a specific Kernel - with the S Ts couplings calculated at each iteration + List of the `S` spaces' weights + loss_fun : callable + tensor-matrix multiplication function based on specific loss function + update : callable + function(:math:`\mathbf{p}`, lambdas, :math:`\mathbf{T}`, :math:`\mathbf{Cs}`) that updates + :math:`\mathbf{C}` according to a specific Kernel with the `S` :math:`\mathbf{T}_s` couplings + calculated at each iteration max_iter : int, optional Max number of iterations tol : float, optional @@ -1272,32 +1369,37 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, Print information along iterations. log : bool, optional Record log if True. - init_C : bool | ndarray, shape(N,N) - Random initial value for the C matrix provided by user. + init_C : bool | array-like, shape(N,N) + Random initial value for the :math:`\mathbf{C}` matrix provided by user. + random_state : int or RandomState instance, optional + Fix the seed for reproducibility Returns ------- - C : ndarray, shape (N, N) + C : array-like, shape (`N`, `N`) Similarity matrix in the barycenter space (permutated arbitrarily) References ---------- - .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon, + .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon, "Gromov-Wasserstein averaging of kernel and distance matrices." International Conference on Machine Learning (ICML). 2016. """ - S = len(Cs) + Cs = list_to_array(*Cs) + ps = list_to_array(*ps) + p = list_to_array(p) + nx = get_backend(*Cs, *ps, p) - Cs = [np.asarray(Cs[s], dtype=np.float64) for s in range(S)] - lambdas = np.asarray(lambdas, dtype=np.float64) + S = len(Cs) # Initialization of C : random SPD matrix (if not provided by user) if init_C is None: - # XXX : should use a random state and not use the global seed - xalea = np.random.randn(N, 2) + generator = check_random_state(random_state) + xalea = generator.randn(N, 2) C = dist(xalea, xalea) C /= C.max() + C = nx.from_numpy(C, type_as=p) else: C = init_C @@ -1320,7 +1422,7 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, if cpt % 10 == 0: # we can speed up the process by checking for the error only all # the 10th iterations - err = np.linalg.norm(C - Cprev) + err = nx.norm(C - Cprev) error.append(err) if log: @@ -1339,21 +1441,21 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_features=False, p=None, loss_fun='square_loss', max_iter=100, tol=1e-9, - verbose=False, log=False, init_C=None, init_X=None): - """Compute the fgw barycenter as presented eq (5) in [24]. + verbose=False, log=False, init_C=None, init_X=None, random_state=None): + """Compute the fgw barycenter as presented eq (5) in :ref:`[24] ` Parameters ---------- - N : integer + N : int Desired number of samples of the target barycenter - Ys: list of ndarray, each element has shape (ns,d) + Ys: list of array-like, each element has shape (ns,d) Features of all samples - Cs : list of ndarray, each element has shape (ns,ns) + Cs : list of array-like, each element has shape (ns,ns) Structure matrices of all samples - ps : list of ndarray, each element has shape (ns,) + ps : list of array-like, each element has shape (ns,) Masses of all samples. lambdas : list of float - List of the S spaces' weights + List of the `S` spaces' weights alpha : float Alpha parameter for the fgw distance fixed_structure : bool @@ -1370,41 +1472,46 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_ Print information along iterations. log : bool, optional Record log if True. - init_C : ndarray, shape (N,N), optional + init_C : array-like, shape (N,N), optional Initialization for the barycenters' structure matrix. If not set a random init is used. - init_X : ndarray, shape (N,d), optional + init_X : array-like, shape (N,d), optional Initialization for the barycenters' features. If not set a random init is used. + random_state : int or RandomState instance, optional + Fix the seed for reproducibility Returns ------- - X : ndarray, shape (N, d) + X : array-like, shape (`N`, `d`) Barycenters' features - C : ndarray, shape (N, N) + C : array-like, shape (`N`, `N`) Barycenters' structure matrix - log_: dict + log : dict Only returned when log=True. It contains the keys: - T : list of (N,ns) transport matrices - Ms : all distance matrices between the feature of the barycenter and the - other features dist(X,Ys) shape (N,ns) + - :math:`\mathbf{T}`: list of (`N`, `ns`) transport matrices + - :math:`(\mathbf{M}_s)_s`: all distance matrices between the feature of the barycenter and the other features :math:`(dist(\mathbf{X}, \mathbf{Y}_s))_s` shape (`N`, `ns`) + + + .. _references-fgw-barycenters: References ---------- - .. [24] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain + .. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain and Courty Nicolas "Optimal Transport for structured data with application on graphs" International Conference on Machine Learning (ICML). 2019. """ + Cs = list_to_array(*Cs) + ps = list_to_array(*ps) + Ys = list_to_array(*Ys) + p = list_to_array(p) + nx = get_backend(*Cs, *Ys, *ps) + S = len(Cs) d = Ys[0].shape[1] # dimension on the node features if p is None: - p = np.ones(N) / N - - Cs = [np.asarray(Cs[s], dtype=np.float64) for s in range(S)] - Ys = [np.asarray(Ys[s], dtype=np.float64) for s in range(S)] - - lambdas = np.asarray(lambdas, dtype=np.float64) + p = nx.ones(N, type_as=Cs[0]) / N if fixed_structure: if init_C is None: @@ -1413,8 +1520,10 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_ C = init_C else: if init_C is None: - xalea = np.random.randn(N, 2) + generator = check_random_state(random_state) + xalea = generator.randn(N, 2) C = dist(xalea, xalea) + C = nx.from_numpy(C, type_as=ps[0]) else: C = init_C @@ -1425,13 +1534,13 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_ X = init_X else: if init_X is None: - X = np.zeros((N, d)) + X = nx.zeros((N, d), type_as=ps[0]) else: X = init_X - T = [np.outer(p, q) for q in ps] + T = [nx.outer(p, q) for q in ps] - Ms = [np.asarray(dist(X, Ys[s]), dtype=np.float64) for s in range(len(Ys))] # Ms is N,ns + Ms = [dist(X, Ys[s]) for s in range(len(Ys))] cpt = 0 err_feature = 1 @@ -1451,20 +1560,19 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_ Ys_temp = [y.T for y in Ys] X = update_feature_matrix(lambdas, Ys_temp, T, p).T - Ms = [np.asarray(dist(X, Ys[s]), dtype=np.float64) for s in range(len(Ys))] + Ms = [dist(X, Ys[s]) for s in range(len(Ys))] if not fixed_structure: if loss_fun == 'square_loss': T_temp = [t.T for t in T] - C = update_sructure_matrix(p, lambdas, T_temp, Cs) + C = update_structure_matrix(p, lambdas, T_temp, Cs) T = [fused_gromov_wasserstein(Ms[s], C, Cs[s], p, ps[s], loss_fun, alpha, numItermax=max_iter, stopThr=1e-5, verbose=verbose) for s in range(S)] # T is N,ns - err_feature = np.linalg.norm(X - Xprev.reshape(N, d)) - err_structure = np.linalg.norm(C - Cprev) - + err_feature = nx.norm(X - nx.reshape(Xprev, (N, d))) + err_structure = nx.norm(C - Cprev) if log: log_['err_feature'].append(err_feature) log_['err_structure'].append(err_structure) @@ -1490,64 +1598,80 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_ return X, C -def update_sructure_matrix(p, lambdas, T, Cs): - """Updates C according to the L2 Loss kernel with the S Ts couplings. +def update_structure_matrix(p, lambdas, T, Cs): + """Updates :math:`\mathbf{C}` according to the L2 Loss kernel with the `S` :math:`\mathbf{T}_s` couplings. It is calculated at each iteration Parameters ---------- - p : ndarray, shape (N,) + p : array-like, shape (N,) Masses in the targeted barycenter. lambdas : list of float - List of the S spaces' weights. - T : list of S ndarray of shape (ns, N) - The S Ts couplings calculated at each iteration. - Cs : list of S ndarray, shape (ns, ns) - Metric cost matrices. + List of the `S` spaces' weights. + T : list of S array-like of shape (ns, N) + The `S` :math:`\mathbf{T}_s` couplings calculated at each iteration. + Cs : list of S array-like, shape (ns, ns) + Metric cost matrices. Returns ------- - C : ndarray, shape (nt, nt) - Updated C matrix. + C : array-like, shape (`nt`, `nt`) + Updated :math:`\mathbf{C}` matrix. """ - tmpsum = sum([lambdas[s] * np.dot(T[s].T, Cs[s]).dot(T[s]) for s in range(len(T))]) - ppt = np.outer(p, p) + p = list_to_array(p) + T = list_to_array(*T) + Cs = list_to_array(*Cs) + nx = get_backend(*Cs, *T, p) - return np.divide(tmpsum, ppt) + tmpsum = sum([ + lambdas[s] * nx.dot( + nx.dot(T[s].T, Cs[s]), + T[s] + ) for s in range(len(T)) + ]) + ppt = nx.outer(p, p) + return tmpsum / ppt def update_feature_matrix(lambdas, Ys, Ts, p): - """Updates the feature with respect to the S Ts couplings. + """Updates the feature with respect to the `S` :math:`\mathbf{T}_s` couplings. See "Solving the barycenter problem with Block Coordinate Descent (BCD)" - in [24] calculated at each iteration + in :ref:`[24] ` calculated at each iteration Parameters ---------- - p : ndarray, shape (N,) + p : array-like, shape (N,) masses in the targeted barycenter lambdas : list of float - List of the S spaces' weights - Ts : list of S np.ndarray(ns,N) - the S Ts couplings calculated at each iteration - Ys : list of S ndarray, shape(d,ns) + List of the `S` spaces' weights + Ts : list of S array-like, shape (ns,N) + The `S` :math:`\mathbf{T}_s` couplings calculated at each iteration + Ys : list of S array-like, shape (d,ns) The features. Returns ------- - X : ndarray, shape (d, N) + X : array-like, shape (`d`, `N`) + + .. _references-update-feature-matrix: References ---------- - .. [24] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain - and Courty Nicolas + .. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain and Courty Nicolas "Optimal Transport for structured data with application on graphs" International Conference on Machine Learning (ICML). 2019. """ - p = np.array(1. / p).reshape(-1,) - - tmpsum = sum([lambdas[s] * np.dot(Ys[s], Ts[s].T) * p[None, :] for s in range(len(Ts))]) - + p = list_to_array(p) + Ts = list_to_array(*Ts) + Ys = list_to_array(*Ys) + nx = get_backend(*Ys, *Ts, p) + + p = 1. / p + tmpsum = sum([ + lambdas[s] * nx.dot(Ys[s], Ts[s].T) * p[None, :] + for s in range(len(Ts)) + ]) return tmpsum diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index c6757d1..4e95ccf 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -691,10 +691,8 @@ def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True, ot.lp.emd2_1d : EMD for 1d distributions (returns cost instead of the transportation matrix) """ - a = np.asarray(a, dtype=np.float64) - b = np.asarray(b, dtype=np.float64) - x_a = np.asarray(x_a, dtype=np.float64) - x_b = np.asarray(x_b, dtype=np.float64) + a, b, x_a, x_b = list_to_array(a, b, x_a, x_b) + nx = get_backend(x_a, x_b) assert (x_a.ndim == 1 or x_a.ndim == 2 and x_a.shape[1] == 1), \ "emd_1d should only be used with monodimensional data" @@ -702,27 +700,43 @@ def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True, "emd_1d should only be used with monodimensional data" # if empty array given then use uniform distributions - if a.ndim == 0 or len(a) == 0: - a = np.ones((x_a.shape[0],), dtype=np.float64) / x_a.shape[0] - if b.ndim == 0 or len(b) == 0: - b = np.ones((x_b.shape[0],), dtype=np.float64) / x_b.shape[0] + if a is None or a.ndim == 0 or len(a) == 0: + a = nx.ones((x_a.shape[0],), type_as=x_a) / x_a.shape[0] + if b is None or b.ndim == 0 or len(b) == 0: + b = nx.ones((x_b.shape[0],), type_as=x_b) / x_b.shape[0] # ensure that same mass - np.testing.assert_almost_equal(a.sum(0),b.sum(0),err_msg='a and b vector must have the same sum') - b=b*a.sum()/b.sum() - - x_a_1d = x_a.reshape((-1,)) - x_b_1d = x_b.reshape((-1,)) - perm_a = np.argsort(x_a_1d) - perm_b = np.argsort(x_b_1d) - - G_sorted, indices, cost = emd_1d_sorted(a[perm_a], b[perm_b], - x_a_1d[perm_a], x_b_1d[perm_b], - metric=metric, p=p) - G = coo_matrix((G_sorted, (perm_a[indices[:, 0]], perm_b[indices[:, 1]])), - shape=(a.shape[0], b.shape[0])) + np.testing.assert_almost_equal( + nx.sum(a, axis=0), + nx.sum(b, axis=0), + err_msg='a and b vector must have the same sum' + ) + b = b * nx.sum(a) / nx.sum(b) + + x_a_1d = nx.reshape(x_a, (-1,)) + x_b_1d = nx.reshape(x_b, (-1,)) + perm_a = nx.argsort(x_a_1d) + perm_b = nx.argsort(x_b_1d) + + G_sorted, indices, cost = emd_1d_sorted( + nx.to_numpy(a[perm_a]), + nx.to_numpy(b[perm_b]), + nx.to_numpy(x_a_1d[perm_a]), + nx.to_numpy(x_b_1d[perm_b]), + metric=metric, p=p + ) + + G = nx.coo_matrix( + G_sorted, + perm_a[indices[:, 0]], + perm_b[indices[:, 1]], + shape=(a.shape[0], b.shape[0]), + type_as=x_a + ) if dense: - G = G.toarray() + G = nx.todense(G) + elif str(nx) == "jax": + warnings.warn("JAX does not support sparse matrices, converting to dense") if log: log = {'cost': cost} return G, log diff --git a/ot/optim.py b/ot/optim.py index 34cbb17..6456c03 100644 --- a/ot/optim.py +++ b/ot/optim.py @@ -23,7 +23,7 @@ def line_search_armijo(f, xk, pk, gfk, old_fval, r""" Armijo linesearch function that works with matrices - Find an approximate minimum of :math:`f(x_k + \\alpha \cdot p_k)` that satisfies the + Find an approximate minimum of :math:`f(x_k + \alpha \cdot p_k)` that satisfies the armijo conditions. Parameters @@ -129,7 +129,7 @@ def solve_linesearch(cost, G, deltaG, Mi, f_val, .. _references-solve-linesearch: References ---------- - .. [24] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain and Courty Nicolas + .. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain and Courty Nicolas "Optimal Transport for structured data with application on graphs" International Conference on Machine Learning (ICML). 2019. """ @@ -162,13 +162,13 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000, The function solves the following optimization problem: .. math:: - \gamma = arg\min_\gamma <\gamma,M>_F + \mathrm{reg} \cdot f(\gamma) + \gamma = \mathop{\arg \min}_\gamma <\gamma, \mathbf{M}>_F + \mathrm{reg} \cdot f(\gamma) - s.t. \ \gamma 1 = a + s.t. \ \gamma \mathbf{1} &= \mathbf{a} - \gamma^T 1= b + \gamma^T \mathbf{1} &= \mathbf{b} - \gamma\geq 0 + \gamma &\geq 0 where : - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix @@ -309,13 +309,13 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10, The function solves the following optimization problem: .. math:: - \gamma = arg\min_\gamma <\gamma,M>_F + \mathrm{reg_1}\cdot\Omega(\gamma) + \mathrm{reg_2}\cdot f(\gamma) + \gamma = \mathop{\arg \min}_\gamma <\gamma, \mathbf{M}>_F + \mathrm{reg_1}\cdot\Omega(\gamma) + \mathrm{reg_2}\cdot f(\gamma) - s.t. \ \gamma 1 = a + s.t. \ \gamma \mathbf{1} &= \mathbf{a} - \gamma^T 1= b + \gamma^T \mathbf{1} &= \mathbf{b} - \gamma\geq 0 + \gamma &\geq 0 where : - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix @@ -452,7 +452,7 @@ def solve_1d_linesearch_quad(a, b, c): .. math:: - arg\min_{0 \leq x \leq 1} f(x) = ax^{2} + bx + c + \mathop{\arg \min}_{0 \leq x \leq 1} f(x) = ax^{2} + bx + c Parameters ---------- diff --git a/test/test_backend.py b/test/test_backend.py index 5853282..0f11ace 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -207,6 +207,22 @@ def test_empty_backend(): nx.stack([M, M]) with pytest.raises(NotImplementedError): nx.reshape(M, (5, 3, 2)) + with pytest.raises(NotImplementedError): + nx.coo_matrix(M, M, M) + with pytest.raises(NotImplementedError): + nx.issparse(M) + with pytest.raises(NotImplementedError): + nx.tocsr(M) + with pytest.raises(NotImplementedError): + nx.eliminate_zeros(M) + with pytest.raises(NotImplementedError): + nx.todense(M) + with pytest.raises(NotImplementedError): + nx.where(M, M, M) + with pytest.raises(NotImplementedError): + nx.copy(M) + with pytest.raises(NotImplementedError): + nx.allclose(M, M) def test_func_backends(nx): @@ -216,6 +232,11 @@ def test_func_backends(nx): v = rnd.randn(3) val = np.array([1.0]) + # Sparse tensors test + sp_row = np.array([0, 3, 1, 0, 3]) + sp_col = np.array([0, 3, 1, 2, 2]) + sp_data = np.array([4, 5, 7, 9, 0]) + lst_tot = [] for nx in [ot.backend.NumpyBackend(), nx]: @@ -229,6 +250,10 @@ def test_func_backends(nx): vb = nx.from_numpy(v) val = nx.from_numpy(val) + sp_rowb = nx.from_numpy(sp_row) + sp_colb = nx.from_numpy(sp_col) + sp_datab = nx.from_numpy(sp_data) + A = nx.set_gradients(val, v, v) lst_b.append(nx.to_numpy(A)) lst_name.append('set_gradients') @@ -438,6 +463,37 @@ def test_func_backends(nx): lst_b.append(nx.to_numpy(A)) lst_name.append('reshape') + sp_Mb = nx.coo_matrix(sp_datab, sp_rowb, sp_colb, shape=(4, 4)) + nx.todense(Mb) + lst_b.append(nx.to_numpy(nx.todense(sp_Mb))) + lst_name.append('coo_matrix') + + assert not nx.issparse(Mb), 'Assert fail on: issparse (expected False)' + assert nx.issparse(sp_Mb) or nx.__name__ == "jax", 'Assert fail on: issparse (expected True)' + + A = nx.tocsr(sp_Mb) + lst_b.append(nx.to_numpy(nx.todense(A))) + lst_name.append('tocsr') + + A = nx.eliminate_zeros(nx.copy(sp_datab), threshold=5.) + lst_b.append(nx.to_numpy(A)) + lst_name.append('eliminate_zeros (dense)') + + A = nx.eliminate_zeros(sp_Mb) + lst_b.append(nx.to_numpy(nx.todense(A))) + lst_name.append('eliminate_zeros (sparse)') + + A = nx.where(Mb >= nx.stack([nx.linspace(0, 1, 10)] * 3, axis=1), Mb, 0.0) + lst_b.append(nx.to_numpy(A)) + lst_name.append('where') + + A = nx.copy(Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('copy') + + assert nx.allclose(Mb, Mb), 'Assert fail on: allclose (expected True)' + assert not nx.allclose(2 * Mb, Mb), 'Assert fail on: allclose (expected False)' + lst_tot.append(lst_b) lst_np = lst_tot[0] diff --git a/test/test_bregman.py b/test/test_bregman.py index c1120ba..6923d31 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -477,8 +477,8 @@ def test_lazy_empirical_sinkhorn(nx): b = ot.unif(n) numIterMax = 1000 - X_s = np.reshape(np.arange(n), (n, 1)) - X_t = np.reshape(np.arange(0, n), (n, 1)) + X_s = np.reshape(np.arange(n, dtype=np.float64), (n, 1)) + X_t = np.reshape(np.arange(0, n, dtype=np.float64), (n, 1)) M = ot.dist(X_s, X_t) M_m = ot.dist(X_s, X_t, metric='euclidean') diff --git a/test/test_gromov.py b/test/test_gromov.py index 0242d72..509c54d 100644 --- a/test/test_gromov.py +++ b/test/test_gromov.py @@ -8,11 +8,12 @@ import numpy as np import ot +from ot.backend import NumpyBackend import pytest -def test_gromov(): +def test_gromov(nx): n_samples = 50 # nb samples mu_s = np.array([0, 0]) @@ -31,37 +32,50 @@ def test_gromov(): C1 /= C1.max() C2 /= C2.max() + C1b = nx.from_numpy(C1) + C2b = nx.from_numpy(C2) + pb = nx.from_numpy(p) + qb = nx.from_numpy(q) + G = ot.gromov.gromov_wasserstein(C1, C2, p, q, 'square_loss', verbose=True) + Gb = nx.to_numpy(ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', verbose=True)) # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) np.testing.assert_allclose( - p, G.sum(1), atol=1e-04) # cf convergence gromov + p, Gb.sum(1), atol=1e-04) # cf convergence gromov np.testing.assert_allclose( - q, G.sum(0), atol=1e-04) # cf convergence gromov + q, Gb.sum(0), atol=1e-04) # cf convergence gromov Id = (1 / (1.0 * n_samples)) * np.eye(n_samples, n_samples) - np.testing.assert_allclose( - G, np.flipud(Id), atol=1e-04) + np.testing.assert_allclose(Gb, np.flipud(Id), atol=1e-04) gw, log = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', log=True) + gwb, logb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', log=True) gw_val = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', log=False) + gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', log=False) G = log['T'] + Gb = nx.to_numpy(logb['T']) - np.testing.assert_allclose(gw, 0, atol=1e-1, rtol=1e-1) + np.testing.assert_allclose(gw, gwb, atol=1e-06) + np.testing.assert_allclose(gwb, 0, atol=1e-1, rtol=1e-1) - np.testing.assert_allclose(gw, gw_val, atol=1e-1, rtol=1e-1) # cf log=False + np.testing.assert_allclose(gw_val, gw_valb, atol=1e-06) + np.testing.assert_allclose(gwb, gw_valb, atol=1e-1, rtol=1e-1) # cf log=False # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) np.testing.assert_allclose( - p, G.sum(1), atol=1e-04) # cf convergence gromov + p, Gb.sum(1), atol=1e-04) # cf convergence gromov np.testing.assert_allclose( - q, G.sum(0), atol=1e-04) # cf convergence gromov + q, Gb.sum(0), atol=1e-04) # cf convergence gromov -def test_entropic_gromov(): +@pytest.skip_backend("jax", reason="test very slow with jax backend") +def test_entropic_gromov(nx): n_samples = 50 # nb samples mu_s = np.array([0, 0]) @@ -80,30 +94,44 @@ def test_entropic_gromov(): C1 /= C1.max() C2 /= C2.max() + C1b = nx.from_numpy(C1) + C2b = nx.from_numpy(C2) + pb = nx.from_numpy(p) + qb = nx.from_numpy(q) + G = ot.gromov.entropic_gromov_wasserstein( C1, C2, p, q, 'square_loss', epsilon=5e-4, verbose=True) + Gb = nx.to_numpy(ot.gromov.entropic_gromov_wasserstein( + C1b, C2b, pb, qb, 'square_loss', epsilon=5e-4, verbose=True + )) # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) np.testing.assert_allclose( - p, G.sum(1), atol=1e-04) # cf convergence gromov + p, Gb.sum(1), atol=1e-04) # cf convergence gromov np.testing.assert_allclose( - q, G.sum(0), atol=1e-04) # cf convergence gromov + q, Gb.sum(0), atol=1e-04) # cf convergence gromov gw, log = ot.gromov.entropic_gromov_wasserstein2( C1, C2, p, q, 'kl_loss', epsilon=1e-2, log=True) + gwb, logb = ot.gromov.entropic_gromov_wasserstein2( + C1b, C2b, pb, qb, 'kl_loss', epsilon=1e-2, log=True) G = log['T'] + Gb = nx.to_numpy(logb['T']) + np.testing.assert_allclose(gw, gwb, atol=1e-06) np.testing.assert_allclose(gw, 0, atol=1e-1, rtol=1e-1) # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) np.testing.assert_allclose( - p, G.sum(1), atol=1e-04) # cf convergence gromov + p, Gb.sum(1), atol=1e-04) # cf convergence gromov np.testing.assert_allclose( - q, G.sum(0), atol=1e-04) # cf convergence gromov + q, Gb.sum(0), atol=1e-04) # cf convergence gromov -def test_pointwise_gromov(): +def test_pointwise_gromov(nx): n_samples = 50 # nb samples mu_s = np.array([0, 0]) @@ -122,33 +150,52 @@ def test_pointwise_gromov(): C1 /= C1.max() C2 /= C2.max() + C1b = nx.from_numpy(C1) + C2b = nx.from_numpy(C2) + pb = nx.from_numpy(p) + qb = nx.from_numpy(q) + def loss(x, y): return np.abs(x - y) + def lossb(x, y): + return nx.abs(x - y) + G, log = ot.gromov.pointwise_gromov_wasserstein( C1, C2, p, q, loss, max_iter=100, log=True, verbose=True, random_state=42) + G = NumpyBackend().todense(G) + Gb, logb = ot.gromov.pointwise_gromov_wasserstein( + C1b, C2b, pb, qb, lossb, max_iter=100, log=True, verbose=True, random_state=42) + Gb = nx.to_numpy(nx.todense(Gb)) # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) np.testing.assert_allclose( - p[:, np.newaxis], G.sum(1), atol=1e-04) # cf convergence gromov + p, Gb.sum(1), atol=1e-04) # cf convergence gromov np.testing.assert_allclose( - q[np.newaxis, :], G.sum(0), atol=1e-04) # cf convergence gromov + q, Gb.sum(0), atol=1e-04) # cf convergence gromov - assert log['gw_dist_estimated'] == 0.0 - assert log['gw_dist_std'] == 0.0 + np.testing.assert_allclose(logb['gw_dist_estimated'], 0.0, atol=1e-08) + np.testing.assert_allclose(logb['gw_dist_std'], 0.0, atol=1e-08) G, log = ot.gromov.pointwise_gromov_wasserstein( C1, C2, p, q, loss, max_iter=100, alpha=0.1, log=True, verbose=True, random_state=42) + G = NumpyBackend().todense(G) + Gb, logb = ot.gromov.pointwise_gromov_wasserstein( + C1b, C2b, pb, qb, lossb, max_iter=100, alpha=0.1, log=True, verbose=True, random_state=42) + Gb = nx.to_numpy(nx.todense(Gb)) - assert log['gw_dist_estimated'] == 0.10342276348494964 - assert log['gw_dist_std'] == 0.0015952535464736394 + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose(logb['gw_dist_estimated'], 0.10342276348494964, atol=1e-8) + np.testing.assert_allclose(logb['gw_dist_std'], 0.0015952535464736394, atol=1e-8) -def test_sampled_gromov(): +@pytest.skip_backend("jax", reason="test very slow with jax backend") +def test_sampled_gromov(nx): n_samples = 50 # nb samples - mu_s = np.array([0, 0]) - cov_s = np.array([[1, 0], [0, 1]]) + mu_s = np.array([0, 0], dtype=np.float64) + cov_s = np.array([[1, 0], [0, 1]], dtype=np.float64) xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=42) @@ -163,23 +210,35 @@ def test_sampled_gromov(): C1 /= C1.max() C2 /= C2.max() + C1b = nx.from_numpy(C1) + C2b = nx.from_numpy(C2) + pb = nx.from_numpy(p) + qb = nx.from_numpy(q) + def loss(x, y): return np.abs(x - y) + def lossb(x, y): + return nx.abs(x - y) + G, log = ot.gromov.sampled_gromov_wasserstein( C1, C2, p, q, loss, max_iter=100, epsilon=1, log=True, verbose=True, random_state=42) + Gb, logb = ot.gromov.sampled_gromov_wasserstein( + C1b, C2b, pb, qb, lossb, max_iter=100, epsilon=1, log=True, verbose=True, random_state=42) + Gb = nx.to_numpy(Gb) # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) np.testing.assert_allclose( - p, G.sum(1), atol=1e-04) # cf convergence gromov + p, Gb.sum(1), atol=1e-04) # cf convergence gromov np.testing.assert_allclose( - q, G.sum(0), atol=1e-04) # cf convergence gromov + q, Gb.sum(0), atol=1e-04) # cf convergence gromov - assert log['gw_dist_estimated'] == 0.05679474884977278 - assert log['gw_dist_std'] == 0.0005986592106971995 + np.testing.assert_allclose(logb['gw_dist_estimated'], 0.05679474884977278, atol=1e-08) + np.testing.assert_allclose(logb['gw_dist_std'], 0.0005986592106971995, atol=1e-08) -def test_gromov_barycenter(): +def test_gromov_barycenter(nx): ns = 10 nt = 20 @@ -188,26 +247,42 @@ def test_gromov_barycenter(): C1 = ot.dist(Xs) C2 = ot.dist(Xt) - + p1 = ot.unif(ns) + p2 = ot.unif(nt) n_samples = 3 - Cb = ot.gromov.gromov_barycenters(n_samples, [C1, C2], - [ot.unif(ns), ot.unif(nt) - ], ot.unif(n_samples), [.5, .5], - 'square_loss', # 5e-4, - max_iter=100, tol=1e-3, - verbose=True) - np.testing.assert_allclose(Cb.shape, (n_samples, n_samples)) + p = ot.unif(n_samples) - Cb2 = ot.gromov.gromov_barycenters(n_samples, [C1, C2], - [ot.unif(ns), ot.unif(nt) - ], ot.unif(n_samples), [.5, .5], - 'kl_loss', # 5e-4, - max_iter=100, tol=1e-3) - np.testing.assert_allclose(Cb2.shape, (n_samples, n_samples)) + C1b = nx.from_numpy(C1) + C2b = nx.from_numpy(C2) + p1b = nx.from_numpy(p1) + p2b = nx.from_numpy(p2) + pb = nx.from_numpy(p) + + Cb = ot.gromov.gromov_barycenters( + n_samples, [C1, C2], [p1, p2], p, [.5, .5], + 'square_loss', max_iter=100, tol=1e-3, verbose=True, random_state=42 + ) + Cbb = nx.to_numpy(ot.gromov.gromov_barycenters( + n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5], + 'square_loss', max_iter=100, tol=1e-3, verbose=True, random_state=42 + )) + np.testing.assert_allclose(Cb, Cbb, atol=1e-06) + np.testing.assert_allclose(Cbb.shape, (n_samples, n_samples)) + + Cb2 = ot.gromov.gromov_barycenters( + n_samples, [C1, C2], [p1, p2], p, [.5, .5], + 'kl_loss', max_iter=100, tol=1e-3, random_state=42 + ) + Cb2b = nx.to_numpy(ot.gromov.gromov_barycenters( + n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5], + 'kl_loss', max_iter=100, tol=1e-3, random_state=42 + )) + np.testing.assert_allclose(Cb2, Cb2b, atol=1e-06) + np.testing.assert_allclose(Cb2b.shape, (n_samples, n_samples)) @pytest.mark.filterwarnings("ignore:divide") -def test_gromov_entropic_barycenter(): +def test_gromov_entropic_barycenter(nx): ns = 10 nt = 20 @@ -216,26 +291,41 @@ def test_gromov_entropic_barycenter(): C1 = ot.dist(Xs) C2 = ot.dist(Xt) - + p1 = ot.unif(ns) + p2 = ot.unif(nt) n_samples = 2 - Cb = ot.gromov.entropic_gromov_barycenters(n_samples, [C1, C2], - [ot.unif(ns), ot.unif(nt) - ], ot.unif(n_samples), [.5, .5], - 'square_loss', 1e-3, - max_iter=50, tol=1e-3, - verbose=True) - np.testing.assert_allclose(Cb.shape, (n_samples, n_samples)) - - Cb2 = ot.gromov.entropic_gromov_barycenters(n_samples, [C1, C2], - [ot.unif(ns), ot.unif(nt) - ], ot.unif(n_samples), [.5, .5], - 'kl_loss', 1e-3, - max_iter=100, tol=1e-3) - np.testing.assert_allclose(Cb2.shape, (n_samples, n_samples)) - - -def test_fgw(): + p = ot.unif(n_samples) + C1b = nx.from_numpy(C1) + C2b = nx.from_numpy(C2) + p1b = nx.from_numpy(p1) + p2b = nx.from_numpy(p2) + pb = nx.from_numpy(p) + + Cb = ot.gromov.entropic_gromov_barycenters( + n_samples, [C1, C2], [p1, p2], p, [.5, .5], + 'square_loss', 1e-3, max_iter=50, tol=1e-3, verbose=True, random_state=42 + ) + Cbb = nx.to_numpy(ot.gromov.entropic_gromov_barycenters( + n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5], + 'square_loss', 1e-3, max_iter=50, tol=1e-3, verbose=True, random_state=42 + )) + np.testing.assert_allclose(Cb, Cbb, atol=1e-06) + np.testing.assert_allclose(Cbb.shape, (n_samples, n_samples)) + + Cb2 = ot.gromov.entropic_gromov_barycenters( + n_samples, [C1, C2], [p1, p2], p, [.5, .5], + 'kl_loss', 1e-3, max_iter=100, tol=1e-3, random_state=42 + ) + Cb2b = nx.to_numpy(ot.gromov.entropic_gromov_barycenters( + n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5], + 'kl_loss', 1e-3, max_iter=100, tol=1e-3, random_state=42 + )) + np.testing.assert_allclose(Cb2, Cb2b, atol=1e-06) + np.testing.assert_allclose(Cb2b.shape, (n_samples, n_samples)) + + +def test_fgw(nx): n_samples = 50 # nb samples mu_s = np.array([0, 0]) @@ -260,33 +350,46 @@ def test_fgw(): M = ot.dist(ys, yt) M /= M.max() + Mb = nx.from_numpy(M) + C1b = nx.from_numpy(C1) + C2b = nx.from_numpy(C2) + pb = nx.from_numpy(p) + qb = nx.from_numpy(q) + G, log = ot.gromov.fused_gromov_wasserstein(M, C1, C2, p, q, 'square_loss', alpha=0.5, log=True) + Gb, logb = ot.gromov.fused_gromov_wasserstein(Mb, C1b, C2b, pb, qb, 'square_loss', alpha=0.5, log=True) + Gb = nx.to_numpy(Gb) # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) np.testing.assert_allclose( - p, G.sum(1), atol=1e-04) # cf convergence fgw + p, Gb.sum(1), atol=1e-04) # cf convergence fgw np.testing.assert_allclose( - q, G.sum(0), atol=1e-04) # cf convergence fgw + q, Gb.sum(0), atol=1e-04) # cf convergence fgw Id = (1 / (1.0 * n_samples)) * np.eye(n_samples, n_samples) np.testing.assert_allclose( - G, np.flipud(Id), atol=1e-04) # cf convergence gromov + Gb, np.flipud(Id), atol=1e-04) # cf convergence gromov fgw, log = ot.gromov.fused_gromov_wasserstein2(M, C1, C2, p, q, 'square_loss', alpha=0.5, log=True) + fgwb, logb = ot.gromov.fused_gromov_wasserstein2(Mb, C1b, C2b, pb, qb, 'square_loss', alpha=0.5, log=True) G = log['T'] + Gb = nx.to_numpy(logb['T']) - np.testing.assert_allclose(fgw, 0, atol=1e-1, rtol=1e-1) + np.testing.assert_allclose(fgw, fgwb, atol=1e-08) + np.testing.assert_allclose(fgwb, 0, atol=1e-1, rtol=1e-1) # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) np.testing.assert_allclose( - p, G.sum(1), atol=1e-04) # cf convergence gromov + p, Gb.sum(1), atol=1e-04) # cf convergence gromov np.testing.assert_allclose( - q, G.sum(0), atol=1e-04) # cf convergence gromov + q, Gb.sum(0), atol=1e-04) # cf convergence gromov -def test_fgw_barycenter(): +def test_fgw_barycenter(nx): np.random.seed(42) ns = 50 @@ -300,30 +403,44 @@ def test_fgw_barycenter(): C1 = ot.dist(Xs) C2 = ot.dist(Xt) - + p1, p2 = ot.unif(ns), ot.unif(nt) n_samples = 3 - X, C = ot.gromov.fgw_barycenters(n_samples, [ys, yt], [C1, C2], [ot.unif(ns), ot.unif(nt)], [.5, .5], 0.5, - fixed_structure=False, fixed_features=False, - p=ot.unif(n_samples), loss_fun='square_loss', - max_iter=100, tol=1e-3) - np.testing.assert_allclose(C.shape, (n_samples, n_samples)) - np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1])) + p = ot.unif(n_samples) + + ysb = nx.from_numpy(ys) + ytb = nx.from_numpy(yt) + C1b = nx.from_numpy(C1) + C2b = nx.from_numpy(C2) + p1b = nx.from_numpy(p1) + p2b = nx.from_numpy(p2) + pb = nx.from_numpy(p) + + Xb, Cb = ot.gromov.fgw_barycenters( + n_samples, [ysb, ytb], [C1b, C2b], [p1b, p2b], [.5, .5], 0.5, fixed_structure=False, + fixed_features=False, p=pb, loss_fun='square_loss', max_iter=100, tol=1e-3, random_state=12345 + ) xalea = np.random.randn(n_samples, 2) init_C = ot.dist(xalea, xalea) - - X, C = ot.gromov.fgw_barycenters(n_samples, [ys, yt], [C1, C2], ps=[ot.unif(ns), ot.unif(nt)], lambdas=[.5, .5], alpha=0.5, - fixed_structure=True, init_C=init_C, fixed_features=False, - p=ot.unif(n_samples), loss_fun='square_loss', - max_iter=100, tol=1e-3) - np.testing.assert_allclose(C.shape, (n_samples, n_samples)) - np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1])) + init_Cb = nx.from_numpy(init_C) + + Xb, Cb = ot.gromov.fgw_barycenters( + n_samples, [ysb, ytb], [C1b, C2b], ps=[p1b, p2b], lambdas=[.5, .5], + alpha=0.5, fixed_structure=True, init_C=init_Cb, fixed_features=False, + p=pb, loss_fun='square_loss', max_iter=100, tol=1e-3 + ) + Xb, Cb = nx.to_numpy(Xb), nx.to_numpy(Cb) + np.testing.assert_allclose(Cb.shape, (n_samples, n_samples)) + np.testing.assert_allclose(Xb.shape, (n_samples, ys.shape[1])) init_X = np.random.randn(n_samples, ys.shape[1]) - - X, C, log = ot.gromov.fgw_barycenters(n_samples, [ys, yt], [C1, C2], [ot.unif(ns), ot.unif(nt)], [.5, .5], 0.5, - fixed_structure=False, fixed_features=True, init_X=init_X, - p=ot.unif(n_samples), loss_fun='square_loss', - max_iter=100, tol=1e-3, log=True) - np.testing.assert_allclose(C.shape, (n_samples, n_samples)) - np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1])) + init_Xb = nx.from_numpy(init_X) + + Xb, Cb, logb = ot.gromov.fgw_barycenters( + n_samples, [ysb, ytb], [C1b, C2b], [p1b, p2b], [.5, .5], 0.5, + fixed_structure=False, fixed_features=True, init_X=init_Xb, + p=pb, loss_fun='square_loss', max_iter=100, tol=1e-3, log=True, random_state=98765 + ) + Xb, Cb = nx.to_numpy(Xb), nx.to_numpy(Cb) + np.testing.assert_allclose(Cb.shape, (n_samples, n_samples)) + np.testing.assert_allclose(Xb.shape, (n_samples, ys.shape[1])) -- cgit v1.2.3 From 2fe69eb130827560ada704bc25998397c4357821 Mon Sep 17 00:00:00 2001 From: Rémi Flamary Date: Thu, 4 Nov 2021 11:00:09 +0100 Subject: [MRG] Make gromov loss differentiable wrt matrices and weights (#302) * grmov differentable * new stuff * test gromov gradients * fgwdifferentiable * fgw tested * correc name test * add awesome example with gromov optimizatrion * pep8+ typos * damn pep8 * thunbnail * remove prints --- README.md | 9 +- examples/backends/plot_optim_gromov_pytorch.py | 260 +++++++++++++++++++++++++ ot/__init__.py | 2 + ot/gromov.py | 141 +++++++++++--- ot/optim.py | 3 +- test/test_gromov.py | 76 ++++++++ 6 files changed, 460 insertions(+), 31 deletions(-) create mode 100644 examples/backends/plot_optim_gromov_pytorch.py (limited to 'test/test_gromov.py') diff --git a/README.md b/README.md index ff32c53..08db003 100644 --- a/README.md +++ b/README.md @@ -26,7 +26,7 @@ POT provides the following generic OT solvers (links to examples): * Debiased Sinkhorn barycenters [Sinkhorn divergence barycenter](https://pythonot.github.io/auto_examples/barycenters/plot_debiased_barycenter.html) [37] * [Smooth optimal transport solvers](https://pythonot.github.io/auto_examples/plot_OT_1D_smooth.html) (dual and semi-dual) for KL and squared L2 regularizations [17]. * Non regularized [Wasserstein barycenters [16] ](https://pythonot.github.io/auto_examples/barycenters/plot_barycenter_lp_vs_entropic.html)) with LP solver (only small scale). -* [Gromov-Wasserstein distances](https://pythonot.github.io/auto_examples/gromov/plot_gromov.html) and [GW barycenters](https://pythonot.github.io/auto_examples/gromov/plot_gromov_barycenter.html) (exact [13] and regularized [12]) +* [Gromov-Wasserstein distances](https://pythonot.github.io/auto_examples/gromov/plot_gromov.html) and [GW barycenters](https://pythonot.github.io/auto_examples/gromov/plot_gromov_barycenter.html) (exact [13] and regularized [12]), differentiable using gradients from * [Fused-Gromov-Wasserstein distances solver](https://pythonot.github.io/auto_examples/gromov/plot_fgw.html#sphx-glr-auto-examples-plot-fgw-py) and [FGW barycenters](https://pythonot.github.io/auto_examples/gromov/plot_barycenter_fgw.html) [24] * [Stochastic solver](https://pythonot.github.io/auto_examples/plot_stochastic.html) for Large-scale Optimal Transport (semi-dual problem [18] and dual problem [19]) * [Stochastic solver of Gromov Wasserstein](https://pythonot.github.io/auto_examples/gromov/plot_gromov.html) for large-scale problem with any loss functions [33] @@ -295,5 +295,8 @@ You can also post bug reports and feature requests in Github issues. Make sure t via optimal transport and diffusions](http://proceedings.mlr.press/v97/liutkus19a/liutkus19a.pdf). In International Conference on Machine Learning (pp. 4104-4113). PMLR. -[37] Janati, H., Cuturi, M., Gramfort, A. Proceedings of the 37th International -Conference on Machine Learning, PMLR 119:4692-4701, 2020 \ No newline at end of file +[37] Janati, H., Cuturi, M., Gramfort, A. [Debiased sinkhorn barycenters](http://proceedings.mlr.press/v119/janati20a/janati20a.pdf) Proceedings of the 37th International +Conference on Machine Learning, PMLR 119:4692-4701, 2020 + +[38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, [Online Graph +Dictionary Learning](https://arxiv.org/pdf/2102.06555.pdf), International Conference on Machine Learning (ICML), 2021. \ No newline at end of file diff --git a/examples/backends/plot_optim_gromov_pytorch.py b/examples/backends/plot_optim_gromov_pytorch.py new file mode 100644 index 0000000..465f612 --- /dev/null +++ b/examples/backends/plot_optim_gromov_pytorch.py @@ -0,0 +1,260 @@ +r""" +================================= +Optimizing the Gromov-Wasserstein distance with PyTorch +================================= + +In this exemple we use the pytorch backend to optimize the Gromov-Wasserstein +(GW) loss between two graphs expressed as empirical distribution. + +In the first example we optimize the weights on the node of a simple template +graph so that it minimizes the GW with a given Stochastic Block Model graph. +We can see that this actually recovers the proportion of classes in the SBM +and allows for an accurate clustering of the nodes using the GW optimal plan. + +In a second example we optimize simultaneously the weights and the sructure of +the template graph which allows us to perform graph compression and to recover +other properties of the SBM. + +The backend actually uses the gradients expressed in [38] to optimize the +weights. + +[38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, Online Graph +Dictionary Learning, International Conference on Machine Learning (ICML), 2021. + +""" +# Author: Rémi Flamary +# +# License: MIT License + +# sphinx_gallery_thumbnail_number = 3 + +from sklearn.manifold import MDS +import numpy as np +import matplotlib.pylab as pl +import torch + +import ot +from ot.gromov import gromov_wasserstein2 + +# %% +# Graph generation +# --------------- + +rng = np.random.RandomState(42) + + +def get_sbm(n, nc, ratio, P): + nbpc = np.round(n * ratio).astype(int) + n = np.sum(nbpc) + C = np.zeros((n, n)) + for c1 in range(nc): + for c2 in range(c1 + 1): + if c1 == c2: + for i in range(np.sum(nbpc[:c1]), np.sum(nbpc[:c1 + 1])): + for j in range(np.sum(nbpc[:c2]), i): + if rng.rand() <= P[c1, c2]: + C[i, j] = 1 + else: + for i in range(np.sum(nbpc[:c1]), np.sum(nbpc[:c1 + 1])): + for j in range(np.sum(nbpc[:c2]), np.sum(nbpc[:c2 + 1])): + if rng.rand() <= P[c1, c2]: + C[i, j] = 1 + + return C + C.T + + +n = 100 +nc = 3 +ratio = np.array([.5, .3, .2]) +P = np.array(0.6 * np.eye(3) + 0.05 * np.ones((3, 3))) +C1 = get_sbm(n, nc, ratio, P) + +# get 2d position for nodes +x1 = MDS(dissimilarity='precomputed', random_state=0).fit_transform(1 - C1) + + +def plot_graph(x, C, color='C0', s=None): + for j in range(C.shape[0]): + for i in range(j): + if C[i, j] > 0: + pl.plot([x[i, 0], x[j, 0]], [x[i, 1], x[j, 1]], alpha=0.2, color='k') + pl.scatter(x[:, 0], x[:, 1], c=color, s=s, zorder=10, edgecolors='k', cmap='tab10', vmax=9) + + +pl.figure(1, (10, 5)) +pl.clf() +pl.subplot(1, 2, 1) +plot_graph(x1, C1, color='C0') +pl.title("SBM Graph") +pl.axis("off") +pl.subplot(1, 2, 2) +pl.imshow(C1, interpolation='nearest') +pl.title("Adjacency matrix") +pl.axis("off") + + +# %% +# Optimizing the weights of a simple template C0=eye(3) to fit Graph 1 +# ------------------------------------------------ +# The adajacency matrix C1 is block diagonal with 3 blocks. We want to +# optimize the weights of a simple template C0=eye(3) and see if we can +# recover the proportion of classes from the SBM (up to a permutation). + +C0 = np.eye(3) + + +def min_weight_gw(C1, C2, a2, nb_iter_max=100, lr=1e-2): + """ solve min_a GW(C1,C2,a, a2) by gradient descent""" + + # use pyTorch for our data + C1_torch = torch.tensor(C1) + C2_torch = torch.tensor(C2) + + a0 = rng.rand(C1.shape[0]) # random_init + a0 /= a0.sum() # on simplex + a1_torch = torch.tensor(a0).requires_grad_(True) + a2_torch = torch.tensor(a2) + + loss_iter = [] + + for i in range(nb_iter_max): + + loss = gromov_wasserstein2(C1_torch, C2_torch, a1_torch, a2_torch) + + loss_iter.append(loss.clone().detach().cpu().numpy()) + loss.backward() + + #print("{:03d} | {}".format(i, loss_iter[-1])) + + # performs a step of projected gradient descent + with torch.no_grad(): + grad = a1_torch.grad + a1_torch -= grad * lr # step + a1_torch.grad.zero_() + a1_torch.data = ot.utils.proj_simplex(a1_torch) + + a1 = a1_torch.clone().detach().cpu().numpy() + + return a1, loss_iter + + +a0_est, loss_iter0 = min_weight_gw(C0, C1, ot.unif(n), nb_iter_max=100, lr=1e-2) + +pl.figure(2) +pl.plot(loss_iter0) +pl.title("Loss along iterations") + +print("Estimated weights : ", a0_est) +print("True proportions : ", ratio) + + +# %% +# It is clear that the optimization has converged and that we recover the +# ratio of the different classes in the SBM graph up to a permutation. + + +# %% +# Community clustering with uniform and estimated weights +# -------------------------------------------- +# The GW OT plan can be used to perform a clustering of the nodes of a graph +# when computing the GW with a simple template like C0 by labeling nodes in +# the original graph using by the index of the noe in the template receiving +# the most mass. +# +# We show here the result of such a clustering when using uniform weights on +# the template C0 and when using the optimal weights previously estimated. + + +T_unif = ot.gromov_wasserstein(C1, C0, ot.unif(n), ot.unif(3)) +label_unif = T_unif.argmax(1) + +T_est = ot.gromov_wasserstein(C1, C0, ot.unif(n), a0_est) +label_est = T_est.argmax(1) + +pl.figure(3, (10, 5)) +pl.clf() +pl.subplot(1, 2, 1) +plot_graph(x1, C1, color=label_unif) +pl.title("Graph clustering unif. weights") +pl.axis("off") +pl.subplot(1, 2, 2) +plot_graph(x1, C1, color=label_est) +pl.title("Graph clustering est. weights") +pl.axis("off") + + +# %% +# Graph compression with GW +# ------------------------- + +# Now we optimize both the weights and structure of a small graph that +# minimize the GW distance wrt our data graph. This can be seen as graph +# compression but can also recover important properties of an SBM such +# as its class proportion but also its matrix of probability of links between +# classes + + +def graph_compession_gw(nb_nodes, C2, a2, nb_iter_max=100, lr=1e-2): + """ solve min_a GW(C1,C2,a, a2) by gradient descent""" + + # use pyTorch for our data + + C2_torch = torch.tensor(C2) + a2_torch = torch.tensor(a2) + + a0 = rng.rand(nb_nodes) # random_init + a0 /= a0.sum() # on simplex + a1_torch = torch.tensor(a0).requires_grad_(True) + C0 = np.eye(nb_nodes) + C1_torch = torch.tensor(C0).requires_grad_(True) + + loss_iter = [] + + for i in range(nb_iter_max): + + loss = gromov_wasserstein2(C1_torch, C2_torch, a1_torch, a2_torch) + + loss_iter.append(loss.clone().detach().cpu().numpy()) + loss.backward() + + #print("{:03d} | {}".format(i, loss_iter[-1])) + + # performs a step of projected gradient descent + with torch.no_grad(): + grad = a1_torch.grad + a1_torch -= grad * lr # step + a1_torch.grad.zero_() + a1_torch.data = ot.utils.proj_simplex(a1_torch) + + grad = C1_torch.grad + C1_torch -= grad * lr # step + C1_torch.grad.zero_() + C1_torch.data = torch.clamp(C1_torch, 0, 1) + + a1 = a1_torch.clone().detach().cpu().numpy() + C1 = C1_torch.clone().detach().cpu().numpy() + + return a1, C1, loss_iter + + +nb_nodes = 3 +a0_est2, C0_est2, loss_iter2 = graph_compession_gw(nb_nodes, C1, ot.unif(n), + nb_iter_max=100, lr=5e-2) + +pl.figure(4) +pl.plot(loss_iter2) +pl.title("Loss along iterations") + + +print("Estimated weights : ", a0_est2) +print("True proportions : ", ratio) + +pl.figure(6, (10, 3.5)) +pl.clf() +pl.subplot(1, 2, 1) +pl.imshow(P, vmin=0, vmax=1) +pl.title('True SBM P matrix') +pl.subplot(1, 2, 2) +pl.imshow(C0_est2, vmin=0, vmax=1) +pl.title('Estimated C0 matrix') +pl.colorbar() diff --git a/ot/__init__.py b/ot/__init__.py index f20332c..4292b41 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -43,6 +43,8 @@ from .unbalanced import (sinkhorn_unbalanced, barycenter_unbalanced, sinkhorn_unbalanced2) from .da import sinkhorn_lpl1_mm from .sliced import sliced_wasserstein_distance, max_sliced_wasserstein_distance +from .gromov import (gromov_wasserstein, gromov_wasserstein2, + gromov_barycenters, fused_gromov_wasserstein, fused_gromov_wasserstein2) # utils functions from .utils import dist, unif, tic, toc, toq diff --git a/ot/gromov.py b/ot/gromov.py index 465693d..ea667e4 100644 --- a/ot/gromov.py +++ b/ot/gromov.py @@ -174,7 +174,7 @@ def tensor_product(constC, hC1, hC2, T): def gwloss(constC, hC1, hC2, T): - """Return the Loss for Gromov-Wasserstein + r"""Return the Loss for Gromov-Wasserstein The loss is computed as described in Proposition 1 Eq. (6) in :ref:`[12] ` @@ -213,7 +213,7 @@ def gwloss(constC, hC1, hC2, T): def gwggrad(constC, hC1, hC2, T): - """Return the gradient for Gromov-Wasserstein + r"""Return the gradient for Gromov-Wasserstein The gradient is computed as described in Proposition 2 in :ref:`[12] ` @@ -247,7 +247,7 @@ def gwggrad(constC, hC1, hC2, T): def update_square_loss(p, lambdas, T, Cs): - """ + r""" Updates :math:`\mathbf{C}` according to the L2 Loss kernel with the `S` :math:`\mathbf{T}_s` couplings calculated at each iteration @@ -284,7 +284,7 @@ def update_square_loss(p, lambdas, T, Cs): def update_kl_loss(p, lambdas, T, Cs): - """ + r""" Updates :math:`\mathbf{C}` according to the KL Loss kernel with the `S` :math:`\mathbf{T}_s` couplings calculated at each iteration @@ -320,7 +320,7 @@ def update_kl_loss(p, lambdas, T, Cs): return nx.exp(tmpsum / ppt) -def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs): +def gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=False, **kwargs): r""" Returns the gromov-wasserstein transport between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` @@ -386,6 +386,14 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs """ p, q = list_to_array(p, q) + p0, q0, C10, C20 = p, q, C1, C2 + nx = get_backend(p0, q0, C10, C20) + + p = nx.to_numpy(p) + q = nx.to_numpy(q) + C1 = nx.to_numpy(C10) + C2 = nx.to_numpy(C20) + constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun) G0 = p[:, None] * q[None, :] @@ -398,13 +406,15 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs if log: res, log = cg(p, q, 0, 1, f, df, G0, log=True, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs) - log['gw_dist'] = gwloss(constC, hC1, hC2, res) - return res, log + log['gw_dist'] = nx.from_numpy(gwloss(constC, hC1, hC2, res), type_as=C10) + log['u'] = nx.from_numpy(log['u'], type_as=C10) + log['v'] = nx.from_numpy(log['v'], type_as=C10) + return nx.from_numpy(res, type_as=C10), log else: - return cg(p, q, 0, 1, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs) + return nx.from_numpy(cg(p, q, 0, 1, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=False, **kwargs), type_as=C10) -def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs): +def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=False, **kwargs): r""" Returns the gromov-wasserstein discrepancy between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` @@ -420,7 +430,11 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwarg - :math:`\mathbf{C_2}`: Metric cost matrix in the target space - :math:`\mathbf{p}`: distribution in the source space - :math:`\mathbf{q}`: distribution in the target space - - `L`: loss function to account for the misfit between the similarity matrices + - `L`: loss function to account for the misfit between the similarity + matrices + + Note that when using backends, this loss function is differentiable wrt the + marices and weights for quadratic loss using the gradients from [38]_. Parameters ---------- @@ -463,9 +477,21 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwarg metric approach to object matching. Foundations of computational mathematics 11.4 (2011): 417-487. + .. [38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, Online + Graph Dictionary Learning, International Conference on Machine Learning + (ICML), 2021. + """ p, q = list_to_array(p, q) + p0, q0, C10, C20 = p, q, C1, C2 + nx = get_backend(p0, q0, C10, C20) + + p = nx.to_numpy(p) + q = nx.to_numpy(q) + C1 = nx.to_numpy(C10) + C2 = nx.to_numpy(C20) + constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun) G0 = p[:, None] * q[None, :] @@ -475,13 +501,28 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwarg def df(G): return gwggrad(constC, hC1, hC2, G) - res, log_gw = cg(p, q, 0, 1, f, df, G0, log=True, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs) - log_gw['gw_dist'] = gwloss(constC, hC1, hC2, res) - log_gw['T'] = res + + T, log_gw = cg(p, q, 0, 1, f, df, G0, log=True, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs) + + T0 = nx.from_numpy(T, type_as=C10) + + log_gw['gw_dist'] = nx.from_numpy(gwloss(constC, hC1, hC2, T), type_as=C10) + log_gw['u'] = nx.from_numpy(log_gw['u'], type_as=C10) + log_gw['v'] = nx.from_numpy(log_gw['v'], type_as=C10) + log_gw['T'] = T0 + + gw = log_gw['gw_dist'] + + if loss_fun == 'square_loss': + gC1 = nx.from_numpy(2 * C1 * (p[:, None] * p[None, :]) - 2 * T.dot(C2).dot(T.T)) + gC2 = nx.from_numpy(2 * C2 * (q[:, None] * q[None, :]) - 2 * T.T.dot(C1).dot(T)) + gw = nx.set_gradients(gw, (p0, q0, C10, C20), + (log_gw['u'], log_gw['v'], gC1, gC2)) + if log: - return log_gw['gw_dist'], log_gw + return gw, log_gw else: - return log_gw['gw_dist'] + return gw def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, **kwargs): @@ -548,6 +589,15 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, """ p, q = list_to_array(p, q) + p0, q0, C10, C20, M0 = p, q, C1, C2, M + nx = get_backend(p0, q0, C10, C20, M0) + + p = nx.to_numpy(p) + q = nx.to_numpy(q) + C1 = nx.to_numpy(C10) + C2 = nx.to_numpy(C20) + M = nx.to_numpy(M0) + constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun) G0 = p[:, None] * q[None, :] @@ -560,10 +610,16 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, if log: res, log = cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs) - log['fgw_dist'] = log['loss'][::-1][0] - return res, log + + fgw_dist = nx.from_numpy(log['loss'][-1], type_as=C10) + + log['fgw_dist'] = fgw_dist + log['u'] = nx.from_numpy(log['u'], type_as=C10) + log['v'] = nx.from_numpy(log['v'], type_as=C10) + return nx.from_numpy(res, type_as=C10), log + else: - return cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs) + return nx.from_numpy(cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs), type_as=C10) def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, **kwargs): @@ -586,7 +642,11 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5 - :math:`\mathbf{p}` and :math:`\mathbf{q}` are source and target weights (sum to 1) - `L` is a loss function to account for the misfit between the similarity matrices - The algorithm used for solving the problem is conditional gradient as discussed in :ref:`[24] ` + The algorithm used for solving the problem is conditional gradient as + discussed in :ref:`[24] ` + + Note that when using backends, this loss function is differentiable wrt the + marices and weights for quadratic loss using the gradients from [38]_. Parameters ---------- @@ -627,9 +687,22 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5 and Courty Nicolas "Optimal Transport for structured data with application on graphs" International Conference on Machine Learning (ICML). 2019. + + .. [38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, Online + Graph Dictionary Learning, International Conference on Machine Learning + (ICML), 2021. """ p, q = list_to_array(p, q) + p0, q0, C10, C20, M0 = p, q, C1, C2, M + nx = get_backend(p0, q0, C10, C20, M0) + + p = nx.to_numpy(p) + q = nx.to_numpy(q) + C1 = nx.to_numpy(C10) + C2 = nx.to_numpy(C20) + M = nx.to_numpy(M0) + constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun) G0 = p[:, None] * q[None, :] @@ -640,13 +713,27 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5 def df(G): return gwggrad(constC, hC1, hC2, G) - res, log = cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs) + T, log_fgw = cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs) + + fgw_dist = nx.from_numpy(log_fgw['loss'][-1], type_as=C10) + + T0 = nx.from_numpy(T, type_as=C10) + + log_fgw['fgw_dist'] = fgw_dist + log_fgw['u'] = nx.from_numpy(log_fgw['u'], type_as=C10) + log_fgw['v'] = nx.from_numpy(log_fgw['v'], type_as=C10) + log_fgw['T'] = T0 + + if loss_fun == 'square_loss': + gC1 = nx.from_numpy(2 * C1 * (p[:, None] * p[None, :]) - 2 * T.dot(C2).dot(T.T)) + gC2 = nx.from_numpy(2 * C2 * (q[:, None] * q[None, :]) - 2 * T.T.dot(C1).dot(T)) + fgw_dist = nx.set_gradients(fgw_dist, (p0, q0, C10, C20, M0), + (log_fgw['u'], log_fgw['v'], alpha * gC1, alpha * gC2, (1 - alpha) * T0)) + if log: - log['fgw_dist'] = log['loss'][::-1][0] - log['T'] = res - return log['fgw_dist'], log + return fgw_dist, log_fgw else: - return log['fgw_dist'] + return fgw_dist def GW_distance_estimation(C1, C2, p, q, loss_fun, T, @@ -1447,7 +1534,7 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_features=False, p=None, loss_fun='square_loss', max_iter=100, tol=1e-9, verbose=False, log=False, init_C=None, init_X=None, random_state=None): - """Compute the fgw barycenter as presented eq (5) in :ref:`[24] ` + r"""Compute the fgw barycenter as presented eq (5) in :ref:`[24] ` Parameters ---------- @@ -1604,7 +1691,7 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_ def update_structure_matrix(p, lambdas, T, Cs): - """Updates :math:`\mathbf{C}` according to the L2 Loss kernel with the `S` :math:`\mathbf{T}_s` couplings. + r"""Updates :math:`\mathbf{C}` according to the L2 Loss kernel with the `S` :math:`\mathbf{T}_s` couplings. It is calculated at each iteration @@ -1640,7 +1727,7 @@ def update_structure_matrix(p, lambdas, T, Cs): def update_feature_matrix(lambdas, Ys, Ts, p): - """Updates the feature with respect to the `S` :math:`\mathbf{T}_s` couplings. + r"""Updates the feature with respect to the `S` :math:`\mathbf{T}_s` couplings. See "Solving the barycenter problem with Block Coordinate Descent (BCD)" diff --git a/ot/optim.py b/ot/optim.py index cc286b6..bd8ca26 100644 --- a/ot/optim.py +++ b/ot/optim.py @@ -267,7 +267,7 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000, Mi += nx.min(Mi) # solve linear program - Gc = emd(a, b, Mi, numItermax=numItermaxEmd) + Gc, logemd = emd(a, b, Mi, numItermax=numItermaxEmd, log=True) deltaG = Gc - G @@ -297,6 +297,7 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000, print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, f_val, relative_delta_fval, abs_delta_fval)) if log: + log.update(logemd) return G, log else: return G diff --git a/test/test_gromov.py b/test/test_gromov.py index 509c54d..bcbcc3a 100644 --- a/test/test_gromov.py +++ b/test/test_gromov.py @@ -9,6 +9,7 @@ import numpy as np import ot from ot.backend import NumpyBackend +from ot.backend import torch import pytest @@ -74,6 +75,42 @@ def test_gromov(nx): q, Gb.sum(0), atol=1e-04) # cf convergence gromov +def test_gromov2_gradients(): + n_samples = 50 # nb samples + + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=4) + + xt = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=5) + + p = ot.unif(n_samples) + q = ot.unif(n_samples) + + C1 = ot.dist(xs, xs) + C2 = ot.dist(xt, xt) + + C1 /= C1.max() + C2 /= C2.max() + + if torch: + + p1 = torch.tensor(p, requires_grad=True) + q1 = torch.tensor(q, requires_grad=True) + C11 = torch.tensor(C1, requires_grad=True) + C12 = torch.tensor(C2, requires_grad=True) + + val = ot.gromov_wasserstein2(C11, C12, p1, q1) + + val.backward() + + assert q1.shape == q1.grad.shape + assert p1.shape == p1.grad.shape + assert C11.shape == C11.grad.shape + assert C12.shape == C12.grad.shape + + @pytest.skip_backend("jax", reason="test very slow with jax backend") def test_entropic_gromov(nx): n_samples = 50 # nb samples @@ -389,6 +426,45 @@ def test_fgw(nx): q, Gb.sum(0), atol=1e-04) # cf convergence gromov +def test_fgw2_gradients(): + n_samples = 50 # nb samples + + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=4) + + xt = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=5) + + p = ot.unif(n_samples) + q = ot.unif(n_samples) + + C1 = ot.dist(xs, xs) + C2 = ot.dist(xt, xt) + M = ot.dist(xs, xt) + + C1 /= C1.max() + C2 /= C2.max() + + if torch: + + p1 = torch.tensor(p, requires_grad=True) + q1 = torch.tensor(q, requires_grad=True) + C11 = torch.tensor(C1, requires_grad=True) + C12 = torch.tensor(C2, requires_grad=True) + M1 = torch.tensor(M, requires_grad=True) + + val = ot.fused_gromov_wasserstein2(M1, C11, C12, p1, q1) + + val.backward() + + assert q1.shape == q1.grad.shape + assert p1.shape == p1.grad.shape + assert C11.shape == C11.grad.shape + assert C12.shape == C12.grad.shape + assert M1.shape == M1.grad.shape + + def test_fgw_barycenter(nx): np.random.seed(42) -- cgit v1.2.3 From 0eac835c70cc1a13bb998f3b6cdb0515fafc05e1 Mon Sep 17 00:00:00 2001 From: ncassereau-idris <84033440+ncassereau-idris@users.noreply.github.com> Date: Fri, 5 Nov 2021 15:57:08 +0100 Subject: [MRG] Tests with types/device on sliced/bregman/gromov functions (#303) * First draft : making pytest use gpu for torch testing * bug solve * Revert "bug solve" This reverts commit 29b013abd162f8693128f26d8129186b79923609. * Revert "First draft : making pytest use gpu for torch testing" This reverts commit 2778175bcc338016c704efa4187d132fe5162e3a. * sliced * sliced * ot 1dsolver * bregman * better print * jax works with sinkhorn, sinkhorn_log and sinkhornn_stabilized, no need to skip them * gromov & entropic gromov --- ot/backend.py | 59 ++++++++++++++++++++++++++++++++++----- ot/sliced.py | 8 +++--- test/conftest.py | 25 +++++++++++++---- test/test_1d_solver.py | 16 ++++------- test/test_bregman.py | 45 ++++++++++++++++++++++++++++++ test/test_gromov.py | 75 ++++++++++++++++++++++++++++++++++++++++++++++++++ test/test_ot.py | 8 ++---- test/test_sliced.py | 44 +++++++++++++++++++++++++++++ 8 files changed, 247 insertions(+), 33 deletions(-) (limited to 'test/test_gromov.py') diff --git a/ot/backend.py b/ot/backend.py index 55e10d3..a044f84 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -653,6 +653,18 @@ class Backend(): """ raise NotImplementedError() + def dtype_device(self, a): + r""" + Returns the dtype and the device of the given tensor. + """ + raise NotImplementedError() + + def assert_same_dtype_device(self, a, b): + r""" + Checks whether or not the two given inputs have the same dtype as well as the same device + """ + raise NotImplementedError() + class NumpyBackend(Backend): """ @@ -880,6 +892,16 @@ class NumpyBackend(Backend): def allclose(self, a, b, rtol=1e-05, atol=1e-08, equal_nan=False): return np.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan) + def dtype_device(self, a): + if hasattr(a, "dtype"): + return a.dtype, "cpu" + else: + return type(a), "cpu" + + def assert_same_dtype_device(self, a, b): + # numpy has implicit type conversion so we automatically validate the test + pass + class JaxBackend(Backend): """ @@ -899,17 +921,20 @@ class JaxBackend(Backend): self.rng_ = jax.random.PRNGKey(42) for d in jax.devices(): - self.__type_list__ = [jax.device_put(jnp.array(1, dtype=np.float32), d), - jax.device_put(jnp.array(1, dtype=np.float64), d)] + self.__type_list__ = [jax.device_put(jnp.array(1, dtype=jnp.float32), d), + jax.device_put(jnp.array(1, dtype=jnp.float64), d)] def to_numpy(self, a): return np.array(a) + def _change_device(self, a, type_as): + return jax.device_put(a, type_as.device_buffer.device()) + def from_numpy(self, a, type_as=None): if type_as is None: return jnp.array(a) else: - return jax.device_put(jnp.array(a).astype(type_as.dtype), type_as.device_buffer.device()) + return self._change_device(jnp.array(a).astype(type_as.dtype), type_as) def set_gradients(self, val, inputs, grads): from jax.flatten_util import ravel_pytree @@ -928,13 +953,13 @@ class JaxBackend(Backend): if type_as is None: return jnp.zeros(shape) else: - return jnp.zeros(shape, dtype=type_as.dtype) + return self._change_device(jnp.zeros(shape, dtype=type_as.dtype), type_as) def ones(self, shape, type_as=None): if type_as is None: return jnp.ones(shape) else: - return jnp.ones(shape, dtype=type_as.dtype) + return self._change_device(jnp.ones(shape, dtype=type_as.dtype), type_as) def arange(self, stop, start=0, step=1, type_as=None): return jnp.arange(start, stop, step) @@ -943,13 +968,13 @@ class JaxBackend(Backend): if type_as is None: return jnp.full(shape, fill_value) else: - return jnp.full(shape, fill_value, dtype=type_as.dtype) + return self._change_device(jnp.full(shape, fill_value, dtype=type_as.dtype), type_as) def eye(self, N, M=None, type_as=None): if type_as is None: return jnp.eye(N, M) else: - return jnp.eye(N, M, dtype=type_as.dtype) + return self._change_device(jnp.eye(N, M, dtype=type_as.dtype), type_as) def sum(self, a, axis=None, keepdims=False): return jnp.sum(a, axis, keepdims=keepdims) @@ -1127,6 +1152,16 @@ class JaxBackend(Backend): def allclose(self, a, b, rtol=1e-05, atol=1e-08, equal_nan=False): return jnp.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan) + def dtype_device(self, a): + return a.dtype, a.device_buffer.device() + + def assert_same_dtype_device(self, a, b): + a_dtype, a_device = self.dtype_device(a) + b_dtype, b_device = self.dtype_device(b) + + assert a_dtype == b_dtype, "Dtype discrepancy" + assert a_device == b_device, f"Device discrepancy. First input is on {str(a_device)}, whereas second input is on {str(b_device)}" + class TorchBackend(Backend): """ @@ -1455,3 +1490,13 @@ class TorchBackend(Backend): def allclose(self, a, b, rtol=1e-05, atol=1e-08, equal_nan=False): return torch.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan) + + def dtype_device(self, a): + return a.dtype, a.device + + def assert_same_dtype_device(self, a, b): + a_dtype, a_device = self.dtype_device(a) + b_dtype, b_device = self.dtype_device(b) + + assert a_dtype == b_dtype, "Dtype discrepancy" + assert a_device == b_device, f"Device discrepancy. First input is on {str(a_device)}, whereas second input is on {str(b_device)}" diff --git a/ot/sliced.py b/ot/sliced.py index 7c09111..cf2d3be 100644 --- a/ot/sliced.py +++ b/ot/sliced.py @@ -139,9 +139,9 @@ def sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, p=2, X_t.shape[1])) if a is None: - a = nx.full(n, 1 / n) + a = nx.full(n, 1 / n, type_as=X_s) if b is None: - b = nx.full(m, 1 / m) + b = nx.full(m, 1 / m, type_as=X_s) d = X_s.shape[1] @@ -238,9 +238,9 @@ def max_sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, X_t.shape[1])) if a is None: - a = nx.full(n, 1 / n) + a = nx.full(n, 1 / n, type_as=X_s) if b is None: - b = nx.full(m, 1 / m) + b = nx.full(m, 1 / m, type_as=X_s) d = X_s.shape[1] diff --git a/test/conftest.py b/test/conftest.py index 876b525..987d98e 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -11,6 +11,7 @@ import functools if jax: from jax.config import config + config.update("jax_enable_x64", True) backend_list = get_backend_list() @@ -18,16 +19,25 @@ backend_list = get_backend_list() @pytest.fixture(params=backend_list) def nx(request): backend = request.param - if backend.__name__ == "jax": - config.update("jax_enable_x64", True) yield backend - if backend.__name__ == "jax": - config.update("jax_enable_x64", False) - def skip_arg(arg, value, reason=None, getter=lambda x: x): + if isinstance(arg, tuple) or isinstance(arg, list): + n = len(arg) + else: + arg = (arg, ) + n = 1 + if n != 1 and (isinstance(value, tuple) or isinstance(value, list)): + pass + else: + value = (value, ) + if isinstance(getter, tuple) or isinstance(value, list): + pass + else: + getter = [getter] * n + if reason is None: reason = f"Param {arg} should be skipped for value {value}" @@ -35,7 +45,10 @@ def skip_arg(arg, value, reason=None, getter=lambda x: x): @functools.wraps(function) def wrapped(*args, **kwargs): - if arg in kwargs.keys() and getter(kwargs[arg]) == value: + if all( + arg[i] in kwargs.keys() and getter[i](kwargs[arg[i]]) == value[i] + for i in range(n) + ): pytest.skip(reason) return function(*args, **kwargs) diff --git a/test/test_1d_solver.py b/test/test_1d_solver.py index 77b1234..cb85cb9 100644 --- a/test/test_1d_solver.py +++ b/test/test_1d_solver.py @@ -85,7 +85,6 @@ def test_wasserstein_1d(nx): np.testing.assert_almost_equal(100 * res[0], res[1], decimal=4) -@pytest.mark.parametrize('nx', backend_list) def test_wasserstein_1d_type_devices(nx): rng = np.random.RandomState(0) @@ -98,8 +97,7 @@ def test_wasserstein_1d_type_devices(nx): rho_v /= rho_v.sum() for tp in nx.__type_list__: - - print(tp.dtype) + print(nx.dtype_device(tp)) xb = nx.from_numpy(x, type_as=tp) rho_ub = nx.from_numpy(rho_u, type_as=tp) @@ -107,8 +105,7 @@ def test_wasserstein_1d_type_devices(nx): res = wasserstein_1d(xb, xb, rho_ub, rho_vb, p=1) - if not str(nx) == 'numpy': - assert res.dtype == xb.dtype + nx.assert_same_dtype_device(xb, res) def test_emd_1d_emd2_1d(): @@ -162,17 +159,14 @@ def test_emd1d_type_devices(nx): rho_v /= rho_v.sum() for tp in nx.__type_list__: - - print(tp.dtype) + print(nx.dtype_device(tp)) xb = nx.from_numpy(x, type_as=tp) rho_ub = nx.from_numpy(rho_u, type_as=tp) rho_vb = nx.from_numpy(rho_v, type_as=tp) emd = ot.emd_1d(xb, xb, rho_ub, rho_vb) - emd2 = ot.emd2_1d(xb, xb, rho_ub, rho_vb) - assert emd.dtype == xb.dtype - if not str(nx) == 'numpy': - assert emd2.dtype == xb.dtype + nx.assert_same_dtype_device(xb, emd) + nx.assert_same_dtype_device(xb, emd2) diff --git a/test/test_bregman.py b/test/test_bregman.py index edfe9c3..830052d 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -278,6 +278,51 @@ def test_sinkhorn_variants(nx): np.testing.assert_allclose(G0, G_green, atol=1e-5) +@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized", + "sinkhorn_epsilon_scaling", + "greenkhorn", + "sinkhorn_log"]) +@pytest.skip_arg(("nx", "method"), ("jax", "sinkhorn_epsilon_scaling"), reason="jax does not support sinkhorn_epsilon_scaling", getter=str) +@pytest.skip_arg(("nx", "method"), ("jax", "greenkhorn"), reason="jax does not support greenkhorn", getter=str) +def test_sinkhorn_variants_dtype_device(nx, method): + n = 100 + + x = np.random.randn(n, 2) + u = ot.utils.unif(n) + + M = ot.dist(x, x) + + for tp in nx.__type_list__: + print(nx.dtype_device(tp)) + + ub = nx.from_numpy(u, type_as=tp) + Mb = nx.from_numpy(M, type_as=tp) + + Gb = ot.sinkhorn(ub, ub, Mb, 1, method=method, stopThr=1e-10) + + nx.assert_same_dtype_device(Mb, Gb) + + +@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized", "sinkhorn_log"]) +def test_sinkhorn2_variants_dtype_device(nx, method): + n = 100 + + x = np.random.randn(n, 2) + u = ot.utils.unif(n) + + M = ot.dist(x, x) + + for tp in nx.__type_list__: + print(nx.dtype_device(tp)) + + ub = nx.from_numpy(u, type_as=tp) + Mb = nx.from_numpy(M, type_as=tp) + + lossb = ot.sinkhorn2(ub, ub, Mb, 1, method=method, stopThr=1e-10) + + nx.assert_same_dtype_device(Mb, lossb) + + @pytest.skip_backend("jax") def test_sinkhorn_variants_multi_b(nx): # test sinkhorn diff --git a/test/test_gromov.py b/test/test_gromov.py index bcbcc3a..c4bc04c 100644 --- a/test/test_gromov.py +++ b/test/test_gromov.py @@ -75,6 +75,41 @@ def test_gromov(nx): q, Gb.sum(0), atol=1e-04) # cf convergence gromov +def test_gromov_dtype_device(nx): + # setup + n_samples = 50 # nb samples + + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=4) + + xt = xs[::-1].copy() + + p = ot.unif(n_samples) + q = ot.unif(n_samples) + + C1 = ot.dist(xs, xs) + C2 = ot.dist(xt, xt) + + C1 /= C1.max() + C2 /= C2.max() + + for tp in nx.__type_list__: + print(nx.dtype_device(tp)) + + C1b = nx.from_numpy(C1, type_as=tp) + C2b = nx.from_numpy(C2, type_as=tp) + pb = nx.from_numpy(p, type_as=tp) + qb = nx.from_numpy(q, type_as=tp) + + Gb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', verbose=True) + gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', log=False) + + nx.assert_same_dtype_device(C1b, Gb) + nx.assert_same_dtype_device(C1b, gw_valb) + + def test_gromov2_gradients(): n_samples = 50 # nb samples @@ -168,6 +203,46 @@ def test_entropic_gromov(nx): q, Gb.sum(0), atol=1e-04) # cf convergence gromov +@pytest.skip_backend("jax", reason="test very slow with jax backend") +def test_entropic_gromov_dtype_device(nx): + # setup + n_samples = 50 # nb samples + + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=42) + + xt = xs[::-1].copy() + + p = ot.unif(n_samples) + q = ot.unif(n_samples) + + C1 = ot.dist(xs, xs) + C2 = ot.dist(xt, xt) + + C1 /= C1.max() + C2 /= C2.max() + + for tp in nx.__type_list__: + print(nx.dtype_device(tp)) + + C1b = nx.from_numpy(C1, type_as=tp) + C2b = nx.from_numpy(C2, type_as=tp) + pb = nx.from_numpy(p, type_as=tp) + qb = nx.from_numpy(q, type_as=tp) + + Gb = ot.gromov.entropic_gromov_wasserstein( + C1b, C2b, pb, qb, 'square_loss', epsilon=5e-4, verbose=True + ) + gw_valb = ot.gromov.entropic_gromov_wasserstein2( + C1b, C2b, pb, qb, 'square_loss', epsilon=5e-4, verbose=True + ) + + nx.assert_same_dtype_device(C1b, Gb) + nx.assert_same_dtype_device(C1b, gw_valb) + + def test_pointwise_gromov(nx): n_samples = 50 # nb samples diff --git a/test/test_ot.py b/test/test_ot.py index dc3930a..92f26a7 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -88,8 +88,7 @@ def test_emd_emd2_types_devices(nx): M = ot.dist(x, y) for tp in nx.__type_list__: - - print(tp.dtype) + print(nx.dtype_device(tp)) ab = nx.from_numpy(a, type_as=tp) Mb = nx.from_numpy(M, type_as=tp) @@ -98,9 +97,8 @@ def test_emd_emd2_types_devices(nx): w = ot.emd2(ab, ab, Mb) - assert Gb.dtype == Mb.dtype - if not str(nx) == 'numpy': - assert w.dtype == Mb.dtype + nx.assert_same_dtype_device(Mb, Gb) + nx.assert_same_dtype_device(Mb, w) def test_emd2_gradients(): diff --git a/test/test_sliced.py b/test/test_sliced.py index 0bd74ec..245202c 100644 --- a/test/test_sliced.py +++ b/test/test_sliced.py @@ -139,6 +139,28 @@ def test_sliced_backend(nx): assert np.allclose(val0, valb) +def test_sliced_backend_type_devices(nx): + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + y = rng.randn(2 * n, 2) + + P = rng.randn(2, 20) + P = P / np.sqrt((P**2).sum(0, keepdims=True)) + + for tp in nx.__type_list__: + print(nx.dtype_device(tp)) + + xb = nx.from_numpy(x, type_as=tp) + yb = nx.from_numpy(y, type_as=tp) + Pb = nx.from_numpy(P, type_as=tp) + + valb = ot.sliced_wasserstein_distance(xb, yb, projections=Pb) + + nx.assert_same_dtype_device(xb, valb) + + def test_max_sliced_backend(nx): n = 100 @@ -167,3 +189,25 @@ def test_max_sliced_backend(nx): valb = nx.to_numpy(ot.max_sliced_wasserstein_distance(xb, yb, projections=Pb)) assert np.allclose(val0, valb) + + +def test_max_sliced_backend_type_devices(nx): + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + y = rng.randn(2 * n, 2) + + P = rng.randn(2, 20) + P = P / np.sqrt((P**2).sum(0, keepdims=True)) + + for tp in nx.__type_list__: + print(nx.dtype_device(tp)) + + xb = nx.from_numpy(x, type_as=tp) + yb = nx.from_numpy(y, type_as=tp) + Pb = nx.from_numpy(P, type_as=tp) + + valb = ot.max_sliced_wasserstein_distance(xb, yb, projections=Pb) + + nx.assert_same_dtype_device(xb, valb) -- cgit v1.2.3