summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGard Spreemann <gspr@nonempty.org>2023-06-14 16:51:31 +0200
committerGard Spreemann <gspr@nonempty.org>2023-06-14 16:51:31 +0200
commit96788a3fe5601e4c3f49b592aa0d9c034247862e (patch)
tree5ee3ebcdea05f6766fc9858344913e40487e9067
parent35bd2c98b642df78638d7d733bc1a89d873db1de (diff)
parent89f1613861152432807077fbb146578611dc5888 (diff)
Merge tag '0.9.0' into dfsg/latestdfsg/latest
-rw-r--r--.circleci/config.yml4
-rw-r--r--.github/CONTRIBUTING.md1
-rw-r--r--.github/requirements_test_windows.txt3
-rw-r--r--.github/workflows/build_doc.yml44
-rw-r--r--CONTRIBUTORS.md7
-rw-r--r--Makefile2
-rw-r--r--README.md35
-rw-r--r--RELEASES.md206
-rw-r--r--codecov.yml2
-rw-r--r--docs/cache_nbrun1
-rw-r--r--docs/requirements_rtd.txt3
-rw-r--r--docs/source/_templates/module.rst2
-rw-r--r--docs/source/_templates/versions.html69
-rw-r--r--docs/source/all.rst32
-rw-r--r--docs/source/code_of_conduct.rst (renamed from docs/source/.github/CODE_OF_CONDUCT.rst)4
-rw-r--r--docs/source/conf.py7
-rw-r--r--docs/source/contributing.rst (renamed from docs/source/.github/CONTRIBUTING.rst)4
-rw-r--r--docs/source/index.rst4
-rw-r--r--docs/source/quickstart.rst98
-rw-r--r--examples/backends/plot_sliced_wass_grad_flow_pytorch.py10
-rw-r--r--examples/backends/plot_ssw_unif_torch.py153
-rw-r--r--examples/barycenters/plot_barycenter_1D.py4
-rw-r--r--examples/barycenters/plot_free_support_barycenter.py28
-rw-r--r--examples/barycenters/plot_free_support_sinkhorn_barycenter.py151
-rw-r--r--examples/barycenters/plot_generalized_free_support_barycenter.py155
-rw-r--r--examples/domain-adaptation/plot_otda_linear_mapping.py2
-rw-r--r--examples/gromov/plot_barycenter_fgw.py2
-rw-r--r--examples/gromov/plot_gromov.py1
-rwxr-xr-xexamples/gromov/plot_gromov_barycenter.py3
-rwxr-xr-xexamples/gromov/plot_gromov_wasserstein_dictionary_learning.py53
-rw-r--r--examples/gromov/plot_semirelaxed_fgw.py301
-rw-r--r--examples/others/plot_COOT.py97
-rw-r--r--examples/others/plot_learning_weights_with_COOT.py150
-rw-r--r--examples/plot_compute_wasserstein_circle.py161
-rw-r--r--examples/sliced-wasserstein/plot_variance_ssw.py111
-rw-r--r--examples/unbalanced-partial/plot_UOT_barycenter_1D.py4
-rw-r--r--ot/__init__.py24
-rw-r--r--ot/backend.py252
-rw-r--r--ot/bregman.py378
-rw-r--r--ot/coot.py434
-rw-r--r--ot/da.py136
-rw-r--r--ot/dr.py32
-rw-r--r--ot/gaussian.py333
-rw-r--r--ot/gromov.py2835
-rw-r--r--ot/gromov/__init__.py48
-rw-r--r--ot/gromov/_bregman.py348
-rw-r--r--ot/gromov/_dictionary.py1008
-rw-r--r--ot/gromov/_estimators.py425
-rw-r--r--ot/gromov/_gw.py978
-rw-r--r--ot/gromov/_semirelaxed.py543
-rw-r--r--ot/gromov/_utils.py413
-rw-r--r--ot/helpers/pre_build_helpers.py24
-rw-r--r--ot/lp/EMD.h5
-rw-r--r--ot/lp/EMD_wrapper.cpp40
-rw-r--r--ot/lp/__init__.py161
-rw-r--r--ot/lp/cvx.py2
-rw-r--r--ot/lp/emd_wrap.pyx9
-rw-r--r--ot/lp/network_simplex_simple.h12
-rw-r--r--ot/lp/network_simplex_simple_omp.h20
-rw-r--r--ot/lp/solver_1d.py629
-rw-r--r--ot/optim.py496
-rwxr-xr-xot/partial.py122
-rw-r--r--ot/sliced.py187
-rw-r--r--ot/smooth.py11
-rw-r--r--ot/solvers.py347
-rw-r--r--ot/unbalanced.py189
-rw-r--r--ot/utils.py238
-rw-r--r--ot/weak.py6
-rw-r--r--requirements.txt3
-rw-r--r--setup.py2
-rw-r--r--test/test_1d_solver.py127
-rw-r--r--test/test_backend.py52
-rw-r--r--test/test_bregman.py315
-rw-r--r--test/test_coot.py359
-rw-r--r--test/test_da.py79
-rw-r--r--test/test_gaussian.py98
-rw-r--r--test/test_gromov.py637
-rw-r--r--test/test_optim.py63
-rw-r--r--test/test_ot.py59
-rwxr-xr-xtest/test_partial.py124
-rw-r--r--test/test_sliced.py200
-rw-r--r--test/test_solvers.py133
-rw-r--r--test/test_unbalanced.py61
-rw-r--r--test/test_utils.py40
84 files changed, 11155 insertions, 3796 deletions
diff --git a/.circleci/config.yml b/.circleci/config.yml
index 7e15a65..d0e972c 100644
--- a/.circleci/config.yml
+++ b/.circleci/config.yml
@@ -117,10 +117,8 @@ jobs:
echo "Deploying dev docs for ${CIRCLE_BRANCH}.";
cd master
cp -a /tmp/build/html/* .;
- cp -a /tmp/build/html/.github/* .github/;
touch .nojekyll;
git add -A;
- git add -f .github/* ;
git commit -m "CircleCI update of dev docs (${CIRCLE_BUILD_NUM}).";
git push origin master;
else
@@ -157,10 +155,8 @@ jobs:
git clean -xdf
echo "Deploying dev docs for ${CIRCLE_BRANCH}.";
cp -a /tmp/build/html/* .;
- cp -a /tmp/build/html/.github/* .github/;
touch .nojekyll;
git add -A;
- git add -f .github/* ;
git commit -m "CircleCI update of dev docs (${CIRCLE_BUILD_NUM}).";
git push origin master;
diff --git a/.github/CONTRIBUTING.md b/.github/CONTRIBUTING.md
index 9bc8e87..168ffb3 100644
--- a/.github/CONTRIBUTING.md
+++ b/.github/CONTRIBUTING.md
@@ -1,4 +1,3 @@
-
Contributing to POT
===================
diff --git a/.github/requirements_test_windows.txt b/.github/requirements_test_windows.txt
index b94392f..448a7fc 100644
--- a/.github/requirements_test_windows.txt
+++ b/.github/requirements_test_windows.txt
@@ -3,8 +3,7 @@ scipy>=1.3
cython
matplotlib
autograd
-pymanopt==0.2.4; python_version <'3'
-pymanopt==0.2.6rc1; python_version >= '3'
+pymanopt
cvxopt
scikit-learn
pytest \ No newline at end of file
diff --git a/.github/workflows/build_doc.yml b/.github/workflows/build_doc.yml
new file mode 100644
index 0000000..93bd113
--- /dev/null
+++ b/.github/workflows/build_doc.yml
@@ -0,0 +1,44 @@
+name: Build doc
+
+on:
+ workflow_dispatch:
+ pull_request:
+ push:
+ branches:
+ - 'master'
+
+jobs:
+ build:
+
+ runs-on: ubuntu-latest
+
+ steps:
+ - uses: actions/checkout@v1
+ # Standard drop-in approach that should work for most people.
+
+ - name: Set up Python 3.8
+ uses: actions/setup-python@v1
+ with:
+ python-version: 3.8
+
+ - name: Get Python running
+ run: |
+ python -m pip install --user --upgrade --progress-bar off pip
+ python -m pip install --user --upgrade --progress-bar off -r requirements.txt
+ python -m pip install --user --upgrade --progress-bar off -r docs/requirements.txt
+ python -m pip install --user --upgrade --progress-bar off ipython "https://api.github.com/repos/sphinx-gallery/sphinx-gallery/zipball/master" memory_profiler
+ python -m pip install --user -e .
+ # Look at what we have and fail early if there is some library conflict
+ - name: Check installation
+ run: |
+ which python
+ python -c "import ot"
+ # Build docs
+ - name: Generate HTML docs
+ uses: rickstaa/sphinx-action@master
+ with:
+ docs-folder: "docs/"
+ - uses: actions/upload-artifact@v1
+ with:
+ name: Documentation
+ path: docs/build/html/ \ No newline at end of file
diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md
index ab64fba..6b35653 100644
--- a/CONTRIBUTORS.md
+++ b/CONTRIBUTORS.md
@@ -36,7 +36,12 @@ The contributors to this library are:
* [Tanguy Kerdoncuff](https://hv0nnus.github.io/) (Sampled Gromov Wasserstein)
* [Minhui Huang](https://mhhuang95.github.io) (Projection Robust Wasserstein Distance)
* [Nathan Cassereau](https://github.com/ncassereau-idris) (Backends)
-* [Cédric Vincent-Cuaz](https://github.com/cedricvincentcuaz) (Graph Dictionary Learning)
+* [Cédric Vincent-Cuaz](https://github.com/cedricvincentcuaz) (Graph Dictionary Learning, semi-relaxed FGW)
+* [Eloi Tanguy](https://github.com/eloitanguy) (Generalized Wasserstein Barycenters)
+* [Camille Le Coz](https://www.linkedin.com/in/camille-le-coz-8593b91a1/) (EMD2 debug)
+* [Eduardo Fernandes Montesuma](https://eddardd.github.io/my-personal-blog/) (Free support sinkhorn barycenter)
+* [Theo Gnassounou](https://github.com/tgnassou) (OT between Gaussian distributions)
+* [Clément Bonet](https://clbonet.github.io) (Wassertstein on circle, Spherical Sliced-Wasserstein)
## Acknowledgments
diff --git a/Makefile b/Makefile
index 315218d..7a5cbe1 100644
--- a/Makefile
+++ b/Makefile
@@ -42,7 +42,7 @@ clean : FORCE
$(PYTHON) setup.py clean
pep8 :
- flake8 examples/ ot/ test/
+ flake8 examples/ ot/ test/ --count --max-line-length=127 --statistics --show-source
test : FORCE pep8
$(PYTHON) -m pytest --durations=20 -v test/ --doctest-modules --ignore ot/gpu/
diff --git a/README.md b/README.md
index e2b33d9..2a81e95 100644
--- a/README.md
+++ b/README.md
@@ -26,8 +26,8 @@ POT provides the following generic OT solvers (links to examples):
* Debiased Sinkhorn barycenters [Sinkhorn divergence barycenter](https://pythonot.github.io/auto_examples/barycenters/plot_debiased_barycenter.html) [37]
* [Smooth optimal transport solvers](https://pythonot.github.io/auto_examples/plot_OT_1D_smooth.html) (dual and semi-dual) for KL and squared L2 regularizations [17].
* Weak OT solver between empirical distributions [39]
-* Non regularized [Wasserstein barycenters [16] ](https://pythonot.github.io/auto_examples/barycenters/plot_barycenter_lp_vs_entropic.html)) with LP solver (only small scale).
-* [Gromov-Wasserstein distances](https://pythonot.github.io/auto_examples/gromov/plot_gromov.html) and [GW barycenters](https://pythonot.github.io/auto_examples/gromov/plot_gromov_barycenter.html) (exact [13] and regularized [12]), differentiable using gradients from
+* Non regularized [Wasserstein barycenters [16] ](https://pythonot.github.io/auto_examples/barycenters/plot_barycenter_lp_vs_entropic.html) with LP solver (only small scale).
+* [Gromov-Wasserstein distances](https://pythonot.github.io/auto_examples/gromov/plot_gromov.html) and [GW barycenters](https://pythonot.github.io/auto_examples/gromov/plot_gromov_barycenter.html) (exact [13] and regularized [12]), differentiable using gradients from Graph Dictionary Learning [38]
* [Fused-Gromov-Wasserstein distances solver](https://pythonot.github.io/auto_examples/gromov/plot_fgw.html#sphx-glr-auto-examples-plot-fgw-py) and [FGW barycenters](https://pythonot.github.io/auto_examples/gromov/plot_barycenter_fgw.html) [24]
* [Stochastic
solver](https://pythonot.github.io/auto_examples/others/plot_stochastic.html) and
@@ -39,7 +39,10 @@ POT provides the following generic OT solvers (links to examples):
* [Partial Wasserstein and Gromov-Wasserstein](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_partial_wass_and_gromov.html) (exact [29] and entropic [3]
formulations).
* [Sliced Wasserstein](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_variance.html) [31, 32] and Max-sliced Wasserstein [35] that can be used for gradient flows [36].
+* [Wasserstein distance on the circle](https://pythonot.github.io/auto_examples/plot_compute_wasserstein_circle.html) [44, 45]
+* [Spherical Sliced Wasserstein](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_variance_ssw.html) [46]
* [Graph Dictionary Learning solvers](https://pythonot.github.io/auto_examples/gromov/plot_gromov_wasserstein_dictionary_learning.html) [38].
+* [Semi-relaxed (Fused) Gromov-Wasserstein divergences](https://pythonot.github.io/auto_examples/gromov/plot_semirelaxed_fgw.html) [48].
* [Several backends](https://pythonot.github.io/quickstart.html#solving-ot-with-multiple-backends) for easy use of POT with [Pytorch](https://pytorch.org/)/[jax](https://github.com/google/jax)/[Numpy](https://numpy.org/)/[Cupy](https://cupy.dev/)/[Tensorflow](https://www.tensorflow.org/) arrays.
POT provides the following Machine Learning related solvers:
@@ -189,7 +192,7 @@ POT has benefited from the financing or manpower from the following partners:
## Contributions and code of conduct
-Every contribution is welcome and should respect the [contribution guidelines](.github/CONTRIBUTING.md). Each member of the project is expected to follow the [code of conduct](.github/CODE_OF_CONDUCT.md).
+Every contribution is welcome and should respect the [contribution guidelines](https://pythonot.github.io/master/contributing.html). Each member of the project is expected to follow the [code of conduct](https://pythonot.github.io/master/code_of_conduct.html).
## Support
@@ -273,19 +276,35 @@ You can also post bug reports and feature requests in Github issues. Make sure t
[35] Deshpande, I., Hu, Y. T., Sun, R., Pyrros, A., Siddiqui, N., Koyejo, S., ... & Schwing, A. G. (2019). [Max-sliced wasserstein distance and its use for gans](https://openaccess.thecvf.com/content_CVPR_2019/papers/Deshpande_Max-Sliced_Wasserstein_Distance_and_Its_Use_for_GANs_CVPR_2019_paper.pdf). In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (pp. 10648-10656).
-[36] Liutkus, A., Simsekli, U., Majewski, S., Durmus, A., & Stöter, F. R.
-(2019, May). [Sliced-Wasserstein flows: Nonparametric generative modeling
-via optimal transport and diffusions](http://proceedings.mlr.press/v97/liutkus19a/liutkus19a.pdf). In International Conference on
+[36] Liutkus, A., Simsekli, U., Majewski, S., Durmus, A., & Stöter, F. R.
+(2019, May). [Sliced-Wasserstein flows: Nonparametric generative modeling
+via optimal transport and diffusions](http://proceedings.mlr.press/v97/liutkus19a/liutkus19a.pdf). In International Conference on
Machine Learning (pp. 4104-4113). PMLR.
[37] Janati, H., Cuturi, M., Gramfort, A. [Debiased sinkhorn barycenters](http://proceedings.mlr.press/v119/janati20a/janati20a.pdf) Proceedings of the 37th International
Conference on Machine Learning, PMLR 119:4692-4701, 2020
-[38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, [Online Graph
+[38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, [Online Graph
Dictionary Learning](https://arxiv.org/pdf/2102.06555.pdf), International Conference on Machine Learning (ICML), 2021.
[39] Gozlan, N., Roberto, C., Samson, P. M., & Tetali, P. (2017). [Kantorovich duality for general transport costs and applications](https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.712.1825&rep=rep1&type=pdf). Journal of Functional Analysis, 273(11), 3327-3405.
[40] Forrow, A., Hütter, J. C., Nitzan, M., Rigollet, P., Schiebinger, G., & Weed, J. (2019, April). [Statistical optimal transport via factored couplings](http://proceedings.mlr.press/v89/forrow19a/forrow19a.pdf). In The 22nd International Conference on Artificial Intelligence and Statistics (pp. 2454-2465). PMLR.
-[41] Chapel*, L., Flamary*, R., Wu, H., Févotte, C., Gasso, G. (2021). [Unbalanced Optimal Transport through Non-negative Penalized Linear Regression](https://proceedings.neurips.cc/paper/2021/file/c3c617a9b80b3ae1ebd868b0017cc349-Paper.pdf) Advances in Neural Information Processing Systems (NeurIPS), 2020. (Two first co-authors) \ No newline at end of file
+[41] Chapel*, L., Flamary*, R., Wu, H., Févotte, C., Gasso, G. (2021). [Unbalanced Optimal Transport through Non-negative Penalized Linear Regression](https://proceedings.neurips.cc/paper/2021/file/c3c617a9b80b3ae1ebd868b0017cc349-Paper.pdf) Advances in Neural Information Processing Systems (NeurIPS), 2020. (Two first co-authors)
+
+[42] Delon, J., Gozlan, N., and Saint-Dizier, A. [Generalized Wasserstein barycenters between probability measures living on different subspaces](https://arxiv.org/pdf/2105.09755). arXiv preprint arXiv:2105.09755, 2021.
+
+[43] Álvarez-Esteban, Pedro C., et al. [A fixed-point approach to barycenters in Wasserstein space.](https://arxiv.org/pdf/1511.05355.pdf) Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762.
+
+[44] Delon, Julie, Julien Salomon, and Andrei Sobolevski. [Fast transport optimization for Monge costs on the circle.](https://arxiv.org/abs/0902.3527) SIAM Journal on Applied Mathematics 70.7 (2010): 2239-2258.
+
+[45] Hundrieser, Shayan, Marcel Klatt, and Axel Munk. [The statistics of circular optimal transport.](https://arxiv.org/abs/2103.15426) Directional Statistics for Innovative Applications: A Bicentennial Tribute to Florence Nightingale. Singapore: Springer Nature Singapore, 2022. 57-82.
+
+[46] Bonet, C., Berg, P., Courty, N., Septier, F., Drumetz, L., & Pham, M. T. (2023). [Spherical Sliced-Wasserstein](https://openreview.net/forum?id=jXQ0ipgMdU). International Conference on Learning Representations.
+
+[47] Chowdhury, S., & Mémoli, F. (2019). [The gromov–wasserstein distance between networks and stable network invariants](https://academic.oup.com/imaiai/article/8/4/757/5627736). Information and Inference: A Journal of the IMA, 8(4), 757-787.
+
+[48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty (2022). [Semi-relaxed Gromov-Wasserstein divergence and applications on graphs](https://openreview.net/pdf?id=RShaMexjc-x). International Conference on Learning Representations (ICLR), 2022.
+
+[49] Redko, I., Vayer, T., Flamary, R., and Courty, N. (2020). [CO-Optimal Transport](https://proceedings.neurips.cc/paper/2020/file/cc384c68ad503482fb24e6d1e3b512ae-Paper.pdf). Advances in Neural Information Processing Systems, 33.
diff --git a/RELEASES.md b/RELEASES.md
index be2192e..e978905 100644
--- a/RELEASES.md
+++ b/RELEASES.md
@@ -1,5 +1,129 @@
# Releases
+## 0.9.0
+
+This new release contains so many new features and bug fixes since 0.8.2 that we
+decided to make it a new minor release at 0.9.0.
+
+The release contains many new features. First we did a major
+update of all Gromov-Wasserstein solvers that brings up to 30% gain in
+computation time (see PR #431) and allows the GW solvers to work on non symmetric
+matrices. It also brings novel solvers for the very
+efficient [semi-relaxed GW problem
+](https://pythonot.github.io/master/auto_examples/gromov/plot_semirelaxed_fgw.html#sphx-glr-auto-examples-gromov-plot-semirelaxed-fgw-py)
+that can be used to find the best re-weighting for one of the distributions. We
+also now have fast and differentiable solvers for [Wasserstein on the circle](https://pythonot.github.io/master/auto_examples/plot_compute_wasserstein_circle.html#sphx-glr-auto-examples-plot-compute-wasserstein-circle-py) and
+[sliced Wasserstein on the
+sphere](https://pythonot.github.io/master/auto_examples/backends/plot_ssw_unif_torch.html#sphx-glr-auto-examples-backends-plot-ssw-unif-torch-py).
+We are also very happy to provide new OT barycenter solvers such as the [Free
+support Sinkhorn
+barycenter](https://pythonot.github.io/master/auto_examples/barycenters/plot_free_support_sinkhorn_barycenter.html#sphx-glr-auto-examples-barycenters-plot-free-support-sinkhorn-barycenter-py)
+and the [Generalized Wasserstein
+barycenter](https://pythonot.github.io/master/auto_examples/barycenters/plot_generalized_free_support_barycenter.html#sphx-glr-auto-examples-barycenters-plot-generalized-free-support-barycenter-py).
+A new differentiable solver for OT across spaces that provides OT plans
+between samples and features simultaneously and
+called [Co-Optimal
+Transport](https://pythonot.github.io/master/auto_examples/others/plot_COOT.html)
+has also been implemented. Finally we began working on OT between Gaussian distributions and
+now provide differentiable estimation for the Bures-Wasserstein [divergence](https://pythonot.github.io/master/gen_modules/ot.gaussian.html#ot.gaussian.bures_wasserstein_distance) and
+[mappings](https://pythonot.github.io/master/auto_examples/domain-adaptation/plot_otda_linear_mapping.html#sphx-glr-auto-examples-domain-adaptation-plot-otda-linear-mapping-py).
+
+Another important first step toward POT 1.0 is the
+implementation of a unified API for OT solvers with introduction of [`ot.solve`](https://pythonot.github.io/master/all.html#ot.solve)
+function that can solve (depending on parameters) exact, regularized and
+unbalanced OT and return a new
+[`OTResult`](https://pythonot.github.io/master/gen_modules/ot.utils.html#ot.utils.OTResult)
+object. The idea behind this new API is to facilitate exploring different solvers
+with just a change of parameter and get a more unified API for them. We will keep
+the old solvers API for power users but it will be the preferred way to solve
+problems starting from release 1.0.0.
+We provide below some examples of use for the new function and how to
+recover different aspects of the solution (OT plan, full loss, linear part of the
+loss, dual variables) :
+```python
+#Solve exact ot
+sol = ot.solve(M)
+
+# get the results
+G = sol.plan # OT plan
+ot_loss = sol.value # OT value (full loss for regularized and unbalanced)
+ot_loss_linear = sol.value_linear # OT value for linear term np.sum(sol.plan*M)
+alpha, beta = sol.potentials # dual potentials
+
+# direct plan and loss computation
+G = ot.solve(M).plan
+ot_loss = ot.solve(M).value
+
+# OT exact with marginals a/b
+sol2 = ot.solve(M, a, b)
+
+# regularized and unbalanced OT
+sol_rkl = ot.solve(M, a, b, reg=1) # KL regularization
+sol_rl2 = ot.solve(M, a, b, reg=1, reg_type='L2')
+sol_ul2 = ot.solve(M, a, b, unbalanced=10, unbalanced_type='L2')
+sol_rkl_ukl = ot.solve(M, a, b, reg=10, unbalanced=10) # KL + KL
+
+```
+The function is fully compatible with backends and will be implemented for
+different types of distribution support (empirical distributions, grids) and OT
+problems (Gromov-Wasserstein) in the new releases. This new API is not yet
+presented in the kickstart part of the documentation as there is a small change
+that it might change
+when implementing new solvers but we encourage users to play with it.
+
+Finally, in addition to those many new this release fixes 20 issues (some long
+standing) and we want to thank all the contributors who made this release so
+big. More details below.
+
+
+#### New features
+- Added feature to (Fused) Gromov-Wasserstein solvers herited from `ot.optim` to support relative and absolute loss variations as stopping criterions (PR #431)
+- Added feature to (Fused) Gromov-Wasserstein solvers to handle asymmetric matrices (PR #431)
+- Added semi-relaxed (Fused) Gromov-Wasserstein solvers in `ot.gromov` + examples (PR #431)
+- Added the spherical sliced-Wasserstein discrepancy in `ot.sliced.sliced_wasserstein_sphere` and `ot.sliced.sliced_wasserstein_sphere_unif` + examples (PR #434)
+- Added the Wasserstein distance on the circle in ``ot.lp.solver_1d.wasserstein_circle`` (PR #434)
+- Added the Wasserstein distance on the circle (for p>=1) in `ot.lp.solver_1d.binary_search_circle` + examples (PR #434)
+- Added the 2-Wasserstein distance on the circle w.r.t a uniform distribution in `ot.lp.solver_1d.semidiscrete_wasserstein2_unif_circle` (PR #434)
+- Added Bures Wasserstein distance in `ot.gaussian` (PR ##428)
+- Added Generalized Wasserstein Barycenter solver + example (PR #372), fixed graphical details on the example (PR #376)
+- Added Free Support Sinkhorn Barycenter + example (PR #387)
+- New API for OT solver using function `ot.solve` (PR #388)
+- Backend version of `ot.partial` and `ot.smooth` (PR #388 and #449)
+- Added argument for warmstart of dual potentials in Sinkhorn-based methods in `ot.bregman` (PR #437)
+- Added parameters method in `ot.da.SinkhornTransport` (PR #440)
+- `ot.dr` now uses the new Pymanopt API and POT is compatible with current
+ Pymanopt (PR #443)
+- Added CO-Optimal Transport solver + examples (PR # 447)
+- Remove the redundant `nx.abs()` at the end of `wasserstein_1d()` (PR #448)
+
+#### Closed issues
+
+
+- Fixed an issue with the documentation gallery sections (PR #395)
+- Fixed an issue where sinkhorn divergence did not have a gradients (Issue #393, PR #394)
+- Fixed an issue where we could not ask TorchBackend to place a random tensor on GPU
+ (Issue #371, PR #373)
+- Fixed an issue where Sinkhorn solver assumed a symmetric cost matrix (Issue #374, PR #375)
+- Fixed an issue where hitting iteration limits would be reported to stderr by std::cerr regardless of Python's stderr stream status (PR #377)
+- Fixed an issue where the metric argument in ot.dist did not allow a callable parameter (Issue #378, PR #379)
+- Fixed an issue where the max number of iterations in ot.emd was not allowed to go beyond 2^31 (PR #380)
+- Fixed an issue where pointers would overflow in the EMD solver, returning an
+incomplete transport plan above a certain size (slightly above 46k, its square being
+roughly 2^31) (PR #381)
+- Error raised when mass mismatch in emd2 (PR #386)
+- Fixed an issue where a pytorch example would throw an error if executed on a GPU (Issue #389, PR #391)
+- Added a work-around for scipy's bug, where you cannot compute the Hamming distance with a "None" weight attribute. (Issue #400, PR #402)
+- Fixed an issue where the doc could not be built due to some changes in matplotlib's API (Issue #403, PR #402)
+- Replaced Numpy C Compiler with Setuptools C Compiler due to deprecation issues (Issue #408, PR #409)
+- Fixed weak optimal transport docstring (Issue #404, PR #410)
+- Fixed error with parameter `log=True`for `SinkhornLpl1Transport` (Issue #412,
+PR #413)
+- Fixed an issue about `warn` parameter in `sinkhorn2` (PR #417)
+- Fix an issue where the parameter `stopThr` in `empirical_sinkhorn_divergence` was rendered useless by subcalls
+ that explicitly specified `stopThr=1e-9` (Issue #421, PR #422).
+- Fixed a bug breaking an example where we would try to make an array of arrays of different shapes (Issue #424, PR #425)
+- Fixed an issue with the documentation gallery section (PR #444)
+- Fixed issues with cuda variables for `line_search_armijo` and `entropic_gromov_wasserstein` (Issue #445, #PR 446)
## 0.8.2
@@ -46,7 +170,7 @@ and [Factored coupling OT](https://pythonot.github.io/auto_examples/others/plot_
- Remove deprecated `ot.gpu` submodule (PR #361)
- Update examples in the gallery (PR #359)
-- Add stochastic loss and OT plan computation for regularized OT and
+- Add stochastic loss and OT plan computation for regularized OT and
backend examples(PR #360)
- Implementation of factored OT with emd and sinkhorn (PR #358)
- A brand new logo for POT (PR #357)
@@ -62,9 +186,9 @@ and [Factored coupling OT](https://pythonot.github.io/auto_examples/others/plot_
#### Closed issues
-- Fix mass gradient of `ot.emd2` and `ot.gromov_wasserstein2` so that they are
+- Fix mass gradient of `ot.emd2` and `ot.gromov_wasserstein2` so that they are
centered (Issue #364, PR #363)
-- Fix bug in instantiating an `autograd` function `ValFunction` (Issue #337,
+- Fix bug in instantiating an `autograd` function `ValFunction` (Issue #337,
PR #338)
- Fix POT ABI compatibility with old and new numpy (Issue #346, PR #349)
- Warning when feeding integer cost matrix to EMD solver resulting in an integer transport plan (Issue #345, PR #343)
@@ -114,21 +238,21 @@ As always we want to that the contributors who helped make POT better (and bug f
- Fix bug in older Numpy ABI (<1.20) (Issue #308, PR #326)
- Fix bug in `ot.dist` function when non euclidean distance (Issue #305, PR #306)
-- Fix gradient scaling for functions using `nx.set_gradients` (Issue #309,
+- Fix gradient scaling for functions using `nx.set_gradients` (Issue #309,
PR #310)
-- Fix bug in generalized Conditional gradient solver and SinkhornL1L2
+- Fix bug in generalized Conditional gradient solver and SinkhornL1L2
(Issue #311, PR #313)
- Fix log error in `gromov_barycenters` (Issue #317, PR #3018)
## 0.8.0
*November 2021*
-This new stable release introduces several important features.
+This new stable release introduces several important features.
First we now have
an OpenMP compatible exact ot solver in `ot.emd`. The OpenMP version is used
when the parameter `numThreads` is greater than one and can lead to nice
-speedups on multi-core machines.
+speedups on multi-core machines.
Second we have introduced a backend mechanism that allows to use standard POT
function seamlessly on Numpy, Pytorch and Jax arrays. Other backends are coming
@@ -147,7 +271,7 @@ for a [sliced Wasserstein gradient
flow](https://PythonOT.github.io/auto_examples/backends/plot_sliced_wass_grad_flow_pytorch.html)
and [optimizing the Gromov-Wassersein distance](https://PythonOT.github.io/auto_examples/backends/plot_optim_gromov_pytorch.html). Note that the Jax backend is still in early development and quite
slow at the moment, we strongly recommend for Jax users to use the [OTT
-toolbox](https://github.com/google-research/ott) when possible.
+toolbox](https://github.com/google-research/ott) when possible.
As a result of this new feature,
the old `ot.gpu` submodule is now deprecated since GPU
implementations can be done using GPU arrays on the torch backends.
@@ -170,7 +294,7 @@ Finally POT was accepted for publication in the Journal of Machine Learning
Research (JMLR) open source software track and we ask the POT users to cite [this
paper](https://www.jmlr.org/papers/v22/20-451.html) from now on. The documentation has been improved in particular by adding a
"Why OT?" section to the quick start guide and several new examples illustrating
-the new features. The documentation now has two version : the stable version
+the new features. The documentation now has two version : the stable version
[https://pythonot.github.io/](https://pythonot.github.io/)
corresponding to the last release and the master version [https://pythonot.github.io/master](https://pythonot.github.io/master) that corresponds to the
current master branch on GitHub.
@@ -180,7 +304,7 @@ As usual, we want to thank all the POT contributors (now 37 people have
contributed to the toolbox). But for this release we thank in particular Nathan
Cassereau and Kamel Guerda from the AI support team at
[IDRIS](http://www.idris.fr/) for their support to the development of the
-backend and OpenMP implementations.
+backend and OpenMP implementations.
#### New features
@@ -247,7 +371,7 @@ repository for the new documentation is now hosted at
This is the first release where the Python 2.7 tests have been removed. Most of
the toolbox should still work but we do not offer support for Python 2.7 and
-will close related Issues.
+will close related Issues.
A lot of changes have been done to the documentation that is now hosted on
[https://PythonOT.github.io/](https://PythonOT.github.io/) instead of
@@ -280,7 +404,7 @@ problems.
This release is also the moment to thank all the POT contributors (old and new)
for helping making POT such a nice toolbox. A lot of changes (also in the API)
-are coming for the next versions.
+are coming for the next versions.
#### Features
@@ -309,14 +433,14 @@ are coming for the next versions.
- Log bugs for Gromov-Wassertein solver (Issue #107, fixed in PR #108)
- Weight issues in barycenter function (PR #106)
-## 0.6.0
+## 0.6.0
*July 2019*
-This is the first official stable release of POT and this means a jump to 0.6!
+This is the first official stable release of POT and this means a jump to 0.6!
The library has been used in
the wild for a while now and we have reached a state where a lot of fundamental
OT solvers are available and tested. It has been quite stable in the last months
-but kept the beta flag in its Pypi classifiers until now.
+but kept the beta flag in its Pypi classifiers until now.
Note that this release will be the last one supporting officially Python 2.7 (See
https://python3statement.org/ for more reasons). For next release we will keep
@@ -345,7 +469,7 @@ graphs](https://github.com/rflamary/POT/blob/master/notebooks/plot_barycenter_fg
A lot of work has been done on the documentation with several new
examples corresponding to the new features and a lot of corrections for the
-docstrings. But the most visible change is a new
+docstrings. But the most visible change is a new
[quick start guide](https://pot.readthedocs.io/en/latest/quickstart.html) for
POT that gives several pointers about which function or classes allow to solve which
specific OT problem. When possible a link is provided to relevant examples.
@@ -383,29 +507,29 @@ bring new features and solvers to the library.
- Issue #72 Macosx build problem
-## 0.5.0
+## 0.5.0
*Sep 2018*
-POT is 2 years old! This release brings numerous new features to the
+POT is 2 years old! This release brings numerous new features to the
toolbox as listed below but also several bug correction.
-Among the new features, we can highlight a [non-regularized Gromov-Wasserstein
-solver](https://github.com/rflamary/POT/blob/master/notebooks/plot_gromov.ipynb),
-a new [greedy variant of sinkhorn](https://pot.readthedocs.io/en/latest/all.html#ot.bregman.greenkhorn),
-[non-regularized](https://pot.readthedocs.io/en/latest/all.html#ot.lp.barycenter),
+Among the new features, we can highlight a [non-regularized Gromov-Wasserstein
+solver](https://github.com/rflamary/POT/blob/master/notebooks/plot_gromov.ipynb),
+a new [greedy variant of sinkhorn](https://pot.readthedocs.io/en/latest/all.html#ot.bregman.greenkhorn),
+[non-regularized](https://pot.readthedocs.io/en/latest/all.html#ot.lp.barycenter),
[convolutional (2D)](https://github.com/rflamary/POT/blob/master/notebooks/plot_convolutional_barycenter.ipynb)
and [free support](https://github.com/rflamary/POT/blob/master/notebooks/plot_free_support_barycenter.ipynb)
- Wasserstein barycenters and [smooth](https://github.com/rflamary/POT/blob/prV0.5/notebooks/plot_OT_1D_smooth.ipynb)
- and [stochastic](https://pot.readthedocs.io/en/latest/all.html#ot.stochastic.sgd_entropic_regularization)
+ Wasserstein barycenters and [smooth](https://github.com/rflamary/POT/blob/prV0.5/notebooks/plot_OT_1D_smooth.ipynb)
+ and [stochastic](https://pot.readthedocs.io/en/latest/all.html#ot.stochastic.sgd_entropic_regularization)
implementation of entropic OT.
-POT 0.5 also comes with a rewriting of ot.gpu using the cupy framework instead of
-the unmaintained cudamat. Note that while we tried to keed changes to the
-minimum, the OTDA classes were deprecated. If you are happy with the cudamat
+POT 0.5 also comes with a rewriting of ot.gpu using the cupy framework instead of
+the unmaintained cudamat. Note that while we tried to keed changes to the
+minimum, the OTDA classes were deprecated. If you are happy with the cudamat
implementation, we recommend you stay with stable release 0.4 for now.
-The code quality has also improved with 92% code coverage in tests that is now
-printed to the log in the Travis builds. The documentation has also been
+The code quality has also improved with 92% code coverage in tests that is now
+printed to the log in the Travis builds. The documentation has also been
greatly improved with new modules and examples/notebooks.
This new release is so full of new stuff and corrections thanks to the old
@@ -424,24 +548,24 @@ and new POT contributors (you can see the list in the [readme](https://github.co
* Stochastic OT in the dual and semi-dual (PR #52 and PR #62)
* Free support barycenters (PR #56)
* Speed-up Sinkhorn function (PR #57 and PR #58)
-* Add convolutional Wassersein barycenters for 2D images (PR #64)
+* Add convolutional Wassersein barycenters for 2D images (PR #64)
* Add Greedy Sinkhorn variant (Greenkhorn) (PR #66)
* Big ot.gpu update with cupy implementation (instead of un-maintained cudamat) (PR #67)
#### Deprecation
-Deprecated OTDA Classes were removed from ot.da and ot.gpu for version 0.5
-(PR #48 and PR #67). The deprecation message has been for a year here since
+Deprecated OTDA Classes were removed from ot.da and ot.gpu for version 0.5
+(PR #48 and PR #67). The deprecation message has been for a year here since
0.4 and it is time to pull the plug.
#### Closed issues
* Issue #35 : remove import plot from ot/__init__.py (See PR #41)
* Issue #43 : Unusable parameter log for EMDTransport (See PR #44)
-* Issue #55 : UnicodeDecodeError: 'ascii' while installing with pip
+* Issue #55 : UnicodeDecodeError: 'ascii' while installing with pip
-## 0.4
+## 0.4
*15 Sep 2017*
This release contains a lot of contribution from new contributors.
@@ -451,14 +575,14 @@ This release contains a lot of contribution from new contributors.
* Automatic notebooks and doc update (PR #27)
* Add gromov Wasserstein solver and Gromov Barycenters (PR #23)
-* emd and emd2 can now return dual variables and have max_iter (PR #29 and PR #25)
+* emd and emd2 can now return dual variables and have max_iter (PR #29 and PR #25)
* New domain adaptation classes compatible with scikit-learn (PR #22)
* Proper tests with pytest on travis (PR #19)
* PEP 8 tests (PR #13)
#### Closed issues
-* emd convergence problem du to fixed max iterations (#24)
+* emd convergence problem du to fixed max iterations (#24)
* Semi supervised DA error (#26)
## 0.3.1
@@ -466,7 +590,7 @@ This release contains a lot of contribution from new contributors.
* Correct bug in emd on windows
-## 0.3
+## 0.3
*7 Jul 2017*
* emd* and sinkhorn* are now performed in parallel for multiple target distributions
@@ -479,7 +603,7 @@ This release contains a lot of contribution from new contributors.
* GPU implementations for sinkhorn and group lasso regularization
-## V0.2
+## V0.2
*7 Apr 2017*
* New dimensionality reduction method (WDA)
@@ -487,7 +611,7 @@ This release contains a lot of contribution from new contributors.
-## 0.1.11
+## 0.1.11
*5 Jan 2017*
* Add sphinx gallery for better documentation
@@ -495,7 +619,7 @@ This release contains a lot of contribution from new contributors.
* Add simple tic() toc() functions for timing
-## 0.1.10
+## 0.1.10
*7 Nov 2016*
* numerical stabilization for sinkhorn (log domain and epsilon scaling)
@@ -524,4 +648,4 @@ It provides the following solvers:
* Optimal transport for domain adaptation with group lasso regularization
* Conditional gradient and Generalized conditional gradient for regularized OT.
-Some demonstrations (both in Python and Jupyter Notebook format) are available in the examples folder.
+Some demonstrations (both in Python and Jupyter Notebook format) are available in the examples folder. \ No newline at end of file
diff --git a/codecov.yml b/codecov.yml
index 1447ced..dd301ab 100644
--- a/codecov.yml
+++ b/codecov.yml
@@ -38,7 +38,7 @@ parsers:
# https://docs.codecov.io/docs/ignoring-paths
ignore:
- - "ot/gpu/*"
+ - "ot/helpers/openmp_helpers.py"
# https://docs.codecov.io/docs/pull-request-comments
comment:
diff --git a/docs/cache_nbrun b/docs/cache_nbrun
deleted file mode 100644
index ac49515..0000000
--- a/docs/cache_nbrun
+++ /dev/null
@@ -1 +0,0 @@
-{"plot_otda_color_images.ipynb": "128d0435c08ebcf788913e4adcd7dd00", "plot_partial_wass_and_gromov.ipynb": "82242f8390df1d04806b333b745c72cf", "plot_WDA.ipynb": "27f8de4c6d7db46497076523673eedfb", "plot_screenkhorn_1D.ipynb": "af7b8a74a1be0f16f2c3908f5a178de0", "plot_otda_laplacian.ipynb": "d92cc0e528b9277f550daaa6f9d18415", "plot_OT_L1_vs_L2.ipynb": "288230c4e679d752a511353c96c134cb", "plot_otda_semi_supervised.ipynb": "568b39ffbdf6621dd6de162df42f4f21", "plot_fgw.ipynb": "f4de8e6939ce2b1339b3badc1fef0f37", "plot_otda_d2.ipynb": "07ef3212ff3123f16c32a5670e0167f8", "plot_compute_emd.ipynb": "299f6fffcdbf48b7c3268c0136e284f8", "plot_barycenter_fgw.ipynb": "9e813d3b07b7c0c0fcc35a778ca1243f", "plot_convolutional_barycenter.ipynb": "fdd259bfcd6d5fe8001efb4345795d2f", "plot_optim_OTreg.ipynb": "bddd8e49f092873d8980d41ae4974e19", "plot_UOT_1D.ipynb": "2658d5164165941b07539dae3cb80a0f", "plot_OT_1D_smooth.ipynb": "f3e1f0e362c9a78071a40c02b85d2305", "plot_barycenter_1D.ipynb": "f6fa5bc13d9811f09792f73b4de70aa0", "plot_otda_mapping.ipynb": "1bb321763f670fc945d77cfc91471e5e", "plot_OT_1D.ipynb": "0346a8c862606d11f36d0aa087ecab0d", "plot_gromov_barycenter.ipynb": "a7999fcc236d90a0adeb8da2c6370db3", "plot_UOT_barycenter_1D.ipynb": "dd9b857a8c66d71d0124d4a2c30a51dd", "plot_otda_mapping_colors_images.ipynb": "16faae80d6ea8b37d6b1f702149a10de", "plot_stochastic.ipynb": "64f23a8dcbab9823ae92f0fd6c3aceab", "plot_otda_linear_mapping.ipynb": "82417d9141e310bf1f2c2ecdb550094b", "plot_otda_classes.ipynb": "8836a924c9b562ef397af12034fa1abb", "plot_free_support_barycenter.ipynb": "be9d0823f9d7774a289311b9f14548eb", "plot_gromov.ipynb": "de06b1dbe8de99abae51c2e0b64b485d", "plot_otda_jcpot.ipynb": "65482cbfef5c6c1e5e73998aeb5f4b10", "plot_OT_2D_samples.ipynb": "9a9496792fa4216b1059fc70abca851a", "plot_barycenter_lp_vs_entropic.ipynb": "334840b69a86898813e50a6db0f3d0de"} \ No newline at end of file
diff --git a/docs/requirements_rtd.txt b/docs/requirements_rtd.txt
index 11957fb..30082bb 100644
--- a/docs/requirements_rtd.txt
+++ b/docs/requirements_rtd.txt
@@ -9,7 +9,6 @@ scipy>=1.0
cython
matplotlib
autograd
-pymanopt==0.2.4; python_version <'3'
-pymanopt; python_version >= '3'
+pymanopt
cvxopt
scikit-learn \ No newline at end of file
diff --git a/docs/source/_templates/module.rst b/docs/source/_templates/module.rst
index 5ad89be..495995e 100644
--- a/docs/source/_templates/module.rst
+++ b/docs/source/_templates/module.rst
@@ -2,6 +2,7 @@
{{ underline }}
.. automodule:: {{ fullname }}
+ :members:
{% block functions %}
{% if functions %}
@@ -12,6 +13,7 @@
{% for item in functions %}
.. autofunction:: {{ item }}
+
.. include:: backreferences/{{fullname}}.{{item}}.examples
diff --git a/docs/source/_templates/versions.html b/docs/source/_templates/versions.html
index f48ab86..5b1021a 100644
--- a/docs/source/_templates/versions.html
+++ b/docs/source/_templates/versions.html
@@ -1,47 +1,50 @@
<div class="rst-versions" data-toggle="rst-versions" role="note"
aria-label="versions">
- <!-- add shift_up to the class for force viewing -->
- <span class="rst-current-version" data-toggle="rst-current-version">
+ <!-- add shift_up to the class for force viewing ,
+ data-toggle="rst-current-version" -->
+ <span class="rst-current-version" style="margin-bottom:1mm;">
<span class="fa fa-book"> Python Optimal Transport</span>
- versions
- <span class="fa fa-caret-down"></span>
+ <hr style="margin-bottom:1.5mm;margin-top:5mm;">
+ <!-- versions
+ <span class="fa fa-caret-down"></span>-->
+ <span class="rst-current-version" style="display: inline-block;padding:
+ 0px;color:#fcfcfcab;float:left;font-size: 100%;">
+ Versions:
+ <a href="https://pythonot.github.io/"
+ style="padding: 3px;color:#fcfcfc;font-size: 100%;">Release</a>
+ <a href="https://pythonot.github.io/master"
+ style="padding: 3px;color:#fcfcfc;font-size: 100%;">Development</a>
+ <a href="https://github.com/PythonOT/POT"
+ style="padding: 3px;color:#fcfcfc;font-size: 100%;">Code</a>
+
+ </span>
+
+
</span>
- <div class="rst-other-versions"><!-- Inserted RTD Footer -->
+
+ <!--
+ <div class="rst-other-versions">
+
+
<div class="injected">
-
-
- <dl>
- <dt>Versions</dt>
-
- <dd><a href="https://pythonot.github.io/">Release</a></dd>
-
- <dd><a href="https://pythonot.github.io/master">Development</a></dd>
-
-
-
- </dl>
-
+
+ <dl>
+ <dt>Versions</dt>
+ <dd><a href="https://pythonot.github.io/">Release</a></dd>
-
- <dl>
- <dt>On GitHub</dt>
- <dd>
- <a href="https://github.com/PythonOT/POT">Code on Github</a>
- </dd>
-
- </dl>
-
+ <dd><a href="https://pythonot.github.io/master">Development</a></dd>
+
-
-
- <hr>
-
+ <dt><a href="https://github.com/PythonOT/POT">Code on Github</a></dt>
+
+ </dl>
+ <hr>
-</div>
-</div>
+</div>
+</div>-->
</div> \ No newline at end of file
diff --git a/docs/source/all.rst b/docs/source/all.rst
index 1ec6be3..a9d7fe2 100644
--- a/docs/source/all.rst
+++ b/docs/source/all.rst
@@ -13,28 +13,34 @@ API and modules
:toctree: gen_modules/
:template: module.rst
- lp
+
backend
bregman
- smooth
- gromov
- optim
+ coot
da
- dr
- utils
datasets
+ dr
+ factored
+ gaussian
+ gromov
+ lp
+ optim
+ partial
plot
- stochastic
- unbalanced
regpath
- partial
sliced
+ smooth
+ stochastic
+ unbalanced
+ utils
weak
- factored
+
-.. autosummary::
- :toctree: ../modules/generated/
- :template: module.rst
+Main :py:mod:`ot` functions
+---------------------------
.. automodule:: ot
:members:
+
+
+
diff --git a/docs/source/.github/CODE_OF_CONDUCT.rst b/docs/source/code_of_conduct.rst
index d4c5cec..40b432e 100644
--- a/docs/source/.github/CODE_OF_CONDUCT.rst
+++ b/docs/source/code_of_conduct.rst
@@ -1,6 +1,6 @@
-Code of Conduct
+Code of conduct
===============
-.. include:: ../../../.github/CODE_OF_CONDUCT.md
+.. include:: ../../.github/CODE_OF_CONDUCT.md
:parser: myst_parser.sphinx_
:start-line: 2
diff --git a/docs/source/conf.py b/docs/source/conf.py
index 9526518..6e76291 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -119,7 +119,7 @@ release = __version__
#
# This is also used if you do content translation via gettext catalogs.
# Usually you set "language" from the command line for these cases.
-language = None
+language = "en"
# There are two options for replacing |today|: either, you set today to some
# non-false value, then it is used:
@@ -194,6 +194,7 @@ html_logo = '_static/images/logo_dark.svg'
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ['_static']
+#html_css_files = ["css/custom.css"]
# Add any extra paths that contain custom files (such as robots.txt or
@@ -340,6 +341,9 @@ texinfo_documents = [
# If true, do not generate a @detailmenu in the "Top" node's menu.
#texinfo_no_detailmenu = False
+autodoc_default_options = {'autosummary': True,
+ 'autosummary_imported_members': True}
+
# Example configuration for intersphinx: refer to the Python standard library.
intersphinx_mapping = {'python': ('https://docs.python.org/3', None),
@@ -351,6 +355,7 @@ intersphinx_mapping = {'python': ('https://docs.python.org/3', None),
sphinx_gallery_conf = {
'examples_dirs': ['../../examples', '../../examples/da'],
'gallery_dirs': 'auto_examples',
+ 'nested_sections' : False,
'backreferences_dir': 'gen_modules/backreferences',
'inspect_global_variables' : True,
'doc_module' : ('ot','numpy','scipy','pylab'),
diff --git a/docs/source/.github/CONTRIBUTING.rst b/docs/source/contributing.rst
index aef24e9..8dec19a 100644
--- a/docs/source/.github/CONTRIBUTING.rst
+++ b/docs/source/contributing.rst
@@ -1,6 +1,6 @@
Contributing to POT
===================
-.. include:: ../../../.github/CONTRIBUTING.md
+.. include:: ../../.github/CONTRIBUTING.md
:parser: myst_parser.sphinx_
- :start-line: 3
+ :start-line: 2
diff --git a/docs/source/index.rst b/docs/source/index.rst
index 3d53ef4..0f04738 100644
--- a/docs/source/index.rst
+++ b/docs/source/index.rst
@@ -21,9 +21,9 @@ Contents
all
auto_examples/index
releases
- .github/CONTRIBUTING
contributors
- .github/CODE_OF_CONDUCT
+ contributing
+ code_of_conduct
.. include:: ../../README.md
diff --git a/docs/source/quickstart.rst b/docs/source/quickstart.rst
index b4cc8ab..1dc9f71 100644
--- a/docs/source/quickstart.rst
+++ b/docs/source/quickstart.rst
@@ -127,14 +127,6 @@ been used to solve both graph Laplacian regularization OT and Gromov
Wasserstein [30]_.
-.. note::
-
- POT is originally designed to solve OT problems with Numpy interface and
- is not yet compatible with Pytorch API. We are currently working on a torch
- submodule that will provide OT solvers and losses for the most common deep
- learning configurations.
-
-
When not to use POT
"""""""""""""""""""
@@ -279,7 +271,7 @@ distributions. In this case there exists a close form solution given in Remark
2.29 in [15]_ and the Monge mapping is an affine function and can be
also computed from the covariances and means of the source and target
distributions. In the case when the finite sample dataset is supposed Gaussian,
-we provide :any:`ot.da.OT_mapping_linear` that returns the parameters for the
+we provide :any:`ot.gaussian.bures_wasserstein_mapping` that returns the parameters for the
Monge mapping.
@@ -628,7 +620,7 @@ approximate a Monge mapping from finite distributions.
First note that when the source and target distributions are supposed to be Gaussian
distributions, there exists a close form solution for the mapping and its an
affine function [14]_ of the form :math:`T(x)=Ax+b` . In this case we provide the function
-:any:`ot.da.OT_mapping_linear` that returns the operator :math:`A` and vector
+:any:`ot.gaussian.bures_wasserstein_mapping` that returns the operator :math:`A` and vector
:math:`b`. Note that if the number of samples is too small there is a parameter
:code:`reg` that provides a regularization for the covariance matrix estimation.
@@ -640,7 +632,7 @@ method proposed in [8]_ that estimates a continuous mapping approximating the
barycentric mapping is provided in :any:`ot.da.joint_OT_mapping_linear` for
linear mapping and :any:`ot.da.joint_OT_mapping_kernel` for non-linear mapping.
-.. minigallery:: ot.da.joint_OT_mapping_linear ot.da.joint_OT_mapping_linear ot.da.OT_mapping_linear
+.. minigallery:: ot.da.joint_OT_mapping_linear ot.da.joint_OT_mapping_linear ot.gaussian.bures_wasserstein_mapping
:add-heading: Examples of Monge mapping estimation
:heading-level: "
@@ -692,42 +684,8 @@ A list of the provided implementation is given in the following note.
:heading-level: "
-Other applications
-------------------
-
-We discuss in the following several OT related problems and tools that has been
-proposed in the OT and machine learning community.
-
-Wasserstein Discriminant Analysis
-^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
-
-Wasserstein Discriminant Analysis [11]_ is a generalization of `Fisher Linear Discriminant
-Analysis <https://en.wikipedia.org/wiki/Linear_discriminant_analysis>`__ that
-allows discrimination between classes that are not linearly separable. It
-consists in finding a linear projector optimizing the following criterion
-
-.. math::
- P = \text{arg}\min_P \frac{\sum_i OT_e(\mu_i\#P,\mu_i\#P)}{\sum_{i,j\neq i}
- OT_e(\mu_i\#P,\mu_j\#P)}
-
-where :math:`\#` is the push-forward operator, :math:`OT_e` is the entropic OT
-loss and :math:`\mu_i` is the
-distribution of samples from class :math:`i`. :math:`P` is also constrained to
-be in the Stiefel manifold. WDA can be solved in POT using function
-:any:`ot.dr.wda`. It requires to have installed :code:`pymanopt` and
-:code:`autograd` for manifold optimization and automatic differentiation
-respectively. Note that we also provide the Fisher discriminant estimator in
-:any:`ot.dr.fda` for easy comparison.
-
-.. warning::
-
- Note that due to the hard dependency on :code:`pymanopt` and
- :code:`autograd`, :any:`ot.dr` is not imported by default. If you want to
- use it you have to specifically import it with :code:`import ot.dr` .
-
-.. minigallery:: ot.dr.wda
- :add-heading: Examples of the use of WDA
- :heading-level: "
+Unbalanced and partial OT
+-------------------------
@@ -845,10 +803,11 @@ regularization of the problem.
:heading-level: "
+Gromov Wasserstein and extensions
+---------------------------------
-
-Gromov-Wasserstein
-^^^^^^^^^^^^^^^^^^
+Gromov Wasserstein(GW)
+^^^^^^^^^^^^^^^^^^^^^^
Gromov Wasserstein (GW) is a generalization of OT to distributions that do not lie in
the same space [13]_. In this case one cannot compute distance between samples
@@ -877,6 +836,8 @@ There also exists an entropic regularized variant of GW that has been proposed i
:add-heading: Examples of computation of GW, regularized G and FGW
:heading-level: "
+Gromov Wasserstein barycenters
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Note that similarly to Wasserstein distance GW allows for the definition of GW
barycenters that can be expressed as
@@ -905,6 +866,43 @@ The implementations of FGW and FGW barycenter is provided in functions
:heading-level: "
+Other applications
+------------------
+
+We discuss in the following several OT related problems and tools that has been
+proposed in the OT and machine learning community.
+
+Wasserstein Discriminant Analysis
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Wasserstein Discriminant Analysis [11]_ is a generalization of `Fisher Linear Discriminant
+Analysis <https://en.wikipedia.org/wiki/Linear_discriminant_analysis>`__ that
+allows discrimination between classes that are not linearly separable. It
+consists in finding a linear projector optimizing the following criterion
+
+.. math::
+ P = \text{arg}\min_P \frac{\sum_i OT_e(\mu_i\#P,\mu_i\#P)}{\sum_{i,j\neq i}
+ OT_e(\mu_i\#P,\mu_j\#P)}
+
+where :math:`\#` is the push-forward operator, :math:`OT_e` is the entropic OT
+loss and :math:`\mu_i` is the
+distribution of samples from class :math:`i`. :math:`P` is also constrained to
+be in the Stiefel manifold. WDA can be solved in POT using function
+:any:`ot.dr.wda`. It requires to have installed :code:`pymanopt` and
+:code:`autograd` for manifold optimization and automatic differentiation
+respectively. Note that we also provide the Fisher discriminant estimator in
+:any:`ot.dr.fda` for easy comparison.
+
+.. warning::
+
+ Note that due to the hard dependency on :code:`pymanopt` and
+ :code:`autograd`, :any:`ot.dr` is not imported by default. If you want to
+ use it you have to specifically import it with :code:`import ot.dr` .
+
+.. minigallery:: ot.dr.wda
+ :add-heading: Examples of the use of WDA
+ :heading-level: "
+
Solving OT with Multiple backends on CPU/GPU
--------------------------------------------
diff --git a/examples/backends/plot_sliced_wass_grad_flow_pytorch.py b/examples/backends/plot_sliced_wass_grad_flow_pytorch.py
index cf5d64d..f00de50 100644
--- a/examples/backends/plot_sliced_wass_grad_flow_pytorch.py
+++ b/examples/backends/plot_sliced_wass_grad_flow_pytorch.py
@@ -74,7 +74,7 @@ x_all = np.zeros((nb_iter_max, x1.shape[0], 2))
loss_iter = []
# generator for random permutations
-gen = torch.Generator()
+gen = torch.Generator(device=device)
gen.manual_seed(42)
for i in range(nb_iter_max):
@@ -103,7 +103,7 @@ ax = pl.axis()
# %%
# Animate trajectories of the gradient flow along iteration
-# -------------------------------------------------------
+# ---------------------------------------------------------
pl.figure(3, (8, 4))
@@ -122,7 +122,7 @@ ani = animation.FuncAnimation(pl.gcf(), _update_plot, nb_iter_max, interval=100,
# %%
# Compute the Sliced Wasserstein Barycenter
-#
+# -----------------------------------------
x1_torch = torch.tensor(x1).to(device=device)
x3_torch = torch.tensor(x3).to(device=device)
xbinit = np.random.randn(500, 2) * 10 + 16
@@ -136,7 +136,7 @@ x_all = np.zeros((nb_iter_max, xbary_torch.shape[0], 2))
loss_iter = []
# generator for random permutations
-gen = torch.Generator()
+gen = torch.Generator(device=device)
gen.manual_seed(42)
alpha = 0.5
@@ -169,7 +169,7 @@ ax = pl.axis()
# %%
# Animate trajectories of the barycenter along gradient descent
-# -------------------------------------------------------
+# -------------------------------------------------------------
pl.figure(5, (8, 4))
diff --git a/examples/backends/plot_ssw_unif_torch.py b/examples/backends/plot_ssw_unif_torch.py
new file mode 100644
index 0000000..7ccc2af
--- /dev/null
+++ b/examples/backends/plot_ssw_unif_torch.py
@@ -0,0 +1,153 @@
+# -*- coding: utf-8 -*-
+r"""
+================================================
+Spherical Sliced-Wasserstein Embedding on Sphere
+================================================
+
+Here, we aim at transforming samples into a uniform
+distribution on the sphere by minimizing SSW:
+
+.. math::
+ \min_{x} SSW_2(\nu, \frac{1}{n}\sum_{i=1}^n \delta_{x_i})
+
+where :math:`\nu=\mathrm{Unif}(S^1)`.
+
+"""
+
+# Author: Clément Bonet <clement.bonet@univ-ubs.fr>
+#
+# License: MIT License
+
+# sphinx_gallery_thumbnail_number = 3
+
+import numpy as np
+import matplotlib.pyplot as pl
+import matplotlib.animation as animation
+import torch
+import torch.nn.functional as F
+
+import ot
+
+
+# %%
+# Data generation
+# ---------------
+
+torch.manual_seed(1)
+
+N = 1000
+x0 = torch.rand(N, 3)
+x0 = F.normalize(x0, dim=-1)
+
+
+# %%
+# Plot data
+# ---------
+
+def plot_sphere(ax):
+ xlist = np.linspace(-1.0, 1.0, 50)
+ ylist = np.linspace(-1.0, 1.0, 50)
+ r = np.linspace(1.0, 1.0, 50)
+ X, Y = np.meshgrid(xlist, ylist)
+
+ Z = np.sqrt(np.maximum(r**2 - X**2 - Y**2, 0))
+
+ ax.plot_wireframe(X, Y, Z, color="gray", alpha=.3)
+ ax.plot_wireframe(X, Y, -Z, color="gray", alpha=.3) # Now plot the bottom half
+
+
+# plot the distributions
+pl.figure(1)
+ax = pl.axes(projection='3d')
+plot_sphere(ax)
+ax.scatter(x0[:, 0], x0[:, 1], x0[:, 2], label='Data samples', alpha=0.5)
+ax.set_title('Data distribution')
+ax.legend()
+
+
+# %%
+# Gradient descent
+# ----------------
+
+x = x0.clone()
+x.requires_grad_(True)
+
+n_iter = 500
+lr = 100
+
+losses = []
+xvisu = torch.zeros(n_iter, N, 3)
+
+for i in range(n_iter):
+ sw = ot.sliced_wasserstein_sphere_unif(x, n_projections=500)
+ grad_x = torch.autograd.grad(sw, x)[0]
+
+ x = x - lr * grad_x
+ x = F.normalize(x, p=2, dim=1)
+
+ losses.append(sw.item())
+ xvisu[i, :, :] = x.detach().clone()
+
+ if i % 100 == 0:
+ print("Iter: {:3d}, loss={}".format(i, losses[-1]))
+
+pl.figure(1)
+pl.semilogy(losses)
+pl.grid()
+pl.title('SSW')
+pl.xlabel("Iterations")
+
+
+# %%
+# Plot trajectories of generated samples along iterations
+# -------------------------------------------------------
+
+ivisu = [0, 25, 50, 75, 100, 150, 200, 350, 499]
+
+fig = pl.figure(3, (10, 10))
+for i in range(9):
+ # pl.subplot(3, 3, i + 1)
+ # ax = pl.axes(projection='3d')
+ ax = fig.add_subplot(3, 3, i + 1, projection='3d')
+ plot_sphere(ax)
+ ax.scatter(xvisu[ivisu[i], :, 0], xvisu[ivisu[i], :, 1], xvisu[ivisu[i], :, 2], label='Data samples', alpha=0.5)
+ ax.set_title('Iter. {}'.format(ivisu[i]))
+ #ax.axis("off")
+ if i == 0:
+ ax.legend()
+
+
+# %%
+# Animate trajectories of generated samples along iteration
+# -------------------------------------------------------
+
+pl.figure(4, (8, 8))
+
+
+def _update_plot(i):
+ i = 3 * i
+ pl.clf()
+ ax = pl.axes(projection='3d')
+ plot_sphere(ax)
+ ax.scatter(xvisu[i, :, 0], xvisu[i, :, 1], xvisu[i, :, 2], label='Data samples$', alpha=0.5)
+ ax.axis("off")
+ ax.set_xlim((-1.5, 1.5))
+ ax.set_ylim((-1.5, 1.5))
+ ax.set_title('Iter. {}'.format(i))
+ return 1
+
+
+print(xvisu.shape)
+
+i = 0
+ax = pl.axes(projection='3d')
+plot_sphere(ax)
+ax.scatter(xvisu[i, :, 0], xvisu[i, :, 1], xvisu[i, :, 2], label='Data samples from $G\#\mu_n$', alpha=0.5)
+ax.axis("off")
+ax.set_xlim((-1.5, 1.5))
+ax.set_ylim((-1.5, 1.5))
+ax.set_title('Iter. {}'.format(ivisu[i]))
+
+
+ani = animation.FuncAnimation(pl.gcf(), _update_plot, n_iter // 5, interval=100, repeat_delay=2000)
+# %%
diff --git a/examples/barycenters/plot_barycenter_1D.py b/examples/barycenters/plot_barycenter_1D.py
index 2373e99..8096245 100644
--- a/examples/barycenters/plot_barycenter_1D.py
+++ b/examples/barycenters/plot_barycenter_1D.py
@@ -106,7 +106,7 @@ for i, z in enumerate(zs):
ys = B_l2[:, i]
verts.append(list(zip(x, ys)))
-ax = plt.gcf().gca(projection='3d')
+ax = plt.gcf().add_subplot(projection='3d')
poly = PolyCollection(verts, facecolors=[cmap(a) for a in alpha_list])
poly.set_alpha(0.7)
@@ -128,7 +128,7 @@ for i, z in enumerate(zs):
ys = B_wass[:, i]
verts.append(list(zip(x, ys)))
-ax = plt.gcf().gca(projection='3d')
+ax = plt.gcf().add_subplot(projection='3d')
poly = PolyCollection(verts, facecolors=[cmap(a) for a in alpha_list])
poly.set_alpha(0.7)
diff --git a/examples/barycenters/plot_free_support_barycenter.py b/examples/barycenters/plot_free_support_barycenter.py
index 226dfeb..f4a13dd 100644
--- a/examples/barycenters/plot_free_support_barycenter.py
+++ b/examples/barycenters/plot_free_support_barycenter.py
@@ -4,13 +4,14 @@
2D free support Wasserstein barycenters of distributions
========================================================
-Illustration of 2D Wasserstein barycenters if distributions are weighted
+Illustration of 2D Wasserstein and Sinkhorn barycenters if distributions are weighted
sum of diracs.
"""
# Authors: Vivien Seguy <vivien.seguy@iip.ist.i.kyoto-u.ac.jp>
# Rémi Flamary <remi.flamary@polytechnique.edu>
+# Eduardo Fernandes Montesuma <eduardo.fernandes-montesuma@universite-paris-saclay.fr>
#
# License: MIT License
@@ -48,7 +49,7 @@ pl.title('Distributions')
# %%
-# Compute free support barycenter
+# Compute free support Wasserstein barycenter
# -------------------------------
k = 200 # number of Diracs of the barycenter
@@ -58,7 +59,28 @@ b = np.ones((k,)) / k # weights of the barycenter (it will not be optimized, on
X = ot.lp.free_support_barycenter(measures_locations, measures_weights, X_init, b)
# %%
-# Plot the barycenter
+# Plot the Wasserstein barycenter
+# ---------
+
+pl.figure(2, (8, 3))
+pl.scatter(x1[:, 0], x1[:, 1], alpha=0.5)
+pl.scatter(x2[:, 0], x2[:, 1], alpha=0.5)
+pl.scatter(X[:, 0], X[:, 1], s=b * 1000, marker='s', label='2-Wasserstein barycenter')
+pl.title('Data measures and their barycenter')
+pl.legend(loc="lower right")
+pl.show()
+
+# %%
+# Compute free support Sinkhorn barycenter
+
+k = 200 # number of Diracs of the barycenter
+X_init = np.random.normal(0., 1., (k, d)) # initial Dirac locations
+b = np.ones((k,)) / k # weights of the barycenter (it will not be optimized, only the locations are optimized)
+
+X = ot.bregman.free_support_sinkhorn_barycenter(measures_locations, measures_weights, X_init, 20, b, numItermax=15)
+
+# %%
+# Plot the Wasserstein barycenter
# ---------
pl.figure(2, (8, 3))
diff --git a/examples/barycenters/plot_free_support_sinkhorn_barycenter.py b/examples/barycenters/plot_free_support_sinkhorn_barycenter.py
new file mode 100644
index 0000000..ebe1f3b
--- /dev/null
+++ b/examples/barycenters/plot_free_support_sinkhorn_barycenter.py
@@ -0,0 +1,151 @@
+# -*- coding: utf-8 -*-
+"""
+========================================================
+2D free support Sinkhorn barycenters of distributions
+========================================================
+
+Illustration of Sinkhorn barycenter calculation between empirical distributions understood as point clouds
+
+"""
+
+# Authors: Eduardo Fernandes Montesuma <eduardo.fernandes-montesuma@universite-paris-saclay.fr>
+#
+# License: MIT License
+
+import numpy as np
+import matplotlib.pyplot as plt
+import ot
+
+# %%
+# General Parameters
+# ------------------
+reg = 1e-2 # Entropic Regularization
+numItermax = 20 # Maximum number of iterations for the Barycenter algorithm
+numInnerItermax = 50 # Maximum number of sinkhorn iterations
+n_samples = 200
+
+# %%
+# Generate Data
+# -------------
+
+X1 = np.random.randn(200, 2)
+X2 = 2 * np.concatenate([
+ np.concatenate([- np.ones([50, 1]), np.linspace(-1, 1, 50)[:, None]], axis=1),
+ np.concatenate([np.linspace(-1, 1, 50)[:, None], np.ones([50, 1])], axis=1),
+ np.concatenate([np.ones([50, 1]), np.linspace(1, -1, 50)[:, None]], axis=1),
+ np.concatenate([np.linspace(1, -1, 50)[:, None], - np.ones([50, 1])], axis=1),
+], axis=0)
+X3 = np.random.randn(200, 2)
+X3 = 2 * (X3 / np.linalg.norm(X3, axis=1)[:, None])
+X4 = np.random.multivariate_normal(np.array([0, 0]), np.array([[1., 0.5], [0.5, 1.]]), size=200)
+
+a1, a2, a3, a4 = ot.unif(len(X1)), ot.unif(len(X1)), ot.unif(len(X1)), ot.unif(len(X1))
+
+# %%
+# Inspect generated distributions
+# -------------------------------
+
+fig, axes = plt.subplots(1, 4, figsize=(16, 4))
+
+axes[0].scatter(x=X1[:, 0], y=X1[:, 1], c='steelblue', edgecolor='k')
+axes[1].scatter(x=X2[:, 0], y=X2[:, 1], c='steelblue', edgecolor='k')
+axes[2].scatter(x=X3[:, 0], y=X3[:, 1], c='steelblue', edgecolor='k')
+axes[3].scatter(x=X4[:, 0], y=X4[:, 1], c='steelblue', edgecolor='k')
+
+axes[0].set_xlim([-3, 3])
+axes[0].set_ylim([-3, 3])
+axes[0].set_title('Distribution 1')
+
+axes[1].set_xlim([-3, 3])
+axes[1].set_ylim([-3, 3])
+axes[1].set_title('Distribution 2')
+
+axes[2].set_xlim([-3, 3])
+axes[2].set_ylim([-3, 3])
+axes[2].set_title('Distribution 3')
+
+axes[3].set_xlim([-3, 3])
+axes[3].set_ylim([-3, 3])
+axes[3].set_title('Distribution 4')
+
+plt.tight_layout()
+plt.show()
+
+# %%
+# Interpolating Empirical Distributions
+# -------------------------------------
+
+fig = plt.figure(figsize=(10, 10))
+
+weights = np.array([
+ [3 / 3, 0 / 3],
+ [2 / 3, 1 / 3],
+ [1 / 3, 2 / 3],
+ [0 / 3, 3 / 3],
+]).astype(np.float32)
+
+for k in range(4):
+ XB_init = np.random.randn(n_samples, 2)
+ XB = ot.bregman.free_support_sinkhorn_barycenter(
+ measures_locations=[X1, X2],
+ measures_weights=[a1, a2],
+ weights=weights[k],
+ X_init=XB_init,
+ reg=reg,
+ numItermax=numItermax,
+ numInnerItermax=numInnerItermax
+ )
+ ax = plt.subplot2grid((4, 4), (0, k))
+ ax.scatter(XB[:, 0], XB[:, 1], color='steelblue', edgecolor='k')
+ ax.set_xlim([-3, 3])
+ ax.set_ylim([-3, 3])
+
+for k in range(1, 4, 1):
+ XB_init = np.random.randn(n_samples, 2)
+ XB = ot.bregman.free_support_sinkhorn_barycenter(
+ measures_locations=[X1, X3],
+ measures_weights=[a1, a2],
+ weights=weights[k],
+ X_init=XB_init,
+ reg=reg,
+ numItermax=numItermax,
+ numInnerItermax=numInnerItermax
+ )
+ ax = plt.subplot2grid((4, 4), (k, 0))
+ ax.scatter(XB[:, 0], XB[:, 1], color='steelblue', edgecolor='k')
+ ax.set_xlim([-3, 3])
+ ax.set_ylim([-3, 3])
+
+for k in range(1, 4, 1):
+ XB_init = np.random.randn(n_samples, 2)
+ XB = ot.bregman.free_support_sinkhorn_barycenter(
+ measures_locations=[X3, X4],
+ measures_weights=[a1, a2],
+ weights=weights[k],
+ X_init=XB_init,
+ reg=reg,
+ numItermax=numItermax,
+ numInnerItermax=numInnerItermax
+ )
+ ax = plt.subplot2grid((4, 4), (3, k))
+ ax.scatter(XB[:, 0], XB[:, 1], color='steelblue', edgecolor='k')
+ ax.set_xlim([-3, 3])
+ ax.set_ylim([-3, 3])
+
+for k in range(1, 3, 1):
+ XB_init = np.random.randn(n_samples, 2)
+ XB = ot.bregman.free_support_sinkhorn_barycenter(
+ measures_locations=[X2, X4],
+ measures_weights=[a1, a2],
+ weights=weights[k],
+ X_init=XB_init,
+ reg=reg,
+ numItermax=numItermax,
+ numInnerItermax=numInnerItermax
+ )
+ ax = plt.subplot2grid((4, 4), (k, 3))
+ ax.scatter(XB[:, 0], XB[:, 1], color='steelblue', edgecolor='k')
+ ax.set_xlim([-3, 3])
+ ax.set_ylim([-3, 3])
+
+plt.show()
diff --git a/examples/barycenters/plot_generalized_free_support_barycenter.py b/examples/barycenters/plot_generalized_free_support_barycenter.py
new file mode 100644
index 0000000..e685ec7
--- /dev/null
+++ b/examples/barycenters/plot_generalized_free_support_barycenter.py
@@ -0,0 +1,155 @@
+# -*- coding: utf-8 -*-
+"""
+=======================================
+Generalized Wasserstein Barycenter Demo
+=======================================
+
+This example illustrates the computation of Generalized Wasserstein Barycenter
+as proposed in [42].
+
+
+[42] Delon, J., Gozlan, N., and Saint-Dizier, A..
+Generalized Wasserstein barycenters between probability measures living on different subspaces.
+arXiv preprint arXiv:2105.09755, 2021.
+
+"""
+
+# Author: Eloi Tanguy <eloi.tanguy@polytechnique.edu>
+#
+# License: MIT License
+
+# sphinx_gallery_thumbnail_number = 2
+
+import numpy as np
+import matplotlib.pyplot as plt
+import matplotlib.pylab as pl
+import ot
+import matplotlib.animation as animation
+
+########################
+# Generate and plot data
+# ----------------------
+
+# Input measures
+sub_sample_factor = 8
+I1 = pl.imread('../../data/redcross.png').astype(np.float64)[::sub_sample_factor, ::sub_sample_factor, 2]
+I2 = pl.imread('../../data/tooth.png').astype(np.float64)[::-sub_sample_factor, ::sub_sample_factor, 2]
+I3 = pl.imread('../../data/heart.png').astype(np.float64)[::-sub_sample_factor, ::sub_sample_factor, 2]
+
+sz = I1.shape[0]
+UU, VV = np.meshgrid(np.arange(sz), np.arange(sz))
+
+# Input measure locations in their respective 2D spaces
+X_list = [np.stack((UU[im == 0], VV[im == 0]), 1) * 1.0 for im in [I1, I2, I3]]
+
+# Input measure weights
+a_list = [ot.unif(x.shape[0]) for x in X_list]
+
+# Projections 3D -> 2D
+P1 = np.array([[1, 0, 0], [0, 1, 0]])
+P2 = np.array([[0, 1, 0], [0, 0, 1]])
+P3 = np.array([[1, 0, 0], [0, 0, 1]])
+P_list = [P1, P2, P3]
+
+# Barycenter weights
+weights = np.array([1 / 3, 1 / 3, 1 / 3])
+
+# Number of barycenter points to compute
+n_samples_bary = 150
+
+# Send the input measures into 3D space for visualisation
+X_visu = [Xi @ Pi for (Xi, Pi) in zip(X_list, P_list)]
+
+# Plot the input data
+fig = plt.figure(figsize=(3, 3))
+axis = fig.add_subplot(1, 1, 1, projection="3d")
+for Xi in X_visu:
+ axis.scatter(Xi[:, 0], Xi[:, 1], Xi[:, 2], marker='o', alpha=.6)
+axis.view_init(azim=45)
+axis.set_xticks([])
+axis.set_yticks([])
+axis.set_zticks([])
+plt.show()
+
+#################################
+# Barycenter computation and plot
+# -------------------------------
+
+Y = ot.lp.generalized_free_support_barycenter(X_list, a_list, P_list, n_samples_bary)
+fig = plt.figure(figsize=(3, 3))
+
+axis = fig.add_subplot(1, 1, 1, projection="3d")
+for Xi in X_visu:
+ axis.scatter(Xi[:, 0], Xi[:, 1], Xi[:, 2], marker='o', alpha=.6)
+axis.scatter(Y[:, 0], Y[:, 1], Y[:, 2], marker='o', alpha=.6)
+axis.view_init(azim=45)
+axis.set_xticks([])
+axis.set_yticks([])
+axis.set_zticks([])
+plt.show()
+
+
+#############################
+# Plotting projection matches
+# ---------------------------
+
+fig = plt.figure(figsize=(9, 3))
+
+ax = fig.add_subplot(1, 3, 1, projection='3d')
+for Xi in X_visu:
+ ax.scatter(Xi[:, 0], Xi[:, 1], Xi[:, 2], marker='o', alpha=.6)
+ax.scatter(Y[:, 0], Y[:, 1], Y[:, 2], marker='o', alpha=.6)
+ax.view_init(elev=0, azim=0)
+ax.set_xticks([])
+ax.set_yticks([])
+ax.set_zticks([])
+
+ax = fig.add_subplot(1, 3, 2, projection='3d')
+for Xi in X_visu:
+ ax.scatter(Xi[:, 0], Xi[:, 1], Xi[:, 2], marker='o', alpha=.6)
+ax.scatter(Y[:, 0], Y[:, 1], Y[:, 2], marker='o', alpha=.6)
+ax.view_init(elev=0, azim=90)
+ax.set_xticks([])
+ax.set_yticks([])
+ax.set_zticks([])
+
+ax = fig.add_subplot(1, 3, 3, projection='3d')
+for Xi in X_visu:
+ ax.scatter(Xi[:, 0], Xi[:, 1], Xi[:, 2], marker='o', alpha=.6)
+ax.scatter(Y[:, 0], Y[:, 1], Y[:, 2], marker='o', alpha=.6)
+ax.view_init(elev=90, azim=0)
+ax.set_xticks([])
+ax.set_yticks([])
+ax.set_zticks([])
+
+plt.tight_layout()
+plt.show()
+
+##############################################
+# Rotation animation
+# --------------------------------------------
+
+fig = plt.figure(figsize=(7, 7))
+ax = fig.add_subplot(1, 1, 1, projection="3d")
+
+
+def _init():
+ for Xi in X_visu:
+ ax.scatter(Xi[:, 0], Xi[:, 1], Xi[:, 2], marker='o', alpha=.6)
+ ax.scatter(Y[:, 0], Y[:, 1], Y[:, 2], marker='o', alpha=.6)
+ ax.view_init(elev=0, azim=0)
+ ax.set_xticks([])
+ ax.set_yticks([])
+ ax.set_zticks([])
+ return fig,
+
+
+def _update_plot(i):
+ if i < 45:
+ ax.view_init(elev=0, azim=4 * i)
+ else:
+ ax.view_init(elev=i - 45, azim=4 * i)
+ return fig,
+
+
+ani = animation.FuncAnimation(fig, _update_plot, init_func=_init, frames=136, interval=50, blit=True, repeat_delay=2000)
diff --git a/examples/domain-adaptation/plot_otda_linear_mapping.py b/examples/domain-adaptation/plot_otda_linear_mapping.py
index a44096a..8284a2a 100644
--- a/examples/domain-adaptation/plot_otda_linear_mapping.py
+++ b/examples/domain-adaptation/plot_otda_linear_mapping.py
@@ -61,7 +61,7 @@ plt.plot(xt[:, 0], xt[:, 1], 'o')
# Estimate linear mapping and transport
# -------------------------------------
-Ae, be = ot.da.OT_mapping_linear(xs, xt)
+Ae, be = ot.gaussian.empirical_bures_wasserstein_mapping(xs, xt)
xst = xs.dot(Ae) + be
diff --git a/examples/gromov/plot_barycenter_fgw.py b/examples/gromov/plot_barycenter_fgw.py
index 556e08f..dc3c6aa 100644
--- a/examples/gromov/plot_barycenter_fgw.py
+++ b/examples/gromov/plot_barycenter_fgw.py
@@ -174,7 +174,7 @@ A, C, log = fgw_barycenters(sizebary, Ys, Cs, ps, lambdas, alpha=0.95, log=True)
# -------------------------
#%% Create the barycenter
-bary = nx.from_numpy_matrix(sp_to_adjency(C, threshinf=0, threshsup=find_thresh(C, sup=100, step=100)[0]))
+bary = nx.from_numpy_array(sp_to_adjency(C, threshinf=0, threshsup=find_thresh(C, sup=100, step=100)[0]))
for i, v in enumerate(A.ravel()):
bary.add_node(i, attr_name=v)
diff --git a/examples/gromov/plot_gromov.py b/examples/gromov/plot_gromov.py
index 5a362cf..05074dc 100644
--- a/examples/gromov/plot_gromov.py
+++ b/examples/gromov/plot_gromov.py
@@ -3,7 +3,6 @@
==========================
Gromov-Wasserstein example
==========================
-
This example is designed to show how to use the Gromov-Wassertsein distance
computation in POT.
"""
diff --git a/examples/gromov/plot_gromov_barycenter.py b/examples/gromov/plot_gromov_barycenter.py
index 7fe081f..08ec610 100755
--- a/examples/gromov/plot_gromov_barycenter.py
+++ b/examples/gromov/plot_gromov_barycenter.py
@@ -110,8 +110,7 @@ for nb in range(4):
if shapes[nb][i, j] < 0.95:
xs[nb].append([j, 8 - i])
-xs = np.array([np.array(xs[0]), np.array(xs[1]),
- np.array(xs[2]), np.array(xs[3])])
+xs = [np.array(xs[s]) for s in range(S)]
##############################################################################
# Barycenter computation
diff --git a/examples/gromov/plot_gromov_wasserstein_dictionary_learning.py b/examples/gromov/plot_gromov_wasserstein_dictionary_learning.py
index 1fdc3b9..7585944 100755
--- a/examples/gromov/plot_gromov_wasserstein_dictionary_learning.py
+++ b/examples/gromov/plot_gromov_wasserstein_dictionary_learning.py
@@ -45,10 +45,11 @@ from ot.gromov import gromov_wasserstein_linear_unmixing, gromov_wasserstein_dic
import ot
import networkx
from networkx.generators.community import stochastic_block_model as sbm
-# %%
-# =============================================================================
+
+#############################################################################
+#
# Generate a dataset composed of graphs following Stochastic Block models of 1, 2 and 3 clusters.
-# =============================================================================
+# ---------------------------------------------
np.random.seed(42)
@@ -109,10 +110,10 @@ for idx_c, c in enumerate(clusters):
pl.tight_layout()
pl.show()
-# %%
-# =============================================================================
+#############################################################################
+#
# Estimate the gromov-wasserstein dictionary from the dataset
-# =============================================================================
+# ---------------------------------------------
np.random.seed(0)
@@ -140,10 +141,10 @@ pl.ylabel('loss', fontsize=12)
pl.tight_layout()
pl.show()
-# %%
-# =============================================================================
+#############################################################################
+#
# Visualization of the estimated dictionary atoms
-# =============================================================================
+# ---------------------------------------------
# Continuous connections between nodes of the atoms are colored in shades of grey (1: dark / 2: white)
@@ -164,10 +165,11 @@ for idx_atom, atom in enumerate(Cdict_GW):
pl.axis("off")
pl.tight_layout()
pl.show()
-#%%
-# =============================================================================
+
+#############################################################################
+#
# Visualization of the embedding space
-# =============================================================================
+# ---------------------------------------------
unmixings = []
reconstruction_errors = []
@@ -211,11 +213,11 @@ pl.axis('off')
pl.legend(fontsize=11)
pl.tight_layout()
pl.show()
-# %%
-# =============================================================================
-# Endow the dataset with node features
-# =============================================================================
+#############################################################################
+#
+# Endow the dataset with node features
+# ---------------------------------------------
# We follow this feature assignment on all nodes of a graph depending on its label/number of clusters
# 1 cluster --> 0 as nodes feature
# 2 clusters --> 1 as nodes feature
@@ -251,10 +253,11 @@ for idx_c, c in enumerate(clusters):
pl.axis("off")
pl.tight_layout()
pl.show()
-# %%
-# =============================================================================
+
+#############################################################################
+#
# Estimate a Fused Gromov-Wasserstein dictionary from the dataset of attributed graphs
-# =============================================================================
+# ---------------------------------------------
np.random.seed(0)
ps = [ot.unif(C.shape[0]) for C in dataset]
D = 3 # 6 atoms instead of 3
@@ -280,10 +283,10 @@ pl.ylabel('loss', fontsize=12)
pl.tight_layout()
pl.show()
-# %%
-# =============================================================================
+#############################################################################
+#
# Visualization of the estimated dictionary atoms
-# =============================================================================
+# ---------------------------------------------
pl.figure(7, (12, 8))
pl.clf()
@@ -307,10 +310,10 @@ for idx_atom, (Catom, Fatom) in enumerate(zip(Cdict_FGW, Ydict_FGW)):
pl.tight_layout()
pl.show()
-# %%
-# =============================================================================
+#############################################################################
+#
# Visualization of the embedding space
-# =============================================================================
+# ---------------------------------------------
unmixings = []
reconstruction_errors = []
diff --git a/examples/gromov/plot_semirelaxed_fgw.py b/examples/gromov/plot_semirelaxed_fgw.py
new file mode 100644
index 0000000..ef4b286
--- /dev/null
+++ b/examples/gromov/plot_semirelaxed_fgw.py
@@ -0,0 +1,301 @@
+# -*- coding: utf-8 -*-
+"""
+==========================
+Semi-relaxed (Fused) Gromov-Wasserstein example
+==========================
+
+This example is designed to show how to use the semi-relaxed Gromov-Wasserstein
+and the semi-relaxed Fused Gromov-Wasserstein divergences.
+
+sr(F)GW between two graphs G1 and G2 searches for a reweighing of the nodes of
+G2 at a minimal (F)GW distance from G1.
+
+First, we generate two graphs following Stochastic Block Models, then show
+how to compute their srGW matchings and illustrate them. These graphs are then
+endowed with node features and we follow the same process with srFGW.
+
+[48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty.
+"Semi-relaxed Gromov-Wasserstein divergence and applications on graphs"
+International Conference on Learning Representations (ICLR), 2021.
+"""
+
+# Author: Cédric Vincent-Cuaz <cedvincentcuaz@gmail.com>
+#
+# License: MIT License
+
+# sphinx_gallery_thumbnail_number = 1
+
+import numpy as np
+import matplotlib.pylab as pl
+from ot.gromov import semirelaxed_gromov_wasserstein, semirelaxed_fused_gromov_wasserstein, gromov_wasserstein, fused_gromov_wasserstein
+import networkx
+from networkx.generators.community import stochastic_block_model as sbm
+
+#############################################################################
+#
+# Generate two graphs following Stochastic Block models of 2 and 3 clusters.
+# ---------------------------------------------
+
+
+N2 = 20 # 2 communities
+N3 = 30 # 3 communities
+p2 = [[1., 0.1],
+ [0.1, 0.9]]
+p3 = [[1., 0.1, 0.],
+ [0.1, 0.95, 0.1],
+ [0., 0.1, 0.9]]
+G2 = sbm(seed=0, sizes=[N2 // 2, N2 // 2], p=p2)
+G3 = sbm(seed=0, sizes=[N3 // 3, N3 // 3, N3 // 3], p=p3)
+
+
+C2 = networkx.to_numpy_array(G2)
+C3 = networkx.to_numpy_array(G3)
+
+h2 = np.ones(C2.shape[0]) / C2.shape[0]
+h3 = np.ones(C3.shape[0]) / C3.shape[0]
+
+# Add weights on the edges for visualization later on
+weight_intra_G2 = 5
+weight_inter_G2 = 0.5
+weight_intra_G3 = 1.
+weight_inter_G3 = 1.5
+
+weightedG2 = networkx.Graph()
+part_G2 = [G2.nodes[i]['block'] for i in range(N2)]
+
+for node in G2.nodes():
+ weightedG2.add_node(node)
+for i, j in G2.edges():
+ if part_G2[i] == part_G2[j]:
+ weightedG2.add_edge(i, j, weight=weight_intra_G2)
+ else:
+ weightedG2.add_edge(i, j, weight=weight_inter_G2)
+
+weightedG3 = networkx.Graph()
+part_G3 = [G3.nodes[i]['block'] for i in range(N3)]
+
+for node in G3.nodes():
+ weightedG3.add_node(node)
+for i, j in G3.edges():
+ if part_G3[i] == part_G3[j]:
+ weightedG3.add_edge(i, j, weight=weight_intra_G3)
+ else:
+ weightedG3.add_edge(i, j, weight=weight_inter_G3)
+
+#############################################################################
+#
+# Compute their semi-relaxed Gromov-Wasserstein divergences
+# ---------------------------------------------
+
+# 0) GW(C2, h2, C3, h3) for reference
+OT, log = gromov_wasserstein(C2, C3, h2, h3, symmetric=True, log=True)
+gw = log['gw_dist']
+
+# 1) srGW(C2, h2, C3)
+OT_23, log_23 = semirelaxed_gromov_wasserstein(C2, C3, h2, symmetric=True,
+ log=True, G0=None)
+srgw_23 = log_23['srgw_dist']
+
+# 2) srGW(C3, h3, C2)
+
+OT_32, log_32 = semirelaxed_gromov_wasserstein(C3, C2, h3, symmetric=None,
+ log=True, G0=OT.T)
+srgw_32 = log_32['srgw_dist']
+
+print('GW(C2, C3) = ', gw)
+print('srGW(C2, h2, C3) = ', srgw_23)
+print('srGW(C3, h3, C2) = ', srgw_32)
+
+
+#############################################################################
+#
+# Visualization of the semi-relaxed Gromov-Wasserstein matchings
+# ---------------------------------------------
+#
+# We color nodes of the graph on the right - then project its node colors
+# based on the optimal transport plan from the srGW matching
+
+
+def draw_graph(G, C, nodes_color_part, Gweights=None,
+ pos=None, edge_color='black', node_size=None,
+ shiftx=0, seed=0):
+
+ if (pos is None):
+ pos = networkx.spring_layout(G, scale=1., seed=seed)
+
+ if shiftx != 0:
+ for k, v in pos.items():
+ v[0] = v[0] + shiftx
+
+ alpha_edge = 0.7
+ width_edge = 1.8
+ if Gweights is None:
+ networkx.draw_networkx_edges(G, pos, width=width_edge, alpha=alpha_edge, edge_color=edge_color)
+ else:
+ # We make more visible connections between activated nodes
+ n = len(Gweights)
+ edgelist_activated = []
+ edgelist_deactivated = []
+ for i in range(n):
+ for j in range(n):
+ if Gweights[i] * Gweights[j] * C[i, j] > 0:
+ edgelist_activated.append((i, j))
+ elif C[i, j] > 0:
+ edgelist_deactivated.append((i, j))
+
+ networkx.draw_networkx_edges(G, pos, edgelist=edgelist_activated,
+ width=width_edge, alpha=alpha_edge,
+ edge_color=edge_color)
+ networkx.draw_networkx_edges(G, pos, edgelist=edgelist_deactivated,
+ width=width_edge, alpha=0.1,
+ edge_color=edge_color)
+
+ if Gweights is None:
+ for node, node_color in enumerate(nodes_color_part):
+ networkx.draw_networkx_nodes(G, pos, nodelist=[node],
+ node_size=node_size, alpha=1,
+ node_color=node_color)
+ else:
+ scaled_Gweights = Gweights / (0.5 * Gweights.max())
+ nodes_size = node_size * scaled_Gweights
+ for node, node_color in enumerate(nodes_color_part):
+ networkx.draw_networkx_nodes(G, pos, nodelist=[node],
+ node_size=nodes_size[node], alpha=1,
+ node_color=node_color)
+ return pos
+
+
+def draw_transp_colored_srGW(G1, C1, G2, C2, part_G1,
+ p1, p2, T, pos1=None, pos2=None,
+ shiftx=4, switchx=False, node_size=70,
+ seed_G1=0, seed_G2=0):
+ starting_color = 0
+ # get graphs partition and their coloring
+ part1 = part_G1.copy()
+ unique_colors = ['C%s' % (starting_color + i) for i in np.unique(part1)]
+ nodes_color_part1 = []
+ for cluster in part1:
+ nodes_color_part1.append(unique_colors[cluster])
+
+ nodes_color_part2 = []
+ # T: getting colors assignment from argmin of columns
+ for i in range(len(G2.nodes())):
+ j = np.argmax(T[:, i])
+ nodes_color_part2.append(nodes_color_part1[j])
+ pos1 = draw_graph(G1, C1, nodes_color_part1, Gweights=p1,
+ pos=pos1, node_size=node_size, shiftx=0, seed=seed_G1)
+ pos2 = draw_graph(G2, C2, nodes_color_part2, Gweights=p2, pos=pos2,
+ node_size=node_size, shiftx=shiftx, seed=seed_G2)
+ for k1, v1 in pos1.items():
+ for k2, v2 in pos2.items():
+ if (T[k1, k2] > 0):
+ pl.plot([pos1[k1][0], pos2[k2][0]],
+ [pos1[k1][1], pos2[k2][1]],
+ '-', lw=0.8, alpha=0.5,
+ color=nodes_color_part1[k1])
+ return pos1, pos2
+
+
+node_size = 40
+fontsize = 10
+seed_G2 = 0
+seed_G3 = 4
+
+pl.figure(1, figsize=(8, 2.5))
+pl.clf()
+pl.subplot(121)
+pl.axis('off')
+pl.axis
+pl.title(r'srGW$(\mathbf{C_2},\mathbf{h_2},\mathbf{C_3}) =%s$' % (np.round(srgw_23, 3)), fontsize=fontsize)
+
+hbar2 = OT_23.sum(axis=0)
+pos1, pos2 = draw_transp_colored_srGW(
+ weightedG2, C2, weightedG3, C3, part_G2, p1=None, p2=hbar2, T=OT_23,
+ shiftx=1.5, node_size=node_size, seed_G1=seed_G2, seed_G2=seed_G3)
+pl.subplot(122)
+pl.axis('off')
+hbar3 = OT_32.sum(axis=0)
+pl.title(r'srGW$(\mathbf{C_3}, \mathbf{h_3},\mathbf{C_2}) =%s$' % (np.round(srgw_32, 3)), fontsize=fontsize)
+pos1, pos2 = draw_transp_colored_srGW(
+ weightedG3, C3, weightedG2, C2, part_G3, p1=None, p2=hbar3, T=OT_32,
+ pos1=pos2, pos2=pos1, shiftx=3., node_size=node_size, seed_G1=0, seed_G2=0)
+pl.tight_layout()
+
+pl.show()
+
+#############################################################################
+#
+# Add node features
+# ---------------------------------------------
+
+# We add node features with given mean - by clusters
+# and inversely proportional to clusters' intra-connectivity
+
+F2 = np.zeros((N2, 1))
+for i, c in enumerate(part_G2):
+ F2[i, 0] = np.random.normal(loc=c, scale=0.01)
+
+F3 = np.zeros((N3, 1))
+for i, c in enumerate(part_G3):
+ F3[i, 0] = np.random.normal(loc=2. - c, scale=0.01)
+
+#############################################################################
+#
+# Compute their semi-relaxed Fused Gromov-Wasserstein divergences
+# ---------------------------------------------
+
+alpha = 0.5
+# Compute pairwise euclidean distance between node features
+M = (F2 ** 2).dot(np.ones((1, N3))) + np.ones((N2, 1)).dot((F3 ** 2).T) - 2 * F2.dot(F3.T)
+
+# 0) FGW_alpha(C2, F2, h2, C3, F3, h3) for reference
+
+OT, log = fused_gromov_wasserstein(
+ M, C2, C3, h2, h3, symmetric=True, alpha=alpha, log=True)
+fgw = log['fgw_dist']
+
+# 1) srFGW(C2, F2, h2, C3, F3)
+OT_23, log_23 = semirelaxed_fused_gromov_wasserstein(
+ M, C2, C3, h2, symmetric=True, alpha=0.5, log=True, G0=None)
+srfgw_23 = log_23['srfgw_dist']
+
+# 2) srFGW(C3, F3, h3, C2, F2)
+
+OT_32, log_32 = semirelaxed_fused_gromov_wasserstein(
+ M.T, C3, C2, h3, symmetric=None, alpha=alpha, log=True, G0=None)
+srfgw_32 = log_32['srfgw_dist']
+
+print('FGW(C2, F2, C3, F3) = ', fgw)
+print('srGW(C2, F2, h2, C3, F3) = ', srfgw_23)
+print('srGW(C3, F3, h3, C2, F2) = ', srfgw_32)
+
+#############################################################################
+#
+# Visualization of the semi-relaxed Fused Gromov-Wasserstein matchings
+# ---------------------------------------------
+#
+# We color nodes of the graph on the right - then project its node colors
+# based on the optimal transport plan from the srFGW matching
+# NB: colors refer to clusters - not to node features
+
+pl.figure(2, figsize=(8, 2.5))
+pl.clf()
+pl.subplot(121)
+pl.axis('off')
+pl.axis
+pl.title(r'srFGW$(\mathbf{C_2},\mathbf{F_2},\mathbf{h_2},\mathbf{C_3},\mathbf{F_3}) =%s$' % (np.round(srfgw_23, 3)), fontsize=fontsize)
+
+hbar2 = OT_23.sum(axis=0)
+pos1, pos2 = draw_transp_colored_srGW(
+ weightedG2, C2, weightedG3, C3, part_G2, p1=None, p2=hbar2, T=OT_23,
+ shiftx=1.5, node_size=node_size, seed_G1=seed_G2, seed_G2=seed_G3)
+pl.subplot(122)
+pl.axis('off')
+hbar3 = OT_32.sum(axis=0)
+pl.title(r'srFGW$(\mathbf{C_3}, \mathbf{F_3}, \mathbf{h_3}, \mathbf{C_2}, \mathbf{F_2}) =%s$' % (np.round(srfgw_32, 3)), fontsize=fontsize)
+pos1, pos2 = draw_transp_colored_srGW(
+ weightedG3, C3, weightedG2, C2, part_G3, p1=None, p2=hbar3, T=OT_32,
+ pos1=pos2, pos2=pos1, shiftx=3., node_size=node_size, seed_G1=0, seed_G2=0)
+pl.tight_layout()
+
+pl.show()
diff --git a/examples/others/plot_COOT.py b/examples/others/plot_COOT.py
new file mode 100644
index 0000000..98c1ce1
--- /dev/null
+++ b/examples/others/plot_COOT.py
@@ -0,0 +1,97 @@
+# -*- coding: utf-8 -*-
+r"""
+===================================================
+Row and column alignments with CO-Optimal Transport
+===================================================
+
+This example is designed to show how to use the CO-Optimal Transport [47]_ in POT.
+CO-Optimal Transport allows to calculate the distance between two **arbitrary-size**
+matrices, and to align their rows and columns. In this example, we consider two
+random matrices :math:`X_1` and :math:`X_2` defined by
+:math:`(X_1)_{i,j} = \cos(\frac{i}{n_1} \pi) + \cos(\frac{j}{d_1} \pi) + \sigma \mathcal N(0,1)`
+and :math:`(X_2)_{i,j} = \cos(\frac{i}{n_2} \pi) + \cos(\frac{j}{d_2} \pi) + \sigma \mathcal N(0,1)`.
+
+.. [49] Redko, I., Vayer, T., Flamary, R., and Courty, N. (2020).
+ `CO-Optimal Transport <https://proceedings.neurips.cc/paper/2020/file/cc384c68ad503482fb24e6d1e3b512ae-Paper.pdf>`_.
+ Advances in Neural Information Processing Systems, 33.
+"""
+
+# Author: Remi Flamary <remi.flamary@unice.fr>
+# Quang Huy Tran <quang-huy.tran@univ-ubs.fr>
+# License: MIT License
+
+from matplotlib.patches import ConnectionPatch
+import matplotlib.pylab as pl
+import numpy as np
+from ot.coot import co_optimal_transport as coot
+from ot.coot import co_optimal_transport2 as coot2
+
+# %%
+# Generating two random matrices
+
+n1 = 20
+n2 = 10
+d1 = 16
+d2 = 8
+sigma = 0.2
+
+X1 = (
+ np.cos(np.arange(n1) * np.pi / n1)[:, None] +
+ np.cos(np.arange(d1) * np.pi / d1)[None, :] +
+ sigma * np.random.randn(n1, d1)
+)
+X2 = (
+ np.cos(np.arange(n2) * np.pi / n2)[:, None] +
+ np.cos(np.arange(d2) * np.pi / d2)[None, :] +
+ sigma * np.random.randn(n2, d2)
+)
+
+# %%
+# Visualizing the matrices
+
+pl.figure(1, (8, 5))
+pl.subplot(1, 2, 1)
+pl.imshow(X1)
+pl.title('$X_1$')
+
+pl.subplot(1, 2, 2)
+pl.imshow(X2)
+pl.title("$X_2$")
+
+pl.tight_layout()
+
+# %%
+# Visualizing the alignments of rows and columns, and calculating the CO-Optimal Transport distance
+
+pi_sample, pi_feature, log = coot(X1, X2, log=True, verbose=True)
+coot_distance = coot2(X1, X2)
+print('CO-Optimal Transport distance = {:.5f}'.format(coot_distance))
+
+fig = pl.figure(4, (9, 7))
+pl.clf()
+
+ax1 = pl.subplot(2, 2, 3)
+pl.imshow(X1)
+pl.xlabel('$X_1$')
+
+ax2 = pl.subplot(2, 2, 2)
+ax2.yaxis.tick_right()
+pl.imshow(np.transpose(X2))
+pl.title("Transpose($X_2$)")
+ax2.xaxis.tick_top()
+
+for i in range(n1):
+ j = np.argmax(pi_sample[i, :])
+ xyA = (d1 - .5, i)
+ xyB = (j, d2 - .5)
+ con = ConnectionPatch(xyA=xyA, xyB=xyB, coordsA=ax1.transData,
+ coordsB=ax2.transData, color="black")
+ fig.add_artist(con)
+
+for i in range(d1):
+ j = np.argmax(pi_feature[i, :])
+ xyA = (i, -.5)
+ xyB = (-.5, j)
+ con = ConnectionPatch(
+ xyA=xyA, xyB=xyB, coordsA=ax1.transData, coordsB=ax2.transData, color="blue")
+ fig.add_artist(con)
diff --git a/examples/others/plot_learning_weights_with_COOT.py b/examples/others/plot_learning_weights_with_COOT.py
new file mode 100644
index 0000000..cb115c3
--- /dev/null
+++ b/examples/others/plot_learning_weights_with_COOT.py
@@ -0,0 +1,150 @@
+# -*- coding: utf-8 -*-
+r"""
+===============================================================
+Learning sample marginal distribution with CO-Optimal Transport
+===============================================================
+
+In this example, we illustrate how to estimate the sample marginal distribution which minimizes
+the CO-Optimal Transport distance [47]_ between two matrices. More precisely, given a source data
+:math:`(X, \mu_x^{(s)}, \mu_x^{(f)})` and a target matrix :math:`Y` associated with a fixed
+histogram on features :math:`\mu_y^{(f)}`, we want to solve the following problem
+
+.. math::
+ \min_{\mu_y^{(s)} \in \Delta} \text{COOT}\left( (X, \mu_x^{(s)}, \mu_x^{(f)}), (Y, \mu_y^{(s)}, \mu_y^{(f)}) \right)
+
+where :math:`\Delta` is the probability simplex. This minimization is done with a
+simple projected gradient descent in PyTorch. We use the automatic backend of POT that
+allows us to compute the CO-Optimal Transport distance with :func:`ot.coot.co_optimal_transport2`
+with differentiable losses.
+
+.. [49] Redko, I., Vayer, T., Flamary, R., and Courty, N. (2020).
+ `CO-Optimal Transport <https://proceedings.neurips.cc/paper/2020/file/cc384c68ad503482fb24e6d1e3b512ae-Paper.pdf>`_.
+ Advances in Neural Information Processing Systems, 33.
+"""
+
+# Author: Remi Flamary <remi.flamary@unice.fr>
+# Quang Huy Tran <quang-huy.tran@univ-ubs.fr>
+# License: MIT License
+
+from matplotlib.patches import ConnectionPatch
+import torch
+import numpy as np
+
+import matplotlib.pyplot as pl
+import ot
+
+from ot.coot import co_optimal_transport as coot
+from ot.coot import co_optimal_transport2 as coot2
+
+
+# %%
+# Generate data
+# -------------
+# The source and clean target matrices are generated by
+# :math:`X_{i,j} = \cos(\frac{i}{n_1} \pi) + \cos(\frac{j}{d_1} \pi)` and
+# :math:`Y_{i,j} = \cos(\frac{i}{n_2} \pi) + \cos(\frac{j}{d_2} \pi)`.
+# The target matrix is then contaminated by adding 5 row outliers.
+# Intuitively, we expect that the estimated sample distribution should ignore these outliers,
+# i.e. their weights should be zero.
+
+np.random.seed(182)
+
+n1, d1 = 20, 16
+n2, d2 = 10, 8
+n = 15
+
+X = (
+ torch.cos(torch.arange(n1) * torch.pi / n1)[:, None] +
+ torch.cos(torch.arange(d1) * torch.pi / d1)[None, :]
+)
+
+# Generate clean target data mixed with outliers
+Y_noisy = torch.randn((n, d2)) * 10.0
+Y_noisy[:n2, :] = (
+ torch.cos(torch.arange(n2) * torch.pi / n2)[:, None] +
+ torch.cos(torch.arange(d2) * torch.pi / d2)[None, :]
+)
+Y = Y_noisy[:n2, :]
+
+X, Y_noisy, Y = X.double(), Y_noisy.double(), Y.double()
+
+fig, axes = pl.subplots(nrows=1, ncols=3, figsize=(12, 5))
+axes[0].imshow(X, vmin=-2, vmax=2)
+axes[0].set_title('$X$')
+
+axes[1].imshow(Y, vmin=-2, vmax=2)
+axes[1].set_title('Clean $Y$')
+
+axes[2].imshow(Y_noisy, vmin=-2, vmax=2)
+axes[2].set_title('Noisy $Y$')
+
+pl.tight_layout()
+
+# %%
+# Optimize the COOT distance with respect to the sample marginal distribution
+# ---------------------------------------------------------------------------
+
+losses = []
+lr = 1e-3
+niter = 1000
+
+b = torch.tensor(ot.unif(n), requires_grad=True)
+
+for i in range(niter):
+
+ loss = coot2(X, Y_noisy, wy_samp=b, log=False, verbose=False)
+ losses.append(float(loss))
+
+ loss.backward()
+
+ with torch.no_grad():
+ b -= lr * b.grad # gradient step
+ b[:] = ot.utils.proj_simplex(b) # projection on the simplex
+
+ b.grad.zero_()
+
+# Estimated sample marginal distribution and training loss curve
+pl.plot(losses[10:])
+pl.title('CO-Optimal Transport distance')
+
+print(f"Marginal distribution = {b.detach().numpy()}")
+
+# %%
+# Visualizing the row and column alignments with the estimated sample marginal distribution
+# -----------------------------------------------------------------------------------------
+#
+# Clearly, the learned marginal distribution completely and successfully ignores the 5 outliers.
+
+X, Y_noisy = X.numpy(), Y_noisy.numpy()
+b = b.detach().numpy()
+
+pi_sample, pi_feature = coot(X, Y_noisy, wy_samp=b, log=False, verbose=True)
+
+fig = pl.figure(4, (9, 7))
+pl.clf()
+
+ax1 = pl.subplot(2, 2, 3)
+pl.imshow(X, vmin=-2, vmax=2)
+pl.xlabel('$X$')
+
+ax2 = pl.subplot(2, 2, 2)
+ax2.yaxis.tick_right()
+pl.imshow(np.transpose(Y_noisy), vmin=-2, vmax=2)
+pl.title("Transpose(Noisy $Y$)")
+ax2.xaxis.tick_top()
+
+for i in range(n1):
+ j = np.argmax(pi_sample[i, :])
+ xyA = (d1 - .5, i)
+ xyB = (j, d2 - .5)
+ con = ConnectionPatch(xyA=xyA, xyB=xyB, coordsA=ax1.transData,
+ coordsB=ax2.transData, color="black")
+ fig.add_artist(con)
+
+for i in range(d1):
+ j = np.argmax(pi_feature[i, :])
+ xyA = (i, -.5)
+ xyB = (-.5, j)
+ con = ConnectionPatch(
+ xyA=xyA, xyB=xyB, coordsA=ax1.transData, coordsB=ax2.transData, color="blue")
+ fig.add_artist(con)
diff --git a/examples/plot_compute_wasserstein_circle.py b/examples/plot_compute_wasserstein_circle.py
new file mode 100644
index 0000000..3ede96f
--- /dev/null
+++ b/examples/plot_compute_wasserstein_circle.py
@@ -0,0 +1,161 @@
+# -*- coding: utf-8 -*-
+"""
+=========================
+OT distance on the Circle
+=========================
+
+Shows how to compute the Wasserstein distance on the circle
+
+
+"""
+
+# Author: Clément Bonet <clement.bonet@univ-ubs.fr>
+#
+# License: MIT License
+
+# sphinx_gallery_thumbnail_number = 2
+
+import numpy as np
+import matplotlib.pylab as pl
+import ot
+
+from scipy.special import iv
+
+##############################################################################
+# Plot data
+# ---------
+
+#%% plot the distributions
+
+
+def pdf_von_Mises(theta, mu, kappa):
+ pdf = np.exp(kappa * np.cos(theta - mu)) / (2.0 * np.pi * iv(0, kappa))
+ return pdf
+
+
+t = np.linspace(0, 2 * np.pi, 1000, endpoint=False)
+
+mu1 = 1
+kappa1 = 20
+
+mu_targets = np.linspace(mu1, mu1 + 2 * np.pi, 10)
+
+
+pdf1 = pdf_von_Mises(t, mu1, kappa1)
+
+
+pl.figure(1)
+for k, mu in enumerate(mu_targets):
+ pdf_t = pdf_von_Mises(t, mu, kappa1)
+ if k == 0:
+ label = "Source distributions"
+ else:
+ label = None
+ pl.plot(t / (2 * np.pi), pdf_t, c='b', label=label)
+
+pl.plot(t / (2 * np.pi), pdf1, c="r", label="Target distribution")
+pl.legend()
+
+mu2 = 0
+kappa2 = kappa1
+
+x1 = np.random.vonmises(mu1, kappa1, size=(10,)) + np.pi
+x2 = np.random.vonmises(mu2, kappa2, size=(10,)) + np.pi
+
+angles = np.linspace(0, 2 * np.pi, 150)
+
+pl.figure(2)
+pl.plot(np.cos(angles), np.sin(angles), c="k")
+pl.xlim(-1.25, 1.25)
+pl.ylim(-1.25, 1.25)
+pl.scatter(np.cos(x1), np.sin(x1), c="b")
+pl.scatter(np.cos(x2), np.sin(x2), c="r")
+
+#########################################################################################
+# Compare the Euclidean Wasserstein distance with the Wasserstein distance on the circle
+# ---------------------------------------------------------------------------------------
+# This examples illustrates the periodicity of the Wasserstein distance on the circle.
+# We choose as target distribution a von Mises distribution with mean :math:`\mu_{\mathrm{target}}`
+# and :math:`\kappa=20`. Then, we compare the distances with samples obtained from a von Mises distribution
+# with parameters :math:`\mu_{\mathrm{source}}` and :math:`\kappa=20`.
+# The Wasserstein distance on the circle takes into account the periodicity
+# and attains its maximum in :math:`\mu_{\mathrm{target}}+1` (the antipodal point) contrary to the
+# Euclidean version.
+
+#%% Compute and plot distributions
+
+mu_targets = np.linspace(0, 2 * np.pi, 200)
+xs = np.random.vonmises(mu1 - np.pi, kappa1, size=(500,)) + np.pi
+
+n_try = 5
+
+xts = np.zeros((n_try, 200, 500))
+for i in range(n_try):
+ for k, mu in enumerate(mu_targets):
+ # np.random.vonmises deals with data on [-pi, pi[
+ xt = np.random.vonmises(mu - np.pi, kappa2, size=(500,)) + np.pi
+ xts[i, k] = xt
+
+# Put data on S^1=[0,1[
+xts2 = xts / (2 * np.pi)
+xs2 = np.concatenate([xs[None] for k in range(200)], axis=0) / (2 * np.pi)
+
+L_w2_circle = np.zeros((n_try, 200))
+L_w2 = np.zeros((n_try, 200))
+
+for i in range(n_try):
+ w2_circle = ot.wasserstein_circle(xs2.T, xts2[i].T, p=2)
+ w2 = ot.wasserstein_1d(xs2.T, xts2[i].T, p=2)
+
+ L_w2_circle[i] = w2_circle
+ L_w2[i] = w2
+
+m_w2_circle = np.mean(L_w2_circle, axis=0)
+std_w2_circle = np.std(L_w2_circle, axis=0)
+
+m_w2 = np.mean(L_w2, axis=0)
+std_w2 = np.std(L_w2, axis=0)
+
+pl.figure(1)
+pl.plot(mu_targets / (2 * np.pi), m_w2_circle, label="Wasserstein circle")
+pl.fill_between(mu_targets / (2 * np.pi), m_w2_circle - 2 * std_w2_circle, m_w2_circle + 2 * std_w2_circle, alpha=0.5)
+pl.plot(mu_targets / (2 * np.pi), m_w2, label="Euclidean Wasserstein")
+pl.fill_between(mu_targets / (2 * np.pi), m_w2 - 2 * std_w2, m_w2 + 2 * std_w2, alpha=0.5)
+pl.vlines(x=[mu1 / (2 * np.pi)], ymin=0, ymax=np.max(w2), linestyle="--", color="k", label=r"$\mu_{\mathrm{target}}$")
+pl.legend()
+pl.xlabel(r"$\mu_{\mathrm{source}}$")
+pl.show()
+
+
+########################################################################
+# Wasserstein distance between von Mises and uniform for different kappa
+# ----------------------------------------------------------------------
+# When :math:`\kappa=0`, the von Mises distribution is the uniform distribution on :math:`S^1`.
+
+#%% Compute Wasserstein between Von Mises and uniform
+
+kappas = np.logspace(-5, 2, 100)
+n_try = 20
+
+xts = np.zeros((n_try, 100, 500))
+for i in range(n_try):
+ for k, kappa in enumerate(kappas):
+ # np.random.vonmises deals with data on [-pi, pi[
+ xt = np.random.vonmises(0, kappa, size=(500,)) + np.pi
+ xts[i, k] = xt / (2 * np.pi)
+
+L_w2 = np.zeros((n_try, 100))
+for i in range(n_try):
+ L_w2[i] = ot.semidiscrete_wasserstein2_unif_circle(xts[i].T)
+
+m_w2 = np.mean(L_w2, axis=0)
+std_w2 = np.std(L_w2, axis=0)
+
+pl.figure(1)
+pl.plot(kappas, m_w2)
+pl.fill_between(kappas, m_w2 - std_w2, m_w2 + std_w2, alpha=0.5)
+pl.title(r"Evolution of $W_2^2(vM(0,\kappa), Unif(S^1))$")
+pl.xlabel(r"$\kappa$")
+pl.show()
+
+# %%
diff --git a/examples/sliced-wasserstein/plot_variance_ssw.py b/examples/sliced-wasserstein/plot_variance_ssw.py
new file mode 100644
index 0000000..83d458f
--- /dev/null
+++ b/examples/sliced-wasserstein/plot_variance_ssw.py
@@ -0,0 +1,111 @@
+# -*- coding: utf-8 -*-
+"""
+====================================================
+Spherical Sliced Wasserstein on distributions in S^2
+====================================================
+
+This example illustrates the computation of the spherical sliced Wasserstein discrepancy as
+proposed in [46].
+
+[46] Bonet, C., Berg, P., Courty, N., Septier, F., Drumetz, L., & Pham, M. T. (2023). 'Spherical Sliced-Wasserstein". International Conference on Learning Representations.
+
+"""
+
+# Author: Clément Bonet <clement.bonet@univ-ubs.fr>
+#
+# License: MIT License
+
+# sphinx_gallery_thumbnail_number = 2
+
+import matplotlib.pylab as pl
+import numpy as np
+
+import ot
+
+##############################################################################
+# Generate data
+# -------------
+
+# %% parameters and data generation
+
+n = 500 # nb samples
+
+xs = np.random.randn(n, 3)
+xt = np.random.randn(n, 3)
+
+xs = xs / np.sqrt(np.sum(xs**2, -1, keepdims=True))
+xt = xt / np.sqrt(np.sum(xt**2, -1, keepdims=True))
+
+a, b = np.ones((n,)) / n, np.ones((n,)) / n # uniform distribution on samples
+
+##############################################################################
+# Plot data
+# ---------
+
+# %% plot samples
+
+fig = pl.figure(figsize=(10, 10))
+ax = pl.axes(projection='3d')
+ax.grid(False)
+
+u, v = np.mgrid[0:2 * np.pi:30j, 0:np.pi:30j]
+x = np.cos(u) * np.sin(v)
+y = np.sin(u) * np.sin(v)
+z = np.cos(v)
+ax.plot_surface(x, y, z, color="gray", alpha=0.03)
+ax.plot_wireframe(x, y, z, linewidth=1, alpha=0.25, color="gray")
+
+ax.scatter(xs[:, 0], xs[:, 1], xs[:, 2], label="Source")
+ax.scatter(xt[:, 0], xt[:, 1], xt[:, 2], label="Target")
+
+fs = 10
+# Labels
+ax.set_xlabel('x', fontsize=fs)
+ax.set_ylabel('y', fontsize=fs)
+ax.set_zlabel('z', fontsize=fs)
+
+ax.view_init(20, 120)
+ax.set_xlim(-1.5, 1.5)
+ax.set_ylim(-1.5, 1.5)
+ax.set_zlim(-1.5, 1.5)
+
+# Ticks
+ax.set_xticks([-1, 0, 1])
+ax.set_yticks([-1, 0, 1])
+ax.set_zticks([-1, 0, 1])
+
+pl.legend(loc=0)
+pl.title("Source and Target distribution")
+
+###############################################################################
+# Spherical Sliced Wasserstein for different seeds and number of projections
+# --------------------------------------------------------------------------
+
+n_seed = 50
+n_projections_arr = np.logspace(0, 3, 25, dtype=int)
+res = np.empty((n_seed, 25))
+
+# %% Compute statistics
+for seed in range(n_seed):
+ for i, n_projections in enumerate(n_projections_arr):
+ res[seed, i] = ot.sliced_wasserstein_sphere(xs, xt, a, b, n_projections, seed=seed, p=1)
+
+res_mean = np.mean(res, axis=0)
+res_std = np.std(res, axis=0)
+
+###############################################################################
+# Plot Spherical Sliced Wasserstein
+# ---------------------------------
+
+pl.figure(2)
+pl.plot(n_projections_arr, res_mean, label=r"$SSW_1$")
+pl.fill_between(n_projections_arr, res_mean - 2 * res_std, res_mean + 2 * res_std, alpha=0.5)
+
+pl.legend()
+pl.xscale('log')
+
+pl.xlabel("Number of projections")
+pl.ylabel("Distance")
+pl.title('Spherical Sliced Wasserstein Distance with 95% confidence inverval')
+
+pl.show()
diff --git a/examples/unbalanced-partial/plot_UOT_barycenter_1D.py b/examples/unbalanced-partial/plot_UOT_barycenter_1D.py
index 931798b..8d227c0 100644
--- a/examples/unbalanced-partial/plot_UOT_barycenter_1D.py
+++ b/examples/unbalanced-partial/plot_UOT_barycenter_1D.py
@@ -127,7 +127,7 @@ for i, z in enumerate(zs):
ys = B_l2[:, i]
verts.append(list(zip(x, ys)))
-ax = pl.gcf().gca(projection='3d')
+ax = pl.gcf().add_subplot(projection='3d')
poly = PolyCollection(verts, facecolors=[cmap(a) for a in weight_list])
poly.set_alpha(0.7)
@@ -149,7 +149,7 @@ for i, z in enumerate(zs):
ys = B_wass[:, i]
verts.append(list(zip(x, ys)))
-ax = pl.gcf().gca(projection='3d')
+ax = pl.gcf().add_subplot(projection='3d')
poly = PolyCollection(verts, facecolors=[cmap(a) for a in weight_list])
poly.set_alpha(0.7)
diff --git a/ot/__init__.py b/ot/__init__.py
index 86ed94e..1a685b6 100644
--- a/ot/__init__.py
+++ b/ot/__init__.py
@@ -8,7 +8,6 @@
, :py:mod:`ot.unbalanced`.
The following sub-modules are not imported due to additional dependencies:
- :any:`ot.dr` : depends on :code:`pymanopt` and :code:`autograd`.
- - :any:`ot.gpu` : depends on :code:`cupy` and a CUDA GPU.
- :any:`ot.plot` : depends on :code:`matplotlib`
"""
@@ -34,32 +33,39 @@ from . import backend
from . import regpath
from . import weak
from . import factored
+from . import solvers
+from . import gaussian
# OT functions
-from .lp import emd, emd2, emd_1d, emd2_1d, wasserstein_1d
+from .lp import (emd, emd2, emd_1d, emd2_1d, wasserstein_1d,
+ binary_search_circle, wasserstein_circle,
+ semidiscrete_wasserstein2_unif_circle)
from .bregman import sinkhorn, sinkhorn2, barycenter
from .unbalanced import (sinkhorn_unbalanced, barycenter_unbalanced,
sinkhorn_unbalanced2)
from .da import sinkhorn_lpl1_mm
-from .sliced import sliced_wasserstein_distance, max_sliced_wasserstein_distance
+from .sliced import (sliced_wasserstein_distance, max_sliced_wasserstein_distance,
+ sliced_wasserstein_sphere, sliced_wasserstein_sphere_unif)
from .gromov import (gromov_wasserstein, gromov_wasserstein2,
gromov_barycenters, fused_gromov_wasserstein, fused_gromov_wasserstein2)
from .weak import weak_optimal_transport
from .factored import factored_optimal_transport
-
+from .solvers import solve
# utils functions
from .utils import dist, unif, tic, toc, toq
-__version__ = "0.8.2"
+__version__ = "0.9.0"
__all__ = ['emd', 'emd2', 'emd_1d', 'sinkhorn', 'sinkhorn2', 'utils',
'datasets', 'bregman', 'lp', 'tic', 'toc', 'toq', 'gromov',
- 'emd2_1d', 'wasserstein_1d', 'backend',
+ 'emd2_1d', 'wasserstein_1d', 'backend', 'gaussian',
'dist', 'unif', 'barycenter', 'sinkhorn_lpl1_mm', 'da', 'optim',
'sinkhorn_unbalanced', 'barycenter_unbalanced',
- 'sinkhorn_unbalanced2', 'sliced_wasserstein_distance',
+ 'sinkhorn_unbalanced2', 'sliced_wasserstein_distance', 'sliced_wasserstein_sphere',
'gromov_wasserstein', 'gromov_wasserstein2', 'gromov_barycenters', 'fused_gromov_wasserstein', 'fused_gromov_wasserstein2',
'max_sliced_wasserstein_distance', 'weak_optimal_transport',
- 'factored_optimal_transport',
- 'smooth', 'stochastic', 'unbalanced', 'partial', 'regpath']
+ 'factored_optimal_transport', 'solve',
+ 'smooth', 'stochastic', 'unbalanced', 'partial', 'regpath', 'solvers',
+ 'binary_search_circle', 'wasserstein_circle',
+ 'semidiscrete_wasserstein2_unif_circle', 'sliced_wasserstein_sphere_unif']
diff --git a/ot/backend.py b/ot/backend.py
index 361ffba..0779243 100644
--- a/ot/backend.py
+++ b/ot/backend.py
@@ -534,9 +534,9 @@ class Backend():
"""
raise NotImplementedError()
- def zero_pad(self, a, pad_width):
+ def zero_pad(self, a, pad_width, value=0):
r"""
- Pads a tensor.
+ Pads a tensor with a given value (0 by default).
This function follows the api from :any:`numpy.pad`
@@ -854,6 +854,21 @@ class Backend():
"""
raise NotImplementedError()
+ def kl_div(self, p, q, eps=1e-16):
+ r"""
+ Computes the Kullback-Leibler divergence.
+
+ This function follows the api from :any:`scipy.stats.entropy`.
+
+ Parameter eps is used to avoid numerical errors and is added in the log.
+
+ .. math::
+ KL(p,q) = \sum_i p(i) \log (\frac{p(i)}{q(i)}+\epsilon)
+
+ See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.entropy.html
+ """
+ raise NotImplementedError()
+
def isfinite(self, a):
r"""
Tests element-wise for finiteness (not infinity and not Not a Number).
@@ -880,6 +895,62 @@ class Backend():
"""
raise NotImplementedError()
+ def tile(self, a, reps):
+ r"""
+ Construct an array by repeating a the number of times given by reps
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.tile.html
+ """
+ raise NotImplementedError()
+
+ def floor(self, a):
+ r"""
+ Return the floor of the input element-wise
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.floor.html
+ """
+ raise NotImplementedError()
+
+ def prod(self, a, axis=None):
+ r"""
+ Return the product of all elements.
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.prod.html
+ """
+ raise NotImplementedError()
+
+ def sort2(self, a, axis=None):
+ r"""
+ Return the sorted array and the indices to sort the array
+
+ See: https://pytorch.org/docs/stable/generated/torch.sort.html
+ """
+ raise NotImplementedError()
+
+ def qr(self, a):
+ r"""
+ Return the QR factorization
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.linalg.qr.html
+ """
+ raise NotImplementedError()
+
+ def atan2(self, a, b):
+ r"""
+ Element wise arctangent
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.arctan2.html
+ """
+ raise NotImplementedError()
+
+ def transpose(self, a, axes=None):
+ r"""
+ Returns a tensor that is a transposed version of a. The given dimensions dim0 and dim1 are swapped.
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.transpose.html
+ """
+ raise NotImplementedError()
+
class NumpyBackend(Backend):
"""
@@ -1024,8 +1095,8 @@ class NumpyBackend(Backend):
def concatenate(self, arrays, axis=0):
return np.concatenate(arrays, axis)
- def zero_pad(self, a, pad_width):
- return np.pad(a, pad_width)
+ def zero_pad(self, a, pad_width, value=0):
+ return np.pad(a, pad_width, constant_values=value)
def argmax(self, a, axis=None):
return np.argmax(a, axis=axis)
@@ -1158,6 +1229,9 @@ class NumpyBackend(Backend):
def sqrtm(self, a):
return scipy.linalg.sqrtm(a)
+ def kl_div(self, p, q, eps=1e-16):
+ return np.sum(p * np.log(p / q + eps))
+
def isfinite(self, a):
return np.isfinite(a)
@@ -1167,6 +1241,44 @@ class NumpyBackend(Backend):
def is_floating_point(self, a):
return a.dtype.kind == "f"
+ def tile(self, a, reps):
+ return np.tile(a, reps)
+
+ def floor(self, a):
+ return np.floor(a)
+
+ def prod(self, a, axis=0):
+ return np.prod(a, axis=axis)
+
+ def sort2(self, a, axis=-1):
+ return self.sort(a, axis), self.argsort(a, axis)
+
+ def qr(self, a):
+ np_version = tuple([int(k) for k in np.__version__.split(".")])
+ if np_version < (1, 22, 0):
+ M, N = a.shape[-2], a.shape[-1]
+ K = min(M, N)
+
+ if len(a.shape) >= 3:
+ n = a.shape[0]
+
+ qs, rs = np.zeros((n, M, K)), np.zeros((n, K, N))
+
+ for i in range(a.shape[0]):
+ qs[i], rs[i] = np.linalg.qr(a[i])
+
+ else:
+ return np.linalg.qr(a)
+
+ return qs, rs
+ return np.linalg.qr(a)
+
+ def atan2(self, a, b):
+ return np.arctan2(a, b)
+
+ def transpose(self, a, axes=None):
+ return np.transpose(a, axes)
+
class JaxBackend(Backend):
"""
@@ -1333,8 +1445,8 @@ class JaxBackend(Backend):
def concatenate(self, arrays, axis=0):
return jnp.concatenate(arrays, axis)
- def zero_pad(self, a, pad_width):
- return jnp.pad(a, pad_width)
+ def zero_pad(self, a, pad_width, value=0):
+ return jnp.pad(a, pad_width, constant_values=value)
def argmax(self, a, axis=None):
return jnp.argmax(a, axis=axis)
@@ -1481,6 +1593,9 @@ class JaxBackend(Backend):
L, V = jnp.linalg.eigh(a)
return (V * jnp.sqrt(L)[None, :]) @ V.T
+ def kl_div(self, p, q, eps=1e-16):
+ return jnp.sum(p * jnp.log(p / q + eps))
+
def isfinite(self, a):
return jnp.isfinite(a)
@@ -1490,6 +1605,27 @@ class JaxBackend(Backend):
def is_floating_point(self, a):
return a.dtype.kind == "f"
+ def tile(self, a, reps):
+ return jnp.tile(a, reps)
+
+ def floor(self, a):
+ return jnp.floor(a)
+
+ def prod(self, a, axis=0):
+ return jnp.prod(a, axis=axis)
+
+ def sort2(self, a, axis=-1):
+ return self.sort(a, axis), self.argsort(a, axis)
+
+ def qr(self, a):
+ return jnp.linalg.qr(a)
+
+ def atan2(self, a, b):
+ return jnp.arctan2(a, b)
+
+ def transpose(self, a, axes=None):
+ return jnp.transpose(a, axes)
+
class TorchBackend(Backend):
"""
@@ -1507,15 +1643,19 @@ class TorchBackend(Backend):
def __init__(self):
- self.rng_ = torch.Generator()
+ self.rng_ = torch.Generator("cpu")
self.rng_.seed()
self.__type_list__ = [torch.tensor(1, dtype=torch.float32),
torch.tensor(1, dtype=torch.float64)]
if torch.cuda.is_available():
+ self.rng_cuda_ = torch.Generator("cuda")
+ self.rng_cuda_.seed()
self.__type_list__.append(torch.tensor(1, dtype=torch.float32, device='cuda'))
self.__type_list__.append(torch.tensor(1, dtype=torch.float64, device='cuda'))
+ else:
+ self.rng_cuda_ = torch.Generator("cpu")
from torch.autograd import Function
@@ -1704,13 +1844,13 @@ class TorchBackend(Backend):
def concatenate(self, arrays, axis=0):
return torch.cat(arrays, dim=axis)
- def zero_pad(self, a, pad_width):
+ def zero_pad(self, a, pad_width, value=0):
from torch.nn.functional import pad
# pad_width is an array of ndim tuples indicating how many 0 before and after
# we need to add. We first need to make it compliant with torch syntax, that
# starts with the last dim, then second last, etc.
how_pad = tuple(element for tupl in pad_width[::-1] for element in tupl)
- return pad(a, how_pad)
+ return pad(a, how_pad, value=value)
def argmax(self, a, axis=None):
return torch.argmax(a, dim=axis)
@@ -1761,20 +1901,26 @@ class TorchBackend(Backend):
def seed(self, seed=None):
if isinstance(seed, int):
self.rng_.manual_seed(seed)
+ self.rng_cuda_.manual_seed(seed)
elif isinstance(seed, torch.Generator):
- self.rng_ = seed
+ if self.device_type(seed) == "GPU":
+ self.rng_cuda_ = seed
+ else:
+ self.rng_ = seed
else:
raise ValueError("Non compatible seed : {}".format(seed))
def rand(self, *size, type_as=None):
if type_as is not None:
- return torch.rand(size=size, generator=self.rng_, dtype=type_as.dtype, device=type_as.device)
+ generator = self.rng_cuda_ if self.device_type(type_as) == "GPU" else self.rng_
+ return torch.rand(size=size, generator=generator, dtype=type_as.dtype, device=type_as.device)
else:
return torch.rand(size=size, generator=self.rng_)
def randn(self, *size, type_as=None):
if type_as is not None:
- return torch.randn(size=size, dtype=type_as.dtype, generator=self.rng_, device=type_as.device)
+ generator = self.rng_cuda_ if self.device_type(type_as) == "GPU" else self.rng_
+ return torch.randn(size=size, dtype=type_as.dtype, generator=generator, device=type_as.device)
else:
return torch.randn(size=size, generator=self.rng_)
@@ -1891,6 +2037,9 @@ class TorchBackend(Backend):
L, V = torch.linalg.eigh(a)
return (V * torch.sqrt(L)[None, :]) @ V.T
+ def kl_div(self, p, q, eps=1e-16):
+ return torch.sum(p * torch.log(p / q + eps))
+
def isfinite(self, a):
return torch.isfinite(a)
@@ -1900,6 +2049,29 @@ class TorchBackend(Backend):
def is_floating_point(self, a):
return a.dtype.is_floating_point
+ def tile(self, a, reps):
+ return a.repeat(reps)
+
+ def floor(self, a):
+ return torch.floor(a)
+
+ def prod(self, a, axis=0):
+ return torch.prod(a, dim=axis)
+
+ def sort2(self, a, axis=-1):
+ return torch.sort(a, axis)
+
+ def qr(self, a):
+ return torch.linalg.qr(a)
+
+ def atan2(self, a, b):
+ return torch.atan2(a, b)
+
+ def transpose(self, a, axes=None):
+ if axes is None:
+ axes = tuple(range(a.ndim)[::-1])
+ return a.permute(axes)
+
class CupyBackend(Backend): # pragma: no cover
"""
@@ -2062,8 +2234,8 @@ class CupyBackend(Backend): # pragma: no cover
def concatenate(self, arrays, axis=0):
return cp.concatenate(arrays, axis)
- def zero_pad(self, a, pad_width):
- return cp.pad(a, pad_width)
+ def zero_pad(self, a, pad_width, value=0):
+ return cp.pad(a, pad_width, constant_values=value)
def argmax(self, a, axis=None):
return cp.argmax(a, axis=axis)
@@ -2238,6 +2410,9 @@ class CupyBackend(Backend): # pragma: no cover
L, V = cp.linalg.eigh(a)
return (V * self.sqrt(L)[None, :]) @ V.T
+ def kl_div(self, p, q, eps=1e-16):
+ return cp.sum(p * cp.log(p / q + eps))
+
def isfinite(self, a):
return cp.isfinite(a)
@@ -2247,6 +2422,27 @@ class CupyBackend(Backend): # pragma: no cover
def is_floating_point(self, a):
return a.dtype.kind == "f"
+ def tile(self, a, reps):
+ return cp.tile(a, reps)
+
+ def floor(self, a):
+ return cp.floor(a)
+
+ def prod(self, a, axis=0):
+ return cp.prod(a, axis=axis)
+
+ def sort2(self, a, axis=-1):
+ return self.sort(a, axis), self.argsort(a, axis)
+
+ def qr(self, a):
+ return cp.linalg.qr(a)
+
+ def atan2(self, a, b):
+ return cp.arctan2(a, b)
+
+ def transpose(self, a, axes=None):
+ return cp.transpose(a, axes)
+
class TensorflowBackend(Backend):
@@ -2417,8 +2613,8 @@ class TensorflowBackend(Backend):
def concatenate(self, arrays, axis=0):
return tnp.concatenate(arrays, axis)
- def zero_pad(self, a, pad_width):
- return tnp.pad(a, pad_width, mode="constant")
+ def zero_pad(self, a, pad_width, value=0):
+ return tnp.pad(a, pad_width, mode="constant", constant_values=value)
def argmax(self, a, axis=None):
return tnp.argmax(a, axis=axis)
@@ -2598,6 +2794,9 @@ class TensorflowBackend(Backend):
def sqrtm(self, a):
return tf.linalg.sqrtm(a)
+ def kl_div(self, p, q, eps=1e-16):
+ return tnp.sum(p * tnp.log(p / q + eps))
+
def isfinite(self, a):
return tnp.isfinite(a)
@@ -2606,3 +2805,24 @@ class TensorflowBackend(Backend):
def is_floating_point(self, a):
return a.dtype.is_floating
+
+ def tile(self, a, reps):
+ return tnp.tile(a, reps)
+
+ def floor(self, a):
+ return tf.floor(a)
+
+ def prod(self, a, axis=0):
+ return tnp.prod(a, axis=axis)
+
+ def sort2(self, a, axis=-1):
+ return self.sort(a, axis), self.argsort(a, axis)
+
+ def qr(self, a):
+ return tf.linalg.qr(a)
+
+ def atan2(self, a, b):
+ return tf.math.atan2(a, b)
+
+ def transpose(self, a, axes=None):
+ return tf.transpose(a, perm=axes)
diff --git a/ot/bregman.py b/ot/bregman.py
index c06af2f..20bef7e 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -24,9 +24,8 @@ from ot.utils import unif, dist, list_to_array
from .backend import get_backend
-def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
- stopThr=1e-9, verbose=False, log=False, warn=True,
- **kwargs):
+def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, stopThr=1e-9,
+ verbose=False, log=False, warn=True, warmstart=None, **kwargs):
r"""
Solve the entropic regularization optimal transport problem and return the OT matrix
@@ -101,6 +100,9 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
record log if True
warn : bool, optional
if True, raises a warning if the algorithm doesn't convergence.
+ warmstart: tuple of arrays, shape (dim_a, dim_b), optional
+ Initialization of dual potentials. If provided, the dual potentials should be given
+ (that is the logarithm of the u,v sinkhorn scaling vectors)
Returns
-------
@@ -156,34 +158,33 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
if method.lower() == 'sinkhorn':
return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax,
stopThr=stopThr, verbose=verbose, log=log,
- warn=warn,
+ warn=warn, warmstart=warmstart,
**kwargs)
elif method.lower() == 'sinkhorn_log':
return sinkhorn_log(a, b, M, reg, numItermax=numItermax,
stopThr=stopThr, verbose=verbose, log=log,
- warn=warn,
+ warn=warn, warmstart=warmstart,
**kwargs)
elif method.lower() == 'greenkhorn':
return greenkhorn(a, b, M, reg, numItermax=numItermax,
stopThr=stopThr, verbose=verbose, log=log,
- warn=warn)
+ warn=warn, warmstart=warmstart)
elif method.lower() == 'sinkhorn_stabilized':
return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax,
- stopThr=stopThr, verbose=verbose,
- log=log, warn=warn,
+ stopThr=stopThr, warmstart=warmstart,
+ verbose=verbose, log=log, warn=warn,
**kwargs)
elif method.lower() == 'sinkhorn_epsilon_scaling':
- return sinkhorn_epsilon_scaling(a, b, M, reg,
- numItermax=numItermax,
- stopThr=stopThr, verbose=verbose,
- log=log, warn=warn,
+ return sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=numItermax,
+ stopThr=stopThr, warmstart=warmstart,
+ verbose=verbose, log=log, warn=warn,
**kwargs)
else:
raise ValueError("Unknown method '%s'." % method)
def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
- stopThr=1e-9, verbose=False, log=False, warn=False, **kwargs):
+ stopThr=1e-9, verbose=False, log=False, warn=False, warmstart=None, **kwargs):
r"""
Solve the entropic regularization optimal transport problem and return the loss
@@ -207,6 +208,9 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
- :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target
weights (histograms, both sum to 1)
+ and returns :math:`\langle \gamma^*, \mathbf{M} \rangle_F` (without
+ the entropic contribution).
+
.. note:: This function is backend-compatible and will work on arrays
from all compatible backends.
@@ -257,6 +261,9 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
record log if True
warn : bool, optional
if True, raises a warning if the algorithm doesn't convergence.
+ warmstart: tuple of arrays, shape (dim_a, dim_b), optional
+ Initialization of dual potentials. If provided, the dual potentials should be given
+ (that is the logarithm of the u,v sinkhorn scaling vectors)
Returns
-------
@@ -320,15 +327,18 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
if len(b.shape) < 2:
if method.lower() == 'sinkhorn':
res = sinkhorn_knopp(a, b, M, reg, numItermax=numItermax,
- stopThr=stopThr, verbose=verbose, log=log,
+ stopThr=stopThr, verbose=verbose,
+ log=log, warn=warn, warmstart=warmstart,
**kwargs)
elif method.lower() == 'sinkhorn_log':
res = sinkhorn_log(a, b, M, reg, numItermax=numItermax,
- stopThr=stopThr, verbose=verbose, log=log,
+ stopThr=stopThr, verbose=verbose,
+ log=log, warn=warn, warmstart=warmstart,
**kwargs)
elif method.lower() == 'sinkhorn_stabilized':
res = sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax,
- stopThr=stopThr, verbose=verbose, log=log,
+ stopThr=stopThr, warmstart=warmstart,
+ verbose=verbose, log=log, warn=warn,
**kwargs)
else:
raise ValueError("Unknown method '%s'." % method)
@@ -341,23 +351,25 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
if method.lower() == 'sinkhorn':
return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax,
- stopThr=stopThr, verbose=verbose, log=log,
+ stopThr=stopThr, verbose=verbose,
+ log=log, warn=warn, warmstart=warmstart,
**kwargs)
elif method.lower() == 'sinkhorn_log':
return sinkhorn_log(a, b, M, reg, numItermax=numItermax,
- stopThr=stopThr, verbose=verbose, log=log,
+ stopThr=stopThr, verbose=verbose,
+ log=log, warn=warn, warmstart=warmstart,
**kwargs)
elif method.lower() == 'sinkhorn_stabilized':
return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax,
- stopThr=stopThr, verbose=verbose, log=log,
+ stopThr=stopThr, warmstart=warmstart,
+ verbose=verbose, log=log, warn=warn,
**kwargs)
else:
raise ValueError("Unknown method '%s'." % method)
def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9,
- verbose=False, log=False, warn=True,
- **kwargs):
+ verbose=False, log=False, warn=True, warmstart=None, **kwargs):
r"""
Solve the entropic regularization optimal transport problem and return the OT matrix
@@ -406,6 +418,9 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9,
record log if True
warn : bool, optional
if True, raises a warning if the algorithm doesn't convergence.
+ warmstart: tuple of arrays, shape (dim_a, dim_b), optional
+ Initialization of dual potentials. If provided, the dual potentials should be given
+ (that is the logarithm of the u,v sinkhorn scaling vectors)
Returns
-------
@@ -465,12 +480,15 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9,
# we assume that no distances are null except those of the diagonal of
# distances
- if n_hists:
- u = nx.ones((dim_a, n_hists), type_as=M) / dim_a
- v = nx.ones((dim_b, n_hists), type_as=M) / dim_b
+ if warmstart is None:
+ if n_hists:
+ u = nx.ones((dim_a, n_hists), type_as=M) / dim_a
+ v = nx.ones((dim_b, n_hists), type_as=M) / dim_b
+ else:
+ u = nx.ones(dim_a, type_as=M) / dim_a
+ v = nx.ones(dim_b, type_as=M) / dim_b
else:
- u = nx.ones(dim_a, type_as=M) / dim_a
- v = nx.ones(dim_b, type_as=M) / dim_b
+ u, v = nx.exp(warmstart[0]), nx.exp(warmstart[1])
K = nx.exp(M / (-reg))
@@ -538,7 +556,7 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9,
def sinkhorn_log(a, b, M, reg, numItermax=1000, stopThr=1e-9, verbose=False,
- log=False, warn=True, **kwargs):
+ log=False, warn=True, warmstart=None, **kwargs):
r"""
Solve the entropic regularization optimal transport problem in log space
and return the OT matrix
@@ -587,6 +605,9 @@ def sinkhorn_log(a, b, M, reg, numItermax=1000, stopThr=1e-9, verbose=False,
record log if True
warn : bool, optional
if True, raises a warning if the algorithm doesn't convergence.
+ warmstart: tuple of arrays, shape (dim_a, dim_b), optional
+ Initialization of dual potentials. If provided, the dual potentials should be given
+ (that is the logarithm of the u,v sinkhorn scaling vectors)
Returns
-------
@@ -647,6 +668,10 @@ def sinkhorn_log(a, b, M, reg, numItermax=1000, stopThr=1e-9, verbose=False,
else:
n_hists = 0
+ # in case of multiple historgrams
+ if n_hists > 1 and warmstart is None:
+ warmstart = [None] * n_hists
+
if n_hists: # we do not want to use tensors sor we do a loop
lst_loss = []
@@ -654,8 +679,8 @@ def sinkhorn_log(a, b, M, reg, numItermax=1000, stopThr=1e-9, verbose=False,
lst_v = []
for k in range(n_hists):
- res = sinkhorn_log(a, b[:, k], M, reg, numItermax=numItermax,
- stopThr=stopThr, verbose=verbose, log=log, **kwargs)
+ res = sinkhorn_log(a, b[:, k], M, reg, numItermax=numItermax, stopThr=stopThr,
+ verbose=verbose, log=log, warmstart=warmstart[k], **kwargs)
if log:
lst_loss.append(nx.sum(M * res[0]))
@@ -682,9 +707,11 @@ def sinkhorn_log(a, b, M, reg, numItermax=1000, stopThr=1e-9, verbose=False,
# we assume that no distances are null except those of the diagonal of
# distances
-
- u = nx.zeros(dim_a, type_as=M)
- v = nx.zeros(dim_b, type_as=M)
+ if warmstart is None:
+ u = nx.zeros(dim_a, type_as=M)
+ v = nx.zeros(dim_b, type_as=M)
+ else:
+ u, v = warmstart
def get_logT(u, v):
if n_hists:
@@ -738,7 +765,7 @@ def sinkhorn_log(a, b, M, reg, numItermax=1000, stopThr=1e-9, verbose=False,
def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False,
- log=False, warn=True):
+ log=False, warn=True, warmstart=None):
r"""
Solve the entropic regularization optimal transport problem and return the OT matrix
@@ -786,6 +813,9 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False,
record log if True
warn : bool, optional
if True, raises a warning if the algorithm doesn't convergence.
+ warmstart: tuple of arrays, shape (dim_a, dim_b), optional
+ Initialization of dual potentials. If provided, the dual potentials should be given
+ (that is the logarithm of the u,v sinkhorn scaling vectors)
Returns
-------
@@ -844,8 +874,11 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False,
K = nx.exp(-M / reg)
- u = nx.full((dim_a,), 1. / dim_a, type_as=K)
- v = nx.full((dim_b,), 1. / dim_b, type_as=K)
+ if warmstart is None:
+ u = nx.full((dim_a,), 1. / dim_a, type_as=K)
+ v = nx.full((dim_b,), 1. / dim_b, type_as=K)
+ else:
+ u, v = nx.exp(warmstart[0]), nx.exp(warmstart[1])
G = u[:, None] * K * v[None, :]
viol = nx.sum(G, axis=1) - a
@@ -1065,7 +1098,8 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
# remove numerical problems and store them in K
if nx.max(nx.abs(u)) > tau or nx.max(nx.abs(v)) > tau:
if n_hists:
- alpha, beta = alpha + reg * nx.max(nx.log(u), 1), beta + reg * nx.max(nx.log(v))
+ alpha, beta = alpha + reg * \
+ nx.max(nx.log(u), 1), beta + reg * nx.max(nx.log(v))
else:
alpha, beta = alpha + reg * nx.log(u), beta + reg * nx.log(v)
if n_hists:
@@ -1278,7 +1312,7 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4,
regi = get_reg(ii)
G, logi = sinkhorn_stabilized(a, b, M, regi,
- numItermax=numInnerItermax, stopThr=1e-9,
+ numItermax=numInnerItermax, stopThr=stopThr,
warmstart=(alpha, beta), verbose=False,
print_period=20, tau=tau, log=True)
@@ -1289,13 +1323,15 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4,
# we can speed up the process by checking for the error only all
# the 10th iterations
transp = G
- err = nx.norm(nx.sum(transp, axis=0) - b) ** 2 + nx.norm(nx.sum(transp, axis=1) - a) ** 2
+ err = nx.norm(nx.sum(transp, axis=0) - b) ** 2 + \
+ nx.norm(nx.sum(transp, axis=1) - a) ** 2
if log:
log['err'].append(err)
if verbose:
if ii % (print_period * 10) == 0:
- print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
+ print('{:5s}|{:12s}'.format(
+ 'It.', 'Err') + '\n' + '-' * 19)
print('{:5d}|{:8e}|'.format(ii, err))
if err <= stopThr and ii > numItermin:
@@ -1511,7 +1547,7 @@ def barycenter_sinkhorn(A, M, reg, weights=None, numItermax=1000,
for ii in range(numItermax):
- UKv = u * nx.dot(K, A / nx.dot(K, u))
+ UKv = u * nx.dot(K.T, A / nx.dot(K, u))
u = (u.T * geometricBar(weights, UKv)).T / UKv
if ii % 10 == 1:
@@ -1540,6 +1576,129 @@ def barycenter_sinkhorn(A, M, reg, weights=None, numItermax=1000,
return geometricBar(weights, UKv)
+def free_support_sinkhorn_barycenter(measures_locations, measures_weights, X_init, reg, b=None, weights=None,
+ numItermax=100, numInnerItermax=1000, stopThr=1e-7, verbose=False, log=None,
+ **kwargs):
+ r"""
+ Solves the free support (locations of the barycenters are optimized, not the weights) regularized Wasserstein barycenter problem (i.e. the weighted Frechet mean for the 2-Sinkhorn divergence), formally:
+
+ .. math::
+ \min_\mathbf{X} \quad \sum_{i=1}^N w_i W_{reg}^2(\mathbf{b}, \mathbf{X}, \mathbf{a}_i, \mathbf{X}_i)
+
+ where :
+
+ - :math:`w \in \mathbb{(0, 1)}^{N}`'s are the barycenter weights and sum to one
+ - `measure_weights` denotes the :math:`\mathbf{a}_i \in \mathbb{R}^{k_i}`: empirical measures weights (on simplex)
+ - `measures_locations` denotes the :math:`\mathbf{X}_i \in \mathbb{R}^{k_i, d}`: empirical measures atoms locations
+ - :math:`\mathbf{b} \in \mathbb{R}^{k}` is the desired weights vector of the barycenter
+
+ This problem is considered in :ref:`[20] <references-free-support-barycenter>` (Algorithm 2).
+ There are two differences with the following codes:
+
+ - we do not optimize over the weights
+ - we do not do line search for the locations updates, we use i.e. :math:`\theta = 1` in
+ :ref:`[20] <references-free-support-barycenter>` (Algorithm 2). This can be seen as a discrete
+ implementation of the fixed-point algorithm of
+ :ref:`[43] <references-free-support-barycenter>` proposed in the continuous setting.
+ - at each iteration, instead of solving an exact OT problem, we use the Sinkhorn algorithm for calculating the
+ transport plan in :ref:`[20] <references-free-support-barycenter>` (Algorithm 2).
+
+ Parameters
+ ----------
+ measures_locations : list of N (k_i,d) array-like
+ The discrete support of a measure supported on :math:`k_i` locations of a `d`-dimensional space
+ (:math:`k_i` can be different for each element of the list)
+ measures_weights : list of N (k_i,) array-like
+ Numpy arrays where each numpy array has :math:`k_i` non-negatives values summing to one
+ representing the weights of each discrete input measure
+
+ X_init : (k,d) array-like
+ Initialization of the support locations (on `k` atoms) of the barycenter
+ reg : float
+ Regularization term >0
+ b : (k,) array-like
+ Initialization of the weights of the barycenter (non-negatives, sum to 1)
+ weights : (N,) array-like
+ Initialization of the coefficients of the barycenter (non-negatives, sum to 1)
+
+ numItermax : int, optional
+ Max number of iterations
+ numInnerItermax : int, optional
+ Max number of iterations when calculating the transport plans with Sinkhorn
+ stopThr : float, optional
+ Stop threshold on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+ Returns
+ -------
+ X : (k,d) array-like
+ Support locations (on k atoms) of the barycenter
+
+ See Also
+ --------
+ ot.bregman.sinkhorn : Entropic regularized OT solver
+ ot.lp.free_support_barycenter : Barycenter solver based on Linear Programming
+
+ .. _references-free-support-barycenter:
+ References
+ ----------
+ .. [20] Cuturi, Marco, and Arnaud Doucet. "Fast computation of Wasserstein barycenters." International Conference on Machine Learning. 2014.
+
+ .. [43] Álvarez-Esteban, Pedro C., et al. "A fixed-point approach to barycenters in Wasserstein space." Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762.
+
+ """
+ nx = get_backend(*measures_locations, *measures_weights, X_init)
+
+ iter_count = 0
+
+ N = len(measures_locations)
+ k = X_init.shape[0]
+ d = X_init.shape[1]
+ if b is None:
+ b = nx.ones((k,), type_as=X_init) / k
+ if weights is None:
+ weights = nx.ones((N,), type_as=X_init) / N
+
+ X = X_init
+
+ log_dict = {}
+ displacement_square_norms = []
+
+ displacement_square_norm = stopThr + 1.
+
+ while (displacement_square_norm > stopThr and iter_count < numItermax):
+
+ T_sum = nx.zeros((k, d), type_as=X_init)
+
+ for (measure_locations_i, measure_weights_i, weight_i) in zip(measures_locations, measures_weights, weights):
+ M_i = dist(X, measure_locations_i)
+ T_i = sinkhorn(b, measure_weights_i, M_i, reg=reg,
+ numItermax=numInnerItermax, **kwargs)
+ T_sum = T_sum + weight_i * 1. / \
+ b[:, None] * nx.dot(T_i, measure_locations_i)
+
+ displacement_square_norm = nx.sum((T_sum - X) ** 2)
+ if log:
+ displacement_square_norms.append(displacement_square_norm)
+
+ X = T_sum
+
+ if verbose:
+ print('iteration %d, displacement_square_norm=%f\n',
+ iter_count, displacement_square_norm)
+
+ iter_count += 1
+
+ if log:
+ log_dict['displacement_square_norms'] = displacement_square_norms
+ return X, log_dict
+ else:
+ return X
+
+
def _barycenter_sinkhorn_log(A, M, reg, weights=None, numItermax=1000,
stopThr=1e-4, verbose=False, log=False, warn=True):
r"""Compute the entropic wasserstein barycenter in log-domain
@@ -2084,7 +2243,8 @@ def _convolutional_barycenter2d(A, reg, weights=None, numItermax=10000,
if verbose:
if ii % 200 == 0:
- print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
+ print('{:5s}|{:12s}'.format(
+ 'It.', 'Err') + '\n' + '-' * 19)
print('{:5d}|{:8e}|'.format(ii, err))
if err < stopThr:
break
@@ -2162,7 +2322,8 @@ def _convolutional_barycenter2d_log(A, reg, weights=None, numItermax=10000,
if verbose:
if ii % 200 == 0:
- print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
+ print('{:5s}|{:12s}'.format(
+ 'It.', 'Err') + '\n' + '-' * 19)
print('{:5d}|{:8e}|'.format(ii, err))
if err < stopThr:
break
@@ -2321,7 +2482,8 @@ def _convolutional_barycenter2d_debiased(A, reg, weights=None, numItermax=10000,
if verbose:
if ii % 200 == 0:
- print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
+ print('{:5s}|{:12s}'.format(
+ 'It.', 'Err') + '\n' + '-' * 19)
print('{:5d}|{:8e}|'.format(ii, err))
# debiased Sinkhorn does not converge monotonically
@@ -2401,7 +2563,8 @@ def _convolutional_barycenter2d_debiased_log(A, reg, weights=None, numItermax=10
if verbose:
if ii % 200 == 0:
- print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
+ print('{:5s}|{:12s}'.format(
+ 'It.', 'Err') + '\n' + '-' * 19)
print('{:5d}|{:8e}|'.format(ii, err))
if err < stopThr and ii > 20:
break
@@ -2729,7 +2892,7 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100,
def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
numIterMax=10000, stopThr=1e-9, isLazy=False, batchSize=100, verbose=False,
- log=False, warn=True, **kwargs):
+ log=False, warn=True, warmstart=None, **kwargs):
r'''
Solve the entropic regularization optimal transport problem and return the
OT matrix from empirical data
@@ -2782,6 +2945,9 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
record log if True
warn : bool, optional
if True, raises a warning if the algorithm doesn't convergence.
+ warmstart: tuple of arrays, shape (dim_a, dim_b), optional
+ Initialization of dual potentials. If provided, the dual potentials should be given
+ (that is the logarithm of the u,v sinkhorn scaling vectors)
Returns
@@ -2832,14 +2998,18 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
dict_log = {"err": []}
log_a, log_b = nx.log(a), nx.log(b)
- f, g = nx.zeros((ns,), type_as=a), nx.zeros((nt,), type_as=a)
+ if warmstart is None:
+ f, g = nx.zeros((ns,), type_as=a), nx.zeros((nt,), type_as=a)
+ else:
+ f, g = warmstart
if isinstance(batchSize, int):
bs, bt = batchSize, batchSize
elif isinstance(batchSize, tuple) and len(batchSize) == 2:
bs, bt = batchSize[0], batchSize[1]
else:
- raise ValueError("Batch size must be in integer or a tuple of two integers")
+ raise ValueError(
+ "Batch size must be in integer or a tuple of two integers")
range_s, range_t = range(0, ns, bs), range(0, nt, bt)
@@ -2877,7 +3047,8 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
M = dist(X_s_np[i:i + bs, :], X_t_np, metric=metric)
M = nx.from_numpy(M, type_as=a)
m1_cols.append(
- nx.sum(nx.exp(f[i:i + bs, None] + g[None, :] - M / reg), axis=1)
+ nx.sum(nx.exp(f[i:i + bs, None] +
+ g[None, :] - M / reg), axis=1)
)
m1 = nx.concatenate(m1_cols, axis=0)
err = nx.sum(nx.abs(m1 - a))
@@ -2885,7 +3056,8 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
dict_log["err"].append(err)
if verbose and (i_ot + 1) % 100 == 0:
- print("Error in marginal at iteration {} = {}".format(i_ot + 1, err))
+ print("Error in marginal at iteration {} = {}".format(
+ i_ot + 1, err))
if err <= stopThr:
break
@@ -2905,17 +3077,17 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
M = dist(X_s, X_t, metric=metric)
if log:
pi, log = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr,
- verbose=verbose, log=True, **kwargs)
+ verbose=verbose, log=True, warmstart=warmstart, **kwargs)
return pi, log
else:
pi = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr,
- verbose=verbose, log=False, **kwargs)
+ verbose=verbose, log=False, warmstart=warmstart, **kwargs)
return pi
def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
- numIterMax=10000, stopThr=1e-9, isLazy=False,
- batchSize=100, verbose=False, log=False, warn=True, **kwargs):
+ numIterMax=10000, stopThr=1e-9, isLazy=False, batchSize=100,
+ verbose=False, log=False, warn=True, warmstart=None, **kwargs):
r'''
Solve the entropic regularization optimal transport problem from empirical
data and return the OT loss
@@ -2939,6 +3111,9 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
:math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1)
+ and returns :math:`\langle \gamma^*, \mathbf{M} \rangle_F` (without
+ the entropic contribution).
+
Parameters
----------
@@ -2969,7 +3144,9 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
record log if True
warn : bool, optional
if True, raises a warning if the algorithm doesn't convergence.
-
+ warmstart: tuple of arrays, shape (dim_a, dim_b), optional
+ Initialization of dual potentials. If provided, the dual potentials should be given
+ (that is the logarithm of the u,v sinkhorn scaling vectors)
Returns
-------
@@ -3025,13 +3202,16 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
isLazy=isLazy,
batchSize=batchSize,
verbose=verbose, log=log,
- warn=warn)
+ warn=warn,
+ warmstart=warmstart)
else:
f, g = empirical_sinkhorn(X_s, X_t, reg, a, b, metric,
- numIterMax=numIterMax, stopThr=stopThr,
+ numIterMax=numIterMax,
+ stopThr=stopThr,
isLazy=isLazy, batchSize=batchSize,
verbose=verbose, log=log,
- warn=warn)
+ warn=warn,
+ warmstart=warmstart)
bs = batchSize if isinstance(batchSize, int) else batchSize[0]
range_s = range(0, ns, bs)
@@ -3053,25 +3233,23 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
return loss
else:
- M = dist(nx.to_numpy(X_s), nx.to_numpy(X_t), metric=metric)
- M = nx.from_numpy(M, type_as=a)
+ M = dist(X_s, X_t, metric=metric)
if log:
sinkhorn_loss, log = sinkhorn2(a, b, M, reg, numItermax=numIterMax,
stopThr=stopThr, verbose=verbose, log=log,
- warn=warn, **kwargs)
+ warn=warn, warmstart=warmstart, **kwargs)
return sinkhorn_loss, log
else:
sinkhorn_loss = sinkhorn2(a, b, M, reg, numItermax=numIterMax,
stopThr=stopThr, verbose=verbose, log=log,
- warn=warn, **kwargs)
+ warn=warn, warmstart=warmstart, **kwargs)
return sinkhorn_loss
def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
- numIterMax=10000, stopThr=1e-9,
- verbose=False, log=False, warn=True,
- **kwargs):
+ numIterMax=10000, stopThr=1e-9, verbose=False,
+ log=False, warn=True, warmstart=None, **kwargs):
r'''
Compute the sinkhorn divergence loss from empirical data
@@ -3118,6 +3296,13 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
:math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1)
+ and returns :math:`\langle \gamma^*, \mathbf{M} \rangle_F -(\langle \gamma^*_a, \mathbf{M_a} \rangle_F + \langle
+ \gamma^*_b , \mathbf{M_b} \rangle_F)/2`.
+
+ .. note: The current implementation does not account for the entropic contributions and thus differs from the
+ Sinkhorn divergence as introduced in the literature. The possibility to account for the entropic contributions
+ will be provided in a future release.
+
Parameters
----------
@@ -3141,6 +3326,9 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
record log if True
warn : bool, optional
if True, raises a warning if the algorithm doesn't convergence.
+ warmstart: tuple of arrays, shape (dim_a, dim_b), optional
+ Initialization of dual potentials. If provided, the dual potentials should be given
+ (that is the logarithm of the u,v sinkhorn scaling vectors)
Returns
-------
@@ -3167,23 +3355,34 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
International Conference on Artficial Intelligence and Statistics,
(AISTATS) 21, 2018
'''
+ X_s, X_t = list_to_array(X_s, X_t)
+
+ nx = get_backend(X_s, X_t)
+ if warmstart is None:
+ warmstart_a, warmstart_b = None, None
+ else:
+ u, v = warmstart
+ warmstart_a = (u, u)
+ warmstart_b = (v, v)
+
if log:
sinkhorn_loss_ab, log_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric,
- numIterMax=numIterMax,
- stopThr=1e-9, verbose=verbose,
- log=log, warn=warn, **kwargs)
+ numIterMax=numIterMax, stopThr=stopThr,
+ verbose=verbose, log=log, warn=warn,
+ warmstart=warmstart, **kwargs)
sinkhorn_loss_a, log_a = empirical_sinkhorn2(X_s, X_s, reg, a, a, metric=metric,
- numIterMax=numIterMax,
- stopThr=1e-9, verbose=verbose,
- log=log, warn=warn, **kwargs)
+ numIterMax=numIterMax, stopThr=stopThr,
+ verbose=verbose, log=log, warn=warn,
+ warmstart=warmstart_a, **kwargs)
sinkhorn_loss_b, log_b = empirical_sinkhorn2(X_t, X_t, reg, b, b, metric=metric,
- numIterMax=numIterMax,
- stopThr=1e-9, verbose=verbose,
- log=log, warn=warn, **kwargs)
+ numIterMax=numIterMax, stopThr=stopThr,
+ verbose=verbose, log=log, warn=warn,
+ warmstart=warmstart_b, **kwargs)
- sinkhorn_div = sinkhorn_loss_ab - 0.5 * (sinkhorn_loss_a + sinkhorn_loss_b)
+ sinkhorn_div = sinkhorn_loss_ab - 0.5 * \
+ (sinkhorn_loss_a + sinkhorn_loss_b)
log = {}
log['sinkhorn_loss_ab'] = sinkhorn_loss_ab
@@ -3193,26 +3392,27 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
log['log_sinkhorn_a'] = log_a
log['log_sinkhorn_b'] = log_b
- return max(0, sinkhorn_div), log
+ return nx.maximum(0, sinkhorn_div), log
else:
sinkhorn_loss_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric,
- numIterMax=numIterMax, stopThr=1e-9,
- verbose=verbose, log=log,
- warn=warn, **kwargs)
+ numIterMax=numIterMax, stopThr=stopThr,
+ verbose=verbose, log=log, warn=warn,
+ warmstart=warmstart, **kwargs)
sinkhorn_loss_a = empirical_sinkhorn2(X_s, X_s, reg, a, a, metric=metric,
- numIterMax=numIterMax, stopThr=1e-9,
- verbose=verbose, log=log,
- warn=warn, **kwargs)
+ numIterMax=numIterMax, stopThr=stopThr,
+ verbose=verbose, log=log, warn=warn,
+ warmstart=warmstart_a, **kwargs)
sinkhorn_loss_b = empirical_sinkhorn2(X_t, X_t, reg, b, b, metric=metric,
- numIterMax=numIterMax, stopThr=1e-9,
- verbose=verbose, log=log,
- warn=warn, **kwargs)
+ numIterMax=numIterMax, stopThr=stopThr,
+ verbose=verbose, log=log, warn=warn,
+ warmstart=warmstart_b, **kwargs)
- sinkhorn_div = sinkhorn_loss_ab - 0.5 * (sinkhorn_loss_a + sinkhorn_loss_b)
- return max(0, sinkhorn_div)
+ sinkhorn_div = sinkhorn_loss_ab - 0.5 * \
+ (sinkhorn_loss_a + sinkhorn_loss_b)
+ return nx.maximum(0, sinkhorn_div)
def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False,
@@ -3379,7 +3579,8 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False,
epsilon_u_square = a[0] / aK_sort[ns_budget - 1]
else:
aK_sort = nx.from_numpy(
- bottleneck.partition(nx.to_numpy(K_sum_cols), ns_budget - 1)[ns_budget - 1],
+ bottleneck.partition(nx.to_numpy(
+ K_sum_cols), ns_budget - 1)[ns_budget - 1],
type_as=M
)
epsilon_u_square = a[0] / aK_sort
@@ -3389,7 +3590,8 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False,
epsilon_v_square = b[0] / bK_sort[nt_budget - 1]
else:
bK_sort = nx.from_numpy(
- bottleneck.partition(nx.to_numpy(K_sum_rows), nt_budget - 1)[nt_budget - 1],
+ bottleneck.partition(nx.to_numpy(
+ K_sum_rows), nt_budget - 1)[nt_budget - 1],
type_as=M
)
epsilon_v_square = b[0] / bK_sort
diff --git a/ot/coot.py b/ot/coot.py
new file mode 100644
index 0000000..66dd2c8
--- /dev/null
+++ b/ot/coot.py
@@ -0,0 +1,434 @@
+# -*- coding: utf-8 -*-
+"""
+CO-Optimal Transport solver
+"""
+
+# Author: Quang Huy Tran <quang-huy.tran@univ-ubs.fr>
+#
+# License: MIT License
+
+import warnings
+from .lp import emd
+from .utils import list_to_array
+from .backend import get_backend
+from .bregman import sinkhorn
+
+
+def co_optimal_transport(X, Y, wx_samp=None, wx_feat=None, wy_samp=None, wy_feat=None,
+ epsilon=0, alpha=0, M_samp=None, M_feat=None,
+ warmstart=None, nits_bcd=100, tol_bcd=1e-7, eval_bcd=1,
+ nits_ot=500, tol_sinkhorn=1e-7, method_sinkhorn="sinkhorn",
+ early_stopping_tol=1e-6, log=False, verbose=False):
+ r"""Compute the CO-Optimal Transport between two matrices.
+
+ Return the sample and feature transport plans between
+ :math:`(\mathbf{X}, \mathbf{w}_{xs}, \mathbf{w}_{xf})` and
+ :math:`(\mathbf{Y}, \mathbf{w}_{ys}, \mathbf{w}_{yf})`.
+
+ The function solves the following CO-Optimal Transport (COOT) problem:
+
+ .. math::
+ \mathbf{COOT}_{\alpha, \varepsilon} = \mathop{\arg \min}_{\mathbf{P}, \mathbf{Q}}
+ &\quad \sum_{i,j,k,l}
+ (\mathbf{X}_{i,k} - \mathbf{Y}_{j,l})^2 \mathbf{P}_{i,j} \mathbf{Q}_{k,l}
+ + \alpha_s \sum_{i,j} \mathbf{P}_{i,j} \mathbf{M^{(s)}}_{i, j} \\
+ &+ \alpha_f \sum_{k, l} \mathbf{Q}_{k,l} \mathbf{M^{(f)}}_{k, l}
+ + \varepsilon_s \mathbf{KL}(\mathbf{P} | \mathbf{w}_{xs} \mathbf{w}_{ys}^T)
+ + \varepsilon_f \mathbf{KL}(\mathbf{Q} | \mathbf{w}_{xf} \mathbf{w}_{yf}^T)
+
+ Where :
+
+ - :math:`\mathbf{X}`: Data matrix in the source space
+ - :math:`\mathbf{Y}`: Data matrix in the target space
+ - :math:`\mathbf{M^{(s)}}`: Additional sample matrix
+ - :math:`\mathbf{M^{(f)}}`: Additional feature matrix
+ - :math:`\mathbf{w}_{xs}`: Distribution of the samples in the source space
+ - :math:`\mathbf{w}_{xf}`: Distribution of the features in the source space
+ - :math:`\mathbf{w}_{ys}`: Distribution of the samples in the target space
+ - :math:`\mathbf{w}_{yf}`: Distribution of the features in the target space
+
+ .. note:: This function allows epsilon to be zero.
+ In that case, the :any:`ot.lp.emd` solver of POT will be used.
+
+ Parameters
+ ----------
+ X : (n_sample_x, n_feature_x) array-like, float
+ First input matrix.
+ Y : (n_sample_y, n_feature_y) array-like, float
+ Second input matrix.
+ wx_samp : (n_sample_x, ) array-like, float, optional (default = None)
+ Histogram assigned on rows (samples) of matrix X.
+ Uniform distribution by default.
+ wx_feat : (n_feature_x, ) array-like, float, optional (default = None)
+ Histogram assigned on columns (features) of matrix X.
+ Uniform distribution by default.
+ wy_samp : (n_sample_y, ) array-like, float, optional (default = None)
+ Histogram assigned on rows (samples) of matrix Y.
+ Uniform distribution by default.
+ wy_feat : (n_feature_y, ) array-like, float, optional (default = None)
+ Histogram assigned on columns (features) of matrix Y.
+ Uniform distribution by default.
+ epsilon : scalar or indexable object of length 2, float or int, optional (default = 0)
+ Regularization parameters for entropic approximation of sample and feature couplings.
+ Allow the case where epsilon contains 0. In that case, the EMD solver is used instead of
+ Sinkhorn solver. If epsilon is scalar, then the same epsilon is applied to
+ both regularization of sample and feature couplings.
+ alpha : scalar or indexable object of length 2, float or int, optional (default = 0)
+ Coeffficient parameter of linear terms with respect to the sample and feature couplings.
+ If alpha is scalar, then the same alpha is applied to both linear terms.
+ M_samp : (n_sample_x, n_sample_y), float, optional (default = None)
+ Sample matrix with respect to the linear term on sample coupling.
+ M_feat : (n_feature_x, n_feature_y), float, optional (default = None)
+ Feature matrix with respect to the linear term on feature coupling.
+ warmstart : dictionary, optional (default = None)
+ Contains 4 keys:
+ - "duals_sample" and "duals_feature" whose values are
+ tuples of 2 vectors of size (n_sample_x, n_sample_y) and (n_feature_x, n_feature_y).
+ Initialization of sample and feature dual vectors
+ if using Sinkhorn algorithm. Zero vectors by default.
+
+ - "pi_sample" and "pi_feature" whose values are matrices
+ of size (n_sample_x, n_sample_y) and (n_feature_x, n_feature_y).
+ Initialization of sample and feature couplings.
+ Uniform distributions by default.
+ nits_bcd : int, optional (default = 100)
+ Number of Block Coordinate Descent (BCD) iterations to solve COOT.
+ tol_bcd : float, optional (default = 1e-7)
+ Tolerance of BCD scheme. If the L1-norm between the current and previous
+ sample couplings is under this threshold, then stop BCD scheme.
+ eval_bcd : int, optional (default = 1)
+ Multiplier of iteration at which the COOT cost is evaluated. For example,
+ if `eval_bcd = 8`, then the cost is calculated at iterations 8, 16, 24, etc...
+ nits_ot : int, optional (default = 100)
+ Number of iterations to solve each of the
+ two optimal transport problems in each BCD iteration.
+ tol_sinkhorn : float, optional (default = 1e-7)
+ Tolerance of Sinkhorn algorithm to stop the Sinkhorn scheme for
+ entropic optimal transport problem (if any) in each BCD iteration.
+ Only triggered when Sinkhorn solver is used.
+ method_sinkhorn : string, optional (default = "sinkhorn")
+ Method used in POT's `ot.sinkhorn` solver.
+ Only support "sinkhorn" and "sinkhorn_log".
+ early_stopping_tol : float, optional (default = 1e-6)
+ Tolerance for the early stopping. If the absolute difference between
+ the last 2 recorded COOT distances is under this tolerance, then stop BCD scheme.
+ log : bool, optional (default = False)
+ If True then the cost and 4 dual vectors, including
+ 2 from sample and 2 from feature couplings, are recorded.
+ verbose : bool, optional (default = False)
+ If True then print the COOT cost at every multiplier of `eval_bcd`-th iteration.
+
+ Returns
+ -------
+ pi_samp : (n_sample_x, n_sample_y) array-like, float
+ Sample coupling matrix.
+ pi_feat : (n_feature_x, n_feature_y) array-like, float
+ Feature coupling matrix.
+ log : dictionary, optional
+ Returned if `log` is True. The keys are:
+ duals_sample : (n_sample_x, n_sample_y) tuple, float
+ Pair of dual vectors when solving OT problem w.r.t the sample coupling.
+ duals_feature : (n_feature_x, n_feature_y) tuple, float
+ Pair of dual vectors when solving OT problem w.r.t the feature coupling.
+ distances : list, float
+ List of COOT distances.
+
+ References
+ ----------
+ .. [49] I. Redko, T. Vayer, R. Flamary, and N. Courty, CO-Optimal Transport,
+ Advances in Neural Information Processing ny_sampstems, 33 (2020).
+ """
+
+ def compute_kl(p, q):
+ kl = nx.sum(p * nx.log(p + 1.0 * (p == 0))) - nx.sum(p * nx.log(q))
+ return kl
+
+ # Main function
+
+ if method_sinkhorn not in ["sinkhorn", "sinkhorn_log"]:
+ raise ValueError(
+ "Method {} is not supported in CO-Optimal Transport.".format(method_sinkhorn))
+
+ X, Y = list_to_array(X, Y)
+ nx = get_backend(X, Y)
+
+ if isinstance(epsilon, float) or isinstance(epsilon, int):
+ eps_samp, eps_feat = epsilon, epsilon
+ else:
+ if len(epsilon) != 2:
+ raise ValueError("Epsilon must be either a scalar or an indexable object of length 2.")
+ else:
+ eps_samp, eps_feat = epsilon[0], epsilon[1]
+
+ if isinstance(alpha, float) or isinstance(alpha, int):
+ alpha_samp, alpha_feat = alpha, alpha
+ else:
+ if len(alpha) != 2:
+ raise ValueError("Alpha must be either a scalar or an indexable object of length 2.")
+ else:
+ alpha_samp, alpha_feat = alpha[0], alpha[1]
+
+ # constant input variables
+ if M_samp is None or alpha_samp == 0:
+ M_samp, alpha_samp = 0, 0
+ if M_feat is None or alpha_feat == 0:
+ M_feat, alpha_feat = 0, 0
+
+ nx_samp, nx_feat = X.shape
+ ny_samp, ny_feat = Y.shape
+
+ # measures on rows and columns
+ if wx_samp is None:
+ wx_samp = nx.ones(nx_samp, type_as=X) / nx_samp
+ if wx_feat is None:
+ wx_feat = nx.ones(nx_feat, type_as=X) / nx_feat
+ if wy_samp is None:
+ wy_samp = nx.ones(ny_samp, type_as=Y) / ny_samp
+ if wy_feat is None:
+ wy_feat = nx.ones(ny_feat, type_as=Y) / ny_feat
+
+ wxy_samp = wx_samp[:, None] * wy_samp[None, :]
+ wxy_feat = wx_feat[:, None] * wy_feat[None, :]
+
+ # pre-calculate cost constants
+ XY_sqr = (X ** 2 @ wx_feat)[:, None] + (Y ** 2 @
+ wy_feat)[None, :] + alpha_samp * M_samp
+ XY_sqr_T = ((X.T)**2 @ wx_samp)[:, None] + ((Y.T)
+ ** 2 @ wy_samp)[None, :] + alpha_feat * M_feat
+
+ # initialize coupling and dual vectors
+ if warmstart is None:
+ pi_samp, pi_feat = wxy_samp, wxy_feat # shape nx_samp x ny_samp and nx_feat x ny_feat
+ duals_samp = (nx.zeros(nx_samp, type_as=X), nx.zeros(
+ ny_samp, type_as=Y)) # shape nx_samp, ny_samp
+ duals_feat = (nx.zeros(nx_feat, type_as=X), nx.zeros(
+ ny_feat, type_as=Y)) # shape nx_feat, ny_feat
+ else:
+ pi_samp, pi_feat = warmstart["pi_sample"], warmstart["pi_feature"]
+ duals_samp, duals_feat = warmstart["duals_sample"], warmstart["duals_feature"]
+
+ # initialize log
+ list_coot = [float("inf")]
+ err = tol_bcd + 1e-3
+
+ for idx in range(nits_bcd):
+ pi_samp_prev = nx.copy(pi_samp)
+
+ # update sample coupling
+ ot_cost = XY_sqr - 2 * X @ pi_feat @ Y.T # size nx_samp x ny_samp
+ if eps_samp > 0:
+ pi_samp, dict_log = sinkhorn(a=wx_samp, b=wy_samp, M=ot_cost, reg=eps_samp, method=method_sinkhorn,
+ numItermax=nits_ot, stopThr=tol_sinkhorn, log=True, warmstart=duals_samp)
+ duals_samp = (nx.log(dict_log["u"]), nx.log(dict_log["v"]))
+ elif eps_samp == 0:
+ pi_samp, dict_log = emd(
+ a=wx_samp, b=wy_samp, M=ot_cost, numItermax=nits_ot, log=True)
+ duals_samp = (dict_log["u"], dict_log["v"])
+ # update feature coupling
+ ot_cost = XY_sqr_T - 2 * X.T @ pi_samp @ Y # size nx_feat x ny_feat
+ if eps_feat > 0:
+ pi_feat, dict_log = sinkhorn(a=wx_feat, b=wy_feat, M=ot_cost, reg=eps_feat, method=method_sinkhorn,
+ numItermax=nits_ot, stopThr=tol_sinkhorn, log=True, warmstart=duals_feat)
+ duals_feat = (nx.log(dict_log["u"]), nx.log(dict_log["v"]))
+ elif eps_feat == 0:
+ pi_feat, dict_log = emd(
+ a=wx_feat, b=wy_feat, M=ot_cost, numItermax=nits_ot, log=True)
+ duals_feat = (dict_log["u"], dict_log["v"])
+
+ if idx % eval_bcd == 0:
+ # update error
+ err = nx.sum(nx.abs(pi_samp - pi_samp_prev))
+
+ # COOT part
+ coot = nx.sum(ot_cost * pi_feat)
+ if alpha_samp != 0:
+ coot = coot + alpha_samp * nx.sum(M_samp * pi_samp)
+ # Entropic part
+ if eps_samp != 0:
+ coot = coot + eps_samp * compute_kl(pi_samp, wxy_samp)
+ if eps_feat != 0:
+ coot = coot + eps_feat * compute_kl(pi_feat, wxy_feat)
+ list_coot.append(coot)
+
+ if err < tol_bcd or abs(list_coot[-2] - list_coot[-1]) < early_stopping_tol:
+ break
+
+ if verbose:
+ print(
+ "CO-Optimal Transport cost at iteration {}: {}".format(idx + 1, coot))
+
+ # sanity check
+ if nx.sum(nx.isnan(pi_samp)) > 0 or nx.sum(nx.isnan(pi_feat)) > 0:
+ warnings.warn("There is NaN in coupling.")
+
+ if log:
+ dict_log = {"duals_sample": duals_samp,
+ "duals_feature": duals_feat,
+ "distances": list_coot[1:]}
+
+ return pi_samp, pi_feat, dict_log
+
+ else:
+ return pi_samp, pi_feat
+
+
+def co_optimal_transport2(X, Y, wx_samp=None, wx_feat=None, wy_samp=None, wy_feat=None,
+ epsilon=0, alpha=0, M_samp=None, M_feat=None,
+ warmstart=None, log=False, verbose=False, early_stopping_tol=1e-6,
+ nits_bcd=100, tol_bcd=1e-7, eval_bcd=1,
+ nits_ot=500, tol_sinkhorn=1e-7,
+ method_sinkhorn="sinkhorn"):
+ r"""Compute the CO-Optimal Transport distance between two measures.
+
+ Returns the CO-Optimal Transport distance between
+ :math:`(\mathbf{X}, \mathbf{w}_{xs}, \mathbf{w}_{xf})` and
+ :math:`(\mathbf{Y}, \mathbf{w}_{ys}, \mathbf{w}_{yf})`.
+
+ The function solves the following CO-Optimal Transport (COOT) problem:
+
+ .. math::
+ \mathbf{COOT}_{\alpha, \varepsilon} = \mathop{\arg \min}_{\mathbf{P}, \mathbf{Q}}
+ &\quad \sum_{i,j,k,l}
+ (\mathbf{X}_{i,k} - \mathbf{Y}_{j,l})^2 \mathbf{P}_{i,j} \mathbf{Q}_{k,l}
+ + \alpha_1 \sum_{i,j} \mathbf{P}_{i,j} \mathbf{M^{(s)}}_{i, j} \\
+ &+ \alpha_2 \sum_{k, l} \mathbf{Q}_{k,l} \mathbf{M^{(f)}}_{k, l}
+ + \varepsilon_1 \mathbf{KL}(\mathbf{P} | \mathbf{w}_{xs} \mathbf{w}_{ys}^T)
+ + \varepsilon_2 \mathbf{KL}(\mathbf{Q} | \mathbf{w}_{xf} \mathbf{w}_{yf}^T)
+
+ Where :
+
+ - :math:`\mathbf{X}`: Data matrix in the source space
+ - :math:`\mathbf{Y}`: Data matrix in the target space
+ - :math:`\mathbf{M^{(s)}}`: Additional sample matrix
+ - :math:`\mathbf{M^{(f)}}`: Additional feature matrix
+ - :math:`\mathbf{w}_{xs}`: Distribution of the samples in the source space
+ - :math:`\mathbf{w}_{xf}`: Distribution of the features in the source space
+ - :math:`\mathbf{w}_{ys}`: Distribution of the samples in the target space
+ - :math:`\mathbf{w}_{yf}`: Distribution of the features in the target space
+
+ .. note:: This function allows epsilon to be zero.
+ In that case, the :any:`ot.lp.emd` solver of POT will be used.
+
+ Parameters
+ ----------
+ X : (n_sample_x, n_feature_x) array-like, float
+ First input matrix.
+ Y : (n_sample_y, n_feature_y) array-like, float
+ Second input matrix.
+ wx_samp : (n_sample_x, ) array-like, float, optional (default = None)
+ Histogram assigned on rows (samples) of matrix X.
+ Uniform distribution by default.
+ wx_feat : (n_feature_x, ) array-like, float, optional (default = None)
+ Histogram assigned on columns (features) of matrix X.
+ Uniform distribution by default.
+ wy_samp : (n_sample_y, ) array-like, float, optional (default = None)
+ Histogram assigned on rows (samples) of matrix Y.
+ Uniform distribution by default.
+ wy_feat : (n_feature_y, ) array-like, float, optional (default = None)
+ Histogram assigned on columns (features) of matrix Y.
+ Uniform distribution by default.
+ epsilon : scalar or indexable object of length 2, float or int, optional (default = 0)
+ Regularization parameters for entropic approximation of sample and feature couplings.
+ Allow the case where epsilon contains 0. In that case, the EMD solver is used instead of
+ Sinkhorn solver. If epsilon is scalar, then the same epsilon is applied to
+ both regularization of sample and feature couplings.
+ alpha : scalar or indexable object of length 2, float or int, optional (default = 0)
+ Coeffficient parameter of linear terms with respect to the sample and feature couplings.
+ If alpha is scalar, then the same alpha is applied to both linear terms.
+ M_samp : (n_sample_x, n_sample_y), float, optional (default = None)
+ Sample matrix with respect to the linear term on sample coupling.
+ M_feat : (n_feature_x, n_feature_y), float, optional (default = None)
+ Feature matrix with respect to the linear term on feature coupling.
+ warmstart : dictionary, optional (default = None)
+ Contains 4 keys:
+ - "duals_sample" and "duals_feature" whose values are
+ tuples of 2 vectors of size (n_sample_x, n_sample_y) and (n_feature_x, n_feature_y).
+ Initialization of sample and feature dual vectors
+ if using Sinkhorn algorithm. Zero vectors by default.
+
+ - "pi_sample" and "pi_feature" whose values are matrices
+ of size (n_sample_x, n_sample_y) and (n_feature_x, n_feature_y).
+ Initialization of sample and feature couplings.
+ Uniform distributions by default.
+ nits_bcd : int, optional (default = 100)
+ Number of Block Coordinate Descent (BCD) iterations to solve COOT.
+ tol_bcd : float, optional (default = 1e-7)
+ Tolerance of BCD scheme. If the L1-norm between the current and previous
+ sample couplings is under this threshold, then stop BCD scheme.
+ eval_bcd : int, optional (default = 1)
+ Multiplier of iteration at which the COOT cost is evaluated. For example,
+ if `eval_bcd = 8`, then the cost is calculated at iterations 8, 16, 24, etc...
+ nits_ot : int, optional (default = 100)
+ Number of iterations to solve each of the
+ two optimal transport problems in each BCD iteration.
+ tol_sinkhorn : float, optional (default = 1e-7)
+ Tolerance of Sinkhorn algorithm to stop the Sinkhorn scheme for
+ entropic optimal transport problem (if any) in each BCD iteration.
+ Only triggered when Sinkhorn solver is used.
+ method_sinkhorn : string, optional (default = "sinkhorn")
+ Method used in POT's `ot.sinkhorn` solver.
+ Only support "sinkhorn" and "sinkhorn_log".
+ early_stopping_tol : float, optional (default = 1e-6)
+ Tolerance for the early stopping. If the absolute difference between
+ the last 2 recorded COOT distances is under this tolerance, then stop BCD scheme.
+ log : bool, optional (default = False)
+ If True then the cost and 4 dual vectors, including
+ 2 from sample and 2 from feature couplings, are recorded.
+ verbose : bool, optional (default = False)
+ If True then print the COOT cost at every multiplier of `eval_bcd`-th iteration.
+
+ Returns
+ -------
+ float
+ CO-Optimal Transport distance.
+ dict
+ Contains logged informations from :any:`co_optimal_transport` solver.
+ Only returned if `log` parameter is True
+
+ References
+ ----------
+ .. [47] I. Redko, T. Vayer, R. Flamary, and N. Courty, CO-Optimal Transport,
+ Advances in Neural Information Processing ny_sampstems, 33 (2020).
+ """
+
+ pi_samp, pi_feat, dict_log = co_optimal_transport(X=X, Y=Y, wx_samp=wx_samp, wx_feat=wx_feat, wy_samp=wy_samp,
+ wy_feat=wy_feat, epsilon=epsilon, alpha=alpha, M_samp=M_samp,
+ M_feat=M_feat, warmstart=warmstart, nits_bcd=nits_bcd,
+ tol_bcd=tol_bcd, eval_bcd=eval_bcd, nits_ot=nits_ot,
+ tol_sinkhorn=tol_sinkhorn, method_sinkhorn=method_sinkhorn,
+ early_stopping_tol=early_stopping_tol,
+ log=True, verbose=verbose)
+
+ X, Y = list_to_array(X, Y)
+ nx = get_backend(X, Y)
+
+ nx_samp, nx_feat = X.shape
+ ny_samp, ny_feat = Y.shape
+
+ # measures on rows and columns
+ if wx_samp is None:
+ wx_samp = nx.ones(nx_samp, type_as=X) / nx_samp
+ if wx_feat is None:
+ wx_feat = nx.ones(nx_feat, type_as=X) / nx_feat
+ if wy_samp is None:
+ wy_samp = nx.ones(ny_samp, type_as=Y) / ny_samp
+ if wy_feat is None:
+ wy_feat = nx.ones(ny_feat, type_as=Y) / ny_feat
+
+ vx_samp, vy_samp = dict_log["duals_sample"]
+ vx_feat, vy_feat = dict_log["duals_feature"]
+
+ gradX = 2 * X * (wx_samp[:, None] * wx_feat[None, :]) - \
+ 2 * pi_samp @ Y @ pi_feat.T # shape (nx_samp, nx_feat)
+ gradY = 2 * Y * (wy_samp[:, None] * wy_feat[None, :]) - \
+ 2 * pi_samp.T @ X @ pi_feat # shape (ny_samp, ny_feat)
+
+ coot = dict_log["distances"][-1]
+ coot = nx.set_gradients(coot, (wx_samp, wx_feat, wy_samp, wy_feat, X, Y),
+ (vx_samp, vx_feat, vy_samp, vy_feat, gradX, gradY))
+
+ if log:
+ return coot, dict_log
+
+ else:
+ return coot
diff --git a/ot/da.py b/ot/da.py
index 0b9737e..5067a69 100644
--- a/ot/da.py
+++ b/ot/da.py
@@ -17,8 +17,9 @@ from .backend import get_backend
from .bregman import sinkhorn, jcpot_barycenter
from .lp import emd
from .utils import unif, dist, kernel, cost_normalization, label_normalization, laplacian, dots
-from .utils import list_to_array, check_params, BaseEstimator
+from .utils import list_to_array, check_params, BaseEstimator, deprecated
from .unbalanced import sinkhorn_unbalanced
+from .gaussian import empirical_bures_wasserstein_mapping
from .optim import cg
from .optim import gcg
@@ -126,8 +127,12 @@ def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10,
W = nx.zeros(M.shape, type_as=M)
for cpt in range(numItermax):
Mreg = M + eta * W
- transp = sinkhorn(a, b, Mreg, reg, numItermax=numInnerItermax,
- stopThr=stopInnerThr)
+ if log:
+ transp, log = sinkhorn(a, b, Mreg, reg, numItermax=numInnerItermax,
+ stopThr=stopInnerThr, log=True)
+ else:
+ transp = sinkhorn(a, b, Mreg, reg, numItermax=numInnerItermax,
+ stopThr=stopInnerThr)
# the transport has been computed. Check if classes are really
# separated
W = nx.ones(M.shape, type_as=M)
@@ -136,7 +141,10 @@ def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10,
majs = p * ((majs + epsilon) ** (p - 1))
W[indices_labels[i]] = majs
- return transp
+ if log:
+ return transp, log
+ else:
+ return transp
def sinkhorn_l1l2_gl(a, labels_a, b, M, reg, eta=0.1, numItermax=10,
@@ -672,112 +680,7 @@ def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian',
return G, L
-def OT_mapping_linear(xs, xt, reg=1e-6, ws=None,
- wt=None, bias=True, log=False):
- r"""Return OT linear operator between samples.
-
- The function estimates the optimal linear operator that aligns the two
- empirical distributions. This is equivalent to estimating the closed
- form mapping between two Gaussian distributions :math:`\mathcal{N}(\mu_s,\Sigma_s)`
- and :math:`\mathcal{N}(\mu_t,\Sigma_t)` as proposed in
- :ref:`[14] <references-OT-mapping-linear>` and discussed in remark 2.29 in
- :ref:`[15] <references-OT-mapping-linear>`.
-
- The linear operator from source to target :math:`M`
-
- .. math::
- M(\mathbf{x})= \mathbf{A} \mathbf{x} + \mathbf{b}
-
- where :
-
- .. math::
- \mathbf{A} &= \Sigma_s^{-1/2} \left(\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2} \right)^{1/2}
- \Sigma_s^{-1/2}
-
- \mathbf{b} &= \mu_t - \mathbf{A} \mu_s
-
- Parameters
- ----------
- xs : array-like (ns,d)
- samples in the source domain
- xt : array-like (nt,d)
- samples in the target domain
- reg : float,optional
- regularization added to the diagonals of covariances (>0)
- ws : array-like (ns,1), optional
- weights for the source samples
- wt : array-like (ns,1), optional
- weights for the target samples
- bias: boolean, optional
- estimate bias :math:`\mathbf{b}` else :math:`\mathbf{b} = 0` (default:True)
- log : bool, optional
- record log if True
-
-
- Returns
- -------
- A : (d, d) array-like
- Linear operator
- b : (1, d) array-like
- bias
- log : dict
- log dictionary return only if log==True in parameters
-
-
- .. _references-OT-mapping-linear:
- References
- ----------
- .. [14] Knott, M. and Smith, C. S. "On the optimal mapping of
- distributions", Journal of Optimization Theory and Applications
- Vol 43, 1984
-
- .. [15] Peyré, G., & Cuturi, M. (2017). "Computational Optimal
- Transport", 2018.
-
-
- """
- xs, xt = list_to_array(xs, xt)
- nx = get_backend(xs, xt)
-
- d = xs.shape[1]
-
- if bias:
- mxs = nx.mean(xs, axis=0)[None, :]
- mxt = nx.mean(xt, axis=0)[None, :]
-
- xs = xs - mxs
- xt = xt - mxt
- else:
- mxs = nx.zeros((1, d), type_as=xs)
- mxt = nx.zeros((1, d), type_as=xs)
-
- if ws is None:
- ws = nx.ones((xs.shape[0], 1), type_as=xs) / xs.shape[0]
-
- if wt is None:
- wt = nx.ones((xt.shape[0], 1), type_as=xt) / xt.shape[0]
-
- Cs = nx.dot((xs * ws).T, xs) / nx.sum(ws) + reg * nx.eye(d, type_as=xs)
- Ct = nx.dot((xt * wt).T, xt) / nx.sum(wt) + reg * nx.eye(d, type_as=xt)
-
- Cs12 = nx.sqrtm(Cs)
- Cs_12 = nx.inv(Cs12)
-
- M0 = nx.sqrtm(dots(Cs12, Ct, Cs12))
-
- A = dots(Cs_12, M0, Cs_12)
-
- b = mxt - nx.dot(mxs, A)
-
- if log:
- log = {}
- log['Cs'] = Cs
- log['Ct'] = Ct
- log['Cs12'] = Cs12
- log['Cs_12'] = Cs_12
- return A, b, log
- else:
- return A, b
+OT_mapping_linear = deprecated(empirical_bures_wasserstein_mapping)
def emd_laplace(a, b, xs, xt, M, sim='knn', sim_param=None, reg='pos', eta=1, alpha=.5,
@@ -1371,10 +1274,10 @@ class LinearTransport(BaseTransport):
self.mu_t = self.distribution_estimation(Xt)
# coupling estimation
- returned_ = OT_mapping_linear(Xs, Xt, reg=self.reg,
- ws=nx.reshape(self.mu_s, (-1, 1)),
- wt=nx.reshape(self.mu_t, (-1, 1)),
- bias=self.bias, log=self.log)
+ returned_ = empirical_bures_wasserstein_mapping(Xs, Xt, reg=self.reg,
+ ws=nx.reshape(self.mu_s, (-1, 1)),
+ wt=nx.reshape(self.mu_t, (-1, 1)),
+ bias=self.bias, log=self.log)
# deal with the value of log
if self.log:
@@ -1514,12 +1417,13 @@ class SinkhornTransport(BaseTransport):
Sciences, 7(3), 1853-1882.
"""
- def __init__(self, reg_e=1., max_iter=1000,
+ def __init__(self, reg_e=1., method="sinkhorn", max_iter=1000,
tol=10e-9, verbose=False, log=False,
metric="sqeuclidean", norm=None,
distribution_estimation=distribution_estimation_uniform,
out_of_sample_map='ferradans', limit_max=np.infty):
self.reg_e = reg_e
+ self.method = method
self.max_iter = max_iter
self.tol = tol
self.verbose = verbose
@@ -1560,7 +1464,7 @@ class SinkhornTransport(BaseTransport):
# coupling estimation
returned_ = sinkhorn(
a=self.mu_s, b=self.mu_t, M=self.cost_, reg=self.reg_e,
- numItermax=self.max_iter, stopThr=self.tol,
+ method=self.method, numItermax=self.max_iter, stopThr=self.tol,
verbose=self.verbose, log=self.log)
# deal with the value of log
diff --git a/ot/dr.py b/ot/dr.py
index 0955c55..b92cd14 100644
--- a/ot/dr.py
+++ b/ot/dr.py
@@ -17,10 +17,10 @@ Dimension reduction with OT
from scipy import linalg
import autograd.numpy as np
-from pymanopt.function import Autograd
-from pymanopt.manifolds import Stiefel
-from pymanopt import Problem
-from pymanopt.solvers import SteepestDescent, TrustRegions
+
+import pymanopt
+import pymanopt.manifolds
+import pymanopt.optimizers
def dist(x1, x2):
@@ -38,8 +38,8 @@ def sinkhorn(w1, w2, M, reg, k):
ui = np.ones((M.shape[0],))
vi = np.ones((M.shape[1],))
for i in range(k):
- vi = w2 / (np.dot(K.T, ui))
- ui = w1 / (np.dot(K, vi))
+ vi = w2 / (np.dot(K.T, ui) + 1e-50)
+ ui = w1 / (np.dot(K, vi) + 1e-50)
G = ui.reshape((M.shape[0], 1)) * K * vi.reshape((1, M.shape[1]))
return G
@@ -167,7 +167,7 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, sinkhorn_method='sinkhorn', maxiter
Size of dimensionnality reduction.
reg : float, optional
Regularization term >0 (entropic regularization)
- solver : None | str, optional
+ solver : None | str, optional
None for steepest descent or 'TrustRegions' for trust regions algorithm
else should be a pymanopt.solvers
sinkhorn_method : str
@@ -222,7 +222,9 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, sinkhorn_method='sinkhorn', maxiter
else:
regmean = np.ones((len(xc), len(xc)))
- @Autograd
+ manifold = pymanopt.manifolds.Stiefel(d, p)
+
+ @pymanopt.function.autograd(manifold)
def cost(P):
# wda loss
loss_b = 0
@@ -243,21 +245,21 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, sinkhorn_method='sinkhorn', maxiter
return loss_w / loss_b
# declare manifold and problem
- manifold = Stiefel(d, p)
- problem = Problem(manifold=manifold, cost=cost)
+
+ problem = pymanopt.Problem(manifold=manifold, cost=cost)
# declare solver and solve
if solver is None:
- solver = SteepestDescent(maxiter=maxiter, logverbosity=verbose)
+ solver = pymanopt.optimizers.SteepestDescent(max_iterations=maxiter, log_verbosity=verbose)
elif solver in ['tr', 'TrustRegions']:
- solver = TrustRegions(maxiter=maxiter, logverbosity=verbose)
+ solver = pymanopt.optimizers.TrustRegions(max_iterations=maxiter, log_verbosity=verbose)
- Popt = solver.solve(problem, x=P0)
+ Popt = solver.run(problem, initial_point=P0)
def proj(X):
- return (X - mx.reshape((1, -1))).dot(Popt)
+ return (X - mx.reshape((1, -1))).dot(Popt.point)
- return Popt, proj
+ return Popt.point, proj
def projection_robust_wasserstein(X, Y, a, b, tau, U0=None, reg=0.1, k=2, stopThr=1e-3, maxiter=100, verbose=0):
diff --git a/ot/gaussian.py b/ot/gaussian.py
new file mode 100644
index 0000000..4ffb726
--- /dev/null
+++ b/ot/gaussian.py
@@ -0,0 +1,333 @@
+# -*- coding: utf-8 -*-
+"""
+Optimal transport for Gaussian distributions
+"""
+
+# Author: Theo Gnassounou <theo.gnassounou@inria.fr>
+# Remi Flamary <remi.flamary@polytehnique.edu>
+#
+# License: MIT License
+
+from .backend import get_backend
+from .utils import dots
+from .utils import list_to_array
+
+
+def bures_wasserstein_mapping(ms, mt, Cs, Ct, log=False):
+ r"""Return OT linear operator between samples.
+
+ The function estimates the optimal linear operator that aligns the two
+ empirical distributions. This is equivalent to estimating the closed
+ form mapping between two Gaussian distributions :math:`\mathcal{N}(\mu_s,\Sigma_s)`
+ and :math:`\mathcal{N}(\mu_t,\Sigma_t)` as proposed in
+ :ref:`[1] <references-OT-mapping-linear>` and discussed in remark 2.29 in
+ :ref:`[2] <references-OT-mapping-linear>`.
+
+ The linear operator from source to target :math:`M`
+
+ .. math::
+ M(\mathbf{x})= \mathbf{A} \mathbf{x} + \mathbf{b}
+
+ where :
+
+ .. math::
+ \mathbf{A} &= \Sigma_s^{-1/2} \left(\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2} \right)^{1/2}
+ \Sigma_s^{-1/2}
+
+ \mathbf{b} &= \mu_t - \mathbf{A} \mu_s
+
+ Parameters
+ ----------
+ ms : array-like (d,)
+ mean of the source distribution
+ mt : array-like (d,)
+ mean of the target distribution
+ Cs : array-like (d,)
+ covariance of the source distribution
+ Ct : array-like (d,)
+ covariance of the target distribution
+ log : bool, optional
+ record log if True
+
+
+ Returns
+ -------
+ A : (d, d) array-like
+ Linear operator
+ b : (1, d) array-like
+ bias
+ log : dict
+ log dictionary return only if log==True in parameters
+
+
+ .. _references-OT-mapping-linear:
+ References
+ ----------
+ .. [1] Knott, M. and Smith, C. S. "On the optimal mapping of
+ distributions", Journal of Optimization Theory and Applications
+ Vol 43, 1984
+
+ .. [2] Peyré, G., & Cuturi, M. (2017). "Computational Optimal
+ Transport", 2018.
+ """
+ ms, mt, Cs, Ct = list_to_array(ms, mt, Cs, Ct)
+ nx = get_backend(ms, mt, Cs, Ct)
+
+ Cs12 = nx.sqrtm(Cs)
+ Cs12inv = nx.inv(Cs12)
+
+ M0 = nx.sqrtm(dots(Cs12, Ct, Cs12))
+
+ A = dots(Cs12inv, M0, Cs12inv)
+
+ b = mt - nx.dot(ms, A)
+
+ if log:
+ log = {}
+ log['Cs12'] = Cs12
+ log['Cs12inv'] = Cs12inv
+ return A, b, log
+ else:
+ return A, b
+
+
+def empirical_bures_wasserstein_mapping(xs, xt, reg=1e-6, ws=None,
+ wt=None, bias=True, log=False):
+ r"""Return OT linear operator between samples.
+
+ The function estimates the optimal linear operator that aligns the two
+ empirical distributions. This is equivalent to estimating the closed
+ form mapping between two Gaussian distributions :math:`\mathcal{N}(\mu_s,\Sigma_s)`
+ and :math:`\mathcal{N}(\mu_t,\Sigma_t)` as proposed in
+ :ref:`[1] <references-OT-mapping-linear>` and discussed in remark 2.29 in
+ :ref:`[2] <references-OT-mapping-linear>`.
+
+ The linear operator from source to target :math:`M`
+
+ .. math::
+ M(\mathbf{x})= \mathbf{A} \mathbf{x} + \mathbf{b}
+
+ where :
+
+ .. math::
+ \mathbf{A} &= \Sigma_s^{-1/2} \left(\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2} \right)^{1/2}
+ \Sigma_s^{-1/2}
+
+ \mathbf{b} &= \mu_t - \mathbf{A} \mu_s
+
+ Parameters
+ ----------
+ xs : array-like (ns,d)
+ samples in the source domain
+ xt : array-like (nt,d)
+ samples in the target domain
+ reg : float,optional
+ regularization added to the diagonals of covariances (>0)
+ ws : array-like (ns,1), optional
+ weights for the source samples
+ wt : array-like (ns,1), optional
+ weights for the target samples
+ bias: boolean, optional
+ estimate bias :math:`\mathbf{b}` else :math:`\mathbf{b} = 0` (default:True)
+ log : bool, optional
+ record log if True
+
+
+ Returns
+ -------
+ A : (d, d) array-like
+ Linear operator
+ b : (1, d) array-like
+ bias
+ log : dict
+ log dictionary return only if log==True in parameters
+
+
+ .. _references-OT-mapping-linear:
+ References
+ ----------
+ .. [1] Knott, M. and Smith, C. S. "On the optimal mapping of
+ distributions", Journal of Optimization Theory and Applications
+ Vol 43, 1984
+
+ .. [2] Peyré, G., & Cuturi, M. (2017). "Computational Optimal
+ Transport", 2018.
+ """
+ xs, xt = list_to_array(xs, xt)
+ nx = get_backend(xs, xt)
+
+ d = xs.shape[1]
+
+ if bias:
+ mxs = nx.mean(xs, axis=0)[None, :]
+ mxt = nx.mean(xt, axis=0)[None, :]
+
+ xs = xs - mxs
+ xt = xt - mxt
+ else:
+ mxs = nx.zeros((1, d), type_as=xs)
+ mxt = nx.zeros((1, d), type_as=xs)
+
+ if ws is None:
+ ws = nx.ones((xs.shape[0], 1), type_as=xs) / xs.shape[0]
+
+ if wt is None:
+ wt = nx.ones((xt.shape[0], 1), type_as=xt) / xt.shape[0]
+
+ Cs = nx.dot((xs * ws).T, xs) / nx.sum(ws) + reg * nx.eye(d, type_as=xs)
+ Ct = nx.dot((xt * wt).T, xt) / nx.sum(wt) + reg * nx.eye(d, type_as=xt)
+
+ if log:
+ A, b, log = bures_wasserstein_mapping(mxs, mxt, Cs, Ct, log=log)
+ log['Cs'] = Cs
+ log['Ct'] = Ct
+ return A, b, log
+ else:
+ A, b = bures_wasserstein_mapping(mxs, mxt, Cs, Ct)
+ return A, b
+
+
+def bures_wasserstein_distance(ms, mt, Cs, Ct, log=False):
+ r"""Return Bures Wasserstein distance between samples.
+
+ The function estimates the Bures-Wasserstein distance between two
+ empirical distributions source :math:`\mu_s` and target :math:`\mu_t`,
+ discussed in remark 2.31 :ref:`[1] <references-bures-wasserstein-distance>`.
+
+ The Bures Wasserstein distance between source and target distribution :math:`\mathcal{W}`
+
+ .. math::
+ \mathcal{W}(\mu_s, \mu_t)_2^2= \left\lVert \mathbf{m}_s - \mathbf{m}_t \right\rVert^2 + \mathcal{B}(\Sigma_s, \Sigma_t)^{2}
+
+ where :
+
+ .. math::
+ \mathbf{B}(\Sigma_s, \Sigma_t)^{2} = \text{Tr}\left(\Sigma_s^{1/2} + \Sigma_t^{1/2} - 2 \sqrt{\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2}} \right)
+
+ Parameters
+ ----------
+ ms : array-like (d,)
+ mean of the source distribution
+ mt : array-like (d,)
+ mean of the target distribution
+ Cs : array-like (d,)
+ covariance of the source distribution
+ Ct : array-like (d,)
+ covariance of the target distribution
+ log : bool, optional
+ record log if True
+
+
+ Returns
+ -------
+ W : float
+ Bures Wasserstein distance
+ log : dict
+ log dictionary return only if log==True in parameters
+
+
+ .. _references-bures-wasserstein-distance:
+ References
+ ----------
+
+ .. [1] Peyré, G., & Cuturi, M. (2017). "Computational Optimal
+ Transport", 2018.
+ """
+ ms, mt, Cs, Ct = list_to_array(ms, mt, Cs, Ct)
+ nx = get_backend(ms, mt, Cs, Ct)
+
+ Cs12 = nx.sqrtm(Cs)
+
+ B = nx.trace(Cs + Ct - 2 * nx.sqrtm(dots(Cs12, Ct, Cs12)))
+ W = nx.sqrt(nx.norm(ms - mt)**2 + B)
+ if log:
+ log = {}
+ log['Cs12'] = Cs12
+ return W, log
+ else:
+ return W
+
+
+def empirical_bures_wasserstein_distance(xs, xt, reg=1e-6, ws=None,
+ wt=None, bias=True, log=False):
+ r"""Return Bures Wasserstein distance from mean and covariance of distribution.
+
+ The function estimates the Bures-Wasserstein distance between two
+ empirical distributions source :math:`\mu_s` and target :math:`\mu_t`,
+ discussed in remark 2.31 :ref:`[1] <references-bures-wasserstein-distance>`.
+
+ The Bures Wasserstein distance between source and target distribution :math:`\mathcal{W}`
+
+ .. math::
+ \mathcal{W}(\mu_s, \mu_t)_2^2= \left\lVert \mathbf{m}_s - \mathbf{m}_t \right\rVert^2 + \mathcal{B}(\Sigma_s, \Sigma_t)^{2}
+
+ where :
+
+ .. math::
+ \mathbf{B}(\Sigma_s, \Sigma_t)^{2} = \text{Tr}\left(\Sigma_s^{1/2} + \Sigma_t^{1/2} - 2 \sqrt{\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2}} \right)
+
+ Parameters
+ ----------
+ xs : array-like (ns,d)
+ samples in the source domain
+ xt : array-like (nt,d)
+ samples in the target domain
+ reg : float,optional
+ regularization added to the diagonals of covariances (>0)
+ ws : array-like (ns,1), optional
+ weights for the source samples
+ wt : array-like (ns,1), optional
+ weights for the target samples
+ bias: boolean, optional
+ estimate bias :math:`\mathbf{b}` else :math:`\mathbf{b} = 0` (default:True)
+ log : bool, optional
+ record log if True
+
+
+ Returns
+ -------
+ W : float
+ Bures Wasserstein distance
+ log : dict
+ log dictionary return only if log==True in parameters
+
+
+ .. _references-bures-wasserstein-distance:
+ References
+ ----------
+
+ .. [1] Peyré, G., & Cuturi, M. (2017). "Computational Optimal
+ Transport", 2018.
+ """
+ xs, xt = list_to_array(xs, xt)
+ nx = get_backend(xs, xt)
+
+ d = xs.shape[1]
+
+ if bias:
+ mxs = nx.mean(xs, axis=0)[None, :]
+ mxt = nx.mean(xt, axis=0)[None, :]
+
+ xs = xs - mxs
+ xt = xt - mxt
+ else:
+ mxs = nx.zeros((1, d), type_as=xs)
+ mxt = nx.zeros((1, d), type_as=xs)
+
+ if ws is None:
+ ws = nx.ones((xs.shape[0], 1), type_as=xs) / xs.shape[0]
+
+ if wt is None:
+ wt = nx.ones((xt.shape[0], 1), type_as=xt) / xt.shape[0]
+
+ Cs = nx.dot((xs * ws).T, xs) / nx.sum(ws) + reg * nx.eye(d, type_as=xs)
+ Ct = nx.dot((xt * wt).T, xt) / nx.sum(wt) + reg * nx.eye(d, type_as=xt)
+
+ if log:
+ W, log = bures_wasserstein_distance(mxs, mxt, Cs, Ct, log=log)
+ log['Cs'] = Cs
+ log['Ct'] = Ct
+ return W, log
+ else:
+ W = bures_wasserstein_distance(mxs, mxt, Cs, Ct)
+ return W
diff --git a/ot/gromov.py b/ot/gromov.py
deleted file mode 100644
index 55ab0bd..0000000
--- a/ot/gromov.py
+++ /dev/null
@@ -1,2835 +0,0 @@
-# -*- coding: utf-8 -*-
-"""
-Gromov-Wasserstein and Fused-Gromov-Wasserstein solvers
-"""
-
-# Author: Erwan Vautier <erwan.vautier@gmail.com>
-# Nicolas Courty <ncourty@irisa.fr>
-# Rémi Flamary <remi.flamary@unice.fr>
-# Titouan Vayer <titouan.vayer@irisa.fr>
-# Cédric Vincent-Cuaz <cedric.vincent-cuaz@inria.fr>
-#
-# License: MIT License
-
-import numpy as np
-
-
-from .bregman import sinkhorn
-from .utils import dist, UndefinedParameter, list_to_array
-from .optim import cg
-from .lp import emd_1d, emd
-from .utils import check_random_state, unif
-from .backend import get_backend
-
-
-def init_matrix(C1, C2, p, q, loss_fun='square_loss'):
- r"""Return loss matrices and tensors for Gromov-Wasserstein fast computation
-
- Returns the value of :math:`\mathcal{L}(\mathbf{C_1}, \mathbf{C_2}) \otimes \mathbf{T}` with the
- selected loss function as the loss function of Gromow-Wasserstein discrepancy.
-
- The matrices are computed as described in Proposition 1 in :ref:`[12] <references-init-matrix>`
-
- Where :
-
- - :math:`\mathbf{C_1}`: Metric cost matrix in the source space
- - :math:`\mathbf{C_2}`: Metric cost matrix in the target space
- - :math:`\mathbf{T}`: A coupling between those two spaces
-
- The square-loss function :math:`L(a, b) = |a - b|^2` is read as :
-
- .. math::
-
- L(a, b) = f_1(a) + f_2(b) - h_1(a) h_2(b)
-
- \mathrm{with} \ f_1(a) &= a^2
-
- f_2(b) &= b^2
-
- h_1(a) &= a
-
- h_2(b) &= 2b
-
- The kl-loss function :math:`L(a, b) = a \log\left(\frac{a}{b}\right) - a + b` is read as :
-
- .. math::
-
- L(a, b) = f_1(a) + f_2(b) - h_1(a) h_2(b)
-
- \mathrm{with} \ f_1(a) &= a \log(a) - a
-
- f_2(b) &= b
-
- h_1(a) &= a
-
- h_2(b) &= \log(b)
-
- Parameters
- ----------
- C1 : array-like, shape (ns, ns)
- Metric cost matrix in the source space
- C2 : array-like, shape (nt, nt)
- Metric cost matrix in the target space
- T : array-like, shape (ns, nt)
- Coupling between source and target spaces
- p : array-like, shape (ns,)
-
- Returns
- -------
- constC : array-like, shape (ns, nt)
- Constant :math:`\mathbf{C}` matrix in Eq. (6)
- hC1 : array-like, shape (ns, ns)
- :math:`\mathbf{h1}(\mathbf{C1})` matrix in Eq. (6)
- hC2 : array-like, shape (nt, nt)
- :math:`\mathbf{h2}(\mathbf{C2})` matrix in Eq. (6)
-
-
- .. _references-init-matrix:
- References
- ----------
- .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon,
- "Gromov-Wasserstein averaging of kernel and distance matrices."
- International Conference on Machine Learning (ICML). 2016.
-
- """
- C1, C2, p, q = list_to_array(C1, C2, p, q)
- nx = get_backend(C1, C2, p, q)
-
- if loss_fun == 'square_loss':
- def f1(a):
- return (a**2)
-
- def f2(b):
- return (b**2)
-
- def h1(a):
- return a
-
- def h2(b):
- return 2 * b
- elif loss_fun == 'kl_loss':
- def f1(a):
- return a * nx.log(a + 1e-15) - a
-
- def f2(b):
- return b
-
- def h1(a):
- return a
-
- def h2(b):
- return nx.log(b + 1e-15)
-
- constC1 = nx.dot(
- nx.dot(f1(C1), nx.reshape(p, (-1, 1))),
- nx.ones((1, len(q)), type_as=q)
- )
- constC2 = nx.dot(
- nx.ones((len(p), 1), type_as=p),
- nx.dot(nx.reshape(q, (1, -1)), f2(C2).T)
- )
- constC = constC1 + constC2
- hC1 = h1(C1)
- hC2 = h2(C2)
-
- return constC, hC1, hC2
-
-
-def tensor_product(constC, hC1, hC2, T):
- r"""Return the tensor for Gromov-Wasserstein fast computation
-
- The tensor is computed as described in Proposition 1 Eq. (6) in :ref:`[12] <references-tensor-product>`
-
- Parameters
- ----------
- constC : array-like, shape (ns, nt)
- Constant :math:`\mathbf{C}` matrix in Eq. (6)
- hC1 : array-like, shape (ns, ns)
- :math:`\mathbf{h1}(\mathbf{C1})` matrix in Eq. (6)
- hC2 : array-like, shape (nt, nt)
- :math:`\mathbf{h2}(\mathbf{C2})` matrix in Eq. (6)
-
- Returns
- -------
- tens : array-like, shape (`ns`, `nt`)
- :math:`\mathcal{L}(\mathbf{C_1}, \mathbf{C_2}) \otimes \mathbf{T}` tensor-matrix multiplication result
-
-
- .. _references-tensor-product:
- References
- ----------
- .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon,
- "Gromov-Wasserstein averaging of kernel and distance matrices."
- International Conference on Machine Learning (ICML). 2016.
-
- """
- constC, hC1, hC2, T = list_to_array(constC, hC1, hC2, T)
- nx = get_backend(constC, hC1, hC2, T)
-
- A = - nx.dot(
- nx.dot(hC1, T), hC2.T
- )
- tens = constC + A
- # tens -= tens.min()
- return tens
-
-
-def gwloss(constC, hC1, hC2, T):
- r"""Return the Loss for Gromov-Wasserstein
-
- The loss is computed as described in Proposition 1 Eq. (6) in :ref:`[12] <references-gwloss>`
-
- Parameters
- ----------
- constC : array-like, shape (ns, nt)
- Constant :math:`\mathbf{C}` matrix in Eq. (6)
- hC1 : array-like, shape (ns, ns)
- :math:`\mathbf{h1}(\mathbf{C1})` matrix in Eq. (6)
- hC2 : array-like, shape (nt, nt)
- :math:`\mathbf{h2}(\mathbf{C2})` matrix in Eq. (6)
- T : array-like, shape (ns, nt)
- Current value of transport matrix :math:`\mathbf{T}`
-
- Returns
- -------
- loss : float
- Gromov Wasserstein loss
-
-
- .. _references-gwloss:
- References
- ----------
- .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon,
- "Gromov-Wasserstein averaging of kernel and distance matrices."
- International Conference on Machine Learning (ICML). 2016.
-
- """
-
- tens = tensor_product(constC, hC1, hC2, T)
-
- tens, T = list_to_array(tens, T)
- nx = get_backend(tens, T)
-
- return nx.sum(tens * T)
-
-
-def gwggrad(constC, hC1, hC2, T):
- r"""Return the gradient for Gromov-Wasserstein
-
- The gradient is computed as described in Proposition 2 in :ref:`[12] <references-gwggrad>`
-
- Parameters
- ----------
- constC : array-like, shape (ns, nt)
- Constant :math:`\mathbf{C}` matrix in Eq. (6)
- hC1 : array-like, shape (ns, ns)
- :math:`\mathbf{h1}(\mathbf{C1})` matrix in Eq. (6)
- hC2 : array-like, shape (nt, nt)
- :math:`\mathbf{h2}(\mathbf{C2})` matrix in Eq. (6)
- T : array-like, shape (ns, nt)
- Current value of transport matrix :math:`\mathbf{T}`
-
- Returns
- -------
- grad : array-like, shape (`ns`, `nt`)
- Gromov Wasserstein gradient
-
-
- .. _references-gwggrad:
- References
- ----------
- .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon,
- "Gromov-Wasserstein averaging of kernel and distance matrices."
- International Conference on Machine Learning (ICML). 2016.
-
- """
- return 2 * tensor_product(constC, hC1, hC2,
- T) # [12] Prop. 2 misses a 2 factor
-
-
-def update_square_loss(p, lambdas, T, Cs):
- r"""
- Updates :math:`\mathbf{C}` according to the L2 Loss kernel with the `S` :math:`\mathbf{T}_s`
- couplings calculated at each iteration
-
- Parameters
- ----------
- p : array-like, shape (N,)
- Masses in the targeted barycenter.
- lambdas : list of float
- List of the `S` spaces' weights.
- T : list of S array-like of shape (ns,N)
- The `S` :math:`\mathbf{T}_s` couplings calculated at each iteration.
- Cs : list of S array-like, shape(ns,ns)
- Metric cost matrices.
-
- Returns
- ----------
- C : array-like, shape (`nt`, `nt`)
- Updated :math:`\mathbf{C}` matrix.
- """
- T = list_to_array(*T)
- Cs = list_to_array(*Cs)
- p = list_to_array(p)
- nx = get_backend(p, *T, *Cs)
-
- tmpsum = sum([
- lambdas[s] * nx.dot(
- nx.dot(T[s].T, Cs[s]),
- T[s]
- ) for s in range(len(T))
- ])
- ppt = nx.outer(p, p)
-
- return tmpsum / ppt
-
-
-def update_kl_loss(p, lambdas, T, Cs):
- r"""
- Updates :math:`\mathbf{C}` according to the KL Loss kernel with the `S` :math:`\mathbf{T}_s` couplings calculated at each iteration
-
-
- Parameters
- ----------
- p : array-like, shape (N,)
- Weights in the targeted barycenter.
- lambdas : list of float
- List of the `S` spaces' weights
- T : list of S array-like of shape (ns,N)
- The `S` :math:`\mathbf{T}_s` couplings calculated at each iteration.
- Cs : list of S array-like, shape(ns,ns)
- Metric cost matrices.
-
- Returns
- ----------
- C : array-like, shape (`ns`, `ns`)
- updated :math:`\mathbf{C}` matrix
- """
- Cs = list_to_array(*Cs)
- T = list_to_array(*T)
- p = list_to_array(p)
- nx = get_backend(p, *T, *Cs)
-
- tmpsum = sum([
- lambdas[s] * nx.dot(
- nx.dot(T[s].T, Cs[s]),
- T[s]
- ) for s in range(len(T))
- ])
- ppt = nx.outer(p, p)
-
- return nx.exp(tmpsum / ppt)
-
-
-def gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=False, G0=None, **kwargs):
- r"""
- Returns the gromov-wasserstein transport between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})`
-
- The function solves the following optimization problem:
-
- .. math::
- \mathbf{GW} = \mathop{\arg \min}_\mathbf{T} \quad \sum_{i,j,k,l}
- L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l}
-
- Where :
-
- - :math:`\mathbf{C_1}`: Metric cost matrix in the source space
- - :math:`\mathbf{C_2}`: Metric cost matrix in the target space
- - :math:`\mathbf{p}`: distribution in the source space
- - :math:`\mathbf{q}`: distribution in the target space
- - `L`: loss function to account for the misfit between the similarity matrices
-
- .. note:: This function is backend-compatible and will work on arrays
- from all compatible backends. But the algorithm uses the C++ CPU backend
- which can lead to copy overhead on GPU arrays.
-
- Parameters
- ----------
- C1 : array-like, shape (ns, ns)
- Metric cost matrix in the source space
- C2 : array-like, shape (nt, nt)
- Metric cost matrix in the target space
- p : array-like, shape (ns,)
- Distribution in the source space
- q : array-like, shape (nt,)
- Distribution in the target space
- loss_fun : str
- loss function used for the solver either 'square_loss' or 'kl_loss'
- max_iter : int, optional
- Max number of iterations
- tol : float, optional
- Stop threshold on error (>0)
- verbose : bool, optional
- Print information along iterations
- log : bool, optional
- record log if True
- armijo : bool, optional
- If True the step of the line-search is found via an armijo research. Else closed form is used.
- If there are convergence issues use False.
- G0: array-like, shape (ns,nt), optional
- If None the initial transport plan of the solver is pq^T.
- Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver.
- **kwargs : dict
- parameters can be directly passed to the ot.optim.cg solver
-
- Returns
- -------
- T : array-like, shape (`ns`, `nt`)
- Coupling between the two spaces that minimizes:
-
- :math:`\sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l}`
- log : dict
- Convergence information and loss.
-
- References
- ----------
- .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon,
- "Gromov-Wasserstein averaging of kernel and distance matrices."
- International Conference on Machine Learning (ICML). 2016.
-
- .. [13] Mémoli, Facundo. Gromov–Wasserstein distances and the
- metric approach to object matching. Foundations of computational
- mathematics 11.4 (2011): 417-487.
-
- """
- p, q = list_to_array(p, q)
- p0, q0, C10, C20 = p, q, C1, C2
- if G0 is None:
- nx = get_backend(p0, q0, C10, C20)
- else:
- G0_ = G0
- nx = get_backend(p0, q0, C10, C20, G0_)
- p = nx.to_numpy(p)
- q = nx.to_numpy(q)
- C1 = nx.to_numpy(C10)
- C2 = nx.to_numpy(C20)
-
- if G0 is None:
- G0 = p[:, None] * q[None, :]
- else:
- G0 = nx.to_numpy(G0_)
- # Check marginals of G0
- np.testing.assert_allclose(G0.sum(axis=1), p, atol=1e-08)
- np.testing.assert_allclose(G0.sum(axis=0), q, atol=1e-08)
-
- constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun)
-
- def f(G):
- return gwloss(constC, hC1, hC2, G)
-
- def df(G):
- return gwggrad(constC, hC1, hC2, G)
-
- if log:
- res, log = cg(p, q, 0, 1, f, df, G0, log=True, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs)
- log['gw_dist'] = nx.from_numpy(gwloss(constC, hC1, hC2, res), type_as=C10)
- log['u'] = nx.from_numpy(log['u'], type_as=C10)
- log['v'] = nx.from_numpy(log['v'], type_as=C10)
- return nx.from_numpy(res, type_as=C10), log
- else:
- return nx.from_numpy(cg(p, q, 0, 1, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=False, **kwargs), type_as=C10)
-
-
-def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=False, G0=None, **kwargs):
- r"""
- Returns the gromov-wasserstein discrepancy between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})`
-
- The function solves the following optimization problem:
-
- .. math::
- GW = \min_\mathbf{T} \quad \sum_{i,j,k,l}
- L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l}
-
- Where :
-
- - :math:`\mathbf{C_1}`: Metric cost matrix in the source space
- - :math:`\mathbf{C_2}`: Metric cost matrix in the target space
- - :math:`\mathbf{p}`: distribution in the source space
- - :math:`\mathbf{q}`: distribution in the target space
- - `L`: loss function to account for the misfit between the similarity
- matrices
-
- Note that when using backends, this loss function is differentiable wrt the
- marices and weights for quadratic loss using the gradients from [38]_.
-
- .. note:: This function is backend-compatible and will work on arrays
- from all compatible backends. But the algorithm uses the C++ CPU backend
- which can lead to copy overhead on GPU arrays.
-
- Parameters
- ----------
- C1 : array-like, shape (ns, ns)
- Metric cost matrix in the source space
- C2 : array-like, shape (nt, nt)
- Metric cost matrix in the target space
- p : array-like, shape (ns,)
- Distribution in the source space.
- q : array-like, shape (nt,)
- Distribution in the target space.
- loss_fun : str
- loss function used for the solver either 'square_loss' or 'kl_loss'
- max_iter : int, optional
- Max number of iterations
- tol : float, optional
- Stop threshold on error (>0)
- verbose : bool, optional
- Print information along iterations
- log : bool, optional
- record log if True
- armijo : bool, optional
- If True the step of the line-search is found via an armijo research. Else closed form is used.
- If there are convergence issues use False.
- G0: array-like, shape (ns,nt), optional
- If None the initial transport plan of the solver is pq^T.
- Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver.
-
- Returns
- -------
- gw_dist : float
- Gromov-Wasserstein distance
- log : dict
- convergence information and Coupling marix
-
- References
- ----------
- .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon,
- "Gromov-Wasserstein averaging of kernel and distance matrices."
- International Conference on Machine Learning (ICML). 2016.
-
- .. [13] Mémoli, Facundo. Gromov–Wasserstein distances and the
- metric approach to object matching. Foundations of computational
- mathematics 11.4 (2011): 417-487.
-
- .. [38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, Online
- Graph Dictionary Learning, International Conference on Machine Learning
- (ICML), 2021.
-
- """
- p, q = list_to_array(p, q)
- p0, q0, C10, C20 = p, q, C1, C2
- if G0 is None:
- nx = get_backend(p0, q0, C10, C20)
- else:
- G0_ = G0
- nx = get_backend(p0, q0, C10, C20, G0_)
-
- p = nx.to_numpy(p)
- q = nx.to_numpy(q)
- C1 = nx.to_numpy(C10)
- C2 = nx.to_numpy(C20)
-
- constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun)
-
- if G0 is None:
- G0 = p[:, None] * q[None, :]
- else:
- G0 = nx.to_numpy(G0_)
- # Check marginals of G0
- np.testing.assert_allclose(G0.sum(axis=1), p, atol=1e-08)
- np.testing.assert_allclose(G0.sum(axis=0), q, atol=1e-08)
-
- def f(G):
- return gwloss(constC, hC1, hC2, G)
-
- def df(G):
- return gwggrad(constC, hC1, hC2, G)
-
- T, log_gw = cg(p, q, 0, 1, f, df, G0, log=True, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs)
-
- T0 = nx.from_numpy(T, type_as=C10)
-
- log_gw['gw_dist'] = nx.from_numpy(gwloss(constC, hC1, hC2, T), type_as=C10)
- log_gw['u'] = nx.from_numpy(log_gw['u'], type_as=C10)
- log_gw['v'] = nx.from_numpy(log_gw['v'], type_as=C10)
- log_gw['T'] = T0
-
- gw = log_gw['gw_dist']
-
- if loss_fun == 'square_loss':
- gC1 = 2 * C1 * (p[:, None] * p[None, :]) - 2 * T.dot(C2).dot(T.T)
- gC2 = 2 * C2 * (q[:, None] * q[None, :]) - 2 * T.T.dot(C1).dot(T)
- gC1 = nx.from_numpy(gC1, type_as=C10)
- gC2 = nx.from_numpy(gC2, type_as=C10)
- gw = nx.set_gradients(gw, (p0, q0, C10, C20),
- (log_gw['u'] - nx.mean(log_gw['u']),
- log_gw['v'] - nx.mean(log_gw['v']), gC1, gC2))
-
- if log:
- return gw, log_gw
- else:
- return gw
-
-
-def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, G0=None, log=False, **kwargs):
- r"""
- Computes the FGW transport between two graphs (see :ref:`[24] <references-fused-gromov-wasserstein>`)
-
- .. math::
- \gamma = \mathop{\arg \min}_\gamma \quad (1 - \alpha) \langle \gamma, \mathbf{M} \rangle_F +
- \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l}
-
- s.t. \ \mathbf{\gamma} \mathbf{1} &= \mathbf{p}
-
- \mathbf{\gamma}^T \mathbf{1} &= \mathbf{q}
-
- \mathbf{\gamma} &\geq 0
-
- where :
-
- - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix
- - :math:`\mathbf{p}` and :math:`\mathbf{q}` are source and target weights (sum to 1)
- - `L` is a loss function to account for the misfit between the similarity matrices
-
- .. note:: This function is backend-compatible and will work on arrays
- from all compatible backends. But the algorithm uses the C++ CPU backend
- which can lead to copy overhead on GPU arrays.
-
- The algorithm used for solving the problem is conditional gradient as discussed in :ref:`[24] <references-fused-gromov-wasserstein>`
-
- Parameters
- ----------
- M : array-like, shape (ns, nt)
- Metric cost matrix between features across domains
- C1 : array-like, shape (ns, ns)
- Metric cost matrix representative of the structure in the source space
- C2 : array-like, shape (nt, nt)
- Metric cost matrix representative of the structure in the target space
- p : array-like, shape (ns,)
- Distribution in the source space
- q : array-like, shape (nt,)
- Distribution in the target space
- loss_fun : str, optional
- Loss function used for the solver
- alpha : float, optional
- Trade-off parameter (0 < alpha < 1)
- armijo : bool, optional
- If True the step of the line-search is found via an armijo research. Else closed form is used.
- If there are convergence issues use False.
- G0: array-like, shape (ns,nt), optional
- If None the initial transport plan of the solver is pq^T.
- Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver.
- log : bool, optional
- record log if True
- **kwargs : dict
- parameters can be directly passed to the ot.optim.cg solver
-
- Returns
- -------
- gamma : array-like, shape (`ns`, `nt`)
- Optimal transportation matrix for the given parameters.
- log : dict
- Log dictionary return only if log==True in parameters.
-
-
- .. _references-fused-gromov-wasserstein:
- References
- ----------
- .. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain
- and Courty Nicolas "Optimal Transport for structured data with
- application on graphs", International Conference on Machine Learning
- (ICML). 2019.
- """
- p, q = list_to_array(p, q)
- p0, q0, C10, C20, M0 = p, q, C1, C2, M
- if G0 is None:
- nx = get_backend(p0, q0, C10, C20, M0)
- else:
- G0_ = G0
- nx = get_backend(p0, q0, C10, C20, M0, G0_)
-
- p = nx.to_numpy(p)
- q = nx.to_numpy(q)
- C1 = nx.to_numpy(C10)
- C2 = nx.to_numpy(C20)
- M = nx.to_numpy(M0)
- if G0 is None:
- G0 = p[:, None] * q[None, :]
- else:
- G0 = nx.to_numpy(G0_)
- # Check marginals of G0
- np.testing.assert_allclose(G0.sum(axis=1), p, atol=1e-08)
- np.testing.assert_allclose(G0.sum(axis=0), q, atol=1e-08)
-
- constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun)
-
- def f(G):
- return gwloss(constC, hC1, hC2, G)
-
- def df(G):
- return gwggrad(constC, hC1, hC2, G)
-
- if log:
- res, log = cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs)
- fgw_dist = nx.from_numpy(log['loss'][-1], type_as=C10)
- log['fgw_dist'] = fgw_dist
- log['u'] = nx.from_numpy(log['u'], type_as=C10)
- log['v'] = nx.from_numpy(log['v'], type_as=C10)
- return nx.from_numpy(res, type_as=C10), log
- else:
- return nx.from_numpy(cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs), type_as=C10)
-
-
-def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, G0=None, log=False, **kwargs):
- r"""
- Computes the FGW distance between two graphs see (see :ref:`[24] <references-fused-gromov-wasserstein2>`)
-
- .. math::
- \min_\gamma \quad (1 - \alpha) \langle \gamma, \mathbf{M} \rangle_F + \alpha \sum_{i,j,k,l}
- L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l}
-
- s.t. \ \mathbf{\gamma} \mathbf{1} &= \mathbf{p}
-
- \mathbf{\gamma}^T \mathbf{1} &= \mathbf{q}
-
- \mathbf{\gamma} &\geq 0
-
- where :
-
- - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix
- - :math:`\mathbf{p}` and :math:`\mathbf{q}` are source and target weights (sum to 1)
- - `L` is a loss function to account for the misfit between the similarity matrices
-
- The algorithm used for solving the problem is conditional gradient as
- discussed in :ref:`[24] <references-fused-gromov-wasserstein2>`
-
- .. note:: This function is backend-compatible and will work on arrays
- from all compatible backends. But the algorithm uses the C++ CPU backend
- which can lead to copy overhead on GPU arrays.
-
- Note that when using backends, this loss function is differentiable wrt the
- marices and weights for quadratic loss using the gradients from [38]_.
-
- Parameters
- ----------
- M : array-like, shape (ns, nt)
- Metric cost matrix between features across domains
- C1 : array-like, shape (ns, ns)
- Metric cost matrix representative of the structure in the source space.
- C2 : array-like, shape (nt, nt)
- Metric cost matrix representative of the structure in the target space.
- p : array-like, shape (ns,)
- Distribution in the source space.
- q : array-like, shape (nt,)
- Distribution in the target space.
- loss_fun : str, optional
- Loss function used for the solver.
- alpha : float, optional
- Trade-off parameter (0 < alpha < 1)
- armijo : bool, optional
- If True the step of the line-search is found via an armijo research.
- Else closed form is used. If there are convergence issues use False.
- G0: array-like, shape (ns,nt), optional
- If None the initial transport plan of the solver is pq^T.
- Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver.
- log : bool, optional
- Record log if True.
- **kwargs : dict
- Parameters can be directly passed to the ot.optim.cg solver.
-
- Returns
- -------
- fgw-distance : float
- Fused gromov wasserstein distance for the given parameters.
- log : dict
- Log dictionary return only if log==True in parameters.
-
-
- .. _references-fused-gromov-wasserstein2:
- References
- ----------
- .. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain
- and Courty Nicolas
- "Optimal Transport for structured data with application on graphs"
- International Conference on Machine Learning (ICML). 2019.
-
- .. [38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, Online
- Graph Dictionary Learning, International Conference on Machine Learning
- (ICML), 2021.
- """
- p, q = list_to_array(p, q)
-
- p0, q0, C10, C20, M0 = p, q, C1, C2, M
- if G0 is None:
- nx = get_backend(p0, q0, C10, C20, M0)
- else:
- G0_ = G0
- nx = get_backend(p0, q0, C10, C20, M0, G0_)
-
- p = nx.to_numpy(p)
- q = nx.to_numpy(q)
- C1 = nx.to_numpy(C10)
- C2 = nx.to_numpy(C20)
- M = nx.to_numpy(M0)
-
- constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun)
-
- if G0 is None:
- G0 = p[:, None] * q[None, :]
- else:
- G0 = nx.to_numpy(G0_)
- # Check marginals of G0
- np.testing.assert_allclose(G0.sum(axis=1), p, atol=1e-08)
- np.testing.assert_allclose(G0.sum(axis=0), q, atol=1e-08)
-
- def f(G):
- return gwloss(constC, hC1, hC2, G)
-
- def df(G):
- return gwggrad(constC, hC1, hC2, G)
-
- T, log_fgw = cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs)
-
- fgw_dist = nx.from_numpy(log_fgw['loss'][-1], type_as=C10)
-
- T0 = nx.from_numpy(T, type_as=C10)
-
- log_fgw['fgw_dist'] = fgw_dist
- log_fgw['u'] = nx.from_numpy(log_fgw['u'], type_as=C10)
- log_fgw['v'] = nx.from_numpy(log_fgw['v'], type_as=C10)
- log_fgw['T'] = T0
-
- if loss_fun == 'square_loss':
- gC1 = 2 * C1 * (p[:, None] * p[None, :]) - 2 * T.dot(C2).dot(T.T)
- gC2 = 2 * C2 * (q[:, None] * q[None, :]) - 2 * T.T.dot(C1).dot(T)
- gC1 = nx.from_numpy(gC1, type_as=C10)
- gC2 = nx.from_numpy(gC2, type_as=C10)
- fgw_dist = nx.set_gradients(fgw_dist, (p0, q0, C10, C20, M0),
- (log_fgw['u'] - nx.mean(log_fgw['u']),
- log_fgw['v'] - nx.mean(log_fgw['v']),
- alpha * gC1, alpha * gC2, (1 - alpha) * T0))
-
- if log:
- return fgw_dist, log_fgw
- else:
- return fgw_dist
-
-
-def GW_distance_estimation(C1, C2, p, q, loss_fun, T,
- nb_samples_p=None, nb_samples_q=None, std=True, random_state=None):
- r"""
- Returns an approximation of the gromov-wasserstein cost between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})`
- with a fixed transport plan :math:`\mathbf{T}`.
-
- The function gives an unbiased approximation of the following equation:
-
- .. math::
-
- GW = \sum_{i,j,k,l} L(\mathbf{C_{1}}_{i,k}, \mathbf{C_{2}}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l}
-
- Where :
-
- - :math:`\mathbf{C_1}`: Metric cost matrix in the source space
- - :math:`\mathbf{C_2}`: Metric cost matrix in the target space
- - `L` : Loss function to account for the misfit between the similarity matrices
- - :math:`\mathbf{T}`: Matrix with marginal :math:`\mathbf{p}` and :math:`\mathbf{q}`
-
- Parameters
- ----------
- C1 : array-like, shape (ns, ns)
- Metric cost matrix in the source space
- C2 : array-like, shape (nt, nt)
- Metric cost matrix in the target space
- p : array-like, shape (ns,)
- Distribution in the source space
- q : array-like, shape (nt,)
- Distribution in the target space
- loss_fun : function: :math:`\mathbb{R} \times \mathbb{R} \mapsto \mathbb{R}`
- Loss function used for the distance, the transport plan does not depend on the loss function
- T : csr or array-like, shape (ns, nt)
- Transport plan matrix, either a sparse csr or a dense matrix
- nb_samples_p : int, optional
- `nb_samples_p` is the number of samples (without replacement) along the first dimension of :math:`\mathbf{T}`
- nb_samples_q : int, optional
- `nb_samples_q` is the number of samples along the second dimension of :math:`\mathbf{T}`, for each sample along the first
- std : bool, optional
- Standard deviation associated with the prediction of the gromov-wasserstein cost
- random_state : int or RandomState instance, optional
- Fix the seed for reproducibility
-
- Returns
- -------
- : float
- Gromov-wasserstein cost
-
- References
- ----------
- .. [14] Kerdoncuff, Tanguy, Emonet, Rémi, Sebban, Marc
- "Sampled Gromov Wasserstein."
- Machine Learning Journal (MLJ). 2021.
-
- """
- C1, C2, p, q = list_to_array(C1, C2, p, q)
- nx = get_backend(C1, C2, p, q)
-
- generator = check_random_state(random_state)
-
- len_p = p.shape[0]
- len_q = q.shape[0]
-
- # It is always better to sample from the biggest distribution first.
- if len_p < len_q:
- p, q = q, p
- len_p, len_q = len_q, len_p
- C1, C2 = C2, C1
- T = T.T
-
- if nb_samples_p is None:
- if nx.issparse(T):
- # If T is sparse, it probably mean that PoGroW was used, thus the number of sample is reduced
- nb_samples_p = min(int(5 * (len_p * np.log(len_p)) ** 0.5), len_p)
- else:
- nb_samples_p = len_p
- else:
- # The number of sample along the first dimension is without replacement.
- nb_samples_p = min(nb_samples_p, len_p)
- if nb_samples_q is None:
- nb_samples_q = 1
- if std:
- nb_samples_q = max(2, nb_samples_q)
-
- index_k = np.zeros((nb_samples_p, nb_samples_q), dtype=int)
- index_l = np.zeros((nb_samples_p, nb_samples_q), dtype=int)
-
- index_i = generator.choice(
- len_p, size=nb_samples_p, p=nx.to_numpy(p), replace=False
- )
- index_j = generator.choice(
- len_p, size=nb_samples_p, p=nx.to_numpy(p), replace=False
- )
-
- for i in range(nb_samples_p):
- if nx.issparse(T):
- T_indexi = nx.reshape(nx.todense(T[index_i[i], :]), (-1,))
- T_indexj = nx.reshape(nx.todense(T[index_j[i], :]), (-1,))
- else:
- T_indexi = T[index_i[i], :]
- T_indexj = T[index_j[i], :]
- # For each of the row sampled, the column is sampled.
- index_k[i] = generator.choice(
- len_q,
- size=nb_samples_q,
- p=nx.to_numpy(T_indexi / nx.sum(T_indexi)),
- replace=True
- )
- index_l[i] = generator.choice(
- len_q,
- size=nb_samples_q,
- p=nx.to_numpy(T_indexj / nx.sum(T_indexj)),
- replace=True
- )
-
- list_value_sample = nx.stack([
- loss_fun(
- C1[np.ix_(index_i, index_j)],
- C2[np.ix_(index_k[:, n], index_l[:, n])]
- ) for n in range(nb_samples_q)
- ], axis=2)
-
- if std:
- std_value = nx.sum(nx.std(list_value_sample, axis=2) ** 2) ** 0.5
- return nx.mean(list_value_sample), std_value / (nb_samples_p * nb_samples_p)
- else:
- return nx.mean(list_value_sample)
-
-
-def pointwise_gromov_wasserstein(C1, C2, p, q, loss_fun,
- alpha=1, max_iter=100, threshold_plan=0, log=False, verbose=False, random_state=None):
- r"""
- Returns the gromov-wasserstein transport between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` using a stochastic Frank-Wolfe.
- This method has a :math:`\mathcal{O}(\mathrm{max\_iter} \times PN^2)` time complexity with `P` the number of Sinkhorn iterations.
-
- The function solves the following optimization problem:
-
- .. math::
- \mathbf{GW} = \mathop{\arg \min}_\mathbf{T} \quad \sum_{i,j,k,l}
- L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l}
-
- s.t. \ \mathbf{T} \mathbf{1} &= \mathbf{p}
-
- \mathbf{T}^T \mathbf{1} &= \mathbf{q}
-
- \mathbf{T} &\geq 0
-
- Where :
-
- - :math:`\mathbf{C_1}`: Metric cost matrix in the source space
- - :math:`\mathbf{C_2}`: Metric cost matrix in the target space
- - :math:`\mathbf{p}`: distribution in the source space
- - :math:`\mathbf{q}`: distribution in the target space
- - `L`: loss function to account for the misfit between the similarity matrices
-
- Parameters
- ----------
- C1 : array-like, shape (ns, ns)
- Metric cost matrix in the source space
- C2 : array-like, shape (nt, nt)
- Metric cost matrix in the target space
- p : array-like, shape (ns,)
- Distribution in the source space
- q : array-like, shape (nt,)
- Distribution in the target space
- loss_fun : function: :math:`\mathbb{R} \times \mathbb{R} \mapsto \mathbb{R}`
- Loss function used for the distance, the transport plan does not depend on the loss function
- alpha : float
- Step of the Frank-Wolfe algorithm, should be between 0 and 1
- max_iter : int, optional
- Max number of iterations
- threshold_plan : float, optional
- Deleting very small values in the transport plan. If above zero, it violates the marginal constraints.
- verbose : bool, optional
- Print information along iterations
- log : bool, optional
- Gives the distance estimated and the standard deviation
- random_state : int or RandomState instance, optional
- Fix the seed for reproducibility
-
- Returns
- -------
- T : array-like, shape (`ns`, `nt`)
- Optimal coupling between the two spaces
-
- References
- ----------
- .. [14] Kerdoncuff, Tanguy, Emonet, Rémi, Sebban, Marc
- "Sampled Gromov Wasserstein."
- Machine Learning Journal (MLJ). 2021.
-
- """
- C1, C2, p, q = list_to_array(C1, C2, p, q)
- nx = get_backend(C1, C2, p, q)
-
- len_p = p.shape[0]
- len_q = q.shape[0]
-
- generator = check_random_state(random_state)
-
- index = np.zeros(2, dtype=int)
-
- # Initialize with default marginal
- index[0] = generator.choice(len_p, size=1, p=nx.to_numpy(p))
- index[1] = generator.choice(len_q, size=1, p=nx.to_numpy(q))
- T = nx.tocsr(emd_1d(C1[index[0]], C2[index[1]], a=p, b=q, dense=False))
-
- best_gw_dist_estimated = np.inf
- for cpt in range(max_iter):
- index[0] = generator.choice(len_p, size=1, p=nx.to_numpy(p))
- T_index0 = nx.reshape(nx.todense(T[index[0], :]), (-1,))
- index[1] = generator.choice(
- len_q, size=1, p=nx.to_numpy(T_index0 / nx.sum(T_index0))
- )
-
- if alpha == 1:
- T = nx.tocsr(
- emd_1d(C1[index[0]], C2[index[1]], a=p, b=q, dense=False)
- )
- else:
- new_T = nx.tocsr(
- emd_1d(C1[index[0]], C2[index[1]], a=p, b=q, dense=False)
- )
- T = (1 - alpha) * T + alpha * new_T
- # To limit the number of non 0, the values below the threshold are set to 0.
- T = nx.eliminate_zeros(T, threshold=threshold_plan)
-
- if cpt % 10 == 0 or cpt == (max_iter - 1):
- gw_dist_estimated = GW_distance_estimation(
- C1=C1, C2=C2, loss_fun=loss_fun,
- p=p, q=q, T=T, std=False, random_state=generator
- )
-
- if gw_dist_estimated < best_gw_dist_estimated:
- best_gw_dist_estimated = gw_dist_estimated
- best_T = nx.copy(T)
-
- if verbose:
- if cpt % 200 == 0:
- print('{:5s}|{:12s}'.format('It.', 'Best gw estimated') + '\n' + '-' * 19)
- print('{:5d}|{:8e}|'.format(cpt, best_gw_dist_estimated))
-
- if log:
- log = {}
- log["gw_dist_estimated"], log["gw_dist_std"] = GW_distance_estimation(
- C1=C1, C2=C2, loss_fun=loss_fun,
- p=p, q=q, T=best_T, random_state=generator
- )
- return best_T, log
- return best_T
-
-
-def sampled_gromov_wasserstein(C1, C2, p, q, loss_fun,
- nb_samples_grad=100, epsilon=1, max_iter=500, log=False, verbose=False,
- random_state=None):
- r"""
- Returns the gromov-wasserstein transport between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` using a 1-stochastic Frank-Wolfe.
- This method has a :math:`\mathcal{O}(\mathrm{max\_iter} \times N \log(N))` time complexity by relying on the 1D Optimal Transport solver.
-
- The function solves the following optimization problem:
-
- .. math::
- \mathbf{GW} = \mathop{\arg \min}_\mathbf{T} \quad \sum_{i,j,k,l}
- L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l}
-
- s.t. \ \mathbf{T} \mathbf{1} &= \mathbf{p}
-
- \mathbf{T}^T \mathbf{1} &= \mathbf{q}
-
- \mathbf{T} &\geq 0
-
- Where :
-
- - :math:`\mathbf{C_1}`: Metric cost matrix in the source space
- - :math:`\mathbf{C_2}`: Metric cost matrix in the target space
- - :math:`\mathbf{p}`: distribution in the source space
- - :math:`\mathbf{q}`: distribution in the target space
- - `L`: loss function to account for the misfit between the similarity matrices
-
- Parameters
- ----------
- C1 : array-like, shape (ns, ns)
- Metric cost matrix in the source space
- C2 : array-like, shape (nt, nt)
- Metric cost matrix in the target space
- p : array-like, shape (ns,)
- Distribution in the source space
- q : array-like, shape (nt,)
- Distribution in the target space
- loss_fun : function: :math:`\mathbb{R} \times \mathbb{R} \mapsto \mathbb{R}`
- Loss function used for the distance, the transport plan does not depend on the loss function
- nb_samples_grad : int
- Number of samples to approximate the gradient
- epsilon : float
- Weight of the Kullback-Leibler regularization
- max_iter : int, optional
- Max number of iterations
- verbose : bool, optional
- Print information along iterations
- log : bool, optional
- Gives the distance estimated and the standard deviation
- random_state : int or RandomState instance, optional
- Fix the seed for reproducibility
-
- Returns
- -------
- T : array-like, shape (`ns`, `nt`)
- Optimal coupling between the two spaces
-
- References
- ----------
- .. [14] Kerdoncuff, Tanguy, Emonet, Rémi, Sebban, Marc
- "Sampled Gromov Wasserstein."
- Machine Learning Journal (MLJ). 2021.
-
- """
- C1, C2, p, q = list_to_array(C1, C2, p, q)
- nx = get_backend(C1, C2, p, q)
-
- len_p = p.shape[0]
- len_q = q.shape[0]
-
- generator = check_random_state(random_state)
-
- # The most natural way to define nb_sample is with a simple integer.
- if isinstance(nb_samples_grad, int):
- if nb_samples_grad > len_p:
- # As the sampling along the first dimension is done without replacement, the rest is reported to the second
- # dimension.
- nb_samples_grad_p, nb_samples_grad_q = len_p, nb_samples_grad // len_p
- else:
- nb_samples_grad_p, nb_samples_grad_q = nb_samples_grad, 1
- else:
- nb_samples_grad_p, nb_samples_grad_q = nb_samples_grad
- T = nx.outer(p, q)
- # continue_loop allows to stop the loop if there is several successive small modification of T.
- continue_loop = 0
-
- # The gradient of GW is more complex if the two matrices are not symmetric.
- C_are_symmetric = nx.allclose(C1, C1.T, rtol=1e-10, atol=1e-10) and nx.allclose(C2, C2.T, rtol=1e-10, atol=1e-10)
-
- for cpt in range(max_iter):
- index0 = generator.choice(
- len_p, size=nb_samples_grad_p, p=nx.to_numpy(p), replace=False
- )
- Lik = 0
- for i, index0_i in enumerate(index0):
- index1 = generator.choice(
- len_q, size=nb_samples_grad_q,
- p=nx.to_numpy(T[index0_i, :] / nx.sum(T[index0_i, :])),
- replace=False
- )
- # If the matrices C are not symmetric, the gradient has 2 terms, thus the term is chosen randomly.
- if (not C_are_symmetric) and generator.rand(1) > 0.5:
- Lik += nx.mean(loss_fun(
- C1[:, [index0[i]] * nb_samples_grad_q][:, None, :],
- C2[:, index1][None, :, :]
- ), axis=2)
- else:
- Lik += nx.mean(loss_fun(
- C1[[index0[i]] * nb_samples_grad_q, :][:, :, None],
- C2[index1, :][:, None, :]
- ), axis=0)
-
- max_Lik = nx.max(Lik)
- if max_Lik == 0:
- continue
- # This division by the max is here to facilitate the choice of epsilon.
- Lik /= max_Lik
-
- if epsilon > 0:
- # Set to infinity all the numbers below exp(-200) to avoid log of 0.
- log_T = nx.log(nx.clip(T, np.exp(-200), 1))
- log_T = nx.where(log_T == -200, -np.inf, log_T)
- Lik = Lik - epsilon * log_T
-
- try:
- new_T = sinkhorn(a=p, b=q, M=Lik, reg=epsilon)
- except (RuntimeWarning, UserWarning):
- print("Warning catched in Sinkhorn: Return last stable T")
- break
- else:
- new_T = emd(a=p, b=q, M=Lik)
-
- change_T = nx.mean((T - new_T) ** 2)
- if change_T <= 10e-20:
- continue_loop += 1
- if continue_loop > 100: # Number max of low modifications of T
- T = nx.copy(new_T)
- break
- else:
- continue_loop = 0
-
- if verbose and cpt % 10 == 0:
- if cpt % 200 == 0:
- print('{:5s}|{:12s}'.format('It.', '||T_n - T_{n+1}||') + '\n' + '-' * 19)
- print('{:5d}|{:8e}|'.format(cpt, change_T))
- T = nx.copy(new_T)
-
- if log:
- log = {}
- log["gw_dist_estimated"], log["gw_dist_std"] = GW_distance_estimation(
- C1=C1, C2=C2, loss_fun=loss_fun,
- p=p, q=q, T=T, random_state=generator
- )
- return T, log
- return T
-
-
-def entropic_gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon,
- max_iter=1000, tol=1e-9, verbose=False, log=False):
- r"""
- Returns the gromov-wasserstein transport between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})`
-
- The function solves the following optimization problem:
-
- .. math::
- \mathbf{GW} = \mathop{\arg\min}_\mathbf{T} \quad \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} - \epsilon(H(\mathbf{T}))
-
- s.t. \ \mathbf{T} \mathbf{1} &= \mathbf{p}
-
- \mathbf{T}^T \mathbf{1} &= \mathbf{q}
-
- \mathbf{T} &\geq 0
-
- Where :
-
- - :math:`\mathbf{C_1}`: Metric cost matrix in the source space
- - :math:`\mathbf{C_2}`: Metric cost matrix in the target space
- - :math:`\mathbf{p}`: distribution in the source space
- - :math:`\mathbf{q}`: distribution in the target space
- - `L`: loss function to account for the misfit between the similarity matrices
- - `H`: entropy
-
- Parameters
- ----------
- C1 : array-like, shape (ns, ns)
- Metric cost matrix in the source space
- C2 : array-like, shape (nt, nt)
- Metric cost matrix in the target space
- p : array-like, shape (ns,)
- Distribution in the source space
- q : array-like, shape (nt,)
- Distribution in the target space
- loss_fun : string
- Loss function used for the solver either 'square_loss' or 'kl_loss'
- epsilon : float
- Regularization term >0
- max_iter : int, optional
- Max number of iterations
- tol : float, optional
- Stop threshold on error (>0)
- verbose : bool, optional
- Print information along iterations
- log : bool, optional
- Record log if True.
-
- Returns
- -------
- T : array-like, shape (`ns`, `nt`)
- Optimal coupling between the two spaces
-
- References
- ----------
- .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon,
- "Gromov-Wasserstein averaging of kernel and distance matrices."
- International Conference on Machine Learning (ICML). 2016.
-
- """
- C1, C2, p, q = list_to_array(C1, C2, p, q)
- nx = get_backend(C1, C2, p, q)
-
- T = nx.outer(p, q)
-
- constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun)
-
- cpt = 0
- err = 1
-
- if log:
- log = {'err': []}
-
- while (err > tol and cpt < max_iter):
-
- Tprev = T
-
- # compute the gradient
- tens = gwggrad(constC, hC1, hC2, T)
-
- T = sinkhorn(p, q, tens, epsilon, method='sinkhorn')
-
- if cpt % 10 == 0:
- # we can speed up the process by checking for the error only all
- # the 10th iterations
- err = nx.norm(T - Tprev)
-
- if log:
- log['err'].append(err)
-
- if verbose:
- if cpt % 200 == 0:
- print('{:5s}|{:12s}'.format(
- 'It.', 'Err') + '\n' + '-' * 19)
- print('{:5d}|{:8e}|'.format(cpt, err))
-
- cpt += 1
-
- if log:
- log['gw_dist'] = gwloss(constC, hC1, hC2, T)
- return T, log
- else:
- return T
-
-
-def entropic_gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon,
- max_iter=1000, tol=1e-9, verbose=False, log=False):
- r"""
- Returns the entropic gromov-wasserstein discrepancy between the two measured similarity matrices :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})`
-
- The function solves the following optimization problem:
-
- .. math::
- GW = \min_\mathbf{T} \quad \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l})
- \mathbf{T}_{i,j} \mathbf{T}_{k,l} - \epsilon(H(\mathbf{T}))
-
- Where :
-
- - :math:`\mathbf{C_1}`: Metric cost matrix in the source space
- - :math:`\mathbf{C_2}`: Metric cost matrix in the target space
- - :math:`\mathbf{p}`: distribution in the source space
- - :math:`\mathbf{q}`: distribution in the target space
- - `L`: loss function to account for the misfit between the similarity matrices
- - `H`: entropy
-
- Parameters
- ----------
- C1 : array-like, shape (ns, ns)
- Metric cost matrix in the source space
- C2 : array-like, shape (nt, nt)
- Metric cost matrix in the target space
- p : array-like, shape (ns,)
- Distribution in the source space
- q : array-like, shape (nt,)
- Distribution in the target space
- loss_fun : str
- Loss function used for the solver either 'square_loss' or 'kl_loss'
- epsilon : float
- Regularization term >0
- max_iter : int, optional
- Max number of iterations
- tol : float, optional
- Stop threshold on error (>0)
- verbose : bool, optional
- Print information along iterations
- log : bool, optional
- Record log if True.
-
- Returns
- -------
- gw_dist : float
- Gromov-Wasserstein distance
-
- References
- ----------
- .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon,
- "Gromov-Wasserstein averaging of kernel and distance matrices."
- International Conference on Machine Learning (ICML). 2016.
-
- """
- gw, logv = entropic_gromov_wasserstein(
- C1, C2, p, q, loss_fun, epsilon, max_iter, tol, verbose, log=True)
-
- logv['T'] = gw
-
- if log:
- return logv['gw_dist'], logv
- else:
- return logv['gw_dist']
-
-
-def entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon,
- max_iter=1000, tol=1e-9, verbose=False, log=False, init_C=None, random_state=None):
- r"""
- Returns the gromov-wasserstein barycenters of `S` measured similarity matrices :math:`(\mathbf{C}_s)_{1 \leq s \leq S}`
-
- The function solves the following optimization problem:
-
- .. math::
-
- \mathbf{C} = \mathop{\arg \min}_{\mathbf{C}\in \mathbb{R}^{N \times N}} \quad \sum_s \lambda_s \mathrm{GW}(\mathbf{C}, \mathbf{C}_s, \mathbf{p}, \mathbf{p}_s)
-
- Where :
-
- - :math:`\mathbf{C}_s`: metric cost matrix
- - :math:`\mathbf{p}_s`: distribution
-
- Parameters
- ----------
- N : int
- Size of the targeted barycenter
- Cs : list of S array-like of shape (ns,ns)
- Metric cost matrices
- ps : list of S array-like of shape (ns,)
- Sample weights in the `S` spaces
- p : array-like, shape(N,)
- Weights in the targeted barycenter
- lambdas : list of float
- List of the `S` spaces' weights.
- loss_fun : callable
- Tensor-matrix multiplication function based on specific loss function.
- update : callable
- function(:math:`\mathbf{p}`, lambdas, :math:`\mathbf{T}`, :math:`\mathbf{Cs}`) that updates
- :math:`\mathbf{C}` according to a specific Kernel with the `S` :math:`\mathbf{T}_s` couplings
- calculated at each iteration
- epsilon : float
- Regularization term >0
- max_iter : int, optional
- Max number of iterations
- tol : float, optional
- Stop threshold on error (>0)
- verbose : bool, optional
- Print information along iterations.
- log : bool, optional
- Record log if True.
- init_C : bool | array-like, shape (N, N)
- Random initial value for the :math:`\mathbf{C}` matrix provided by user.
- random_state : int or RandomState instance, optional
- Fix the seed for reproducibility
-
- Returns
- -------
- C : array-like, shape (`N`, `N`)
- Similarity matrix in the barycenter space (permutated arbitrarily)
- log : dict
- Log dictionary of error during iterations. Return only if `log=True` in parameters.
-
- References
- ----------
- .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon,
- "Gromov-Wasserstein averaging of kernel and distance matrices."
- International Conference on Machine Learning (ICML). 2016.
- """
- Cs = list_to_array(*Cs)
- ps = list_to_array(*ps)
- p = list_to_array(p)
- nx = get_backend(*Cs, *ps, p)
-
- S = len(Cs)
-
- # Initialization of C : random SPD matrix (if not provided by user)
- if init_C is None:
- generator = check_random_state(random_state)
- xalea = generator.randn(N, 2)
- C = dist(xalea, xalea)
- C /= C.max()
- C = nx.from_numpy(C, type_as=p)
- else:
- C = init_C
-
- cpt = 0
- err = 1
-
- error = []
-
- while (err > tol) and (cpt < max_iter):
- Cprev = C
-
- T = [entropic_gromov_wasserstein(Cs[s], C, ps[s], p, loss_fun, epsilon,
- max_iter, 1e-4, verbose, log=False) for s in range(S)]
- if loss_fun == 'square_loss':
- C = update_square_loss(p, lambdas, T, Cs)
-
- elif loss_fun == 'kl_loss':
- C = update_kl_loss(p, lambdas, T, Cs)
-
- if cpt % 10 == 0:
- # we can speed up the process by checking for the error only all
- # the 10th iterations
- err = nx.norm(C - Cprev)
- error.append(err)
-
- if verbose:
- if cpt % 200 == 0:
- print('{:5s}|{:12s}'.format(
- 'It.', 'Err') + '\n' + '-' * 19)
- print('{:5d}|{:8e}|'.format(cpt, err))
-
- cpt += 1
-
- if log:
- return C, {"err": error}
- else:
- return C
-
-
-def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun,
- max_iter=1000, tol=1e-9, verbose=False, log=False, init_C=None, random_state=None):
- r"""
- Returns the gromov-wasserstein barycenters of `S` measured similarity matrices :math:`(\mathbf{C}_s)_{1 \leq s \leq S}`
-
- The function solves the following optimization problem with block coordinate descent:
-
- .. math::
-
- \mathbf{C} = \mathop{\arg \min}_{\mathbf{C}\in \mathbb{R}^{N \times N}} \quad \sum_s \lambda_s \mathrm{GW}(\mathbf{C}, \mathbf{C}_s, \mathbf{p}, \mathbf{p}_s)
-
- Where :
-
- - :math:`\mathbf{C}_s`: metric cost matrix
- - :math:`\mathbf{p}_s`: distribution
-
- Parameters
- ----------
- N : int
- Size of the targeted barycenter
- Cs : list of S array-like of shape (ns, ns)
- Metric cost matrices
- ps : list of S array-like of shape (ns,)
- Sample weights in the `S` spaces
- p : array-like, shape (N,)
- Weights in the targeted barycenter
- lambdas : list of float
- List of the `S` spaces' weights
- loss_fun : callable
- tensor-matrix multiplication function based on specific loss function
- update : callable
- function(:math:`\mathbf{p}`, lambdas, :math:`\mathbf{T}`, :math:`\mathbf{Cs}`) that updates
- :math:`\mathbf{C}` according to a specific Kernel with the `S` :math:`\mathbf{T}_s` couplings
- calculated at each iteration
- max_iter : int, optional
- Max number of iterations
- tol : float, optional
- Stop threshold on error (>0).
- verbose : bool, optional
- Print information along iterations.
- log : bool, optional
- Record log if True.
- init_C : bool | array-like, shape(N,N)
- Random initial value for the :math:`\mathbf{C}` matrix provided by user.
- random_state : int or RandomState instance, optional
- Fix the seed for reproducibility
-
- Returns
- -------
- C : array-like, shape (`N`, `N`)
- Similarity matrix in the barycenter space (permutated arbitrarily)
- log : dict
- Log dictionary of error during iterations. Return only if `log=True` in parameters.
-
- References
- ----------
- .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon,
- "Gromov-Wasserstein averaging of kernel and distance matrices."
- International Conference on Machine Learning (ICML). 2016.
-
- """
- Cs = list_to_array(*Cs)
- ps = list_to_array(*ps)
- p = list_to_array(p)
- nx = get_backend(*Cs, *ps, p)
-
- S = len(Cs)
-
- # Initialization of C : random SPD matrix (if not provided by user)
- if init_C is None:
- generator = check_random_state(random_state)
- xalea = generator.randn(N, 2)
- C = dist(xalea, xalea)
- C /= C.max()
- C = nx.from_numpy(C, type_as=p)
- else:
- C = init_C
-
- cpt = 0
- err = 1
-
- error = []
-
- while(err > tol and cpt < max_iter):
- Cprev = C
-
- T = [gromov_wasserstein(Cs[s], C, ps[s], p, loss_fun,
- numItermax=max_iter, stopThr=1e-5, verbose=verbose, log=False) for s in range(S)]
- if loss_fun == 'square_loss':
- C = update_square_loss(p, lambdas, T, Cs)
-
- elif loss_fun == 'kl_loss':
- C = update_kl_loss(p, lambdas, T, Cs)
-
- if cpt % 10 == 0:
- # we can speed up the process by checking for the error only all
- # the 10th iterations
- err = nx.norm(C - Cprev)
- error.append(err)
-
- if verbose:
- if cpt % 200 == 0:
- print('{:5s}|{:12s}'.format(
- 'It.', 'Err') + '\n' + '-' * 19)
- print('{:5d}|{:8e}|'.format(cpt, err))
-
- cpt += 1
-
- if log:
- return C, {"err": error}
- else:
- return C
-
-
-def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_features=False,
- p=None, loss_fun='square_loss', max_iter=100, tol=1e-9,
- verbose=False, log=False, init_C=None, init_X=None, random_state=None):
- r"""Compute the fgw barycenter as presented eq (5) in :ref:`[24] <references-fgw-barycenters>`
-
- Parameters
- ----------
- N : int
- Desired number of samples of the target barycenter
- Ys: list of array-like, each element has shape (ns,d)
- Features of all samples
- Cs : list of array-like, each element has shape (ns,ns)
- Structure matrices of all samples
- ps : list of array-like, each element has shape (ns,)
- Masses of all samples.
- lambdas : list of float
- List of the `S` spaces' weights
- alpha : float
- Alpha parameter for the fgw distance
- fixed_structure : bool
- Whether to fix the structure of the barycenter during the updates
- fixed_features : bool
- Whether to fix the feature of the barycenter during the updates
- loss_fun : str
- Loss function used for the solver either 'square_loss' or 'kl_loss'
- max_iter : int, optional
- Max number of iterations
- tol : float, optional
- Stop threshold on error (>0).
- verbose : bool, optional
- Print information along iterations.
- log : bool, optional
- Record log if True.
- init_C : array-like, shape (N,N), optional
- Initialization for the barycenters' structure matrix. If not set
- a random init is used.
- init_X : array-like, shape (N,d), optional
- Initialization for the barycenters' features. If not set a
- random init is used.
- random_state : int or RandomState instance, optional
- Fix the seed for reproducibility
-
- Returns
- -------
- X : array-like, shape (`N`, `d`)
- Barycenters' features
- C : array-like, shape (`N`, `N`)
- Barycenters' structure matrix
- log : dict
- Only returned when log=True. It contains the keys:
-
- - :math:`\mathbf{T}`: list of (`N`, `ns`) transport matrices
- - :math:`(\mathbf{M}_s)_s`: all distance matrices between the feature of the barycenter and the other features :math:`(dist(\mathbf{X}, \mathbf{Y}_s))_s` shape (`N`, `ns`)
-
-
- .. _references-fgw-barycenters:
- References
- ----------
- .. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain
- and Courty Nicolas
- "Optimal Transport for structured data with application on graphs"
- International Conference on Machine Learning (ICML). 2019.
- """
- Cs = list_to_array(*Cs)
- ps = list_to_array(*ps)
- Ys = list_to_array(*Ys)
- p = list_to_array(p)
- nx = get_backend(*Cs, *Ys, *ps)
-
- S = len(Cs)
- d = Ys[0].shape[1] # dimension on the node features
- if p is None:
- p = nx.ones(N, type_as=Cs[0]) / N
-
- if fixed_structure:
- if init_C is None:
- raise UndefinedParameter('If C is fixed it must be initialized')
- else:
- C = init_C
- else:
- if init_C is None:
- generator = check_random_state(random_state)
- xalea = generator.randn(N, 2)
- C = dist(xalea, xalea)
- C = nx.from_numpy(C, type_as=ps[0])
- else:
- C = init_C
-
- if fixed_features:
- if init_X is None:
- raise UndefinedParameter('If X is fixed it must be initialized')
- else:
- X = init_X
- else:
- if init_X is None:
- X = nx.zeros((N, d), type_as=ps[0])
- else:
- X = init_X
-
- T = [nx.outer(p, q) for q in ps]
-
- Ms = [dist(X, Ys[s]) for s in range(len(Ys))]
-
- cpt = 0
- err_feature = 1
- err_structure = 1
-
- if log:
- log_ = {}
- log_['err_feature'] = []
- log_['err_structure'] = []
- log_['Ts_iter'] = []
-
- while((err_feature > tol or err_structure > tol) and cpt < max_iter):
- Cprev = C
- Xprev = X
-
- if not fixed_features:
- Ys_temp = [y.T for y in Ys]
- X = update_feature_matrix(lambdas, Ys_temp, T, p).T
-
- Ms = [dist(X, Ys[s]) for s in range(len(Ys))]
-
- if not fixed_structure:
- if loss_fun == 'square_loss':
- T_temp = [t.T for t in T]
- C = update_structure_matrix(p, lambdas, T_temp, Cs)
-
- T = [fused_gromov_wasserstein(Ms[s], C, Cs[s], p, ps[s], loss_fun, alpha,
- numItermax=max_iter, stopThr=1e-5, verbose=verbose) for s in range(S)]
-
- # T is N,ns
- err_feature = nx.norm(X - nx.reshape(Xprev, (N, d)))
- err_structure = nx.norm(C - Cprev)
- if log:
- log_['err_feature'].append(err_feature)
- log_['err_structure'].append(err_structure)
- log_['Ts_iter'].append(T)
-
- if verbose:
- if cpt % 200 == 0:
- print('{:5s}|{:12s}'.format(
- 'It.', 'Err') + '\n' + '-' * 19)
- print('{:5d}|{:8e}|'.format(cpt, err_structure))
- print('{:5d}|{:8e}|'.format(cpt, err_feature))
-
- cpt += 1
-
- if log:
- log_['T'] = T # from target to Ys
- log_['p'] = p
- log_['Ms'] = Ms
-
- if log:
- return X, C, log_
- else:
- return X, C
-
-
-def update_structure_matrix(p, lambdas, T, Cs):
- r"""Updates :math:`\mathbf{C}` according to the L2 Loss kernel with the `S` :math:`\mathbf{T}_s` couplings.
-
- It is calculated at each iteration
-
- Parameters
- ----------
- p : array-like, shape (N,)
- Masses in the targeted barycenter.
- lambdas : list of float
- List of the `S` spaces' weights.
- T : list of S array-like of shape (ns, N)
- The `S` :math:`\mathbf{T}_s` couplings calculated at each iteration.
- Cs : list of S array-like, shape (ns, ns)
- Metric cost matrices.
-
- Returns
- -------
- C : array-like, shape (`nt`, `nt`)
- Updated :math:`\mathbf{C}` matrix.
- """
- p = list_to_array(p)
- T = list_to_array(*T)
- Cs = list_to_array(*Cs)
- nx = get_backend(*Cs, *T, p)
-
- tmpsum = sum([
- lambdas[s] * nx.dot(
- nx.dot(T[s].T, Cs[s]),
- T[s]
- ) for s in range(len(T))
- ])
- ppt = nx.outer(p, p)
- return tmpsum / ppt
-
-
-def update_feature_matrix(lambdas, Ys, Ts, p):
- r"""Updates the feature with respect to the `S` :math:`\mathbf{T}_s` couplings.
-
-
- See "Solving the barycenter problem with Block Coordinate Descent (BCD)"
- in :ref:`[24] <references-update-feature-matrix>` calculated at each iteration
-
- Parameters
- ----------
- p : array-like, shape (N,)
- masses in the targeted barycenter
- lambdas : list of float
- List of the `S` spaces' weights
- Ts : list of S array-like, shape (ns,N)
- The `S` :math:`\mathbf{T}_s` couplings calculated at each iteration
- Ys : list of S array-like, shape (d,ns)
- The features.
-
- Returns
- -------
- X : array-like, shape (`d`, `N`)
-
-
- .. _references-update-feature-matrix:
- References
- ----------
- .. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain and Courty Nicolas
- "Optimal Transport for structured data with application on graphs"
- International Conference on Machine Learning (ICML). 2019.
- """
- p = list_to_array(p)
- Ts = list_to_array(*Ts)
- Ys = list_to_array(*Ys)
- nx = get_backend(*Ys, *Ts, p)
-
- p = 1. / p
- tmpsum = sum([
- lambdas[s] * nx.dot(Ys[s], Ts[s].T) * p[None, :]
- for s in range(len(Ts))
- ])
- return tmpsum
-
-
-def gromov_wasserstein_dictionary_learning(Cs, D, nt, reg=0., ps=None, q=None, epochs=20, batch_size=32, learning_rate=1., Cdict_init=None, projection='nonnegative_symmetric', use_log=True,
- tol_outer=10**(-5), tol_inner=10**(-5), max_iter_outer=20, max_iter_inner=200, use_adam_optimizer=True, verbose=False, **kwargs):
- r"""
- Infer Gromov-Wasserstein linear dictionary :math:`\{ (\mathbf{C_{dict}[d]}, q) \}_{d \in [D]}` from the list of structures :math:`\{ (\mathbf{C_s},\mathbf{p_s}) \}_s`
-
- .. math::
- \min_{\mathbf{C_{dict}}, \{\mathbf{w_s} \}_{s \leq S}} \sum_{s=1}^S GW_2(\mathbf{C_s}, \sum_{d=1}^D w_{s,d}\mathbf{C_{dict}[d]}, \mathbf{p_s}, \mathbf{q}) - reg\| \mathbf{w_s} \|_2^2
-
- such that, :math:`\forall s \leq S` :
-
- - :math:`\mathbf{w_s}^\top \mathbf{1}_D = 1`
- - :math:`\mathbf{w_s} \geq \mathbf{0}_D`
-
- Where :
-
- - :math:`\forall s \leq S, \mathbf{C_s}` is a (ns,ns) pairwise similarity matrix of variable size ns.
- - :math:`\mathbf{C_{dict}}` is a (D, nt, nt) tensor of D pairwise similarity matrix of fixed size nt.
- - :math:`\forall s \leq S, \mathbf{p_s}` is the source distribution corresponding to :math:`\mathbf{C_s}`
- - :math:`\mathbf{q}` is the target distribution assigned to every structures in the embedding space.
- - reg is the regularization coefficient.
-
- The stochastic algorithm used for estimating the graph dictionary atoms as proposed in [38]
-
- Parameters
- ----------
- Cs : list of S symmetric array-like, shape (ns, ns)
- List of Metric/Graph cost matrices of variable size (ns, ns).
- D: int
- Number of dictionary atoms to learn
- nt: int
- Number of samples within each dictionary atoms
- reg : float, optional
- Coefficient of the negative quadratic regularization used to promote sparsity of w. The default is 0.
- ps : list of S array-like, shape (ns,), optional
- Distribution in each source space C of Cs. Default is None and corresponds to uniform distibutions.
- q : array-like, shape (nt,), optional
- Distribution in the embedding space whose structure will be learned. Default is None and corresponds to uniform distributions.
- epochs: int, optional
- Number of epochs used to learn the dictionary. Default is 32.
- batch_size: int, optional
- Batch size for each stochastic gradient update of the dictionary. Set to the dataset size if the provided batch_size is higher than the dataset size. Default is 32.
- learning_rate: float, optional
- Learning rate used for the stochastic gradient descent. Default is 1.
- Cdict_init: list of D array-like with shape (nt, nt), optional
- Used to initialize the dictionary.
- If set to None (Default), the dictionary will be initialized randomly.
- Else Cdict must have shape (D, nt, nt) i.e match provided shape features.
- projection: str , optional
- If 'nonnegative' and/or 'symmetric' is in projection, the corresponding projection will be performed at each stochastic update of the dictionary
- Else the set of atoms is :math:`R^{nt * nt}`. Default is 'nonnegative_symmetric'
- log: bool, optional
- If set to True, losses evolution by batches and epochs are tracked. Default is False.
- use_adam_optimizer: bool, optional
- If set to True, adam optimizer with default settings is used as adaptative learning rate strategy.
- Else perform SGD with fixed learning rate. Default is True.
- tol_outer : float, optional
- Solver precision for the BCD algorithm, measured by absolute relative error on consecutive losses. Default is :math:`10^{-5}`.
- tol_inner : float, optional
- Solver precision for the Conjugate Gradient algorithm used to get optimal w at a fixed transport, measured by absolute relative error on consecutive losses. Default is :math:`10^{-5}`.
- max_iter_outer : int, optional
- Maximum number of iterations for the BCD. Default is 20.
- max_iter_inner : int, optional
- Maximum number of iterations for the Conjugate Gradient. Default is 200.
- verbose : bool, optional
- Print the reconstruction loss every epoch. Default is False.
-
- Returns
- -------
-
- Cdict_best_state : D array-like, shape (D,nt,nt)
- Metric/Graph cost matrices composing the dictionary.
- The dictionary leading to the best loss over an epoch is saved and returned.
- log: dict
- If use_log is True, contains loss evolutions by batches and epochs.
- References
- -------
-
- ..[38] Cédric Vincent-Cuaz, Titouan Vayer, Rémi Flamary, Marco Corneli, Nicolas Courty.
- "Online Graph Dictionary Learning"
- International Conference on Machine Learning (ICML). 2021.
- """
- # Handle backend of non-optional arguments
- Cs0 = Cs
- nx = get_backend(*Cs0)
- Cs = [nx.to_numpy(C) for C in Cs0]
- dataset_size = len(Cs)
- # Handle backend of optional arguments
- if ps is None:
- ps = [unif(C.shape[0]) for C in Cs]
- else:
- ps = [nx.to_numpy(p) for p in ps]
- if q is None:
- q = unif(nt)
- else:
- q = nx.to_numpy(q)
- if Cdict_init is None:
- # Initialize randomly structures of dictionary atoms based on samples
- dataset_means = [C.mean() for C in Cs]
- Cdict = np.random.normal(loc=np.mean(dataset_means), scale=np.std(dataset_means), size=(D, nt, nt))
- else:
- Cdict = nx.to_numpy(Cdict_init).copy()
- assert Cdict.shape == (D, nt, nt)
-
- if 'symmetric' in projection:
- Cdict = 0.5 * (Cdict + Cdict.transpose((0, 2, 1)))
- if 'nonnegative' in projection:
- Cdict[Cdict < 0.] = 0
- if use_adam_optimizer:
- adam_moments = _initialize_adam_optimizer(Cdict)
-
- log = {'loss_batches': [], 'loss_epochs': []}
- const_q = q[:, None] * q[None, :]
- Cdict_best_state = Cdict.copy()
- loss_best_state = np.inf
- if batch_size > dataset_size:
- batch_size = dataset_size
- iter_by_epoch = dataset_size // batch_size + int((dataset_size % batch_size) > 0)
-
- for epoch in range(epochs):
- cumulated_loss_over_epoch = 0.
-
- for _ in range(iter_by_epoch):
- # batch sampling
- batch = np.random.choice(range(dataset_size), size=batch_size, replace=False)
- cumulated_loss_over_batch = 0.
- unmixings = np.zeros((batch_size, D))
- Cs_embedded = np.zeros((batch_size, nt, nt))
- Ts = [None] * batch_size
-
- for batch_idx, C_idx in enumerate(batch):
- # BCD solver for Gromov-Wassersteisn linear unmixing used independently on each structure of the sampled batch
- unmixings[batch_idx], Cs_embedded[batch_idx], Ts[batch_idx], current_loss = gromov_wasserstein_linear_unmixing(
- Cs[C_idx], Cdict, reg=reg, p=ps[C_idx], q=q, tol_outer=tol_outer, tol_inner=tol_inner,
- max_iter_outer=max_iter_outer, max_iter_inner=max_iter_inner
- )
- cumulated_loss_over_batch += current_loss
- cumulated_loss_over_epoch += cumulated_loss_over_batch
-
- if use_log:
- log['loss_batches'].append(cumulated_loss_over_batch)
-
- # Stochastic projected gradient step over dictionary atoms
- grad_Cdict = np.zeros_like(Cdict)
- for batch_idx, C_idx in enumerate(batch):
- shared_term_structures = Cs_embedded[batch_idx] * const_q - (Cs[C_idx].dot(Ts[batch_idx])).T.dot(Ts[batch_idx])
- grad_Cdict += unmixings[batch_idx][:, None, None] * shared_term_structures[None, :, :]
- grad_Cdict *= 2 / batch_size
- if use_adam_optimizer:
- Cdict, adam_moments = _adam_stochastic_updates(Cdict, grad_Cdict, learning_rate, adam_moments)
- else:
- Cdict -= learning_rate * grad_Cdict
- if 'symmetric' in projection:
- Cdict = 0.5 * (Cdict + Cdict.transpose((0, 2, 1)))
- if 'nonnegative' in projection:
- Cdict[Cdict < 0.] = 0.
-
- if use_log:
- log['loss_epochs'].append(cumulated_loss_over_epoch)
- if loss_best_state > cumulated_loss_over_epoch:
- loss_best_state = cumulated_loss_over_epoch
- Cdict_best_state = Cdict.copy()
- if verbose:
- print('--- epoch =', epoch, ' cumulated reconstruction error: ', cumulated_loss_over_epoch)
-
- return nx.from_numpy(Cdict_best_state), log
-
-
-def _initialize_adam_optimizer(variable):
-
- # Initialization for our numpy implementation of adam optimizer
- atoms_adam_m = np.zeros_like(variable) # Initialize first moment tensor
- atoms_adam_v = np.zeros_like(variable) # Initialize second moment tensor
- atoms_adam_count = 1
-
- return {'mean': atoms_adam_m, 'var': atoms_adam_v, 'count': atoms_adam_count}
-
-
-def _adam_stochastic_updates(variable, grad, learning_rate, adam_moments, beta_1=0.9, beta_2=0.99, eps=1e-09):
-
- adam_moments['mean'] = beta_1 * adam_moments['mean'] + (1 - beta_1) * grad
- adam_moments['var'] = beta_2 * adam_moments['var'] + (1 - beta_2) * (grad**2)
- unbiased_m = adam_moments['mean'] / (1 - beta_1**adam_moments['count'])
- unbiased_v = adam_moments['var'] / (1 - beta_2**adam_moments['count'])
- variable -= learning_rate * unbiased_m / (np.sqrt(unbiased_v) + eps)
- adam_moments['count'] += 1
-
- return variable, adam_moments
-
-
-def gromov_wasserstein_linear_unmixing(C, Cdict, reg=0., p=None, q=None, tol_outer=10**(-5), tol_inner=10**(-5), max_iter_outer=20, max_iter_inner=200, **kwargs):
- r"""
- Returns the Gromov-Wasserstein linear unmixing of :math:`(\mathbf{C},\mathbf{p})` onto the dictionary :math:`\{ (\mathbf{C_{dict}[d]}, \mathbf{q}) \}_{d \in [D]}`.
-
- .. math::
- \min_{ \mathbf{w}} GW_2(\mathbf{C}, \sum_{d=1}^D w_d\mathbf{C_{dict}[d]}, \mathbf{p}, \mathbf{q}) - reg \| \mathbf{w} \|_2^2
-
- such that:
-
- - :math:`\mathbf{w}^\top \mathbf{1}_D = 1`
- - :math:`\mathbf{w} \geq \mathbf{0}_D`
-
- Where :
-
- - :math:`\mathbf{C}` is the (ns,ns) pairwise similarity matrix.
- - :math:`\mathbf{C_{dict}}` is a (D, nt, nt) tensor of D pairwise similarity matrices of size nt.
- - :math:`\mathbf{p}` and :math:`\mathbf{q}` are source and target weights.
- - reg is the regularization coefficient.
-
- The algorithm used for solving the problem is a Block Coordinate Descent as discussed in [38], algorithm 1.
-
- Parameters
- ----------
- C : array-like, shape (ns, ns)
- Metric/Graph cost matrix.
- Cdict : D array-like, shape (D,nt,nt)
- Metric/Graph cost matrices composing the dictionary on which to embed C.
- reg : float, optional.
- Coefficient of the negative quadratic regularization used to promote sparsity of w. Default is 0.
- p : array-like, shape (ns,), optional
- Distribution in the source space C. Default is None and corresponds to uniform distribution.
- q : array-like, shape (nt,), optional
- Distribution in the space depicted by the dictionary. Default is None and corresponds to uniform distribution.
- tol_outer : float, optional
- Solver precision for the BCD algorithm.
- tol_inner : float, optional
- Solver precision for the Conjugate Gradient algorithm used to get optimal w at a fixed transport. Default is :math:`10^{-5}`.
- max_iter_outer : int, optional
- Maximum number of iterations for the BCD. Default is 20.
- max_iter_inner : int, optional
- Maximum number of iterations for the Conjugate Gradient. Default is 200.
-
- Returns
- -------
- w: array-like, shape (D,)
- gromov-wasserstein linear unmixing of :math:`(\mathbf{C},\mathbf{p})` onto the span of the dictionary.
- Cembedded: array-like, shape (nt,nt)
- embedded structure of :math:`(\mathbf{C},\mathbf{p})` onto the dictionary, :math:`\sum_d w_d\mathbf{C_{dict}[d]}`.
- T: array-like (ns, nt)
- Gromov-Wasserstein transport plan between :math:`(\mathbf{C},\mathbf{p})` and :math:`(\sum_d w_d\mathbf{C_{dict}[d]}, \mathbf{q})`
- current_loss: float
- reconstruction error
- References
- -------
-
- ..[38] Cédric Vincent-Cuaz, Titouan Vayer, Rémi Flamary, Marco Corneli, Nicolas Courty.
- "Online Graph Dictionary Learning"
- International Conference on Machine Learning (ICML). 2021.
- """
- C0, Cdict0 = C, Cdict
- nx = get_backend(C0, Cdict0)
- C = nx.to_numpy(C0)
- Cdict = nx.to_numpy(Cdict0)
- if p is None:
- p = unif(C.shape[0])
- else:
- p = nx.to_numpy(p)
-
- if q is None:
- q = unif(Cdict.shape[-1])
- else:
- q = nx.to_numpy(q)
-
- T = p[:, None] * q[None, :]
- D = len(Cdict)
-
- w = unif(D) # Initialize uniformly the unmixing w
- Cembedded = np.sum(w[:, None, None] * Cdict, axis=0)
-
- const_q = q[:, None] * q[None, :]
- # Trackers for BCD convergence
- convergence_criterion = np.inf
- current_loss = 10**15
- outer_count = 0
-
- while (convergence_criterion > tol_outer) and (outer_count < max_iter_outer):
- previous_loss = current_loss
- # 1. Solve GW transport between (C,p) and (\sum_d Cdictionary[d],q) fixing the unmixing w
- T, log = gromov_wasserstein(C1=C, C2=Cembedded, p=p, q=q, loss_fun='square_loss', G0=T, log=True, armijo=False, **kwargs)
- current_loss = log['gw_dist']
- if reg != 0:
- current_loss -= reg * np.sum(w**2)
-
- # 2. Solve linear unmixing problem over w with a fixed transport plan T
- w, Cembedded, current_loss = _cg_gromov_wasserstein_unmixing(
- C=C, Cdict=Cdict, Cembedded=Cembedded, w=w, const_q=const_q, T=T,
- starting_loss=current_loss, reg=reg, tol=tol_inner, max_iter=max_iter_inner, **kwargs
- )
-
- if previous_loss != 0:
- convergence_criterion = abs(previous_loss - current_loss) / abs(previous_loss)
- else: # handle numerical issues around 0
- convergence_criterion = abs(previous_loss - current_loss) / 10**(-15)
- outer_count += 1
-
- return nx.from_numpy(w), nx.from_numpy(Cembedded), nx.from_numpy(T), nx.from_numpy(current_loss)
-
-
-def _cg_gromov_wasserstein_unmixing(C, Cdict, Cembedded, w, const_q, T, starting_loss, reg=0., tol=10**(-5), max_iter=200, **kwargs):
- r"""
- Returns for a fixed admissible transport plan,
- the linear unmixing w minimizing the Gromov-Wasserstein cost between :math:`(\mathbf{C},\mathbf{p})` and :math:`(\sum_d w[d]*\mathbf{C_{dict}[d]}, \mathbf{q})`
-
- .. math::
- \min_{\mathbf{w}} \sum_{ijkl} (C_{i,j} - \sum_{d=1}^D w_d*C_{dict}[d]_{k,l} )^2 T_{i,k}T_{j,l} - reg* \| \mathbf{w} \|_2^2
-
-
- Such that:
-
- - :math:`\mathbf{w}^\top \mathbf{1}_D = 1`
- - :math:`\mathbf{w} \geq \mathbf{0}_D`
-
- Where :
-
- - :math:`\mathbf{C}` is the (ns,ns) pairwise similarity matrix.
- - :math:`\mathbf{C_{dict}}` is a (D, nt, nt) tensor of D pairwise similarity matrices of nt points.
- - :math:`\mathbf{p}` and :math:`\mathbf{q}` are source and target weights.
- - :math:`\mathbf{w}` is the linear unmixing of :math:`(\mathbf{C}, \mathbf{p})` onto :math:`(\sum_d w_d \mathbf{Cdict[d]}, \mathbf{q})`.
- - :math:`\mathbf{T}` is the optimal transport plan conditioned by the current state of :math:`\mathbf{w}`.
- - reg is the regularization coefficient.
-
- The algorithm used for solving the problem is a Conditional Gradient Descent as discussed in [38]
-
- Parameters
- ----------
-
- C : array-like, shape (ns, ns)
- Metric/Graph cost matrix.
- Cdict : list of D array-like, shape (nt,nt)
- Metric/Graph cost matrices composing the dictionary on which to embed C.
- Each matrix in the dictionary must have the same size (nt,nt).
- Cembedded: array-like, shape (nt,nt)
- Embedded structure :math:`(\sum_d w[d]*Cdict[d],q)` of :math:`(\mathbf{C},\mathbf{p})` onto the dictionary. Used to avoid redundant computations.
- w: array-like, shape (D,)
- Linear unmixing of the input structure onto the dictionary
- const_q: array-like, shape (nt,nt)
- product matrix :math:`\mathbf{q}\mathbf{q}^\top` where q is the target space distribution. Used to avoid redundant computations.
- T: array-like, shape (ns,nt)
- fixed transport plan between the input structure and its representation in the dictionary.
- p : array-like, shape (ns,)
- Distribution in the source space.
- q : array-like, shape (nt,)
- Distribution in the embedding space depicted by the dictionary.
- reg : float, optional.
- Coefficient of the negative quadratic regularization used to promote sparsity of w. Default is 0.
-
- Returns
- -------
- w: ndarray (D,)
- optimal unmixing of :math:`(\mathbf{C},\mathbf{p})` onto the dictionary span given OT starting from previously optimal unmixing.
- """
- convergence_criterion = np.inf
- current_loss = starting_loss
- count = 0
- const_TCT = np.transpose(C.dot(T)).dot(T)
-
- while (convergence_criterion > tol) and (count < max_iter):
-
- previous_loss = current_loss
- # 1) Compute gradient at current point w
- grad_w = 2 * np.sum(Cdict * (Cembedded[None, :, :] * const_q[None, :, :] - const_TCT[None, :, :]), axis=(1, 2))
- grad_w -= 2 * reg * w
-
- # 2) Conditional gradient direction finding: x= \argmin_x x^T.grad_w
- min_ = np.min(grad_w)
- x = (grad_w == min_).astype(np.float64)
- x /= np.sum(x)
-
- # 3) Line-search step: solve \argmin_{\gamma \in [0,1]} a*gamma^2 + b*gamma + c
- gamma, a, b, Cembedded_diff = _linesearch_gromov_wasserstein_unmixing(w, grad_w, x, Cdict, Cembedded, const_q, const_TCT, reg)
-
- # 4) Updates: w <-- (1-gamma)*w + gamma*x
- w += gamma * (x - w)
- Cembedded += gamma * Cembedded_diff
- current_loss += a * (gamma**2) + b * gamma
-
- if previous_loss != 0: # not that the loss can be negative if reg >0
- convergence_criterion = abs(previous_loss - current_loss) / abs(previous_loss)
- else: # handle numerical issues around 0
- convergence_criterion = abs(previous_loss - current_loss) / 10**(-15)
- count += 1
-
- return w, Cembedded, current_loss
-
-
-def _linesearch_gromov_wasserstein_unmixing(w, grad_w, x, Cdict, Cembedded, const_q, const_TCT, reg, **kwargs):
- r"""
- Compute optimal steps for the line search problem of Gromov-Wasserstein linear unmixing
- .. math::
- \min_{\gamma \in [0,1]} \sum_{ijkl} (C_{i,j} - \sum_{d=1}^D z_d(\gamma)C_{dict}[d]_{k,l} )^2 T_{i,k}T_{j,l} - reg\| \mathbf{z}(\gamma) \|_2^2
-
-
- Such that:
-
- - :math:`\mathbf{z}(\gamma) = (1- \gamma)\mathbf{w} + \gamma \mathbf{x}`
-
- Parameters
- ----------
-
- w : array-like, shape (D,)
- Unmixing.
- grad_w : array-like, shape (D, D)
- Gradient of the reconstruction loss with respect to w.
- x: array-like, shape (D,)
- Conditional gradient direction.
- Cdict : list of D array-like, shape (nt,nt)
- Metric/Graph cost matrices composing the dictionary on which to embed C.
- Each matrix in the dictionary must have the same size (nt,nt).
- Cembedded: array-like, shape (nt,nt)
- Embedded structure :math:`(\sum_d w_dCdict[d],q)` of :math:`(\mathbf{C},\mathbf{p})` onto the dictionary. Used to avoid redundant computations.
- const_q: array-like, shape (nt,nt)
- product matrix :math:`\mathbf{q}\mathbf{q}^\top` where q is the target space distribution. Used to avoid redundant computations.
- const_TCT: array-like, shape (nt, nt)
- :math:`\mathbf{T}^\top \mathbf{C}^\top \mathbf{T}`. Used to avoid redundant computations.
- Returns
- -------
- gamma: float
- Optimal value for the line-search step
- a: float
- Constant factor appearing in the factorization :math:`a \gamma^2 + b \gamma +c` of the reconstruction loss
- b: float
- Constant factor appearing in the factorization :math:`a \gamma^2 + b \gamma +c` of the reconstruction loss
- Cembedded_diff: numpy array, shape (nt, nt)
- Difference between models evaluated in :math:`\mathbf{w}` and in :math:`\mathbf{w}`.
- reg : float, optional.
- Coefficient of the negative quadratic regularization used to promote sparsity of :math:`\mathbf{w}`.
- """
-
- # 3) Line-search step: solve \argmin_{\gamma \in [0,1]} a*gamma^2 + b*gamma + c
- Cembedded_x = np.sum(x[:, None, None] * Cdict, axis=0)
- Cembedded_diff = Cembedded_x - Cembedded
- trace_diffx = np.sum(Cembedded_diff * Cembedded_x * const_q)
- trace_diffw = np.sum(Cembedded_diff * Cembedded * const_q)
- a = trace_diffx - trace_diffw
- b = 2 * (trace_diffw - np.sum(Cembedded_diff * const_TCT))
- if reg != 0:
- a -= reg * np.sum((x - w)**2)
- b -= 2 * reg * np.sum(w * (x - w))
-
- if a > 0:
- gamma = min(1, max(0, - b / (2 * a)))
- elif a + b < 0:
- gamma = 1
- else:
- gamma = 0
-
- return gamma, a, b, Cembedded_diff
-
-
-def fused_gromov_wasserstein_dictionary_learning(Cs, Ys, D, nt, alpha, reg=0., ps=None, q=None, epochs=20, batch_size=32, learning_rate_C=1., learning_rate_Y=1.,
- Cdict_init=None, Ydict_init=None, projection='nonnegative_symmetric', use_log=False,
- tol_outer=10**(-5), tol_inner=10**(-5), max_iter_outer=20, max_iter_inner=200, use_adam_optimizer=True, verbose=False, **kwargs):
- r"""
- Infer Fused Gromov-Wasserstein linear dictionary :math:`\{ (\mathbf{C_{dict}[d]}, \mathbf{Y_{dict}[d]}, \mathbf{q}) \}_{d \in [D]}` from the list of S attributed structures :math:`\{ (\mathbf{C_s}, \mathbf{Y_s},\mathbf{p_s}) \}_s`
-
- .. math::
- \min_{\mathbf{C_{dict}},\mathbf{Y_{dict}}, \{\mathbf{w_s}\}_{s}} \sum_{s=1}^S FGW_{2,\alpha}(\mathbf{C_s}, \mathbf{Y_s}, \sum_{d=1}^D w_{s,d}\mathbf{C_{dict}[d]},\sum_{d=1}^D w_{s,d}\mathbf{Y_{dict}[d]}, \mathbf{p_s}, \mathbf{q}) \\ - reg\| \mathbf{w_s} \|_2^2
-
-
- Such that :math:`\forall s \leq S` :
-
- - :math:`\mathbf{w_s}^\top \mathbf{1}_D = 1`
- - :math:`\mathbf{w_s} \geq \mathbf{0}_D`
-
- Where :
-
- - :math:`\forall s \leq S, \mathbf{C_s}` is a (ns,ns) pairwise similarity matrix of variable size ns.
- - :math:`\forall s \leq S, \mathbf{Y_s}` is a (ns,d) features matrix of variable size ns and fixed dimension d.
- - :math:`\mathbf{C_{dict}}` is a (D, nt, nt) tensor of D pairwise similarity matrix of fixed size nt.
- - :math:`\mathbf{Y_{dict}}` is a (D, nt, d) tensor of D features matrix of fixed size nt and fixed dimension d.
- - :math:`\forall s \leq S, \mathbf{p_s}` is the source distribution corresponding to :math:`\mathbf{C_s}`
- - :math:`\mathbf{q}` is the target distribution assigned to every structures in the embedding space.
- - :math:`\alpha` is the trade-off parameter of Fused Gromov-Wasserstein
- - reg is the regularization coefficient.
-
-
- The stochastic algorithm used for estimating the attributed graph dictionary atoms as proposed in [38]
-
- Parameters
- ----------
- Cs : list of S symmetric array-like, shape (ns, ns)
- List of Metric/Graph cost matrices of variable size (ns,ns).
- Ys : list of S array-like, shape (ns, d)
- List of feature matrix of variable size (ns,d) with d fixed.
- D: int
- Number of dictionary atoms to learn
- nt: int
- Number of samples within each dictionary atoms
- alpha : float
- Trade-off parameter of Fused Gromov-Wasserstein
- reg : float, optional
- Coefficient of the negative quadratic regularization used to promote sparsity of w. The default is 0.
- ps : list of S array-like, shape (ns,), optional
- Distribution in each source space C of Cs. Default is None and corresponds to uniform distibutions.
- q : array-like, shape (nt,), optional
- Distribution in the embedding space whose structure will be learned. Default is None and corresponds to uniform distributions.
- epochs: int, optional
- Number of epochs used to learn the dictionary. Default is 32.
- batch_size: int, optional
- Batch size for each stochastic gradient update of the dictionary. Set to the dataset size if the provided batch_size is higher than the dataset size. Default is 32.
- learning_rate_C: float, optional
- Learning rate used for the stochastic gradient descent on Cdict. Default is 1.
- learning_rate_Y: float, optional
- Learning rate used for the stochastic gradient descent on Ydict. Default is 1.
- Cdict_init: list of D array-like with shape (nt, nt), optional
- Used to initialize the dictionary structures Cdict.
- If set to None (Default), the dictionary will be initialized randomly.
- Else Cdict must have shape (D, nt, nt) i.e match provided shape features.
- Ydict_init: list of D array-like with shape (nt, d), optional
- Used to initialize the dictionary features Ydict.
- If set to None, the dictionary features will be initialized randomly.
- Else Ydict must have shape (D, nt, d) where d is the features dimension of inputs Ys and also match provided shape features.
- projection: str, optional
- If 'nonnegative' and/or 'symmetric' is in projection, the corresponding projection will be performed at each stochastic update of the dictionary
- Else the set of atoms is :math:`R^{nt * nt}`. Default is 'nonnegative_symmetric'
- log: bool, optional
- If set to True, losses evolution by batches and epochs are tracked. Default is False.
- use_adam_optimizer: bool, optional
- If set to True, adam optimizer with default settings is used as adaptative learning rate strategy.
- Else perform SGD with fixed learning rate. Default is True.
- tol_outer : float, optional
- Solver precision for the BCD algorithm, measured by absolute relative error on consecutive losses. Default is :math:`10^{-5}`.
- tol_inner : float, optional
- Solver precision for the Conjugate Gradient algorithm used to get optimal w at a fixed transport, measured by absolute relative error on consecutive losses. Default is :math:`10^{-5}`.
- max_iter_outer : int, optional
- Maximum number of iterations for the BCD. Default is 20.
- max_iter_inner : int, optional
- Maximum number of iterations for the Conjugate Gradient. Default is 200.
- verbose : bool, optional
- Print the reconstruction loss every epoch. Default is False.
-
- Returns
- -------
-
- Cdict_best_state : D array-like, shape (D,nt,nt)
- Metric/Graph cost matrices composing the dictionary.
- The dictionary leading to the best loss over an epoch is saved and returned.
- Ydict_best_state : D array-like, shape (D,nt,d)
- Feature matrices composing the dictionary.
- The dictionary leading to the best loss over an epoch is saved and returned.
- log: dict
- If use_log is True, contains loss evolutions by batches and epoches.
- References
- -------
-
- ..[38] Cédric Vincent-Cuaz, Titouan Vayer, Rémi Flamary, Marco Corneli, Nicolas Courty.
- "Online Graph Dictionary Learning"
- International Conference on Machine Learning (ICML). 2021.
- """
- Cs0, Ys0 = Cs, Ys
- nx = get_backend(*Cs0, *Ys0)
- Cs = [nx.to_numpy(C) for C in Cs0]
- Ys = [nx.to_numpy(Y) for Y in Ys0]
-
- d = Ys[0].shape[-1]
- dataset_size = len(Cs)
-
- if ps is None:
- ps = [unif(C.shape[0]) for C in Cs]
- else:
- ps = [nx.to_numpy(p) for p in ps]
- if q is None:
- q = unif(nt)
- else:
- q = nx.to_numpy(q)
-
- if Cdict_init is None:
- # Initialize randomly structures of dictionary atoms based on samples
- dataset_means = [C.mean() for C in Cs]
- Cdict = np.random.normal(loc=np.mean(dataset_means), scale=np.std(dataset_means), size=(D, nt, nt))
- else:
- Cdict = nx.to_numpy(Cdict_init).copy()
- assert Cdict.shape == (D, nt, nt)
- if Ydict_init is None:
- # Initialize randomly features of dictionary atoms based on samples distribution by feature component
- dataset_feature_means = np.stack([F.mean(axis=0) for F in Ys])
- Ydict = np.random.normal(loc=dataset_feature_means.mean(axis=0), scale=dataset_feature_means.std(axis=0), size=(D, nt, d))
- else:
- Ydict = nx.to_numpy(Ydict_init).copy()
- assert Ydict.shape == (D, nt, d)
-
- if 'symmetric' in projection:
- Cdict = 0.5 * (Cdict + Cdict.transpose((0, 2, 1)))
- if 'nonnegative' in projection:
- Cdict[Cdict < 0.] = 0.
-
- if use_adam_optimizer:
- adam_moments_C = _initialize_adam_optimizer(Cdict)
- adam_moments_Y = _initialize_adam_optimizer(Ydict)
-
- log = {'loss_batches': [], 'loss_epochs': []}
- const_q = q[:, None] * q[None, :]
- diag_q = np.diag(q)
- Cdict_best_state = Cdict.copy()
- Ydict_best_state = Ydict.copy()
- loss_best_state = np.inf
- if batch_size > dataset_size:
- batch_size = dataset_size
- iter_by_epoch = dataset_size // batch_size + int((dataset_size % batch_size) > 0)
-
- for epoch in range(epochs):
- cumulated_loss_over_epoch = 0.
-
- for _ in range(iter_by_epoch):
-
- # Batch iterations
- batch = np.random.choice(range(dataset_size), size=batch_size, replace=False)
- cumulated_loss_over_batch = 0.
- unmixings = np.zeros((batch_size, D))
- Cs_embedded = np.zeros((batch_size, nt, nt))
- Ys_embedded = np.zeros((batch_size, nt, d))
- Ts = [None] * batch_size
-
- for batch_idx, C_idx in enumerate(batch):
- # BCD solver for Gromov-Wassersteisn linear unmixing used independently on each structure of the sampled batch
- unmixings[batch_idx], Cs_embedded[batch_idx], Ys_embedded[batch_idx], Ts[batch_idx], current_loss = fused_gromov_wasserstein_linear_unmixing(
- Cs[C_idx], Ys[C_idx], Cdict, Ydict, alpha, reg=reg, p=ps[C_idx], q=q,
- tol_outer=tol_outer, tol_inner=tol_inner, max_iter_outer=max_iter_outer, max_iter_inner=max_iter_inner
- )
- cumulated_loss_over_batch += current_loss
- cumulated_loss_over_epoch += cumulated_loss_over_batch
- if use_log:
- log['loss_batches'].append(cumulated_loss_over_batch)
-
- # Stochastic projected gradient step over dictionary atoms
- grad_Cdict = np.zeros_like(Cdict)
- grad_Ydict = np.zeros_like(Ydict)
-
- for batch_idx, C_idx in enumerate(batch):
- shared_term_structures = Cs_embedded[batch_idx] * const_q - (Cs[C_idx].dot(Ts[batch_idx])).T.dot(Ts[batch_idx])
- shared_term_features = diag_q.dot(Ys_embedded[batch_idx]) - Ts[batch_idx].T.dot(Ys[C_idx])
- grad_Cdict += alpha * unmixings[batch_idx][:, None, None] * shared_term_structures[None, :, :]
- grad_Ydict += (1 - alpha) * unmixings[batch_idx][:, None, None] * shared_term_features[None, :, :]
- grad_Cdict *= 2 / batch_size
- grad_Ydict *= 2 / batch_size
-
- if use_adam_optimizer:
- Cdict, adam_moments_C = _adam_stochastic_updates(Cdict, grad_Cdict, learning_rate_C, adam_moments_C)
- Ydict, adam_moments_Y = _adam_stochastic_updates(Ydict, grad_Ydict, learning_rate_Y, adam_moments_Y)
- else:
- Cdict -= learning_rate_C * grad_Cdict
- Ydict -= learning_rate_Y * grad_Ydict
-
- if 'symmetric' in projection:
- Cdict = 0.5 * (Cdict + Cdict.transpose((0, 2, 1)))
- if 'nonnegative' in projection:
- Cdict[Cdict < 0.] = 0.
-
- if use_log:
- log['loss_epochs'].append(cumulated_loss_over_epoch)
- if loss_best_state > cumulated_loss_over_epoch:
- loss_best_state = cumulated_loss_over_epoch
- Cdict_best_state = Cdict.copy()
- Ydict_best_state = Ydict.copy()
- if verbose:
- print('--- epoch: ', epoch, ' cumulated reconstruction error: ', cumulated_loss_over_epoch)
-
- return nx.from_numpy(Cdict_best_state), nx.from_numpy(Ydict_best_state), log
-
-
-def fused_gromov_wasserstein_linear_unmixing(C, Y, Cdict, Ydict, alpha, reg=0., p=None, q=None, tol_outer=10**(-5), tol_inner=10**(-5), max_iter_outer=20, max_iter_inner=200, **kwargs):
- r"""
- Returns the Fused Gromov-Wasserstein linear unmixing of :math:`(\mathbf{C},\mathbf{Y},\mathbf{p})` onto the attributed dictionary atoms :math:`\{ (\mathbf{C_{dict}[d]},\mathbf{Y_{dict}[d]}, \mathbf{q}) \}_{d \in [D]}`
-
- .. math::
- \min_{\mathbf{w}} FGW_{2,\alpha}(\mathbf{C},\mathbf{Y}, \sum_{d=1}^D w_d\mathbf{C_{dict}[d]},\sum_{d=1}^D w_d\mathbf{Y_{dict}[d]}, \mathbf{p}, \mathbf{q}) - reg \| \mathbf{w} \|_2^2
-
- such that, :math:`\forall s \leq S` :
-
- - :math:`\mathbf{w_s}^\top \mathbf{1}_D = 1`
- - :math:`\mathbf{w_s} \geq \mathbf{0}_D`
-
- Where :
-
- - :math:`\mathbf{C}` is a (ns,ns) pairwise similarity matrix of variable size ns.
- - :math:`\mathbf{Y}` is a (ns,d) features matrix of variable size ns and fixed dimension d.
- - :math:`\mathbf{C_{dict}}` is a (D, nt, nt) tensor of D pairwise similarity matrix of fixed size nt.
- - :math:`\mathbf{Y_{dict}}` is a (D, nt, d) tensor of D features matrix of fixed size nt and fixed dimension d.
- - :math:`\mathbf{p}` is the source distribution corresponding to :math:`\mathbf{C_s}`
- - :math:`\mathbf{q}` is the target distribution assigned to every structures in the embedding space.
- - :math:`\alpha` is the trade-off parameter of Fused Gromov-Wasserstein
- - reg is the regularization coefficient.
-
- The algorithm used for solving the problem is a Block Coordinate Descent as discussed in [38], algorithm 6.
-
- Parameters
- ----------
- C : array-like, shape (ns, ns)
- Metric/Graph cost matrix.
- Y : array-like, shape (ns, d)
- Feature matrix.
- Cdict : D array-like, shape (D,nt,nt)
- Metric/Graph cost matrices composing the dictionary on which to embed (C,Y).
- Ydict : D array-like, shape (D,nt,d)
- Feature matrices composing the dictionary on which to embed (C,Y).
- alpha: float,
- Trade-off parameter of Fused Gromov-Wasserstein.
- reg : float, optional
- Coefficient of the negative quadratic regularization used to promote sparsity of w. The default is 0.
- p : array-like, shape (ns,), optional
- Distribution in the source space C. Default is None and corresponds to uniform distribution.
- q : array-like, shape (nt,), optional
- Distribution in the space depicted by the dictionary. Default is None and corresponds to uniform distribution.
- tol_outer : float, optional
- Solver precision for the BCD algorithm.
- tol_inner : float, optional
- Solver precision for the Conjugate Gradient algorithm used to get optimal w at a fixed transport. Default is :math:`10^{-5}`.
- max_iter_outer : int, optional
- Maximum number of iterations for the BCD. Default is 20.
- max_iter_inner : int, optional
- Maximum number of iterations for the Conjugate Gradient. Default is 200.
-
- Returns
- -------
- w: array-like, shape (D,)
- fused gromov-wasserstein linear unmixing of (C,Y,p) onto the span of the dictionary.
- Cembedded: array-like, shape (nt,nt)
- embedded structure of :math:`(\mathbf{C},\mathbf{Y}, \mathbf{p})` onto the dictionary, :math:`\sum_d w_d\mathbf{C_{dict}[d]}`.
- Yembedded: array-like, shape (nt,d)
- embedded features of :math:`(\mathbf{C},\mathbf{Y}, \mathbf{p})` onto the dictionary, :math:`\sum_d w_d\mathbf{Y_{dict}[d]}`.
- T: array-like (ns,nt)
- Fused Gromov-Wasserstein transport plan between :math:`(\mathbf{C},\mathbf{p})` and :math:`(\sum_d w_d\mathbf{C_{dict}[d]}, \sum_d w_d\mathbf{Y_{dict}[d]},\mathbf{q})`.
- current_loss: float
- reconstruction error
- References
- -------
-
- ..[38] Cédric Vincent-Cuaz, Titouan Vayer, Rémi Flamary, Marco Corneli, Nicolas Courty.
- "Online Graph Dictionary Learning"
- International Conference on Machine Learning (ICML). 2021.
- """
- C0, Y0, Cdict0, Ydict0 = C, Y, Cdict, Ydict
- nx = get_backend(C0, Y0, Cdict0, Ydict0)
- C = nx.to_numpy(C0)
- Y = nx.to_numpy(Y0)
- Cdict = nx.to_numpy(Cdict0)
- Ydict = nx.to_numpy(Ydict0)
-
- if p is None:
- p = unif(C.shape[0])
- else:
- p = nx.to_numpy(p)
- if q is None:
- q = unif(Cdict.shape[-1])
- else:
- q = nx.to_numpy(q)
-
- T = p[:, None] * q[None, :]
- D = len(Cdict)
- d = Y.shape[-1]
- w = unif(D) # Initialize with uniform weights
- ns = C.shape[-1]
- nt = Cdict.shape[-1]
-
- # modeling (C,Y)
- Cembedded = np.sum(w[:, None, None] * Cdict, axis=0)
- Yembedded = np.sum(w[:, None, None] * Ydict, axis=0)
-
- # constants depending on q
- const_q = q[:, None] * q[None, :]
- diag_q = np.diag(q)
- # Trackers for BCD convergence
- convergence_criterion = np.inf
- current_loss = 10**15
- outer_count = 0
- Ys_constM = (Y**2).dot(np.ones((d, nt))) # constant in computing euclidean pairwise feature matrix
-
- while (convergence_criterion > tol_outer) and (outer_count < max_iter_outer):
- previous_loss = current_loss
-
- # 1. Solve GW transport between (C,p) and (\sum_d Cdictionary[d],q) fixing the unmixing w
- Yt_varM = (np.ones((ns, d))).dot((Yembedded**2).T)
- M = Ys_constM + Yt_varM - 2 * Y.dot(Yembedded.T) # euclidean distance matrix between features
- T, log = fused_gromov_wasserstein(M, C, Cembedded, p, q, loss_fun='square_loss', alpha=alpha, armijo=False, G0=T, log=True)
- current_loss = log['fgw_dist']
- if reg != 0:
- current_loss -= reg * np.sum(w**2)
-
- # 2. Solve linear unmixing problem over w with a fixed transport plan T
- w, Cembedded, Yembedded, current_loss = _cg_fused_gromov_wasserstein_unmixing(C, Y, Cdict, Ydict, Cembedded, Yembedded, w,
- T, p, q, const_q, diag_q, current_loss, alpha, reg,
- tol=tol_inner, max_iter=max_iter_inner, **kwargs)
- if previous_loss != 0:
- convergence_criterion = abs(previous_loss - current_loss) / abs(previous_loss)
- else:
- convergence_criterion = abs(previous_loss - current_loss) / 10**(-12)
- outer_count += 1
-
- return nx.from_numpy(w), nx.from_numpy(Cembedded), nx.from_numpy(Yembedded), nx.from_numpy(T), nx.from_numpy(current_loss)
-
-
-def _cg_fused_gromov_wasserstein_unmixing(C, Y, Cdict, Ydict, Cembedded, Yembedded, w, T, p, q, const_q, diag_q, starting_loss, alpha, reg, tol=10**(-6), max_iter=200, **kwargs):
- r"""
- Returns for a fixed admissible transport plan,
- the optimal linear unmixing :math:`\mathbf{w}` minimizing the Fused Gromov-Wasserstein cost between :math:`(\mathbf{C},\mathbf{Y},\mathbf{p})` and :math:`(\sum_d w_d \mathbf{C_{dict}[d]},\sum_d w_d*\mathbf{Y_{dict}[d]}, \mathbf{q})`
-
- .. math::
- \min_{\mathbf{w}} \alpha \sum_{ijkl} (C_{i,j} - \sum_{d=1}^D w_d C_{dict}[d]_{k,l} )^2 T_{i,k}T_{j,l} \\+ (1-\alpha) \sum_{ij} \| \mathbf{Y_i} - \sum_d w_d \mathbf{Y_{dict}[d]_j} \|_2^2 T_{ij}- reg \| \mathbf{w} \|_2^2
-
- Such that :
-
- - :math:`\mathbf{w}^\top \mathbf{1}_D = 1`
- - :math:`\mathbf{w} \geq \mathbf{0}_D`
-
- Where :
-
- - :math:`\mathbf{C}` is a (ns,ns) pairwise similarity matrix of variable size ns.
- - :math:`\mathbf{Y}` is a (ns,d) features matrix of variable size ns and fixed dimension d.
- - :math:`\mathbf{C_{dict}}` is a (D, nt, nt) tensor of D pairwise similarity matrix of fixed size nt.
- - :math:`\mathbf{Y_{dict}}` is a (D, nt, d) tensor of D features matrix of fixed size nt and fixed dimension d.
- - :math:`\mathbf{p}` is the source distribution corresponding to :math:`\mathbf{C_s}`
- - :math:`\mathbf{q}` is the target distribution assigned to every structures in the embedding space.
- - :math:`\mathbf{T}` is the optimal transport plan conditioned by the previous state of :math:`\mathbf{w}`
- - :math:`\alpha` is the trade-off parameter of Fused Gromov-Wasserstein
- - reg is the regularization coefficient.
-
- The algorithm used for solving the problem is a Conditional Gradient Descent as discussed in [38], algorithm 7.
-
- Parameters
- ----------
-
- C : array-like, shape (ns, ns)
- Metric/Graph cost matrix.
- Y : array-like, shape (ns, d)
- Feature matrix.
- Cdict : list of D array-like, shape (nt,nt)
- Metric/Graph cost matrices composing the dictionary on which to embed (C,Y).
- Each matrix in the dictionary must have the same size (nt,nt).
- Ydict : list of D array-like, shape (nt,d)
- Feature matrices composing the dictionary on which to embed (C,Y).
- Each matrix in the dictionary must have the same size (nt,d).
- Cembedded: array-like, shape (nt,nt)
- Embedded structure of (C,Y) onto the dictionary
- Yembedded: array-like, shape (nt,d)
- Embedded features of (C,Y) onto the dictionary
- w: array-like, shape (n_D,)
- Linear unmixing of (C,Y) onto (Cdict,Ydict)
- const_q: array-like, shape (nt,nt)
- product matrix :math:`\mathbf{qq}^\top` where :math:`\mathbf{q}` is the target space distribution.
- diag_q: array-like, shape (nt,nt)
- diagonal matrix with values of q on the diagonal.
- T: array-like, shape (ns,nt)
- fixed transport plan between (C,Y) and its model
- p : array-like, shape (ns,)
- Distribution in the source space (C,Y).
- q : array-like, shape (nt,)
- Distribution in the embedding space depicted by the dictionary.
- alpha: float,
- Trade-off parameter of Fused Gromov-Wasserstein.
- reg : float, optional
- Coefficient of the negative quadratic regularization used to promote sparsity of w.
-
- Returns
- -------
- w: ndarray (D,)
- linear unmixing of :math:`(\mathbf{C},\mathbf{Y},\mathbf{p})` onto the span of :math:`(C_{dict},Y_{dict})` given OT corresponding to previous unmixing.
- """
- convergence_criterion = np.inf
- current_loss = starting_loss
- count = 0
- const_TCT = np.transpose(C.dot(T)).dot(T)
- ones_ns_d = np.ones(Y.shape)
-
- while (convergence_criterion > tol) and (count < max_iter):
- previous_loss = current_loss
-
- # 1) Compute gradient at current point w
- # structure
- grad_w = alpha * np.sum(Cdict * (Cembedded[None, :, :] * const_q[None, :, :] - const_TCT[None, :, :]), axis=(1, 2))
- # feature
- grad_w += (1 - alpha) * np.sum(Ydict * (diag_q.dot(Yembedded)[None, :, :] - T.T.dot(Y)[None, :, :]), axis=(1, 2))
- grad_w -= reg * w
- grad_w *= 2
-
- # 2) Conditional gradient direction finding: x= \argmin_x x^T.grad_w
- min_ = np.min(grad_w)
- x = (grad_w == min_).astype(np.float64)
- x /= np.sum(x)
-
- # 3) Line-search step: solve \argmin_{\gamma \in [0,1]} a*gamma^2 + b*gamma + c
- gamma, a, b, Cembedded_diff, Yembedded_diff = _linesearch_fused_gromov_wasserstein_unmixing(w, grad_w, x, Y, Cdict, Ydict, Cembedded, Yembedded, T, const_q, const_TCT, ones_ns_d, alpha, reg)
-
- # 4) Updates: w <-- (1-gamma)*w + gamma*x
- w += gamma * (x - w)
- Cembedded += gamma * Cembedded_diff
- Yembedded += gamma * Yembedded_diff
- current_loss += a * (gamma**2) + b * gamma
-
- if previous_loss != 0:
- convergence_criterion = abs(previous_loss - current_loss) / abs(previous_loss)
- else:
- convergence_criterion = abs(previous_loss - current_loss) / 10**(-12)
- count += 1
-
- return w, Cembedded, Yembedded, current_loss
-
-
-def _linesearch_fused_gromov_wasserstein_unmixing(w, grad_w, x, Y, Cdict, Ydict, Cembedded, Yembedded, T, const_q, const_TCT, ones_ns_d, alpha, reg, **kwargs):
- r"""
- Compute optimal steps for the line search problem of Fused Gromov-Wasserstein linear unmixing
- .. math::
- \min_{\gamma \in [0,1]} \alpha \sum_{ijkl} (C_{i,j} - \sum_{d=1}^D z_d(\gamma)C_{dict}[d]_{k,l} )^2 T_{i,k}T_{j,l} \\ + (1-\alpha) \sum_{ij} \| \mathbf{Y_i} - \sum_d z_d(\gamma) \mathbf{Y_{dict}[d]_j} \|_2^2 - reg\| \mathbf{z}(\gamma) \|_2^2
-
-
- Such that :
-
- - :math:`\mathbf{z}(\gamma) = (1- \gamma)\mathbf{w} + \gamma \mathbf{x}`
-
- Parameters
- ----------
-
- w : array-like, shape (D,)
- Unmixing.
- grad_w : array-like, shape (D, D)
- Gradient of the reconstruction loss with respect to w.
- x: array-like, shape (D,)
- Conditional gradient direction.
- Y: arrat-like, shape (ns,d)
- Feature matrix of the input space
- Cdict : list of D array-like, shape (nt, nt)
- Metric/Graph cost matrices composing the dictionary on which to embed (C,Y).
- Each matrix in the dictionary must have the same size (nt,nt).
- Ydict : list of D array-like, shape (nt, d)
- Feature matrices composing the dictionary on which to embed (C,Y).
- Each matrix in the dictionary must have the same size (nt,d).
- Cembedded: array-like, shape (nt, nt)
- Embedded structure of (C,Y) onto the dictionary
- Yembedded: array-like, shape (nt, d)
- Embedded features of (C,Y) onto the dictionary
- T: array-like, shape (ns, nt)
- Fixed transport plan between (C,Y) and its current model.
- const_q: array-like, shape (nt,nt)
- product matrix :math:`\mathbf{q}\mathbf{q}^\top` where q is the target space distribution. Used to avoid redundant computations.
- const_TCT: array-like, shape (nt, nt)
- :math:`\mathbf{T}^\top \mathbf{C}^\top \mathbf{T}`. Used to avoid redundant computations.
- ones_ns_d: array-like, shape (ns, d)
- :math:`\mathbf{1}_{ ns \times d}`. Used to avoid redundant computations.
- alpha: float,
- Trade-off parameter of Fused Gromov-Wasserstein.
- reg : float, optional
- Coefficient of the negative quadratic regularization used to promote sparsity of w.
-
- Returns
- -------
- gamma: float
- Optimal value for the line-search step
- a: float
- Constant factor appearing in the factorization :math:`a \gamma^2 + b \gamma +c` of the reconstruction loss
- b: float
- Constant factor appearing in the factorization :math:`a \gamma^2 + b \gamma +c` of the reconstruction loss
- Cembedded_diff: numpy array, shape (nt, nt)
- Difference between structure matrix of models evaluated in :math:`\mathbf{w}` and in :math:`\mathbf{w}`.
- Yembedded_diff: numpy array, shape (nt, nt)
- Difference between feature matrix of models evaluated in :math:`\mathbf{w}` and in :math:`\mathbf{w}`.
- """
- # polynomial coefficients from quadratic objective (with respect to w) on structures
- Cembedded_x = np.sum(x[:, None, None] * Cdict, axis=0)
- Cembedded_diff = Cembedded_x - Cembedded
- trace_diffx = np.sum(Cembedded_diff * Cembedded_x * const_q)
- trace_diffw = np.sum(Cembedded_diff * Cembedded * const_q)
- # Constant factor appearing in the factorization a*gamma^2 + b*g + c of the Gromov-Wasserstein reconstruction loss
- a_gw = trace_diffx - trace_diffw
- b_gw = 2 * (trace_diffw - np.sum(Cembedded_diff * const_TCT))
-
- # polynomial coefficient from quadratic objective (with respect to w) on features
- Yembedded_x = np.sum(x[:, None, None] * Ydict, axis=0)
- Yembedded_diff = Yembedded_x - Yembedded
- # Constant factor appearing in the factorization a*gamma^2 + b*g + c of the Gromov-Wasserstein reconstruction loss
- a_w = np.sum(ones_ns_d.dot((Yembedded_diff**2).T) * T)
- b_w = 2 * np.sum(T * (ones_ns_d.dot((Yembedded * Yembedded_diff).T) - Y.dot(Yembedded_diff.T)))
-
- a = alpha * a_gw + (1 - alpha) * a_w
- b = alpha * b_gw + (1 - alpha) * b_w
- if reg != 0:
- a -= reg * np.sum((x - w)**2)
- b -= 2 * reg * np.sum(w * (x - w))
- if a > 0:
- gamma = min(1, max(0, -b / (2 * a)))
- elif a + b < 0:
- gamma = 1
- else:
- gamma = 0
-
- return gamma, a, b, Cembedded_diff, Yembedded_diff
diff --git a/ot/gromov/__init__.py b/ot/gromov/__init__.py
new file mode 100644
index 0000000..6184edf
--- /dev/null
+++ b/ot/gromov/__init__.py
@@ -0,0 +1,48 @@
+# -*- coding: utf-8 -*-
+"""
+Solvers related to Gromov-Wasserstein problems.
+
+"""
+
+# Author: Remi Flamary <remi.flamary@unice.fr>
+# Cedric Vincent-Cuaz <cedvincentcuaz@gmail.com>
+#
+# License: MIT License
+
+# All submodules and packages
+from ._utils import (init_matrix, tensor_product, gwloss, gwggrad,
+ update_square_loss, update_kl_loss,
+ init_matrix_semirelaxed)
+from ._gw import (gromov_wasserstein, gromov_wasserstein2,
+ fused_gromov_wasserstein, fused_gromov_wasserstein2,
+ solve_gromov_linesearch, gromov_barycenters, fgw_barycenters,
+ update_structure_matrix, update_feature_matrix)
+from ._bregman import (entropic_gromov_wasserstein,
+ entropic_gromov_wasserstein2,
+ entropic_gromov_barycenters)
+from ._estimators import (GW_distance_estimation, pointwise_gromov_wasserstein,
+ sampled_gromov_wasserstein)
+from ._semirelaxed import (semirelaxed_gromov_wasserstein,
+ semirelaxed_gromov_wasserstein2,
+ semirelaxed_fused_gromov_wasserstein,
+ semirelaxed_fused_gromov_wasserstein2,
+ solve_semirelaxed_gromov_linesearch)
+from ._dictionary import (gromov_wasserstein_dictionary_learning,
+ gromov_wasserstein_linear_unmixing,
+ fused_gromov_wasserstein_dictionary_learning,
+ fused_gromov_wasserstein_linear_unmixing)
+
+
+__all__ = ['init_matrix', 'tensor_product', 'gwloss', 'gwggrad',
+ 'update_square_loss', 'update_kl_loss', 'init_matrix_semirelaxed',
+ 'gromov_wasserstein', 'gromov_wasserstein2', 'fused_gromov_wasserstein',
+ 'fused_gromov_wasserstein2', 'solve_gromov_linesearch', 'gromov_barycenters',
+ 'fgw_barycenters', 'update_structure_matrix', 'update_feature_matrix',
+ 'entropic_gromov_wasserstein', 'entropic_gromov_wasserstein2',
+ 'entropic_gromov_barycenters', 'GW_distance_estimation',
+ 'pointwise_gromov_wasserstein', 'sampled_gromov_wasserstein',
+ 'semirelaxed_gromov_wasserstein', 'semirelaxed_gromov_wasserstein2',
+ 'semirelaxed_fused_gromov_wasserstein', 'semirelaxed_fused_gromov_wasserstein2',
+ 'solve_semirelaxed_gromov_linesearch', 'gromov_wasserstein_dictionary_learning',
+ 'gromov_wasserstein_linear_unmixing', 'fused_gromov_wasserstein_dictionary_learning',
+ 'fused_gromov_wasserstein_linear_unmixing']
diff --git a/ot/gromov/_bregman.py b/ot/gromov/_bregman.py
new file mode 100644
index 0000000..b0cccfb
--- /dev/null
+++ b/ot/gromov/_bregman.py
@@ -0,0 +1,348 @@
+# -*- coding: utf-8 -*-
+"""
+Bregman projections solvers for entropic Gromov-Wasserstein
+"""
+
+# Author: Erwan Vautier <erwan.vautier@gmail.com>
+# Nicolas Courty <ncourty@irisa.fr>
+# Rémi Flamary <remi.flamary@unice.fr>
+# Titouan Vayer <titouan.vayer@irisa.fr>
+# Cédric Vincent-Cuaz <cedvincentcuaz@gmail.com>
+#
+# License: MIT License
+
+from ..bregman import sinkhorn
+from ..utils import dist, list_to_array, check_random_state
+from ..backend import get_backend
+
+from ._utils import init_matrix, gwloss, gwggrad
+from ._utils import update_square_loss, update_kl_loss
+
+
+def entropic_gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, symmetric=None, G0=None,
+ max_iter=1000, tol=1e-9, verbose=False, log=False):
+ r"""
+ Returns the gromov-wasserstein transport between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})`
+
+ The function solves the following optimization problem:
+
+ .. math::
+ \mathbf{GW} = \mathop{\arg\min}_\mathbf{T} \quad \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} - \epsilon(H(\mathbf{T}))
+
+ s.t. \ \mathbf{T} \mathbf{1} &= \mathbf{p}
+
+ \mathbf{T}^T \mathbf{1} &= \mathbf{q}
+
+ \mathbf{T} &\geq 0
+
+ Where :
+
+ - :math:`\mathbf{C_1}`: Metric cost matrix in the source space
+ - :math:`\mathbf{C_2}`: Metric cost matrix in the target space
+ - :math:`\mathbf{p}`: distribution in the source space
+ - :math:`\mathbf{q}`: distribution in the target space
+ - `L`: loss function to account for the misfit between the similarity matrices
+ - `H`: entropy
+
+ .. note:: If the inner solver `ot.sinkhorn` did not convergence, the
+ optimal coupling :math:`\mathbf{T}` returned by this function does not
+ necessarily satisfy the marginal constraints
+ :math:`\mathbf{T}\mathbf{1}=\mathbf{p}` and
+ :math:`\mathbf{T}^T\mathbf{1}=\mathbf{q}`. So the returned
+ Gromov-Wasserstein loss does not necessarily satisfy distance
+ properties and may be negative.
+
+ Parameters
+ ----------
+ C1 : array-like, shape (ns, ns)
+ Metric cost matrix in the source space
+ C2 : array-like, shape (nt, nt)
+ Metric cost matrix in the target space
+ p : array-like, shape (ns,)
+ Distribution in the source space
+ q : array-like, shape (nt,)
+ Distribution in the target space
+ loss_fun : string
+ Loss function used for the solver either 'square_loss' or 'kl_loss'
+ epsilon : float
+ Regularization term >0
+ symmetric : bool, optional
+ Either C1 and C2 are to be assumed symmetric or not.
+ If let to its default None value, a symmetry test will be conducted.
+ Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymetric).
+ G0: array-like, shape (ns,nt), optional
+ If None the initial transport plan of the solver is pq^T.
+ Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver.
+ max_iter : int, optional
+ Max number of iterations
+ tol : float, optional
+ Stop threshold on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ Record log if True.
+
+ Returns
+ -------
+ T : array-like, shape (`ns`, `nt`)
+ Optimal coupling between the two spaces
+
+ References
+ ----------
+ .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon,
+ "Gromov-Wasserstein averaging of kernel and distance matrices."
+ International Conference on Machine Learning (ICML). 2016.
+
+ .. [47] Chowdhury, S., & Mémoli, F. (2019). The gromov–wasserstein
+ distance between networks and stable network invariants.
+ Information and Inference: A Journal of the IMA, 8(4), 757-787.
+ """
+ C1, C2, p, q = list_to_array(C1, C2, p, q)
+ if G0 is None:
+ nx = get_backend(p, q, C1, C2)
+ G0 = nx.outer(p, q)
+ else:
+ nx = get_backend(p, q, C1, C2, G0)
+ T = G0
+ constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun, nx)
+ if symmetric is None:
+ symmetric = nx.allclose(C1, C1.T, atol=1e-10) and nx.allclose(C2, C2.T, atol=1e-10)
+ if not symmetric:
+ constCt, hC1t, hC2t = init_matrix(C1.T, C2.T, p, q, loss_fun, nx)
+ cpt = 0
+ err = 1
+
+ if log:
+ log = {'err': []}
+
+ while (err > tol and cpt < max_iter):
+
+ Tprev = T
+
+ # compute the gradient
+ if symmetric:
+ tens = gwggrad(constC, hC1, hC2, T, nx)
+ else:
+ tens = 0.5 * (gwggrad(constC, hC1, hC2, T, nx) + gwggrad(constCt, hC1t, hC2t, T, nx))
+ T = sinkhorn(p, q, tens, epsilon, method='sinkhorn')
+
+ if cpt % 10 == 0:
+ # we can speed up the process by checking for the error only all
+ # the 10th iterations
+ err = nx.norm(T - Tprev)
+
+ if log:
+ log['err'].append(err)
+
+ if verbose:
+ if cpt % 200 == 0:
+ print('{:5s}|{:12s}'.format(
+ 'It.', 'Err') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(cpt, err))
+
+ cpt += 1
+
+ if log:
+ log['gw_dist'] = gwloss(constC, hC1, hC2, T, nx)
+ return T, log
+ else:
+ return T
+
+
+def entropic_gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon, symmetric=None, G0=None,
+ max_iter=1000, tol=1e-9, verbose=False, log=False):
+ r"""
+ Returns the entropic gromov-wasserstein discrepancy between the two measured similarity matrices :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})`
+
+ The function solves the following optimization problem:
+
+ .. math::
+ GW = \min_\mathbf{T} \quad \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l})
+ \mathbf{T}_{i,j} \mathbf{T}_{k,l} - \epsilon(H(\mathbf{T}))
+
+ Where :
+
+ - :math:`\mathbf{C_1}`: Metric cost matrix in the source space
+ - :math:`\mathbf{C_2}`: Metric cost matrix in the target space
+ - :math:`\mathbf{p}`: distribution in the source space
+ - :math:`\mathbf{q}`: distribution in the target space
+ - `L`: loss function to account for the misfit between the similarity matrices
+ - `H`: entropy
+
+ .. note:: If the inner solver `ot.sinkhorn` did not convergence, the
+ optimal coupling :math:`\mathbf{T}` returned by this function does not
+ necessarily satisfy the marginal constraints
+ :math:`\mathbf{T}\mathbf{1}=\mathbf{p}` and
+ :math:`\mathbf{T}^T\mathbf{1}=\mathbf{q}`. So the returned
+ Gromov-Wasserstein loss does not necessarily satisfy distance
+ properties and may be negative.
+
+ Parameters
+ ----------
+ C1 : array-like, shape (ns, ns)
+ Metric cost matrix in the source space
+ C2 : array-like, shape (nt, nt)
+ Metric cost matrix in the target space
+ p : array-like, shape (ns,)
+ Distribution in the source space
+ q : array-like, shape (nt,)
+ Distribution in the target space
+ loss_fun : str
+ Loss function used for the solver either 'square_loss' or 'kl_loss'
+ epsilon : float
+ Regularization term >0
+ symmetric : bool, optional
+ Either C1 and C2 are to be assumed symmetric or not.
+ If let to its default None value, a symmetry test will be conducted.
+ Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymetric).
+ G0: array-like, shape (ns,nt), optional
+ If None the initial transport plan of the solver is pq^T.
+ Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver.
+ max_iter : int, optional
+ Max number of iterations
+ tol : float, optional
+ Stop threshold on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ Record log if True.
+
+ Returns
+ -------
+ gw_dist : float
+ Gromov-Wasserstein distance
+
+ References
+ ----------
+ .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon,
+ "Gromov-Wasserstein averaging of kernel and distance matrices."
+ International Conference on Machine Learning (ICML). 2016.
+
+ """
+ gw, logv = entropic_gromov_wasserstein(
+ C1, C2, p, q, loss_fun, epsilon, symmetric, G0, max_iter, tol, verbose, log=True)
+
+ logv['T'] = gw
+
+ if log:
+ return logv['gw_dist'], logv
+ else:
+ return logv['gw_dist']
+
+
+def entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, symmetric=True,
+ max_iter=1000, tol=1e-9, verbose=False, log=False, init_C=None, random_state=None):
+ r"""
+ Returns the gromov-wasserstein barycenters of `S` measured similarity matrices :math:`(\mathbf{C}_s)_{1 \leq s \leq S}`
+
+ The function solves the following optimization problem:
+
+ .. math::
+
+ \mathbf{C} = \mathop{\arg \min}_{\mathbf{C}\in \mathbb{R}^{N \times N}} \quad \sum_s \lambda_s \mathrm{GW}(\mathbf{C}, \mathbf{C}_s, \mathbf{p}, \mathbf{p}_s)
+
+ Where :
+
+ - :math:`\mathbf{C}_s`: metric cost matrix
+ - :math:`\mathbf{p}_s`: distribution
+
+ Parameters
+ ----------
+ N : int
+ Size of the targeted barycenter
+ Cs : list of S array-like of shape (ns,ns)
+ Metric cost matrices
+ ps : list of S array-like of shape (ns,)
+ Sample weights in the `S` spaces
+ p : array-like, shape(N,)
+ Weights in the targeted barycenter
+ lambdas : list of float
+ List of the `S` spaces' weights.
+ loss_fun : callable
+ Tensor-matrix multiplication function based on specific loss function.
+ update : callable
+ function(:math:`\mathbf{p}`, lambdas, :math:`\mathbf{T}`, :math:`\mathbf{Cs}`) that updates
+ :math:`\mathbf{C}` according to a specific Kernel with the `S` :math:`\mathbf{T}_s` couplings
+ calculated at each iteration
+ epsilon : float
+ Regularization term >0
+ symmetric : bool, optional.
+ Either structures are to be assumed symmetric or not. Default value is True.
+ Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymmetric).
+ max_iter : int, optional
+ Max number of iterations
+ tol : float, optional
+ Stop threshold on error (>0)
+ verbose : bool, optional
+ Print information along iterations.
+ log : bool, optional
+ Record log if True.
+ init_C : bool | array-like, shape (N, N)
+ Random initial value for the :math:`\mathbf{C}` matrix provided by user.
+ random_state : int or RandomState instance, optional
+ Fix the seed for reproducibility
+
+ Returns
+ -------
+ C : array-like, shape (`N`, `N`)
+ Similarity matrix in the barycenter space (permutated arbitrarily)
+ log : dict
+ Log dictionary of error during iterations. Return only if `log=True` in parameters.
+
+ References
+ ----------
+ .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon,
+ "Gromov-Wasserstein averaging of kernel and distance matrices."
+ International Conference on Machine Learning (ICML). 2016.
+ """
+ Cs = list_to_array(*Cs)
+ ps = list_to_array(*ps)
+ p = list_to_array(p)
+ nx = get_backend(*Cs, *ps, p)
+
+ S = len(Cs)
+
+ # Initialization of C : random SPD matrix (if not provided by user)
+ if init_C is None:
+ generator = check_random_state(random_state)
+ xalea = generator.randn(N, 2)
+ C = dist(xalea, xalea)
+ C /= C.max()
+ C = nx.from_numpy(C, type_as=p)
+ else:
+ C = init_C
+
+ cpt = 0
+ err = 1
+
+ error = []
+
+ while (err > tol) and (cpt < max_iter):
+ Cprev = C
+
+ T = [entropic_gromov_wasserstein(Cs[s], C, ps[s], p, loss_fun, epsilon, symmetric, None,
+ max_iter, 1e-4, verbose, log=False) for s in range(S)]
+ if loss_fun == 'square_loss':
+ C = update_square_loss(p, lambdas, T, Cs)
+
+ elif loss_fun == 'kl_loss':
+ C = update_kl_loss(p, lambdas, T, Cs)
+
+ if cpt % 10 == 0:
+ # we can speed up the process by checking for the error only all
+ # the 10th iterations
+ err = nx.norm(C - Cprev)
+ error.append(err)
+
+ if verbose:
+ if cpt % 200 == 0:
+ print('{:5s}|{:12s}'.format(
+ 'It.', 'Err') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(cpt, err))
+
+ cpt += 1
+
+ if log:
+ return C, {"err": error}
+ else:
+ return C
diff --git a/ot/gromov/_dictionary.py b/ot/gromov/_dictionary.py
new file mode 100644
index 0000000..5b32671
--- /dev/null
+++ b/ot/gromov/_dictionary.py
@@ -0,0 +1,1008 @@
+# -*- coding: utf-8 -*-
+"""
+(Fused) Gromov-Wasserstein dictionary learning.
+"""
+
+# Author: Rémi Flamary <remi.flamary@unice.fr>
+# Cédric Vincent-Cuaz <cedvincentcuaz@gmail.com>
+#
+# License: MIT License
+
+import numpy as np
+
+
+from ..utils import unif
+from ..backend import get_backend
+from ._gw import gromov_wasserstein, fused_gromov_wasserstein
+
+
+def gromov_wasserstein_dictionary_learning(Cs, D, nt, reg=0., ps=None, q=None, epochs=20, batch_size=32, learning_rate=1., Cdict_init=None, projection='nonnegative_symmetric', use_log=True,
+ tol_outer=10**(-5), tol_inner=10**(-5), max_iter_outer=20, max_iter_inner=200, use_adam_optimizer=True, verbose=False, **kwargs):
+ r"""
+ Infer Gromov-Wasserstein linear dictionary :math:`\{ (\mathbf{C_{dict}[d]}, q) \}_{d \in [D]}` from the list of structures :math:`\{ (\mathbf{C_s},\mathbf{p_s}) \}_s`
+
+ .. math::
+ \min_{\mathbf{C_{dict}}, \{\mathbf{w_s} \}_{s \leq S}} \sum_{s=1}^S GW_2(\mathbf{C_s}, \sum_{d=1}^D w_{s,d}\mathbf{C_{dict}[d]}, \mathbf{p_s}, \mathbf{q}) - reg\| \mathbf{w_s} \|_2^2
+
+ such that, :math:`\forall s \leq S` :
+
+ - :math:`\mathbf{w_s}^\top \mathbf{1}_D = 1`
+ - :math:`\mathbf{w_s} \geq \mathbf{0}_D`
+
+ Where :
+
+ - :math:`\forall s \leq S, \mathbf{C_s}` is a (ns,ns) pairwise similarity matrix of variable size ns.
+ - :math:`\mathbf{C_{dict}}` is a (D, nt, nt) tensor of D pairwise similarity matrix of fixed size nt.
+ - :math:`\forall s \leq S, \mathbf{p_s}` is the source distribution corresponding to :math:`\mathbf{C_s}`
+ - :math:`\mathbf{q}` is the target distribution assigned to every structures in the embedding space.
+ - reg is the regularization coefficient.
+
+ The stochastic algorithm used for estimating the graph dictionary atoms as proposed in [38]_
+
+ Parameters
+ ----------
+ Cs : list of S symmetric array-like, shape (ns, ns)
+ List of Metric/Graph cost matrices of variable size (ns, ns).
+ D: int
+ Number of dictionary atoms to learn
+ nt: int
+ Number of samples within each dictionary atoms
+ reg : float, optional
+ Coefficient of the negative quadratic regularization used to promote sparsity of w. The default is 0.
+ ps : list of S array-like, shape (ns,), optional
+ Distribution in each source space C of Cs. Default is None and corresponds to uniform distibutions.
+ q : array-like, shape (nt,), optional
+ Distribution in the embedding space whose structure will be learned. Default is None and corresponds to uniform distributions.
+ epochs: int, optional
+ Number of epochs used to learn the dictionary. Default is 32.
+ batch_size: int, optional
+ Batch size for each stochastic gradient update of the dictionary. Set to the dataset size if the provided batch_size is higher than the dataset size. Default is 32.
+ learning_rate: float, optional
+ Learning rate used for the stochastic gradient descent. Default is 1.
+ Cdict_init: list of D array-like with shape (nt, nt), optional
+ Used to initialize the dictionary.
+ If set to None (Default), the dictionary will be initialized randomly.
+ Else Cdict must have shape (D, nt, nt) i.e match provided shape features.
+ projection: str , optional
+ If 'nonnegative' and/or 'symmetric' is in projection, the corresponding projection will be performed at each stochastic update of the dictionary
+ Else the set of atoms is :math:`R^{nt * nt}`. Default is 'nonnegative_symmetric'
+ log: bool, optional
+ If set to True, losses evolution by batches and epochs are tracked. Default is False.
+ use_adam_optimizer: bool, optional
+ If set to True, adam optimizer with default settings is used as adaptative learning rate strategy.
+ Else perform SGD with fixed learning rate. Default is True.
+ tol_outer : float, optional
+ Solver precision for the BCD algorithm, measured by absolute relative error on consecutive losses. Default is :math:`10^{-5}`.
+ tol_inner : float, optional
+ Solver precision for the Conjugate Gradient algorithm used to get optimal w at a fixed transport, measured by absolute relative error on consecutive losses. Default is :math:`10^{-5}`.
+ max_iter_outer : int, optional
+ Maximum number of iterations for the BCD. Default is 20.
+ max_iter_inner : int, optional
+ Maximum number of iterations for the Conjugate Gradient. Default is 200.
+ verbose : bool, optional
+ Print the reconstruction loss every epoch. Default is False.
+
+ Returns
+ -------
+
+ Cdict_best_state : D array-like, shape (D,nt,nt)
+ Metric/Graph cost matrices composing the dictionary.
+ The dictionary leading to the best loss over an epoch is saved and returned.
+ log: dict
+ If use_log is True, contains loss evolutions by batches and epochs.
+ References
+ -------
+ .. [38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, Online
+ Graph Dictionary Learning, International Conference on Machine Learning
+ (ICML), 2021.
+ """
+ # Handle backend of non-optional arguments
+ Cs0 = Cs
+ nx = get_backend(*Cs0)
+ Cs = [nx.to_numpy(C) for C in Cs0]
+ dataset_size = len(Cs)
+ # Handle backend of optional arguments
+ if ps is None:
+ ps = [unif(C.shape[0]) for C in Cs]
+ else:
+ ps = [nx.to_numpy(p) for p in ps]
+ if q is None:
+ q = unif(nt)
+ else:
+ q = nx.to_numpy(q)
+ if Cdict_init is None:
+ # Initialize randomly structures of dictionary atoms based on samples
+ dataset_means = [C.mean() for C in Cs]
+ Cdict = np.random.normal(loc=np.mean(dataset_means), scale=np.std(dataset_means), size=(D, nt, nt))
+ else:
+ Cdict = nx.to_numpy(Cdict_init).copy()
+ assert Cdict.shape == (D, nt, nt)
+
+ if 'symmetric' in projection:
+ Cdict = 0.5 * (Cdict + Cdict.transpose((0, 2, 1)))
+ symmetric = True
+ else:
+ symmetric = False
+ if 'nonnegative' in projection:
+ Cdict[Cdict < 0.] = 0
+ if use_adam_optimizer:
+ adam_moments = _initialize_adam_optimizer(Cdict)
+
+ log = {'loss_batches': [], 'loss_epochs': []}
+ const_q = q[:, None] * q[None, :]
+ Cdict_best_state = Cdict.copy()
+ loss_best_state = np.inf
+ if batch_size > dataset_size:
+ batch_size = dataset_size
+ iter_by_epoch = dataset_size // batch_size + int((dataset_size % batch_size) > 0)
+
+ for epoch in range(epochs):
+ cumulated_loss_over_epoch = 0.
+
+ for _ in range(iter_by_epoch):
+ # batch sampling
+ batch = np.random.choice(range(dataset_size), size=batch_size, replace=False)
+ cumulated_loss_over_batch = 0.
+ unmixings = np.zeros((batch_size, D))
+ Cs_embedded = np.zeros((batch_size, nt, nt))
+ Ts = [None] * batch_size
+
+ for batch_idx, C_idx in enumerate(batch):
+ # BCD solver for Gromov-Wassersteisn linear unmixing used independently on each structure of the sampled batch
+ unmixings[batch_idx], Cs_embedded[batch_idx], Ts[batch_idx], current_loss = gromov_wasserstein_linear_unmixing(
+ Cs[C_idx], Cdict, reg=reg, p=ps[C_idx], q=q, tol_outer=tol_outer, tol_inner=tol_inner,
+ max_iter_outer=max_iter_outer, max_iter_inner=max_iter_inner, symmetric=symmetric, **kwargs
+ )
+ cumulated_loss_over_batch += current_loss
+ cumulated_loss_over_epoch += cumulated_loss_over_batch
+
+ if use_log:
+ log['loss_batches'].append(cumulated_loss_over_batch)
+
+ # Stochastic projected gradient step over dictionary atoms
+ grad_Cdict = np.zeros_like(Cdict)
+ for batch_idx, C_idx in enumerate(batch):
+ shared_term_structures = Cs_embedded[batch_idx] * const_q - (Cs[C_idx].dot(Ts[batch_idx])).T.dot(Ts[batch_idx])
+ grad_Cdict += unmixings[batch_idx][:, None, None] * shared_term_structures[None, :, :]
+ grad_Cdict *= 2 / batch_size
+ if use_adam_optimizer:
+ Cdict, adam_moments = _adam_stochastic_updates(Cdict, grad_Cdict, learning_rate, adam_moments)
+ else:
+ Cdict -= learning_rate * grad_Cdict
+ if 'symmetric' in projection:
+ Cdict = 0.5 * (Cdict + Cdict.transpose((0, 2, 1)))
+ if 'nonnegative' in projection:
+ Cdict[Cdict < 0.] = 0.
+
+ if use_log:
+ log['loss_epochs'].append(cumulated_loss_over_epoch)
+ if loss_best_state > cumulated_loss_over_epoch:
+ loss_best_state = cumulated_loss_over_epoch
+ Cdict_best_state = Cdict.copy()
+ if verbose:
+ print('--- epoch =', epoch, ' cumulated reconstruction error: ', cumulated_loss_over_epoch)
+
+ return nx.from_numpy(Cdict_best_state), log
+
+
+def _initialize_adam_optimizer(variable):
+
+ # Initialization for our numpy implementation of adam optimizer
+ atoms_adam_m = np.zeros_like(variable) # Initialize first moment tensor
+ atoms_adam_v = np.zeros_like(variable) # Initialize second moment tensor
+ atoms_adam_count = 1
+
+ return {'mean': atoms_adam_m, 'var': atoms_adam_v, 'count': atoms_adam_count}
+
+
+def _adam_stochastic_updates(variable, grad, learning_rate, adam_moments, beta_1=0.9, beta_2=0.99, eps=1e-09):
+
+ adam_moments['mean'] = beta_1 * adam_moments['mean'] + (1 - beta_1) * grad
+ adam_moments['var'] = beta_2 * adam_moments['var'] + (1 - beta_2) * (grad**2)
+ unbiased_m = adam_moments['mean'] / (1 - beta_1**adam_moments['count'])
+ unbiased_v = adam_moments['var'] / (1 - beta_2**adam_moments['count'])
+ variable -= learning_rate * unbiased_m / (np.sqrt(unbiased_v) + eps)
+ adam_moments['count'] += 1
+
+ return variable, adam_moments
+
+
+def gromov_wasserstein_linear_unmixing(C, Cdict, reg=0., p=None, q=None, tol_outer=10**(-5), tol_inner=10**(-5), max_iter_outer=20, max_iter_inner=200, symmetric=None, **kwargs):
+ r"""
+ Returns the Gromov-Wasserstein linear unmixing of :math:`(\mathbf{C},\mathbf{p})` onto the dictionary :math:`\{ (\mathbf{C_{dict}[d]}, \mathbf{q}) \}_{d \in [D]}`.
+
+ .. math::
+ \min_{ \mathbf{w}} GW_2(\mathbf{C}, \sum_{d=1}^D w_d\mathbf{C_{dict}[d]}, \mathbf{p}, \mathbf{q}) - reg \| \mathbf{w} \|_2^2
+
+ such that:
+
+ - :math:`\mathbf{w}^\top \mathbf{1}_D = 1`
+ - :math:`\mathbf{w} \geq \mathbf{0}_D`
+
+ Where :
+
+ - :math:`\mathbf{C}` is the (ns,ns) pairwise similarity matrix.
+ - :math:`\mathbf{C_{dict}}` is a (D, nt, nt) tensor of D pairwise similarity matrices of size nt.
+ - :math:`\mathbf{p}` and :math:`\mathbf{q}` are source and target weights.
+ - reg is the regularization coefficient.
+
+ The algorithm used for solving the problem is a Block Coordinate Descent as discussed in [38]_ , algorithm 1.
+
+ Parameters
+ ----------
+ C : array-like, shape (ns, ns)
+ Metric/Graph cost matrix.
+ Cdict : D array-like, shape (D,nt,nt)
+ Metric/Graph cost matrices composing the dictionary on which to embed C.
+ reg : float, optional.
+ Coefficient of the negative quadratic regularization used to promote sparsity of w. Default is 0.
+ p : array-like, shape (ns,), optional
+ Distribution in the source space C. Default is None and corresponds to uniform distribution.
+ q : array-like, shape (nt,), optional
+ Distribution in the space depicted by the dictionary. Default is None and corresponds to uniform distribution.
+ tol_outer : float, optional
+ Solver precision for the BCD algorithm.
+ tol_inner : float, optional
+ Solver precision for the Conjugate Gradient algorithm used to get optimal w at a fixed transport. Default is :math:`10^{-5}`.
+ max_iter_outer : int, optional
+ Maximum number of iterations for the BCD. Default is 20.
+ max_iter_inner : int, optional
+ Maximum number of iterations for the Conjugate Gradient. Default is 200.
+
+ Returns
+ -------
+ w: array-like, shape (D,)
+ gromov-wasserstein linear unmixing of :math:`(\mathbf{C},\mathbf{p})` onto the span of the dictionary.
+ Cembedded: array-like, shape (nt,nt)
+ embedded structure of :math:`(\mathbf{C},\mathbf{p})` onto the dictionary, :math:`\sum_d w_d\mathbf{C_{dict}[d]}`.
+ T: array-like (ns, nt)
+ Gromov-Wasserstein transport plan between :math:`(\mathbf{C},\mathbf{p})` and :math:`(\sum_d w_d\mathbf{C_{dict}[d]}, \mathbf{q})`
+ current_loss: float
+ reconstruction error
+ References
+ -------
+ .. [38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, Online
+ Graph Dictionary Learning, International Conference on Machine Learning
+ (ICML), 2021.
+ """
+ C0, Cdict0 = C, Cdict
+ nx = get_backend(C0, Cdict0)
+ C = nx.to_numpy(C0)
+ Cdict = nx.to_numpy(Cdict0)
+ if p is None:
+ p = unif(C.shape[0])
+ else:
+ p = nx.to_numpy(p)
+
+ if q is None:
+ q = unif(Cdict.shape[-1])
+ else:
+ q = nx.to_numpy(q)
+
+ T = p[:, None] * q[None, :]
+ D = len(Cdict)
+
+ w = unif(D) # Initialize uniformly the unmixing w
+ Cembedded = np.sum(w[:, None, None] * Cdict, axis=0)
+
+ const_q = q[:, None] * q[None, :]
+ # Trackers for BCD convergence
+ convergence_criterion = np.inf
+ current_loss = 10**15
+ outer_count = 0
+
+ while (convergence_criterion > tol_outer) and (outer_count < max_iter_outer):
+ previous_loss = current_loss
+ # 1. Solve GW transport between (C,p) and (\sum_d Cdictionary[d],q) fixing the unmixing w
+ T, log = gromov_wasserstein(
+ C1=C, C2=Cembedded, p=p, q=q, loss_fun='square_loss', G0=T,
+ max_iter=max_iter_inner, tol_rel=tol_inner, tol_abs=0., log=True, armijo=False, symmetric=symmetric, **kwargs)
+ current_loss = log['gw_dist']
+ if reg != 0:
+ current_loss -= reg * np.sum(w**2)
+
+ # 2. Solve linear unmixing problem over w with a fixed transport plan T
+ w, Cembedded, current_loss = _cg_gromov_wasserstein_unmixing(
+ C=C, Cdict=Cdict, Cembedded=Cembedded, w=w, const_q=const_q, T=T,
+ starting_loss=current_loss, reg=reg, tol=tol_inner, max_iter=max_iter_inner, **kwargs
+ )
+
+ if previous_loss != 0:
+ convergence_criterion = abs(previous_loss - current_loss) / abs(previous_loss)
+ else: # handle numerical issues around 0
+ convergence_criterion = abs(previous_loss - current_loss) / 10**(-15)
+ outer_count += 1
+
+ return nx.from_numpy(w), nx.from_numpy(Cembedded), nx.from_numpy(T), nx.from_numpy(current_loss)
+
+
+def _cg_gromov_wasserstein_unmixing(C, Cdict, Cembedded, w, const_q, T, starting_loss, reg=0., tol=10**(-5), max_iter=200, **kwargs):
+ r"""
+ Returns for a fixed admissible transport plan,
+ the linear unmixing w minimizing the Gromov-Wasserstein cost between :math:`(\mathbf{C},\mathbf{p})` and :math:`(\sum_d w[d]*\mathbf{C_{dict}[d]}, \mathbf{q})`
+
+ .. math::
+ \min_{\mathbf{w}} \sum_{ijkl} (C_{i,j} - \sum_{d=1}^D w_d*C_{dict}[d]_{k,l} )^2 T_{i,k}T_{j,l} - reg* \| \mathbf{w} \|_2^2
+
+
+ Such that:
+
+ - :math:`\mathbf{w}^\top \mathbf{1}_D = 1`
+ - :math:`\mathbf{w} \geq \mathbf{0}_D`
+
+ Where :
+
+ - :math:`\mathbf{C}` is the (ns,ns) pairwise similarity matrix.
+ - :math:`\mathbf{C_{dict}}` is a (D, nt, nt) tensor of D pairwise similarity matrices of nt points.
+ - :math:`\mathbf{p}` and :math:`\mathbf{q}` are source and target weights.
+ - :math:`\mathbf{w}` is the linear unmixing of :math:`(\mathbf{C}, \mathbf{p})` onto :math:`(\sum_d w_d \mathbf{Cdict[d]}, \mathbf{q})`.
+ - :math:`\mathbf{T}` is the optimal transport plan conditioned by the current state of :math:`\mathbf{w}`.
+ - reg is the regularization coefficient.
+
+ The algorithm used for solving the problem is a Conditional Gradient Descent as discussed in [38]_
+
+ Parameters
+ ----------
+
+ C : array-like, shape (ns, ns)
+ Metric/Graph cost matrix.
+ Cdict : list of D array-like, shape (nt,nt)
+ Metric/Graph cost matrices composing the dictionary on which to embed C.
+ Each matrix in the dictionary must have the same size (nt,nt).
+ Cembedded: array-like, shape (nt,nt)
+ Embedded structure :math:`(\sum_d w[d]*Cdict[d],q)` of :math:`(\mathbf{C},\mathbf{p})` onto the dictionary. Used to avoid redundant computations.
+ w: array-like, shape (D,)
+ Linear unmixing of the input structure onto the dictionary
+ const_q: array-like, shape (nt,nt)
+ product matrix :math:`\mathbf{q}\mathbf{q}^\top` where q is the target space distribution. Used to avoid redundant computations.
+ T: array-like, shape (ns,nt)
+ fixed transport plan between the input structure and its representation in the dictionary.
+ p : array-like, shape (ns,)
+ Distribution in the source space.
+ q : array-like, shape (nt,)
+ Distribution in the embedding space depicted by the dictionary.
+ reg : float, optional.
+ Coefficient of the negative quadratic regularization used to promote sparsity of w. Default is 0.
+
+ Returns
+ -------
+ w: ndarray (D,)
+ optimal unmixing of :math:`(\mathbf{C},\mathbf{p})` onto the dictionary span given OT starting from previously optimal unmixing.
+ """
+ convergence_criterion = np.inf
+ current_loss = starting_loss
+ count = 0
+ const_TCT = np.transpose(C.dot(T)).dot(T)
+
+ while (convergence_criterion > tol) and (count < max_iter):
+
+ previous_loss = current_loss
+ # 1) Compute gradient at current point w
+ grad_w = 2 * np.sum(Cdict * (Cembedded[None, :, :] * const_q[None, :, :] - const_TCT[None, :, :]), axis=(1, 2))
+ grad_w -= 2 * reg * w
+
+ # 2) Conditional gradient direction finding: x= \argmin_x x^T.grad_w
+ min_ = np.min(grad_w)
+ x = (grad_w == min_).astype(np.float64)
+ x /= np.sum(x)
+
+ # 3) Line-search step: solve \argmin_{\gamma \in [0,1]} a*gamma^2 + b*gamma + c
+ gamma, a, b, Cembedded_diff = _linesearch_gromov_wasserstein_unmixing(w, grad_w, x, Cdict, Cembedded, const_q, const_TCT, reg)
+
+ # 4) Updates: w <-- (1-gamma)*w + gamma*x
+ w += gamma * (x - w)
+ Cembedded += gamma * Cembedded_diff
+ current_loss += a * (gamma**2) + b * gamma
+
+ if previous_loss != 0: # not that the loss can be negative if reg >0
+ convergence_criterion = abs(previous_loss - current_loss) / abs(previous_loss)
+ else: # handle numerical issues around 0
+ convergence_criterion = abs(previous_loss - current_loss) / 10**(-15)
+ count += 1
+
+ return w, Cembedded, current_loss
+
+
+def _linesearch_gromov_wasserstein_unmixing(w, grad_w, x, Cdict, Cembedded, const_q, const_TCT, reg, **kwargs):
+ r"""
+ Compute optimal steps for the line search problem of Gromov-Wasserstein linear unmixing
+ .. math::
+ \min_{\gamma \in [0,1]} \sum_{ijkl} (C_{i,j} - \sum_{d=1}^D z_d(\gamma)C_{dict}[d]_{k,l} )^2 T_{i,k}T_{j,l} - reg\| \mathbf{z}(\gamma) \|_2^2
+
+
+ Such that:
+
+ - :math:`\mathbf{z}(\gamma) = (1- \gamma)\mathbf{w} + \gamma \mathbf{x}`
+
+ Parameters
+ ----------
+
+ w : array-like, shape (D,)
+ Unmixing.
+ grad_w : array-like, shape (D, D)
+ Gradient of the reconstruction loss with respect to w.
+ x: array-like, shape (D,)
+ Conditional gradient direction.
+ Cdict : list of D array-like, shape (nt,nt)
+ Metric/Graph cost matrices composing the dictionary on which to embed C.
+ Each matrix in the dictionary must have the same size (nt,nt).
+ Cembedded: array-like, shape (nt,nt)
+ Embedded structure :math:`(\sum_d w_dCdict[d],q)` of :math:`(\mathbf{C},\mathbf{p})` onto the dictionary. Used to avoid redundant computations.
+ const_q: array-like, shape (nt,nt)
+ product matrix :math:`\mathbf{q}\mathbf{q}^\top` where q is the target space distribution. Used to avoid redundant computations.
+ const_TCT: array-like, shape (nt, nt)
+ :math:`\mathbf{T}^\top \mathbf{C}^\top \mathbf{T}`. Used to avoid redundant computations.
+ Returns
+ -------
+ gamma: float
+ Optimal value for the line-search step
+ a: float
+ Constant factor appearing in the factorization :math:`a \gamma^2 + b \gamma +c` of the reconstruction loss
+ b: float
+ Constant factor appearing in the factorization :math:`a \gamma^2 + b \gamma +c` of the reconstruction loss
+ Cembedded_diff: numpy array, shape (nt, nt)
+ Difference between models evaluated in :math:`\mathbf{w}` and in :math:`\mathbf{w}`.
+ reg : float, optional.
+ Coefficient of the negative quadratic regularization used to promote sparsity of :math:`\mathbf{w}`.
+ """
+
+ # 3) Line-search step: solve \argmin_{\gamma \in [0,1]} a*gamma^2 + b*gamma + c
+ Cembedded_x = np.sum(x[:, None, None] * Cdict, axis=0)
+ Cembedded_diff = Cembedded_x - Cembedded
+ trace_diffx = np.sum(Cembedded_diff * Cembedded_x * const_q)
+ trace_diffw = np.sum(Cembedded_diff * Cembedded * const_q)
+ a = trace_diffx - trace_diffw
+ b = 2 * (trace_diffw - np.sum(Cembedded_diff * const_TCT))
+ if reg != 0:
+ a -= reg * np.sum((x - w)**2)
+ b -= 2 * reg * np.sum(w * (x - w))
+
+ if a > 0:
+ gamma = min(1, max(0, - b / (2 * a)))
+ elif a + b < 0:
+ gamma = 1
+ else:
+ gamma = 0
+
+ return gamma, a, b, Cembedded_diff
+
+
+def fused_gromov_wasserstein_dictionary_learning(Cs, Ys, D, nt, alpha, reg=0., ps=None, q=None, epochs=20, batch_size=32, learning_rate_C=1., learning_rate_Y=1.,
+ Cdict_init=None, Ydict_init=None, projection='nonnegative_symmetric', use_log=False,
+ tol_outer=10**(-5), tol_inner=10**(-5), max_iter_outer=20, max_iter_inner=200, use_adam_optimizer=True, verbose=False, **kwargs):
+ r"""
+ Infer Fused Gromov-Wasserstein linear dictionary :math:`\{ (\mathbf{C_{dict}[d]}, \mathbf{Y_{dict}[d]}, \mathbf{q}) \}_{d \in [D]}` from the list of S attributed structures :math:`\{ (\mathbf{C_s}, \mathbf{Y_s},\mathbf{p_s}) \}_s`
+
+ .. math::
+ \min_{\mathbf{C_{dict}},\mathbf{Y_{dict}}, \{\mathbf{w_s}\}_{s}} \sum_{s=1}^S FGW_{2,\alpha}(\mathbf{C_s}, \mathbf{Y_s}, \sum_{d=1}^D w_{s,d}\mathbf{C_{dict}[d]},\sum_{d=1}^D w_{s,d}\mathbf{Y_{dict}[d]}, \mathbf{p_s}, \mathbf{q}) \\ - reg\| \mathbf{w_s} \|_2^2
+
+
+ Such that :math:`\forall s \leq S` :
+
+ - :math:`\mathbf{w_s}^\top \mathbf{1}_D = 1`
+ - :math:`\mathbf{w_s} \geq \mathbf{0}_D`
+
+ Where :
+
+ - :math:`\forall s \leq S, \mathbf{C_s}` is a (ns,ns) pairwise similarity matrix of variable size ns.
+ - :math:`\forall s \leq S, \mathbf{Y_s}` is a (ns,d) features matrix of variable size ns and fixed dimension d.
+ - :math:`\mathbf{C_{dict}}` is a (D, nt, nt) tensor of D pairwise similarity matrix of fixed size nt.
+ - :math:`\mathbf{Y_{dict}}` is a (D, nt, d) tensor of D features matrix of fixed size nt and fixed dimension d.
+ - :math:`\forall s \leq S, \mathbf{p_s}` is the source distribution corresponding to :math:`\mathbf{C_s}`
+ - :math:`\mathbf{q}` is the target distribution assigned to every structures in the embedding space.
+ - :math:`\alpha` is the trade-off parameter of Fused Gromov-Wasserstein
+ - reg is the regularization coefficient.
+
+
+ The stochastic algorithm used for estimating the attributed graph dictionary atoms as proposed in [38]_
+
+ Parameters
+ ----------
+ Cs : list of S symmetric array-like, shape (ns, ns)
+ List of Metric/Graph cost matrices of variable size (ns,ns).
+ Ys : list of S array-like, shape (ns, d)
+ List of feature matrix of variable size (ns,d) with d fixed.
+ D: int
+ Number of dictionary atoms to learn
+ nt: int
+ Number of samples within each dictionary atoms
+ alpha : float
+ Trade-off parameter of Fused Gromov-Wasserstein
+ reg : float, optional
+ Coefficient of the negative quadratic regularization used to promote sparsity of w. The default is 0.
+ ps : list of S array-like, shape (ns,), optional
+ Distribution in each source space C of Cs. Default is None and corresponds to uniform distibutions.
+ q : array-like, shape (nt,), optional
+ Distribution in the embedding space whose structure will be learned. Default is None and corresponds to uniform distributions.
+ epochs: int, optional
+ Number of epochs used to learn the dictionary. Default is 32.
+ batch_size: int, optional
+ Batch size for each stochastic gradient update of the dictionary. Set to the dataset size if the provided batch_size is higher than the dataset size. Default is 32.
+ learning_rate_C: float, optional
+ Learning rate used for the stochastic gradient descent on Cdict. Default is 1.
+ learning_rate_Y: float, optional
+ Learning rate used for the stochastic gradient descent on Ydict. Default is 1.
+ Cdict_init: list of D array-like with shape (nt, nt), optional
+ Used to initialize the dictionary structures Cdict.
+ If set to None (Default), the dictionary will be initialized randomly.
+ Else Cdict must have shape (D, nt, nt) i.e match provided shape features.
+ Ydict_init: list of D array-like with shape (nt, d), optional
+ Used to initialize the dictionary features Ydict.
+ If set to None, the dictionary features will be initialized randomly.
+ Else Ydict must have shape (D, nt, d) where d is the features dimension of inputs Ys and also match provided shape features.
+ projection: str, optional
+ If 'nonnegative' and/or 'symmetric' is in projection, the corresponding projection will be performed at each stochastic update of the dictionary
+ Else the set of atoms is :math:`R^{nt * nt}`. Default is 'nonnegative_symmetric'
+ log: bool, optional
+ If set to True, losses evolution by batches and epochs are tracked. Default is False.
+ use_adam_optimizer: bool, optional
+ If set to True, adam optimizer with default settings is used as adaptative learning rate strategy.
+ Else perform SGD with fixed learning rate. Default is True.
+ tol_outer : float, optional
+ Solver precision for the BCD algorithm, measured by absolute relative error on consecutive losses. Default is :math:`10^{-5}`.
+ tol_inner : float, optional
+ Solver precision for the Conjugate Gradient algorithm used to get optimal w at a fixed transport, measured by absolute relative error on consecutive losses. Default is :math:`10^{-5}`.
+ max_iter_outer : int, optional
+ Maximum number of iterations for the BCD. Default is 20.
+ max_iter_inner : int, optional
+ Maximum number of iterations for the Conjugate Gradient. Default is 200.
+ verbose : bool, optional
+ Print the reconstruction loss every epoch. Default is False.
+
+ Returns
+ -------
+
+ Cdict_best_state : D array-like, shape (D,nt,nt)
+ Metric/Graph cost matrices composing the dictionary.
+ The dictionary leading to the best loss over an epoch is saved and returned.
+ Ydict_best_state : D array-like, shape (D,nt,d)
+ Feature matrices composing the dictionary.
+ The dictionary leading to the best loss over an epoch is saved and returned.
+ log: dict
+ If use_log is True, contains loss evolutions by batches and epoches.
+ References
+ -------
+ .. [38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, Online
+ Graph Dictionary Learning, International Conference on Machine Learning
+ (ICML), 2021.
+ """
+ Cs0, Ys0 = Cs, Ys
+ nx = get_backend(*Cs0, *Ys0)
+ Cs = [nx.to_numpy(C) for C in Cs0]
+ Ys = [nx.to_numpy(Y) for Y in Ys0]
+
+ d = Ys[0].shape[-1]
+ dataset_size = len(Cs)
+
+ if ps is None:
+ ps = [unif(C.shape[0]) for C in Cs]
+ else:
+ ps = [nx.to_numpy(p) for p in ps]
+ if q is None:
+ q = unif(nt)
+ else:
+ q = nx.to_numpy(q)
+
+ if Cdict_init is None:
+ # Initialize randomly structures of dictionary atoms based on samples
+ dataset_means = [C.mean() for C in Cs]
+ Cdict = np.random.normal(loc=np.mean(dataset_means), scale=np.std(dataset_means), size=(D, nt, nt))
+ else:
+ Cdict = nx.to_numpy(Cdict_init).copy()
+ assert Cdict.shape == (D, nt, nt)
+ if Ydict_init is None:
+ # Initialize randomly features of dictionary atoms based on samples distribution by feature component
+ dataset_feature_means = np.stack([F.mean(axis=0) for F in Ys])
+ Ydict = np.random.normal(loc=dataset_feature_means.mean(axis=0), scale=dataset_feature_means.std(axis=0), size=(D, nt, d))
+ else:
+ Ydict = nx.to_numpy(Ydict_init).copy()
+ assert Ydict.shape == (D, nt, d)
+
+ if 'symmetric' in projection:
+ Cdict = 0.5 * (Cdict + Cdict.transpose((0, 2, 1)))
+ symmetric = True
+ else:
+ symmetric = False
+ if 'nonnegative' in projection:
+ Cdict[Cdict < 0.] = 0.
+
+ if use_adam_optimizer:
+ adam_moments_C = _initialize_adam_optimizer(Cdict)
+ adam_moments_Y = _initialize_adam_optimizer(Ydict)
+
+ log = {'loss_batches': [], 'loss_epochs': []}
+ const_q = q[:, None] * q[None, :]
+ diag_q = np.diag(q)
+ Cdict_best_state = Cdict.copy()
+ Ydict_best_state = Ydict.copy()
+ loss_best_state = np.inf
+ if batch_size > dataset_size:
+ batch_size = dataset_size
+ iter_by_epoch = dataset_size // batch_size + int((dataset_size % batch_size) > 0)
+
+ for epoch in range(epochs):
+ cumulated_loss_over_epoch = 0.
+
+ for _ in range(iter_by_epoch):
+
+ # Batch iterations
+ batch = np.random.choice(range(dataset_size), size=batch_size, replace=False)
+ cumulated_loss_over_batch = 0.
+ unmixings = np.zeros((batch_size, D))
+ Cs_embedded = np.zeros((batch_size, nt, nt))
+ Ys_embedded = np.zeros((batch_size, nt, d))
+ Ts = [None] * batch_size
+
+ for batch_idx, C_idx in enumerate(batch):
+ # BCD solver for Gromov-Wassersteisn linear unmixing used independently on each structure of the sampled batch
+ unmixings[batch_idx], Cs_embedded[batch_idx], Ys_embedded[batch_idx], Ts[batch_idx], current_loss = fused_gromov_wasserstein_linear_unmixing(
+ Cs[C_idx], Ys[C_idx], Cdict, Ydict, alpha, reg=reg, p=ps[C_idx], q=q,
+ tol_outer=tol_outer, tol_inner=tol_inner, max_iter_outer=max_iter_outer, max_iter_inner=max_iter_inner, symmetric=symmetric, **kwargs
+ )
+ cumulated_loss_over_batch += current_loss
+ cumulated_loss_over_epoch += cumulated_loss_over_batch
+ if use_log:
+ log['loss_batches'].append(cumulated_loss_over_batch)
+
+ # Stochastic projected gradient step over dictionary atoms
+ grad_Cdict = np.zeros_like(Cdict)
+ grad_Ydict = np.zeros_like(Ydict)
+
+ for batch_idx, C_idx in enumerate(batch):
+ shared_term_structures = Cs_embedded[batch_idx] * const_q - (Cs[C_idx].dot(Ts[batch_idx])).T.dot(Ts[batch_idx])
+ shared_term_features = diag_q.dot(Ys_embedded[batch_idx]) - Ts[batch_idx].T.dot(Ys[C_idx])
+ grad_Cdict += alpha * unmixings[batch_idx][:, None, None] * shared_term_structures[None, :, :]
+ grad_Ydict += (1 - alpha) * unmixings[batch_idx][:, None, None] * shared_term_features[None, :, :]
+ grad_Cdict *= 2 / batch_size
+ grad_Ydict *= 2 / batch_size
+
+ if use_adam_optimizer:
+ Cdict, adam_moments_C = _adam_stochastic_updates(Cdict, grad_Cdict, learning_rate_C, adam_moments_C)
+ Ydict, adam_moments_Y = _adam_stochastic_updates(Ydict, grad_Ydict, learning_rate_Y, adam_moments_Y)
+ else:
+ Cdict -= learning_rate_C * grad_Cdict
+ Ydict -= learning_rate_Y * grad_Ydict
+
+ if 'symmetric' in projection:
+ Cdict = 0.5 * (Cdict + Cdict.transpose((0, 2, 1)))
+ if 'nonnegative' in projection:
+ Cdict[Cdict < 0.] = 0.
+
+ if use_log:
+ log['loss_epochs'].append(cumulated_loss_over_epoch)
+ if loss_best_state > cumulated_loss_over_epoch:
+ loss_best_state = cumulated_loss_over_epoch
+ Cdict_best_state = Cdict.copy()
+ Ydict_best_state = Ydict.copy()
+ if verbose:
+ print('--- epoch: ', epoch, ' cumulated reconstruction error: ', cumulated_loss_over_epoch)
+
+ return nx.from_numpy(Cdict_best_state), nx.from_numpy(Ydict_best_state), log
+
+
+def fused_gromov_wasserstein_linear_unmixing(C, Y, Cdict, Ydict, alpha, reg=0., p=None, q=None, tol_outer=10**(-5),
+ tol_inner=10**(-5), max_iter_outer=20, max_iter_inner=200, symmetric=True, **kwargs):
+ r"""
+ Returns the Fused Gromov-Wasserstein linear unmixing of :math:`(\mathbf{C},\mathbf{Y},\mathbf{p})` onto the attributed dictionary atoms :math:`\{ (\mathbf{C_{dict}[d]},\mathbf{Y_{dict}[d]}, \mathbf{q}) \}_{d \in [D]}`
+
+ .. math::
+ \min_{\mathbf{w}} FGW_{2,\alpha}(\mathbf{C},\mathbf{Y}, \sum_{d=1}^D w_d\mathbf{C_{dict}[d]},\sum_{d=1}^D w_d\mathbf{Y_{dict}[d]}, \mathbf{p}, \mathbf{q}) - reg \| \mathbf{w} \|_2^2
+
+ such that, :math:`\forall s \leq S` :
+
+ - :math:`\mathbf{w_s}^\top \mathbf{1}_D = 1`
+ - :math:`\mathbf{w_s} \geq \mathbf{0}_D`
+
+ Where :
+
+ - :math:`\mathbf{C}` is a (ns,ns) pairwise similarity matrix of variable size ns.
+ - :math:`\mathbf{Y}` is a (ns,d) features matrix of variable size ns and fixed dimension d.
+ - :math:`\mathbf{C_{dict}}` is a (D, nt, nt) tensor of D pairwise similarity matrix of fixed size nt.
+ - :math:`\mathbf{Y_{dict}}` is a (D, nt, d) tensor of D features matrix of fixed size nt and fixed dimension d.
+ - :math:`\mathbf{p}` is the source distribution corresponding to :math:`\mathbf{C_s}`
+ - :math:`\mathbf{q}` is the target distribution assigned to every structures in the embedding space.
+ - :math:`\alpha` is the trade-off parameter of Fused Gromov-Wasserstein
+ - reg is the regularization coefficient.
+
+ The algorithm used for solving the problem is a Block Coordinate Descent as discussed in [38]_, algorithm 6.
+
+ Parameters
+ ----------
+ C : array-like, shape (ns, ns)
+ Metric/Graph cost matrix.
+ Y : array-like, shape (ns, d)
+ Feature matrix.
+ Cdict : D array-like, shape (D,nt,nt)
+ Metric/Graph cost matrices composing the dictionary on which to embed (C,Y).
+ Ydict : D array-like, shape (D,nt,d)
+ Feature matrices composing the dictionary on which to embed (C,Y).
+ alpha: float,
+ Trade-off parameter of Fused Gromov-Wasserstein.
+ reg : float, optional
+ Coefficient of the negative quadratic regularization used to promote sparsity of w. The default is 0.
+ p : array-like, shape (ns,), optional
+ Distribution in the source space C. Default is None and corresponds to uniform distribution.
+ q : array-like, shape (nt,), optional
+ Distribution in the space depicted by the dictionary. Default is None and corresponds to uniform distribution.
+ tol_outer : float, optional
+ Solver precision for the BCD algorithm.
+ tol_inner : float, optional
+ Solver precision for the Conjugate Gradient algorithm used to get optimal w at a fixed transport. Default is :math:`10^{-5}`.
+ max_iter_outer : int, optional
+ Maximum number of iterations for the BCD. Default is 20.
+ max_iter_inner : int, optional
+ Maximum number of iterations for the Conjugate Gradient. Default is 200.
+
+ Returns
+ -------
+ w: array-like, shape (D,)
+ fused gromov-wasserstein linear unmixing of (C,Y,p) onto the span of the dictionary.
+ Cembedded: array-like, shape (nt,nt)
+ embedded structure of :math:`(\mathbf{C},\mathbf{Y}, \mathbf{p})` onto the dictionary, :math:`\sum_d w_d\mathbf{C_{dict}[d]}`.
+ Yembedded: array-like, shape (nt,d)
+ embedded features of :math:`(\mathbf{C},\mathbf{Y}, \mathbf{p})` onto the dictionary, :math:`\sum_d w_d\mathbf{Y_{dict}[d]}`.
+ T: array-like (ns,nt)
+ Fused Gromov-Wasserstein transport plan between :math:`(\mathbf{C},\mathbf{p})` and :math:`(\sum_d w_d\mathbf{C_{dict}[d]}, \sum_d w_d\mathbf{Y_{dict}[d]},\mathbf{q})`.
+ current_loss: float
+ reconstruction error
+ References
+ -------
+ .. [38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, Online
+ Graph Dictionary Learning, International Conference on Machine Learning
+ (ICML), 2021.
+ """
+ C0, Y0, Cdict0, Ydict0 = C, Y, Cdict, Ydict
+ nx = get_backend(C0, Y0, Cdict0, Ydict0)
+ C = nx.to_numpy(C0)
+ Y = nx.to_numpy(Y0)
+ Cdict = nx.to_numpy(Cdict0)
+ Ydict = nx.to_numpy(Ydict0)
+
+ if p is None:
+ p = unif(C.shape[0])
+ else:
+ p = nx.to_numpy(p)
+ if q is None:
+ q = unif(Cdict.shape[-1])
+ else:
+ q = nx.to_numpy(q)
+
+ T = p[:, None] * q[None, :]
+ D = len(Cdict)
+ d = Y.shape[-1]
+ w = unif(D) # Initialize with uniform weights
+ ns = C.shape[-1]
+ nt = Cdict.shape[-1]
+
+ # modeling (C,Y)
+ Cembedded = np.sum(w[:, None, None] * Cdict, axis=0)
+ Yembedded = np.sum(w[:, None, None] * Ydict, axis=0)
+
+ # constants depending on q
+ const_q = q[:, None] * q[None, :]
+ diag_q = np.diag(q)
+ # Trackers for BCD convergence
+ convergence_criterion = np.inf
+ current_loss = 10**15
+ outer_count = 0
+ Ys_constM = (Y**2).dot(np.ones((d, nt))) # constant in computing euclidean pairwise feature matrix
+
+ while (convergence_criterion > tol_outer) and (outer_count < max_iter_outer):
+ previous_loss = current_loss
+
+ # 1. Solve GW transport between (C,p) and (\sum_d Cdictionary[d],q) fixing the unmixing w
+ Yt_varM = (np.ones((ns, d))).dot((Yembedded**2).T)
+ M = Ys_constM + Yt_varM - 2 * Y.dot(Yembedded.T) # euclidean distance matrix between features
+ T, log = fused_gromov_wasserstein(
+ M, C, Cembedded, p, q, loss_fun='square_loss', alpha=alpha,
+ max_iter=max_iter_inner, tol_rel=tol_inner, tol_abs=0., armijo=False, G0=T, log=True, symmetric=symmetric, **kwargs)
+ current_loss = log['fgw_dist']
+ if reg != 0:
+ current_loss -= reg * np.sum(w**2)
+
+ # 2. Solve linear unmixing problem over w with a fixed transport plan T
+ w, Cembedded, Yembedded, current_loss = _cg_fused_gromov_wasserstein_unmixing(C, Y, Cdict, Ydict, Cembedded, Yembedded, w,
+ T, p, q, const_q, diag_q, current_loss, alpha, reg,
+ tol=tol_inner, max_iter=max_iter_inner, **kwargs)
+ if previous_loss != 0:
+ convergence_criterion = abs(previous_loss - current_loss) / abs(previous_loss)
+ else:
+ convergence_criterion = abs(previous_loss - current_loss) / 10**(-12)
+ outer_count += 1
+
+ return nx.from_numpy(w), nx.from_numpy(Cembedded), nx.from_numpy(Yembedded), nx.from_numpy(T), nx.from_numpy(current_loss)
+
+
+def _cg_fused_gromov_wasserstein_unmixing(C, Y, Cdict, Ydict, Cembedded, Yembedded, w, T, p, q, const_q, diag_q, starting_loss, alpha, reg, tol=10**(-6), max_iter=200, **kwargs):
+ r"""
+ Returns for a fixed admissible transport plan,
+ the optimal linear unmixing :math:`\mathbf{w}` minimizing the Fused Gromov-Wasserstein cost between :math:`(\mathbf{C},\mathbf{Y},\mathbf{p})` and :math:`(\sum_d w_d \mathbf{C_{dict}[d]},\sum_d w_d*\mathbf{Y_{dict}[d]}, \mathbf{q})`
+
+ .. math::
+ \min_{\mathbf{w}} \alpha \sum_{ijkl} (C_{i,j} - \sum_{d=1}^D w_d C_{dict}[d]_{k,l} )^2 T_{i,k}T_{j,l} \\+ (1-\alpha) \sum_{ij} \| \mathbf{Y_i} - \sum_d w_d \mathbf{Y_{dict}[d]_j} \|_2^2 T_{ij}- reg \| \mathbf{w} \|_2^2
+
+ Such that :
+
+ - :math:`\mathbf{w}^\top \mathbf{1}_D = 1`
+ - :math:`\mathbf{w} \geq \mathbf{0}_D`
+
+ Where :
+
+ - :math:`\mathbf{C}` is a (ns,ns) pairwise similarity matrix of variable size ns.
+ - :math:`\mathbf{Y}` is a (ns,d) features matrix of variable size ns and fixed dimension d.
+ - :math:`\mathbf{C_{dict}}` is a (D, nt, nt) tensor of D pairwise similarity matrix of fixed size nt.
+ - :math:`\mathbf{Y_{dict}}` is a (D, nt, d) tensor of D features matrix of fixed size nt and fixed dimension d.
+ - :math:`\mathbf{p}` is the source distribution corresponding to :math:`\mathbf{C_s}`
+ - :math:`\mathbf{q}` is the target distribution assigned to every structures in the embedding space.
+ - :math:`\mathbf{T}` is the optimal transport plan conditioned by the previous state of :math:`\mathbf{w}`
+ - :math:`\alpha` is the trade-off parameter of Fused Gromov-Wasserstein
+ - reg is the regularization coefficient.
+
+ The algorithm used for solving the problem is a Conditional Gradient Descent as discussed in [38]_, algorithm 7.
+
+ Parameters
+ ----------
+
+ C : array-like, shape (ns, ns)
+ Metric/Graph cost matrix.
+ Y : array-like, shape (ns, d)
+ Feature matrix.
+ Cdict : list of D array-like, shape (nt,nt)
+ Metric/Graph cost matrices composing the dictionary on which to embed (C,Y).
+ Each matrix in the dictionary must have the same size (nt,nt).
+ Ydict : list of D array-like, shape (nt,d)
+ Feature matrices composing the dictionary on which to embed (C,Y).
+ Each matrix in the dictionary must have the same size (nt,d).
+ Cembedded: array-like, shape (nt,nt)
+ Embedded structure of (C,Y) onto the dictionary
+ Yembedded: array-like, shape (nt,d)
+ Embedded features of (C,Y) onto the dictionary
+ w: array-like, shape (n_D,)
+ Linear unmixing of (C,Y) onto (Cdict,Ydict)
+ const_q: array-like, shape (nt,nt)
+ product matrix :math:`\mathbf{qq}^\top` where :math:`\mathbf{q}` is the target space distribution.
+ diag_q: array-like, shape (nt,nt)
+ diagonal matrix with values of q on the diagonal.
+ T: array-like, shape (ns,nt)
+ fixed transport plan between (C,Y) and its model
+ p : array-like, shape (ns,)
+ Distribution in the source space (C,Y).
+ q : array-like, shape (nt,)
+ Distribution in the embedding space depicted by the dictionary.
+ alpha: float,
+ Trade-off parameter of Fused Gromov-Wasserstein.
+ reg : float, optional
+ Coefficient of the negative quadratic regularization used to promote sparsity of w.
+
+ Returns
+ -------
+ w: ndarray (D,)
+ linear unmixing of :math:`(\mathbf{C},\mathbf{Y},\mathbf{p})` onto the span of :math:`(C_{dict},Y_{dict})` given OT corresponding to previous unmixing.
+ """
+ convergence_criterion = np.inf
+ current_loss = starting_loss
+ count = 0
+ const_TCT = np.transpose(C.dot(T)).dot(T)
+ ones_ns_d = np.ones(Y.shape)
+
+ while (convergence_criterion > tol) and (count < max_iter):
+ previous_loss = current_loss
+
+ # 1) Compute gradient at current point w
+ # structure
+ grad_w = alpha * np.sum(Cdict * (Cembedded[None, :, :] * const_q[None, :, :] - const_TCT[None, :, :]), axis=(1, 2))
+ # feature
+ grad_w += (1 - alpha) * np.sum(Ydict * (diag_q.dot(Yembedded)[None, :, :] - T.T.dot(Y)[None, :, :]), axis=(1, 2))
+ grad_w -= reg * w
+ grad_w *= 2
+
+ # 2) Conditional gradient direction finding: x= \argmin_x x^T.grad_w
+ min_ = np.min(grad_w)
+ x = (grad_w == min_).astype(np.float64)
+ x /= np.sum(x)
+
+ # 3) Line-search step: solve \argmin_{\gamma \in [0,1]} a*gamma^2 + b*gamma + c
+ gamma, a, b, Cembedded_diff, Yembedded_diff = _linesearch_fused_gromov_wasserstein_unmixing(w, grad_w, x, Y, Cdict, Ydict, Cembedded, Yembedded, T, const_q, const_TCT, ones_ns_d, alpha, reg)
+
+ # 4) Updates: w <-- (1-gamma)*w + gamma*x
+ w += gamma * (x - w)
+ Cembedded += gamma * Cembedded_diff
+ Yembedded += gamma * Yembedded_diff
+ current_loss += a * (gamma**2) + b * gamma
+
+ if previous_loss != 0:
+ convergence_criterion = abs(previous_loss - current_loss) / abs(previous_loss)
+ else:
+ convergence_criterion = abs(previous_loss - current_loss) / 10**(-12)
+ count += 1
+
+ return w, Cembedded, Yembedded, current_loss
+
+
+def _linesearch_fused_gromov_wasserstein_unmixing(w, grad_w, x, Y, Cdict, Ydict, Cembedded, Yembedded, T, const_q, const_TCT, ones_ns_d, alpha, reg, **kwargs):
+ r"""
+ Compute optimal steps for the line search problem of Fused Gromov-Wasserstein linear unmixing
+ .. math::
+ \min_{\gamma \in [0,1]} \alpha \sum_{ijkl} (C_{i,j} - \sum_{d=1}^D z_d(\gamma)C_{dict}[d]_{k,l} )^2 T_{i,k}T_{j,l} \\ + (1-\alpha) \sum_{ij} \| \mathbf{Y_i} - \sum_d z_d(\gamma) \mathbf{Y_{dict}[d]_j} \|_2^2 - reg\| \mathbf{z}(\gamma) \|_2^2
+
+
+ Such that :
+
+ - :math:`\mathbf{z}(\gamma) = (1- \gamma)\mathbf{w} + \gamma \mathbf{x}`
+
+ Parameters
+ ----------
+
+ w : array-like, shape (D,)
+ Unmixing.
+ grad_w : array-like, shape (D, D)
+ Gradient of the reconstruction loss with respect to w.
+ x: array-like, shape (D,)
+ Conditional gradient direction.
+ Y: arrat-like, shape (ns,d)
+ Feature matrix of the input space
+ Cdict : list of D array-like, shape (nt, nt)
+ Metric/Graph cost matrices composing the dictionary on which to embed (C,Y).
+ Each matrix in the dictionary must have the same size (nt,nt).
+ Ydict : list of D array-like, shape (nt, d)
+ Feature matrices composing the dictionary on which to embed (C,Y).
+ Each matrix in the dictionary must have the same size (nt,d).
+ Cembedded: array-like, shape (nt, nt)
+ Embedded structure of (C,Y) onto the dictionary
+ Yembedded: array-like, shape (nt, d)
+ Embedded features of (C,Y) onto the dictionary
+ T: array-like, shape (ns, nt)
+ Fixed transport plan between (C,Y) and its current model.
+ const_q: array-like, shape (nt,nt)
+ product matrix :math:`\mathbf{q}\mathbf{q}^\top` where q is the target space distribution. Used to avoid redundant computations.
+ const_TCT: array-like, shape (nt, nt)
+ :math:`\mathbf{T}^\top \mathbf{C}^\top \mathbf{T}`. Used to avoid redundant computations.
+ ones_ns_d: array-like, shape (ns, d)
+ :math:`\mathbf{1}_{ ns \times d}`. Used to avoid redundant computations.
+ alpha: float,
+ Trade-off parameter of Fused Gromov-Wasserstein.
+ reg : float, optional
+ Coefficient of the negative quadratic regularization used to promote sparsity of w.
+
+ Returns
+ -------
+ gamma: float
+ Optimal value for the line-search step
+ a: float
+ Constant factor appearing in the factorization :math:`a \gamma^2 + b \gamma +c` of the reconstruction loss
+ b: float
+ Constant factor appearing in the factorization :math:`a \gamma^2 + b \gamma +c` of the reconstruction loss
+ Cembedded_diff: numpy array, shape (nt, nt)
+ Difference between structure matrix of models evaluated in :math:`\mathbf{w}` and in :math:`\mathbf{w}`.
+ Yembedded_diff: numpy array, shape (nt, nt)
+ Difference between feature matrix of models evaluated in :math:`\mathbf{w}` and in :math:`\mathbf{w}`.
+ """
+ # polynomial coefficients from quadratic objective (with respect to w) on structures
+ Cembedded_x = np.sum(x[:, None, None] * Cdict, axis=0)
+ Cembedded_diff = Cembedded_x - Cembedded
+ trace_diffx = np.sum(Cembedded_diff * Cembedded_x * const_q)
+ trace_diffw = np.sum(Cembedded_diff * Cembedded * const_q)
+ # Constant factor appearing in the factorization a*gamma^2 + b*g + c of the Gromov-Wasserstein reconstruction loss
+ a_gw = trace_diffx - trace_diffw
+ b_gw = 2 * (trace_diffw - np.sum(Cembedded_diff * const_TCT))
+
+ # polynomial coefficient from quadratic objective (with respect to w) on features
+ Yembedded_x = np.sum(x[:, None, None] * Ydict, axis=0)
+ Yembedded_diff = Yembedded_x - Yembedded
+ # Constant factor appearing in the factorization a*gamma^2 + b*g + c of the Gromov-Wasserstein reconstruction loss
+ a_w = np.sum(ones_ns_d.dot((Yembedded_diff**2).T) * T)
+ b_w = 2 * np.sum(T * (ones_ns_d.dot((Yembedded * Yembedded_diff).T) - Y.dot(Yembedded_diff.T)))
+
+ a = alpha * a_gw + (1 - alpha) * a_w
+ b = alpha * b_gw + (1 - alpha) * b_w
+ if reg != 0:
+ a -= reg * np.sum((x - w)**2)
+ b -= 2 * reg * np.sum(w * (x - w))
+ if a > 0:
+ gamma = min(1, max(0, -b / (2 * a)))
+ elif a + b < 0:
+ gamma = 1
+ else:
+ gamma = 0
+
+ return gamma, a, b, Cembedded_diff, Yembedded_diff
diff --git a/ot/gromov/_estimators.py b/ot/gromov/_estimators.py
new file mode 100644
index 0000000..0a29a91
--- /dev/null
+++ b/ot/gromov/_estimators.py
@@ -0,0 +1,425 @@
+# -*- coding: utf-8 -*-
+"""
+Gromov-Wasserstein and Fused-Gromov-Wasserstein stochastic estimators.
+"""
+
+# Author: Rémi Flamary <remi.flamary@unice.fr>
+# Tanguy Kerdoncuff <tanguy.kerdoncuff@laposte.net>
+#
+# License: MIT License
+
+import numpy as np
+
+
+from ..bregman import sinkhorn
+from ..utils import list_to_array, check_random_state
+from ..lp import emd_1d, emd
+from ..backend import get_backend
+
+
+def GW_distance_estimation(C1, C2, p, q, loss_fun, T,
+ nb_samples_p=None, nb_samples_q=None, std=True, random_state=None):
+ r"""
+ Returns an approximation of the gromov-wasserstein cost between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})`
+ with a fixed transport plan :math:`\mathbf{T}`.
+
+ The function gives an unbiased approximation of the following equation:
+
+ .. math::
+
+ GW = \sum_{i,j,k,l} L(\mathbf{C_{1}}_{i,k}, \mathbf{C_{2}}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l}
+
+ Where :
+
+ - :math:`\mathbf{C_1}`: Metric cost matrix in the source space
+ - :math:`\mathbf{C_2}`: Metric cost matrix in the target space
+ - `L` : Loss function to account for the misfit between the similarity matrices
+ - :math:`\mathbf{T}`: Matrix with marginal :math:`\mathbf{p}` and :math:`\mathbf{q}`
+
+ Parameters
+ ----------
+ C1 : array-like, shape (ns, ns)
+ Metric cost matrix in the source space
+ C2 : array-like, shape (nt, nt)
+ Metric cost matrix in the target space
+ p : array-like, shape (ns,)
+ Distribution in the source space
+ q : array-like, shape (nt,)
+ Distribution in the target space
+ loss_fun : function: :math:`\mathbb{R} \times \mathbb{R} \mapsto \mathbb{R}`
+ Loss function used for the distance, the transport plan does not depend on the loss function
+ T : csr or array-like, shape (ns, nt)
+ Transport plan matrix, either a sparse csr or a dense matrix
+ nb_samples_p : int, optional
+ `nb_samples_p` is the number of samples (without replacement) along the first dimension of :math:`\mathbf{T}`
+ nb_samples_q : int, optional
+ `nb_samples_q` is the number of samples along the second dimension of :math:`\mathbf{T}`, for each sample along the first
+ std : bool, optional
+ Standard deviation associated with the prediction of the gromov-wasserstein cost
+ random_state : int or RandomState instance, optional
+ Fix the seed for reproducibility
+
+ Returns
+ -------
+ : float
+ Gromov-wasserstein cost
+
+ References
+ ----------
+ .. [14] Kerdoncuff, Tanguy, Emonet, Rémi, Sebban, Marc
+ "Sampled Gromov Wasserstein."
+ Machine Learning Journal (MLJ). 2021.
+
+ """
+ C1, C2, p, q = list_to_array(C1, C2, p, q)
+ nx = get_backend(C1, C2, p, q)
+
+ generator = check_random_state(random_state)
+
+ len_p = p.shape[0]
+ len_q = q.shape[0]
+
+ # It is always better to sample from the biggest distribution first.
+ if len_p < len_q:
+ p, q = q, p
+ len_p, len_q = len_q, len_p
+ C1, C2 = C2, C1
+ T = T.T
+
+ if nb_samples_p is None:
+ if nx.issparse(T):
+ # If T is sparse, it probably mean that PoGroW was used, thus the number of sample is reduced
+ nb_samples_p = min(int(5 * (len_p * np.log(len_p)) ** 0.5), len_p)
+ else:
+ nb_samples_p = len_p
+ else:
+ # The number of sample along the first dimension is without replacement.
+ nb_samples_p = min(nb_samples_p, len_p)
+ if nb_samples_q is None:
+ nb_samples_q = 1
+ if std:
+ nb_samples_q = max(2, nb_samples_q)
+
+ index_k = np.zeros((nb_samples_p, nb_samples_q), dtype=int)
+ index_l = np.zeros((nb_samples_p, nb_samples_q), dtype=int)
+
+ index_i = generator.choice(
+ len_p, size=nb_samples_p, p=nx.to_numpy(p), replace=False
+ )
+ index_j = generator.choice(
+ len_p, size=nb_samples_p, p=nx.to_numpy(p), replace=False
+ )
+
+ for i in range(nb_samples_p):
+ if nx.issparse(T):
+ T_indexi = nx.reshape(nx.todense(T[index_i[i], :]), (-1,))
+ T_indexj = nx.reshape(nx.todense(T[index_j[i], :]), (-1,))
+ else:
+ T_indexi = T[index_i[i], :]
+ T_indexj = T[index_j[i], :]
+ # For each of the row sampled, the column is sampled.
+ index_k[i] = generator.choice(
+ len_q,
+ size=nb_samples_q,
+ p=nx.to_numpy(T_indexi / nx.sum(T_indexi)),
+ replace=True
+ )
+ index_l[i] = generator.choice(
+ len_q,
+ size=nb_samples_q,
+ p=nx.to_numpy(T_indexj / nx.sum(T_indexj)),
+ replace=True
+ )
+
+ list_value_sample = nx.stack([
+ loss_fun(
+ C1[np.ix_(index_i, index_j)],
+ C2[np.ix_(index_k[:, n], index_l[:, n])]
+ ) for n in range(nb_samples_q)
+ ], axis=2)
+
+ if std:
+ std_value = nx.sum(nx.std(list_value_sample, axis=2) ** 2) ** 0.5
+ return nx.mean(list_value_sample), std_value / (nb_samples_p * nb_samples_p)
+ else:
+ return nx.mean(list_value_sample)
+
+
+def pointwise_gromov_wasserstein(C1, C2, p, q, loss_fun,
+ alpha=1, max_iter=100, threshold_plan=0, log=False, verbose=False, random_state=None):
+ r"""
+ Returns the gromov-wasserstein transport between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` using a stochastic Frank-Wolfe.
+ This method has a :math:`\mathcal{O}(\mathrm{max\_iter} \times PN^2)` time complexity with `P` the number of Sinkhorn iterations.
+
+ The function solves the following optimization problem:
+
+ .. math::
+ \mathbf{GW} = \mathop{\arg \min}_\mathbf{T} \quad \sum_{i,j,k,l}
+ L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l}
+
+ s.t. \ \mathbf{T} \mathbf{1} &= \mathbf{p}
+
+ \mathbf{T}^T \mathbf{1} &= \mathbf{q}
+
+ \mathbf{T} &\geq 0
+
+ Where :
+
+ - :math:`\mathbf{C_1}`: Metric cost matrix in the source space
+ - :math:`\mathbf{C_2}`: Metric cost matrix in the target space
+ - :math:`\mathbf{p}`: distribution in the source space
+ - :math:`\mathbf{q}`: distribution in the target space
+ - `L`: loss function to account for the misfit between the similarity matrices
+
+ Parameters
+ ----------
+ C1 : array-like, shape (ns, ns)
+ Metric cost matrix in the source space
+ C2 : array-like, shape (nt, nt)
+ Metric cost matrix in the target space
+ p : array-like, shape (ns,)
+ Distribution in the source space
+ q : array-like, shape (nt,)
+ Distribution in the target space
+ loss_fun : function: :math:`\mathbb{R} \times \mathbb{R} \mapsto \mathbb{R}`
+ Loss function used for the distance, the transport plan does not depend on the loss function
+ alpha : float
+ Step of the Frank-Wolfe algorithm, should be between 0 and 1
+ max_iter : int, optional
+ Max number of iterations
+ threshold_plan : float, optional
+ Deleting very small values in the transport plan. If above zero, it violates the marginal constraints.
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ Gives the distance estimated and the standard deviation
+ random_state : int or RandomState instance, optional
+ Fix the seed for reproducibility
+
+ Returns
+ -------
+ T : array-like, shape (`ns`, `nt`)
+ Optimal coupling between the two spaces
+
+ References
+ ----------
+ .. [14] Kerdoncuff, Tanguy, Emonet, Rémi, Sebban, Marc
+ "Sampled Gromov Wasserstein."
+ Machine Learning Journal (MLJ). 2021.
+
+ """
+ C1, C2, p, q = list_to_array(C1, C2, p, q)
+ nx = get_backend(C1, C2, p, q)
+
+ len_p = p.shape[0]
+ len_q = q.shape[0]
+
+ generator = check_random_state(random_state)
+
+ index = np.zeros(2, dtype=int)
+
+ # Initialize with default marginal
+ index[0] = generator.choice(len_p, size=1, p=nx.to_numpy(p))
+ index[1] = generator.choice(len_q, size=1, p=nx.to_numpy(q))
+ T = nx.tocsr(emd_1d(C1[index[0]], C2[index[1]], a=p, b=q, dense=False))
+
+ best_gw_dist_estimated = np.inf
+ for cpt in range(max_iter):
+ index[0] = generator.choice(len_p, size=1, p=nx.to_numpy(p))
+ T_index0 = nx.reshape(nx.todense(T[index[0], :]), (-1,))
+ index[1] = generator.choice(
+ len_q, size=1, p=nx.to_numpy(T_index0 / nx.sum(T_index0))
+ )
+
+ if alpha == 1:
+ T = nx.tocsr(
+ emd_1d(C1[index[0]], C2[index[1]], a=p, b=q, dense=False)
+ )
+ else:
+ new_T = nx.tocsr(
+ emd_1d(C1[index[0]], C2[index[1]], a=p, b=q, dense=False)
+ )
+ T = (1 - alpha) * T + alpha * new_T
+ # To limit the number of non 0, the values below the threshold are set to 0.
+ T = nx.eliminate_zeros(T, threshold=threshold_plan)
+
+ if cpt % 10 == 0 or cpt == (max_iter - 1):
+ gw_dist_estimated = GW_distance_estimation(
+ C1=C1, C2=C2, loss_fun=loss_fun,
+ p=p, q=q, T=T, std=False, random_state=generator
+ )
+
+ if gw_dist_estimated < best_gw_dist_estimated:
+ best_gw_dist_estimated = gw_dist_estimated
+ best_T = nx.copy(T)
+
+ if verbose:
+ if cpt % 200 == 0:
+ print('{:5s}|{:12s}'.format('It.', 'Best gw estimated') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(cpt, best_gw_dist_estimated))
+
+ if log:
+ log = {}
+ log["gw_dist_estimated"], log["gw_dist_std"] = GW_distance_estimation(
+ C1=C1, C2=C2, loss_fun=loss_fun,
+ p=p, q=q, T=best_T, random_state=generator
+ )
+ return best_T, log
+ return best_T
+
+
+def sampled_gromov_wasserstein(C1, C2, p, q, loss_fun,
+ nb_samples_grad=100, epsilon=1, max_iter=500, log=False, verbose=False,
+ random_state=None):
+ r"""
+ Returns the gromov-wasserstein transport between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` using a 1-stochastic Frank-Wolfe.
+ This method has a :math:`\mathcal{O}(\mathrm{max\_iter} \times N \log(N))` time complexity by relying on the 1D Optimal Transport solver.
+
+ The function solves the following optimization problem:
+
+ .. math::
+ \mathbf{GW} = \mathop{\arg \min}_\mathbf{T} \quad \sum_{i,j,k,l}
+ L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l}
+
+ s.t. \ \mathbf{T} \mathbf{1} &= \mathbf{p}
+
+ \mathbf{T}^T \mathbf{1} &= \mathbf{q}
+
+ \mathbf{T} &\geq 0
+
+ Where :
+
+ - :math:`\mathbf{C_1}`: Metric cost matrix in the source space
+ - :math:`\mathbf{C_2}`: Metric cost matrix in the target space
+ - :math:`\mathbf{p}`: distribution in the source space
+ - :math:`\mathbf{q}`: distribution in the target space
+ - `L`: loss function to account for the misfit between the similarity matrices
+
+ Parameters
+ ----------
+ C1 : array-like, shape (ns, ns)
+ Metric cost matrix in the source space
+ C2 : array-like, shape (nt, nt)
+ Metric cost matrix in the target space
+ p : array-like, shape (ns,)
+ Distribution in the source space
+ q : array-like, shape (nt,)
+ Distribution in the target space
+ loss_fun : function: :math:`\mathbb{R} \times \mathbb{R} \mapsto \mathbb{R}`
+ Loss function used for the distance, the transport plan does not depend on the loss function
+ nb_samples_grad : int
+ Number of samples to approximate the gradient
+ epsilon : float
+ Weight of the Kullback-Leibler regularization
+ max_iter : int, optional
+ Max number of iterations
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ Gives the distance estimated and the standard deviation
+ random_state : int or RandomState instance, optional
+ Fix the seed for reproducibility
+
+ Returns
+ -------
+ T : array-like, shape (`ns`, `nt`)
+ Optimal coupling between the two spaces
+
+ References
+ ----------
+ .. [14] Kerdoncuff, Tanguy, Emonet, Rémi, Sebban, Marc
+ "Sampled Gromov Wasserstein."
+ Machine Learning Journal (MLJ). 2021.
+
+ """
+ C1, C2, p, q = list_to_array(C1, C2, p, q)
+ nx = get_backend(C1, C2, p, q)
+
+ len_p = p.shape[0]
+ len_q = q.shape[0]
+
+ generator = check_random_state(random_state)
+
+ # The most natural way to define nb_sample is with a simple integer.
+ if isinstance(nb_samples_grad, int):
+ if nb_samples_grad > len_p:
+ # As the sampling along the first dimension is done without replacement, the rest is reported to the second
+ # dimension.
+ nb_samples_grad_p, nb_samples_grad_q = len_p, nb_samples_grad // len_p
+ else:
+ nb_samples_grad_p, nb_samples_grad_q = nb_samples_grad, 1
+ else:
+ nb_samples_grad_p, nb_samples_grad_q = nb_samples_grad
+ T = nx.outer(p, q)
+ # continue_loop allows to stop the loop if there is several successive small modification of T.
+ continue_loop = 0
+
+ # The gradient of GW is more complex if the two matrices are not symmetric.
+ C_are_symmetric = nx.allclose(C1, C1.T, rtol=1e-10, atol=1e-10) and nx.allclose(C2, C2.T, rtol=1e-10, atol=1e-10)
+
+ for cpt in range(max_iter):
+ index0 = generator.choice(
+ len_p, size=nb_samples_grad_p, p=nx.to_numpy(p), replace=False
+ )
+ Lik = 0
+ for i, index0_i in enumerate(index0):
+ index1 = generator.choice(
+ len_q, size=nb_samples_grad_q,
+ p=nx.to_numpy(T[index0_i, :] / nx.sum(T[index0_i, :])),
+ replace=False
+ )
+ # If the matrices C are not symmetric, the gradient has 2 terms, thus the term is chosen randomly.
+ if (not C_are_symmetric) and generator.rand(1) > 0.5:
+ Lik += nx.mean(loss_fun(
+ C1[:, [index0[i]] * nb_samples_grad_q][:, None, :],
+ C2[:, index1][None, :, :]
+ ), axis=2)
+ else:
+ Lik += nx.mean(loss_fun(
+ C1[[index0[i]] * nb_samples_grad_q, :][:, :, None],
+ C2[index1, :][:, None, :]
+ ), axis=0)
+
+ max_Lik = nx.max(Lik)
+ if max_Lik == 0:
+ continue
+ # This division by the max is here to facilitate the choice of epsilon.
+ Lik /= max_Lik
+
+ if epsilon > 0:
+ # Set to infinity all the numbers below exp(-200) to avoid log of 0.
+ log_T = nx.log(nx.clip(T, np.exp(-200), 1))
+ log_T = nx.where(log_T == -200, -np.inf, log_T)
+ Lik = Lik - epsilon * log_T
+
+ try:
+ new_T = sinkhorn(a=p, b=q, M=Lik, reg=epsilon)
+ except (RuntimeWarning, UserWarning):
+ print("Warning catched in Sinkhorn: Return last stable T")
+ break
+ else:
+ new_T = emd(a=p, b=q, M=Lik)
+
+ change_T = nx.mean((T - new_T) ** 2)
+ if change_T <= 10e-20:
+ continue_loop += 1
+ if continue_loop > 100: # Number max of low modifications of T
+ T = nx.copy(new_T)
+ break
+ else:
+ continue_loop = 0
+
+ if verbose and cpt % 10 == 0:
+ if cpt % 200 == 0:
+ print('{:5s}|{:12s}'.format('It.', '||T_n - T_{n+1}||') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(cpt, change_T))
+ T = nx.copy(new_T)
+
+ if log:
+ log = {}
+ log["gw_dist_estimated"], log["gw_dist_std"] = GW_distance_estimation(
+ C1=C1, C2=C2, loss_fun=loss_fun,
+ p=p, q=q, T=T, random_state=generator
+ )
+ return T, log
+ return T
diff --git a/ot/gromov/_gw.py b/ot/gromov/_gw.py
new file mode 100644
index 0000000..c6e4076
--- /dev/null
+++ b/ot/gromov/_gw.py
@@ -0,0 +1,978 @@
+# -*- coding: utf-8 -*-
+"""
+Gromov-Wasserstein and Fused-Gromov-Wasserstein conditional gradient solvers.
+"""
+
+# Author: Erwan Vautier <erwan.vautier@gmail.com>
+# Nicolas Courty <ncourty@irisa.fr>
+# Rémi Flamary <remi.flamary@unice.fr>
+# Titouan Vayer <titouan.vayer@irisa.fr>
+# Cédric Vincent-Cuaz <cedvincentcuaz@gmail.com>
+#
+# License: MIT License
+
+import numpy as np
+
+
+from ..utils import dist, UndefinedParameter, list_to_array
+from ..optim import cg, line_search_armijo, solve_1d_linesearch_quad
+from ..utils import check_random_state
+from ..backend import get_backend, NumpyBackend
+
+from ._utils import init_matrix, gwloss, gwggrad
+from ._utils import update_square_loss, update_kl_loss
+
+
+def gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', symmetric=None, log=False, armijo=False, G0=None,
+ max_iter=1e4, tol_rel=1e-9, tol_abs=1e-9, **kwargs):
+ r"""
+ Returns the gromov-wasserstein transport between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})`
+
+ The function solves the following optimization problem:
+
+ .. math::
+ \mathbf{GW} = \mathop{\arg \min}_\mathbf{T} \quad \sum_{i,j,k,l}
+ L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l}
+
+ s.t. \ \mathbf{\gamma} \mathbf{1} &= \mathbf{p}
+
+ \mathbf{\gamma}^T \mathbf{1} &= \mathbf{q}
+
+ \mathbf{\gamma} &\geq 0
+ Where :
+
+ - :math:`\mathbf{C_1}`: Metric cost matrix in the source space
+ - :math:`\mathbf{C_2}`: Metric cost matrix in the target space
+ - :math:`\mathbf{p}`: distribution in the source space
+ - :math:`\mathbf{q}`: distribution in the target space
+ - `L`: loss function to account for the misfit between the similarity matrices
+
+ .. note:: This function is backend-compatible and will work on arrays
+ from all compatible backends. But the algorithm uses the C++ CPU backend
+ which can lead to copy overhead on GPU arrays.
+ .. note:: All computations in the conjugate gradient solver are done with
+ numpy to limit memory overhead.
+
+ Parameters
+ ----------
+ C1 : array-like, shape (ns, ns)
+ Metric cost matrix in the source space
+ C2 : array-like, shape (nt, nt)
+ Metric cost matrix in the target space
+ p : array-like, shape (ns,)
+ Distribution in the source space
+ q : array-like, shape (nt,)
+ Distribution in the target space
+ loss_fun : str
+ loss function used for the solver either 'square_loss' or 'kl_loss'
+ symmetric : bool, optional
+ Either C1 and C2 are to be assumed symmetric or not.
+ If let to its default None value, a symmetry test will be conducted.
+ Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymetric).
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+ armijo : bool, optional
+ If True the step of the line-search is found via an armijo research. Else closed form is used.
+ If there are convergence issues use False.
+ G0: array-like, shape (ns,nt), optional
+ If None the initial transport plan of the solver is pq^T.
+ Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver.
+ max_iter : int, optional
+ Max number of iterations
+ tol_rel : float, optional
+ Stop threshold on relative error (>0)
+ tol_abs : float, optional
+ Stop threshold on absolute error (>0)
+ **kwargs : dict
+ parameters can be directly passed to the ot.optim.cg solver
+
+ Returns
+ -------
+ T : array-like, shape (`ns`, `nt`)
+ Coupling between the two spaces that minimizes:
+
+ :math:`\sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l}`
+ log : dict
+ Convergence information and loss.
+
+ References
+ ----------
+ .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon,
+ "Gromov-Wasserstein averaging of kernel and distance matrices."
+ International Conference on Machine Learning (ICML). 2016.
+
+ .. [13] Mémoli, Facundo. Gromov–Wasserstein distances and the
+ metric approach to object matching. Foundations of computational
+ mathematics 11.4 (2011): 417-487.
+
+ .. [47] Chowdhury, S., & Mémoli, F. (2019). The gromov–wasserstein
+ distance between networks and stable network invariants.
+ Information and Inference: A Journal of the IMA, 8(4), 757-787.
+ """
+ p, q = list_to_array(p, q)
+ p0, q0, C10, C20 = p, q, C1, C2
+ if G0 is None:
+ nx = get_backend(p0, q0, C10, C20)
+ else:
+ G0_ = G0
+ nx = get_backend(p0, q0, C10, C20, G0_)
+ p = nx.to_numpy(p)
+ q = nx.to_numpy(q)
+ C1 = nx.to_numpy(C10)
+ C2 = nx.to_numpy(C20)
+ if symmetric is None:
+ symmetric = np.allclose(C1, C1.T, atol=1e-10) and np.allclose(C2, C2.T, atol=1e-10)
+
+ if G0 is None:
+ G0 = p[:, None] * q[None, :]
+ else:
+ G0 = nx.to_numpy(G0_)
+ # Check marginals of G0
+ np.testing.assert_allclose(G0.sum(axis=1), p, atol=1e-08)
+ np.testing.assert_allclose(G0.sum(axis=0), q, atol=1e-08)
+ # cg for GW is implemented using numpy on CPU
+ np_ = NumpyBackend()
+
+ constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun, np_)
+
+ def f(G):
+ return gwloss(constC, hC1, hC2, G, np_)
+
+ if symmetric:
+ def df(G):
+ return gwggrad(constC, hC1, hC2, G, np_)
+ else:
+ constCt, hC1t, hC2t = init_matrix(C1.T, C2.T, p, q, loss_fun, np_)
+
+ def df(G):
+ return 0.5 * (gwggrad(constC, hC1, hC2, G, np_) + gwggrad(constCt, hC1t, hC2t, G, np_))
+ if loss_fun == 'kl_loss':
+ armijo = True # there is no closed form line-search with KL
+
+ if armijo:
+ def line_search(cost, G, deltaG, Mi, cost_G, **kwargs):
+ return line_search_armijo(cost, G, deltaG, Mi, cost_G, nx=np_, **kwargs)
+ else:
+ def line_search(cost, G, deltaG, Mi, cost_G, **kwargs):
+ return solve_gromov_linesearch(G, deltaG, cost_G, C1, C2, M=0., reg=1., nx=np_, **kwargs)
+ if log:
+ res, log = cg(p, q, 0., 1., f, df, G0, line_search, log=True, numItermax=max_iter, stopThr=tol_rel, stopThr2=tol_abs, **kwargs)
+ log['gw_dist'] = nx.from_numpy(log['loss'][-1], type_as=C10)
+ log['u'] = nx.from_numpy(log['u'], type_as=C10)
+ log['v'] = nx.from_numpy(log['v'], type_as=C10)
+ return nx.from_numpy(res, type_as=C10), log
+ else:
+ return nx.from_numpy(cg(p, q, 0., 1., f, df, G0, line_search, log=False, numItermax=max_iter, stopThr=tol_rel, stopThr2=tol_abs, **kwargs), type_as=C10)
+
+
+def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', symmetric=None, log=False, armijo=False, G0=None,
+ max_iter=1e4, tol_rel=1e-9, tol_abs=1e-9, **kwargs):
+ r"""
+ Returns the gromov-wasserstein discrepancy between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})`
+
+ The function solves the following optimization problem:
+
+ .. math::
+ GW = \min_\mathbf{T} \quad \sum_{i,j,k,l}
+ L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l}
+
+ s.t. \ \mathbf{\gamma} \mathbf{1} &= \mathbf{p}
+
+ \mathbf{\gamma}^T \mathbf{1} &= \mathbf{q}
+
+ \mathbf{\gamma} &\geq 0
+ Where :
+
+ - :math:`\mathbf{C_1}`: Metric cost matrix in the source space
+ - :math:`\mathbf{C_2}`: Metric cost matrix in the target space
+ - :math:`\mathbf{p}`: distribution in the source space
+ - :math:`\mathbf{q}`: distribution in the target space
+ - `L`: loss function to account for the misfit between the similarity
+ matrices
+
+ Note that when using backends, this loss function is differentiable wrt the
+ matrices (C1, C2) and weights (p, q) for quadratic loss using the gradients from [38]_.
+
+ .. note:: This function is backend-compatible and will work on arrays
+ from all compatible backends. But the algorithm uses the C++ CPU backend
+ which can lead to copy overhead on GPU arrays.
+ .. note:: All computations in the conjugate gradient solver are done with
+ numpy to limit memory overhead.
+
+ Parameters
+ ----------
+ C1 : array-like, shape (ns, ns)
+ Metric cost matrix in the source space
+ C2 : array-like, shape (nt, nt)
+ Metric cost matrix in the target space
+ p : array-like, shape (ns,)
+ Distribution in the source space.
+ q : array-like, shape (nt,)
+ Distribution in the target space.
+ loss_fun : str
+ loss function used for the solver either 'square_loss' or 'kl_loss'
+ symmetric : bool, optional
+ Either C1 and C2 are to be assumed symmetric or not.
+ If let to its default None value, a symmetry test will be conducted.
+ Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymetric).
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+ armijo : bool, optional
+ If True the step of the line-search is found via an armijo research. Else closed form is used.
+ If there are convergence issues use False.
+ G0: array-like, shape (ns,nt), optional
+ If None the initial transport plan of the solver is pq^T.
+ Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver.
+ max_iter : int, optional
+ Max number of iterations
+ tol_rel : float, optional
+ Stop threshold on relative error (>0)
+ tol_abs : float, optional
+ Stop threshold on absolute error (>0)
+ **kwargs : dict
+ parameters can be directly passed to the ot.optim.cg solver
+
+ Returns
+ -------
+ gw_dist : float
+ Gromov-Wasserstein distance
+ log : dict
+ convergence information and Coupling marix
+
+ References
+ ----------
+ .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon,
+ "Gromov-Wasserstein averaging of kernel and distance matrices."
+ International Conference on Machine Learning (ICML). 2016.
+
+ .. [13] Mémoli, Facundo. Gromov–Wasserstein distances and the
+ metric approach to object matching. Foundations of computational
+ mathematics 11.4 (2011): 417-487.
+
+ .. [38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, Online
+ Graph Dictionary Learning, International Conference on Machine Learning
+ (ICML), 2021.
+
+ .. [47] Chowdhury, S., & Mémoli, F. (2019). The gromov–wasserstein
+ distance between networks and stable network invariants.
+ Information and Inference: A Journal of the IMA, 8(4), 757-787.
+ """
+ # simple get_backend as the full one will be handled in gromov_wasserstein
+ nx = get_backend(C1, C2)
+
+ T, log_gw = gromov_wasserstein(
+ C1, C2, p, q, loss_fun, symmetric, log=True, armijo=armijo, G0=G0,
+ max_iter=max_iter, tol_rel=tol_rel, tol_abs=tol_abs, **kwargs)
+
+ log_gw['T'] = T
+ gw = log_gw['gw_dist']
+
+ if loss_fun == 'square_loss':
+ gC1 = 2 * C1 * nx.outer(p, p) - 2 * nx.dot(T, nx.dot(C2, T.T))
+ gC2 = 2 * C2 * nx.outer(q, q) - 2 * nx.dot(T.T, nx.dot(C1, T))
+ gw = nx.set_gradients(gw, (p, q, C1, C2),
+ (log_gw['u'] - nx.mean(log_gw['u']),
+ log_gw['v'] - nx.mean(log_gw['v']), gC1, gC2))
+
+ if log:
+ return gw, log_gw
+ else:
+ return gw
+
+
+def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', symmetric=None, alpha=0.5,
+ armijo=False, G0=None, log=False, max_iter=1e4, tol_rel=1e-9, tol_abs=1e-9, **kwargs):
+ r"""
+ Computes the FGW transport between two graphs (see :ref:`[24] <references-fused-gromov-wasserstein>`)
+
+ .. math::
+ \gamma = \mathop{\arg \min}_\gamma \quad (1 - \alpha) \langle \gamma, \mathbf{M} \rangle_F +
+ \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l}
+
+ s.t. \ \mathbf{\gamma} \mathbf{1} &= \mathbf{p}
+
+ \mathbf{\gamma}^T \mathbf{1} &= \mathbf{q}
+
+ \mathbf{\gamma} &\geq 0
+
+ where :
+
+ - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix
+ - :math:`\mathbf{p}` and :math:`\mathbf{q}` are source and target weights (sum to 1)
+ - `L` is a loss function to account for the misfit between the similarity matrices
+
+ .. note:: This function is backend-compatible and will work on arrays
+ from all compatible backends. But the algorithm uses the C++ CPU backend
+ which can lead to copy overhead on GPU arrays.
+ .. note:: All computations in the conjugate gradient solver are done with
+ numpy to limit memory overhead.
+ The algorithm used for solving the problem is conditional gradient as discussed in :ref:`[24] <references-fused-gromov-wasserstein>`
+
+ Parameters
+ ----------
+ M : array-like, shape (ns, nt)
+ Metric cost matrix between features across domains
+ C1 : array-like, shape (ns, ns)
+ Metric cost matrix representative of the structure in the source space
+ C2 : array-like, shape (nt, nt)
+ Metric cost matrix representative of the structure in the target space
+ p : array-like, shape (ns,)
+ Distribution in the source space
+ q : array-like, shape (nt,)
+ Distribution in the target space
+ loss_fun : str, optional
+ Loss function used for the solver
+ symmetric : bool, optional
+ Either C1 and C2 are to be assumed symmetric or not.
+ If let to its default None value, a symmetry test will be conducted.
+ Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymetric).
+ alpha : float, optional
+ Trade-off parameter (0 < alpha < 1)
+ armijo : bool, optional
+ If True the step of the line-search is found via an armijo research. Else closed form is used.
+ If there are convergence issues use False.
+ G0: array-like, shape (ns,nt), optional
+ If None the initial transport plan of the solver is pq^T.
+ Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver.
+ log : bool, optional
+ record log if True
+ max_iter : int, optional
+ Max number of iterations
+ tol_rel : float, optional
+ Stop threshold on relative error (>0)
+ tol_abs : float, optional
+ Stop threshold on absolute error (>0)
+ **kwargs : dict
+ parameters can be directly passed to the ot.optim.cg solver
+
+ Returns
+ -------
+ gamma : array-like, shape (`ns`, `nt`)
+ Optimal transportation matrix for the given parameters.
+ log : dict
+ Log dictionary return only if log==True in parameters.
+
+
+ .. _references-fused-gromov-wasserstein:
+ References
+ ----------
+ .. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain
+ and Courty Nicolas "Optimal Transport for structured data with
+ application on graphs", International Conference on Machine Learning
+ (ICML). 2019.
+
+ .. [47] Chowdhury, S., & Mémoli, F. (2019). The gromov–wasserstein
+ distance between networks and stable network invariants.
+ Information and Inference: A Journal of the IMA, 8(4), 757-787.
+ """
+ p, q = list_to_array(p, q)
+ p0, q0, C10, C20, M0 = p, q, C1, C2, M
+ if G0 is None:
+ nx = get_backend(p0, q0, C10, C20, M0)
+ else:
+ G0_ = G0
+ nx = get_backend(p0, q0, C10, C20, M0, G0_)
+
+ p = nx.to_numpy(p)
+ q = nx.to_numpy(q)
+ C1 = nx.to_numpy(C10)
+ C2 = nx.to_numpy(C20)
+ M = nx.to_numpy(M0)
+
+ if symmetric is None:
+ symmetric = np.allclose(C1, C1.T, atol=1e-10) and np.allclose(C2, C2.T, atol=1e-10)
+
+ if G0 is None:
+ G0 = p[:, None] * q[None, :]
+ else:
+ G0 = nx.to_numpy(G0_)
+ # Check marginals of G0
+ np.testing.assert_allclose(G0.sum(axis=1), p, atol=1e-08)
+ np.testing.assert_allclose(G0.sum(axis=0), q, atol=1e-08)
+ # cg for GW is implemented using numpy on CPU
+ np_ = NumpyBackend()
+
+ constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun, np_)
+
+ def f(G):
+ return gwloss(constC, hC1, hC2, G, np_)
+
+ if symmetric:
+ def df(G):
+ return gwggrad(constC, hC1, hC2, G, np_)
+ else:
+ constCt, hC1t, hC2t = init_matrix(C1.T, C2.T, p, q, loss_fun, np_)
+
+ def df(G):
+ return 0.5 * (gwggrad(constC, hC1, hC2, G, np_) + gwggrad(constCt, hC1t, hC2t, G, np_))
+
+ if loss_fun == 'kl_loss':
+ armijo = True # there is no closed form line-search with KL
+
+ if armijo:
+ def line_search(cost, G, deltaG, Mi, cost_G, **kwargs):
+ return line_search_armijo(cost, G, deltaG, Mi, cost_G, nx=np_, **kwargs)
+ else:
+ def line_search(cost, G, deltaG, Mi, cost_G, **kwargs):
+ return solve_gromov_linesearch(G, deltaG, cost_G, C1, C2, M=(1 - alpha) * M, reg=alpha, nx=np_, **kwargs)
+ if log:
+ res, log = cg(p, q, (1 - alpha) * M, alpha, f, df, G0, line_search, log=True, numItermax=max_iter, stopThr=tol_rel, stopThr2=tol_abs, **kwargs)
+ log['fgw_dist'] = nx.from_numpy(log['loss'][-1], type_as=C10)
+ log['u'] = nx.from_numpy(log['u'], type_as=C10)
+ log['v'] = nx.from_numpy(log['v'], type_as=C10)
+ return nx.from_numpy(res, type_as=C10), log
+ else:
+ return nx.from_numpy(cg(p, q, (1 - alpha) * M, alpha, f, df, G0, line_search, log=False, numItermax=max_iter, stopThr=tol_rel, stopThr2=tol_abs, **kwargs), type_as=C10)
+
+
+def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', symmetric=None, alpha=0.5,
+ armijo=False, G0=None, log=False, max_iter=1e4, tol_rel=1e-9, tol_abs=1e-9, **kwargs):
+ r"""
+ Computes the FGW distance between two graphs see (see :ref:`[24] <references-fused-gromov-wasserstein2>`)
+
+ .. math::
+ \min_\gamma \quad (1 - \alpha) \langle \gamma, \mathbf{M} \rangle_F + \alpha \sum_{i,j,k,l}
+ L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l}
+
+ s.t. \ \mathbf{\gamma} \mathbf{1} &= \mathbf{p}
+
+ \mathbf{\gamma}^T \mathbf{1} &= \mathbf{q}
+
+ \mathbf{\gamma} &\geq 0
+
+ where :
+
+ - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix
+ - :math:`\mathbf{p}` and :math:`\mathbf{q}` are source and target weights (sum to 1)
+ - `L` is a loss function to account for the misfit between the similarity matrices
+
+ The algorithm used for solving the problem is conditional gradient as
+ discussed in :ref:`[24] <references-fused-gromov-wasserstein2>`
+
+ .. note:: This function is backend-compatible and will work on arrays
+ from all compatible backends. But the algorithm uses the C++ CPU backend
+ which can lead to copy overhead on GPU arrays.
+ .. note:: All computations in the conjugate gradient solver are done with
+ numpy to limit memory overhead.
+
+ Note that when using backends, this loss function is differentiable wrt the
+ matrices (C1, C2, M) and weights (p, q) for quadratic loss using the gradients from [38]_.
+
+ Parameters
+ ----------
+ M : array-like, shape (ns, nt)
+ Metric cost matrix between features across domains
+ C1 : array-like, shape (ns, ns)
+ Metric cost matrix representative of the structure in the source space.
+ C2 : array-like, shape (nt, nt)
+ Metric cost matrix representative of the structure in the target space.
+ p : array-like, shape (ns,)
+ Distribution in the source space.
+ q : array-like, shape (nt,)
+ Distribution in the target space.
+ loss_fun : str, optional
+ Loss function used for the solver.
+ symmetric : bool, optional
+ Either C1 and C2 are to be assumed symmetric or not.
+ If let to its default None value, a symmetry test will be conducted.
+ Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymmetric).
+ alpha : float, optional
+ Trade-off parameter (0 < alpha < 1)
+ armijo : bool, optional
+ If True the step of the line-search is found via an armijo research.
+ Else closed form is used. If there are convergence issues use False.
+ G0: array-like, shape (ns,nt), optional
+ If None the initial transport plan of the solver is pq^T.
+ Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver.
+ log : bool, optional
+ Record log if True.
+ max_iter : int, optional
+ Max number of iterations
+ tol_rel : float, optional
+ Stop threshold on relative error (>0)
+ tol_abs : float, optional
+ Stop threshold on absolute error (>0)
+ **kwargs : dict
+ Parameters can be directly passed to the ot.optim.cg solver.
+
+ Returns
+ -------
+ fgw-distance : float
+ Fused gromov wasserstein distance for the given parameters.
+ log : dict
+ Log dictionary return only if log==True in parameters.
+
+
+ .. _references-fused-gromov-wasserstein2:
+ References
+ ----------
+ .. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain
+ and Courty Nicolas
+ "Optimal Transport for structured data with application on graphs"
+ International Conference on Machine Learning (ICML). 2019.
+
+ .. [38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, Online
+ Graph Dictionary Learning, International Conference on Machine Learning
+ (ICML), 2021.
+
+ .. [47] Chowdhury, S., & Mémoli, F. (2019). The gromov–wasserstein
+ distance between networks and stable network invariants.
+ Information and Inference: A Journal of the IMA, 8(4), 757-787.
+ """
+ nx = get_backend(C1, C2, M)
+
+ T, log_fgw = fused_gromov_wasserstein(
+ M, C1, C2, p, q, loss_fun, symmetric, alpha, armijo, G0, log=True,
+ max_iter=max_iter, tol_rel=tol_rel, tol_abs=tol_abs, **kwargs)
+
+ fgw_dist = log_fgw['fgw_dist']
+ log_fgw['T'] = T
+
+ if loss_fun == 'square_loss':
+ gC1 = 2 * C1 * nx.outer(p, p) - 2 * nx.dot(T, nx.dot(C2, T.T))
+ gC2 = 2 * C2 * nx.outer(q, q) - 2 * nx.dot(T.T, nx.dot(C1, T))
+ fgw_dist = nx.set_gradients(fgw_dist, (p, q, C1, C2, M),
+ (log_fgw['u'] - nx.mean(log_fgw['u']),
+ log_fgw['v'] - nx.mean(log_fgw['v']),
+ alpha * gC1, alpha * gC2, (1 - alpha) * T))
+
+ if log:
+ return fgw_dist, log_fgw
+ else:
+ return fgw_dist
+
+
+def solve_gromov_linesearch(G, deltaG, cost_G, C1, C2, M, reg,
+ alpha_min=None, alpha_max=None, nx=None, **kwargs):
+ """
+ Solve the linesearch in the FW iterations
+
+ Parameters
+ ----------
+
+ G : array-like, shape(ns,nt)
+ The transport map at a given iteration of the FW
+ deltaG : array-like (ns,nt)
+ Difference between the optimal map found by linearization in the FW algorithm and the value at a given iteration
+ cost_G : float
+ Value of the cost at `G`
+ C1 : array-like (ns,ns), optional
+ Structure matrix in the source domain.
+ C2 : array-like (nt,nt), optional
+ Structure matrix in the target domain.
+ M : array-like (ns,nt)
+ Cost matrix between the features.
+ reg : float
+ Regularization parameter.
+ alpha_min : float, optional
+ Minimum value for alpha
+ alpha_max : float, optional
+ Maximum value for alpha
+ nx : backend, optional
+ If let to its default value None, a backend test will be conducted.
+ Returns
+ -------
+ alpha : float
+ The optimal step size of the FW
+ fc : int
+ nb of function call. Useless here
+ cost_G : float
+ The value of the cost for the next iteration
+
+
+ .. _references-solve-linesearch:
+ References
+ ----------
+ .. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain and Courty Nicolas
+ "Optimal Transport for structured data with application on graphs"
+ International Conference on Machine Learning (ICML). 2019.
+ """
+ if nx is None:
+ G, deltaG, C1, C2, M = list_to_array(G, deltaG, C1, C2, M)
+
+ if isinstance(M, int) or isinstance(M, float):
+ nx = get_backend(G, deltaG, C1, C2)
+ else:
+ nx = get_backend(G, deltaG, C1, C2, M)
+
+ dot = nx.dot(nx.dot(C1, deltaG), C2.T)
+ a = -2 * reg * nx.sum(dot * deltaG)
+ b = nx.sum(M * deltaG) - 2 * reg * (nx.sum(dot * G) + nx.sum(nx.dot(nx.dot(C1, G), C2.T) * deltaG))
+
+ alpha = solve_1d_linesearch_quad(a, b)
+ if alpha_min is not None or alpha_max is not None:
+ alpha = np.clip(alpha, alpha_min, alpha_max)
+
+ # the new cost is deduced from the line search quadratic function
+ cost_G = cost_G + a * (alpha ** 2) + b * alpha
+
+ return alpha, 1, cost_G
+
+
+def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, symmetric=True, armijo=False,
+ max_iter=1000, tol=1e-9, verbose=False, log=False,
+ init_C=None, random_state=None, **kwargs):
+ r"""
+ Returns the gromov-wasserstein barycenters of `S` measured similarity matrices :math:`(\mathbf{C}_s)_{1 \leq s \leq S}`
+
+ The function solves the following optimization problem with block coordinate descent:
+
+ .. math::
+
+ \mathbf{C} = \mathop{\arg \min}_{\mathbf{C}\in \mathbb{R}^{N \times N}} \quad \sum_s \lambda_s \mathrm{GW}(\mathbf{C}, \mathbf{C}_s, \mathbf{p}, \mathbf{p}_s)
+
+ Where :
+
+ - :math:`\mathbf{C}_s`: metric cost matrix
+ - :math:`\mathbf{p}_s`: distribution
+
+ Parameters
+ ----------
+ N : int
+ Size of the targeted barycenter
+ Cs : list of S array-like of shape (ns, ns)
+ Metric cost matrices
+ ps : list of S array-like of shape (ns,)
+ Sample weights in the `S` spaces
+ p : array-like, shape (N,)
+ Weights in the targeted barycenter
+ lambdas : list of float
+ List of the `S` spaces' weights
+ loss_fun : callable
+ tensor-matrix multiplication function based on specific loss function
+ symmetric : bool, optional.
+ Either structures are to be assumed symmetric or not. Default value is True.
+ Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymmetric).
+ update : callable
+ function(:math:`\mathbf{p}`, lambdas, :math:`\mathbf{T}`, :math:`\mathbf{Cs}`) that updates
+ :math:`\mathbf{C}` according to a specific Kernel with the `S` :math:`\mathbf{T}_s` couplings
+ calculated at each iteration
+ max_iter : int, optional
+ Max number of iterations
+ tol : float, optional
+ Stop threshold on relative error (>0)
+ verbose : bool, optional
+ Print information along iterations.
+ log : bool, optional
+ Record log if True.
+ init_C : bool | array-like, shape(N,N)
+ Random initial value for the :math:`\mathbf{C}` matrix provided by user.
+ random_state : int or RandomState instance, optional
+ Fix the seed for reproducibility
+
+ Returns
+ -------
+ C : array-like, shape (`N`, `N`)
+ Similarity matrix in the barycenter space (permutated arbitrarily)
+ log : dict
+ Log dictionary of error during iterations. Return only if `log=True` in parameters.
+
+ References
+ ----------
+ .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon,
+ "Gromov-Wasserstein averaging of kernel and distance matrices."
+ International Conference on Machine Learning (ICML). 2016.
+
+ """
+ Cs = list_to_array(*Cs)
+ ps = list_to_array(*ps)
+ p = list_to_array(p)
+ nx = get_backend(*Cs, *ps, p)
+
+ S = len(Cs)
+
+ # Initialization of C : random SPD matrix (if not provided by user)
+ if init_C is None:
+ generator = check_random_state(random_state)
+ xalea = generator.randn(N, 2)
+ C = dist(xalea, xalea)
+ C /= C.max()
+ C = nx.from_numpy(C, type_as=p)
+ else:
+ C = init_C
+
+ if loss_fun == 'kl_loss':
+ armijo = True
+
+ cpt = 0
+ err = 1
+
+ error = []
+
+ while (err > tol and cpt < max_iter):
+ Cprev = C
+
+ T = [gromov_wasserstein(Cs[s], C, ps[s], p, loss_fun, symmetric=symmetric, armijo=armijo,
+ max_iter=max_iter, tol_rel=1e-5, tol_abs=0., verbose=verbose, log=False, **kwargs) for s in range(S)]
+ if loss_fun == 'square_loss':
+ C = update_square_loss(p, lambdas, T, Cs)
+
+ elif loss_fun == 'kl_loss':
+ C = update_kl_loss(p, lambdas, T, Cs)
+
+ if cpt % 10 == 0:
+ # we can speed up the process by checking for the error only all
+ # the 10th iterations
+ err = nx.norm(C - Cprev)
+ error.append(err)
+
+ if verbose:
+ if cpt % 200 == 0:
+ print('{:5s}|{:12s}'.format(
+ 'It.', 'Err') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(cpt, err))
+
+ cpt += 1
+
+ if log:
+ return C, {"err": error}
+ else:
+ return C
+
+
+def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_features=False,
+ p=None, loss_fun='square_loss', armijo=False, symmetric=True, max_iter=100, tol=1e-9,
+ verbose=False, log=False, init_C=None, init_X=None, random_state=None, **kwargs):
+ r"""Compute the fgw barycenter as presented eq (5) in :ref:`[24] <references-fgw-barycenters>`
+
+ Parameters
+ ----------
+ N : int
+ Desired number of samples of the target barycenter
+ Ys: list of array-like, each element has shape (ns,d)
+ Features of all samples
+ Cs : list of array-like, each element has shape (ns,ns)
+ Structure matrices of all samples
+ ps : list of array-like, each element has shape (ns,)
+ Masses of all samples.
+ lambdas : list of float
+ List of the `S` spaces' weights
+ alpha : float
+ Alpha parameter for the fgw distance
+ fixed_structure : bool
+ Whether to fix the structure of the barycenter during the updates
+ fixed_features : bool
+ Whether to fix the feature of the barycenter during the updates
+ loss_fun : str
+ Loss function used for the solver either 'square_loss' or 'kl_loss'
+ symmetric : bool, optional
+ Either structures are to be assumed symmetric or not. Default value is True.
+ Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymmetric).
+ max_iter : int, optional
+ Max number of iterations
+ tol : float, optional
+ Stop threshold on relative error (>0)
+ verbose : bool, optional
+ Print information along iterations.
+ log : bool, optional
+ Record log if True.
+ init_C : array-like, shape (N,N), optional
+ Initialization for the barycenters' structure matrix. If not set
+ a random init is used.
+ init_X : array-like, shape (N,d), optional
+ Initialization for the barycenters' features. If not set a
+ random init is used.
+ random_state : int or RandomState instance, optional
+ Fix the seed for reproducibility
+
+ Returns
+ -------
+ X : array-like, shape (`N`, `d`)
+ Barycenters' features
+ C : array-like, shape (`N`, `N`)
+ Barycenters' structure matrix
+ log : dict
+ Only returned when log=True. It contains the keys:
+
+ - :math:`\mathbf{T}`: list of (`N`, `ns`) transport matrices
+ - :math:`(\mathbf{M}_s)_s`: all distance matrices between the feature of the barycenter and the other features :math:`(dist(\mathbf{X}, \mathbf{Y}_s))_s` shape (`N`, `ns`)
+
+
+ .. _references-fgw-barycenters:
+ References
+ ----------
+ .. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain
+ and Courty Nicolas
+ "Optimal Transport for structured data with application on graphs"
+ International Conference on Machine Learning (ICML). 2019.
+ """
+ Cs = list_to_array(*Cs)
+ ps = list_to_array(*ps)
+ Ys = list_to_array(*Ys)
+ p = list_to_array(p)
+ nx = get_backend(*Cs, *Ys, *ps)
+
+ S = len(Cs)
+ d = Ys[0].shape[1] # dimension on the node features
+ if p is None:
+ p = nx.ones(N, type_as=Cs[0]) / N
+
+ if fixed_structure:
+ if init_C is None:
+ raise UndefinedParameter('If C is fixed it must be initialized')
+ else:
+ C = init_C
+ else:
+ if init_C is None:
+ generator = check_random_state(random_state)
+ xalea = generator.randn(N, 2)
+ C = dist(xalea, xalea)
+ C = nx.from_numpy(C, type_as=ps[0])
+ else:
+ C = init_C
+
+ if fixed_features:
+ if init_X is None:
+ raise UndefinedParameter('If X is fixed it must be initialized')
+ else:
+ X = init_X
+ else:
+ if init_X is None:
+ X = nx.zeros((N, d), type_as=ps[0])
+ else:
+ X = init_X
+
+ T = [nx.outer(p, q) for q in ps]
+
+ Ms = [dist(X, Ys[s]) for s in range(len(Ys))]
+
+ if loss_fun == 'kl_loss':
+ armijo = True
+
+ cpt = 0
+ err_feature = 1
+ err_structure = 1
+
+ if log:
+ log_ = {}
+ log_['err_feature'] = []
+ log_['err_structure'] = []
+ log_['Ts_iter'] = []
+
+ while ((err_feature > tol or err_structure > tol) and cpt < max_iter):
+ Cprev = C
+ Xprev = X
+
+ if not fixed_features:
+ Ys_temp = [y.T for y in Ys]
+ X = update_feature_matrix(lambdas, Ys_temp, T, p).T
+
+ Ms = [dist(X, Ys[s]) for s in range(len(Ys))]
+
+ if not fixed_structure:
+ if loss_fun == 'square_loss':
+ T_temp = [t.T for t in T]
+ C = update_structure_matrix(p, lambdas, T_temp, Cs)
+
+ T = [fused_gromov_wasserstein(Ms[s], C, Cs[s], p, ps[s], loss_fun=loss_fun, alpha=alpha, armijo=armijo, symmetric=symmetric,
+ max_iter=max_iter, tol_rel=1e-5, tol_abs=0., verbose=verbose, **kwargs) for s in range(S)]
+
+ # T is N,ns
+ err_feature = nx.norm(X - nx.reshape(Xprev, (N, d)))
+ err_structure = nx.norm(C - Cprev)
+ if log:
+ log_['err_feature'].append(err_feature)
+ log_['err_structure'].append(err_structure)
+ log_['Ts_iter'].append(T)
+
+ if verbose:
+ if cpt % 200 == 0:
+ print('{:5s}|{:12s}'.format(
+ 'It.', 'Err') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(cpt, err_structure))
+ print('{:5d}|{:8e}|'.format(cpt, err_feature))
+
+ cpt += 1
+
+ if log:
+ log_['T'] = T # from target to Ys
+ log_['p'] = p
+ log_['Ms'] = Ms
+
+ if log:
+ return X, C, log_
+ else:
+ return X, C
+
+
+def update_structure_matrix(p, lambdas, T, Cs):
+ r"""Updates :math:`\mathbf{C}` according to the L2 Loss kernel with the `S` :math:`\mathbf{T}_s` couplings.
+
+ It is calculated at each iteration
+
+ Parameters
+ ----------
+ p : array-like, shape (N,)
+ Masses in the targeted barycenter.
+ lambdas : list of float
+ List of the `S` spaces' weights.
+ T : list of S array-like of shape (ns, N)
+ The `S` :math:`\mathbf{T}_s` couplings calculated at each iteration.
+ Cs : list of S array-like, shape (ns, ns)
+ Metric cost matrices.
+
+ Returns
+ -------
+ C : array-like, shape (`nt`, `nt`)
+ Updated :math:`\mathbf{C}` matrix.
+ """
+ p = list_to_array(p)
+ T = list_to_array(*T)
+ Cs = list_to_array(*Cs)
+ nx = get_backend(*Cs, *T, p)
+
+ tmpsum = sum([
+ lambdas[s] * nx.dot(
+ nx.dot(T[s].T, Cs[s]),
+ T[s]
+ ) for s in range(len(T))
+ ])
+ ppt = nx.outer(p, p)
+ return tmpsum / ppt
+
+
+def update_feature_matrix(lambdas, Ys, Ts, p):
+ r"""Updates the feature with respect to the `S` :math:`\mathbf{T}_s` couplings.
+
+
+ See "Solving the barycenter problem with Block Coordinate Descent (BCD)"
+ in :ref:`[24] <references-update-feature-matrix>` calculated at each iteration
+
+ Parameters
+ ----------
+ p : array-like, shape (N,)
+ masses in the targeted barycenter
+ lambdas : list of float
+ List of the `S` spaces' weights
+ Ts : list of S array-like, shape (ns,N)
+ The `S` :math:`\mathbf{T}_s` couplings calculated at each iteration
+ Ys : list of S array-like, shape (d,ns)
+ The features.
+
+ Returns
+ -------
+ X : array-like, shape (`d`, `N`)
+
+
+ .. _references-update-feature-matrix:
+ References
+ ----------
+ .. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain and Courty Nicolas
+ "Optimal Transport for structured data with application on graphs"
+ International Conference on Machine Learning (ICML). 2019.
+ """
+ p = list_to_array(p)
+ Ts = list_to_array(*Ts)
+ Ys = list_to_array(*Ys)
+ nx = get_backend(*Ys, *Ts, p)
+
+ p = 1. / p
+ tmpsum = sum([
+ lambdas[s] * nx.dot(Ys[s], Ts[s].T) * p[None, :]
+ for s in range(len(Ts))
+ ])
+ return tmpsum
diff --git a/ot/gromov/_semirelaxed.py b/ot/gromov/_semirelaxed.py
new file mode 100644
index 0000000..638bb1c
--- /dev/null
+++ b/ot/gromov/_semirelaxed.py
@@ -0,0 +1,543 @@
+# -*- coding: utf-8 -*-
+"""
+Semi-relaxed Gromov-Wasserstein and Fused-Gromov-Wasserstein solvers.
+"""
+
+# Author: Rémi Flamary <remi.flamary@unice.fr>
+# Cédric Vincent-Cuaz <cedvincentcuaz@gmail.com>
+#
+# License: MIT License
+
+import numpy as np
+
+
+from ..utils import list_to_array, unif
+from ..optim import semirelaxed_cg, solve_1d_linesearch_quad
+from ..backend import get_backend
+
+from ._utils import init_matrix_semirelaxed, gwloss, gwggrad
+
+
+def semirelaxed_gromov_wasserstein(C1, C2, p, loss_fun='square_loss', symmetric=None, log=False, G0=None,
+ max_iter=1e4, tol_rel=1e-9, tol_abs=1e-9, **kwargs):
+ r"""
+ Returns the semi-relaxed gromov-wasserstein divergence transport from :math:`(\mathbf{C_1}, \mathbf{p})` to :math:`\mathbf{C_2}`
+
+ The function solves the following optimization problem:
+
+ .. math::
+ \mathbf{srGW} = \mathop{\arg \min}_\mathbf{T} \quad \sum_{i,j,k,l}
+ L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l}
+
+ s.t. \ \mathbf{\gamma} \mathbf{1} &= \mathbf{p}
+
+ \mathbf{\gamma} &\geq 0
+ Where :
+
+ - :math:`\mathbf{C_1}`: Metric cost matrix in the source space
+ - :math:`\mathbf{C_2}`: Metric cost matrix in the target space
+ - :math:`\mathbf{p}`: distribution in the source space
+
+ - `L`: loss function to account for the misfit between the similarity matrices
+
+ .. note:: This function is backend-compatible and will work on arrays
+ from all compatible backends. However all the steps in the conditional
+ gradient are not differentiable.
+
+ Parameters
+ ----------
+ C1 : array-like, shape (ns, ns)
+ Metric cost matrix in the source space
+ C2 : array-like, shape (nt, nt)
+ Metric cost matrix in the target space
+ p : array-like, shape (ns,)
+ Distribution in the source space
+ loss_fun : str
+ loss function used for the solver either 'square_loss' or 'kl_loss'.
+ 'kl_loss' is not implemented yet and will raise an error.
+ symmetric : bool, optional
+ Either C1 and C2 are to be assumed symmetric or not.
+ If let to its default None value, a symmetry test will be conducted.
+ Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymetric).
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+ G0: array-like, shape (ns,nt), optional
+ If None the initial transport plan of the solver is pq^T.
+ Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver.
+ max_iter : int, optional
+ Max number of iterations
+ tol_rel : float, optional
+ Stop threshold on relative error (>0)
+ tol_abs : float, optional
+ Stop threshold on absolute error (>0)
+ **kwargs : dict
+ parameters can be directly passed to the ot.optim.cg solver
+
+ Returns
+ -------
+ T : array-like, shape (`ns`, `nt`)
+ Coupling between the two spaces that minimizes:
+
+ :math:`\sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l}`
+ log : dict
+ Convergence information and loss.
+
+ References
+ ----------
+ .. [48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty.
+ "Semi-relaxed Gromov-Wasserstein divergence and applications on graphs"
+ International Conference on Learning Representations (ICLR), 2022.
+ """
+ if loss_fun == 'kl_loss':
+ raise NotImplementedError()
+ p = list_to_array(p)
+ if G0 is None:
+ nx = get_backend(p, C1, C2)
+ else:
+ nx = get_backend(p, C1, C2, G0)
+
+ if symmetric is None:
+ symmetric = nx.allclose(C1, C1.T, atol=1e-10) and nx.allclose(C2, C2.T, atol=1e-10)
+ if G0 is None:
+ q = unif(C2.shape[0], type_as=p)
+ G0 = nx.outer(p, q)
+ else:
+ q = nx.sum(G0, 0)
+ # Check first marginal of G0
+ np.testing.assert_allclose(nx.sum(G0, 1), p, atol=1e-08)
+
+ constC, hC1, hC2, fC2t = init_matrix_semirelaxed(C1, C2, p, loss_fun, nx)
+
+ ones_p = nx.ones(p.shape[0], type_as=p)
+
+ def f(G):
+ qG = nx.sum(G, 0)
+ marginal_product = nx.outer(ones_p, nx.dot(qG, fC2t))
+ return gwloss(constC + marginal_product, hC1, hC2, G, nx)
+
+ if symmetric:
+ def df(G):
+ qG = nx.sum(G, 0)
+ marginal_product = nx.outer(ones_p, nx.dot(qG, fC2t))
+ return gwggrad(constC + marginal_product, hC1, hC2, G, nx)
+ else:
+ constCt, hC1t, hC2t, fC2 = init_matrix_semirelaxed(C1.T, C2.T, p, loss_fun, nx)
+
+ def df(G):
+ qG = nx.sum(G, 0)
+ marginal_product_1 = nx.outer(ones_p, nx.dot(qG, fC2t))
+ marginal_product_2 = nx.outer(ones_p, nx.dot(qG, fC2))
+ return 0.5 * (gwggrad(constC + marginal_product_1, hC1, hC2, G, nx) + gwggrad(constCt + marginal_product_2, hC1t, hC2t, G, nx))
+
+ def line_search(cost, G, deltaG, Mi, cost_G, **kwargs):
+ return solve_semirelaxed_gromov_linesearch(G, deltaG, cost_G, C1, C2, ones_p, M=0., reg=1., nx=nx, **kwargs)
+
+ if log:
+ res, log = semirelaxed_cg(p, q, 0., 1., f, df, G0, line_search, log=True, numItermax=max_iter, stopThr=tol_rel, stopThr2=tol_abs, **kwargs)
+ log['srgw_dist'] = log['loss'][-1]
+ return res, log
+ else:
+ return semirelaxed_cg(p, q, 0., 1., f, df, G0, line_search, log=False, numItermax=max_iter, stopThr=tol_rel, stopThr2=tol_abs, **kwargs)
+
+
+def semirelaxed_gromov_wasserstein2(C1, C2, p, loss_fun='square_loss', symmetric=None, log=False, G0=None,
+ max_iter=1e4, tol_rel=1e-9, tol_abs=1e-9, **kwargs):
+ r"""
+ Returns the semi-relaxed gromov-wasserstein divergence from :math:`(\mathbf{C_1}, \mathbf{p})` to :math:`\mathbf{C_2}`
+
+ The function solves the following optimization problem:
+
+ .. math::
+ srGW = \min_\mathbf{T} \quad \sum_{i,j,k,l}
+ L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l}
+
+ s.t. \ \mathbf{\gamma} \mathbf{1} &= \mathbf{p}
+
+ \mathbf{\gamma} &\geq 0
+ Where :
+
+ - :math:`\mathbf{C_1}`: Metric cost matrix in the source space
+ - :math:`\mathbf{C_2}`: Metric cost matrix in the target space
+ - :math:`\mathbf{p}`: distribution in the source space
+ - `L`: loss function to account for the misfit between the similarity
+ matrices
+
+ Note that when using backends, this loss function is differentiable wrt the
+ matrices (C1, C2) but not yet for the weights p.
+ .. note:: This function is backend-compatible and will work on arrays
+ from all compatible backends. However all the steps in the conditional
+ gradient are not differentiable.
+
+ Parameters
+ ----------
+ C1 : array-like, shape (ns, ns)
+ Metric cost matrix in the source space
+ C2 : array-like, shape (nt, nt)
+ Metric cost matrix in the target space
+ p : array-like, shape (ns,)
+ Distribution in the source space.
+ loss_fun : str
+ loss function used for the solver either 'square_loss' or 'kl_loss'.
+ 'kl_loss' is not implemented yet and will raise an error.
+ symmetric : bool, optional
+ Either C1 and C2 are to be assumed symmetric or not.
+ If let to its default None value, a symmetry test will be conducted.
+ Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymetric).
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+ G0: array-like, shape (ns,nt), optional
+ If None the initial transport plan of the solver is pq^T.
+ Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver.
+ max_iter : int, optional
+ Max number of iterations
+ tol_rel : float, optional
+ Stop threshold on relative error (>0)
+ tol_abs : float, optional
+ Stop threshold on absolute error (>0)
+ **kwargs : dict
+ parameters can be directly passed to the ot.optim.cg solver
+
+ Returns
+ -------
+ srgw : float
+ Semi-relaxed Gromov-Wasserstein divergence
+ log : dict
+ convergence information and Coupling matrix
+
+ References
+ ----------
+
+ .. [48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty.
+ "Semi-relaxed Gromov-Wasserstein divergence and applications on graphs"
+ International Conference on Learning Representations (ICLR), 2022.
+ """
+ nx = get_backend(p, C1, C2)
+
+ T, log_srgw = semirelaxed_gromov_wasserstein(
+ C1, C2, p, loss_fun, symmetric, log=True, G0=G0,
+ max_iter=max_iter, tol_rel=tol_rel, tol_abs=tol_abs, **kwargs)
+
+ q = nx.sum(T, 0)
+ log_srgw['T'] = T
+ srgw = log_srgw['srgw_dist']
+
+ if loss_fun == 'square_loss':
+ gC1 = 2 * C1 * nx.outer(p, p) - 2 * nx.dot(T, nx.dot(C2, T.T))
+ gC2 = 2 * C2 * nx.outer(q, q) - 2 * nx.dot(T.T, nx.dot(C1, T))
+ srgw = nx.set_gradients(srgw, (C1, C2), (gC1, gC2))
+
+ if log:
+ return srgw, log_srgw
+ else:
+ return srgw
+
+
+def semirelaxed_fused_gromov_wasserstein(M, C1, C2, p, loss_fun='square_loss', symmetric=None, alpha=0.5, G0=None, log=False,
+ max_iter=1e4, tol_rel=1e-9, tol_abs=1e-9, **kwargs):
+ r"""
+ Computes the semi-relaxed FGW transport between two graphs (see :ref:`[48] <references-semirelaxed-fused-gromov-wasserstein>`)
+
+ .. math::
+ \gamma = \mathop{\arg \min}_\gamma \quad (1 - \alpha) \langle \gamma, \mathbf{M} \rangle_F +
+ \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l}
+
+ s.t. \ \mathbf{\gamma} \mathbf{1} &= \mathbf{p}
+
+ \mathbf{\gamma} &\geq 0
+
+ where :
+
+ - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix
+ - :math:`\mathbf{p}` source weights (sum to 1)
+ - `L` is a loss function to account for the misfit between the similarity matrices
+
+
+ .. note:: This function is backend-compatible and will work on arrays
+ from all compatible backends. However all the steps in the conditional
+ gradient are not differentiable.
+
+ The algorithm used for solving the problem is conditional gradient as discussed in :ref:`[48] <references-semirelaxed-fused-gromov-wasserstein>`
+
+ Parameters
+ ----------
+ M : array-like, shape (ns, nt)
+ Metric cost matrix between features across domains
+ C1 : array-like, shape (ns, ns)
+ Metric cost matrix representative of the structure in the source space
+ C2 : array-like, shape (nt, nt)
+ Metric cost matrix representative of the structure in the target space
+ p : array-like, shape (ns,)
+ Distribution in the source space
+ loss_fun : str
+ loss function used for the solver either 'square_loss' or 'kl_loss'.
+ 'kl_loss' is not implemented yet and will raise an error.
+ symmetric : bool, optional
+ Either C1 and C2 are to be assumed symmetric or not.
+ If let to its default None value, a symmetry test will be conducted.
+ Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymetric).
+ alpha : float, optional
+ Trade-off parameter (0 < alpha < 1)
+ G0: array-like, shape (ns,nt), optional
+ If None the initial transport plan of the solver is pq^T.
+ Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver.
+ log : bool, optional
+ record log if True
+ max_iter : int, optional
+ Max number of iterations
+ tol_rel : float, optional
+ Stop threshold on relative error (>0)
+ tol_abs : float, optional
+ Stop threshold on absolute error (>0)
+ **kwargs : dict
+ parameters can be directly passed to the ot.optim.cg solver
+
+ Returns
+ -------
+ gamma : array-like, shape (`ns`, `nt`)
+ Optimal transportation matrix for the given parameters.
+ log : dict
+ Log dictionary return only if log==True in parameters.
+
+
+ .. _references-semirelaxed-fused-gromov-wasserstein:
+ References
+ ----------
+ .. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain
+ and Courty Nicolas "Optimal Transport for structured data with
+ application on graphs", International Conference on Machine Learning
+ (ICML). 2019.
+
+ .. [48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty.
+ "Semi-relaxed Gromov-Wasserstein divergence and applications on graphs"
+ International Conference on Learning Representations (ICLR), 2022.
+ """
+ if loss_fun == 'kl_loss':
+ raise NotImplementedError()
+
+ p = list_to_array(p)
+ if G0 is None:
+ nx = get_backend(p, C1, C2, M)
+ else:
+ nx = get_backend(p, C1, C2, M, G0)
+
+ if symmetric is None:
+ symmetric = nx.allclose(C1, C1.T, atol=1e-10) and nx.allclose(C2, C2.T, atol=1e-10)
+
+ if G0 is None:
+ q = unif(C2.shape[0], type_as=p)
+ G0 = nx.outer(p, q)
+ else:
+ q = nx.sum(G0, 0)
+ # Check marginals of G0
+ np.testing.assert_allclose(nx.sum(G0, 1), p, atol=1e-08)
+
+ constC, hC1, hC2, fC2t = init_matrix_semirelaxed(C1, C2, p, loss_fun, nx)
+
+ ones_p = nx.ones(p.shape[0], type_as=p)
+
+ def f(G):
+ qG = nx.sum(G, 0)
+ marginal_product = nx.outer(ones_p, nx.dot(qG, fC2t))
+ return gwloss(constC + marginal_product, hC1, hC2, G, nx)
+
+ if symmetric:
+ def df(G):
+ qG = nx.sum(G, 0)
+ marginal_product = nx.outer(ones_p, nx.dot(qG, fC2t))
+ return gwggrad(constC + marginal_product, hC1, hC2, G, nx)
+ else:
+ constCt, hC1t, hC2t, fC2 = init_matrix_semirelaxed(C1.T, C2.T, p, loss_fun, nx)
+
+ def df(G):
+ qG = nx.sum(G, 0)
+ marginal_product_1 = nx.outer(ones_p, nx.dot(qG, fC2t))
+ marginal_product_2 = nx.outer(ones_p, nx.dot(qG, fC2))
+ return 0.5 * (gwggrad(constC + marginal_product_1, hC1, hC2, G, nx) + gwggrad(constCt + marginal_product_2, hC1t, hC2t, G, nx))
+
+ def line_search(cost, G, deltaG, Mi, cost_G, **kwargs):
+ return solve_semirelaxed_gromov_linesearch(
+ G, deltaG, cost_G, C1, C2, ones_p, M=(1 - alpha) * M, reg=alpha, nx=nx, **kwargs)
+
+ if log:
+ res, log = semirelaxed_cg(p, q, (1 - alpha) * M, alpha, f, df, G0, line_search, log=True, numItermax=max_iter, stopThr=tol_rel, stopThr2=tol_abs, **kwargs)
+ log['srfgw_dist'] = log['loss'][-1]
+ return res, log
+ else:
+ return semirelaxed_cg(p, q, (1 - alpha) * M, alpha, f, df, G0, line_search, log=False, numItermax=max_iter, stopThr=tol_rel, stopThr2=tol_abs, **kwargs)
+
+
+def semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p, loss_fun='square_loss', symmetric=None, alpha=0.5, G0=None, log=False,
+ max_iter=1e4, tol_rel=1e-9, tol_abs=1e-9, **kwargs):
+ r"""
+ Computes the semi-relaxed FGW divergence between two graphs (see :ref:`[48] <references-semirelaxed-fused-gromov-wasserstein2>`)
+
+ .. math::
+ \min_\gamma \quad (1 - \alpha) \langle \gamma, \mathbf{M} \rangle_F + \alpha \sum_{i,j,k,l}
+ L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l}
+
+ s.t. \ \mathbf{\gamma} \mathbf{1} &= \mathbf{p}
+
+ \mathbf{\gamma} &\geq 0
+
+ where :
+
+ - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix
+ - :math:`\mathbf{p}` source weights (sum to 1)
+ - `L` is a loss function to account for the misfit between the similarity matrices
+
+ The algorithm used for solving the problem is conditional gradient as
+ discussed in :ref:`[48] <semirelaxed-fused-gromov-wasserstein2>`
+
+ Note that when using backends, this loss function is differentiable wrt the
+ matrices (C1, C2) but not yet for the weights p.
+
+ .. note:: This function is backend-compatible and will work on arrays
+ from all compatible backends. However all the steps in the conditional
+ gradient are not differentiable.
+
+ Parameters
+ ----------
+ M : array-like, shape (ns, nt)
+ Metric cost matrix between features across domains
+ C1 : array-like, shape (ns, ns)
+ Metric cost matrix representative of the structure in the source space.
+ C2 : array-like, shape (nt, nt)
+ Metric cost matrix representative of the structure in the target space.
+ p : array-like, shape (ns,)
+ Distribution in the source space.
+ loss_fun : str, optional
+ loss function used for the solver either 'square_loss' or 'kl_loss'.
+ 'kl_loss' is not implemented yet and will raise an error.
+ symmetric : bool, optional
+ Either C1 and C2 are to be assumed symmetric or not.
+ If let to its default None value, a symmetry test will be conducted.
+ Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymetric).
+ alpha : float, optional
+ Trade-off parameter (0 < alpha < 1)
+ G0: array-like, shape (ns,nt), optional
+ If None the initial transport plan of the solver is pq^T.
+ Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver.
+ log : bool, optional
+ Record log if True.
+ max_iter : int, optional
+ Max number of iterations
+ tol_rel : float, optional
+ Stop threshold on relative error (>0)
+ tol_abs : float, optional
+ Stop threshold on absolute error (>0)
+ **kwargs : dict
+ Parameters can be directly passed to the ot.optim.cg solver.
+
+ Returns
+ -------
+ srfgw-divergence : float
+ Semi-relaxed Fused gromov wasserstein divergence for the given parameters.
+ log : dict
+ Log dictionary return only if log==True in parameters.
+
+
+ .. _references-semirelaxed-fused-gromov-wasserstein2:
+ References
+ ----------
+ .. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain
+ and Courty Nicolas "Optimal Transport for structured data with
+ application on graphs", International Conference on Machine Learning
+ (ICML). 2019.
+
+ .. [48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty.
+ "Semi-relaxed Gromov-Wasserstein divergence and applications on graphs"
+ International Conference on Learning Representations (ICLR), 2022.
+ """
+ nx = get_backend(p, C1, C2, M)
+
+ T, log_fgw = semirelaxed_fused_gromov_wasserstein(
+ M, C1, C2, p, loss_fun, symmetric, alpha, G0, log=True,
+ max_iter=max_iter, tol_rel=tol_rel, tol_abs=tol_abs, **kwargs)
+ q = nx.sum(T, 0)
+ srfgw_dist = log_fgw['srfgw_dist']
+ log_fgw['T'] = T
+
+ if loss_fun == 'square_loss':
+ gC1 = 2 * C1 * nx.outer(p, p) - 2 * nx.dot(T, nx.dot(C2, T.T))
+ gC2 = 2 * C2 * nx.outer(q, q) - 2 * nx.dot(T.T, nx.dot(C1, T))
+ srfgw_dist = nx.set_gradients(srfgw_dist, (C1, C2, M),
+ (alpha * gC1, alpha * gC2, (1 - alpha) * T))
+
+ if log:
+ return srfgw_dist, log_fgw
+ else:
+ return srfgw_dist
+
+
+def solve_semirelaxed_gromov_linesearch(G, deltaG, cost_G, C1, C2, ones_p,
+ M, reg, alpha_min=None, alpha_max=None, nx=None, **kwargs):
+ """
+ Solve the linesearch in the FW iterations
+
+ Parameters
+ ----------
+
+ G : array-like, shape(ns,nt)
+ The transport map at a given iteration of the FW
+ deltaG : array-like (ns,nt)
+ Difference between the optimal map found by linearization in the FW algorithm and the value at a given iteration
+ cost_G : float
+ Value of the cost at `G`
+ C1 : array-like (ns,ns)
+ Structure matrix in the source domain.
+ C2 : array-like (nt,nt)
+ Structure matrix in the target domain.
+ ones_p: array-like (ns,1)
+ Array of ones of size ns
+ M : array-like (ns,nt)
+ Cost matrix between the features.
+ reg : float
+ Regularization parameter.
+ alpha_min : float, optional
+ Minimum value for alpha
+ alpha_max : float, optional
+ Maximum value for alpha
+ nx : backend, optional
+ If let to its default value None, a backend test will be conducted.
+ Returns
+ -------
+ alpha : float
+ The optimal step size of the FW
+ fc : int
+ nb of function call. Useless here
+ cost_G : float
+ The value of the cost for the next iteration
+
+ References
+ ----------
+ .. [48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty.
+ "Semi-relaxed Gromov-Wasserstein divergence and applications on graphs"
+ International Conference on Learning Representations (ICLR), 2021.
+ """
+ if nx is None:
+ G, deltaG, C1, C2, M = list_to_array(G, deltaG, C1, C2, M)
+
+ if isinstance(M, int) or isinstance(M, float):
+ nx = get_backend(G, deltaG, C1, C2)
+ else:
+ nx = get_backend(G, deltaG, C1, C2, M)
+
+ qG, qdeltaG = nx.sum(G, 0), nx.sum(deltaG, 0)
+ dot = nx.dot(nx.dot(C1, deltaG), C2.T)
+ C2t_square = C2.T ** 2
+ dot_qG = nx.dot(nx.outer(ones_p, qG), C2t_square)
+ dot_qdeltaG = nx.dot(nx.outer(ones_p, qdeltaG), C2t_square)
+ a = reg * nx.sum((dot_qdeltaG - 2 * dot) * deltaG)
+ b = nx.sum(M * deltaG) + reg * (nx.sum((dot_qdeltaG - 2 * dot) * G) + nx.sum((dot_qG - 2 * nx.dot(nx.dot(C1, G), C2.T)) * deltaG))
+ alpha = solve_1d_linesearch_quad(a, b)
+ if alpha_min is not None or alpha_max is not None:
+ alpha = np.clip(alpha, alpha_min, alpha_max)
+
+ # the new cost can be deduced from the line search quadratic function
+ cost_G = cost_G + a * (alpha ** 2) + b * alpha
+
+ return alpha, 1, cost_G
diff --git a/ot/gromov/_utils.py b/ot/gromov/_utils.py
new file mode 100644
index 0000000..e842250
--- /dev/null
+++ b/ot/gromov/_utils.py
@@ -0,0 +1,413 @@
+# -*- coding: utf-8 -*-
+"""
+Gromov-Wasserstein and Fused-Gromov-Wasserstein utils.
+"""
+
+# Author: Erwan Vautier <erwan.vautier@gmail.com>
+# Nicolas Courty <ncourty@irisa.fr>
+# Rémi Flamary <remi.flamary@unice.fr>
+# Titouan Vayer <titouan.vayer@irisa.fr>
+# Cédric Vincent-Cuaz <cedvincentcuaz@gmail.com>
+#
+# License: MIT License
+
+
+from ..utils import list_to_array
+from ..backend import get_backend
+
+
+def init_matrix(C1, C2, p, q, loss_fun='square_loss', nx=None):
+ r"""Return loss matrices and tensors for Gromov-Wasserstein fast computation
+
+ Returns the value of :math:`\mathcal{L}(\mathbf{C_1}, \mathbf{C_2}) \otimes \mathbf{T}` with the
+ selected loss function as the loss function of Gromow-Wasserstein discrepancy.
+
+ The matrices are computed as described in Proposition 1 in :ref:`[12] <references-init-matrix>`
+
+ Where :
+
+ - :math:`\mathbf{C_1}`: Metric cost matrix in the source space
+ - :math:`\mathbf{C_2}`: Metric cost matrix in the target space
+ - :math:`\mathbf{T}`: A coupling between those two spaces
+
+ The square-loss function :math:`L(a, b) = |a - b|^2` is read as :
+
+ .. math::
+
+ L(a, b) = f_1(a) + f_2(b) - h_1(a) h_2(b)
+
+ \mathrm{with} \ f_1(a) &= a^2
+
+ f_2(b) &= b^2
+
+ h_1(a) &= a
+
+ h_2(b) &= 2b
+
+ The kl-loss function :math:`L(a, b) = a \log\left(\frac{a}{b}\right) - a + b` is read as :
+
+ .. math::
+
+ L(a, b) = f_1(a) + f_2(b) - h_1(a) h_2(b)
+
+ \mathrm{with} \ f_1(a) &= a \log(a) - a
+
+ f_2(b) &= b
+
+ h_1(a) &= a
+
+ h_2(b) &= \log(b)
+
+ Parameters
+ ----------
+ C1 : array-like, shape (ns, ns)
+ Metric cost matrix in the source space
+ C2 : array-like, shape (nt, nt)
+ Metric cost matrix in the target space
+ p : array-like, shape (ns,)
+ Probability distribution in the source space
+ q : array-like, shape (nt,)
+ Probability distribution in the target space
+ loss_fun : str, optional
+ Name of loss function to use: either 'square_loss' or 'kl_loss' (default='square_loss')
+ nx : backend, optional
+ If let to its default value None, a backend test will be conducted.
+ Returns
+ -------
+ constC : array-like, shape (ns, nt)
+ Constant :math:`\mathbf{C}` matrix in Eq. (6)
+ hC1 : array-like, shape (ns, ns)
+ :math:`\mathbf{h1}(\mathbf{C1})` matrix in Eq. (6)
+ hC2 : array-like, shape (nt, nt)
+ :math:`\mathbf{h2}(\mathbf{C2})` matrix in Eq. (6)
+
+
+ .. _references-init-matrix:
+ References
+ ----------
+ .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon,
+ "Gromov-Wasserstein averaging of kernel and distance matrices."
+ International Conference on Machine Learning (ICML). 2016.
+
+ """
+ if nx is None:
+ C1, C2, p, q = list_to_array(C1, C2, p, q)
+ nx = get_backend(C1, C2, p, q)
+
+ if loss_fun == 'square_loss':
+ def f1(a):
+ return (a**2)
+
+ def f2(b):
+ return (b**2)
+
+ def h1(a):
+ return a
+
+ def h2(b):
+ return 2 * b
+ elif loss_fun == 'kl_loss':
+ def f1(a):
+ return a * nx.log(a + 1e-15) - a
+
+ def f2(b):
+ return b
+
+ def h1(a):
+ return a
+
+ def h2(b):
+ return nx.log(b + 1e-15)
+
+ constC1 = nx.dot(
+ nx.dot(f1(C1), nx.reshape(p, (-1, 1))),
+ nx.ones((1, len(q)), type_as=q)
+ )
+ constC2 = nx.dot(
+ nx.ones((len(p), 1), type_as=p),
+ nx.dot(nx.reshape(q, (1, -1)), f2(C2).T)
+ )
+ constC = constC1 + constC2
+ hC1 = h1(C1)
+ hC2 = h2(C2)
+
+ return constC, hC1, hC2
+
+
+def tensor_product(constC, hC1, hC2, T, nx=None):
+ r"""Return the tensor for Gromov-Wasserstein fast computation
+
+ The tensor is computed as described in Proposition 1 Eq. (6) in :ref:`[12] <references-tensor-product>`
+
+ Parameters
+ ----------
+ constC : array-like, shape (ns, nt)
+ Constant :math:`\mathbf{C}` matrix in Eq. (6)
+ hC1 : array-like, shape (ns, ns)
+ :math:`\mathbf{h1}(\mathbf{C1})` matrix in Eq. (6)
+ hC2 : array-like, shape (nt, nt)
+ :math:`\mathbf{h2}(\mathbf{C2})` matrix in Eq. (6)
+ nx : backend, optional
+ If let to its default value None, a backend test will be conducted.
+ Returns
+ -------
+ tens : array-like, shape (`ns`, `nt`)
+ :math:`\mathcal{L}(\mathbf{C_1}, \mathbf{C_2}) \otimes \mathbf{T}` tensor-matrix multiplication result
+
+
+ .. _references-tensor-product:
+ References
+ ----------
+ .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon,
+ "Gromov-Wasserstein averaging of kernel and distance matrices."
+ International Conference on Machine Learning (ICML). 2016.
+
+ """
+ if nx is None:
+ constC, hC1, hC2, T = list_to_array(constC, hC1, hC2, T)
+ nx = get_backend(constC, hC1, hC2, T)
+
+ A = - nx.dot(
+ nx.dot(hC1, T), hC2.T
+ )
+ tens = constC + A
+ # tens -= tens.min()
+ return tens
+
+
+def gwloss(constC, hC1, hC2, T, nx=None):
+ r"""Return the Loss for Gromov-Wasserstein
+
+ The loss is computed as described in Proposition 1 Eq. (6) in :ref:`[12] <references-gwloss>`
+
+ Parameters
+ ----------
+ constC : array-like, shape (ns, nt)
+ Constant :math:`\mathbf{C}` matrix in Eq. (6)
+ hC1 : array-like, shape (ns, ns)
+ :math:`\mathbf{h1}(\mathbf{C1})` matrix in Eq. (6)
+ hC2 : array-like, shape (nt, nt)
+ :math:`\mathbf{h2}(\mathbf{C2})` matrix in Eq. (6)
+ T : array-like, shape (ns, nt)
+ Current value of transport matrix :math:`\mathbf{T}`
+ nx : backend, optional
+ If let to its default value None, a backend test will be conducted.
+ Returns
+ -------
+ loss : float
+ Gromov Wasserstein loss
+
+
+ .. _references-gwloss:
+ References
+ ----------
+ .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon,
+ "Gromov-Wasserstein averaging of kernel and distance matrices."
+ International Conference on Machine Learning (ICML). 2016.
+
+ """
+
+ tens = tensor_product(constC, hC1, hC2, T, nx)
+ if nx is None:
+ tens, T = list_to_array(tens, T)
+ nx = get_backend(tens, T)
+
+ return nx.sum(tens * T)
+
+
+def gwggrad(constC, hC1, hC2, T, nx=None):
+ r"""Return the gradient for Gromov-Wasserstein
+
+ The gradient is computed as described in Proposition 2 in :ref:`[12] <references-gwggrad>`
+
+ Parameters
+ ----------
+ constC : array-like, shape (ns, nt)
+ Constant :math:`\mathbf{C}` matrix in Eq. (6)
+ hC1 : array-like, shape (ns, ns)
+ :math:`\mathbf{h1}(\mathbf{C1})` matrix in Eq. (6)
+ hC2 : array-like, shape (nt, nt)
+ :math:`\mathbf{h2}(\mathbf{C2})` matrix in Eq. (6)
+ T : array-like, shape (ns, nt)
+ Current value of transport matrix :math:`\mathbf{T}`
+ nx : backend, optional
+ If let to its default value None, a backend test will be conducted.
+ Returns
+ -------
+ grad : array-like, shape (`ns`, `nt`)
+ Gromov Wasserstein gradient
+
+
+ .. _references-gwggrad:
+ References
+ ----------
+ .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon,
+ "Gromov-Wasserstein averaging of kernel and distance matrices."
+ International Conference on Machine Learning (ICML). 2016.
+
+ """
+ return 2 * tensor_product(constC, hC1, hC2,
+ T, nx) # [12] Prop. 2 misses a 2 factor
+
+
+def update_square_loss(p, lambdas, T, Cs):
+ r"""
+ Updates :math:`\mathbf{C}` according to the L2 Loss kernel with the `S` :math:`\mathbf{T}_s`
+ couplings calculated at each iteration
+
+ Parameters
+ ----------
+ p : array-like, shape (N,)
+ Masses in the targeted barycenter.
+ lambdas : list of float
+ List of the `S` spaces' weights.
+ T : list of S array-like of shape (ns,N)
+ The `S` :math:`\mathbf{T}_s` couplings calculated at each iteration.
+ Cs : list of S array-like, shape(ns,ns)
+ Metric cost matrices.
+
+ Returns
+ ----------
+ C : array-like, shape (`nt`, `nt`)
+ Updated :math:`\mathbf{C}` matrix.
+ """
+ T = list_to_array(*T)
+ Cs = list_to_array(*Cs)
+ p = list_to_array(p)
+ nx = get_backend(p, *T, *Cs)
+
+ tmpsum = sum([
+ lambdas[s] * nx.dot(
+ nx.dot(T[s].T, Cs[s]),
+ T[s]
+ ) for s in range(len(T))
+ ])
+ ppt = nx.outer(p, p)
+
+ return tmpsum / ppt
+
+
+def update_kl_loss(p, lambdas, T, Cs):
+ r"""
+ Updates :math:`\mathbf{C}` according to the KL Loss kernel with the `S` :math:`\mathbf{T}_s` couplings calculated at each iteration
+
+
+ Parameters
+ ----------
+ p : array-like, shape (N,)
+ Weights in the targeted barycenter.
+ lambdas : list of float
+ List of the `S` spaces' weights
+ T : list of S array-like of shape (ns,N)
+ The `S` :math:`\mathbf{T}_s` couplings calculated at each iteration.
+ Cs : list of S array-like, shape(ns,ns)
+ Metric cost matrices.
+
+ Returns
+ ----------
+ C : array-like, shape (`ns`, `ns`)
+ updated :math:`\mathbf{C}` matrix
+ """
+ Cs = list_to_array(*Cs)
+ T = list_to_array(*T)
+ p = list_to_array(p)
+ nx = get_backend(p, *T, *Cs)
+
+ tmpsum = sum([
+ lambdas[s] * nx.dot(
+ nx.dot(T[s].T, Cs[s]),
+ T[s]
+ ) for s in range(len(T))
+ ])
+ ppt = nx.outer(p, p)
+
+ return nx.exp(tmpsum / ppt)
+
+
+def init_matrix_semirelaxed(C1, C2, p, loss_fun='square_loss', nx=None):
+ r"""Return loss matrices and tensors for semi-relaxed Gromov-Wasserstein fast computation
+
+ Returns the value of :math:`\mathcal{L}(\mathbf{C_1}, \mathbf{C_2}) \otimes \mathbf{T}` with the
+ selected loss function as the loss function of semi-relaxed Gromow-Wasserstein discrepancy.
+
+ The matrices are computed as described in Proposition 1 in :ref:`[12] <references-init-matrix>`
+ and adapted to the semi-relaxed problem where the second marginal is not a constant anymore.
+
+ Where :
+
+ - :math:`\mathbf{C_1}`: Metric cost matrix in the source space
+ - :math:`\mathbf{C_2}`: Metric cost matrix in the target space
+ - :math:`\mathbf{T}`: A coupling between those two spaces
+
+ The square-loss function :math:`L(a, b) = |a - b|^2` is read as :
+
+ .. math::
+
+ L(a, b) = f_1(a) + f_2(b) - h_1(a) h_2(b)
+
+ \mathrm{with} \ f_1(a) &= a^2
+
+ f_2(b) &= b^2
+
+ h_1(a) &= a
+
+ h_2(b) &= 2b
+
+ Parameters
+ ----------
+ C1 : array-like, shape (ns, ns)
+ Metric cost matrix in the source space
+ C2 : array-like, shape (nt, nt)
+ Metric cost matrix in the target space
+ T : array-like, shape (ns, nt)
+ Coupling between source and target spaces
+ p : array-like, shape (ns,)
+ nx : backend, optional
+ If let to its default value None, a backend test will be conducted.
+ Returns
+ -------
+ constC : array-like, shape (ns, nt)
+ Constant :math:`\mathbf{C}` matrix in Eq. (6) adapted to srGW
+ hC1 : array-like, shape (ns, ns)
+ :math:`\mathbf{h1}(\mathbf{C1})` matrix in Eq. (6)
+ hC2 : array-like, shape (nt, nt)
+ :math:`\mathbf{h2}(\mathbf{C2})` matrix in Eq. (6)
+ fC2t: array-like, shape (nt, nt)
+ :math:`\mathbf{f2}(\mathbf{C2})^\top` matrix in Eq. (6)
+
+
+ .. _references-init-matrix:
+ References
+ ----------
+ .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon,
+ "Gromov-Wasserstein averaging of kernel and distance matrices."
+ International Conference on Machine Learning (ICML). 2016.
+
+ .. [48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty.
+ "Semi-relaxed Gromov-Wasserstein divergence and applications on graphs"
+ International Conference on Learning Representations (ICLR), 2022.
+ """
+ if nx is None:
+ C1, C2, p = list_to_array(C1, C2, p)
+ nx = get_backend(C1, C2, p)
+
+ if loss_fun == 'square_loss':
+ def f1(a):
+ return (a**2)
+
+ def f2(b):
+ return (b**2)
+
+ def h1(a):
+ return a
+
+ def h2(b):
+ return 2 * b
+
+ constC = nx.dot(nx.dot(f1(C1), nx.reshape(p, (-1, 1))),
+ nx.ones((1, C2.shape[0]), type_as=p))
+
+ hC1 = h1(C1)
+ hC2 = h2(C2)
+ fC2t = f2(C2).T
+ return constC, hC1, hC2, fC2t
diff --git a/ot/helpers/pre_build_helpers.py b/ot/helpers/pre_build_helpers.py
index 93ecd6a..2930036 100644
--- a/ot/helpers/pre_build_helpers.py
+++ b/ot/helpers/pre_build_helpers.py
@@ -4,34 +4,14 @@ import os
import sys
import glob
import tempfile
-import setuptools # noqa
import subprocess
-from distutils.dist import Distribution
-from distutils.sysconfig import customize_compiler
-from numpy.distutils.ccompiler import new_compiler
-from numpy.distutils.command.config_compiler import config_cc
+from setuptools.command.build_ext import customize_compiler, new_compiler
def _get_compiler():
- """Get a compiler equivalent to the one that will be used to build POT
- Handles compiler specified as follows:
- - python setup.py build_ext --compiler=<compiler>
- - CC=<compiler> python setup.py build_ext
- """
- dist = Distribution({'script_name': os.path.basename(sys.argv[0]),
- 'script_args': sys.argv[1:],
- 'cmdclass': {'config_cc': config_cc}})
-
- cmd_opts = dist.command_options.get('build_ext')
- if cmd_opts is not None and 'compiler' in cmd_opts:
- compiler = cmd_opts['compiler'][1]
- else:
- compiler = None
-
- ccompiler = new_compiler(compiler=compiler)
+ ccompiler = new_compiler()
customize_compiler(ccompiler)
-
return ccompiler
diff --git a/ot/lp/EMD.h b/ot/lp/EMD.h
index 8a1f9ac..b56f060 100644
--- a/ot/lp/EMD.h
+++ b/ot/lp/EMD.h
@@ -18,6 +18,7 @@
#include <iostream>
#include <vector>
+#include <cstdint>
typedef unsigned int node_id_type;
@@ -28,8 +29,8 @@ enum ProblemType {
MAX_ITER_REACHED
};
-int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int maxIter);
-int EMD_wrap_omp(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int maxIter, int numThreads);
+int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter);
+int EMD_wrap_omp(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter, int numThreads);
diff --git a/ot/lp/EMD_wrapper.cpp b/ot/lp/EMD_wrapper.cpp
index 2bdc172..4aa5a6e 100644
--- a/ot/lp/EMD_wrapper.cpp
+++ b/ot/lp/EMD_wrapper.cpp
@@ -20,11 +20,11 @@
int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,
- double* alpha, double* beta, double *cost, int maxIter) {
+ double* alpha, double* beta, double *cost, uint64_t maxIter) {
// beware M and C are stored in row major C style!!!
using namespace lemon;
- int n, m, cur;
+ uint64_t n, m, cur;
typedef FullBipartiteDigraph Digraph;
DIGRAPH_TYPEDEFS(Digraph);
@@ -51,15 +51,15 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,
// Define the graph
- std::vector<int> indI(n), indJ(m);
+ std::vector<uint64_t> indI(n), indJ(m);
std::vector<double> weights1(n), weights2(m);
Digraph di(n, m);
- NetworkSimplexSimple<Digraph,double,double, node_id_type> net(di, true, n+m, ((int64_t)n)*((int64_t)m), maxIter);
+ NetworkSimplexSimple<Digraph,double,double, node_id_type> net(di, true, (int) (n + m), n * m, maxIter);
// Set supply and demand, don't account for 0 values (faster)
cur=0;
- for (int i=0; i<n1; i++) {
+ for (uint64_t i=0; i<n1; i++) {
double val=*(X+i);
if (val>0) {
weights1[ cur ] = val;
@@ -70,7 +70,7 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,
// Demand is actually negative supply...
cur=0;
- for (int i=0; i<n2; i++) {
+ for (uint64_t i=0; i<n2; i++) {
double val=*(Y+i);
if (val>0) {
weights2[ cur ] = -val;
@@ -79,12 +79,12 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,
}
- net.supplyMap(&weights1[0], n, &weights2[0], m);
+ net.supplyMap(&weights1[0], (int) n, &weights2[0], (int) m);
// Set the cost of each edge
int64_t idarc = 0;
- for (int i=0; i<n; i++) {
- for (int j=0; j<m; j++) {
+ for (uint64_t i=0; i<n; i++) {
+ for (uint64_t j=0; j<m; j++) {
double val=*(D+indI[i]*n2+indJ[j]);
net.setCost(di.arcFromId(idarc), val);
++idarc;
@@ -95,7 +95,7 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,
// Solve the problem with the network simplex algorithm
int ret=net.run();
- int i, j;
+ uint64_t i, j;
if (ret==(int)net.OPTIMAL || ret==(int)net.MAX_ITER_REACHED) {
*cost = 0;
Arc a; di.first(a);
@@ -122,11 +122,11 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,
int EMD_wrap_omp(int n1, int n2, double *X, double *Y, double *D, double *G,
- double* alpha, double* beta, double *cost, int maxIter, int numThreads) {
+ double* alpha, double* beta, double *cost, uint64_t maxIter, int numThreads) {
// beware M and C are stored in row major C style!!!
using namespace lemon_omp;
- int n, m, cur;
+ uint64_t n, m, cur;
typedef FullBipartiteDigraph Digraph;
DIGRAPH_TYPEDEFS(Digraph);
@@ -153,15 +153,15 @@ int EMD_wrap_omp(int n1, int n2, double *X, double *Y, double *D, double *G,
// Define the graph
- std::vector<int> indI(n), indJ(m);
+ std::vector<uint64_t> indI(n), indJ(m);
std::vector<double> weights1(n), weights2(m);
Digraph di(n, m);
- NetworkSimplexSimple<Digraph,double,double, node_id_type> net(di, true, n+m, ((int64_t)n)*((int64_t)m), maxIter, numThreads);
+ NetworkSimplexSimple<Digraph,double,double, node_id_type> net(di, true, (int) (n + m), n * m, maxIter, numThreads);
// Set supply and demand, don't account for 0 values (faster)
cur=0;
- for (int i=0; i<n1; i++) {
+ for (uint64_t i=0; i<n1; i++) {
double val=*(X+i);
if (val>0) {
weights1[ cur ] = val;
@@ -172,7 +172,7 @@ int EMD_wrap_omp(int n1, int n2, double *X, double *Y, double *D, double *G,
// Demand is actually negative supply...
cur=0;
- for (int i=0; i<n2; i++) {
+ for (uint64_t i=0; i<n2; i++) {
double val=*(Y+i);
if (val>0) {
weights2[ cur ] = -val;
@@ -181,12 +181,12 @@ int EMD_wrap_omp(int n1, int n2, double *X, double *Y, double *D, double *G,
}
- net.supplyMap(&weights1[0], n, &weights2[0], m);
+ net.supplyMap(&weights1[0], (int) n, &weights2[0], (int) m);
// Set the cost of each edge
int64_t idarc = 0;
- for (int i=0; i<n; i++) {
- for (int j=0; j<m; j++) {
+ for (uint64_t i=0; i<n; i++) {
+ for (uint64_t j=0; j<m; j++) {
double val=*(D+indI[i]*n2+indJ[j]);
net.setCost(di.arcFromId(idarc), val);
++idarc;
@@ -197,7 +197,7 @@ int EMD_wrap_omp(int n1, int n2, double *X, double *Y, double *D, double *G,
// Solve the problem with the network simplex algorithm
int ret=net.run();
- int i, j;
+ uint64_t i, j;
if (ret==(int)net.OPTIMAL || ret==(int)net.MAX_ITER_REACHED) {
*cost = 0;
Arc a; di.first(a);
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py
index 390c32d..2ff02ab 100644
--- a/ot/lp/__init__.py
+++ b/ot/lp/__init__.py
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
"""
-Solvers for the original linear program OT problem
+Solvers for the original linear program OT problem.
"""
@@ -20,16 +20,17 @@ from .cvx import barycenter
# import compiled emd
from .emd_wrap import emd_c, check_result, emd_1d_sorted
-from .solver_1d import emd_1d, emd2_1d, wasserstein_1d
+from .solver_1d import (emd_1d, emd2_1d, wasserstein_1d,
+ binary_search_circle, wasserstein_circle,
+ semidiscrete_wasserstein2_unif_circle)
from ..utils import dist, list_to_array
from ..utils import parmap
from ..backend import get_backend
-
-
__all__ = ['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx', ' emd_1d_sorted',
- 'emd_1d', 'emd2_1d', 'wasserstein_1d']
+ 'emd_1d', 'emd2_1d', 'wasserstein_1d', 'generalized_free_support_barycenter',
+ 'binary_search_circle', 'wasserstein_circle', 'semidiscrete_wasserstein2_unif_circle']
def check_number_threads(numThreads):
@@ -232,6 +233,8 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1):
If this behaviour is unwanted, please make sure to provide a
floating point input.
+ .. note:: An error will be raised if the vectors :math:`\mathbf{a}` and :math:`\mathbf{b}` do not sum to the same value.
+
Uses the algorithm proposed in :ref:`[1] <references-emd>`.
Parameters
@@ -391,6 +394,8 @@ def emd2(a, b, M, processes=1,
If this behaviour is unwanted, please make sure to provide a
floating point input.
+ .. note:: An error will be raised if the vectors :math:`\mathbf{a}` and :math:`\mathbf{b}` do not sum to the same value.
+
Uses the algorithm proposed in :ref:`[1] <references-emd2>`.
Parameters
@@ -483,6 +488,11 @@ def emd2(a, b, M, processes=1,
assert (a.shape[0] == M.shape[0] and b.shape[0] == M.shape[1]), \
"Dimension mismatch, check dimensions of M with a and b"
+ # ensure that same mass
+ np.testing.assert_almost_equal(a.sum(0),
+ b.sum(0,keepdims=True), err_msg='a and b vector must have the same sum')
+ b = b * a.sum(0) / b.sum(0,keepdims=True)
+
asel = a != 0
numThreads = check_number_threads(numThreads)
@@ -517,8 +527,8 @@ def emd2(a, b, M, processes=1,
log['warning'] = result_code_string
log['result_code'] = result_code
cost = nx.set_gradients(nx.from_numpy(cost, type_as=type_as),
- (a0, b0, M0), (log['u'] - nx.mean(log['u']),
- log['v'] - nx.mean(log['v']), G))
+ (a0, b0, M0), (log['u'] - nx.mean(log['u']),
+ log['v'] - nx.mean(log['v']), G))
return [cost, log]
else:
def f(b):
@@ -572,18 +582,18 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None
where :
- :math:`w \in \mathbb{(0, 1)}^{N}`'s are the barycenter weights and sum to one
- - the :math:`\mathbf{a}_i \in \mathbb{R}^{k_i}` are the empirical measures weights and sum to one for each :math:`i`
- - the :math:`\mathbf{X}_i \in \mathbb{R}^{k_i, d}` are the empirical measures atoms locations
+ - `measure_weights` denotes the :math:`\mathbf{a}_i \in \mathbb{R}^{k_i}`: empirical measures weights (on simplex)
+ - `measures_locations` denotes the :math:`\mathbf{X}_i \in \mathbb{R}^{k_i, d}`: empirical measures atoms locations
- :math:`\mathbf{b} \in \mathbb{R}^{k}` is the desired weights vector of the barycenter
- This problem is considered in :ref:`[1] <references-free-support-barycenter>` (Algorithm 2).
+ This problem is considered in :ref:`[20] <references-free-support-barycenter>` (Algorithm 2).
There are two differences with the following codes:
- we do not optimize over the weights
- we do not do line search for the locations updates, we use i.e. :math:`\theta = 1` in
- :ref:`[1] <references-free-support-barycenter>` (Algorithm 2). This can be seen as a discrete
+ :ref:`[20] <references-free-support-barycenter>` (Algorithm 2). This can be seen as a discrete
implementation of the fixed-point algorithm of
- :ref:`[2] <references-free-support-barycenter>` proposed in the continuous setting.
+ :ref:`[43] <references-free-support-barycenter>` proposed in the continuous setting.
Parameters
----------
@@ -623,13 +633,13 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None
.. _references-free-support-barycenter:
References
----------
- .. [1] Cuturi, Marco, and Arnaud Doucet. "Fast computation of Wasserstein barycenters." International Conference on Machine Learning. 2014.
+ .. [20] Cuturi, Marco, and Arnaud Doucet. "Fast computation of Wasserstein barycenters." International Conference on Machine Learning. 2014.
- .. [2] Álvarez-Esteban, Pedro C., et al. "A fixed-point approach to barycenters in Wasserstein space." Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762.
+ .. [43] Álvarez-Esteban, Pedro C., et al. "A fixed-point approach to barycenters in Wasserstein space." Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762.
"""
- nx = get_backend(*measures_locations,*measures_weights,X_init)
+ nx = get_backend(*measures_locations, *measures_weights, X_init)
iter_count = 0
@@ -637,9 +647,9 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None
k = X_init.shape[0]
d = X_init.shape[1]
if b is None:
- b = nx.ones((k,),type_as=X_init) / k
+ b = nx.ones((k,), type_as=X_init) / k
if weights is None:
- weights = nx.ones((N,),type_as=X_init) / N
+ weights = nx.ones((N,), type_as=X_init) / N
X = X_init
@@ -650,15 +660,14 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None
while (displacement_square_norm > stopThr and iter_count < numItermax):
- T_sum = nx.zeros((k, d),type_as=X_init)
-
+ T_sum = nx.zeros((k, d), type_as=X_init)
- for (measure_locations_i, measure_weights_i, weight_i) in zip(measures_locations, measures_weights, weights):
+ for (measure_locations_i, measure_weights_i, weight_i) in zip(measures_locations, measures_weights, weights):
M_i = dist(X, measure_locations_i)
T_i = emd(b, measure_weights_i, M_i, numThreads=numThreads)
- T_sum = T_sum + weight_i * 1. / b[:,None] * nx.dot(T_i, measure_locations_i)
+ T_sum = T_sum + weight_i * 1. / b[:, None] * nx.dot(T_i, measure_locations_i)
- displacement_square_norm = nx.sum((T_sum - X)**2)
+ displacement_square_norm = nx.sum((T_sum - X) ** 2)
if log:
displacement_square_norms.append(displacement_square_norm)
@@ -675,3 +684,111 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None
else:
return X
+
+def generalized_free_support_barycenter(X_list, a_list, P_list, n_samples_bary, Y_init=None, b=None, weights=None,
+ numItermax=100, stopThr=1e-7, verbose=False, log=None, numThreads=1, eps=0):
+ r"""
+ Solves the free support generalised Wasserstein barycenter problem: finding a barycenter (a discrete measure with
+ a fixed amount of points of uniform weights) whose respective projections fit the input measures.
+ More formally:
+
+ .. math::
+ \min_\gamma \quad \sum_{i=1}^p w_i W_2^2(\nu_i, \mathbf{P}_i\#\gamma)
+
+ where :
+
+ - :math:`\gamma = \sum_{l=1}^n b_l\delta_{y_l}` is the desired barycenter with each :math:`y_l \in \mathbb{R}^d`
+ - :math:`\mathbf{b} \in \mathbb{R}^{n}` is the desired weights vector of the barycenter
+ - The input measures are :math:`\nu_i = \sum_{j=1}^{k_i}a_{i,j}\delta_{x_{i,j}}`
+ - The :math:`\mathbf{a}_i \in \mathbb{R}^{k_i}` are the respective empirical measures weights (on the simplex)
+ - The :math:`\mathbf{X}_i \in \mathbb{R}^{k_i, d_i}` are the respective empirical measures atoms locations
+ - :math:`w = (w_1, \cdots w_p)` are the barycenter coefficients (on the simplex)
+ - Each :math:`\mathbf{P}_i \in \mathbb{R}^{d, d_i}`, and :math:`P_i\#\nu_i = \sum_{j=1}^{k_i}a_{i,j}\delta_{P_ix_{i,j}}`
+
+ As show by :ref:`[42] <references-generalized-free-support-barycenter>`,
+ this problem can be re-written as a Wasserstein Barycenter problem,
+ which we solve using the free support method :ref:`[20] <references-generalized-free-support-barycenter>`
+ (Algorithm 2).
+
+ Parameters
+ ----------
+ X_list : list of p (k_i,d_i) array-like
+ Discrete supports of the input measures: each consists of :math:`k_i` locations of a `d_i`-dimensional space
+ (:math:`k_i` can be different for each element of the list)
+ a_list : list of p (k_i,) array-like
+ Measure weights: each element is a vector (k_i) on the simplex
+ P_list : list of p (d_i,d) array-like
+ Each :math:`P_i` is a linear map :math:`\mathbb{R}^{d} \rightarrow \mathbb{R}^{d_i}`
+ n_samples_bary : int
+ Number of barycenter points
+ Y_init : (n_samples_bary,d) array-like
+ Initialization of the support locations (on `k` atoms) of the barycenter
+ b : (n_samples_bary,) array-like
+ Initialization of the weights of the barycenter measure (on the simplex)
+ weights : (p,) array-like
+ Initialization of the coefficients of the barycenter (on the simplex)
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshold on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+ numThreads: int or "max", optional (default=1, i.e. OpenMP is not used)
+ If compiled with OpenMP, chooses the number of threads to parallelize.
+ "max" selects the highest number possible.
+ eps: Stability coefficient for the change of variable matrix inversion
+ If the :math:`\mathbf{P}_i^T` matrices don't span :math:`\mathbb{R}^d`, the problem is ill-defined and a matrix
+ inversion will fail. In this case one may set eps=1e-8 and get a solution anyway (which may make little sense)
+
+
+ Returns
+ -------
+ Y : (n_samples_bary,d) array-like
+ Support locations (on n_samples_bary atoms) of the barycenter
+
+
+ .. _references-generalized-free-support-barycenter:
+ References
+ ----------
+ .. [20] Cuturi, M. and Doucet, A.. "Fast computation of Wasserstein barycenters." International Conference on Machine Learning. 2014.
+
+ .. [42] Delon, J., Gozlan, N., and Saint-Dizier, A.. Generalized Wasserstein barycenters between probability measures living on different subspaces. arXiv preprint arXiv:2105.09755, 2021.
+
+ """
+ nx = get_backend(*X_list, *a_list, *P_list)
+ d = P_list[0].shape[1]
+ p = len(X_list)
+
+ if weights is None:
+ weights = nx.ones(p, type_as=X_list[0]) / p
+
+ # variable change matrix to reduce the problem to a Wasserstein Barycenter (WB)
+ A = eps * nx.eye(d, type_as=X_list[0]) # if eps nonzero: will force the invertibility of A
+ for (P_i, lambda_i) in zip(P_list, weights):
+ A = A + lambda_i * P_i.T @ P_i
+ B = nx.inv(nx.sqrtm(A))
+
+ Z_list = [x @ Pi @ B.T for (x, Pi) in zip(X_list, P_list)] # change of variables -> (WB) problem on Z
+
+ if Y_init is None:
+ Y_init = nx.randn(n_samples_bary, d, type_as=X_list[0])
+
+ if b is None:
+ b = nx.ones(n_samples_bary, type_as=X_list[0]) / n_samples_bary # not optimised
+
+ out = free_support_barycenter(Z_list, a_list, Y_init, b, numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose, log=log, numThreads=numThreads)
+
+ if log: # unpack
+ Y, log_dict = out
+ else:
+ Y = out
+ log_dict = None
+ Y = Y @ B.T # return to the Generalised WB formulation
+
+ if log:
+ return Y, log_dict
+ else:
+ return Y
diff --git a/ot/lp/cvx.py b/ot/lp/cvx.py
index fbf3c0e..361ad0f 100644
--- a/ot/lp/cvx.py
+++ b/ot/lp/cvx.py
@@ -80,7 +80,7 @@ def barycenter(A, M, weights=None, verbose=False, log=False, solver='interior-po
if weights is None:
weights = np.ones(A.shape[1]) / A.shape[1]
else:
- assert(len(weights) == A.shape[1])
+ assert len(weights) == A.shape[1]
n_distributions = A.shape[1]
n = A.shape[0]
diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx
index 42e08f4..e5cec89 100644
--- a/ot/lp/emd_wrap.pyx
+++ b/ot/lp/emd_wrap.pyx
@@ -14,13 +14,14 @@ from ..utils import dist
cimport cython
cimport libc.math as math
+from libc.stdint cimport uint64_t
import warnings
cdef extern from "EMD.h":
- int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int maxIter) nogil
- int EMD_wrap_omp(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int maxIter, int numThreads) nogil
+ int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter) nogil
+ int EMD_wrap_omp(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter, int numThreads) nogil
cdef enum ProblemType: INFEASIBLE, OPTIMAL, UNBOUNDED, MAX_ITER_REACHED
@@ -39,7 +40,7 @@ def check_result(result_code):
@cython.boundscheck(False)
@cython.wraparound(False)
-def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mode="c"] b, np.ndarray[double, ndim=2, mode="c"] M, int max_iter, int numThreads):
+def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mode="c"] b, np.ndarray[double, ndim=2, mode="c"] M, uint64_t max_iter, int numThreads):
"""
Solves the Earth Movers distance problem and returns the optimal transport matrix
@@ -75,7 +76,7 @@ def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mod
target histogram
M : (ns,nt) numpy.ndarray, float64
loss matrix
- max_iter : int
+ max_iter : uint64_t
The maximum number of iterations before stopping the optimization
algorithm if it has not converged.
diff --git a/ot/lp/network_simplex_simple.h b/ot/lp/network_simplex_simple.h
index 3b46b9b..9612a8a 100644
--- a/ot/lp/network_simplex_simple.h
+++ b/ot/lp/network_simplex_simple.h
@@ -233,7 +233,7 @@ namespace lemon {
/// mixed order in the internal data structure.
/// In special cases, it could lead to better overall performance,
/// but it is usually slower. Therefore it is disabled by default.
- NetworkSimplexSimple(const GR& graph, bool arc_mixing, int nbnodes, ArcsType nb_arcs, size_t maxiters) :
+ NetworkSimplexSimple(const GR& graph, bool arc_mixing, int nbnodes, ArcsType nb_arcs, uint64_t maxiters) :
_graph(graph), //_arc_id(graph),
_arc_mixing(arc_mixing), _init_nb_nodes(nbnodes), _init_nb_arcs(nb_arcs),
MAX(std::numeric_limits<Value>::max()),
@@ -242,7 +242,7 @@ namespace lemon {
{
// Reset data structures
reset();
- max_iter=maxiters;
+ max_iter = maxiters;
}
/// The type of the flow amounts, capacity bounds and supply values
@@ -293,7 +293,7 @@ namespace lemon {
private:
- size_t max_iter;
+ uint64_t max_iter;
TEMPLATE_DIGRAPH_TYPEDEFS(GR);
typedef std::vector<int> IntVector;
@@ -1427,14 +1427,12 @@ namespace lemon {
// Perform heuristic initial pivots
if (!initialPivots()) return UNBOUNDED;
- size_t iter_number=0;
+ uint64_t iter_number = 0;
//pivot.setDantzig(true);
// Execute the Network Simplex algorithm
while (pivot.findEnteringArc()) {
if(max_iter > 0 && ++iter_number>=max_iter&&max_iter>0){
- char errMess[1000];
- sprintf( errMess, "RESULT MIGHT BE INACURATE\nMax number of iteration reached, currently \%d. Sometimes iterations go on in cycle even though the solution has been reached, to check if it's the case here have a look at the minimal reduced cost. If it is very close to machine precision, you might actually have the correct solution, if not try setting the maximum number of iterations a bit higher\n",iter_number );
- std::cerr << errMess;
+ // max iterations hit
retVal = MAX_ITER_REACHED;
break;
}
diff --git a/ot/lp/network_simplex_simple_omp.h b/ot/lp/network_simplex_simple_omp.h
index 87e4c05..890b7ab 100644
--- a/ot/lp/network_simplex_simple_omp.h
+++ b/ot/lp/network_simplex_simple_omp.h
@@ -41,8 +41,8 @@
#undef EPSILON
#undef _EPSILON
#undef MAX_DEBUG_ITER
-#define EPSILON std::numeric_limits<Cost>::epsilon()*10
-#define _EPSILON 1e-8
+#define EPSILON std::numeric_limits<Cost>::epsilon()
+#define _EPSILON 1e-14
#define MAX_DEBUG_ITER 100000
/// \ingroup min_cost_flow_algs
@@ -67,7 +67,7 @@
//#include "core.h"
//#include "lmath.h"
-#ifdef OMP
+#ifdef _OPENMP
#include <omp.h>
#endif
#include <cmath>
@@ -244,7 +244,7 @@ namespace lemon_omp {
/// mixed order in the internal data structure.
/// In special cases, it could lead to better overall performance,
/// but it is usually slower. Therefore it is disabled by default.
- NetworkSimplexSimple(const GR& graph, bool arc_mixing, int nbnodes, ArcsType nb_arcs, size_t maxiters = 0, int numThreads=-1) :
+ NetworkSimplexSimple(const GR& graph, bool arc_mixing, int nbnodes, ArcsType nb_arcs, uint64_t maxiters = 0, int numThreads=-1) :
_graph(graph), //_arc_id(graph),
_arc_mixing(arc_mixing), _init_nb_nodes(nbnodes), _init_nb_arcs(nb_arcs),
MAX(std::numeric_limits<Value>::max()),
@@ -254,7 +254,7 @@ namespace lemon_omp {
// Reset data structures
reset();
max_iter = maxiters;
-#ifdef OMP
+#ifdef _OPENMP
if (max_threads < 0) {
max_threads = omp_get_max_threads();
}
@@ -317,7 +317,7 @@ namespace lemon_omp {
private:
- size_t max_iter;
+ uint64_t max_iter;
int num_threads;
TEMPLATE_DIGRAPH_TYPEDEFS(GR);
@@ -513,7 +513,7 @@ namespace lemon_omp {
int j;
#pragma omp parallel
{
-#ifdef OMP
+#ifdef _OPENMP
int t = omp_get_thread_num();
#else
int t = 0;
@@ -1563,7 +1563,7 @@ namespace lemon_omp {
// Perform heuristic initial pivots
if (!initialPivots()) return UNBOUNDED;
- size_t iter_number = 0;
+ uint64_t iter_number = 0;
// Execute the Network Simplex algorithm
while (pivot.findEnteringArc()) {
if ((++iter_number <= max_iter&&max_iter > 0) || max_iter<=0) {
@@ -1610,9 +1610,7 @@ namespace lemon_omp {
} else {
- char errMess[1000];
- sprintf( errMess, "RESULT MIGHT BE INACURATE\nMax number of iteration reached, currently \%d. Sometimes iterations go on in cycle even though the solution has been reached, to check if it's the case here have a look at the minimal reduced cost. If it is very close to machine precision, you might actually have the correct solution, if not try setting the maximum number of iterations a bit higher\n",iter_number );
- std::cerr << errMess;
+ // max iters
retVal = MAX_ITER_REACHED;
break;
}
diff --git a/ot/lp/solver_1d.py b/ot/lp/solver_1d.py
index 43763a9..bcfc920 100644
--- a/ot/lp/solver_1d.py
+++ b/ot/lp/solver_1d.py
@@ -53,7 +53,7 @@ def wasserstein_1d(u_values, v_values, u_weights=None, v_weights=None, p=1, requ
distributions
.. math:
- OT_{loss} = \int_0^1 |cdf_u^{-1}(q) cdf_v^{-1}(q)|^p dq
+ OT_{loss} = \int_0^1 |cdf_u^{-1}(q) - cdf_v^{-1}(q)|^p dq
It is formally the p-Wasserstein distance raised to the power p.
We do so in a vectorized way by first building the individual quantile functions then integrating them.
@@ -129,7 +129,7 @@ def wasserstein_1d(u_values, v_values, u_weights=None, v_weights=None, p=1, requ
diff_quantiles = nx.abs(u_quantiles - v_quantiles)
if p == 1:
- return nx.sum(delta * nx.abs(diff_quantiles), axis=0)
+ return nx.sum(delta * diff_quantiles, axis=0)
return nx.sum(delta * nx.power(diff_quantiles, p), axis=0)
@@ -365,3 +365,628 @@ def emd2_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True,
log_emd = {'G': G}
return cost, log_emd
return cost
+
+
+def roll_cols(M, shifts):
+ r"""
+ Utils functions which allow to shift the order of each row of a 2d matrix
+
+ Parameters
+ ----------
+ M : (nr, nc) ndarray
+ Matrix to shift
+ shifts: int or (nr,) ndarray
+
+ Returns
+ -------
+ Shifted array
+
+ Examples
+ --------
+ >>> M = np.array([[1,2,3],[4,5,6],[7,8,9]])
+ >>> roll_cols(M, 2)
+ array([[2, 3, 1],
+ [5, 6, 4],
+ [8, 9, 7]])
+ >>> roll_cols(M, np.array([[1],[2],[1]]))
+ array([[3, 1, 2],
+ [5, 6, 4],
+ [9, 7, 8]])
+
+ References
+ ----------
+ https://stackoverflow.com/questions/66596699/how-to-shift-columns-or-rows-in-a-tensor-with-different-offsets-in-pytorch
+ """
+ nx = get_backend(M)
+
+ n_rows, n_cols = M.shape
+
+ arange1 = nx.tile(nx.reshape(nx.arange(n_cols), (1, n_cols)), (n_rows, 1))
+ arange2 = (arange1 - shifts) % n_cols
+
+ return nx.take_along_axis(M, arange2, 1)
+
+
+def derivative_cost_on_circle(theta, u_values, v_values, u_cdf, v_cdf, p=2):
+ r""" Computes the left and right derivative of the cost (Equation (6.3) and (6.4) of [1])
+
+ Parameters
+ ----------
+ theta: array-like, shape (n_batch, n)
+ Cuts on the circle
+ u_values: array-like, shape (n_batch, n)
+ locations of the first empirical distribution
+ v_values: array-like, shape (n_batch, n)
+ locations of the second empirical distribution
+ u_cdf: array-like, shape (n_batch, n)
+ cdf of the first empirical distribution
+ v_cdf: array-like, shape (n_batch, n)
+ cdf of the second empirical distribution
+ p: float, optional = 2
+ Power p used for computing the Wasserstein distance
+
+ Returns
+ -------
+ dCp: array-like, shape (n_batch, 1)
+ The batched right derivative
+ dCm: array-like, shape (n_batch, 1)
+ The batched left derivative
+
+ References
+ ---------
+ .. [44] Delon, Julie, Julien Salomon, and Andrei Sobolevski. "Fast transport optimization for Monge costs on the circle." SIAM Journal on Applied Mathematics 70.7 (2010): 2239-2258.
+ """
+ nx = get_backend(theta, u_values, v_values, u_cdf, v_cdf)
+
+ v_values = nx.copy(v_values)
+
+ n = u_values.shape[-1]
+ m_batch, m = v_values.shape
+
+ v_cdf_theta = v_cdf - (theta - nx.floor(theta))
+
+ mask_p = v_cdf_theta >= 0
+ mask_n = v_cdf_theta < 0
+
+ v_values[mask_n] += nx.floor(theta)[mask_n] + 1
+ v_values[mask_p] += nx.floor(theta)[mask_p]
+
+ if nx.any(mask_n) and nx.any(mask_p):
+ v_cdf_theta[mask_n] += 1
+
+ v_cdf_theta2 = nx.copy(v_cdf_theta)
+ v_cdf_theta2[mask_n] = np.inf
+ shift = (-nx.argmin(v_cdf_theta2, axis=-1))
+
+ v_cdf_theta = roll_cols(v_cdf_theta, nx.reshape(shift, (-1, 1)))
+ v_values = roll_cols(v_values, nx.reshape(shift, (-1, 1)))
+ v_values = nx.concatenate([v_values, nx.reshape(v_values[:, 0], (-1, 1)) + 1], axis=1)
+
+ if nx.__name__ == 'torch':
+ # this is to ensure the best performance for torch searchsorted
+ # and avoid a warninng related to non-contiguous arrays
+ u_cdf = u_cdf.contiguous()
+ v_cdf_theta = v_cdf_theta.contiguous()
+
+ # quantiles of F_u evaluated in F_v^\theta
+ u_index = nx.searchsorted(u_cdf, v_cdf_theta)
+ u_icdf_theta = nx.take_along_axis(u_values, nx.clip(u_index, 0, n - 1), -1)
+
+ # Deal with 1
+ u_cdfm = nx.concatenate([u_cdf, nx.reshape(u_cdf[:, 0], (-1, 1)) + 1], axis=1)
+ u_valuesm = nx.concatenate([u_values, nx.reshape(u_values[:, 0], (-1, 1)) + 1], axis=1)
+
+ if nx.__name__ == 'torch':
+ # this is to ensure the best performance for torch searchsorted
+ # and avoid a warninng related to non-contiguous arrays
+ u_cdfm = u_cdfm.contiguous()
+ v_cdf_theta = v_cdf_theta.contiguous()
+
+ u_indexm = nx.searchsorted(u_cdfm, v_cdf_theta, side="right")
+ u_icdfm_theta = nx.take_along_axis(u_valuesm, nx.clip(u_indexm, 0, n), -1)
+
+ dCp = nx.sum(nx.power(nx.abs(u_icdf_theta - v_values[:, 1:]), p)
+ - nx.power(nx.abs(u_icdf_theta - v_values[:, :-1]), p), axis=-1)
+
+ dCm = nx.sum(nx.power(nx.abs(u_icdfm_theta - v_values[:, 1:]), p)
+ - nx.power(nx.abs(u_icdfm_theta - v_values[:, :-1]), p), axis=-1)
+
+ return dCp.reshape(-1, 1), dCm.reshape(-1, 1)
+
+
+def ot_cost_on_circle(theta, u_values, v_values, u_cdf, v_cdf, p):
+ r""" Computes the the cost (Equation (6.2) of [1])
+
+ Parameters
+ ----------
+ theta: array-like, shape (n_batch, n)
+ Cuts on the circle
+ u_values: array-like, shape (n_batch, n)
+ locations of the first empirical distribution
+ v_values: array-like, shape (n_batch, n)
+ locations of the second empirical distribution
+ u_cdf: array-like, shape (n_batch, n)
+ cdf of the first empirical distribution
+ v_cdf: array-like, shape (n_batch, n)
+ cdf of the second empirical distribution
+ p: float, optional = 2
+ Power p used for computing the Wasserstein distance
+
+ Returns
+ -------
+ ot_cost: array-like, shape (n_batch,)
+ OT cost evaluated at theta
+
+ References
+ ---------
+ .. [44] Delon, Julie, Julien Salomon, and Andrei Sobolevski. "Fast transport optimization for Monge costs on the circle." SIAM Journal on Applied Mathematics 70.7 (2010): 2239-2258.
+ """
+ nx = get_backend(theta, u_values, v_values, u_cdf, v_cdf)
+
+ v_values = nx.copy(v_values)
+
+ m_batch, m = v_values.shape
+ n_batch, n = u_values.shape
+
+ v_cdf_theta = v_cdf - (theta - nx.floor(theta))
+
+ mask_p = v_cdf_theta >= 0
+ mask_n = v_cdf_theta < 0
+
+ v_values[mask_n] += nx.floor(theta)[mask_n] + 1
+ v_values[mask_p] += nx.floor(theta)[mask_p]
+
+ if nx.any(mask_n) and nx.any(mask_p):
+ v_cdf_theta[mask_n] += 1
+
+ # Put negative values at the end
+ v_cdf_theta2 = nx.copy(v_cdf_theta)
+ v_cdf_theta2[mask_n] = np.inf
+ shift = (-nx.argmin(v_cdf_theta2, axis=-1))
+
+ v_cdf_theta = roll_cols(v_cdf_theta, nx.reshape(shift, (-1, 1)))
+ v_values = roll_cols(v_values, nx.reshape(shift, (-1, 1)))
+ v_values = nx.concatenate([v_values, nx.reshape(v_values[:, 0], (-1, 1)) + 1], axis=1)
+
+ # Compute absciss
+ cdf_axis = nx.sort(nx.concatenate((u_cdf, v_cdf_theta), -1), -1)
+ cdf_axis_pad = nx.zero_pad(cdf_axis, pad_width=[(0, 0), (1, 0)])
+
+ delta = cdf_axis_pad[..., 1:] - cdf_axis_pad[..., :-1]
+
+ if nx.__name__ == 'torch':
+ # this is to ensure the best performance for torch searchsorted
+ # and avoid a warninng related to non-contiguous arrays
+ u_cdf = u_cdf.contiguous()
+ v_cdf_theta = v_cdf_theta.contiguous()
+ cdf_axis = cdf_axis.contiguous()
+
+ # Compute icdf
+ u_index = nx.searchsorted(u_cdf, cdf_axis)
+ u_icdf = nx.take_along_axis(u_values, u_index.clip(0, n - 1), -1)
+
+ v_values = nx.concatenate([v_values, nx.reshape(v_values[:, 0], (-1, 1)) + 1], axis=1)
+ v_index = nx.searchsorted(v_cdf_theta, cdf_axis)
+ v_icdf = nx.take_along_axis(v_values, v_index.clip(0, m), -1)
+
+ if p == 1:
+ ot_cost = nx.sum(delta * nx.abs(u_icdf - v_icdf), axis=-1)
+ else:
+ ot_cost = nx.sum(delta * nx.power(nx.abs(u_icdf - v_icdf), p), axis=-1)
+
+ return ot_cost
+
+
+def binary_search_circle(u_values, v_values, u_weights=None, v_weights=None, p=1,
+ Lm=10, Lp=10, tm=-1, tp=1, eps=1e-6, require_sort=True,
+ log=False):
+ r"""Computes the Wasserstein distance on the circle using the Binary search algorithm proposed in [44].
+ Samples need to be in :math:`S^1\cong [0,1[`. If they are on :math:`\mathbb{R}`,
+ takes the value modulo 1.
+ If the values are on :math:`S^1\subset\mathbb{R}^2`, it is required to first find the coordinates
+ using e.g. the atan2 function.
+
+ .. math::
+ W_p^p(u,v) = \inf_{\theta\in\mathbb{R}}\int_0^1 |F_u^{-1}(q) - (F_v-\theta)^{-1}(q)|^p\ \mathrm{d}q
+
+ where:
+
+ - :math:`F_u` and :math:`F_v` are respectively the cdfs of :math:`u` and :math:`v`
+
+ For values :math:`x=(x_1,x_2)\in S^1`, it is required to first get their coordinates with
+
+ .. math::
+ u = \frac{\pi + \mathrm{atan2}(-x_2,-x_1)}{2\pi}
+
+ using e.g. ot.utils.get_coordinate_circle(x)
+
+ The function runs on backend but tensorflow is not supported.
+
+ Parameters
+ ----------
+ u_values : ndarray, shape (n, ...)
+ samples in the source domain (coordinates on [0,1[)
+ v_values : ndarray, shape (n, ...)
+ samples in the target domain (coordinates on [0,1[)
+ u_weights : ndarray, shape (n, ...), optional
+ samples weights in the source domain
+ v_weights : ndarray, shape (n, ...), optional
+ samples weights in the target domain
+ p : float, optional (default=1)
+ Power p used for computing the Wasserstein distance
+ Lm : int, optional
+ Lower bound dC
+ Lp : int, optional
+ Upper bound dC
+ tm: float, optional
+ Lower bound theta
+ tp: float, optional
+ Upper bound theta
+ eps: float, optional
+ Stopping condition
+ require_sort: bool, optional
+ If True, sort the values.
+ log: bool, optional
+ If True, returns also the optimal theta
+
+ Returns
+ -------
+ loss: float
+ Cost associated to the optimal transportation
+ log: dict, optional
+ log dictionary returned only if log==True in parameters
+
+ Examples
+ --------
+ >>> u = np.array([[0.2,0.5,0.8]])%1
+ >>> v = np.array([[0.4,0.5,0.7]])%1
+ >>> binary_search_circle(u.T, v.T, p=1)
+ array([0.1])
+
+ References
+ ----------
+ .. [44] Delon, Julie, Julien Salomon, and Andrei Sobolevski. "Fast transport optimization for Monge costs on the circle." SIAM Journal on Applied Mathematics 70.7 (2010): 2239-2258.
+ .. Matlab Code: https://users.mccme.ru/ansobol/otarie/software.html
+ """
+ assert p >= 1, "The OT loss is only valid for p>=1, {p} was given".format(p=p)
+
+ if u_weights is not None and v_weights is not None:
+ nx = get_backend(u_values, v_values, u_weights, v_weights)
+ else:
+ nx = get_backend(u_values, v_values)
+
+ n = u_values.shape[0]
+ m = v_values.shape[0]
+
+ if len(u_values.shape) == 1:
+ u_values = nx.reshape(u_values, (n, 1))
+ if len(v_values.shape) == 1:
+ v_values = nx.reshape(v_values, (m, 1))
+
+ if u_values.shape[1] != v_values.shape[1]:
+ raise ValueError(
+ "u and v must have the same number of batchs {} and {} respectively given".format(u_values.shape[1],
+ v_values.shape[1]))
+
+ u_values = u_values % 1
+ v_values = v_values % 1
+
+ if u_weights is None:
+ u_weights = nx.full(u_values.shape, 1. / n, type_as=u_values)
+ elif u_weights.ndim != u_values.ndim:
+ u_weights = nx.repeat(u_weights[..., None], u_values.shape[-1], -1)
+ if v_weights is None:
+ v_weights = nx.full(v_values.shape, 1. / m, type_as=v_values)
+ elif v_weights.ndim != v_values.ndim:
+ v_weights = nx.repeat(v_weights[..., None], v_values.shape[-1], -1)
+
+ if require_sort:
+ u_sorter = nx.argsort(u_values, 0)
+ u_values = nx.take_along_axis(u_values, u_sorter, 0)
+
+ v_sorter = nx.argsort(v_values, 0)
+ v_values = nx.take_along_axis(v_values, v_sorter, 0)
+
+ u_weights = nx.take_along_axis(u_weights, u_sorter, 0)
+ v_weights = nx.take_along_axis(v_weights, v_sorter, 0)
+
+ u_cdf = nx.cumsum(u_weights, 0).T
+ v_cdf = nx.cumsum(v_weights, 0).T
+
+ u_values = u_values.T
+ v_values = v_values.T
+
+ L = max(Lm, Lp)
+
+ tm = tm * nx.reshape(nx.ones((u_values.shape[0],), type_as=u_values), (-1, 1))
+ tm = nx.tile(tm, (1, m))
+ tp = tp * nx.reshape(nx.ones((u_values.shape[0],), type_as=u_values), (-1, 1))
+ tp = nx.tile(tp, (1, m))
+ tc = (tm + tp) / 2
+
+ done = nx.zeros((u_values.shape[0], m))
+
+ cpt = 0
+ while nx.any(1 - done):
+ cpt += 1
+
+ dCp, dCm = derivative_cost_on_circle(tc, u_values, v_values, u_cdf, v_cdf, p)
+ done = ((dCp * dCm) <= 0) * 1
+
+ mask = ((tp - tm) < eps / L) * (1 - done)
+
+ if nx.any(mask):
+ # can probably be improved by computing only relevant values
+ dCptp, dCmtp = derivative_cost_on_circle(tp, u_values, v_values, u_cdf, v_cdf, p)
+ dCptm, dCmtm = derivative_cost_on_circle(tm, u_values, v_values, u_cdf, v_cdf, p)
+ Ctm = ot_cost_on_circle(tm, u_values, v_values, u_cdf, v_cdf, p).reshape(-1, 1)
+ Ctp = ot_cost_on_circle(tp, u_values, v_values, u_cdf, v_cdf, p).reshape(-1, 1)
+
+ mask_end = mask * (nx.abs(dCptm - dCmtp) > 0.001)
+ tc[mask_end > 0] = ((Ctp - Ctm + tm * dCptm - tp * dCmtp) / (dCptm - dCmtp))[mask_end > 0]
+ done[nx.prod(mask, axis=-1) > 0] = 1
+ elif nx.any(1 - done):
+ tm[((1 - mask) * (dCp < 0)) > 0] = tc[((1 - mask) * (dCp < 0)) > 0]
+ tp[((1 - mask) * (dCp >= 0)) > 0] = tc[((1 - mask) * (dCp >= 0)) > 0]
+ tc[((1 - mask) * (1 - done)) > 0] = (tm[((1 - mask) * (1 - done)) > 0] + tp[((1 - mask) * (1 - done)) > 0]) / 2
+
+ w = ot_cost_on_circle(tc, u_values, v_values, u_cdf, v_cdf, p)
+
+ if log:
+ return w, {"optimal_theta": tc[:, 0]}
+ return w
+
+
+def wasserstein1_circle(u_values, v_values, u_weights=None, v_weights=None, require_sort=True):
+ r"""Computes the 1-Wasserstein distance on the circle using the level median [45].
+ Samples need to be in :math:`S^1\cong [0,1[`. If they are on :math:`\mathbb{R}`,
+ takes the value modulo 1.
+ If the values are on :math:`S^1\subset\mathbb{R}^2`, first find the coordinates
+ using e.g. the atan2 function.
+ The function runs on backend but tensorflow is not supported.
+
+ .. math::
+ W_1(u,v) = \int_0^1 |F_u(t)-F_v(t)-LevMed(F_u-F_v)|\ \mathrm{d}t
+
+ Parameters
+ ----------
+ u_values : ndarray, shape (n, ...)
+ samples in the source domain (coordinates on [0,1[)
+ v_values : ndarray, shape (n, ...)
+ samples in the target domain (coordinates on [0,1[)
+ u_weights : ndarray, shape (n, ...), optional
+ samples weights in the source domain
+ v_weights : ndarray, shape (n, ...), optional
+ samples weights in the target domain
+ require_sort: bool, optional
+ If True, sort the values.
+
+ Returns
+ -------
+ loss: float
+ Cost associated to the optimal transportation
+
+ Examples
+ --------
+ >>> u = np.array([[0.2,0.5,0.8]])%1
+ >>> v = np.array([[0.4,0.5,0.7]])%1
+ >>> wasserstein1_circle(u.T, v.T)
+ array([0.1])
+
+ References
+ ----------
+ .. [45] Hundrieser, Shayan, Marcel Klatt, and Axel Munk. "The statistics of circular optimal transport." Directional Statistics for Innovative Applications: A Bicentennial Tribute to Florence Nightingale. Singapore: Springer Nature Singapore, 2022. 57-82.
+ .. Code R: https://gitlab.gwdg.de/shundri/circularOT/-/tree/master/
+ """
+
+ if u_weights is not None and v_weights is not None:
+ nx = get_backend(u_values, v_values, u_weights, v_weights)
+ else:
+ nx = get_backend(u_values, v_values)
+
+ n = u_values.shape[0]
+ m = v_values.shape[0]
+
+ if len(u_values.shape) == 1:
+ u_values = nx.reshape(u_values, (n, 1))
+ if len(v_values.shape) == 1:
+ v_values = nx.reshape(v_values, (m, 1))
+
+ if u_values.shape[1] != v_values.shape[1]:
+ raise ValueError(
+ "u and v must have the same number of batchs {} and {} respectively given".format(u_values.shape[1],
+ v_values.shape[1]))
+
+ u_values = u_values % 1
+ v_values = v_values % 1
+
+ if u_weights is None:
+ u_weights = nx.full(u_values.shape, 1. / n, type_as=u_values)
+ elif u_weights.ndim != u_values.ndim:
+ u_weights = nx.repeat(u_weights[..., None], u_values.shape[-1], -1)
+ if v_weights is None:
+ v_weights = nx.full(v_values.shape, 1. / m, type_as=v_values)
+ elif v_weights.ndim != v_values.ndim:
+ v_weights = nx.repeat(v_weights[..., None], v_values.shape[-1], -1)
+
+ if require_sort:
+ u_sorter = nx.argsort(u_values, 0)
+ u_values = nx.take_along_axis(u_values, u_sorter, 0)
+
+ v_sorter = nx.argsort(v_values, 0)
+ v_values = nx.take_along_axis(v_values, v_sorter, 0)
+
+ u_weights = nx.take_along_axis(u_weights, u_sorter, 0)
+ v_weights = nx.take_along_axis(v_weights, v_sorter, 0)
+
+ # Code inspired from https://gitlab.gwdg.de/shundri/circularOT/-/tree/master/
+ values_sorted, values_sorter = nx.sort2(nx.concatenate((u_values, v_values), 0), 0)
+
+ cdf_diff = nx.cumsum(nx.take_along_axis(nx.concatenate((u_weights, -v_weights), 0), values_sorter, 0), 0)
+ cdf_diff_sorted, cdf_diff_sorter = nx.sort2(cdf_diff, axis=0)
+
+ values_sorted = nx.zero_pad(values_sorted, pad_width=[(0, 1), (0, 0)], value=1)
+ delta = values_sorted[1:, ...] - values_sorted[:-1, ...]
+ weight_sorted = nx.take_along_axis(delta, cdf_diff_sorter, 0)
+
+ sum_weights = nx.cumsum(weight_sorted, axis=0) - 0.5
+ sum_weights[sum_weights < 0] = np.inf
+ inds = nx.argmin(sum_weights, axis=0)
+
+ levMed = nx.take_along_axis(cdf_diff_sorted, nx.reshape(inds, (1, -1)), 0)
+
+ return nx.sum(delta * nx.abs(cdf_diff - levMed), axis=0)
+
+
+def wasserstein_circle(u_values, v_values, u_weights=None, v_weights=None, p=1,
+ Lm=10, Lp=10, tm=-1, tp=1, eps=1e-6, require_sort=True):
+ r"""Computes the Wasserstein distance on the circle using either [45] for p=1 or
+ the binary search algorithm proposed in [44] otherwise.
+ Samples need to be in :math:`S^1\cong [0,1[`. If they are on :math:`\mathbb{R}`,
+ takes the value modulo 1.
+ If the values are on :math:`S^1\subset\mathbb{R}^2`, it requires to first find the coordinates
+ using e.g. the atan2 function.
+
+ General loss returned:
+
+ .. math::
+ OT_{loss} = \inf_{\theta\in\mathbb{R}}\int_0^1 |cdf_u^{-1}(q) - (cdf_v-\theta)^{-1}(q)|^p\ \mathrm{d}q
+
+ For p=1, [45]
+
+ .. math::
+ W_1(u,v) = \int_0^1 |F_u(t)-F_v(t)-LevMed(F_u-F_v)|\ \mathrm{d}t
+
+ For values :math:`x=(x_1,x_2)\in S^1`, it is required to first get their coordinates with
+
+ .. math::
+ u = \frac{\pi + \mathrm{atan2}(-x_2,-x_1)}{2\pi}
+
+ using e.g. ot.utils.get_coordinate_circle(x)
+
+ The function runs on backend but tensorflow is not supported.
+
+ Parameters
+ ----------
+ u_values : ndarray, shape (n, ...)
+ samples in the source domain (coordinates on [0,1[)
+ v_values : ndarray, shape (n, ...)
+ samples in the target domain (coordinates on [0,1[)
+ u_weights : ndarray, shape (n, ...), optional
+ samples weights in the source domain
+ v_weights : ndarray, shape (n, ...), optional
+ samples weights in the target domain
+ p : float, optional (default=1)
+ Power p used for computing the Wasserstein distance
+ Lm : int, optional
+ Lower bound dC. For p>1.
+ Lp : int, optional
+ Upper bound dC. For p>1.
+ tm: float, optional
+ Lower bound theta. For p>1.
+ tp: float, optional
+ Upper bound theta. For p>1.
+ eps: float, optional
+ Stopping condition. For p>1.
+ require_sort: bool, optional
+ If True, sort the values.
+
+ Returns
+ -------
+ loss: float
+ Cost associated to the optimal transportation
+
+ Examples
+ --------
+ >>> u = np.array([[0.2,0.5,0.8]])%1
+ >>> v = np.array([[0.4,0.5,0.7]])%1
+ >>> wasserstein_circle(u.T, v.T)
+ array([0.1])
+
+ References
+ ----------
+ .. [44] Hundrieser, Shayan, Marcel Klatt, and Axel Munk. "The statistics of circular optimal transport." Directional Statistics for Innovative Applications: A Bicentennial Tribute to Florence Nightingale. Singapore: Springer Nature Singapore, 2022. 57-82.
+ .. [45] Delon, Julie, Julien Salomon, and Andrei Sobolevski. "Fast transport optimization for Monge costs on the circle." SIAM Journal on Applied Mathematics 70.7 (2010): 2239-2258.
+ """
+ assert p >= 1, "The OT loss is only valid for p>=1, {p} was given".format(p=p)
+
+ if p == 1:
+ return wasserstein1_circle(u_values, v_values, u_weights, v_weights, require_sort)
+
+ return binary_search_circle(u_values, v_values, u_weights, v_weights,
+ p=p, Lm=Lm, Lp=Lp, tm=tm, tp=tp, eps=eps,
+ require_sort=require_sort)
+
+
+def semidiscrete_wasserstein2_unif_circle(u_values, u_weights=None):
+ r"""Computes the closed-form for the 2-Wasserstein distance between samples and a uniform distribution on :math:`S^1`
+ Samples need to be in :math:`S^1\cong [0,1[`. If they are on :math:`\mathbb{R}`,
+ takes the value modulo 1.
+ If the values are on :math:`S^1\subset\mathbb{R}^2`, it is required to first find the coordinates
+ using e.g. the atan2 function.
+
+ .. math::
+ W_2^2(\mu_n, \nu) = \sum_{i=1}^n \alpha_i x_i^2 - \left(\sum_{i=1}^n \alpha_i x_i\right)^2 + \sum_{i=1}^n \alpha_i x_i \left(1-\alpha_i-2\sum_{k=1}^{i-1}\alpha_k\right) + \frac{1}{12}
+
+ where:
+
+ - :math:`\nu=\mathrm{Unif}(S^1)` and :math:`\mu_n = \sum_{i=1}^n \alpha_i \delta_{x_i}`
+
+ For values :math:`x=(x_1,x_2)\in S^1`, it is required to first get their coordinates with
+
+ .. math::
+ u = \frac{\pi + \mathrm{atan2}(-x_2,-x_1)}{2\pi},
+
+ using e.g. ot.utils.get_coordinate_circle(x)
+
+ Parameters
+ ----------
+ u_values: ndarray, shape (n, ...)
+ Samples
+ u_weights : ndarray, shape (n, ...), optional
+ samples weights in the source domain
+
+ Returns
+ -------
+ loss: float
+ Cost associated to the optimal transportation
+
+ Examples
+ --------
+ >>> x0 = np.array([[0], [0.2], [0.4]])
+ >>> semidiscrete_wasserstein2_unif_circle(x0)
+ array([0.02111111])
+
+ References
+ ----------
+ .. [46] Bonet, C., Berg, P., Courty, N., Septier, F., Drumetz, L., & Pham, M. T. (2023). Spherical sliced-wasserstein. International Conference on Learning Representations.
+ """
+
+ if u_weights is not None:
+ nx = get_backend(u_values, u_weights)
+ else:
+ nx = get_backend(u_values)
+
+ n = u_values.shape[0]
+
+ u_values = u_values % 1
+
+ if len(u_values.shape) == 1:
+ u_values = nx.reshape(u_values, (n, 1))
+
+ if u_weights is None:
+ u_weights = nx.full(u_values.shape, 1. / n, type_as=u_values)
+ elif u_weights.ndim != u_values.ndim:
+ u_weights = nx.repeat(u_weights[..., None], u_values.shape[-1], -1)
+
+ u_values = nx.sort(u_values, 0)
+ u_cdf = nx.cumsum(u_weights, 0)
+ u_cdf = nx.zero_pad(u_cdf, [(1, 0), (0, 0)])
+
+ cpt1 = nx.sum(u_weights * u_values**2, axis=0)
+ u_mean = nx.sum(u_weights * u_values, axis=0)
+
+ ns = 1 - u_weights - 2 * u_cdf[:-1]
+ cpt2 = nx.sum(u_values * u_weights * ns, axis=0)
+
+ return cpt1 - u_mean**2 + cpt2 + 1 / 12
diff --git a/ot/optim.py b/ot/optim.py
index 5a1d605..58e5596 100644
--- a/ot/optim.py
+++ b/ot/optim.py
@@ -1,11 +1,11 @@
# -*- coding: utf-8 -*-
"""
-Generic solvers for regularized OT
+Generic solvers for regularized OT or its semi-relaxed version.
"""
# Author: Remi Flamary <remi.flamary@unice.fr>
# Titouan Vayer <titouan.vayer@irisa.fr>
-#
+# Cédric Vincent-Cuaz <cedvincentcuaz@gmail.com>
# License: MIT License
import numpy as np
@@ -27,7 +27,7 @@ with warnings.catch_warnings():
def line_search_armijo(
f, xk, pk, gfk, old_fval, args=(), c1=1e-4,
- alpha0=0.99, alpha_min=None, alpha_max=None
+ alpha0=0.99, alpha_min=None, alpha_max=None, nx=None, **kwargs
):
r"""
Armijo linesearch function that works with matrices
@@ -35,6 +35,9 @@ def line_search_armijo(
Find an approximate minimum of :math:`f(x_k + \alpha \cdot p_k)` that satisfies the
armijo conditions.
+ .. note:: If the loss function f returns a float (resp. a 1d array) then
+ the returned alpha and fa are float (resp. 1d arrays).
+
Parameters
----------
f : callable
@@ -45,7 +48,7 @@ def line_search_armijo(
descent direction
gfk : array-like
gradient of `f` at :math:`x_k`
- old_fval : float
+ old_fval : float or 1d array
loss value at :math:`x_k`
args : tuple, optional
arguments given to `f`
@@ -57,138 +60,97 @@ def line_search_armijo(
minimum value for alpha
alpha_max : float, optional
maximum value for alpha
-
+ nx : backend, optional
+ If let to its default value None, a backend test will be conducted.
Returns
-------
- alpha : float
+ alpha : float or 1d array
step that satisfy armijo conditions
fc : int
nb of function call
- fa : float
+ fa : float or 1d array
loss value at step alpha
"""
-
- xk, pk, gfk = list_to_array(xk, pk, gfk)
- nx = get_backend(xk, pk)
+ if nx is None:
+ xk, pk, gfk = list_to_array(xk, pk, gfk)
+ xk0, pk0 = xk, pk
+ nx = get_backend(xk0, pk0)
+ else:
+ xk0, pk0 = xk, pk
if len(xk.shape) == 0:
xk = nx.reshape(xk, (-1,))
+ xk = nx.to_numpy(xk)
+ pk = nx.to_numpy(pk)
+ gfk = nx.to_numpy(gfk)
+
fc = [0]
def phi(alpha1):
+ # The callable function operates on nx backend
fc[0] += 1
- return f(xk + alpha1 * pk, *args)
+ alpha10 = nx.from_numpy(alpha1)
+ fval = f(xk0 + alpha10 * pk0, *args)
+ if type(fval) is float:
+ # prevent bug from nx.to_numpy that can look for .cpu or .gpu
+ return fval
+ else:
+ return nx.to_numpy(fval)
if old_fval is None:
phi0 = phi(0.)
- else:
+ elif type(old_fval) is float:
+ # prevent bug from nx.to_numpy that can look for .cpu or .gpu
phi0 = old_fval
+ else:
+ phi0 = nx.to_numpy(old_fval)
- derphi0 = nx.sum(pk * gfk) # Quickfix for matrices
+ derphi0 = np.sum(pk * gfk) # Quickfix for matrices
alpha, phi1 = scalar_search_armijo(
phi, phi0, derphi0, c1=c1, alpha0=alpha0)
if alpha is None:
- return 0., fc[0], phi0
+ return 0., fc[0], nx.from_numpy(phi0, type_as=xk0)
else:
if alpha_min is not None or alpha_max is not None:
alpha = np.clip(alpha, alpha_min, alpha_max)
- return float(alpha), fc[0], phi1
+ return nx.from_numpy(alpha, type_as=xk0), fc[0], nx.from_numpy(phi1, type_as=xk0)
-def solve_linesearch(
- cost, G, deltaG, Mi, f_val, armijo=True, C1=None, C2=None,
- reg=None, Gc=None, constC=None, M=None, alpha_min=None, alpha_max=None
-):
- """
- Solve the linesearch in the FW iterations
-
- Parameters
- ----------
- cost : method
- Cost in the FW for the linesearch
- G : array-like, shape(ns,nt)
- The transport map at a given iteration of the FW
- deltaG : array-like (ns,nt)
- Difference between the optimal map found by linearization in the FW algorithm and the value at a given iteration
- Mi : array-like (ns,nt)
- Cost matrix of the linearized transport problem. Corresponds to the gradient of the cost
- f_val : float
- Value of the cost at `G`
- armijo : bool, optional
- If True the steps of the line-search is found via an armijo research. Else closed form is used.
- If there is convergence issues use False.
- C1 : array-like (ns,ns), optional
- Structure matrix in the source domain. Only used and necessary when armijo=False
- C2 : array-like (nt,nt), optional
- Structure matrix in the target domain. Only used and necessary when armijo=False
- reg : float, optional
- Regularization parameter. Only used and necessary when armijo=False
- Gc : array-like (ns,nt)
- Optimal map found by linearization in the FW algorithm. Only used and necessary when armijo=False
- constC : array-like (ns,nt)
- Constant for the gromov cost. See :ref:`[24] <references-solve-linesearch>`. Only used and necessary when armijo=False
- M : array-like (ns,nt), optional
- Cost matrix between the features. Only used and necessary when armijo=False
- alpha_min : float, optional
- Minimum value for alpha
- alpha_max : float, optional
- Maximum value for alpha
-
- Returns
- -------
- alpha : float
- The optimal step size of the FW
- fc : int
- nb of function call. Useless here
- f_val : float
- The value of the cost for the next iteration
+def generic_conditional_gradient(a, b, M, f, df, reg1, reg2, lp_solver, line_search, G0=None,
+ numItermax=200, stopThr=1e-9,
+ stopThr2=1e-9, verbose=False, log=False, **kwargs):
+ r"""
+ Solve the general regularized OT problem or its semi-relaxed version with
+ conditional gradient or generalized conditional gradient depending on the
+ provided linear program solver.
+ The function solves the following optimization problem if set as a conditional gradient:
- .. _references-solve-linesearch:
- References
- ----------
- .. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain and Courty Nicolas
- "Optimal Transport for structured data with application on graphs"
- International Conference on Machine Learning (ICML). 2019.
- """
- if armijo:
- alpha, fc, f_val = line_search_armijo(
- cost, G, deltaG, Mi, f_val, alpha_min=alpha_min, alpha_max=alpha_max
- )
- else: # requires symetric matrices
- G, deltaG, C1, C2, constC, M = list_to_array(G, deltaG, C1, C2, constC, M)
- if isinstance(M, int) or isinstance(M, float):
- nx = get_backend(G, deltaG, C1, C2, constC)
- else:
- nx = get_backend(G, deltaG, C1, C2, constC, M)
+ .. math::
+ \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
+ \mathrm{reg_1} \cdot f(\gamma)
- dot = nx.dot(nx.dot(C1, deltaG), C2)
- a = -2 * reg * nx.sum(dot * deltaG)
- b = nx.sum((M + reg * constC) * deltaG) - 2 * reg * (nx.sum(dot * G) + nx.sum(nx.dot(nx.dot(C1, G), C2) * deltaG))
- c = cost(G)
+ s.t. \ \gamma \mathbf{1} &= \mathbf{a}
- alpha = solve_1d_linesearch_quad(a, b, c)
- if alpha_min is not None or alpha_max is not None:
- alpha = np.clip(alpha, alpha_min, alpha_max)
- fc = None
- f_val = cost(G + alpha * deltaG)
+ \gamma^T \mathbf{1} &= \mathbf{b} (optional constraint)
- return alpha, fc, f_val
+ \gamma &\geq 0
+ where :
+ - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix
+ - :math:`f` is the regularization term (and `df` is its gradient)
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1)
-def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000,
- stopThr=1e-9, stopThr2=1e-9, verbose=False, log=False, **kwargs):
- r"""
- Solve the general regularized OT problem with conditional gradient
+ The algorithm used for solving the problem is conditional gradient as discussed in :ref:`[1] <references-cg>`
- The function solves the following optimization problem:
+ The function solves the following optimization problem if set a generalized conditional gradient:
.. math::
\gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
- \mathrm{reg} \cdot f(\gamma)
+ \mathrm{reg_1}\cdot f(\gamma) + \mathrm{reg_2}\cdot\Omega(\gamma)
s.t. \ \gamma \mathbf{1} &= \mathbf{a}
@@ -197,29 +159,39 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000,
\gamma &\geq 0
where :
- - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix
- - :math:`f` is the regularization term (and `df` is its gradient)
- - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1)
-
- The algorithm used for solving the problem is conditional gradient as discussed in :ref:`[1] <references-cg>`
+ - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ The algorithm used for solving the problem is the generalized conditional gradient as discussed in :ref:`[5, 7] <references-gcg>`
Parameters
----------
a : array-like, shape (ns,)
samples weights in the source domain
b : array-like, shape (nt,)
- samples in the target domain
+ samples weights in the target domain
M : array-like, shape (ns, nt)
loss matrix
- reg : float
+ f : function
+ Regularization function taking a transportation matrix as argument
+ df: function
+ Gradient of the regularization function taking a transportation matrix as argument
+ reg1 : float
Regularization term >0
+ reg2 : float,
+ Entropic Regularization term >0. Ignored if set to None.
+ lp_solver: function,
+ linear program solver for direction finding of the (generalized) conditional gradient.
+ If set to emd will solve the general regularized OT problem using cg.
+ If set to lp_semi_relaxed_OT will solve the general regularized semi-relaxed OT problem using cg.
+ If set to sinkhorn will solve the general regularized OT problem using generalized cg.
+ line_search: function,
+ Function to find the optimal step. Currently used instances are:
+ line_search_armijo (generic solver). solve_gromov_linesearch for (F)GW problem.
+ solve_semirelaxed_gromov_linesearch for sr(F)GW problem. gcg_linesearch for the Generalized cg.
G0 : array-like, shape (ns,nt), optional
initial guess (default is indep joint density)
numItermax : int, optional
Max number of iterations
- numItermaxEmd : int, optional
- Max number of iterations for emd
stopThr : float, optional
Stop threshold on the relative variation (>0)
stopThr2 : float, optional
@@ -240,16 +212,20 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000,
.. _references-cg:
+ .. _references_gcg:
References
----------
.. [1] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014). Regularized discrete optimal transport. SIAM Journal on Imaging Sciences, 7(3), 1853-1882.
+ .. [5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, "Optimal Transport for Domain Adaptation," in IEEE Transactions on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1
+
+ .. [7] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015). Generalized conditional gradient: analysis of convergence and applications. arXiv preprint arXiv:1510.06567.
+
See Also
--------
ot.lp.emd : Unregularized optimal ransport
ot.bregman.sinkhorn : Entropic regularized optimal transport
-
"""
a, b, M, G0 = list_to_array(a, b, M, G0)
if isinstance(M, int) or isinstance(M, float):
@@ -265,42 +241,45 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000,
if G0 is None:
G = nx.outer(a, b)
else:
- G = G0
-
- def cost(G):
- return nx.sum(M * G) + reg * f(G)
+ # to not change G0 in place.
+ G = nx.copy(G0)
- f_val = cost(G)
+ if reg2 is None:
+ def cost(G):
+ return nx.sum(M * G) + reg1 * f(G)
+ else:
+ def cost(G):
+ return nx.sum(M * G) + reg1 * f(G) + reg2 * nx.sum(G * nx.log(G))
+ cost_G = cost(G)
if log:
- log['loss'].append(f_val)
+ log['loss'].append(cost_G)
it = 0
if verbose:
print('{:5s}|{:12s}|{:8s}|{:8s}'.format(
'It.', 'Loss', 'Relative loss', 'Absolute loss') + '\n' + '-' * 48)
- print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, f_val, 0, 0))
+ print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, cost_G, 0, 0))
while loop:
it += 1
- old_fval = f_val
-
+ old_cost_G = cost_G
# problem linearization
- Mi = M + reg * df(G)
+ Mi = M + reg1 * df(G)
+
+ if not (reg2 is None):
+ Mi = Mi + reg2 * (1 + nx.log(G))
# set M positive
- Mi += nx.min(Mi)
+ Mi = Mi + nx.min(Mi)
# solve linear program
- Gc, logemd = emd(a, b, Mi, numItermax=numItermaxEmd, log=True)
+ Gc, innerlog_ = lp_solver(a, b, Mi, **kwargs)
+ # line search
deltaG = Gc - G
- # line search
- alpha, fc, f_val = solve_linesearch(
- cost, G, deltaG, Mi, f_val, reg=reg, M=M, Gc=Gc,
- alpha_min=0., alpha_max=1., **kwargs
- )
+ alpha, fc, cost_G = line_search(cost, G, deltaG, Mi, cost_G, **kwargs)
G = G + alpha * deltaG
@@ -308,29 +287,197 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000,
if it >= numItermax:
loop = 0
- abs_delta_fval = abs(f_val - old_fval)
- relative_delta_fval = abs_delta_fval / abs(f_val)
- if relative_delta_fval < stopThr or abs_delta_fval < stopThr2:
+ abs_delta_cost_G = abs(cost_G - old_cost_G)
+ relative_delta_cost_G = abs_delta_cost_G / abs(cost_G)
+ if relative_delta_cost_G < stopThr or abs_delta_cost_G < stopThr2:
loop = 0
if log:
- log['loss'].append(f_val)
+ log['loss'].append(cost_G)
if verbose:
if it % 20 == 0:
print('{:5s}|{:12s}|{:8s}|{:8s}'.format(
'It.', 'Loss', 'Relative loss', 'Absolute loss') + '\n' + '-' * 48)
- print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, f_val, relative_delta_fval, abs_delta_fval))
+ print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, cost_G, relative_delta_cost_G, abs_delta_cost_G))
if log:
- log.update(logemd)
+ log.update(innerlog_)
return G, log
else:
return G
+def cg(a, b, M, reg, f, df, G0=None, line_search=line_search_armijo,
+ numItermax=200, numItermaxEmd=100000, stopThr=1e-9, stopThr2=1e-9,
+ verbose=False, log=False, **kwargs):
+ r"""
+ Solve the general regularized OT problem with conditional gradient
+
+ The function solves the following optimization problem:
+
+ .. math::
+ \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
+ \mathrm{reg} \cdot f(\gamma)
+
+ s.t. \ \gamma \mathbf{1} &= \mathbf{a}
+
+ \gamma^T \mathbf{1} &= \mathbf{b}
+
+ \gamma &\geq 0
+ where :
+
+ - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix
+ - :math:`f` is the regularization term (and `df` is its gradient)
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1)
+
+ The algorithm used for solving the problem is conditional gradient as discussed in :ref:`[1] <references-cg>`
+
+
+ Parameters
+ ----------
+ a : array-like, shape (ns,)
+ samples weights in the source domain
+ b : array-like, shape (nt,)
+ samples in the target domain
+ M : array-like, shape (ns, nt)
+ loss matrix
+ reg : float
+ Regularization term >0
+ G0 : array-like, shape (ns,nt), optional
+ initial guess (default is indep joint density)
+ line_search: function,
+ Function to find the optimal step.
+ Default is line_search_armijo.
+ numItermax : int, optional
+ Max number of iterations
+ numItermaxEmd : int, optional
+ Max number of iterations for emd
+ stopThr : float, optional
+ Stop threshold on the relative variation (>0)
+ stopThr2 : float, optional
+ Stop threshold on the absolute variation (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+ **kwargs : dict
+ Parameters for linesearch
+
+ Returns
+ -------
+ gamma : (ns x nt) ndarray
+ Optimal transportation matrix for the given parameters
+ log : dict
+ log dictionary return only if log==True in parameters
+
+
+ .. _references-cg:
+ References
+ ----------
+
+ .. [1] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014). Regularized discrete optimal transport. SIAM Journal on Imaging Sciences, 7(3), 1853-1882.
+
+ See Also
+ --------
+ ot.lp.emd : Unregularized optimal ransport
+ ot.bregman.sinkhorn : Entropic regularized optimal transport
+
+ """
+
+ def lp_solver(a, b, M, **kwargs):
+ return emd(a, b, M, numItermaxEmd, log=True)
+
+ return generic_conditional_gradient(a, b, M, f, df, reg, None, lp_solver, line_search, G0=G0,
+ numItermax=numItermax, stopThr=stopThr,
+ stopThr2=stopThr2, verbose=verbose, log=log, **kwargs)
+
+
+def semirelaxed_cg(a, b, M, reg, f, df, G0=None, line_search=line_search_armijo,
+ numItermax=200, stopThr=1e-9, stopThr2=1e-9, verbose=False, log=False, **kwargs):
+ r"""
+ Solve the general regularized and semi-relaxed OT problem with conditional gradient
+
+ The function solves the following optimization problem:
+
+ .. math::
+ \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
+ \mathrm{reg} \cdot f(\gamma)
+
+ s.t. \ \gamma \mathbf{1} &= \mathbf{a}
+
+ \gamma &\geq 0
+ where :
+
+ - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix
+ - :math:`f` is the regularization term (and `df` is its gradient)
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1)
+
+ The algorithm used for solving the problem is conditional gradient as discussed in :ref:`[1] <references-cg>`
+
+
+ Parameters
+ ----------
+ a : array-like, shape (ns,)
+ samples weights in the source domain
+ b : array-like, shape (nt,)
+ currently estimated samples weights in the target domain
+ M : array-like, shape (ns, nt)
+ loss matrix
+ reg : float
+ Regularization term >0
+ G0 : array-like, shape (ns,nt), optional
+ initial guess (default is indep joint density)
+ line_search: function,
+ Function to find the optimal step.
+ Default is the armijo line-search.
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshold on the relative variation (>0)
+ stopThr2 : float, optional
+ Stop threshold on the absolute variation (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+ **kwargs : dict
+ Parameters for linesearch
+
+ Returns
+ -------
+ gamma : (ns x nt) ndarray
+ Optimal transportation matrix for the given parameters
+ log : dict
+ log dictionary return only if log==True in parameters
+
+
+ .. _references-cg:
+ References
+ ----------
+
+ .. [48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty.
+ "Semi-relaxed Gromov-Wasserstein divergence and applications on graphs"
+ International Conference on Learning Representations (ICLR), 2021.
+
+ """
+
+ nx = get_backend(a, b)
+
+ def lp_solver(a, b, Mi, **kwargs):
+ # get minimum by rows as binary mask
+ Gc = nx.ones(1, type_as=a) * (Mi == nx.reshape(nx.min(Mi, axis=1), (-1, 1)))
+ Gc *= nx.reshape((a / nx.sum(Gc, axis=1)), (-1, 1))
+ # return by default an empty inner_log
+ return Gc, {}
+
+ return generic_conditional_gradient(a, b, M, f, df, reg, None, lp_solver, line_search, G0=G0,
+ numItermax=numItermax, stopThr=stopThr,
+ stopThr2=stopThr2, verbose=verbose, log=log, **kwargs)
+
+
def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10,
- numInnerItermax=200, stopThr=1e-9, stopThr2=1e-9, verbose=False, log=False):
+ numInnerItermax=200, stopThr=1e-9, stopThr2=1e-9, verbose=False, log=False, **kwargs):
r"""
Solve the general regularized OT problem with the generalized conditional gradient
@@ -403,81 +550,18 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10,
ot.optim.cg : conditional gradient
"""
- a, b, M, G0 = list_to_array(a, b, M, G0)
- nx = get_backend(a, b, M)
-
- loop = 1
-
- if log:
- log = {'loss': []}
-
- if G0 is None:
- G = nx.outer(a, b)
- else:
- G = G0
-
- def cost(G):
- return nx.sum(M * G) + reg1 * nx.sum(G * nx.log(G)) + reg2 * f(G)
-
- f_val = cost(G)
- if log:
- log['loss'].append(f_val)
-
- it = 0
-
- if verbose:
- print('{:5s}|{:12s}|{:8s}|{:8s}'.format(
- 'It.', 'Loss', 'Relative loss', 'Absolute loss') + '\n' + '-' * 48)
- print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, f_val, 0, 0))
-
- while loop:
-
- it += 1
- old_fval = f_val
-
- # problem linearization
- Mi = M + reg2 * df(G)
-
- # solve linear program with Sinkhorn
- # Gc = sinkhorn_stabilized(a,b, Mi, reg1, numItermax = numInnerItermax)
- Gc = sinkhorn(a, b, Mi, reg1, numItermax=numInnerItermax)
- deltaG = Gc - G
-
- # line search
- dcost = Mi + reg1 * (1 + nx.log(G)) # ??
- alpha, fc, f_val = line_search_armijo(
- cost, G, deltaG, dcost, f_val, alpha_min=0., alpha_max=1.
- )
+ def lp_solver(a, b, Mi, **kwargs):
+ return sinkhorn(a, b, Mi, reg1, numItermax=numInnerItermax, log=True, **kwargs)
- G = G + alpha * deltaG
+ def line_search(cost, G, deltaG, Mi, cost_G, **kwargs):
+ return line_search_armijo(cost, G, deltaG, Mi, cost_G, **kwargs)
- # test convergence
- if it >= numItermax:
- loop = 0
+ return generic_conditional_gradient(a, b, M, f, df, reg2, reg1, lp_solver, line_search, G0=G0,
+ numItermax=numItermax, stopThr=stopThr, stopThr2=stopThr2, verbose=verbose, log=log, **kwargs)
- abs_delta_fval = abs(f_val - old_fval)
- relative_delta_fval = abs_delta_fval / abs(f_val)
- if relative_delta_fval < stopThr or abs_delta_fval < stopThr2:
- loop = 0
-
- if log:
- log['loss'].append(f_val)
-
- if verbose:
- if it % 20 == 0:
- print('{:5s}|{:12s}|{:8s}|{:8s}'.format(
- 'It.', 'Loss', 'Relative loss', 'Absolute loss') + '\n' + '-' * 48)
- print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, f_val, relative_delta_fval, abs_delta_fval))
-
- if log:
- return G, log
- else:
- return G
-
-
-def solve_1d_linesearch_quad(a, b, c):
+def solve_1d_linesearch_quad(a, b):
r"""
For any convex or non-convex 1d quadratic function `f`, solve the following problem:
@@ -487,7 +571,7 @@ def solve_1d_linesearch_quad(a, b, c):
Parameters
----------
- a,b,c : float
+ a,b : float or tensors (1,)
The coefficients of the quadratic function
Returns
@@ -495,15 +579,11 @@ def solve_1d_linesearch_quad(a, b, c):
x : float
The optimal value which leads to the minimal cost
"""
- f0 = c
- df0 = b
- f1 = a + f0 + df0
-
if a > 0: # convex
- minimum = min(1, max(0, np.divide(-b, 2.0 * a)))
+ minimum = min(1., max(0., -b / (2.0 * a)))
return minimum
else: # non convex
- if f0 > f1:
- return 1
+ if a + b < 0:
+ return 1.
else:
- return 0
+ return 0.
diff --git a/ot/partial.py b/ot/partial.py
index 0a9e450..bf4119d 100755
--- a/ot/partial.py
+++ b/ot/partial.py
@@ -8,6 +8,8 @@ Partial OT solvers
import numpy as np
from .lp import emd
+from .backend import get_backend
+from .utils import list_to_array
def partial_wasserstein_lagrange(a, b, M, reg_m=None, nb_dummies=1, log=False,
@@ -114,14 +116,22 @@ def partial_wasserstein_lagrange(a, b, M, reg_m=None, nb_dummies=1, log=False,
ot.partial.partial_wasserstein : Partial Wasserstein with fixed mass
"""
- if np.sum(a) > 1 or np.sum(b) > 1:
+ a, b, M = list_to_array(a, b, M)
+
+ nx = get_backend(a, b, M)
+
+ if nx.sum(a) > 1 + 1e-15 or nx.sum(b) > 1 + 1e-15: # 1e-15 for numerical errors
raise ValueError("Problem infeasible. Check that a and b are in the "
"simplex")
if reg_m is None:
- reg_m = np.max(M) + 1
- if reg_m < -np.max(M):
- return np.zeros((len(a), len(b)))
+ reg_m = float(nx.max(M)) + 1
+ if reg_m < -nx.max(M):
+ return nx.zeros((len(a), len(b)), type_as=M)
+
+ a0, b0, M0 = a, b, M
+ # convert to humpy
+ a, b, M = nx.to_numpy(a, b, M)
eps = 1e-20
M = np.asarray(M, dtype=np.float64)
@@ -149,10 +159,16 @@ def partial_wasserstein_lagrange(a, b, M, reg_m=None, nb_dummies=1, log=False,
gamma = np.zeros((len(a), len(b)))
gamma[np.ix_(idx_x, idx_y)] = gamma_extended[:-nb_dummies, :-nb_dummies]
+ # convert back to backend
+ gamma = nx.from_numpy(gamma, type_as=M0)
+
if log_emd['warning'] is not None:
raise ValueError("Error in the EMD resolution: try to increase the"
" number of dummy points")
- log_emd['cost'] = np.sum(gamma * M)
+ log_emd['cost'] = nx.sum(gamma * M0)
+ log_emd['u'] = nx.from_numpy(log_emd['u'], type_as=a0)
+ log_emd['v'] = nx.from_numpy(log_emd['v'], type_as=b0)
+
if log:
return gamma, log_emd
else:
@@ -250,32 +266,52 @@ def partial_wasserstein(a, b, M, m=None, nb_dummies=1, log=False, **kwargs):
entropic regularization parameter
"""
+ a, b, M = list_to_array(a, b, M)
+
+ nx = get_backend(a, b, M)
+
+ dim_a, dim_b = M.shape
+ if len(a) == 0:
+ a = nx.ones(dim_a, type_as=a) / dim_a
+ if len(b) == 0:
+ b = nx.ones(dim_b, type_as=b) / dim_b
+
if m is None:
return partial_wasserstein_lagrange(a, b, M, log=log, **kwargs)
elif m < 0:
raise ValueError("Problem infeasible. Parameter m should be greater"
" than 0.")
- elif m > np.min((np.sum(a), np.sum(b))):
+ elif m > nx.min(nx.stack((nx.sum(a), nx.sum(b)))):
raise ValueError("Problem infeasible. Parameter m should lower or"
" equal than min(|a|_1, |b|_1).")
- b_extended = np.append(b, [(np.sum(a) - m) / nb_dummies] * nb_dummies)
- a_extended = np.append(a, [(np.sum(b) - m) / nb_dummies] * nb_dummies)
- M_extended = np.zeros((len(a_extended), len(b_extended)))
- M_extended[-nb_dummies:, -nb_dummies:] = np.max(M) * 2
- M_extended[:len(a), :len(b)] = M
+ b_extension = nx.ones(nb_dummies, type_as=b) * (nx.sum(a) - m) / nb_dummies
+ b_extended = nx.concatenate((b, b_extension))
+ a_extension = nx.ones(nb_dummies, type_as=a) * (nx.sum(b) - m) / nb_dummies
+ a_extended = nx.concatenate((a, a_extension))
+ M_extension = nx.ones((nb_dummies, nb_dummies), type_as=M) * nx.max(M) * 2
+ M_extended = nx.concatenate(
+ (nx.concatenate((M, nx.zeros((M.shape[0], M_extension.shape[1]))), axis=1),
+ nx.concatenate((nx.zeros((M_extension.shape[0], M.shape[1])), M_extension), axis=1)),
+ axis=0
+ )
gamma, log_emd = emd(a_extended, b_extended, M_extended, log=True,
**kwargs)
+
+ gamma = gamma[:len(a), :len(b)]
+
if log_emd['warning'] is not None:
raise ValueError("Error in the EMD resolution: try to increase the"
" number of dummy points")
- log_emd['partial_w_dist'] = np.sum(M * gamma[:len(a), :len(b)])
+ log_emd['partial_w_dist'] = nx.sum(M * gamma)
+ log_emd['u'] = log_emd['u'][:len(a)]
+ log_emd['v'] = log_emd['v'][:len(b)]
if log:
- return gamma[:len(a), :len(b)], log_emd
+ return gamma, log_emd
else:
- return gamma[:len(a), :len(b)]
+ return gamma
def partial_wasserstein2(a, b, M, m=None, nb_dummies=1, log=False, **kwargs):
@@ -360,14 +396,18 @@ def partial_wasserstein2(a, b, M, m=None, nb_dummies=1, log=False, **kwargs):
NeurIPS.
"""
+ a, b, M = list_to_array(a, b, M)
+
+ nx = get_backend(a, b, M)
+
partial_gw, log_w = partial_wasserstein(a, b, M, m, nb_dummies, log=True,
**kwargs)
log_w['T'] = partial_gw
if log:
- return np.sum(partial_gw * M), log_w
+ return nx.sum(partial_gw * M), log_w
else:
- return np.sum(partial_gw * M)
+ return nx.sum(partial_gw * M)
def gwgrad_partial(C1, C2, T):
@@ -809,60 +849,64 @@ def entropic_partial_wasserstein(a, b, M, reg, m=None, numItermax=1000,
ot.partial.partial_wasserstein: exact Partial Wasserstein
"""
- a = np.asarray(a, dtype=np.float64)
- b = np.asarray(b, dtype=np.float64)
- M = np.asarray(M, dtype=np.float64)
+ a, b, M = list_to_array(a, b, M)
+
+ nx = get_backend(a, b, M)
dim_a, dim_b = M.shape
- dx = np.ones(dim_a, dtype=np.float64)
- dy = np.ones(dim_b, dtype=np.float64)
+ dx = nx.ones(dim_a, type_as=a)
+ dy = nx.ones(dim_b, type_as=b)
if len(a) == 0:
- a = np.ones(dim_a, dtype=np.float64) / dim_a
+ a = nx.ones(dim_a, type_as=a) / dim_a
if len(b) == 0:
- b = np.ones(dim_b, dtype=np.float64) / dim_b
+ b = nx.ones(dim_b, type_as=b) / dim_b
if m is None:
- m = np.min((np.sum(a), np.sum(b))) * 1.0
+ m = nx.min(nx.stack((nx.sum(a), nx.sum(b)))) * 1.0
if m < 0:
raise ValueError("Problem infeasible. Parameter m should be greater"
" than 0.")
- if m > np.min((np.sum(a), np.sum(b))):
+ if m > nx.min(nx.stack((nx.sum(a), nx.sum(b)))):
raise ValueError("Problem infeasible. Parameter m should lower or"
" equal than min(|a|_1, |b|_1).")
log_e = {'err': []}
- # Next 3 lines equivalent to K=np.exp(-M/reg), but faster to compute
- K = np.empty(M.shape, dtype=M.dtype)
- np.divide(M, -reg, out=K)
- np.exp(K, out=K)
- np.multiply(K, m / np.sum(K), out=K)
+ if type(a) == type(b) == type(M) == np.ndarray:
+ # Next 3 lines equivalent to K=nx.exp(-M/reg), but faster to compute
+ K = np.empty(M.shape, dtype=M.dtype)
+ np.divide(M, -reg, out=K)
+ np.exp(K, out=K)
+ np.multiply(K, m / np.sum(K), out=K)
+ else:
+ K = nx.exp(-M / reg)
+ K = K * m / nx.sum(K)
err, cpt = 1, 0
- q1 = np.ones(K.shape)
- q2 = np.ones(K.shape)
- q3 = np.ones(K.shape)
+ q1 = nx.ones(K.shape, type_as=K)
+ q2 = nx.ones(K.shape, type_as=K)
+ q3 = nx.ones(K.shape, type_as=K)
while (err > stopThr and cpt < numItermax):
Kprev = K
K = K * q1
- K1 = np.dot(np.diag(np.minimum(a / np.sum(K, axis=1), dx)), K)
+ K1 = nx.dot(nx.diag(nx.minimum(a / nx.sum(K, axis=1), dx)), K)
q1 = q1 * Kprev / K1
K1prev = K1
K1 = K1 * q2
- K2 = np.dot(K1, np.diag(np.minimum(b / np.sum(K1, axis=0), dy)))
+ K2 = nx.dot(K1, nx.diag(nx.minimum(b / nx.sum(K1, axis=0), dy)))
q2 = q2 * K1prev / K2
K2prev = K2
K2 = K2 * q3
- K = K2 * (m / np.sum(K2))
+ K = K2 * (m / nx.sum(K2))
q3 = q3 * K2prev / K
- if np.any(np.isnan(K)) or np.any(np.isinf(K)):
+ if nx.any(nx.isnan(K)) or nx.any(nx.isinf(K)):
print('Warning: numerical errors at iteration', cpt)
break
if cpt % 10 == 0:
- err = np.linalg.norm(Kprev - K)
+ err = nx.norm(Kprev - K)
if log:
log_e['err'].append(err)
if verbose:
@@ -872,7 +916,7 @@ def entropic_partial_wasserstein(a, b, M, reg, m=None, numItermax=1000,
print('{:5d}|{:8e}|'.format(cpt, err))
cpt = cpt + 1
- log_e['partial_w_dist'] = np.sum(M * K)
+ log_e['partial_w_dist'] = nx.sum(M * K)
if log:
return K, log_e
else:
diff --git a/ot/sliced.py b/ot/sliced.py
index cf2d3be..077ff0b 100644
--- a/ot/sliced.py
+++ b/ot/sliced.py
@@ -12,7 +12,8 @@ Sliced OT Distances
import numpy as np
from .backend import get_backend, NumpyBackend
-from .utils import list_to_array
+from .utils import list_to_array, get_coordinate_circle
+from .lp import wasserstein_circle, semidiscrete_wasserstein2_unif_circle
def get_random_projections(d, n_projections, seed=None, backend=None, type_as=None):
@@ -107,7 +108,6 @@ def sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, p=2,
--------
>>> n_samples_a = 20
- >>> reg = 0.1
>>> X = np.random.normal(0., 1., (n_samples_a, 5))
>>> sliced_wasserstein_distance(X, X, seed=0) # doctest: +NORMALIZE_WHITESPACE
0.0
@@ -147,6 +147,8 @@ def sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, p=2,
if projections is None:
projections = get_random_projections(d, n_projections, seed, backend=nx, type_as=X_s)
+ else:
+ n_projections = projections.shape[1]
X_s_projections = nx.dot(X_s, projections)
X_t_projections = nx.dot(X_t, projections)
@@ -206,7 +208,6 @@ def max_sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50,
--------
>>> n_samples_a = 20
- >>> reg = 0.1
>>> X = np.random.normal(0., 1., (n_samples_a, 5))
>>> sliced_wasserstein_distance(X, X, seed=0) # doctest: +NORMALIZE_WHITESPACE
0.0
@@ -256,3 +257,183 @@ def max_sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50,
if log:
return res, {"projections": projections, "projected_emds": projected_emd}
return res
+
+
+def sliced_wasserstein_sphere(X_s, X_t, a=None, b=None, n_projections=50,
+ p=2, seed=None, log=False):
+ r"""
+ Compute the spherical sliced-Wasserstein discrepancy.
+
+ .. math::
+ SSW_p(\mu,\nu) = \left(\int_{\mathbb{V}_{d,2}} W_p^p(P^U_\#\mu, P^U_\#\nu)\ \mathrm{d}\sigma(U)\right)^{\frac{1}{p}}
+
+ where:
+
+ - :math:`P^U_\# \mu` stands for the pushforwards of the projection :math:`\forall x\in S^{d-1},\ P^U(x) = \frac{U^Tx}{\|U^Tx\|_2}`
+
+ The function runs on backend but tensorflow is not supported.
+
+ Parameters
+ ----------
+ X_s: ndarray, shape (n_samples_a, dim)
+ Samples in the source domain
+ X_t: ndarray, shape (n_samples_b, dim)
+ Samples in the target domain
+ a : ndarray, shape (n_samples_a,), optional
+ samples weights in the source domain
+ b : ndarray, shape (n_samples_b,), optional
+ samples weights in the target domain
+ n_projections : int, optional
+ Number of projections used for the Monte-Carlo approximation
+ p: float, optional (default=2)
+ Power p used for computing the spherical sliced Wasserstein
+ seed: int or RandomState or None, optional
+ Seed used for random number generator
+ log: bool, optional
+ if True, sliced_wasserstein_sphere returns the projections used and their associated EMD.
+
+ Returns
+ -------
+ cost: float
+ Spherical Sliced Wasserstein Cost
+ log: dict, optional
+ log dictionary return only if log==True in parameters
+
+ Examples
+ --------
+ >>> n_samples_a = 20
+ >>> X = np.random.normal(0., 1., (n_samples_a, 5))
+ >>> X = X / np.sqrt(np.sum(X**2, -1, keepdims=True))
+ >>> sliced_wasserstein_sphere(X, X, seed=0) # doctest: +NORMALIZE_WHITESPACE
+ 0.0
+
+ References
+ ----------
+ .. [46] Bonet, C., Berg, P., Courty, N., Septier, F., Drumetz, L., & Pham, M. T. (2023). Spherical sliced-wasserstein. International Conference on Learning Representations.
+ """
+ if a is not None and b is not None:
+ nx = get_backend(X_s, X_t, a, b)
+ else:
+ nx = get_backend(X_s, X_t)
+
+ n, d = X_s.shape
+ m, _ = X_t.shape
+
+ if X_s.shape[1] != X_t.shape[1]:
+ raise ValueError(
+ "X_s and X_t must have the same number of dimensions {} and {} respectively given".format(X_s.shape[1],
+ X_t.shape[1]))
+ if nx.any(nx.abs(nx.sum(X_s**2, axis=-1) - 1) > 10**(-4)):
+ raise ValueError("X_s is not on the sphere.")
+ if nx.any(nx.abs(nx.sum(X_t**2, axis=-1) - 1) > 10**(-4)):
+ raise ValueError("Xt is not on the sphere.")
+
+ # Uniforms and independent samples on the Stiefel manifold V_{d,2}
+ if isinstance(seed, np.random.RandomState) and str(nx) == 'numpy':
+ Z = seed.randn(n_projections, d, 2)
+ else:
+ if seed is not None:
+ nx.seed(seed)
+ Z = nx.randn(n_projections, d, 2, type_as=X_s)
+
+ projections, _ = nx.qr(Z)
+
+ # Projection on S^1
+ # Projection on plane
+ Xps = nx.transpose(nx.reshape(nx.dot(nx.transpose(projections, (0, 2, 1))[:, None], X_s[:, :, None]), (n_projections, 2, n)), (0, 2, 1))
+ Xpt = nx.transpose(nx.reshape(nx.dot(nx.transpose(projections, (0, 2, 1))[:, None], X_t[:, :, None]), (n_projections, 2, m)), (0, 2, 1))
+
+ # Projection on sphere
+ Xps = Xps / nx.sqrt(nx.sum(Xps**2, -1, keepdims=True))
+ Xpt = Xpt / nx.sqrt(nx.sum(Xpt**2, -1, keepdims=True))
+
+ # Get coordinates on [0,1[
+ Xps_coords = nx.reshape(get_coordinate_circle(nx.reshape(Xps, (-1, 2))), (n_projections, n))
+ Xpt_coords = nx.reshape(get_coordinate_circle(nx.reshape(Xpt, (-1, 2))), (n_projections, m))
+
+ projected_emd = wasserstein_circle(Xps_coords.T, Xpt_coords.T, u_weights=a, v_weights=b, p=p)
+ res = nx.mean(projected_emd) ** (1 / p)
+
+ if log:
+ return res, {"projections": projections, "projected_emds": projected_emd}
+ return res
+
+
+def sliced_wasserstein_sphere_unif(X_s, a=None, n_projections=50, seed=None, log=False):
+ r"""Compute the 2-spherical sliced wasserstein w.r.t. a uniform distribution.
+
+ .. math::
+ SSW_2(\mu_n, \nu)
+
+ where
+
+ - :math:`\mu_n=\sum_{i=1}^n \alpha_i \delta_{x_i}`
+ - :math:`\nu=\mathrm{Unif}(S^1)`
+
+ Parameters
+ ----------
+ X_s: ndarray, shape (n_samples_a, dim)
+ Samples in the source domain
+ a : ndarray, shape (n_samples_a,), optional
+ samples weights in the source domain
+ n_projections : int, optional
+ Number of projections used for the Monte-Carlo approximation
+ seed: int or RandomState or None, optional
+ Seed used for random number generator
+ log: bool, optional
+ if True, sliced_wasserstein_distance returns the projections used and their associated EMD.
+
+ Returns
+ -------
+ cost: float
+ Spherical Sliced Wasserstein Cost
+ log: dict, optional
+ log dictionary return only if log==True in parameters
+
+ Examples
+ ---------
+ >>> np.random.seed(42)
+ >>> x0 = np.random.randn(500,3)
+ >>> x0 = x0 / np.sqrt(np.sum(x0**2, -1, keepdims=True))
+ >>> ssw = sliced_wasserstein_sphere_unif(x0, seed=42)
+ >>> np.allclose(sliced_wasserstein_sphere_unif(x0, seed=42), 0.01734, atol=1e-3)
+ True
+
+ References:
+ -----------
+ .. [46] Bonet, C., Berg, P., Courty, N., Septier, F., Drumetz, L., & Pham, M. T. (2023). Spherical sliced-wasserstein. International Conference on Learning Representations.
+ """
+ if a is not None:
+ nx = get_backend(X_s, a)
+ else:
+ nx = get_backend(X_s)
+
+ n, d = X_s.shape
+
+ if nx.any(nx.abs(nx.sum(X_s**2, axis=-1) - 1) > 10**(-4)):
+ raise ValueError("X_s is not on the sphere.")
+
+ # Uniforms and independent samples on the Stiefel manifold V_{d,2}
+ if isinstance(seed, np.random.RandomState) and str(nx) == 'numpy':
+ Z = seed.randn(n_projections, d, 2)
+ else:
+ if seed is not None:
+ nx.seed(seed)
+ Z = nx.randn(n_projections, d, 2, type_as=X_s)
+
+ projections, _ = nx.qr(Z)
+
+ # Projection on S^1
+ # Projection on plane
+ Xps = nx.transpose(nx.reshape(nx.dot(nx.transpose(projections, (0, 2, 1))[:, None], X_s[:, :, None]), (n_projections, 2, n)), (0, 2, 1))
+ # Projection on sphere
+ Xps = Xps / nx.sqrt(nx.sum(Xps**2, -1, keepdims=True))
+ # Get coordinates on [0,1[
+ Xps_coords = nx.reshape(get_coordinate_circle(nx.reshape(Xps, (-1, 2))), (n_projections, n))
+
+ projected_emd = semidiscrete_wasserstein2_unif_circle(Xps_coords.T, u_weights=a)
+ res = nx.mean(projected_emd) ** (1 / 2)
+
+ if log:
+ return res, {"projections": projections, "projected_emds": projected_emd}
+ return res
diff --git a/ot/smooth.py b/ot/smooth.py
index 6855005..8e0ef38 100644
--- a/ot/smooth.py
+++ b/ot/smooth.py
@@ -44,6 +44,7 @@ Original code from https://github.com/mblondel/smooth-ot/
import numpy as np
from scipy.optimize import minimize
+from .backend import get_backend
def projection_simplex(V, z=1, axis=None):
@@ -511,6 +512,8 @@ def smooth_ot_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=1e-9,
"""
+ nx = get_backend(a, b, M)
+
if reg_type.lower() in ['l2', 'squaredl2']:
regul = SquaredL2(gamma=reg)
elif reg_type.lower() in ['entropic', 'negentropy', 'kl']:
@@ -518,15 +521,19 @@ def smooth_ot_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=1e-9,
else:
raise NotImplementedError('Unknown regularization')
+ a0, b0, M0 = a, b, M
+ # convert to humpy
+ a, b, M = nx.to_numpy(a, b, M)
+
# solve dual
alpha, beta, res = solve_dual(a, b, M, regul, max_iter=numItermax,
tol=stopThr, verbose=verbose)
# reconstruct transport matrix
- G = get_plan_from_dual(alpha, beta, M, regul)
+ G = nx.from_numpy(get_plan_from_dual(alpha, beta, M, regul), type_as=M0)
if log:
- log = {'alpha': alpha, 'beta': beta, 'res': res}
+ log = {'alpha': nx.from_numpy(alpha, type_as=a0), 'beta': nx.from_numpy(beta, type_as=b0), 'res': res}
return G, log
else:
return G
diff --git a/ot/solvers.py b/ot/solvers.py
new file mode 100644
index 0000000..0294d71
--- /dev/null
+++ b/ot/solvers.py
@@ -0,0 +1,347 @@
+# -*- coding: utf-8 -*-
+"""
+General OT solvers with unified API
+"""
+
+# Author: Remi Flamary <remi.flamary@polytechnique.edu>
+#
+# License: MIT License
+
+from .utils import OTResult
+from .lp import emd2
+from .backend import get_backend
+from .unbalanced import mm_unbalanced, sinkhorn_knopp_unbalanced, lbfgsb_unbalanced
+from .bregman import sinkhorn_log
+from .partial import partial_wasserstein_lagrange
+from .smooth import smooth_ot_dual
+
+
+def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None,
+ unbalanced_type='KL', n_threads=1, max_iter=None, plan_init=None,
+ potentials_init=None, tol=None, verbose=False):
+ r"""Solve the discrete optimal transport problem and return :any:`OTResult` object
+
+ The function solves the following general optimal transport problem
+
+ .. math::
+ \min_{\mathbf{T}\geq 0} \quad \sum_{i,j} T_{i,j}M_{i,j} + \lambda_r R(\mathbf{T}) +
+ \lambda_u U(\mathbf{T}\mathbf{1},\mathbf{a}) +
+ \lambda_u U(\mathbf{T}^T\mathbf{1},\mathbf{b})
+
+ The regularization is selected with :any:`reg` (:math:`\lambda_r`) and :any:`reg_type`. By
+ default ``reg=None`` and there is no regularization. The unbalanced marginal
+ penalization can be selected with :any:`unbalanced` (:math:`\lambda_u`) and
+ :any:`unbalanced_type`. By default ``unbalanced=None`` and the function
+ solves the exact optimal transport problem (respecting the marginals).
+
+ Parameters
+ ----------
+ M : array_like, shape (dim_a, dim_b)
+ Loss matrix
+ a : array-like, shape (dim_a,), optional
+ Samples weights in the source domain (default is uniform)
+ b : array-like, shape (dim_b,), optional
+ Samples weights in the source domain (default is uniform)
+ reg : float, optional
+ Regularization weight :math:`\lambda_r`, by default None (no reg., exact
+ OT)
+ reg_type : str, optional
+ Type of regularization :math:`R` either "KL", "L2", 'entropy', by default "KL"
+ unbalanced : float, optional
+ Unbalanced penalization weight :math:`\lambda_u`, by default None
+ (balanced OT)
+ unbalanced_type : str, optional
+ Type of unbalanced penalization unction :math:`U` either "KL", "L2", 'TV', by default 'KL'
+ n_threads : int, optional
+ Number of OMP threads for exact OT solver, by default 1
+ max_iter : int, optional
+ Maximum number of iteration, by default None (default values in each solvers)
+ plan_init : array_like, shape (dim_a, dim_b), optional
+ Initialization of the OT plan for iterative methods, by default None
+ potentials_init : (array_like(dim_a,),array_like(dim_b,)), optional
+ Initialization of the OT dual potentials for iterative methods, by default None
+ tol : _type_, optional
+ Tolerance for solution precision, by default None (default values in each solvers)
+ verbose : bool, optional
+ Print information in the solver, by default False
+
+ Returns
+ -------
+ res : OTResult()
+ Result of the optimization problem. The information can be obtained as follows:
+
+ - res.plan : OT plan :math:`\mathbf{T}`
+ - res.potentials : OT dual potentials
+ - res.value : Optimal value of the optimization problem
+ - res.value_linear : Linear OT loss with the optimal OT plan
+
+ See :any:`OTResult` for more information.
+
+ Notes
+ -----
+
+ The following methods are available for solving the OT problems:
+
+ - **Classical exact OT problem** (default parameters):
+
+ .. math::
+ \min_\mathbf{T} \quad \langle \mathbf{T}, \mathbf{M} \rangle_F
+
+ s.t. \ \mathbf{T} \mathbf{1} = \mathbf{a}
+
+ \mathbf{T}^T \mathbf{1} = \mathbf{b}
+
+ \mathbf{T} \geq 0
+
+ can be solved with the following code:
+
+ .. code-block:: python
+
+ res = ot.solve(M, a, b)
+
+ - **Entropic regularized OT** (when ``reg!=None``):
+
+ .. math::
+ \min_\mathbf{T} \quad \langle \mathbf{T}, \mathbf{M} \rangle_F + \lambda R(\mathbf{T})
+
+ s.t. \ \mathbf{T} \mathbf{1} = \mathbf{a}
+
+ \mathbf{T}^T \mathbf{1} = \mathbf{b}
+
+ \mathbf{T} \geq 0
+
+ can be solved with the following code:
+
+ .. code-block:: python
+
+ # default is ``"KL"`` regularization (``reg_type="KL"``)
+ res = ot.solve(M, a, b, reg=1.0)
+ # or for original Sinkhorn paper formulation [2]
+ res = ot.solve(M, a, b, reg=1.0, reg_type='entropy')
+
+ - **Quadratic regularized OT** (when ``reg!=None`` and ``reg_type="L2"``):
+
+ .. math::
+ \min_\mathbf{T} \quad \langle \mathbf{T}, \mathbf{M} \rangle_F + \lambda R(\mathbf{T})
+
+ s.t. \ \mathbf{T} \mathbf{1} = \mathbf{a}
+
+ \mathbf{T}^T \mathbf{1} = \mathbf{b}
+
+ \mathbf{T} \geq 0
+
+ can be solved with the following code:
+
+ .. code-block:: python
+
+ res = ot.solve(M,a,b,reg=1.0,reg_type='L2')
+
+ - **Unbalanced OT** (when ``unbalanced!=None``):
+
+ .. math::
+ \min_{\mathbf{T}\geq 0} \quad \sum_{i,j} T_{i,j}M_{i,j} + \lambda_u U(\mathbf{T}\mathbf{1},\mathbf{a}) + \lambda_u U(\mathbf{T}^T\mathbf{1},\mathbf{b})
+
+ can be solved with the following code:
+
+ .. code-block:: python
+
+ # default is ``"KL"``
+ res = ot.solve(M,a,b,reg=1.0,unbalanced=1.0)
+ # quadratic unbalanced OT
+ res = ot.solve(M,a,b,reg=1.0,unbalanced=1.0,unbalanced_type='L2')
+ # TV = partial OT
+ res = ot.solve(M,a,b,reg=1.0,unbalanced=1.0,unbalanced_type='TV')
+
+
+ - **Regularized unbalanced regularized OT** (when ``unbalanced!=None`` and ``reg!=None``):
+
+ .. math::
+ \min_{\mathbf{T}\geq 0} \quad \sum_{i,j} T_{i,j}M_{i,j} + \lambda_r R(\mathbf{T}) + \lambda_u U(\mathbf{T}\mathbf{1},\mathbf{a}) + \lambda_u U(\mathbf{T}^T\mathbf{1},\mathbf{b})
+
+ can be solved with the following code:
+
+ .. code-block:: python
+
+ # default is ``"KL"`` for both
+ res = ot.solve(M,a,b,reg=1.0,unbalanced=1.0)
+ # quadratic unbalanced OT with KL regularization
+ res = ot.solve(M,a,b,reg=1.0,unbalanced=1.0,unbalanced_type='L2')
+ # both quadratic
+ res = ot.solve(M,a,b,reg=1.0, reg_type='L2',unbalanced=1.0,unbalanced_type='L2')
+
+
+ .. _references-solve:
+ References
+ ----------
+
+ .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation
+ of Optimal Transport, Advances in Neural Information Processing
+ Systems (NIPS) 26, 2013
+
+ .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
+ Scaling algorithms for unbalanced transport problems.
+ arXiv preprint arXiv:1607.05816.
+
+ .. [17] Blondel, M., Seguy, V., & Rolet, A. (2018). Smooth and Sparse
+ Optimal Transport. Proceedings of the Twenty-First International
+ Conference on Artificial Intelligence and Statistics (AISTATS).
+
+ .. [34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., Trouvé,
+ A., & Peyré, G. (2019, April). Interpolating between optimal transport
+ and MMD using Sinkhorn divergences. In The 22nd International Conference
+ on Artificial Intelligence and Statistics (pp. 2681-2690). PMLR.
+
+ """
+
+ # detect backend
+ arr = [M]
+ if a is not None:
+ arr.append(a)
+ if b is not None:
+ arr.append(b)
+ nx = get_backend(*arr)
+
+ # create uniform weights if not given
+ if a is None:
+ a = nx.ones(M.shape[0], type_as=M) / M.shape[0]
+ if b is None:
+ b = nx.ones(M.shape[1], type_as=M) / M.shape[1]
+
+ # default values for solutions
+ potentials = None
+ value = None
+ value_linear = None
+ plan = None
+ status = None
+
+ if reg is None or reg == 0: # exact OT
+
+ if unbalanced is None: # Exact balanced OT
+
+ # default values for EMD solver
+ if max_iter is None:
+ max_iter = 1000000
+
+ value_linear, log = emd2(a, b, M, numItermax=max_iter, log=True, return_matrix=True, numThreads=n_threads)
+
+ value = value_linear
+ potentials = (log['u'], log['v'])
+ plan = log['G']
+ status = log["warning"] if log["warning"] is not None else 'Converged'
+
+ elif unbalanced_type.lower() in ['kl', 'l2']: # unbalanced exact OT
+
+ # default values for exact unbalanced OT
+ if max_iter is None:
+ max_iter = 1000
+ if tol is None:
+ tol = 1e-12
+
+ plan, log = mm_unbalanced(a, b, M, reg_m=unbalanced,
+ div=unbalanced_type.lower(), numItermax=max_iter,
+ stopThr=tol, log=True,
+ verbose=verbose, G0=plan_init)
+
+ value_linear = log['cost']
+
+ if unbalanced_type.lower() == 'kl':
+ value = value_linear + unbalanced * (nx.kl_div(nx.sum(plan, 1), a) + nx.kl_div(nx.sum(plan, 0), b))
+ else:
+ err_a = nx.sum(plan, 1) - a
+ err_b = nx.sum(plan, 0) - b
+ value = value_linear + unbalanced * nx.sum(err_a**2) + unbalanced * nx.sum(err_b**2)
+
+ elif unbalanced_type.lower() == 'tv':
+
+ if max_iter is None:
+ max_iter = 1000000
+
+ plan, log = partial_wasserstein_lagrange(a, b, M, reg_m=unbalanced**2, log=True, numItermax=max_iter)
+
+ value_linear = nx.sum(M * plan)
+ err_a = nx.sum(plan, 1) - a
+ err_b = nx.sum(plan, 0) - b
+ value = value_linear + nx.sqrt(unbalanced**2 / 2.0 * (nx.sum(nx.abs(err_a)) +
+ nx.sum(nx.abs(err_b))))
+
+ else:
+ raise (NotImplementedError('Unknown unbalanced_type="{}"'.format(unbalanced_type)))
+
+ else: # regularized OT
+
+ if unbalanced is None: # Balanced regularized OT
+
+ if reg_type.lower() in ['entropy', 'kl']:
+
+ # default values for sinkhorn
+ if max_iter is None:
+ max_iter = 1000
+ if tol is None:
+ tol = 1e-9
+
+ plan, log = sinkhorn_log(a, b, M, reg=reg, numItermax=max_iter,
+ stopThr=tol, log=True,
+ verbose=verbose)
+
+ value_linear = nx.sum(M * plan)
+
+ if reg_type.lower() == 'entropy':
+ value = value_linear + reg * nx.sum(plan * nx.log(plan + 1e-16))
+ else:
+ value = value_linear + reg * nx.kl_div(plan, a[:, None] * b[None, :])
+
+ potentials = (log['log_u'], log['log_v'])
+
+ elif reg_type.lower() == 'l2':
+
+ if max_iter is None:
+ max_iter = 1000
+ if tol is None:
+ tol = 1e-9
+
+ plan, log = smooth_ot_dual(a, b, M, reg=reg, numItermax=max_iter, stopThr=tol, log=True, verbose=verbose)
+
+ value_linear = nx.sum(M * plan)
+ value = value_linear + reg * nx.sum(plan**2)
+ potentials = (log['alpha'], log['beta'])
+
+ else:
+ raise (NotImplementedError('Not implemented reg_type="{}"'.format(reg_type)))
+
+ else: # unbalanced AND regularized OT
+
+ if reg_type.lower() in ['kl'] and unbalanced_type.lower() == 'kl':
+
+ if max_iter is None:
+ max_iter = 1000
+ if tol is None:
+ tol = 1e-9
+
+ plan, log = sinkhorn_knopp_unbalanced(a, b, M, reg=reg, reg_m=unbalanced, numItermax=max_iter, stopThr=tol, verbose=verbose, log=True)
+
+ value_linear = nx.sum(M * plan)
+
+ value = value_linear + reg * nx.kl_div(plan, a[:, None] * b[None, :]) + unbalanced * (nx.kl_div(nx.sum(plan, 1), a) + nx.kl_div(nx.sum(plan, 0), b))
+
+ potentials = (log['logu'], log['logv'])
+
+ elif reg_type.lower() in ['kl', 'l2', 'entropy'] and unbalanced_type.lower() in ['kl', 'l2']:
+
+ if max_iter is None:
+ max_iter = 1000
+ if tol is None:
+ tol = 1e-12
+
+ plan, log = lbfgsb_unbalanced(a, b, M, reg=reg, reg_m=unbalanced, reg_div=reg_type.lower(), regm_div=unbalanced_type.lower(), numItermax=max_iter, stopThr=tol, verbose=verbose, log=True)
+
+ value_linear = nx.sum(M * plan)
+
+ value = log['loss']
+
+ else:
+ raise (NotImplementedError('Not implemented reg_type="{}" and unbalanced_type="{}"'.format(reg_type, unbalanced_type)))
+
+ res = OTResult(potentials=potentials, value=value,
+ value_linear=value_linear, plan=plan, status=status, backend=nx)
+
+ return res
diff --git a/ot/unbalanced.py b/ot/unbalanced.py
index 90c920c..a71a0dd 100644
--- a/ot/unbalanced.py
+++ b/ot/unbalanced.py
@@ -10,6 +10,9 @@ Regularized Unbalanced OT solvers
from __future__ import division
import warnings
+import numpy as np
+from scipy.optimize import minimize, Bounds
+
from .backend import get_backend
from .utils import list_to_array
# from .utils import unif, dist
@@ -269,7 +272,8 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn',
def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000,
stopThr=1e-6, verbose=False, log=False, **kwargs):
r"""
- Solve the entropic regularization unbalanced optimal transport problem and return the loss
+ Solve the entropic regularization unbalanced optimal transport problem and
+ return the OT plan
The function solves the following optimization problem:
@@ -734,7 +738,7 @@ def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3,
if weights is None:
weights = nx.ones(n_hists, type_as=A) / n_hists
else:
- assert(len(weights) == A.shape[1])
+ assert (len(weights) == A.shape[1])
if log:
log = {'err': []}
@@ -882,7 +886,7 @@ def barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, weights=None,
if weights is None:
weights = nx.ones(n_hists, type_as=A) / n_hists
else:
- assert(len(weights) == A.shape[1])
+ assert (len(weights) == A.shape[1])
if log:
log = {'err': []}
@@ -1252,3 +1256,182 @@ def mm_unbalanced2(a, b, M, reg_m, div='kl', G0=None, numItermax=1000,
return log_mm['cost'], log_mm
else:
return log_mm['cost']
+
+
+def _get_loss_unbalanced(a, b, M, reg, reg_m, reg_div='kl', regm_div='kl'):
+ """
+ return the loss function (scipy.optimize compatible) for regularized
+ unbalanced OT
+ """
+
+ m, n = M.shape
+
+ def kl(p, q):
+ return np.sum(p * np.log(p / q + 1e-16))
+
+ def reg_l2(G):
+ return np.sum((G - a[:, None] * b[None, :])**2) / 2
+
+ def grad_l2(G):
+ return G - a[:, None] * b[None, :]
+
+ def reg_kl(G):
+ return kl(G, a[:, None] * b[None, :])
+
+ def grad_kl(G):
+ return np.log(G / (a[:, None] * b[None, :]) + 1e-16) + 1
+
+ def reg_entropy(G):
+ return kl(G, 1)
+
+ def grad_entropy(G):
+ return np.log(G + 1e-16) + 1
+
+ if reg_div == 'kl':
+ reg_fun = reg_kl
+ grad_reg_fun = grad_kl
+ elif reg_div == 'entropy':
+ reg_fun = reg_entropy
+ grad_reg_fun = grad_entropy
+ else:
+ reg_fun = reg_l2
+ grad_reg_fun = grad_l2
+
+ def marg_l2(G):
+ return 0.5 * np.sum((G.sum(1) - a)**2) + 0.5 * np.sum((G.sum(0) - b)**2)
+
+ def grad_marg_l2(G):
+ return np.outer((G.sum(1) - a), np.ones(n)) + np.outer(np.ones(m), (G.sum(0) - b))
+
+ def marg_kl(G):
+ return kl(G.sum(1), a) + kl(G.sum(0), b)
+
+ def grad_marg_kl(G):
+ return np.outer(np.log(G.sum(1) / a + 1e-16) + 1, np.ones(n)) + np.outer(np.ones(m), np.log(G.sum(0) / b + 1e-16) + 1)
+
+ if regm_div == 'kl':
+ regm_fun = marg_kl
+ grad_regm_fun = grad_marg_kl
+ else:
+ regm_fun = marg_l2
+ grad_regm_fun = grad_marg_l2
+
+ def _func(G):
+ G = G.reshape((m, n))
+
+ # compute loss
+ val = np.sum(G * M) + reg * reg_fun(G) + reg_m * regm_fun(G)
+
+ # compute gradient
+ grad = M + reg * grad_reg_fun(G) + reg_m * grad_regm_fun(G)
+
+ return val, grad.ravel()
+
+ return _func
+
+
+def lbfgsb_unbalanced(a, b, M, reg, reg_m, reg_div='kl', regm_div='kl', G0=None, numItermax=1000,
+ stopThr=1e-15, method='L-BFGS-B', verbose=False, log=False):
+ r"""
+ Solve the unbalanced optimal transport problem and return the OT plan using L-BFGS-B.
+ The function solves the following optimization problem:
+
+ .. math::
+ W = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
+ + \mathrm{reg} \mathrm{div}(\gamma,\mathbf{a}\mathbf{b}^T)
+ \mathrm{reg_m} \cdot \mathrm{div_m}(\gamma \mathbf{1}, \mathbf{a}) +
+ \mathrm{reg_m} \cdot \mathrm{div}(\gamma^T \mathbf{1}, \mathbf{b})
+
+ s.t.
+ \gamma \geq 0
+
+ where:
+
+ - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target
+ unbalanced distributions
+ - :math:`\mathrm{div}` is a divergence, either Kullback-Leibler or :math:`\ell_2` divergence
+
+ The algorithm used for solving the problem is a L-BFGS-B from scipy.optimize
+
+ Parameters
+ ----------
+ a : array-like (dim_a,)
+ Unnormalized histogram of dimension `dim_a`
+ b : array-like (dim_b,)
+ Unnormalized histogram of dimension `dim_b`
+ M : array-like (dim_a, dim_b)
+ loss matrix
+ reg: float
+ regularization term (>=0)
+ reg_m: float
+ Marginal relaxation term >= 0
+ reg_div: string, optional
+ Divergence used for regularization.
+ Can take two values: 'kl' (Kullback-Leibler) or 'l2' (quadratic)
+ reg_div: string, optional
+ Divergence to quantify the difference between the marginals.
+ Can take two values: 'kl' (Kullback-Leibler) or 'l2' (quadratic)
+ G0: array-like (dim_a, dim_b)
+ Initialization of the transport matrix
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshold on error (> 0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+ Returns
+ -------
+ ot_distance : array-like
+ the OT distance between :math:`\mathbf{a}` and :math:`\mathbf{b}`
+ log : dict
+ log dictionary returned only if `log` is `True`
+
+ Examples
+ --------
+ >>> import ot
+ >>> import numpy as np
+ >>> a=[.5, .5]
+ >>> b=[.5, .5]
+ >>> M=[[1., 36.],[9., 4.]]
+ >>> np.round(ot.unbalanced.mm_unbalanced2(a, b, M, 1, 'l2'),2)
+ 0.25
+ >>> np.round(ot.unbalanced.mm_unbalanced2(a, b, M, 1, 'kl'),2)
+ 0.57
+
+ References
+ ----------
+ .. [41] Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021).
+ Unbalanced optimal transport through non-negative penalized
+ linear regression. NeurIPS.
+ See Also
+ --------
+ ot.lp.emd2 : Unregularized OT loss
+ ot.unbalanced.sinkhorn_unbalanced2 : Entropic regularized OT loss
+ """
+ nx = get_backend(M, a, b)
+
+ M0 = M
+ # convert to humpy
+ a, b, M = nx.to_numpy(a, b, M)
+
+ if G0 is not None:
+ G0 = nx.to_numpy(G0)
+ else:
+ G0 = np.zeros(M.shape)
+
+ _func = _get_loss_unbalanced(a, b, M, reg, reg_m, reg_div, regm_div)
+
+ res = minimize(_func, G0.ravel(), method=method, jac=True, bounds=Bounds(0, np.inf),
+ tol=stopThr, options=dict(maxiter=numItermax, disp=verbose))
+
+ G = nx.from_numpy(res.x.reshape(M.shape), type_as=M0)
+
+ if log:
+ log = {'loss': nx.from_numpy(res.fun, type_as=M0), 'res': res}
+ return G, log
+ else:
+ return G
diff --git a/ot/utils.py b/ot/utils.py
index a23ce7e..3423a7e 100644
--- a/ot/utils.py
+++ b/ot/utils.py
@@ -15,7 +15,7 @@ from scipy.spatial.distance import cdist
import sys
import warnings
from inspect import signature
-from .backend import get_backend, Backend
+from .backend import get_backend, Backend, NumpyBackend
__time_tic_toc = time.time()
@@ -232,9 +232,11 @@ def dist(x1, x2=None, metric='sqeuclidean', p=2, w=None):
if not get_backend(x1, x2).__name__ == 'numpy':
raise NotImplementedError()
else:
- if metric.endswith("minkowski"):
+ if isinstance(metric, str) and metric.endswith("minkowski"):
return cdist(x1, x2, metric=metric, p=p, w=w)
- return cdist(x1, x2, metric=metric, w=w)
+ if w is not None:
+ return cdist(x1, x2, metric=metric, w=w)
+ return cdist(x1, x2, metric=metric)
def dist0(n, method='lin_square'):
@@ -373,6 +375,36 @@ def check_random_state(seed):
' instance'.format(seed))
+def get_coordinate_circle(x):
+ r"""For :math:`x\in S^1 \subset \mathbb{R}^2`, returns the coordinates in
+ turn (in [0,1[).
+
+ .. math::
+ u = \frac{\pi + \mathrm{atan2}(-x_2,-x_1)}{2\pi}
+
+ Parameters
+ ----------
+ x: ndarray, shape (n, 2)
+ Samples on the circle with ambient coordinates
+
+ Returns
+ -------
+ x_t: ndarray, shape (n,)
+ Coordinates on [0,1[
+
+ Examples
+ --------
+ >>> u = np.array([[0.2,0.5,0.8]]) * (2 * np.pi)
+ >>> x1, y1 = np.cos(u), np.sin(u)
+ >>> x = np.concatenate([x1, y1]).T
+ >>> get_coordinate_circle(x)
+ array([0.2, 0.5, 0.8])
+ """
+ nx = get_backend(x)
+ x_t = (nx.atan2(-x[:, 1], -x[:, 0]) + np.pi) / (2 * np.pi)
+ return x_t
+
+
class deprecated(object):
r"""Decorator to mark a function or class as deprecated.
@@ -609,3 +641,203 @@ class UndefinedParameter(Exception):
"""
pass
+
+
+class OTResult:
+ def __init__(self, potentials=None, value=None, value_linear=None, plan=None, log=None, backend=None, sparse_plan=None, lazy_plan=None, status=None):
+
+ self._potentials = potentials
+ self._value = value
+ self._value_linear = value_linear
+ self._plan = plan
+ self._log = log
+ self._sparse_plan = sparse_plan
+ self._lazy_plan = lazy_plan
+ self._backend = backend if backend is not None else NumpyBackend()
+ self._status = status
+
+ # I assume that other solvers may return directly
+ # some primal objects?
+ # In the code below, let's define the main quantities
+ # that may be of interest to users.
+ # An OT solver returns an object that inherits from OTResult
+ # (e.g. SinkhornOTResult) and implements the relevant
+ # methods (e.g. "plan" and "lazy_plan" but not "sparse_plan", etc.).
+ # log is a dictionary containing potential information about the solver
+
+ # Dual potentials --------------------------------------------
+
+ def __repr__(self):
+ s = 'OTResult('
+ if self._value is not None:
+ s += 'value={},'.format(self._value)
+ if self._value_linear is not None:
+ s += 'value_linear={},'.format(self._value_linear)
+ if self._plan is not None:
+ s += 'plan={}(shape={}),'.format(self._plan.__class__.__name__, self._plan.shape)
+
+ if s[-1] != '(':
+ s = s[:-1] + ')'
+ else:
+ s = s + ')'
+ return s
+
+ @property
+ def potentials(self):
+ """Dual potentials, i.e. Lagrange multipliers for the marginal constraints.
+
+ This pair of arrays has the same shape, numerical type
+ and properties as the input weights "a" and "b".
+ """
+ if self._potentials is not None:
+ return self._potentials
+ else:
+ raise NotImplementedError()
+
+ @property
+ def potential_a(self):
+ """First dual potential, associated to the "source" measure "a"."""
+ if self._potentials is not None:
+ return self._potentials[0]
+ else:
+ raise NotImplementedError()
+
+ @property
+ def potential_b(self):
+ """Second dual potential, associated to the "target" measure "b"."""
+ if self._potentials is not None:
+ return self._potentials[1]
+ else:
+ raise NotImplementedError()
+
+ # Transport plan -------------------------------------------
+ @property
+ def plan(self):
+ """Transport plan, encoded as a dense array."""
+ # N.B.: We may catch out-of-memory errors and suggest
+ # the use of lazy_plan or sparse_plan when appropriate.
+
+ if self._plan is not None:
+ return self._plan
+ else:
+ raise NotImplementedError()
+
+ @property
+ def sparse_plan(self):
+ """Transport plan, encoded as a sparse array."""
+ if self._sparse_plan is not None:
+ return self._sparse_plan
+ elif self._plan is not None:
+ return self._backend.tocsr(self._plan)
+ else:
+ raise NotImplementedError()
+
+ @property
+ def lazy_plan(self):
+ """Transport plan, encoded as a symbolic KeOps LazyTensor."""
+ raise NotImplementedError()
+
+ # Loss values --------------------------------
+
+ @property
+ def value(self):
+ """Full transport cost, including possible regularization terms."""
+ if self._value is not None:
+ return self._value
+ else:
+ raise NotImplementedError()
+
+ @property
+ def value_linear(self):
+ """The "minimal" transport cost, i.e. the product between the transport plan and the cost."""
+ if self._value_linear is not None:
+ return self._value_linear
+ else:
+ raise NotImplementedError()
+
+ # Marginal constraints -------------------------
+ @property
+ def marginals(self):
+ """Marginals of the transport plan: should be very close to "a" and "b"
+ for balanced OT."""
+ if self._plan is not None:
+ return self.marginal_a, self.marginal_b
+ else:
+ raise NotImplementedError()
+
+ @property
+ def marginal_a(self):
+ """First marginal of the transport plan, with the same shape as "a"."""
+ if self._plan is not None:
+ return self._backend.sum(self._plan, 1)
+ else:
+ raise NotImplementedError()
+
+ @property
+ def marginal_b(self):
+ """Second marginal of the transport plan, with the same shape as "b"."""
+ if self._plan is not None:
+ return self._backend.sum(self._plan, 0)
+ else:
+ raise NotImplementedError()
+
+ @property
+ def status(self):
+ """Optimization status of the solver."""
+ if self._status is not None:
+ return self._status
+ else:
+ raise NotImplementedError()
+
+ # Barycentric mappings -------------------------
+ # Return the displacement vectors as an array
+ # that has the same shape as "xa"/"xb" (for samples)
+ # or "a"/"b" * D (for images)?
+
+ @property
+ def a_to_b(self):
+ """Displacement vectors from the first to the second measure."""
+ raise NotImplementedError()
+
+ @property
+ def b_to_a(self):
+ """Displacement vectors from the second to the first measure."""
+ raise NotImplementedError()
+
+ # # Wasserstein barycenters ----------------------
+ # @property
+ # def masses(self):
+ # """Masses for the Wasserstein barycenter."""
+ # raise NotImplementedError()
+
+ # @property
+ # def samples(self):
+ # """Sample locations for the Wasserstein barycenter."""
+ # raise NotImplementedError()
+
+ # Miscellaneous --------------------------------
+
+ @property
+ def citation(self):
+ """Appropriate citation(s) for this result, in plain text and BibTex formats."""
+
+ # The string below refers to the POT library:
+ # successor methods may concatenate the relevant references
+ # to the original definitions, solvers and underlying numerical backends.
+ return """POT library:
+
+ POT Python Optimal Transport library, Journal of Machine Learning Research, 22(78):1−8, 2021.
+ Website: https://pythonot.github.io/
+ Rémi Flamary, Nicolas Courty, Alexandre Gramfort, Mokhtar Z. Alaya, Aurélie Boisbunon, Stanislas Chambon, Laetitia Chapel, Adrien Corenflos, Kilian Fatras, Nemo Fournier, Léo Gautheron, Nathalie T.H. Gayraud, Hicham Janati, Alain Rakotomamonjy, Ievgen Redko, Antoine Rolet, Antony Schutz, Vivien Seguy, Danica J. Sutherland, Romain Tavenard, Alexander Tong, Titouan Vayer;
+
+ @article{flamary2021pot,
+ author = {R{\'e}mi Flamary and Nicolas Courty and Alexandre Gramfort and Mokhtar Z. Alaya and Aur{\'e}lie Boisbunon and Stanislas Chambon and Laetitia Chapel and Adrien Corenflos and Kilian Fatras and Nemo Fournier and L{\'e}o Gautheron and Nathalie T.H. Gayraud and Hicham Janati and Alain Rakotomamonjy and Ievgen Redko and Antoine Rolet and Antony Schutz and Vivien Seguy and Danica J. Sutherland and Romain Tavenard and Alexander Tong and Titouan Vayer},
+ title = {{POT}: {Python} {Optimal} {Transport}},
+ journal = {Journal of Machine Learning Research},
+ year = {2021},
+ volume = {22},
+ number = {78},
+ pages = {1-8},
+ url = {http://jmlr.org/papers/v22/20-451.html}
+ }
+ """
diff --git a/ot/weak.py b/ot/weak.py
index f7d5b23..7364e68 100644
--- a/ot/weak.py
+++ b/ot/weak.py
@@ -18,7 +18,7 @@ def weak_optimal_transport(Xa, Xb, a=None, b=None, verbose=False, log=False, G0=
.. math::
- \gamma = \mathop{\arg \min}_\gamma \quad \|X_a-diag(1/a)\gammaX_b\|_F^2
+ \gamma = \mathop{\arg \min}_\gamma \quad \sum_i \mathbf{a}_i \left(\mathbf{X^a}_i - \frac{1}{\mathbf{a}_i} \sum_j \gamma_{ij} \mathbf{X^b}_j \right)^2
s.t. \ \gamma \mathbf{1} = \mathbf{a}
@@ -28,7 +28,7 @@ def weak_optimal_transport(Xa, Xb, a=None, b=None, verbose=False, log=False, G0=
where :
- - :math:`X_a` :math:`X_b` are the sample matrices.
+ - :math:`X^a` and :math:`X^b` are the sample matrices.
- :math:`\mathbf{a}` and :math:`\mathbf{b}` are the sample weights
@@ -49,6 +49,8 @@ def weak_optimal_transport(Xa, Xb, a=None, b=None, verbose=False, log=False, G0=
Source histogram (uniform weight if empty list)
b : (nt,) array-like, float
Target histogram (uniform weight if empty list))
+ G0 : (ns,nt) array-like, float
+ initial guess (default is indep joint density)
numItermax : int, optional
Max number of iterations
numItermaxEmd : int, optional
diff --git a/requirements.txt b/requirements.txt
index 7cbb29a..9be4deb 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -2,8 +2,7 @@ numpy>=1.20
scipy>=1.3
matplotlib
autograd
-pymanopt==0.2.4; python_version <'3'
-pymanopt==0.2.6rc1; python_version >= '3'
+pymanopt
cvxopt
scikit-learn
torch
diff --git a/setup.py b/setup.py
index c03191a..dc9066d 100644
--- a/setup.py
+++ b/setup.py
@@ -37,7 +37,7 @@ compile_args = ["/O2" if sys.platform == "win32" else "-O3"]
link_args = []
if openmp_supported:
- compile_args += flags + ["/DOMP" if sys.platform == 'win32' else "-DOMP"]
+ compile_args += flags
link_args += flags
if sys.platform.startswith('darwin'):
diff --git a/test/test_1d_solver.py b/test/test_1d_solver.py
index 20f307a..21abd1d 100644
--- a/test/test_1d_solver.py
+++ b/test/test_1d_solver.py
@@ -218,3 +218,130 @@ def test_emd1d_device_tf():
nx.assert_same_dtype_device(xb, emd)
nx.assert_same_dtype_device(xb, emd2)
assert nx.dtype_device(emd)[1].startswith("GPU")
+
+
+def test_wasserstein_1d_circle():
+ # test binary_search_circle and wasserstein_circle give similar results as emd
+ n = 20
+ m = 30
+ rng = np.random.RandomState(0)
+ u = rng.rand(n,)
+ v = rng.rand(m,)
+
+ w_u = rng.uniform(0., 1., n)
+ w_u = w_u / w_u.sum()
+
+ w_v = rng.uniform(0., 1., m)
+ w_v = w_v / w_v.sum()
+
+ M1 = np.minimum(np.abs(u[:, None] - v[None]), 1 - np.abs(u[:, None] - v[None]))
+
+ wass1 = ot.emd2(w_u, w_v, M1)
+
+ wass1_bsc = ot.binary_search_circle(u, v, w_u, w_v, p=1)
+ w1_circle = ot.wasserstein_circle(u, v, w_u, w_v, p=1)
+
+ M2 = M1**2
+ wass2 = ot.emd2(w_u, w_v, M2)
+ wass2_bsc = ot.binary_search_circle(u, v, w_u, w_v, p=2)
+ w2_circle = ot.wasserstein_circle(u, v, w_u, w_v, p=2)
+
+ # check loss is similar
+ np.testing.assert_allclose(wass1, wass1_bsc)
+ np.testing.assert_allclose(wass1, w1_circle, rtol=1e-2)
+ np.testing.assert_allclose(wass2, wass2_bsc)
+ np.testing.assert_allclose(wass2, w2_circle)
+
+
+@pytest.skip_backend("tf")
+def test_wasserstein1d_circle_devices(nx):
+ rng = np.random.RandomState(0)
+
+ n = 10
+ x = np.linspace(0, 1, n)
+ rho_u = np.abs(rng.randn(n))
+ rho_u /= rho_u.sum()
+ rho_v = np.abs(rng.randn(n))
+ rho_v /= rho_v.sum()
+
+ for tp in nx.__type_list__:
+ print(nx.dtype_device(tp))
+
+ xb, rho_ub, rho_vb = nx.from_numpy(x, rho_u, rho_v, type_as=tp)
+
+ w1 = ot.wasserstein_circle(xb, xb, rho_ub, rho_vb, p=1)
+ w2_bsc = ot.wasserstein_circle(xb, xb, rho_ub, rho_vb, p=2)
+
+ nx.assert_same_dtype_device(xb, w1)
+ nx.assert_same_dtype_device(xb, w2_bsc)
+
+
+def test_wasserstein_1d_unif_circle():
+ # test semidiscrete_wasserstein2_unif_circle versus wasserstein_circle
+ n = 20
+ m = 50000
+
+ rng = np.random.RandomState(0)
+ u = rng.rand(n,)
+ v = rng.rand(m,)
+
+ # w_u = rng.uniform(0., 1., n)
+ # w_u = w_u / w_u.sum()
+
+ w_u = ot.utils.unif(n)
+ w_v = ot.utils.unif(m)
+
+ M1 = np.minimum(np.abs(u[:, None] - v[None]), 1 - np.abs(u[:, None] - v[None]))
+ wass2 = ot.emd2(w_u, w_v, M1**2)
+
+ wass2_circle = ot.wasserstein_circle(u, v, w_u, w_v, p=2, eps=1e-15)
+ wass2_unif_circle = ot.semidiscrete_wasserstein2_unif_circle(u, w_u)
+
+ # check loss is similar
+ np.testing.assert_allclose(wass2, wass2_unif_circle, atol=1e-3)
+ np.testing.assert_allclose(wass2_circle, wass2_unif_circle, atol=1e-3)
+
+
+def test_wasserstein1d_unif_circle_devices(nx):
+ rng = np.random.RandomState(0)
+
+ n = 10
+ x = np.linspace(0, 1, n)
+ rho_u = np.abs(rng.randn(n))
+ rho_u /= rho_u.sum()
+
+ for tp in nx.__type_list__:
+ print(nx.dtype_device(tp))
+
+ xb, rho_ub = nx.from_numpy(x, rho_u, type_as=tp)
+
+ w2 = ot.semidiscrete_wasserstein2_unif_circle(xb, rho_ub)
+
+ nx.assert_same_dtype_device(xb, w2)
+
+
+def test_binary_search_circle_log():
+ n = 20
+ m = 30
+ rng = np.random.RandomState(0)
+ u = rng.rand(n,)
+ v = rng.rand(m,)
+
+ wass2_bsc, log = ot.binary_search_circle(u, v, p=2, log=True)
+ optimal_thetas = log["optimal_theta"]
+
+ assert optimal_thetas.shape[0] == 1
+
+
+def test_wasserstein_circle_bad_shape():
+ n = 20
+ m = 30
+ rng = np.random.RandomState(0)
+ u = rng.rand(n, 2)
+ v = rng.rand(m, 1)
+
+ with pytest.raises(ValueError):
+ _ = ot.wasserstein_circle(u, v, p=2)
+
+ with pytest.raises(ValueError):
+ _ = ot.wasserstein_circle(u, v, p=1)
diff --git a/test/test_backend.py b/test/test_backend.py
index 311c075..fd9a761 100644
--- a/test/test_backend.py
+++ b/test/test_backend.py
@@ -275,11 +275,27 @@ def test_empty_backend():
with pytest.raises(NotImplementedError):
nx.sqrtm(M)
with pytest.raises(NotImplementedError):
+ nx.kl_div(M, M)
+ with pytest.raises(NotImplementedError):
nx.isfinite(M)
with pytest.raises(NotImplementedError):
nx.array_equal(M, M)
with pytest.raises(NotImplementedError):
nx.is_floating_point(M)
+ with pytest.raises(NotImplementedError):
+ nx.tile(M, (10, 1))
+ with pytest.raises(NotImplementedError):
+ nx.floor(M)
+ with pytest.raises(NotImplementedError):
+ nx.prod(M)
+ with pytest.raises(NotImplementedError):
+ nx.sort2(M)
+ with pytest.raises(NotImplementedError):
+ nx.qr(M)
+ with pytest.raises(NotImplementedError):
+ nx.atan2(v, v)
+ with pytest.raises(NotImplementedError):
+ nx.transpose(M)
def test_func_backends(nx):
@@ -592,11 +608,47 @@ def test_func_backends(nx):
lst_b.append(nx.to_numpy(A))
lst_name.append("matrix square root")
+ A = nx.kl_div(nx.abs(Mb), nx.abs(Mb) + 1)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append("Kullback-Leibler divergence")
+
A = nx.concatenate([vb, nx.from_numpy(np.array([np.inf, np.nan]))], axis=0)
A = nx.isfinite(A)
lst_b.append(nx.to_numpy(A))
lst_name.append("isfinite")
+ A = nx.tile(vb, (10, 1))
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append("tile")
+
+ A = nx.floor(Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append("floor")
+
+ A = nx.prod(Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append("prod")
+
+ A, B = nx.sort2(Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append("sort2 sort")
+ lst_b.append(nx.to_numpy(B))
+ lst_name.append("sort2 argsort")
+
+ A, B = nx.qr(Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append("QR Q")
+ lst_b.append(nx.to_numpy(B))
+ lst_name.append("QR R")
+
+ A = nx.atan2(vb, vb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append("atan2")
+
+ A = nx.transpose(Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append("transpose")
+
assert not nx.array_equal(Mb, vb), "array_equal (shape)"
assert nx.array_equal(Mb, Mb), "array_equal (elements) - expected true"
assert not nx.array_equal(
diff --git a/test/test_bregman.py b/test/test_bregman.py
index 6c37984..f01bb14 100644
--- a/test/test_bregman.py
+++ b/test/test_bregman.py
@@ -3,9 +3,11 @@
# Author: Remi Flamary <remi.flamary@unice.fr>
# Kilian Fatras <kilian.fatras@irisa.fr>
# Quang Huy Tran <quang-huy.tran@univ-ubs.fr>
+# Eduardo Fernandes Montesuma <eduardo.fernandes-montesuma@universite-paris-saclay.fr>
#
# License: MIT License
+import warnings
from itertools import product
import numpy as np
@@ -57,7 +59,12 @@ def test_convergence_warning(method):
with pytest.warns(UserWarning):
ot.barycenter(A, M, 1, method=method, stopThr=0, numItermax=1)
with pytest.warns(UserWarning):
- ot.sinkhorn2(a1, a2, M, 1, method=method, stopThr=0, numItermax=1)
+ ot.sinkhorn2(a1, a2, M, 1, method=method,
+ stopThr=0, numItermax=1, warn=True)
+ with warnings.catch_warnings():
+ warnings.simplefilter("error")
+ ot.sinkhorn2(a1, a2, M, 1, method=method,
+ stopThr=0, numItermax=1, warn=False)
def test_not_implemented_method():
@@ -261,12 +268,16 @@ def test_sinkhorn_variants(nx):
ub, M_nx = nx.from_numpy(u, M)
G = ot.sinkhorn(u, u, M, 1, method='sinkhorn', stopThr=1e-10)
- Gl = nx.to_numpy(ot.sinkhorn(ub, ub, M_nx, 1, method='sinkhorn_log', stopThr=1e-10))
- G0 = nx.to_numpy(ot.sinkhorn(ub, ub, M_nx, 1, method='sinkhorn', stopThr=1e-10))
- Gs = nx.to_numpy(ot.sinkhorn(ub, ub, M_nx, 1, method='sinkhorn_stabilized', stopThr=1e-10))
+ Gl = nx.to_numpy(ot.sinkhorn(
+ ub, ub, M_nx, 1, method='sinkhorn_log', stopThr=1e-10))
+ G0 = nx.to_numpy(ot.sinkhorn(
+ ub, ub, M_nx, 1, method='sinkhorn', stopThr=1e-10))
+ Gs = nx.to_numpy(ot.sinkhorn(
+ ub, ub, M_nx, 1, method='sinkhorn_stabilized', stopThr=1e-10))
Ges = nx.to_numpy(ot.sinkhorn(
ub, ub, M_nx, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10))
- G_green = nx.to_numpy(ot.sinkhorn(ub, ub, M_nx, 1, method='greenkhorn', stopThr=1e-10))
+ G_green = nx.to_numpy(ot.sinkhorn(
+ ub, ub, M_nx, 1, method='greenkhorn', stopThr=1e-10))
# check values
np.testing.assert_allclose(G, G0, atol=1e-05)
@@ -366,9 +377,12 @@ def test_sinkhorn_variants_multi_b(nx):
ub, bb, M_nx = nx.from_numpy(u, b, M)
G = ot.sinkhorn(u, b, M, 1, method='sinkhorn', stopThr=1e-10)
- Gl = nx.to_numpy(ot.sinkhorn(ub, bb, M_nx, 1, method='sinkhorn_log', stopThr=1e-10))
- G0 = nx.to_numpy(ot.sinkhorn(ub, bb, M_nx, 1, method='sinkhorn', stopThr=1e-10))
- Gs = nx.to_numpy(ot.sinkhorn(ub, bb, M_nx, 1, method='sinkhorn_stabilized', stopThr=1e-10))
+ Gl = nx.to_numpy(ot.sinkhorn(
+ ub, bb, M_nx, 1, method='sinkhorn_log', stopThr=1e-10))
+ G0 = nx.to_numpy(ot.sinkhorn(
+ ub, bb, M_nx, 1, method='sinkhorn', stopThr=1e-10))
+ Gs = nx.to_numpy(ot.sinkhorn(
+ ub, bb, M_nx, 1, method='sinkhorn_stabilized', stopThr=1e-10))
# check values
np.testing.assert_allclose(G, G0, atol=1e-05)
@@ -394,9 +408,12 @@ def test_sinkhorn2_variants_multi_b(nx):
ub, bb, M_nx = nx.from_numpy(u, b, M)
G = ot.sinkhorn2(u, b, M, 1, method='sinkhorn', stopThr=1e-10)
- Gl = nx.to_numpy(ot.sinkhorn2(ub, bb, M_nx, 1, method='sinkhorn_log', stopThr=1e-10))
- G0 = nx.to_numpy(ot.sinkhorn2(ub, bb, M_nx, 1, method='sinkhorn', stopThr=1e-10))
- Gs = nx.to_numpy(ot.sinkhorn2(ub, bb, M_nx, 1, method='sinkhorn_stabilized', stopThr=1e-10))
+ Gl = nx.to_numpy(ot.sinkhorn2(
+ ub, bb, M_nx, 1, method='sinkhorn_log', stopThr=1e-10))
+ G0 = nx.to_numpy(ot.sinkhorn2(
+ ub, bb, M_nx, 1, method='sinkhorn', stopThr=1e-10))
+ Gs = nx.to_numpy(ot.sinkhorn2(
+ ub, bb, M_nx, 1, method='sinkhorn_stabilized', stopThr=1e-10))
# check values
np.testing.assert_allclose(G, G0, atol=1e-05)
@@ -414,12 +431,16 @@ def test_sinkhorn_variants_log():
M = ot.dist(x, x)
- G0, log0 = ot.sinkhorn(u, u, M, 1, method='sinkhorn', stopThr=1e-10, log=True)
- Gl, logl = ot.sinkhorn(u, u, M, 1, method='sinkhorn_log', stopThr=1e-10, log=True)
- Gs, logs = ot.sinkhorn(u, u, M, 1, method='sinkhorn_stabilized', stopThr=1e-10, log=True)
+ G0, log0 = ot.sinkhorn(u, u, M, 1, method='sinkhorn',
+ stopThr=1e-10, log=True)
+ Gl, logl = ot.sinkhorn(
+ u, u, M, 1, method='sinkhorn_log', stopThr=1e-10, log=True)
+ Gs, logs = ot.sinkhorn(
+ u, u, M, 1, method='sinkhorn_stabilized', stopThr=1e-10, log=True)
Ges, loges = ot.sinkhorn(
u, u, M, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10, log=True,)
- G_green, loggreen = ot.sinkhorn(u, u, M, 1, method='greenkhorn', stopThr=1e-10, log=True)
+ G_green, loggreen = ot.sinkhorn(
+ u, u, M, 1, method='greenkhorn', stopThr=1e-10, log=True)
# check values
np.testing.assert_allclose(G0, Gs, atol=1e-05)
@@ -441,7 +462,8 @@ def test_sinkhorn_variants_log_multib(verbose, warn):
M = ot.dist(x, x)
- G0, log0 = ot.sinkhorn(u, b, M, 1, method='sinkhorn', stopThr=1e-10, log=True)
+ G0, log0 = ot.sinkhorn(u, b, M, 1, method='sinkhorn',
+ stopThr=1e-10, log=True)
Gl, logl = ot.sinkhorn(u, b, M, 1, method='sinkhorn_log', stopThr=1e-10, log=True,
verbose=verbose, warn=warn)
Gs, logs = ot.sinkhorn(u, b, M, 1, method='sinkhorn_stabilized', stopThr=1e-10, log=True,
@@ -480,8 +502,73 @@ def test_barycenter(nx, method, verbose, warn):
ot.bregman.barycenter(A_nx, M_nx, reg, weights, method=method)
else:
# wasserstein
- bary_wass_np = ot.bregman.barycenter(A, M, reg, weights, method=method, verbose=verbose, warn=warn)
- bary_wass, _ = ot.bregman.barycenter(A_nx, M_nx, reg, weights_nx, method=method, log=True)
+ bary_wass_np = ot.bregman.barycenter(
+ A, M, reg, weights, method=method, verbose=verbose, warn=warn)
+ bary_wass, _ = ot.bregman.barycenter(
+ A_nx, M_nx, reg, weights_nx, method=method, log=True)
+ bary_wass = nx.to_numpy(bary_wass)
+
+ np.testing.assert_allclose(1, np.sum(bary_wass))
+ np.testing.assert_allclose(bary_wass, bary_wass_np)
+
+ ot.bregman.barycenter(A_nx, M_nx, reg, log=True)
+
+
+def test_free_support_sinkhorn_barycenter():
+ measures_locations = [
+ np.array([-1.]).reshape((1, 1)), # First dirac support
+ np.array([1.]).reshape((1, 1)) # Second dirac support
+ ]
+
+ measures_weights = [
+ np.array([1.]), # First dirac sample weights
+ np.array([1.]) # Second dirac sample weights
+ ]
+
+ # Barycenter initialization
+ X_init = np.array([-12.]).reshape((1, 1))
+
+ # Obvious barycenter locations. Take a look on test_ot.py, test_free_support_barycenter
+ bar_locations = np.array([0.]).reshape((1, 1))
+
+ # Calculate free support barycenter w/ Sinkhorn algorithm. We set the entropic regularization
+ # term to 1, but this should be, in general, fine-tuned to the problem.
+ X = ot.bregman.free_support_sinkhorn_barycenter(
+ measures_locations, measures_weights, X_init, reg=1)
+
+ # Verifies if calculated barycenter matches ground-truth
+ np.testing.assert_allclose(X, bar_locations, rtol=1e-5, atol=1e-7)
+
+
+@pytest.mark.parametrize("method, verbose, warn",
+ product(["sinkhorn", "sinkhorn_stabilized", "sinkhorn_log"],
+ [True, False], [True, False]))
+def test_barycenter_assymetric_cost(nx, method, verbose, warn):
+ n_bins = 20 # nb bins
+
+ # Gaussian distributions
+ A = ot.datasets.make_1D_gauss(n_bins, m=30, s=10) # m= mean, s= std
+
+ # creating matrix A containing all distributions
+ A = A[:, None]
+
+ # assymetric loss matrix + normalization
+ rng = np.random.RandomState(42)
+ M = rng.randn(n_bins, n_bins) ** 2
+ M /= M.max()
+
+ A_nx, M_nx = nx.from_numpy(A, M)
+ reg = 1e-2
+
+ if nx.__name__ in ("jax", "tf") and method == "sinkhorn_log":
+ with pytest.raises(NotImplementedError):
+ ot.bregman.barycenter(A_nx, M_nx, reg, method=method)
+ else:
+ # wasserstein
+ bary_wass_np = ot.bregman.barycenter(
+ A, M, reg, method=method, verbose=verbose, warn=warn)
+ bary_wass, _ = ot.bregman.barycenter(
+ A_nx, M_nx, reg, method=method, log=True)
bary_wass = nx.to_numpy(bary_wass)
np.testing.assert_allclose(1, np.sum(bary_wass))
@@ -516,17 +603,20 @@ def test_barycenter_debiased(nx, method, verbose, warn):
reg = 1e-2
if nx.__name__ in ("jax", "tf") and method == "sinkhorn_log":
with pytest.raises(NotImplementedError):
- ot.bregman.barycenter_debiased(A_nx, M_nx, reg, weights, method=method)
+ ot.bregman.barycenter_debiased(
+ A_nx, M_nx, reg, weights, method=method)
else:
bary_wass_np = ot.bregman.barycenter_debiased(A, M, reg, weights, method=method,
verbose=verbose, warn=warn)
- bary_wass, _ = ot.bregman.barycenter_debiased(A_nx, M_nx, reg, weights_nx, method=method, log=True)
+ bary_wass, _ = ot.bregman.barycenter_debiased(
+ A_nx, M_nx, reg, weights_nx, method=method, log=True)
bary_wass = nx.to_numpy(bary_wass)
np.testing.assert_allclose(1, np.sum(bary_wass), atol=1e-3)
np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-5)
- ot.bregman.barycenter_debiased(A_nx, M_nx, reg, log=True, verbose=False)
+ ot.bregman.barycenter_debiased(
+ A_nx, M_nx, reg, log=True, verbose=False)
@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_log"])
@@ -551,7 +641,8 @@ def test_convergence_warning_barycenters(method):
weights = np.array([1 - alpha, alpha])
reg = 0.1
with pytest.warns(UserWarning):
- ot.bregman.barycenter_debiased(A, M, reg, weights, method=method, numItermax=1)
+ ot.bregman.barycenter_debiased(
+ A, M, reg, weights, method=method, numItermax=1)
with pytest.warns(UserWarning):
ot.bregman.barycenter(A, M, reg, weights, method=method, numItermax=1)
with pytest.warns(UserWarning):
@@ -583,7 +674,8 @@ def test_barycenter_stabilization(nx):
# wasserstein
reg = 1e-2
- bar_np = ot.bregman.barycenter(A, M, reg, weights, method="sinkhorn", stopThr=1e-8, verbose=True)
+ bar_np = ot.bregman.barycenter(
+ A, M, reg, weights, method="sinkhorn", stopThr=1e-8, verbose=True)
bar_stable = nx.to_numpy(ot.bregman.barycenter(
A_nx, M_nx, reg, weights_b, method="sinkhorn_stabilized",
stopThr=1e-8, verbose=True
@@ -618,8 +710,10 @@ def test_wasserstein_bary_2d(nx, method):
with pytest.raises(NotImplementedError):
ot.bregman.convolutional_barycenter2d(A_nx, reg, method=method)
else:
- bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d(A, reg, method=method, verbose=True, log=True)
- bary_wass = nx.to_numpy(ot.bregman.convolutional_barycenter2d(A_nx, reg, method=method))
+ bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d(
+ A, reg, method=method, verbose=True, log=True)
+ bary_wass = nx.to_numpy(
+ ot.bregman.convolutional_barycenter2d(A_nx, reg, method=method))
np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3)
np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3)
@@ -648,10 +742,13 @@ def test_wasserstein_bary_2d_debiased(nx, method):
reg = 1e-2
if nx.__name__ in ("jax", "tf") and method == "sinkhorn_log":
with pytest.raises(NotImplementedError):
- ot.bregman.convolutional_barycenter2d_debiased(A_nx, reg, method=method)
+ ot.bregman.convolutional_barycenter2d_debiased(
+ A_nx, reg, method=method)
else:
- bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d_debiased(A, reg, method=method, verbose=True, log=True)
- bary_wass = nx.to_numpy(ot.bregman.convolutional_barycenter2d_debiased(A_nx, reg, method=method))
+ bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d_debiased(
+ A, reg, method=method, verbose=True, log=True)
+ bary_wass = nx.to_numpy(
+ ot.bregman.convolutional_barycenter2d_debiased(A_nx, reg, method=method))
np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3)
np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3)
@@ -685,7 +782,8 @@ def test_unmix(nx):
# wasserstein
reg = 1e-3
um_np = ot.bregman.unmix(a, D, M, M0, h0, reg, 1, alpha=0.01)
- um = nx.to_numpy(ot.bregman.unmix(ab, Db, M_nx, M0b, h0b, reg, 1, alpha=0.01))
+ um = nx.to_numpy(ot.bregman.unmix(
+ ab, Db, M_nx, M0b, h0b, reg, 1, alpha=0.01))
np.testing.assert_allclose(1, np.sum(um), rtol=1e-03, atol=1e-03)
np.testing.assert_allclose([0.5, 0.5], um, rtol=1e-03, atol=1e-03)
@@ -716,10 +814,12 @@ def test_empirical_sinkhorn(nx):
sinkhorn_log, log_s = ot.sinkhorn(ab, bb, M_nx, 0.1, log=True)
sinkhorn_log = nx.to_numpy(sinkhorn_log)
- G_m = nx.to_numpy(ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1, metric='euclidean'))
+ G_m = nx.to_numpy(ot.bregman.empirical_sinkhorn(
+ X_sb, X_tb, 1, metric='euclidean'))
sinkhorn_m = nx.to_numpy(ot.sinkhorn(ab, bb, M_mb, 1))
- loss_emp_sinkhorn = nx.to_numpy(ot.bregman.empirical_sinkhorn2(X_sb, X_tb, 1))
+ loss_emp_sinkhorn = nx.to_numpy(
+ ot.bregman.empirical_sinkhorn2(X_sb, X_tb, 1))
loss_sinkhorn = nx.to_numpy(ot.sinkhorn2(ab, bb, M_nx, 1))
# check constraints
@@ -752,23 +852,27 @@ def test_lazy_empirical_sinkhorn(nx):
ab, bb, X_sb, X_tb, M_nx, M_mb = nx.from_numpy(a, b, X_s, X_t, M, M_m)
- f, g = ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1, numIterMax=numIterMax, isLazy=True, batchSize=(1, 3), verbose=True)
+ f, g = ot.bregman.empirical_sinkhorn(
+ X_sb, X_tb, 1, numIterMax=numIterMax, isLazy=True, batchSize=(1, 3), verbose=True)
f, g = nx.to_numpy(f), nx.to_numpy(g)
G_sqe = np.exp(f[:, None] + g[None, :] - M / 1)
sinkhorn_sqe = nx.to_numpy(ot.sinkhorn(ab, bb, M_nx, 1))
- f, g, log_es = ot.bregman.empirical_sinkhorn(X_sb, X_tb, 0.1, numIterMax=numIterMax, isLazy=True, batchSize=1, log=True)
+ f, g, log_es = ot.bregman.empirical_sinkhorn(
+ X_sb, X_tb, 0.1, numIterMax=numIterMax, isLazy=True, batchSize=1, log=True)
f, g = nx.to_numpy(f), nx.to_numpy(g)
G_log = np.exp(f[:, None] + g[None, :] - M / 0.1)
sinkhorn_log, log_s = ot.sinkhorn(ab, bb, M_nx, 0.1, log=True)
sinkhorn_log = nx.to_numpy(sinkhorn_log)
- f, g = ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1, metric='euclidean', numIterMax=numIterMax, isLazy=True, batchSize=1)
+ f, g = ot.bregman.empirical_sinkhorn(
+ X_sb, X_tb, 1, metric='euclidean', numIterMax=numIterMax, isLazy=True, batchSize=1)
f, g = nx.to_numpy(f), nx.to_numpy(g)
G_m = np.exp(f[:, None] + g[None, :] - M_m / 1)
sinkhorn_m = nx.to_numpy(ot.sinkhorn(ab, bb, M_mb, 1))
- loss_emp_sinkhorn, log = ot.bregman.empirical_sinkhorn2(X_sb, X_tb, 1, numIterMax=numIterMax, isLazy=True, batchSize=1, log=True)
+ loss_emp_sinkhorn, log = ot.bregman.empirical_sinkhorn2(
+ X_sb, X_tb, 1, numIterMax=numIterMax, isLazy=True, batchSize=1, log=True)
loss_emp_sinkhorn = nx.to_numpy(loss_emp_sinkhorn)
loss_sinkhorn = nx.to_numpy(ot.sinkhorn2(ab, bb, M_nx, 1))
@@ -800,22 +904,57 @@ def test_empirical_sinkhorn_divergence(nx):
M_s = ot.dist(X_s, X_s)
M_t = ot.dist(X_t, X_t)
- ab, bb, X_sb, X_tb, M_nx, M_sb, M_tb = nx.from_numpy(a, b, X_s, X_t, M, M_s, M_t)
+ ab, bb, X_sb, X_tb, M_nx, M_sb, M_tb = nx.from_numpy(
+ a, b, X_s, X_t, M, M_s, M_t)
- emp_sinkhorn_div = nx.to_numpy(ot.bregman.empirical_sinkhorn_divergence(X_sb, X_tb, 1, a=ab, b=bb))
+ emp_sinkhorn_div = nx.to_numpy(
+ ot.bregman.empirical_sinkhorn_divergence(X_sb, X_tb, 1, a=ab, b=bb))
sinkhorn_div = nx.to_numpy(
ot.sinkhorn2(ab, bb, M_nx, 1)
- 1 / 2 * ot.sinkhorn2(ab, ab, M_sb, 1)
- 1 / 2 * ot.sinkhorn2(bb, bb, M_tb, 1)
)
- emp_sinkhorn_div_np = ot.bregman.empirical_sinkhorn_divergence(X_s, X_t, 1, a=a, b=b)
+ emp_sinkhorn_div_np = ot.bregman.empirical_sinkhorn_divergence(
+ X_s, X_t, 1, a=a, b=b)
# check constraints
- np.testing.assert_allclose(emp_sinkhorn_div, emp_sinkhorn_div_np, atol=1e-05)
+ np.testing.assert_allclose(
+ emp_sinkhorn_div, emp_sinkhorn_div_np, atol=1e-05)
np.testing.assert_allclose(
emp_sinkhorn_div, sinkhorn_div, atol=1e-05) # cf conv emp sinkhorn
- ot.bregman.empirical_sinkhorn_divergence(X_sb, X_tb, 1, a=ab, b=bb, log=True)
+ ot.bregman.empirical_sinkhorn_divergence(
+ X_sb, X_tb, 1, a=ab, b=bb, log=True)
+
+
+@pytest.mark.skipif(not torch, reason="No torch available")
+def test_empirical_sinkhorn_divergence_gradient():
+ # Test sinkhorn divergence
+ n = 10
+ a = np.linspace(1, n, n)
+ a /= a.sum()
+ b = ot.unif(n)
+ X_s = np.reshape(np.arange(n, dtype=np.float64), (n, 1))
+ X_t = np.reshape(np.arange(0, n * 2, 2, dtype=np.float64), (n, 1))
+
+ nx = ot.backend.TorchBackend()
+
+ ab, bb, X_sb, X_tb = nx.from_numpy(a, b, X_s, X_t)
+
+ ab.requires_grad = True
+ bb.requires_grad = True
+ X_sb.requires_grad = True
+ X_tb.requires_grad = True
+
+ emp_sinkhorn_div = ot.bregman.empirical_sinkhorn_divergence(
+ X_sb, X_tb, 1, a=ab, b=bb)
+
+ emp_sinkhorn_div.backward()
+
+ assert ab.grad is not None
+ assert bb.grad is not None
+ assert X_sb.grad is not None
+ assert X_tb.grad is not None
def test_stabilized_vs_sinkhorn_multidim(nx):
@@ -837,7 +976,8 @@ def test_stabilized_vs_sinkhorn_multidim(nx):
ab, bb, M_nx = nx.from_numpy(a, b, M)
- G_np, _ = ot.bregman.sinkhorn(a, b, M, reg=epsilon, method="sinkhorn", log=True)
+ G_np, _ = ot.bregman.sinkhorn(
+ a, b, M, reg=epsilon, method="sinkhorn", log=True)
G, log = ot.bregman.sinkhorn(ab, bb, M_nx, reg=epsilon,
method="sinkhorn_stabilized",
log=True)
@@ -902,7 +1042,8 @@ def test_screenkhorn(nx):
# sinkhorn
G_sink = nx.to_numpy(ot.sinkhorn(ab, bb, M_nx, 1e-1))
# screenkhorn
- G_screen = nx.to_numpy(ot.bregman.screenkhorn(ab, bb, M_nx, 1e-1, uniform=True, verbose=True))
+ G_screen = nx.to_numpy(ot.bregman.screenkhorn(
+ ab, bb, M_nx, 1e-1, uniform=True, verbose=True))
# check marginals
np.testing.assert_allclose(G_sink.sum(0), G_screen.sum(0), atol=1e-02)
np.testing.assert_allclose(G_sink.sum(1), G_screen.sum(1), atol=1e-02)
@@ -919,3 +1060,93 @@ def test_convolutional_barycenter_non_square(nx):
np.testing.assert_allclose(np.ones((2, 3)) / (2 * 3), b, atol=1e-02)
np.testing.assert_allclose(np.ones((2, 3)) / (2 * 3), b, atol=1e-02)
np.testing.assert_allclose(b, b_np)
+
+
+def test_sinkhorn_warmstart():
+ m, n = 10, 20
+ a = ot.unif(m)
+ b = ot.unif(n)
+
+ Xs = np.arange(m) * 1.0
+ Xt = np.arange(n) * 1.0
+ M = ot.dist(Xs.reshape(-1, 1), Xt.reshape(-1, 1))
+
+ # Generate warmstart from dual vectors of unregularized OT
+ _, log = ot.lp.emd(a, b, M, log=True)
+ warmstart = (log["u"], log["v"])
+
+ reg = 1
+
+ # Optimal plan with uniform warmstart
+ pi_unif, _ = ot.bregman.sinkhorn(
+ a, b, M, reg, method="sinkhorn", log=True, warmstart=None)
+ # Optimal plan with warmstart generated from unregularized OT
+ pi_sh, _ = ot.bregman.sinkhorn(
+ a, b, M, reg, method="sinkhorn", log=True, warmstart=warmstart)
+ pi_sh_log, _ = ot.bregman.sinkhorn(
+ a, b, M, reg, method="sinkhorn_log", log=True, warmstart=warmstart)
+ pi_sh_stab, _ = ot.bregman.sinkhorn(
+ a, b, M, reg, method="sinkhorn_stabilized", log=True, warmstart=warmstart)
+ pi_sh_sc, _ = ot.bregman.sinkhorn(
+ a, b, M, reg, method="sinkhorn_epsilon_scaling", log=True, warmstart=warmstart)
+
+ np.testing.assert_allclose(pi_unif, pi_sh, atol=1e-05)
+ np.testing.assert_allclose(pi_unif, pi_sh_log, atol=1e-05)
+ np.testing.assert_allclose(pi_unif, pi_sh_stab, atol=1e-05)
+ np.testing.assert_allclose(pi_unif, pi_sh_sc, atol=1e-05)
+
+
+def test_empirical_sinkhorn_warmstart():
+ m, n = 10, 20
+ Xs = np.arange(m).reshape(-1, 1) * 1.0
+ Xt = np.arange(n).reshape(-1, 1) * 1.0
+ M = ot.dist(Xs, Xt)
+
+ # Generate warmstart from dual vectors of unregularized OT
+ a = ot.unif(m)
+ b = ot.unif(n)
+ _, log = ot.lp.emd(a, b, M, log=True)
+ warmstart = (log["u"], log["v"])
+
+ reg = 1
+
+ # Optimal plan with uniform warmstart
+ f, g, _ = ot.bregman.empirical_sinkhorn(
+ X_s=Xs, X_t=Xt, reg=reg, isLazy=True, log=True, warmstart=None)
+ pi_unif = np.exp(f[:, None] + g[None, :] - M / reg)
+ # Optimal plan with warmstart generated from unregularized OT
+ f, g, _ = ot.bregman.empirical_sinkhorn(
+ X_s=Xs, X_t=Xt, reg=reg, isLazy=True, log=True, warmstart=warmstart)
+ pi_ws_lazy = np.exp(f[:, None] + g[None, :] - M / reg)
+ pi_ws_not_lazy, _ = ot.bregman.empirical_sinkhorn(
+ X_s=Xs, X_t=Xt, reg=reg, isLazy=False, log=True, warmstart=warmstart)
+
+ np.testing.assert_allclose(pi_unif, pi_ws_lazy, atol=1e-05)
+ np.testing.assert_allclose(pi_unif, pi_ws_not_lazy, atol=1e-05)
+
+
+def test_empirical_sinkhorn_divergence_warmstart():
+ m, n = 10, 20
+ Xs = np.arange(m).reshape(-1, 1) * 1.0
+ Xt = np.arange(n).reshape(-1, 1) * 1.0
+ M = ot.dist(Xs, Xt)
+
+ # Generate warmstart from dual vectors of unregularized OT
+ a = ot.unif(m)
+ b = ot.unif(n)
+ _, log = ot.lp.emd(a, b, M, log=True)
+ warmstart = (log["u"], log["v"])
+
+ reg = 1
+
+ # Optimal plan with uniform warmstart
+ sd_unif, _ = ot.bregman.empirical_sinkhorn_divergence(
+ X_s=Xs, X_t=Xt, reg=reg, isLazy=True, log=True, warmstart=None)
+ # Optimal plan with warmstart generated from unregularized OT
+ sd_ws_lazy, _ = ot.bregman.empirical_sinkhorn_divergence(
+ X_s=Xs, X_t=Xt, reg=reg, isLazy=True, log=True, warmstart=warmstart)
+ sd_ws_not_lazy, _ = ot.bregman.empirical_sinkhorn_divergence(
+ X_s=Xs, X_t=Xt, reg=reg, isLazy=False, log=True, warmstart=warmstart)
+
+ np.testing.assert_allclose(sd_unif, sd_ws_lazy, atol=1e-05)
+ np.testing.assert_allclose(sd_unif, sd_ws_not_lazy, atol=1e-05)
diff --git a/test/test_coot.py b/test/test_coot.py
new file mode 100644
index 0000000..ef68a9b
--- /dev/null
+++ b/test/test_coot.py
@@ -0,0 +1,359 @@
+"""Tests for module COOT on OT """
+
+# Author: Quang Huy Tran <quang-huy.tran@univ-ubs.fr>
+#
+# License: MIT License
+
+import numpy as np
+import ot
+from ot.coot import co_optimal_transport as coot
+from ot.coot import co_optimal_transport2 as coot2
+import pytest
+
+
+@pytest.mark.parametrize("verbose", [False, True, 1, 0])
+def test_coot(nx, verbose):
+ n_samples = 60 # nb samples
+
+ mu_s = np.array([0, 0])
+ cov_s = np.array([[1, 0], [0, 1]])
+
+ xs = ot.datasets.make_2D_samples_gauss(
+ n_samples, mu_s, cov_s, random_state=4)
+ xt = xs[::-1].copy()
+ xs_nx = nx.from_numpy(xs)
+ xt_nx = nx.from_numpy(xt)
+
+ # test couplings
+ pi_sample, pi_feature = coot(X=xs, Y=xt, verbose=verbose)
+ pi_sample_nx, pi_feature_nx = coot(X=xs_nx, Y=xt_nx, verbose=verbose)
+ pi_sample_nx = nx.to_numpy(pi_sample_nx)
+ pi_feature_nx = nx.to_numpy(pi_feature_nx)
+
+ anti_id_sample = np.flipud(np.eye(n_samples, n_samples)) / n_samples
+ id_feature = np.eye(2, 2) / 2
+
+ np.testing.assert_allclose(pi_sample, anti_id_sample, atol=1e-04)
+ np.testing.assert_allclose(pi_sample_nx, anti_id_sample, atol=1e-04)
+ np.testing.assert_allclose(pi_feature, id_feature, atol=1e-04)
+ np.testing.assert_allclose(pi_feature_nx, id_feature, atol=1e-04)
+
+ # test marginal distributions
+ px_s, px_f = ot.unif(n_samples), ot.unif(2)
+ py_s, py_f = ot.unif(n_samples), ot.unif(2)
+
+ np.testing.assert_allclose(px_s, pi_sample_nx.sum(0), atol=1e-04)
+ np.testing.assert_allclose(py_s, pi_sample_nx.sum(1), atol=1e-04)
+ np.testing.assert_allclose(px_f, pi_feature_nx.sum(0), atol=1e-04)
+ np.testing.assert_allclose(py_f, pi_feature_nx.sum(1), atol=1e-04)
+
+ np.testing.assert_allclose(px_s, pi_sample.sum(0), atol=1e-04)
+ np.testing.assert_allclose(py_s, pi_sample.sum(1), atol=1e-04)
+ np.testing.assert_allclose(px_f, pi_feature.sum(0), atol=1e-04)
+ np.testing.assert_allclose(py_f, pi_feature.sum(1), atol=1e-04)
+
+ # test COOT distance
+
+ coot_np = coot2(X=xs, Y=xt, verbose=verbose)
+ coot_nx = nx.to_numpy(coot2(X=xs_nx, Y=xt_nx, verbose=verbose))
+ np.testing.assert_allclose(coot_np, 0, atol=1e-08)
+ np.testing.assert_allclose(coot_nx, 0, atol=1e-08)
+
+
+def test_entropic_coot(nx):
+ n_samples = 60 # nb samples
+
+ mu_s = np.array([0, 0])
+ cov_s = np.array([[1, 0], [0, 1]])
+
+ xs = ot.datasets.make_2D_samples_gauss(
+ n_samples, mu_s, cov_s, random_state=4)
+ xt = xs[::-1].copy()
+ xs_nx = nx.from_numpy(xs)
+ xt_nx = nx.from_numpy(xt)
+
+ epsilon = (1, 1e-1)
+ nits_ot = 2000
+
+ # test couplings
+ pi_sample, pi_feature = coot(X=xs, Y=xt, epsilon=epsilon, nits_ot=nits_ot)
+ pi_sample_nx, pi_feature_nx = coot(
+ X=xs_nx, Y=xt_nx, epsilon=epsilon, nits_ot=nits_ot)
+ pi_sample_nx = nx.to_numpy(pi_sample_nx)
+ pi_feature_nx = nx.to_numpy(pi_feature_nx)
+
+ np.testing.assert_allclose(pi_sample, pi_sample_nx, atol=1e-04)
+ np.testing.assert_allclose(pi_feature, pi_feature_nx, atol=1e-04)
+
+ # test marginal distributions
+ px_s, px_f = ot.unif(n_samples), ot.unif(2)
+ py_s, py_f = ot.unif(n_samples), ot.unif(2)
+
+ np.testing.assert_allclose(px_s, pi_sample_nx.sum(0), atol=1e-04)
+ np.testing.assert_allclose(py_s, pi_sample_nx.sum(1), atol=1e-04)
+ np.testing.assert_allclose(px_f, pi_feature_nx.sum(0), atol=1e-04)
+ np.testing.assert_allclose(py_f, pi_feature_nx.sum(1), atol=1e-04)
+
+ np.testing.assert_allclose(px_s, pi_sample.sum(0), atol=1e-04)
+ np.testing.assert_allclose(py_s, pi_sample.sum(1), atol=1e-04)
+ np.testing.assert_allclose(px_f, pi_feature.sum(0), atol=1e-04)
+ np.testing.assert_allclose(py_f, pi_feature.sum(1), atol=1e-04)
+
+ # test entropic COOT distance
+
+ coot_np = coot2(X=xs, Y=xt, epsilon=epsilon, nits_ot=nits_ot)
+ coot_nx = nx.to_numpy(
+ coot2(X=xs_nx, Y=xt_nx, epsilon=epsilon, nits_ot=nits_ot))
+
+ np.testing.assert_allclose(coot_np, coot_nx, atol=1e-08)
+
+
+def test_coot_with_linear_terms(nx):
+ n_samples = 60 # nb samples
+
+ mu_s = np.array([0, 0])
+ cov_s = np.array([[1, 0], [0, 1]])
+
+ xs = ot.datasets.make_2D_samples_gauss(
+ n_samples, mu_s, cov_s, random_state=4)
+ xt = xs[::-1].copy()
+ xs_nx = nx.from_numpy(xs)
+ xt_nx = nx.from_numpy(xt)
+
+ M_samp = np.ones((n_samples, n_samples))
+ np.fill_diagonal(np.fliplr(M_samp), 0)
+ M_feat = np.ones((2, 2))
+ np.fill_diagonal(M_feat, 0)
+ M_samp_nx, M_feat_nx = nx.from_numpy(M_samp), nx.from_numpy(M_feat)
+
+ alpha = (1, 2)
+
+ # test couplings
+ anti_id_sample = np.flipud(np.eye(n_samples, n_samples)) / n_samples
+ id_feature = np.eye(2, 2) / 2
+
+ pi_sample, pi_feature = coot(
+ X=xs, Y=xt, alpha=alpha, M_samp=M_samp, M_feat=M_feat)
+ pi_sample_nx, pi_feature_nx = coot(
+ X=xs_nx, Y=xt_nx, alpha=alpha, M_samp=M_samp_nx, M_feat=M_feat_nx)
+ pi_sample_nx = nx.to_numpy(pi_sample_nx)
+ pi_feature_nx = nx.to_numpy(pi_feature_nx)
+
+ np.testing.assert_allclose(pi_sample, anti_id_sample, atol=1e-04)
+ np.testing.assert_allclose(pi_sample_nx, anti_id_sample, atol=1e-04)
+ np.testing.assert_allclose(pi_feature, id_feature, atol=1e-04)
+ np.testing.assert_allclose(pi_feature_nx, id_feature, atol=1e-04)
+
+ # test marginal distributions
+ px_s, px_f = ot.unif(n_samples), ot.unif(2)
+ py_s, py_f = ot.unif(n_samples), ot.unif(2)
+
+ np.testing.assert_allclose(px_s, pi_sample_nx.sum(0), atol=1e-04)
+ np.testing.assert_allclose(py_s, pi_sample_nx.sum(1), atol=1e-04)
+ np.testing.assert_allclose(px_f, pi_feature_nx.sum(0), atol=1e-04)
+ np.testing.assert_allclose(py_f, pi_feature_nx.sum(1), atol=1e-04)
+
+ np.testing.assert_allclose(px_s, pi_sample.sum(0), atol=1e-04)
+ np.testing.assert_allclose(py_s, pi_sample.sum(1), atol=1e-04)
+ np.testing.assert_allclose(px_f, pi_feature.sum(0), atol=1e-04)
+ np.testing.assert_allclose(py_f, pi_feature.sum(1), atol=1e-04)
+
+ # test COOT distance
+
+ coot_np = coot2(X=xs, Y=xt, alpha=alpha, M_samp=M_samp, M_feat=M_feat)
+ coot_nx = nx.to_numpy(
+ coot2(X=xs_nx, Y=xt_nx, alpha=alpha, M_samp=M_samp_nx, M_feat=M_feat_nx))
+ np.testing.assert_allclose(coot_np, 0, atol=1e-08)
+ np.testing.assert_allclose(coot_nx, 0, atol=1e-08)
+
+
+def test_coot_raise_value_error(nx):
+ n_samples = 80 # nb samples
+
+ mu_s = np.array([2, 4])
+ cov_s = np.array([[1, 0], [0, 1]])
+
+ xs = ot.datasets.make_2D_samples_gauss(
+ n_samples, mu_s, cov_s, random_state=43)
+ xt = xs[::-1].copy()
+ xs_nx = nx.from_numpy(xs)
+ xt_nx = nx.from_numpy(xt)
+
+ # raise value error of method sinkhorn
+ def coot_sh(method_sinkhorn):
+ return coot(X=xs, Y=xt, method_sinkhorn=method_sinkhorn)
+
+ def coot_sh_nx(method_sinkhorn):
+ return coot(X=xs_nx, Y=xt_nx, method_sinkhorn=method_sinkhorn)
+
+ np.testing.assert_raises(ValueError, coot_sh, "not_sinkhorn")
+ np.testing.assert_raises(ValueError, coot_sh_nx, "not_sinkhorn")
+
+ # raise value error for epsilon
+ def coot_eps(epsilon):
+ return coot(X=xs, Y=xt, epsilon=epsilon)
+
+ def coot_eps_nx(epsilon):
+ return coot(X=xs_nx, Y=xt_nx, epsilon=epsilon)
+
+ np.testing.assert_raises(ValueError, coot_eps, (1, 2, 3))
+ np.testing.assert_raises(ValueError, coot_eps_nx, [1, 2, 3, 4])
+
+ # raise value error for alpha
+ def coot_alpha(alpha):
+ return coot(X=xs, Y=xt, alpha=alpha)
+
+ def coot_alpha_nx(alpha):
+ return coot(X=xs_nx, Y=xt_nx, alpha=alpha)
+
+ np.testing.assert_raises(ValueError, coot_alpha, [1])
+ np.testing.assert_raises(ValueError, coot_alpha_nx, np.arange(4))
+
+
+def test_coot_warmstart(nx):
+ n_samples = 80 # nb samples
+
+ mu_s = np.array([2, 3])
+ cov_s = np.array([[1, 0], [0, 1]])
+
+ xs = ot.datasets.make_2D_samples_gauss(
+ n_samples, mu_s, cov_s, random_state=125)
+ xt = xs[::-1].copy()
+ xs_nx = nx.from_numpy(xs)
+ xt_nx = nx.from_numpy(xt)
+
+ # initialize warmstart
+ init_pi_sample = np.random.rand(n_samples, n_samples)
+ init_pi_sample = init_pi_sample / np.sum(init_pi_sample)
+ init_pi_sample_nx = nx.from_numpy(init_pi_sample)
+
+ init_pi_feature = np.random.rand(2, 2)
+ init_pi_feature /= init_pi_feature / np.sum(init_pi_feature)
+ init_pi_feature_nx = nx.from_numpy(init_pi_feature)
+
+ init_duals_sample = (np.random.random(n_samples) * 2 - 1,
+ np.random.random(n_samples) * 2 - 1)
+ init_duals_sample_nx = (nx.from_numpy(init_duals_sample[0]),
+ nx.from_numpy(init_duals_sample[1]))
+
+ init_duals_feature = (np.random.random(2) * 2 - 1,
+ np.random.random(2) * 2 - 1)
+ init_duals_feature_nx = (nx.from_numpy(init_duals_feature[0]),
+ nx.from_numpy(init_duals_feature[1]))
+
+ warmstart = {
+ "pi_sample": init_pi_sample,
+ "pi_feature": init_pi_feature,
+ "duals_sample": init_duals_sample,
+ "duals_feature": init_duals_feature
+ }
+
+ warmstart_nx = {
+ "pi_sample": init_pi_sample_nx,
+ "pi_feature": init_pi_feature_nx,
+ "duals_sample": init_duals_sample_nx,
+ "duals_feature": init_duals_feature_nx
+ }
+
+ # test couplings
+ pi_sample, pi_feature = coot(X=xs, Y=xt, warmstart=warmstart)
+ pi_sample_nx, pi_feature_nx = coot(
+ X=xs_nx, Y=xt_nx, warmstart=warmstart_nx)
+ pi_sample_nx = nx.to_numpy(pi_sample_nx)
+ pi_feature_nx = nx.to_numpy(pi_feature_nx)
+
+ anti_id_sample = np.flipud(np.eye(n_samples, n_samples)) / n_samples
+ id_feature = np.eye(2, 2) / 2
+
+ np.testing.assert_allclose(pi_sample, anti_id_sample, atol=1e-04)
+ np.testing.assert_allclose(pi_sample_nx, anti_id_sample, atol=1e-04)
+ np.testing.assert_allclose(pi_feature, id_feature, atol=1e-04)
+ np.testing.assert_allclose(pi_feature_nx, id_feature, atol=1e-04)
+
+ # test marginal distributions
+ px_s, px_f = ot.unif(n_samples), ot.unif(2)
+ py_s, py_f = ot.unif(n_samples), ot.unif(2)
+
+ np.testing.assert_allclose(px_s, pi_sample_nx.sum(0), atol=1e-04)
+ np.testing.assert_allclose(py_s, pi_sample_nx.sum(1), atol=1e-04)
+ np.testing.assert_allclose(px_f, pi_feature_nx.sum(0), atol=1e-04)
+ np.testing.assert_allclose(py_f, pi_feature_nx.sum(1), atol=1e-04)
+
+ np.testing.assert_allclose(px_s, pi_sample.sum(0), atol=1e-04)
+ np.testing.assert_allclose(py_s, pi_sample.sum(1), atol=1e-04)
+ np.testing.assert_allclose(px_f, pi_feature.sum(0), atol=1e-04)
+ np.testing.assert_allclose(py_f, pi_feature.sum(1), atol=1e-04)
+
+ # test COOT distance
+ coot_np = coot2(X=xs, Y=xt, warmstart=warmstart)
+ coot_nx = nx.to_numpy(coot2(X=xs_nx, Y=xt_nx, warmstart=warmstart_nx))
+ np.testing.assert_allclose(coot_np, 0, atol=1e-08)
+ np.testing.assert_allclose(coot_nx, 0, atol=1e-08)
+
+
+def test_coot_log(nx):
+ n_samples = 90 # nb samples
+
+ mu_s = np.array([-2, 0])
+ cov_s = np.array([[1, 0], [0, 1]])
+
+ xs = ot.datasets.make_2D_samples_gauss(
+ n_samples, mu_s, cov_s, random_state=43)
+ xt = xs[::-1].copy()
+ xs_nx = nx.from_numpy(xs)
+ xt_nx = nx.from_numpy(xt)
+
+ pi_sample, pi_feature, log = coot(X=xs, Y=xt, log=True)
+ pi_sample_nx, pi_feature_nx, log_nx = coot(X=xs_nx, Y=xt_nx, log=True)
+
+ duals_sample, duals_feature = log["duals_sample"], log["duals_feature"]
+ assert len(duals_sample) == 2
+ assert len(duals_feature) == 2
+ assert len(duals_sample[0]) == n_samples
+ assert len(duals_sample[1]) == n_samples
+ assert len(duals_feature[0]) == 2
+ assert len(duals_feature[1]) == 2
+
+ duals_sample_nx = log_nx["duals_sample"]
+ assert len(duals_sample_nx) == 2
+ assert len(duals_sample_nx[0]) == n_samples
+ assert len(duals_sample_nx[1]) == n_samples
+
+ duals_feature_nx = log_nx["duals_feature"]
+ assert len(duals_feature_nx) == 2
+ assert len(duals_feature_nx[0]) == 2
+ assert len(duals_feature_nx[1]) == 2
+
+ list_coot = log["distances"]
+ assert len(list_coot) >= 1
+
+ list_coot_nx = log_nx["distances"]
+ assert len(list_coot_nx) >= 1
+
+ # test with coot distance
+ coot_np, log = coot2(X=xs, Y=xt, log=True)
+ coot_nx, log_nx = coot2(X=xs_nx, Y=xt_nx, log=True)
+
+ duals_sample, duals_feature = log["duals_sample"], log["duals_feature"]
+ assert len(duals_sample) == 2
+ assert len(duals_feature) == 2
+ assert len(duals_sample[0]) == n_samples
+ assert len(duals_sample[1]) == n_samples
+ assert len(duals_feature[0]) == 2
+ assert len(duals_feature[1]) == 2
+
+ duals_sample_nx = log_nx["duals_sample"]
+ assert len(duals_sample_nx) == 2
+ assert len(duals_sample_nx[0]) == n_samples
+ assert len(duals_sample_nx[1]) == n_samples
+
+ duals_feature_nx = log_nx["duals_feature"]
+ assert len(duals_feature_nx) == 2
+ assert len(duals_feature_nx[0]) == 2
+ assert len(duals_feature_nx[1]) == 2
+
+ list_coot = log["distances"]
+ assert len(list_coot) >= 1
+
+ list_coot_nx = log_nx["distances"]
+ assert len(list_coot_nx) >= 1
diff --git a/test/test_da.py b/test/test_da.py
index 4bf0ab1..c5f08d6 100644
--- a/test/test_da.py
+++ b/test/test_da.py
@@ -44,12 +44,32 @@ def test_class_jax_tf():
@pytest.skip_backend("jax")
@pytest.skip_backend("tf")
+@pytest.mark.parametrize("class_to_test", [ot.da.EMDTransport, ot.da.SinkhornTransport, ot.da.SinkhornLpl1Transport, ot.da.SinkhornL1l2Transport, ot.da.SinkhornL1l2Transport])
+def test_log_da(nx, class_to_test):
+
+ ns = 50
+ nt = 50
+
+ Xs, ys = make_data_classif('3gauss', ns)
+ Xt, yt = make_data_classif('3gauss2', nt)
+
+ Xs, ys, Xt, yt = nx.from_numpy(Xs, ys, Xt, yt)
+
+ otda = class_to_test(log=True)
+
+ # test its computed
+ otda.fit(Xs=Xs, ys=ys, Xt=Xt)
+ assert hasattr(otda, "log_")
+
+
+@pytest.skip_backend("jax")
+@pytest.skip_backend("tf")
def test_sinkhorn_lpl1_transport_class(nx):
"""test_sinkhorn_transport
"""
- ns = 150
- nt = 200
+ ns = 50
+ nt = 50
Xs, ys = make_data_classif('3gauss', ns)
Xt, yt = make_data_classif('3gauss2', nt)
@@ -136,7 +156,7 @@ def test_sinkhorn_l1l2_transport_class(nx):
"""
ns = 50
- nt = 100
+ nt = 50
Xs, ys = make_data_classif('3gauss', ns)
Xt, yt = make_data_classif('3gauss2', nt)
@@ -230,8 +250,8 @@ def test_sinkhorn_transport_class(nx):
"""test_sinkhorn_transport
"""
- ns = 150
- nt = 200
+ ns = 50
+ nt = 50
Xs, ys = make_data_classif('3gauss', ns)
Xt, yt = make_data_classif('3gauss2', nt)
@@ -323,8 +343,8 @@ def test_unbalanced_sinkhorn_transport_class(nx):
"""test_sinkhorn_transport
"""
- ns = 150
- nt = 200
+ ns = 50
+ nt = 50
Xs, ys = make_data_classif('3gauss', ns)
Xt, yt = make_data_classif('3gauss2', nt)
@@ -402,8 +422,8 @@ def test_emd_transport_class(nx):
"""test_sinkhorn_transport
"""
- ns = 150
- nt = 200
+ ns = 50
+ nt = 50
Xs, ys = make_data_classif('3gauss', ns)
Xt, yt = make_data_classif('3gauss2', nt)
@@ -557,30 +577,9 @@ def test_mapping_transport_class_specific_seed(nx):
@pytest.skip_backend("jax")
@pytest.skip_backend("tf")
-def test_linear_mapping(nx):
- ns = 150
- nt = 200
-
- Xs, ys = make_data_classif('3gauss', ns)
- Xt, yt = make_data_classif('3gauss2', nt)
-
- Xsb, Xtb = nx.from_numpy(Xs, Xt)
-
- A, b = ot.da.OT_mapping_linear(Xsb, Xtb)
-
- Xst = nx.to_numpy(nx.dot(Xsb, A) + b)
-
- Ct = np.cov(Xt.T)
- Cst = np.cov(Xst.T)
-
- np.testing.assert_allclose(Ct, Cst, rtol=1e-2, atol=1e-2)
-
-
-@pytest.skip_backend("jax")
-@pytest.skip_backend("tf")
def test_linear_mapping_class(nx):
- ns = 150
- nt = 200
+ ns = 50
+ nt = 50
Xs, ys = make_data_classif('3gauss', ns)
Xt, yt = make_data_classif('3gauss2', nt)
@@ -609,9 +608,9 @@ def test_jcpot_transport_class(nx):
"""test_jcpot_transport
"""
- ns1 = 150
- ns2 = 150
- nt = 200
+ ns1 = 50
+ ns2 = 50
+ nt = 50
Xs1, ys1 = make_data_classif('3gauss', ns1)
Xs2, ys2 = make_data_classif('3gauss', ns2)
@@ -681,9 +680,9 @@ def test_jcpot_barycenter(nx):
"""test_jcpot_barycenter
"""
- ns1 = 150
- ns2 = 150
- nt = 200
+ ns1 = 50
+ ns2 = 50
+ nt = 50
sigma = 0.1
np.random.seed(1985)
@@ -713,8 +712,8 @@ def test_jcpot_barycenter(nx):
def test_emd_laplace_class(nx):
"""test_emd_laplace_transport
"""
- ns = 150
- nt = 200
+ ns = 50
+ nt = 50
Xs, ys = make_data_classif('3gauss', ns)
Xt, yt = make_data_classif('3gauss2', nt)
diff --git a/test/test_gaussian.py b/test/test_gaussian.py
new file mode 100644
index 0000000..be7a806
--- /dev/null
+++ b/test/test_gaussian.py
@@ -0,0 +1,98 @@
+"""Tests for module gaussian"""
+
+# Author: Theo Gnassounou <theo.gnassounou@inria.fr>
+# Remi Flamary <remi.flamary@polytehnique.edu>
+#
+# License: MIT License
+
+import numpy as np
+
+import pytest
+
+import ot
+from ot.datasets import make_data_classif
+
+
+def test_bures_wasserstein_mapping(nx):
+ ns = 50
+ nt = 50
+
+ Xs, ys = make_data_classif('3gauss', ns)
+ Xt, yt = make_data_classif('3gauss2', nt)
+ ms = np.mean(Xs, axis=0)[None, :]
+ mt = np.mean(Xt, axis=0)[None, :]
+ Cs = np.cov(Xs.T)
+ Ct = np.cov(Xt.T)
+
+ Xsb, msb, mtb, Csb, Ctb = nx.from_numpy(Xs, ms, mt, Cs, Ct)
+
+ A_log, b_log, log = ot.gaussian.bures_wasserstein_mapping(msb, mtb, Csb, Ctb, log=True)
+ A, b = ot.gaussian.bures_wasserstein_mapping(msb, mtb, Csb, Ctb, log=False)
+
+ Xst = nx.to_numpy(nx.dot(Xsb, A) + b)
+ Xst_log = nx.to_numpy(nx.dot(Xsb, A_log) + b_log)
+
+ Cst = np.cov(Xst.T)
+ Cst_log = np.cov(Xst_log.T)
+
+ np.testing.assert_allclose(Cst_log, Cst, rtol=1e-2, atol=1e-2)
+ np.testing.assert_allclose(Ct, Cst, rtol=1e-2, atol=1e-2)
+
+
+@pytest.mark.parametrize("bias", [True, False])
+def test_empirical_bures_wasserstein_mapping(nx, bias):
+ ns = 50
+ nt = 50
+
+ Xs, ys = make_data_classif('3gauss', ns)
+ Xt, yt = make_data_classif('3gauss2', nt)
+
+ if not bias:
+ ms = np.mean(Xs, axis=0)[None, :]
+ mt = np.mean(Xt, axis=0)[None, :]
+
+ Xs = Xs - ms
+ Xt = Xt - mt
+
+ Xsb, Xtb = nx.from_numpy(Xs, Xt)
+
+ A, b, log = ot.gaussian.empirical_bures_wasserstein_mapping(Xsb, Xtb, log=True, bias=bias)
+ A_log, b_log = ot.gaussian.empirical_bures_wasserstein_mapping(Xsb, Xtb, log=False, bias=bias)
+
+ Xst = nx.to_numpy(nx.dot(Xsb, A) + b)
+ Xst_log = nx.to_numpy(nx.dot(Xsb, A_log) + b_log)
+
+ Ct = np.cov(Xt.T)
+ Cst = np.cov(Xst.T)
+ Cst_log = np.cov(Xst_log.T)
+
+ np.testing.assert_allclose(Cst_log, Cst, rtol=1e-2, atol=1e-2)
+ np.testing.assert_allclose(Ct, Cst, rtol=1e-2, atol=1e-2)
+
+
+def test_bures_wasserstein_distance(nx):
+ ms, mt = np.array([0]), np.array([10])
+ Cs, Ct = np.array([[1]]).astype(np.float32), np.array([[1]]).astype(np.float32)
+ msb, mtb, Csb, Ctb = nx.from_numpy(ms, mt, Cs, Ct)
+ Wb_log, log = ot.gaussian.bures_wasserstein_distance(msb, mtb, Csb, Ctb, log=True)
+ Wb = ot.gaussian.bures_wasserstein_distance(msb, mtb, Csb, Ctb, log=False)
+
+ np.testing.assert_allclose(nx.to_numpy(Wb_log), nx.to_numpy(Wb), rtol=1e-2, atol=1e-2)
+ np.testing.assert_allclose(10, nx.to_numpy(Wb), rtol=1e-2, atol=1e-2)
+
+
+@pytest.mark.parametrize("bias", [True, False])
+def test_empirical_bures_wasserstein_distance(nx, bias):
+ ns = 400
+ nt = 400
+
+ rng = np.random.RandomState(10)
+ Xs = rng.normal(0, 1, ns)[:, np.newaxis]
+ Xt = rng.normal(10 * bias, 1, nt)[:, np.newaxis]
+
+ Xsb, Xtb = nx.from_numpy(Xs, Xt)
+ Wb_log, log = ot.gaussian.empirical_bures_wasserstein_distance(Xsb, Xtb, log=True, bias=bias)
+ Wb = ot.gaussian.empirical_bures_wasserstein_distance(Xsb, Xtb, log=False, bias=bias)
+
+ np.testing.assert_allclose(nx.to_numpy(Wb_log), nx.to_numpy(Wb), rtol=1e-2, atol=1e-2)
+ np.testing.assert_allclose(10 * bias, nx.to_numpy(Wb), rtol=1e-2, atol=1e-2)
diff --git a/test/test_gromov.py b/test/test_gromov.py
index 9c85b92..80b6df4 100644
--- a/test/test_gromov.py
+++ b/test/test_gromov.py
@@ -3,7 +3,7 @@
# Author: Erwan Vautier <erwan.vautier@gmail.com>
# Nicolas Courty <ncourty@irisa.fr>
# Titouan Vayer <titouan.vayer@irisa.fr>
-# Cédric Vincent-Cuaz <cedric.vincent-cuaz@inria.fr>
+# Cédric Vincent-Cuaz <cedvincentcuaz@gmail.com>
#
# License: MIT License
@@ -11,18 +11,15 @@ import numpy as np
import ot
from ot.backend import NumpyBackend
from ot.backend import torch, tf
-
import pytest
def test_gromov(nx):
n_samples = 50 # nb samples
-
mu_s = np.array([0, 0])
cov_s = np.array([[1, 0], [0, 1]])
- xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=4)
-
+ xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=1)
xt = xs[::-1].copy()
p = ot.unif(n_samples)
@@ -38,7 +35,7 @@ def test_gromov(nx):
C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0)
G = ot.gromov.gromov_wasserstein(C1, C2, p, q, 'square_loss', G0=G0, verbose=True)
- Gb = nx.to_numpy(ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', G0=G0b, verbose=True))
+ Gb = nx.to_numpy(ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', symmetric=True, G0=G0b, verbose=True))
# check constraints
np.testing.assert_allclose(G, Gb, atol=1e-06)
@@ -51,13 +48,13 @@ def test_gromov(nx):
np.testing.assert_allclose(Gb, np.flipud(Id), atol=1e-04)
- gw, log = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', log=True)
- gwb, logb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', log=True)
+ gw, log = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', armijo=True, log=True)
+ gwb, logb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', armijo=True, log=True)
gwb = nx.to_numpy(gwb)
- gw_val = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', G0=G0, log=False)
+ gw_val = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', armijo=True, G0=G0, log=False)
gw_valb = nx.to_numpy(
- ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', G0=G0b, log=False)
+ ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', armijo=True, G0=G0b, log=False)
)
G = log['T']
@@ -77,6 +74,49 @@ def test_gromov(nx):
q, Gb.sum(0), atol=1e-04) # cf convergence gromov
+def test_asymmetric_gromov(nx):
+ n_samples = 30 # nb samples
+ np.random.seed(0)
+ C1 = np.random.uniform(low=0., high=10, size=(n_samples, n_samples))
+ idx = np.arange(n_samples)
+ np.random.shuffle(idx)
+ C2 = C1[idx, :][:, idx]
+
+ p = ot.unif(n_samples)
+ q = ot.unif(n_samples)
+ G0 = p[:, None] * q[None, :]
+
+ C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0)
+
+ G, log = ot.gromov.gromov_wasserstein(C1, C2, p, q, 'square_loss', G0=G0, log=True, symmetric=False, verbose=True)
+ Gb, logb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', log=True, symmetric=False, G0=G0b, verbose=True)
+ Gb = nx.to_numpy(Gb)
+ # check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
+ np.testing.assert_allclose(
+ p, Gb.sum(1), atol=1e-04) # cf convergence gromov
+ np.testing.assert_allclose(
+ q, Gb.sum(0), atol=1e-04) # cf convergence gromov
+
+ np.testing.assert_allclose(log['gw_dist'], 0., atol=1e-04)
+ np.testing.assert_allclose(logb['gw_dist'], 0., atol=1e-04)
+
+ gw, log = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'square_loss', G0=G0, log=True, symmetric=False, verbose=True)
+ gwb, logb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'square_loss', log=True, symmetric=False, G0=G0b, verbose=True)
+
+ G = log['T']
+ Gb = nx.to_numpy(logb['T'])
+ # check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
+ np.testing.assert_allclose(
+ p, Gb.sum(1), atol=1e-04) # cf convergence gromov
+ np.testing.assert_allclose(
+ q, Gb.sum(0), atol=1e-04) # cf convergence gromov
+
+ np.testing.assert_allclose(log['gw_dist'], 0., atol=1e-04)
+ np.testing.assert_allclose(logb['gw_dist'], 0., atol=1e-04)
+
+
def test_gromov_dtype_device(nx):
# setup
n_samples = 50 # nb samples
@@ -104,7 +144,7 @@ def test_gromov_dtype_device(nx):
C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0, type_as=tp)
Gb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', G0=G0b, verbose=True)
- gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', G0=G0b, log=False)
+ gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', armijo=True, G0=G0b, log=False)
nx.assert_same_dtype_device(C1b, Gb)
nx.assert_same_dtype_device(C1b, gw_valb)
@@ -130,7 +170,7 @@ def test_gromov_device_tf():
with tf.device("/CPU:0"):
C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0)
Gb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', G0=G0b, verbose=True)
- gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', G0=G0b, log=False)
+ gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', armijo=True, G0=G0b, log=False)
nx.assert_same_dtype_device(C1b, Gb)
nx.assert_same_dtype_device(C1b, gw_valb)
@@ -138,7 +178,7 @@ def test_gromov_device_tf():
# Check that everything happens on the GPU
C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0)
Gb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', verbose=True)
- gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', log=False)
+ gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', armijo=True, log=False)
nx.assert_same_dtype_device(C1b, Gb)
nx.assert_same_dtype_device(C1b, gw_valb)
assert nx.dtype_device(Gb)[1].startswith("GPU")
@@ -174,6 +214,7 @@ def test_gromov2_gradients():
C11 = torch.tensor(C1, requires_grad=True, device=device)
C12 = torch.tensor(C2, requires_grad=True, device=device)
+ # Test with exact line-search
val = ot.gromov_wasserstein2(C11, C12, p1, q1)
val.backward()
@@ -184,6 +225,60 @@ def test_gromov2_gradients():
assert C11.shape == C11.grad.shape
assert C12.shape == C12.grad.shape
+ # Test with armijo line-search
+ q1.grad = None
+ p1.grad = None
+ C11.grad = None
+ C12.grad = None
+ val = ot.gromov_wasserstein2(C11, C12, p1, q1, armijo=True)
+
+ val.backward()
+
+ assert val.device == p1.device
+ assert q1.shape == q1.grad.shape
+ assert p1.shape == p1.grad.shape
+ assert C11.shape == C11.grad.shape
+ assert C12.shape == C12.grad.shape
+
+
+def test_gw_helper_backend(nx):
+ n_samples = 20 # nb samples
+
+ mu = np.array([0, 0])
+ cov = np.array([[1, 0], [0, 1]])
+
+ xs = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov, random_state=0)
+ xt = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov, random_state=1)
+
+ p = ot.unif(n_samples)
+ q = ot.unif(n_samples)
+ G0 = p[:, None] * q[None, :]
+
+ C1 = ot.dist(xs, xs)
+ C2 = ot.dist(xt, xt)
+
+ C1 /= C1.max()
+ C2 /= C2.max()
+
+ C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0)
+ Gb, logb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', armijo=False, symmetric=True, G0=G0b, log=True)
+
+ # calls with nx=None
+ constCb, hC1b, hC2b = ot.gromov.init_matrix(C1b, C2b, pb, qb, loss_fun='square_loss')
+
+ def f(G):
+ return ot.gromov.gwloss(constCb, hC1b, hC2b, G, None)
+
+ def df(G):
+ return ot.gromov.gwggrad(constCb, hC1b, hC2b, G, None)
+
+ def line_search(cost, G, deltaG, Mi, cost_G):
+ return ot.gromov.solve_gromov_linesearch(G, deltaG, cost_G, C1b, C2b, M=0., reg=1., nx=None)
+ # feed the precomputed local optimum Gb to cg
+ res, log = ot.optim.cg(pb, qb, 0., 1., f, df, Gb, line_search, log=True, numItermax=1e4, stopThr=1e-9, stopThr2=1e-9)
+ # check constraints
+ np.testing.assert_allclose(res, Gb, atol=1e-06)
+
@pytest.skip_backend("jax", reason="test very slow with jax backend")
@pytest.skip_backend("tf", reason="test very slow with tf backend")
@@ -199,19 +294,21 @@ def test_entropic_gromov(nx):
p = ot.unif(n_samples)
q = ot.unif(n_samples)
-
+ G0 = p[:, None] * q[None, :]
C1 = ot.dist(xs, xs)
C2 = ot.dist(xt, xt)
C1 /= C1.max()
C2 /= C2.max()
- C1b, C2b, pb, qb = nx.from_numpy(C1, C2, p, q)
+ C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0)
- G = ot.gromov.entropic_gromov_wasserstein(
- C1, C2, p, q, 'square_loss', epsilon=5e-4, verbose=True)
+ G, log = ot.gromov.entropic_gromov_wasserstein(
+ C1, C2, p, q, 'square_loss', symmetric=None, G0=G0,
+ epsilon=1e-2, verbose=True, log=True)
Gb = nx.to_numpy(ot.gromov.entropic_gromov_wasserstein(
- C1b, C2b, pb, qb, 'square_loss', epsilon=5e-4, verbose=True
+ C1b, C2b, pb, qb, 'square_loss', symmetric=True, G0=None,
+ epsilon=1e-2, verbose=True, log=False
))
# check constraints
@@ -222,9 +319,11 @@ def test_entropic_gromov(nx):
q, Gb.sum(0), atol=1e-04) # cf convergence gromov
gw, log = ot.gromov.entropic_gromov_wasserstein2(
- C1, C2, p, q, 'kl_loss', max_iter=10, epsilon=1e-2, log=True)
+ C1, C2, p, q, 'kl_loss', symmetric=True, G0=None,
+ max_iter=10, epsilon=1e-2, log=True)
gwb, logb = ot.gromov.entropic_gromov_wasserstein2(
- C1b, C2b, pb, qb, 'kl_loss', max_iter=10, epsilon=1e-2, log=True)
+ C1b, C2b, pb, qb, 'kl_loss', symmetric=None, G0=G0b,
+ max_iter=10, epsilon=1e-2, log=True)
gwb = nx.to_numpy(gwb)
G = log['T']
@@ -241,6 +340,45 @@ def test_entropic_gromov(nx):
q, Gb.sum(0), atol=1e-04) # cf convergence gromov
+def test_asymmetric_entropic_gromov(nx):
+ n_samples = 10 # nb samples
+ np.random.seed(0)
+ C1 = np.random.uniform(low=0., high=10, size=(n_samples, n_samples))
+ idx = np.arange(n_samples)
+ np.random.shuffle(idx)
+ C2 = C1[idx, :][:, idx]
+
+ p = ot.unif(n_samples)
+ q = ot.unif(n_samples)
+ G0 = p[:, None] * q[None, :]
+
+ C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0)
+ G = ot.gromov.entropic_gromov_wasserstein(
+ C1, C2, p, q, 'square_loss', symmetric=None, G0=G0,
+ epsilon=1e-1, verbose=True, log=False)
+ Gb = nx.to_numpy(ot.gromov.entropic_gromov_wasserstein(
+ C1b, C2b, pb, qb, 'square_loss', symmetric=False, G0=None,
+ epsilon=1e-1, verbose=True, log=False
+ ))
+ # check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
+ np.testing.assert_allclose(
+ p, Gb.sum(1), atol=1e-04) # cf convergence gromov
+ np.testing.assert_allclose(
+ q, Gb.sum(0), atol=1e-04) # cf convergence gromov
+
+ gw = ot.gromov.entropic_gromov_wasserstein2(
+ C1, C2, p, q, 'kl_loss', symmetric=False, G0=None,
+ max_iter=10, epsilon=1e-1, log=False)
+ gwb = ot.gromov.entropic_gromov_wasserstein2(
+ C1b, C2b, pb, qb, 'kl_loss', symmetric=None, G0=G0b,
+ max_iter=10, epsilon=1e-1, log=False)
+ gwb = nx.to_numpy(gwb)
+
+ np.testing.assert_allclose(gw, gwb, atol=1e-06)
+ np.testing.assert_allclose(gw, 0, atol=1e-1, rtol=1e-1)
+
+
@pytest.skip_backend("jax", reason="test very slow with jax backend")
@pytest.skip_backend("tf", reason="test very slow with tf backend")
def test_entropic_gromov_dtype_device(nx):
@@ -539,8 +677,8 @@ def test_fgw(nx):
Mb, C1b, C2b, pb, qb, G0b = nx.from_numpy(M, C1, C2, p, q, G0)
- G, log = ot.gromov.fused_gromov_wasserstein(M, C1, C2, p, q, 'square_loss', alpha=0.5, G0=G0, log=True)
- Gb, logb = ot.gromov.fused_gromov_wasserstein(Mb, C1b, C2b, pb, qb, 'square_loss', alpha=0.5, G0=G0b, log=True)
+ G, log = ot.gromov.fused_gromov_wasserstein(M, C1, C2, p, q, 'square_loss', alpha=0.5, armijo=True, symmetric=None, G0=G0, log=True)
+ Gb, logb = ot.gromov.fused_gromov_wasserstein(Mb, C1b, C2b, pb, qb, 'square_loss', alpha=0.5, armijo=True, symmetric=True, G0=G0b, log=True)
Gb = nx.to_numpy(Gb)
# check constraints
@@ -555,8 +693,8 @@ def test_fgw(nx):
np.testing.assert_allclose(
Gb, np.flipud(Id), atol=1e-04) # cf convergence gromov
- fgw, log = ot.gromov.fused_gromov_wasserstein2(M, C1, C2, p, q, 'square_loss', G0=None, alpha=0.5, log=True)
- fgwb, logb = ot.gromov.fused_gromov_wasserstein2(Mb, C1b, C2b, pb, qb, 'square_loss', G0=G0b, alpha=0.5, log=True)
+ fgw, log = ot.gromov.fused_gromov_wasserstein2(M, C1, C2, p, q, 'square_loss', armijo=True, symmetric=True, G0=None, alpha=0.5, log=True)
+ fgwb, logb = ot.gromov.fused_gromov_wasserstein2(Mb, C1b, C2b, pb, qb, 'square_loss', armijo=True, symmetric=None, G0=G0b, alpha=0.5, log=True)
fgwb = nx.to_numpy(fgwb)
G = log['T']
@@ -573,6 +711,82 @@ def test_fgw(nx):
q, Gb.sum(0), atol=1e-04) # cf convergence gromov
+def test_asymmetric_fgw(nx):
+ n_samples = 50 # nb samples
+ np.random.seed(0)
+ C1 = np.random.uniform(low=0., high=10, size=(n_samples, n_samples))
+ idx = np.arange(n_samples)
+ np.random.shuffle(idx)
+ C2 = C1[idx, :][:, idx]
+
+ # add features
+ F1 = np.random.uniform(low=0., high=10, size=(n_samples, 1))
+ F2 = F1[idx, :]
+ p = ot.unif(n_samples)
+ q = ot.unif(n_samples)
+ G0 = p[:, None] * q[None, :]
+
+ M = ot.dist(F1, F2)
+ Mb, C1b, C2b, pb, qb, G0b = nx.from_numpy(M, C1, C2, p, q, G0)
+
+ G, log = ot.gromov.fused_gromov_wasserstein(M, C1, C2, p, q, 'square_loss', alpha=0.5, G0=G0, log=True, symmetric=False, verbose=True)
+ Gb, logb = ot.gromov.fused_gromov_wasserstein(Mb, C1b, C2b, pb, qb, 'square_loss', alpha=0.5, log=True, symmetric=None, G0=G0b, verbose=True)
+ Gb = nx.to_numpy(Gb)
+ # check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
+ np.testing.assert_allclose(
+ p, Gb.sum(1), atol=1e-04) # cf convergence gromov
+ np.testing.assert_allclose(
+ q, Gb.sum(0), atol=1e-04) # cf convergence gromov
+
+ np.testing.assert_allclose(log['fgw_dist'], 0., atol=1e-04)
+ np.testing.assert_allclose(logb['fgw_dist'], 0., atol=1e-04)
+
+ fgw, log = ot.gromov.fused_gromov_wasserstein2(M, C1, C2, p, q, 'square_loss', alpha=0.5, G0=G0, log=True, symmetric=None, verbose=True)
+ fgwb, logb = ot.gromov.fused_gromov_wasserstein2(Mb, C1b, C2b, pb, qb, 'square_loss', alpha=0.5, log=True, symmetric=False, G0=G0b, verbose=True)
+
+ G = log['T']
+ Gb = nx.to_numpy(logb['T'])
+ # check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
+ np.testing.assert_allclose(
+ p, Gb.sum(1), atol=1e-04) # cf convergence gromov
+ np.testing.assert_allclose(
+ q, Gb.sum(0), atol=1e-04) # cf convergence gromov
+
+ np.testing.assert_allclose(log['fgw_dist'], 0., atol=1e-04)
+ np.testing.assert_allclose(logb['fgw_dist'], 0., atol=1e-04)
+
+ # Tests with kl-loss:
+ G, log = ot.gromov.fused_gromov_wasserstein(M, C1, C2, p, q, 'kl_loss', alpha=0.5, G0=G0, log=True, symmetric=False, verbose=True)
+ Gb, logb = ot.gromov.fused_gromov_wasserstein(Mb, C1b, C2b, pb, qb, 'kl_loss', alpha=0.5, log=True, symmetric=None, G0=G0b, verbose=True)
+ Gb = nx.to_numpy(Gb)
+ # check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
+ np.testing.assert_allclose(
+ p, Gb.sum(1), atol=1e-04) # cf convergence gromov
+ np.testing.assert_allclose(
+ q, Gb.sum(0), atol=1e-04) # cf convergence gromov
+
+ np.testing.assert_allclose(log['fgw_dist'], 0., atol=1e-04)
+ np.testing.assert_allclose(logb['fgw_dist'], 0., atol=1e-04)
+
+ fgw, log = ot.gromov.fused_gromov_wasserstein2(M, C1, C2, p, q, 'kl_loss', alpha=0.5, G0=G0, log=True, symmetric=None, verbose=True)
+ fgwb, logb = ot.gromov.fused_gromov_wasserstein2(Mb, C1b, C2b, pb, qb, 'kl_loss', alpha=0.5, log=True, symmetric=False, G0=G0b, verbose=True)
+
+ G = log['T']
+ Gb = nx.to_numpy(logb['T'])
+ # check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
+ np.testing.assert_allclose(
+ p, Gb.sum(1), atol=1e-04) # cf convergence gromov
+ np.testing.assert_allclose(
+ q, Gb.sum(0), atol=1e-04) # cf convergence gromov
+
+ np.testing.assert_allclose(log['fgw_dist'], 0., atol=1e-04)
+ np.testing.assert_allclose(logb['fgw_dist'], 0., atol=1e-04)
+
+
def test_fgw2_gradients():
n_samples = 20 # nb samples
@@ -617,6 +831,57 @@ def test_fgw2_gradients():
assert M1.shape == M1.grad.shape
+def test_fgw_helper_backend(nx):
+ n_samples = 20 # nb samples
+
+ mu = np.array([0, 0])
+ cov = np.array([[1, 0], [0, 1]])
+
+ xs = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov, random_state=0)
+ ys = np.random.randn(xs.shape[0], 2)
+ xt = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov, random_state=1)
+ yt = np.random.randn(xt.shape[0], 2)
+
+ p = ot.unif(n_samples)
+ q = ot.unif(n_samples)
+ G0 = p[:, None] * q[None, :]
+
+ C1 = ot.dist(xs, xs)
+ C2 = ot.dist(xt, xt)
+
+ C1 /= C1.max()
+ C2 /= C2.max()
+
+ M = ot.dist(ys, yt)
+ M /= M.max()
+
+ Mb, C1b, C2b, pb, qb, G0b = nx.from_numpy(M, C1, C2, p, q, G0)
+ alpha = 0.5
+ Gb, logb = ot.gromov.fused_gromov_wasserstein(Mb, C1b, C2b, pb, qb, 'square_loss', alpha=0.5, armijo=False, symmetric=True, G0=G0b, log=True)
+
+ # calls with nx=None
+ constCb, hC1b, hC2b = ot.gromov.init_matrix(C1b, C2b, pb, qb, loss_fun='square_loss')
+
+ def f(G):
+ return ot.gromov.gwloss(constCb, hC1b, hC2b, G, None)
+
+ def df(G):
+ return ot.gromov.gwggrad(constCb, hC1b, hC2b, G, None)
+
+ def line_search(cost, G, deltaG, Mi, cost_G):
+ return ot.gromov.solve_gromov_linesearch(G, deltaG, cost_G, C1b, C2b, M=(1 - alpha) * Mb, reg=alpha, nx=None)
+ # feed the precomputed local optimum Gb to cg
+ res, log = ot.optim.cg(pb, qb, (1 - alpha) * Mb, alpha, f, df, Gb, line_search, log=True, numItermax=1e4, stopThr=1e-9, stopThr2=1e-9)
+
+ def line_search(cost, G, deltaG, Mi, cost_G):
+ return ot.optim.line_search_armijo(cost, G, deltaG, Mi, cost_G, nx=None)
+ # feed the precomputed local optimum Gb to cg
+ res_armijo, log_armijo = ot.optim.cg(pb, qb, (1 - alpha) * Mb, alpha, f, df, Gb, line_search, log=True, numItermax=1e4, stopThr=1e-9, stopThr2=1e-9)
+ # check constraints
+ np.testing.assert_allclose(res, Gb, atol=1e-06)
+ np.testing.assert_allclose(res_armijo, Gb, atol=1e-06)
+
+
def test_fgw_barycenter(nx):
np.random.seed(42)
@@ -1186,3 +1451,327 @@ def test_fused_gromov_wasserstein_dictionary_learning(nx):
# > Compare results with/without backend
total_reconstruction_b_bis2 = nx.to_numpy(total_reconstruction_b_bis2)
np.testing.assert_allclose(total_reconstruction_bis2, total_reconstruction_b_bis2, atol=1e-05)
+
+
+def test_semirelaxed_gromov(nx):
+ np.random.seed(0)
+ # unbalanced proportions
+ list_n = [30, 15]
+ nt = 2
+ ns = np.sum(list_n)
+ # create directed sbm with C2 as connectivity matrix
+ C1 = np.zeros((ns, ns), dtype=np.float64)
+ C2 = np.array([[0.8, 0.05],
+ [0.05, 1.]], dtype=np.float64)
+ for i in range(nt):
+ for j in range(nt):
+ ni, nj = list_n[i], list_n[j]
+ xij = np.random.binomial(size=(ni, nj), n=1, p=C2[i, j])
+ C1[i * ni: (i + 1) * ni, j * nj: (j + 1) * nj] = xij
+ p = ot.unif(ns, type_as=C1)
+ q0 = ot.unif(C2.shape[0], type_as=C1)
+ G0 = p[:, None] * q0[None, :]
+ # asymmetric
+ C1b, C2b, pb, q0b, G0b = nx.from_numpy(C1, C2, p, q0, G0)
+
+ G, log = ot.gromov.semirelaxed_gromov_wasserstein(C1, C2, p, loss_fun='square_loss', symmetric=None, log=True, G0=G0)
+ Gb, logb = ot.gromov.semirelaxed_gromov_wasserstein(C1b, C2b, pb, loss_fun='square_loss', symmetric=False, log=True, G0=None)
+
+ # check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
+ np.testing.assert_allclose(p, nx.sum(Gb, axis=1), atol=1e-04)
+ np.testing.assert_allclose(list_n / ns, np.sum(G, axis=0), atol=1e-01)
+ np.testing.assert_allclose(list_n / ns, nx.sum(Gb, axis=0), atol=1e-01)
+
+ srgw, log2 = ot.gromov.semirelaxed_gromov_wasserstein2(C1, C2, p, loss_fun='square_loss', symmetric=False, log=True, G0=G0)
+ srgwb, logb2 = ot.gromov.semirelaxed_gromov_wasserstein2(C1b, C2b, pb, loss_fun='square_loss', symmetric=None, log=True, G0=None)
+
+ G = log2['T']
+ Gb = nx.to_numpy(logb2['T'])
+ # check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
+ np.testing.assert_allclose(p, Gb.sum(1), atol=1e-04) # cf convergence gromov
+ np.testing.assert_allclose(list_n / ns, Gb.sum(0), atol=1e-04) # cf convergence gromov
+
+ np.testing.assert_allclose(log2['srgw_dist'], logb['srgw_dist'], atol=1e-07)
+ np.testing.assert_allclose(logb2['srgw_dist'], log['srgw_dist'], atol=1e-07)
+
+ # symmetric
+ C1 = 0.5 * (C1 + C1.T)
+ C1b, C2b, pb, q0b, G0b = nx.from_numpy(C1, C2, p, q0, G0)
+
+ G, log = ot.gromov.semirelaxed_gromov_wasserstein(C1, C2, p, loss_fun='square_loss', symmetric=None, log=True, G0=None)
+ Gb = ot.gromov.semirelaxed_gromov_wasserstein(C1b, C2b, pb, loss_fun='square_loss', symmetric=True, log=False, G0=G0b)
+
+ # check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
+ np.testing.assert_allclose(p, nx.sum(Gb, axis=1), atol=1e-04) # cf convergence gromov
+ np.testing.assert_allclose(list_n / ns, nx.sum(Gb, axis=0), atol=1e-02) # cf convergence gromov
+
+ srgw, log2 = ot.gromov.semirelaxed_gromov_wasserstein2(C1, C2, p, loss_fun='square_loss', symmetric=True, log=True, G0=G0)
+ srgwb, logb2 = ot.gromov.semirelaxed_gromov_wasserstein2(C1b, C2b, pb, loss_fun='square_loss', symmetric=None, log=True, G0=None)
+
+ srgw_ = ot.gromov.semirelaxed_gromov_wasserstein2(C1, C2, p, loss_fun='square_loss', symmetric=True, log=False, G0=G0)
+
+ G = log2['T']
+ # check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
+ np.testing.assert_allclose(p, nx.sum(Gb, 1), atol=1e-04) # cf convergence gromov
+ np.testing.assert_allclose(list_n / ns, np.sum(G, axis=0), atol=1e-01)
+ np.testing.assert_allclose(list_n / ns, nx.sum(Gb, axis=0), atol=1e-01)
+
+ np.testing.assert_allclose(log2['srgw_dist'], log['srgw_dist'], atol=1e-07)
+ np.testing.assert_allclose(logb2['srgw_dist'], log['srgw_dist'], atol=1e-07)
+ np.testing.assert_allclose(srgw, srgw_, atol=1e-07)
+
+
+def test_semirelaxed_gromov2_gradients():
+ n_samples = 50 # nb samples
+
+ mu_s = np.array([0, 0])
+ cov_s = np.array([[1, 0], [0, 1]])
+
+ xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=4)
+
+ xt = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=5)
+
+ p = ot.unif(n_samples)
+
+ C1 = ot.dist(xs, xs)
+ C2 = ot.dist(xt, xt)
+
+ C1 /= C1.max()
+ C2 /= C2.max()
+
+ if torch:
+
+ devices = [torch.device("cpu")]
+ if torch.cuda.is_available():
+ devices.append(torch.device("cuda"))
+ for device in devices:
+ # semirelaxed solvers do not support gradients over masses yet.
+ p1 = torch.tensor(p, requires_grad=False, device=device)
+ C11 = torch.tensor(C1, requires_grad=True, device=device)
+ C12 = torch.tensor(C2, requires_grad=True, device=device)
+
+ val = ot.gromov.semirelaxed_gromov_wasserstein2(C11, C12, p1)
+
+ val.backward()
+
+ assert val.device == p1.device
+ assert p1.grad is None
+ assert C11.shape == C11.grad.shape
+ assert C12.shape == C12.grad.shape
+
+
+def test_srgw_helper_backend(nx):
+ n_samples = 20 # nb samples
+
+ mu = np.array([0, 0])
+ cov = np.array([[1, 0], [0, 1]])
+
+ xs = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov, random_state=0)
+ xt = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov, random_state=1)
+
+ p = ot.unif(n_samples)
+ q = ot.unif(n_samples)
+ C1 = ot.dist(xs, xs)
+ C2 = ot.dist(xt, xt)
+
+ C1 /= C1.max()
+ C2 /= C2.max()
+
+ C1b, C2b, pb, qb = nx.from_numpy(C1, C2, p, q)
+ Gb, logb = ot.gromov.semirelaxed_gromov_wasserstein(C1b, C2b, pb, 'square_loss', armijo=False, symmetric=True, G0=None, log=True)
+
+ # calls with nx=None
+ constCb, hC1b, hC2b, fC2tb = ot.gromov.init_matrix_semirelaxed(C1b, C2b, pb, loss_fun='square_loss')
+ ones_pb = nx.ones(pb.shape[0], type_as=pb)
+
+ def f(G):
+ qG = nx.sum(G, 0)
+ marginal_product = nx.outer(ones_pb, nx.dot(qG, fC2tb))
+ return ot.gromov.gwloss(constCb + marginal_product, hC1b, hC2b, G, nx=None)
+
+ def df(G):
+ qG = nx.sum(G, 0)
+ marginal_product = nx.outer(ones_pb, nx.dot(qG, fC2tb))
+ return ot.gromov.gwggrad(constCb + marginal_product, hC1b, hC2b, G, nx=None)
+
+ def line_search(cost, G, deltaG, Mi, cost_G):
+ return ot.gromov.solve_semirelaxed_gromov_linesearch(
+ G, deltaG, cost_G, C1b, C2b, ones_pb, 0., 1., nx=None)
+ # feed the precomputed local optimum Gb to semirelaxed_cg
+ res, log = ot.optim.semirelaxed_cg(pb, qb, 0., 1., f, df, Gb, line_search, log=True, numItermax=1e4, stopThr=1e-9, stopThr2=1e-9)
+ # check constraints
+ np.testing.assert_allclose(res, Gb, atol=1e-06)
+
+
+def test_semirelaxed_fgw(nx):
+ np.random.seed(0)
+ list_n = [16, 8]
+ nt = 2
+ ns = 24
+ # create directed sbm with C2 as connectivity matrix
+ C1 = np.zeros((ns, ns))
+ C2 = np.array([[0.7, 0.05],
+ [0.05, 0.9]])
+ for i in range(nt):
+ for j in range(nt):
+ ni, nj = list_n[i], list_n[j]
+ xij = np.random.binomial(size=(ni, nj), n=1, p=C2[i, j])
+ C1[i * ni: (i + 1) * ni, j * nj: (j + 1) * nj] = xij
+ F1 = np.zeros((ns, 1))
+ F1[:16] = np.random.normal(loc=0., scale=0.01, size=(16, 1))
+ F1[16:] = np.random.normal(loc=1., scale=0.01, size=(8, 1))
+ F2 = np.zeros((2, 1))
+ F2[1, :] = 1.
+ M = (F1 ** 2).dot(np.ones((1, nt))) + np.ones((ns, 1)).dot((F2 ** 2).T) - 2 * F1.dot(F2.T)
+
+ p = ot.unif(ns)
+ q0 = ot.unif(C2.shape[0])
+ G0 = p[:, None] * q0[None, :]
+
+ # asymmetric
+ Mb, C1b, C2b, pb, q0b, G0b = nx.from_numpy(M, C1, C2, p, q0, G0)
+ G, log = ot.gromov.semirelaxed_fused_gromov_wasserstein(M, C1, C2, p, loss_fun='square_loss', alpha=0.5, symmetric=None, log=True, G0=None)
+ Gb, logb = ot.gromov.semirelaxed_fused_gromov_wasserstein(Mb, C1b, C2b, pb, loss_fun='square_loss', alpha=0.5, symmetric=False, log=True, G0=G0b)
+
+ # check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
+ np.testing.assert_allclose(p, nx.sum(Gb, axis=1), atol=1e-04) # cf convergence gromov
+ np.testing.assert_allclose([2 / 3, 1 / 3], nx.sum(Gb, axis=0), atol=1e-02) # cf convergence gromov
+
+ srgw, log2 = ot.gromov.semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p, loss_fun='square_loss', alpha=0.5, symmetric=False, log=True, G0=G0)
+ srgwb, logb2 = ot.gromov.semirelaxed_fused_gromov_wasserstein2(Mb, C1b, C2b, pb, loss_fun='square_loss', alpha=0.5, symmetric=None, log=True, G0=None)
+
+ G = log2['T']
+ Gb = nx.to_numpy(logb2['T'])
+ # check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
+ np.testing.assert_allclose(p, Gb.sum(1), atol=1e-04) # cf convergence gromov
+ np.testing.assert_allclose([2 / 3, 1 / 3], Gb.sum(0), atol=1e-04) # cf convergence gromov
+
+ np.testing.assert_allclose(log2['srfgw_dist'], logb['srfgw_dist'], atol=1e-07)
+ np.testing.assert_allclose(logb2['srfgw_dist'], log['srfgw_dist'], atol=1e-07)
+
+ # symmetric
+ C1 = 0.5 * (C1 + C1.T)
+ Mb, C1b, C2b, pb, q0b, G0b = nx.from_numpy(M, C1, C2, p, q0, G0)
+
+ G, log = ot.gromov.semirelaxed_fused_gromov_wasserstein(M, C1, C2, p, loss_fun='square_loss', alpha=0.5, symmetric=None, log=True, G0=None)
+ Gb = ot.gromov.semirelaxed_fused_gromov_wasserstein(Mb, C1b, C2b, pb, loss_fun='square_loss', alpha=0.5, symmetric=True, log=False, G0=G0b)
+
+ # check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
+ np.testing.assert_allclose(p, nx.sum(Gb, axis=1), atol=1e-04) # cf convergence gromov
+ np.testing.assert_allclose([2 / 3, 1 / 3], nx.sum(Gb, axis=0), atol=1e-02) # cf convergence gromov
+
+ srgw, log2 = ot.gromov.semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p, loss_fun='square_loss', alpha=0.5, symmetric=True, log=True, G0=G0)
+ srgwb, logb2 = ot.gromov.semirelaxed_fused_gromov_wasserstein2(Mb, C1b, C2b, pb, loss_fun='square_loss', alpha=0.5, symmetric=None, log=True, G0=None)
+
+ srgw_ = ot.gromov.semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p, loss_fun='square_loss', alpha=0.5, symmetric=True, log=False, G0=G0)
+
+ G = log2['T']
+ Gb = nx.to_numpy(logb2['T'])
+ # check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
+ np.testing.assert_allclose(p, Gb.sum(1), atol=1e-04) # cf convergence gromov
+ np.testing.assert_allclose([2 / 3, 1 / 3], Gb.sum(0), atol=1e-04) # cf convergence gromov
+
+ np.testing.assert_allclose(log2['srfgw_dist'], log['srfgw_dist'], atol=1e-07)
+ np.testing.assert_allclose(logb2['srfgw_dist'], log['srfgw_dist'], atol=1e-07)
+ np.testing.assert_allclose(srgw, srgw_, atol=1e-07)
+
+
+def test_semirelaxed_fgw2_gradients():
+ n_samples = 20 # nb samples
+
+ mu_s = np.array([0, 0])
+ cov_s = np.array([[1, 0], [0, 1]])
+
+ xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=4)
+
+ xt = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=5)
+
+ p = ot.unif(n_samples)
+
+ C1 = ot.dist(xs, xs)
+ C2 = ot.dist(xt, xt)
+ M = ot.dist(xs, xt)
+
+ C1 /= C1.max()
+ C2 /= C2.max()
+
+ if torch:
+
+ devices = [torch.device("cpu")]
+ if torch.cuda.is_available():
+ devices.append(torch.device("cuda"))
+ for device in devices:
+ # semirelaxed solvers do not support gradients over masses yet.
+ p1 = torch.tensor(p, requires_grad=False, device=device)
+ C11 = torch.tensor(C1, requires_grad=True, device=device)
+ C12 = torch.tensor(C2, requires_grad=True, device=device)
+ M1 = torch.tensor(M, requires_grad=True, device=device)
+
+ val = ot.gromov.semirelaxed_fused_gromov_wasserstein2(M1, C11, C12, p1)
+
+ val.backward()
+
+ assert val.device == p1.device
+ assert p1.grad is None
+ assert C11.shape == C11.grad.shape
+ assert C12.shape == C12.grad.shape
+ assert M1.shape == M1.grad.shape
+
+
+def test_srfgw_helper_backend(nx):
+ n_samples = 20 # nb samples
+
+ mu = np.array([0, 0])
+ cov = np.array([[1, 0], [0, 1]])
+
+ xs = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov, random_state=0)
+ ys = np.random.randn(xs.shape[0], 2)
+ xt = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov, random_state=1)
+ yt = np.random.randn(xt.shape[0], 2)
+
+ p = ot.unif(n_samples)
+ q = ot.unif(n_samples)
+ G0 = p[:, None] * q[None, :]
+
+ C1 = ot.dist(xs, xs)
+ C2 = ot.dist(xt, xt)
+
+ C1 /= C1.max()
+ C2 /= C2.max()
+
+ M = ot.dist(ys, yt)
+ M /= M.max()
+
+ Mb, C1b, C2b, pb, qb, G0b = nx.from_numpy(M, C1, C2, p, q, G0)
+ alpha = 0.5
+ Gb, logb = ot.gromov.semirelaxed_fused_gromov_wasserstein(Mb, C1b, C2b, pb, 'square_loss', alpha=0.5, armijo=False, symmetric=True, G0=G0b, log=True)
+
+ # calls with nx=None
+ constCb, hC1b, hC2b, fC2tb = ot.gromov.init_matrix_semirelaxed(C1b, C2b, pb, loss_fun='square_loss')
+ ones_pb = nx.ones(pb.shape[0], type_as=pb)
+
+ def f(G):
+ qG = nx.sum(G, 0)
+ marginal_product = nx.outer(ones_pb, nx.dot(qG, fC2tb))
+ return ot.gromov.gwloss(constCb + marginal_product, hC1b, hC2b, G, nx=None)
+
+ def df(G):
+ qG = nx.sum(G, 0)
+ marginal_product = nx.outer(ones_pb, nx.dot(qG, fC2tb))
+ return ot.gromov.gwggrad(constCb + marginal_product, hC1b, hC2b, G, nx=None)
+
+ def line_search(cost, G, deltaG, Mi, cost_G):
+ return ot.gromov.solve_semirelaxed_gromov_linesearch(
+ G, deltaG, cost_G, C1b, C2b, ones_pb, M=(1 - alpha) * Mb, reg=alpha, nx=None)
+ # feed the precomputed local optimum Gb to semirelaxed_cg
+ res, log = ot.optim.semirelaxed_cg(pb, qb, (1 - alpha) * Mb, alpha, f, df, Gb, line_search, log=True, numItermax=1e4, stopThr=1e-9, stopThr2=1e-9)
+ # check constraints
+ np.testing.assert_allclose(res, Gb, atol=1e-06)
diff --git a/test/test_optim.py b/test/test_optim.py
index 67e9d13..a43e704 100644
--- a/test/test_optim.py
+++ b/test/test_optim.py
@@ -120,31 +120,33 @@ def test_generalized_conditional_gradient(nx):
Gb, log = ot.optim.gcg(ab, bb, Mb, reg1, reg2, fb, df, verbose=True, log=True)
Gb = nx.to_numpy(Gb)
- np.testing.assert_allclose(Gb, G)
+ np.testing.assert_allclose(Gb, G, atol=1e-12)
np.testing.assert_allclose(a, Gb.sum(1), atol=1e-05)
np.testing.assert_allclose(b, Gb.sum(0), atol=1e-05)
def test_solve_1d_linesearch_quad_funct():
- np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad(1, -1, 0), 0.5)
- np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad(-1, 5, 0), 0)
- np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad(-1, 0.5, 0), 1)
+ np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad(1, -1), 0.5)
+ np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad(-1, 5), 0)
+ np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad(-1, 0.5), 1)
def test_line_search_armijo(nx):
xk = np.array([[0.25, 0.25], [0.25, 0.25]])
pk = np.array([[-0.25, 0.25], [0.25, -0.25]])
gfk = np.array([[23.04273441, 23.0449082], [23.04273441, 23.0449082]])
- old_fval = -123
+ old_fval = -123.
xkb, pkb, gfkb = nx.from_numpy(xk, pk, gfk)
+ def f(x):
+ return 1.
# Should not throw an exception and return 0. for alpha
alpha, a, b = ot.optim.line_search_armijo(
- lambda x: 1, xkb, pkb, gfkb, old_fval
+ f, xkb, pkb, gfkb, old_fval
)
alpha_np, anp, bnp = ot.optim.line_search_armijo(
- lambda x: 1, xk, pk, gfk, old_fval
+ f, xk, pk, gfk, old_fval
)
assert a == anp
assert b == bnp
@@ -182,3 +184,50 @@ def test_line_search_armijo(nx):
old_fval = f(xk)
alpha, _, _ = ot.optim.line_search_armijo(f, xk, pk, gfk, old_fval)
np.testing.assert_allclose(alpha, 0.1)
+
+
+def test_line_search_armijo_dtype_device(nx):
+ for tp in nx.__type_list__:
+ def f(x):
+ return nx.sum((x - 5.0) ** 2)
+
+ def grad(x):
+ return 2 * (x - 5.0)
+
+ xk = np.array([[[-5.0, -5.0]]])
+ pk = np.array([[[100.0, 100.0]]])
+ xkb, pkb = nx.from_numpy(xk, pk, type_as=tp)
+ gfkb = grad(xkb)
+ old_fval = f(xkb)
+
+ # chech the case where the optimum is on the direction
+ alpha, _, fval = ot.optim.line_search_armijo(f, xkb, pkb, gfkb, old_fval)
+ alpha = nx.to_numpy(alpha)
+ np.testing.assert_allclose(alpha, 0.1)
+ nx.assert_same_dtype_device(old_fval, fval)
+
+ # check the case where the direction is not far enough
+ pk = np.array([[[3.0, 3.0]]])
+ pkb = nx.from_numpy(pk, type_as=tp)
+ alpha, _, fval = ot.optim.line_search_armijo(f, xkb, pkb, gfkb, old_fval, alpha0=1.0)
+ alpha = nx.to_numpy(alpha)
+ np.testing.assert_allclose(alpha, 1.0)
+ nx.assert_same_dtype_device(old_fval, fval)
+
+ # check the case where checking the wrong direction
+ alpha, _, fval = ot.optim.line_search_armijo(f, xkb, -pkb, gfkb, old_fval)
+ alpha = nx.to_numpy(alpha)
+
+ assert alpha <= 0
+ nx.assert_same_dtype_device(old_fval, fval)
+
+ # check the case where the point is not a vector
+ xkb = nx.from_numpy(np.array(-5.0), type_as=tp)
+ pkb = nx.from_numpy(np.array(100), type_as=tp)
+ gfkb = grad(xkb)
+ old_fval = f(xkb)
+ alpha, _, fval = ot.optim.line_search_armijo(f, xkb, pkb, gfkb, old_fval)
+ alpha = nx.to_numpy(alpha)
+
+ np.testing.assert_allclose(alpha, 0.1)
+ nx.assert_same_dtype_device(old_fval, fval)
diff --git a/test/test_ot.py b/test/test_ot.py
index bf832f6..f2338ac 100644
--- a/test/test_ot.py
+++ b/test/test_ot.py
@@ -29,9 +29,12 @@ def test_emd_dimension_and_mass_mismatch():
np.testing.assert_raises(AssertionError, ot.emd2, a, a, M)
+ # test emd and emd2 for mass mismatch
+ a = ot.utils.unif(n_samples)
b = a.copy()
a[0] = 100
np.testing.assert_raises(AssertionError, ot.emd, a, b, M)
+ np.testing.assert_raises(AssertionError, ot.emd2, a, b, M)
def test_emd_backends(nx):
@@ -201,6 +204,22 @@ def test_emd_emd2():
np.testing.assert_allclose(w, 0)
+def test_omp_emd2():
+ # test emd2 and emd2 with openmp for simple identity
+ n = 100
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 2)
+ u = ot.utils.unif(n)
+
+ M = ot.dist(x, x)
+
+ w = ot.emd2(u, u, M)
+ w2 = ot.emd2(u, u, M, numThreads=2)
+
+ np.testing.assert_allclose(w, w2)
+
+
def test_emd_empty():
# test emd and emd2 for simple identity
n = 100
@@ -320,6 +339,46 @@ def test_free_support_barycenter_backends(nx):
np.testing.assert_allclose(X, nx.to_numpy(X2))
+def test_generalised_free_support_barycenter():
+ np.random.seed(42) # random inits
+ X = [np.array([-1., -1.]).reshape((1, 2)), np.array([1., 1.]).reshape((1, 2))] # two 2D points bar is obviously 0
+ a = [np.array([1.]), np.array([1.])]
+
+ P = [np.eye(2), np.eye(2)]
+
+ Y_init = np.array([-12., 7.]).reshape((1, 2))
+
+ # obvious barycenter location between two 2D diracs
+ Y_true = np.array([0., .0]).reshape((1, 2))
+
+ # test without log and no init
+ Y = ot.lp.generalized_free_support_barycenter(X, a, P, 1)
+ np.testing.assert_allclose(Y, Y_true, rtol=1e-5, atol=1e-7)
+
+ # test with log and init
+ Y, _ = ot.lp.generalized_free_support_barycenter(X, a, P, 1, Y_init=Y_init, b=np.array([1.]), log=True)
+ np.testing.assert_allclose(Y, Y_true, rtol=1e-5, atol=1e-7)
+
+
+def test_generalised_free_support_barycenter_backends(nx):
+ np.random.seed(42)
+ X = [np.array([-1.]).reshape((1, 1)), np.array([1.]).reshape((1, 1))]
+ a = [np.array([1.]), np.array([1.])]
+ P = [np.array([1.]).reshape((1, 1)), np.array([1.]).reshape((1, 1))]
+ Y_init = np.array([-12.]).reshape((1, 1))
+
+ Y = ot.lp.generalized_free_support_barycenter(X, a, P, 1, Y_init=Y_init)
+
+ X2 = nx.from_numpy(*X)
+ a2 = nx.from_numpy(*a)
+ P2 = nx.from_numpy(*P)
+ Y_init2 = nx.from_numpy(Y_init)
+
+ Y2 = ot.lp.generalized_free_support_barycenter(X2, a2, P2, 1, Y_init=Y_init2)
+
+ np.testing.assert_allclose(Y, nx.to_numpy(Y2))
+
+
@pytest.mark.skipif(not ot.lp.cvx.cvxopt, reason="No cvxopt available")
def test_lp_barycenter_cvxopt():
a1 = np.array([1.0, 0, 0])[:, None]
diff --git a/test/test_partial.py b/test/test_partial.py
index 97c611b..86f9e62 100755
--- a/test/test_partial.py
+++ b/test/test_partial.py
@@ -8,6 +8,7 @@
import numpy as np
import scipy as sp
import ot
+from ot.backend import to_numpy, torch
import pytest
@@ -79,8 +80,10 @@ def test_partial_wasserstein_lagrange():
w0, log0 = ot.partial.partial_wasserstein_lagrange(p, q, M, 1, log=True)
+ w0, log0 = ot.partial.partial_wasserstein_lagrange(p, q, M, 100, log=True)
-def test_partial_wasserstein():
+
+def test_partial_wasserstein(nx):
n_samples = 20 # nb samples (gaussian)
n_noise = 20 # nb of samples (noise)
@@ -100,25 +103,20 @@ def test_partial_wasserstein():
m = 0.5
+ p, q, M = nx.from_numpy(p, q, M)
+
w0, log0 = ot.partial.partial_wasserstein(p, q, M, m=m, log=True)
- w, log = ot.partial.entropic_partial_wasserstein(p, q, M, reg=1, m=m,
- log=True, verbose=True)
+ w, log = ot.partial.entropic_partial_wasserstein(p, q, M, reg=1, m=m, log=True, verbose=True)
# check constraints
- np.testing.assert_equal(
- w0.sum(1) - p <= 1e-5, [True] * len(p)) # cf convergence wasserstein
- np.testing.assert_equal(
- w0.sum(0) - q <= 1e-5, [True] * len(q)) # cf convergence wasserstein
- np.testing.assert_equal(
- w.sum(1) - p <= 1e-5, [True] * len(p)) # cf convergence wasserstein
- np.testing.assert_equal(
- w.sum(0) - q <= 1e-5, [True] * len(q)) # cf convergence wasserstein
+ np.testing.assert_equal(to_numpy(nx.sum(w0, axis=1) - p) <= 1e-5, [True] * len(p))
+ np.testing.assert_equal(to_numpy(nx.sum(w0, axis=0) - q) <= 1e-5, [True] * len(q))
+ np.testing.assert_equal(to_numpy(nx.sum(w0, axis=1) - p) <= 1e-5, [True] * len(p))
+ np.testing.assert_equal(to_numpy(nx.sum(w0, axis=0) - q) <= 1e-5, [True] * len(q))
# check transported mass
- np.testing.assert_allclose(
- np.sum(w0), m, atol=1e-04)
- np.testing.assert_allclose(
- np.sum(w), m, atol=1e-04)
+ np.testing.assert_allclose(np.sum(to_numpy(w0)), m, atol=1e-04)
+ np.testing.assert_allclose(np.sum(to_numpy(w)), m, atol=1e-04)
w0, log0 = ot.partial.partial_wasserstein2(p, q, M, m=m, log=True)
w0_val = ot.partial.partial_wasserstein2(p, q, M, m=m, log=False)
@@ -128,15 +126,91 @@ def test_partial_wasserstein():
np.testing.assert_allclose(w0, w0_val, atol=1e-1, rtol=1e-1)
# check constraints
- np.testing.assert_equal(
- G.sum(1) - p <= 1e-5, [True] * len(p)) # cf convergence wasserstein
- np.testing.assert_equal(
- G.sum(0) - q <= 1e-5, [True] * len(q)) # cf convergence wasserstein
- np.testing.assert_allclose(
- np.sum(G), m, atol=1e-04)
+ np.testing.assert_equal(to_numpy(nx.sum(G, axis=1) - p) <= 1e-5, [True] * len(p))
+ np.testing.assert_equal(to_numpy(nx.sum(G, axis=0) - q) <= 1e-5, [True] * len(q))
+ np.testing.assert_allclose(np.sum(to_numpy(G)), m, atol=1e-04)
+
+ empty_array = nx.zeros(0, type_as=M)
+ w = ot.partial.partial_wasserstein(empty_array, empty_array, M=M, m=None)
+
+ # check constraints
+ np.testing.assert_equal(to_numpy(nx.sum(w, axis=1) - p) <= 1e-5, [True] * len(p))
+ np.testing.assert_equal(to_numpy(nx.sum(w, axis=0) - q) <= 1e-5, [True] * len(q))
+ np.testing.assert_equal(to_numpy(nx.sum(w, axis=1) - p) <= 1e-5, [True] * len(p))
+ np.testing.assert_equal(to_numpy(nx.sum(w, axis=0) - q) <= 1e-5, [True] * len(q))
+
+ # check transported mass
+ np.testing.assert_allclose(np.sum(to_numpy(w)), 1, atol=1e-04)
+
+ w0 = ot.partial.entropic_partial_wasserstein(empty_array, empty_array, M=M, reg=10, m=None)
+
+ # check constraints
+ np.testing.assert_equal(to_numpy(nx.sum(w0, axis=1) - p) <= 1e-5, [True] * len(p))
+ np.testing.assert_equal(to_numpy(nx.sum(w0, axis=0) - q) <= 1e-5, [True] * len(q))
+ np.testing.assert_equal(to_numpy(nx.sum(w0, axis=1) - p) <= 1e-5, [True] * len(p))
+ np.testing.assert_equal(to_numpy(nx.sum(w0, axis=0) - q) <= 1e-5, [True] * len(q))
+
+ # check transported mass
+ np.testing.assert_allclose(np.sum(to_numpy(w0)), 1, atol=1e-04)
+
+
+def test_partial_wasserstein2_gradient():
+ if torch:
+ n_samples = 40
+
+ mu = np.array([0, 0])
+ cov = np.array([[1, 0], [0, 2]])
+
+ xs = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov)
+ xt = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov)
+
+ M = torch.tensor(ot.dist(xs, xt), requires_grad=True, dtype=torch.float64)
+
+ p = torch.tensor(ot.unif(n_samples), dtype=torch.float64)
+ q = torch.tensor(ot.unif(n_samples), dtype=torch.float64)
+
+ m = 0.5
+
+ w, log = ot.partial.partial_wasserstein2(p, q, M, m=m, log=True)
+
+ w.backward()
+
+ assert M.grad is not None
+ assert M.grad.shape == M.shape
+
+
+def test_entropic_partial_wasserstein_gradient():
+ if torch:
+ n_samples = 40
+
+ mu = np.array([0, 0])
+ cov = np.array([[1, 0], [0, 2]])
+
+ xs = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov)
+ xt = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov)
+
+ M = torch.tensor(ot.dist(xs, xt), requires_grad=True, dtype=torch.float64)
+
+ p = torch.tensor(ot.unif(n_samples), requires_grad=True, dtype=torch.float64)
+ q = torch.tensor(ot.unif(n_samples), requires_grad=True, dtype=torch.float64)
+
+ m = 0.5
+ reg = 1
+
+ _, log = ot.partial.entropic_partial_wasserstein(p, q, M, m=m, reg=reg, log=True)
+
+ log['partial_w_dist'].backward()
+
+ assert M.grad is not None
+ assert p.grad is not None
+ assert q.grad is not None
+ assert M.grad.shape == M.shape
+ assert p.grad.shape == p.shape
+ assert q.grad.shape == q.shape
def test_partial_gromov_wasserstein():
+ rng = np.random.RandomState(seed=42)
n_samples = 20 # nb samples
n_noise = 10 # nb of samples (noise)
@@ -149,11 +223,11 @@ def test_partial_gromov_wasserstein():
mu_t = np.array([0, 0, 0])
cov_t = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
- xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s)
- xs = np.concatenate((xs, ((np.random.rand(n_noise, 2) + 1) * 4)), axis=0)
+ xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, rng)
+ xs = np.concatenate((xs, ((rng.rand(n_noise, 2) + 1) * 4)), axis=0)
P = sp.linalg.sqrtm(cov_t)
- xt = np.random.randn(n_samples, 3).dot(P) + mu_t
- xt = np.concatenate((xt, ((np.random.rand(n_noise, 3) + 1) * 10)), axis=0)
+ xt = rng.randn(n_samples, 3).dot(P) + mu_t
+ xt = np.concatenate((xt, ((rng.rand(n_noise, 3) + 1) * 10)), axis=0)
xt2 = xs[::-1].copy()
C1 = ot.dist(xs, xs)
diff --git a/test/test_sliced.py b/test/test_sliced.py
index 08ab4fb..f54c799 100644
--- a/test/test_sliced.py
+++ b/test/test_sliced.py
@@ -110,6 +110,20 @@ def test_max_sliced_different_dists():
assert res > 0.
+def test_sliced_same_proj():
+ n_projections = 10
+ seed = 12
+ rng = np.random.RandomState(0)
+ X = rng.randn(8, 2)
+ Y = rng.randn(8, 2)
+ cost1, log1 = ot.sliced_wasserstein_distance(X, Y, seed=seed, n_projections=n_projections, log=True)
+ P = get_random_projections(X.shape[1], n_projections=10, seed=seed)
+ cost2, log2 = ot.sliced_wasserstein_distance(X, Y, projections=P, log=True)
+
+ assert np.allclose(log1['projections'], log2['projections'])
+ assert np.isclose(cost1, cost2)
+
+
def test_sliced_backend(nx):
n = 100
@@ -252,3 +266,189 @@ def test_max_sliced_backend_device_tf():
valb = ot.max_sliced_wasserstein_distance(xb, yb, projections=Pb)
nx.assert_same_dtype_device(xb, valb)
assert nx.dtype_device(valb)[1].startswith("GPU")
+
+
+def test_projections_stiefel():
+ rng = np.random.RandomState(0)
+
+ n_projs = 500
+ x = np.random.randn(100, 3)
+ x = x / np.sqrt(np.sum(x**2, -1, keepdims=True))
+
+ ssw, log = ot.sliced_wasserstein_sphere(x, x, n_projections=n_projs,
+ seed=rng, log=True)
+
+ P = log["projections"]
+ P_T = np.transpose(P, [0, 2, 1])
+ np.testing.assert_almost_equal(np.matmul(P_T, P), np.array([np.eye(2) for k in range(n_projs)]))
+
+
+def test_sliced_sphere_same_dist():
+ n = 100
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 3)
+ x = x / np.sqrt(np.sum(x**2, -1, keepdims=True))
+ u = ot.utils.unif(n)
+
+ res = ot.sliced_wasserstein_sphere(x, x, u, u, 10, seed=rng)
+ np.testing.assert_almost_equal(res, 0.)
+
+
+def test_sliced_sphere_bad_shapes():
+ n = 100
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 3)
+ x = x / np.sqrt(np.sum(x**2, -1, keepdims=True))
+
+ y = rng.randn(n, 4)
+ y = y / np.sqrt(np.sum(x**2, -1, keepdims=True))
+
+ u = ot.utils.unif(n)
+
+ with pytest.raises(ValueError):
+ _ = ot.sliced_wasserstein_sphere(x, y, u, u, 10, seed=rng)
+
+
+def test_sliced_sphere_values_on_the_sphere():
+ n = 100
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 3)
+ x = x / np.sqrt(np.sum(x**2, -1, keepdims=True))
+
+ y = rng.randn(n, 4)
+
+ u = ot.utils.unif(n)
+
+ with pytest.raises(ValueError):
+ _ = ot.sliced_wasserstein_sphere(x, y, u, u, 10, seed=rng)
+
+
+def test_sliced_sphere_log():
+ n = 100
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 4)
+ x = x / np.sqrt(np.sum(x**2, -1, keepdims=True))
+ y = rng.randn(n, 4)
+ y = y / np.sqrt(np.sum(y**2, -1, keepdims=True))
+ u = ot.utils.unif(n)
+
+ res, log = ot.sliced_wasserstein_sphere(x, y, u, u, 10, p=1, seed=rng, log=True)
+ assert len(log) == 2
+ projections = log["projections"]
+ projected_emds = log["projected_emds"]
+
+ assert projections.shape[0] == len(projected_emds) == 10
+ for emd in projected_emds:
+ assert emd > 0
+
+
+def test_sliced_sphere_different_dists():
+ n = 100
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 3)
+ x = x / np.sqrt(np.sum(x**2, -1, keepdims=True))
+
+ u = ot.utils.unif(n)
+ y = rng.randn(n, 3)
+ y = y / np.sqrt(np.sum(y**2, -1, keepdims=True))
+
+ res = ot.sliced_wasserstein_sphere(x, y, u, u, 10, seed=rng)
+ assert res > 0.
+
+
+def test_1d_sliced_sphere_equals_emd():
+ n = 100
+ m = 120
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 2)
+ x = x / np.sqrt(np.sum(x**2, -1, keepdims=True))
+ x_coords = (np.arctan2(-x[:, 1], -x[:, 0]) + np.pi) / (2 * np.pi)
+ a = rng.uniform(0, 1, n)
+ a /= a.sum()
+
+ y = rng.randn(m, 2)
+ y = y / np.sqrt(np.sum(y**2, -1, keepdims=True))
+ y_coords = (np.arctan2(-y[:, 1], -y[:, 0]) + np.pi) / (2 * np.pi)
+ u = ot.utils.unif(m)
+
+ res = ot.sliced_wasserstein_sphere(x, y, a, u, 10, seed=42, p=2)
+ expected = ot.binary_search_circle(x_coords.T, y_coords.T, a, u, p=2)
+
+ res1 = ot.sliced_wasserstein_sphere(x, y, a, u, 10, seed=42, p=1)
+ expected1 = ot.binary_search_circle(x_coords.T, y_coords.T, a, u, p=1)
+
+ np.testing.assert_almost_equal(res ** 2, expected)
+ np.testing.assert_almost_equal(res1, expected1, decimal=3)
+
+
+@pytest.skip_backend("tf")
+def test_sliced_sphere_backend_type_devices(nx):
+ n = 100
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 3)
+ x = x / np.sqrt(np.sum(x**2, -1, keepdims=True))
+
+ y = rng.randn(2 * n, 3)
+ y = y / np.sqrt(np.sum(y**2, -1, keepdims=True))
+
+ for tp in nx.__type_list__:
+ print(nx.dtype_device(tp))
+
+ xb, yb = nx.from_numpy(x, y, type_as=tp)
+
+ valb = ot.sliced_wasserstein_sphere(xb, yb)
+
+ nx.assert_same_dtype_device(xb, valb)
+
+
+def test_sliced_sphere_unif_values_on_the_sphere():
+ n = 100
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 3)
+ u = ot.utils.unif(n)
+
+ with pytest.raises(ValueError):
+ _ = ot.sliced_wasserstein_sphere_unif(x, u, 10, seed=rng)
+
+
+def test_sliced_sphere_unif_log():
+ n = 100
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 4)
+ x = x / np.sqrt(np.sum(x**2, -1, keepdims=True))
+ u = ot.utils.unif(n)
+
+ res, log = ot.sliced_wasserstein_sphere_unif(x, u, 10, seed=rng, log=True)
+ assert len(log) == 2
+ projections = log["projections"]
+ projected_emds = log["projected_emds"]
+
+ assert projections.shape[0] == len(projected_emds) == 10
+ for emd in projected_emds:
+ assert emd > 0
+
+
+def test_sliced_sphere_unif_backend_type_devices(nx):
+ n = 100
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 3)
+ x = x / np.sqrt(np.sum(x**2, -1, keepdims=True))
+
+ for tp in nx.__type_list__:
+ print(nx.dtype_device(tp))
+
+ xb = nx.from_numpy(x, type_as=tp)
+
+ valb = ot.sliced_wasserstein_sphere_unif(xb)
+
+ nx.assert_same_dtype_device(xb, valb)
diff --git a/test/test_solvers.py b/test/test_solvers.py
new file mode 100644
index 0000000..b792aca
--- /dev/null
+++ b/test/test_solvers.py
@@ -0,0 +1,133 @@
+"""Tests for ot solvers"""
+
+# Author: Remi Flamary <remi.flamary@polytechnique.edu>
+#
+# License: MIT License
+
+
+import itertools
+import numpy as np
+import pytest
+
+import ot
+
+
+lst_reg = [None, 1.0]
+lst_reg_type = ['KL', 'entropy', 'L2']
+lst_unbalanced = [None, 0.9]
+lst_unbalanced_type = ['KL', 'L2', 'TV']
+
+
+def assert_allclose_sol(sol1, sol2):
+
+ lst_attr = ['value', 'value_linear', 'plan',
+ 'potential_a', 'potential_b', 'marginal_a', 'marginal_b']
+
+ nx1 = sol1._backend if sol1._backend is not None else ot.backend.NumpyBackend()
+ nx2 = sol2._backend if sol2._backend is not None else ot.backend.NumpyBackend()
+
+ for attr in lst_attr:
+ try:
+ np.allclose(nx1.to_numpy(getattr(sol1, attr)), nx2.to_numpy(getattr(sol2, attr)))
+ except NotImplementedError:
+ pass
+
+
+def test_solve(nx):
+ n_samples_s = 10
+ n_samples_t = 7
+ n_features = 2
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n_samples_s, n_features)
+ y = rng.randn(n_samples_t, n_features)
+ a = ot.utils.unif(n_samples_s)
+ b = ot.utils.unif(n_samples_t)
+
+ M = ot.dist(x, y)
+
+ # solve unif weights
+ sol0 = ot.solve(M)
+
+ print(sol0)
+
+ # solve signe weights
+ sol = ot.solve(M, a, b)
+
+ # check some attributes
+ sol.potentials
+ sol.sparse_plan
+ sol.marginals
+ sol.status
+
+ assert_allclose_sol(sol0, sol)
+
+ # solve in backend
+ ab, bb, Mb = nx.from_numpy(a, b, M)
+ solb = ot.solve(M, a, b)
+
+ assert_allclose_sol(sol, solb)
+
+ # test not implemented unbalanced and check raise
+ with pytest.raises(NotImplementedError):
+ sol0 = ot.solve(M, unbalanced=1, unbalanced_type='cryptic divergence')
+
+ # test not implemented reg_type and check raise
+ with pytest.raises(NotImplementedError):
+ sol0 = ot.solve(M, reg=1, reg_type='cryptic divergence')
+
+
+@pytest.mark.parametrize("reg,reg_type,unbalanced,unbalanced_type", itertools.product(lst_reg, lst_reg_type, lst_unbalanced, lst_unbalanced_type))
+def test_solve_grid(nx, reg, reg_type, unbalanced, unbalanced_type):
+ n_samples_s = 10
+ n_samples_t = 7
+ n_features = 2
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n_samples_s, n_features)
+ y = rng.randn(n_samples_t, n_features)
+ a = ot.utils.unif(n_samples_s)
+ b = ot.utils.unif(n_samples_t)
+
+ M = ot.dist(x, y)
+
+ try:
+
+ # solve unif weights
+ sol0 = ot.solve(M, reg=reg, reg_type=reg_type, unbalanced=unbalanced, unbalanced_type=unbalanced_type)
+
+ # solve signe weights
+ sol = ot.solve(M, a, b, reg=reg, reg_type=reg_type, unbalanced=unbalanced, unbalanced_type=unbalanced_type)
+
+ assert_allclose_sol(sol0, sol)
+
+ # solve in backend
+ ab, bb, Mb = nx.from_numpy(a, b, M)
+ solb = ot.solve(M, a, b, reg=reg, reg_type=reg_type, unbalanced=unbalanced, unbalanced_type=unbalanced_type)
+
+ assert_allclose_sol(sol, solb)
+ except NotImplementedError:
+ pass
+
+
+def test_solve_not_implemented(nx):
+
+ n_samples_s = 10
+ n_samples_t = 7
+ n_features = 2
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n_samples_s, n_features)
+ y = rng.randn(n_samples_t, n_features)
+
+ M = ot.dist(x, y)
+
+ # test not implemented and check raise
+ with pytest.raises(NotImplementedError):
+ ot.solve(M, reg=1.0, reg_type='cryptic divergence')
+ with pytest.raises(NotImplementedError):
+ ot.solve(M, unbalanced=1.0, unbalanced_type='cryptic divergence')
+
+ # pairs of incompatible divergences
+ with pytest.raises(NotImplementedError):
+ ot.solve(M, reg=1.0, reg_type='kl', unbalanced=1.0, unbalanced_type='tv')
diff --git a/test/test_unbalanced.py b/test/test_unbalanced.py
index 02b3fc3..b76d738 100644
--- a/test/test_unbalanced.py
+++ b/test/test_unbalanced.py
@@ -5,6 +5,7 @@
#
# License: MIT License
+import itertools
import numpy as np
import ot
import pytest
@@ -289,32 +290,55 @@ def test_implemented_methods(nx):
method=method)
+@pytest.mark.parametrize("reg_div,regm_div", itertools.product(['kl', 'l2', 'entropy'], ['kl', 'l2']))
+def test_lbfgsb_unbalanced(nx, reg_div, regm_div):
+
+ np.random.seed(42)
+
+ xs = np.random.randn(5, 2)
+ xt = np.random.randn(6, 2)
+
+ M = ot.dist(xs, xt)
+
+ a = ot.unif(5)
+ b = ot.unif(6)
+
+ G, log = ot.unbalanced.lbfgsb_unbalanced(a, b, M, 1, 10, reg_div=reg_div, regm_div=regm_div, log=True, verbose=False)
+
+ ab, bb, Mb = nx.from_numpy(a, b, M)
+
+ Gb, log = ot.unbalanced.lbfgsb_unbalanced(ab, bb, Mb, 1, 10, reg_div=reg_div, regm_div=regm_div, log=True, verbose=False)
+
+ np.testing.assert_allclose(G, nx.to_numpy(Gb))
+
+
def test_mm_convergence(nx):
n = 100
rng = np.random.RandomState(42)
x = rng.randn(n, 2)
rng = np.random.RandomState(75)
y = rng.randn(n, 2)
- a = ot.utils.unif(n)
- b = ot.utils.unif(n)
+ a_np = ot.utils.unif(n)
+ b_np = ot.utils.unif(n)
M = ot.dist(x, y)
M = M / M.max()
reg_m = 100
- a, b, M = nx.from_numpy(a, b, M)
+ a, b, M = nx.from_numpy(a_np, b_np, M)
G_kl, _ = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='kl',
- verbose=True, log=True)
- loss_kl = nx.to_numpy(ot.unbalanced.mm_unbalanced2(
- a, b, M, reg_m, div='kl', verbose=True))
+ verbose=False, log=True)
+ loss_kl = nx.to_numpy(
+ ot.unbalanced.mm_unbalanced2(a, b, M, reg_m, div='kl', verbose=True)
+ )
G_l2, _ = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='l2',
verbose=False, log=True)
# check if the marginals come close to the true ones when large reg
- np.testing.assert_allclose(np.sum(nx.to_numpy(G_kl), 1), a, atol=1e-03)
- np.testing.assert_allclose(np.sum(nx.to_numpy(G_kl), 0), b, atol=1e-03)
- np.testing.assert_allclose(np.sum(nx.to_numpy(G_l2), 1), a, atol=1e-03)
- np.testing.assert_allclose(np.sum(nx.to_numpy(G_l2), 0), b, atol=1e-03)
+ np.testing.assert_allclose(np.sum(nx.to_numpy(G_kl), 1), a_np, atol=1e-03)
+ np.testing.assert_allclose(np.sum(nx.to_numpy(G_kl), 0), b_np, atol=1e-03)
+ np.testing.assert_allclose(np.sum(nx.to_numpy(G_l2), 1), a_np, atol=1e-03)
+ np.testing.assert_allclose(np.sum(nx.to_numpy(G_l2), 0), b_np, atol=1e-03)
# check if mm_unbalanced2 returns the correct loss
np.testing.assert_allclose(nx.to_numpy(nx.sum(G_kl * M)), loss_kl,
@@ -324,15 +348,16 @@ def test_mm_convergence(nx):
a_np, b_np = np.array([]), np.array([])
a, b = nx.from_numpy(a_np, b_np)
- G_kl_null = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='kl')
- G_l2_null = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='l2')
- np.testing.assert_allclose(G_kl_null, G_kl)
- np.testing.assert_allclose(G_l2_null, G_l2)
+ G_kl_null = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='kl', verbose=False)
+ G_l2_null = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='l2', verbose=False)
+ np.testing.assert_allclose(nx.to_numpy(G_kl_null), nx.to_numpy(G_kl))
+ np.testing.assert_allclose(nx.to_numpy(G_l2_null), nx.to_numpy(G_l2))
# test when G0 is given
G0 = ot.emd(a, b, M)
+ G0_np = nx.to_numpy(G0)
reg_m = 10000
- G_kl = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='kl', G0=G0)
- G_l2 = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='l2', G0=G0)
- np.testing.assert_allclose(G0, G_kl, atol=1e-05)
- np.testing.assert_allclose(G0, G_l2, atol=1e-05)
+ G_kl = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='kl', G0=G0, verbose=False)
+ G_l2 = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='l2', G0=G0, verbose=False)
+ np.testing.assert_allclose(G0_np, nx.to_numpy(G_kl), atol=1e-05)
+ np.testing.assert_allclose(G0_np, nx.to_numpy(G_l2), atol=1e-05)
diff --git a/test/test_utils.py b/test/test_utils.py
index 3cfd295..31b12ef 100644
--- a/test/test_utils.py
+++ b/test/test_utils.py
@@ -143,6 +143,7 @@ def test_dist():
for metric in metrics_w:
print(metric)
ot.dist(x, x, metric=metric, p=3, w=np.random.random((2, )))
+ ot.dist(x, x, metric=metric, p=3, w=None) # check that not having any weight does not cause issues
for metric in metrics:
print(metric)
ot.dist(x, x, metric=metric, p=3)
@@ -300,3 +301,42 @@ def test_BaseEstimator():
cl.set_params(bibi=10)
assert cl.first == 'spam again'
+
+
+def test_OTResult():
+
+ res = ot.utils.OTResult()
+
+ # test print
+ print(res)
+
+ # tets get citation
+ print(res.citation)
+
+ lst_attributes = ['a_to_b',
+ 'b_to_a',
+ 'lazy_plan',
+ 'marginal_a',
+ 'marginal_b',
+ 'marginals',
+ 'plan',
+ 'potential_a',
+ 'potential_b',
+ 'potentials',
+ 'sparse_plan',
+ 'status',
+ 'value',
+ 'value_linear']
+ for at in lst_attributes:
+ with pytest.raises(NotImplementedError):
+ getattr(res, at)
+
+
+def test_get_coordinate_circle():
+
+ u = np.random.rand(1, 100)
+ x1, y1 = np.cos(u * (2 * np.pi)), np.sin(u * (2 * np.pi))
+ x = np.concatenate([x1, y1]).T
+ x_p = ot.utils.get_coordinate_circle(x)
+
+ np.testing.assert_allclose(u[0], x_p)