summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGard Spreemann <gspr@nonempty.org>2022-04-27 11:49:23 +0200
committerGard Spreemann <gspr@nonempty.org>2022-04-27 11:49:23 +0200
commit35bd2c98b642df78638d7d733bc1a89d873db1de (patch)
tree6bc637624004713808d3097b95acdccbb9608e52
parentc4753bd3f74139af8380127b66b484bc09b50661 (diff)
parenteccb1386eea52b94b82456d126bd20cbe3198e05 (diff)
Merge tag '0.8.2' into dfsg/latest
-rw-r--r--.circleci/config.yml25
-rw-r--r--.github/workflows/build_tests.yml6
-rw-r--r--.github/workflows/build_wheels.yml8
-rw-r--r--.github/workflows/build_wheels_weekly.yml2
-rw-r--r--CONTRIBUTORS.md52
-rw-r--r--README.md54
-rw-r--r--RELEASES.md81
-rw-r--r--docs/source/_static/images/bak.png (renamed from docs/source/auto_examples/images/bak.png)bin304669 -> 304669 bytes
-rw-r--r--docs/source/_static/images/logo.pngbin0 -> 4325 bytes
-rw-r--r--docs/source/_static/images/logo.svg200
-rw-r--r--docs/source/_static/images/logo_3ia.jpgbin0 -> 25029 bytes
-rw-r--r--docs/source/_static/images/logo_anr.jpgbin0 -> 23493 bytes
-rw-r--r--docs/source/_static/images/logo_cnrs.jpgbin0 -> 6918 bytes
-rw-r--r--docs/source/_static/images/logo_dark.pngbin0 -> 3437 bytes
-rw-r--r--docs/source/_static/images/logo_dark.svg187
-rw-r--r--docs/source/_static/images/sinkhorn.png (renamed from docs/source/auto_examples/images/sinkhorn.png)bin37204 -> 37204 bytes
-rw-r--r--docs/source/all.rst4
-rw-r--r--docs/source/conf.py13
-rw-r--r--docs/source/contributors.rst6
-rw-r--r--docs/source/index.rst6
-rw-r--r--docs/source/quickstart.rst279
-rw-r--r--docs/source/releases.rst2
-rw-r--r--examples/backends/plot_dual_ot_pytorch.py168
-rw-r--r--examples/backends/plot_sliced_wass_grad_flow_pytorch.py2
-rw-r--r--examples/backends/plot_stoch_continuous_ot_pytorch.py189
-rw-r--r--examples/backends/plot_wass1d_torch.py8
-rw-r--r--examples/barycenters/plot_free_support_barycenter.py55
-rwxr-xr-xexamples/gromov/plot_gromov_wasserstein_dictionary_learning.py357
-rw-r--r--examples/others/plot_WeakOT_VS_OT.py98
-rw-r--r--examples/others/plot_factored_coupling.py86
-rw-r--r--examples/others/plot_logo.py112
-rw-r--r--examples/others/plot_screenkhorn_1D.py (renamed from examples/plot_screenkhorn_1D.py)6
-rw-r--r--examples/others/plot_stochastic.py (renamed from examples/plot_stochastic.py)0
-rw-r--r--examples/plot_Intro_OT.py4
-rw-r--r--examples/plot_OT_1D.py12
-rw-r--r--examples/plot_OT_1D_smooth.py6
-rw-r--r--examples/plot_OT_2D_samples.py7
-rw-r--r--examples/plot_OT_L1_vs_L2.py34
-rw-r--r--examples/plot_compute_emd.py72
-rw-r--r--examples/plot_optim_OTreg.py38
-rw-r--r--examples/sliced-wasserstein/README.txt2
-rw-r--r--examples/sliced-wasserstein/plot_variance.py8
-rw-r--r--examples/unbalanced-partial/plot_UOT_1D.py17
-rw-r--r--examples/unbalanced-partial/plot_regpath.py88
-rw-r--r--examples/unbalanced-partial/plot_unbalanced_OT.py116
-rw-r--r--ot/__init__.py14
-rw-r--r--ot/backend.py306
-rw-r--r--ot/bregman.py17
-rw-r--r--ot/da.py382
-rw-r--r--ot/dr.py44
-rw-r--r--ot/factored.py145
-rw-r--r--ot/gpu/__init__.py50
-rw-r--r--ot/gpu/bregman.py196
-rw-r--r--ot/gpu/da.py144
-rw-r--r--ot/gpu/utils.py101
-rw-r--r--ot/gromov.py1109
-rw-r--r--ot/lp/__init__.py123
-rw-r--r--ot/lp/cvx.py1
-rw-r--r--ot/optim.py11
-rwxr-xr-xot/partial.py84
-rw-r--r--ot/plot.py7
-rw-r--r--ot/regpath.py545
-rw-r--r--ot/stochastic.py242
-rw-r--r--ot/unbalanced.py525
-rw-r--r--ot/utils.py36
-rw-r--r--ot/weak.py124
-rw-r--r--pyproject.toml2
-rw-r--r--requirements.txt1
-rw-r--r--setup.py5
-rw-r--r--test/test_1d_solver.py28
-rw-r--r--test/test_backend.py66
-rw-r--r--test/test_bregman.py94
-rw-r--r--test/test_da.py307
-rw-r--r--test/test_dr.py22
-rw-r--r--test/test_factored.py56
-rw-r--r--test/test_gpu.py106
-rw-r--r--test/test_gromov.py726
-rw-r--r--test/test_optim.py17
-rw-r--r--test/test_ot.py42
-rw-r--r--test/test_sliced.py32
-rw-r--r--test/test_stochastic.py115
-rw-r--r--test/test_unbalanced.py207
-rw-r--r--test/test_utils.py22
-rw-r--r--test/test_weak.py52
84 files changed, 6474 insertions, 2042 deletions
diff --git a/.circleci/config.yml b/.circleci/config.yml
index f5cb756..7e15a65 100644
--- a/.circleci/config.yml
+++ b/.circleci/config.yml
@@ -36,20 +36,26 @@ jobs:
- pip-cache
- run:
+ name: Install ffmpeg
+ command: |
+ sudo apt update
+ sudo apt install ffmpeg
+
+ - run:
name: Get Python running
command: |
python -m pip install --user --upgrade --progress-bar off pip
python -m pip install --user -e .
python -m pip install --user --upgrade --no-cache-dir --progress-bar off -r requirements.txt
python -m pip install --user --upgrade --progress-bar off -r docs/requirements.txt
- python -m pip install --user --upgrade --progress-bar off ipython "https://api.github.com/repos/sphinx-gallery/sphinx-gallery/zipball/master" memory_profiler
-
-
+ python -m pip install --user --upgrade --progress-bar off ipython sphinx-gallery memory_profiler
+ # python -m pip install --user --upgrade --progress-bar off ipython "https://api.github.com/repos/sphinx-gallery/sphinx-gallery/zipball/master" memory_profiler
- save_cache:
key: pip-cache
paths:
- ~/.cache/pip
+
# Look at what we have and fail early if there is some library conflict
- run:
name: Check installation
@@ -57,6 +63,11 @@ jobs:
which python
python -c "import ot"
+ - run:
+ name: Correct link in release file
+ command: |
+ sed -i -r 's/PR #([[:digit:]]*)/\[PR #\1\]\(https:\/\/github.com\/PythonOT\/POT\/pull\/\1\)/' RELEASES.md
+ sed -i -r 's/Issue #([[:digit:]]*)/\[Issue #\1\]\(https:\/\/github.com\/PythonOT\/POT\/issues\/\1\)/' RELEASES.md
# Build docs
- run:
name: make html
@@ -106,10 +117,10 @@ jobs:
echo "Deploying dev docs for ${CIRCLE_BRANCH}.";
cd master
cp -a /tmp/build/html/* .;
- cp -a /tmp/build/html/.github .github;
+ cp -a /tmp/build/html/.github/* .github/;
touch .nojekyll;
git add -A;
- git add -f .github/*.html ;
+ git add -f .github/* ;
git commit -m "CircleCI update of dev docs (${CIRCLE_BUILD_NUM}).";
git push origin master;
else
@@ -146,10 +157,10 @@ jobs:
git clean -xdf
echo "Deploying dev docs for ${CIRCLE_BRANCH}.";
cp -a /tmp/build/html/* .;
- cp -a /tmp/build/html/.github .github;
+ cp -a /tmp/build/html/.github/* .github/;
touch .nojekyll;
git add -A;
- git add -f .github/*.html ;
+ git add -f .github/* ;
git commit -m "CircleCI update of dev docs (${CIRCLE_BUILD_NUM}).";
git push origin master;
diff --git a/.github/workflows/build_tests.yml b/.github/workflows/build_tests.yml
index 3c99da8..ce725c6 100644
--- a/.github/workflows/build_tests.yml
+++ b/.github/workflows/build_tests.yml
@@ -22,7 +22,7 @@ jobs:
strategy:
max-parallel: 4
matrix:
- python-version: ["3.7", "3.8", "3.9"]
+ python-version: ["3.7", "3.8", "3.9", "3.10"]
steps:
- uses: actions/checkout@v1
@@ -93,7 +93,7 @@ jobs:
strategy:
max-parallel: 4
matrix:
- python-version: ["3.7", "3.8", "3.9"]
+ python-version: ["3.7", "3.8", "3.9", "3.10"]
steps:
- uses: actions/checkout@v1
@@ -120,7 +120,7 @@ jobs:
strategy:
max-parallel: 4
matrix:
- python-version: ["3.7", "3.8", "3.9"]
+ python-version: ["3.7", "3.8", "3.9", "3.10"]
steps:
- uses: actions/checkout@v1
diff --git a/.github/workflows/build_wheels.yml b/.github/workflows/build_wheels.yml
index c746eb8..475058c 100644
--- a/.github/workflows/build_wheels.yml
+++ b/.github/workflows/build_wheels.yml
@@ -27,8 +27,6 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
- pip install -r requirements.txt
- pip install -U "cython"
- name: Install cibuildwheel
run: |
@@ -37,7 +35,6 @@ jobs:
- name: Build wheels
env:
CIBW_SKIP: "pp*-win* pp*-macosx* cp2* pp* cp36*" # remove pypy on mac and win (wrong version)
- CIBW_BEFORE_BUILD: "pip install numpy cython"
run: |
python -m cibuildwheel --output-dir wheelhouse
@@ -65,8 +62,6 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
- pip install -r requirements.txt
- pip install -U "cython"
- name: Install cibuildwheel
run: |
@@ -80,8 +75,7 @@ jobs:
- name: Build wheels
env:
- CIBW_SKIP: "pp*-win* pp*-macosx* cp2* pp* cp*musl* cp36*" # remove pypy on mac and win (wrong version)
- CIBW_BEFORE_BUILD: "pip install numpy cython"
+ CIBW_SKIP: "pp*-win* pp*-macosx* cp2* pp* cp*musl*" # remove pypy on mac and win (wrong version)
CIBW_ARCHS_LINUX: auto aarch64 # force aarch64 with QEMU
CIBW_ARCHS_MACOS: x86_64 universal2 arm64
run: |
diff --git a/.github/workflows/build_wheels_weekly.yml b/.github/workflows/build_wheels_weekly.yml
index dbf342f..b9154c5 100644
--- a/.github/workflows/build_wheels_weekly.yml
+++ b/.github/workflows/build_wheels_weekly.yml
@@ -26,8 +26,6 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
- pip install -r requirements.txt
- pip install -U "cython"
- name: Install cibuildwheel
run: |
diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md
new file mode 100644
index 0000000..ab64fba
--- /dev/null
+++ b/CONTRIBUTORS.md
@@ -0,0 +1,52 @@
+
+
+## Creators and Maintainers
+
+This toolbox has been created and is maintained by:
+
+* [Rémi Flamary](http://remi.flamary.com/)
+* [Nicolas Courty](http://people.irisa.fr/Nicolas.Courty/)
+
+## Contributors
+
+The contributors to this library are:
+
+* [Rémi Flamary](http://remi.flamary.com/) (EMD wrapper, Pytorch backend, DA
+ classes, conditional gradients, WDA, weak OT, linear OT mapping, documentation)
+* [Nicolas Courty](http://people.irisa.fr/Nicolas.Courty/) (Original sinkhorn,
+ Wasserstein barycenters and convolutional barycenters, 1D wasserstein)
+* [Alexandre Gramfort](http://alexandre.gramfort.net/) (CI, documentation)
+* [Laetitia Chapel](http://people.irisa.fr/Laetitia.Chapel/) (Partial OT,
+ Unbalanced OT non-regularized)
+* [Michael Perrot](http://perso.univ-st-etienne.fr/pem82055/) (Mapping estimation)
+* [Léo Gautheron](https://github.com/aje) (Initial GPU implementation)
+* [Nathalie Gayraud](https://www.linkedin.com/in/nathalie-t-h-gayraud/?ppe=1) (DA classes)
+* [Stanislas Chambon](https://slasnista.github.io/) (DA classes)
+* [Antoine Rolet](https://arolet.github.io/) (EMD solver debug)
+* Erwan Vautier (Gromov-Wasserstein)
+* [Kilian Fatras](https://kilianfatras.github.io/) (Stochastic solvers,
+ empirical sinkhorn)
+* [Alain Rakotomamonjy](https://sites.google.com/site/alainrakotomamonjy/home) (Greenkhorn)
+* [Vayer Titouan](https://tvayer.github.io/) (Gromov-Wasserstein, Fused-Gromov-Wasserstein)
+* [Hicham Janati](https://hichamjanati.github.io/) (Unbalanced OT, Debiased barycenters)
+* [Romain Tavenard](https://rtavenar.github.io/) (1D Wasserstein)
+* [Mokhtar Z. Alaya](http://mzalaya.github.io/) (Screenkhorn)
+* [Ievgen Redko](https://ievred.github.io/) (Laplacian DA, JCPOT)
+* [Adrien Corenflos](https://adriencorenflos.github.io/) (Sliced Wasserstein Distance)
+* [Tanguy Kerdoncuff](https://hv0nnus.github.io/) (Sampled Gromov Wasserstein)
+* [Minhui Huang](https://mhhuang95.github.io) (Projection Robust Wasserstein Distance)
+* [Nathan Cassereau](https://github.com/ncassereau-idris) (Backends)
+* [Cédric Vincent-Cuaz](https://github.com/cedricvincentcuaz) (Graph Dictionary Learning)
+
+## Acknowledgments
+
+This toolbox benefit a lot from open source research and we would like to thank the following persons for providing some code (in various languages):
+
+* [Gabriel Peyré](http://gpeyre.github.io/) (Wasserstein Barycenters in Matlab)
+* [Mathieu Blondel](https://mblondel.org/) (original implementation smooth OT)
+* [Nicolas Bonneel](http://liris.cnrs.fr/~nbonneel/) (C++ code for EMD)
+* [Marco Cuturi](http://marcocuturi.net/) (Sinkhorn Knopp in Matlab/Cuda)
+
+POT has benefited from the financing or manpower from the following partners:
+
+<img src="https://pythonot.github.io/master/_static/images/logo_anr.jpg" alt="ANR" style="height:60px;"/><img src="https://pythonot.github.io/master/_static/images/logo_cnrs.jpg" alt="CNRS" style="height:60px;"/><img src="https://pythonot.github.io/master/_static/images/logo_3ia.jpg" alt="3IA" style="height:60px;"/> \ No newline at end of file
diff --git a/README.md b/README.md
index 17fbe81..e2b33d9 100644
--- a/README.md
+++ b/README.md
@@ -25,16 +25,21 @@ POT provides the following generic OT solvers (links to examples):
* Sinkhorn divergence [23] and entropic regularization OT from empirical data.
* Debiased Sinkhorn barycenters [Sinkhorn divergence barycenter](https://pythonot.github.io/auto_examples/barycenters/plot_debiased_barycenter.html) [37]
* [Smooth optimal transport solvers](https://pythonot.github.io/auto_examples/plot_OT_1D_smooth.html) (dual and semi-dual) for KL and squared L2 regularizations [17].
+* Weak OT solver between empirical distributions [39]
* Non regularized [Wasserstein barycenters [16] ](https://pythonot.github.io/auto_examples/barycenters/plot_barycenter_lp_vs_entropic.html)) with LP solver (only small scale).
* [Gromov-Wasserstein distances](https://pythonot.github.io/auto_examples/gromov/plot_gromov.html) and [GW barycenters](https://pythonot.github.io/auto_examples/gromov/plot_gromov_barycenter.html) (exact [13] and regularized [12]), differentiable using gradients from
* [Fused-Gromov-Wasserstein distances solver](https://pythonot.github.io/auto_examples/gromov/plot_fgw.html#sphx-glr-auto-examples-plot-fgw-py) and [FGW barycenters](https://pythonot.github.io/auto_examples/gromov/plot_barycenter_fgw.html) [24]
-* [Stochastic solver](https://pythonot.github.io/auto_examples/plot_stochastic.html) for Large-scale Optimal Transport (semi-dual problem [18] and dual problem [19])
-* [Stochastic solver of Gromov Wasserstein](https://pythonot.github.io/auto_examples/gromov/plot_gromov.html) for large-scale problem with any loss functions [33]
+* [Stochastic
+ solver](https://pythonot.github.io/auto_examples/others/plot_stochastic.html) and
+ [differentiable losses](https://pythonot.github.io/auto_examples/backends/plot_stoch_continuous_ot_pytorch.html) for
+ Large-scale Optimal Transport (semi-dual problem [18] and dual problem [19])
+* [Sampled solver of Gromov Wasserstein](https://pythonot.github.io/auto_examples/gromov/plot_gromov.html) for large-scale problem with any loss functions [33]
* Non regularized [free support Wasserstein barycenters](https://pythonot.github.io/auto_examples/barycenters/plot_free_support_barycenter.html) [20].
-* [Unbalanced OT](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT_1D.html) with KL relaxation and [barycenter](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT_barycenter_1D.html) [10, 25].
+* [One dimensional Unbalanced OT](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT_1D.html) with KL relaxation and [barycenter](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT_barycenter_1D.html) [10, 25]. Also [exact unbalanced OT](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_unbalanced_ot.html) with KL and quadratic regularization and the [regularization path of UOT](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_regpath.html) [41]
* [Partial Wasserstein and Gromov-Wasserstein](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_partial_wass_and_gromov.html) (exact [29] and entropic [3]
formulations).
* [Sliced Wasserstein](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_variance.html) [31, 32] and Max-sliced Wasserstein [35] that can be used for gradient flows [36].
+* [Graph Dictionary Learning solvers](https://pythonot.github.io/auto_examples/gromov/plot_gromov_wasserstein_dictionary_learning.html) [38].
* [Several backends](https://pythonot.github.io/quickstart.html#solving-ot-with-multiple-backends) for easy use of POT with [Pytorch](https://pytorch.org/)/[jax](https://github.com/google/jax)/[Numpy](https://numpy.org/)/[Cupy](https://cupy.dev/)/[Tensorflow](https://www.tensorflow.org/) arrays.
POT provides the following Machine Learning related solvers:
@@ -117,7 +122,7 @@ Note that for easier access the module is named `ot` instead of `pot`.
### Dependencies
-Some sub-modules require additional dependences which are discussed below
+Some sub-modules require additional dependencies which are discussed below
* **ot.dr** (Wasserstein dimensionality reduction) depends on autograd and pymanopt that can be installed with:
@@ -125,7 +130,6 @@ Some sub-modules require additional dependences which are discussed below
pip install pymanopt autograd
```
-* **ot.gpu** (GPU accelerated OT) depends on cupy that have to be installed following instructions on [this page](https://docs-cupy.chainer.org/en/stable/install.html). Obviously you will need CUDA installed and a compatible GPU. Note that this module is deprecated since version 0.8 and will be deleted in the future. GPU is now handled automatically through the backends and several solver already can run on GPU using the Pytorch backend.
## Examples
@@ -176,34 +180,12 @@ This toolbox has been created and is maintained by
* [Rémi Flamary](http://remi.flamary.com/)
* [Nicolas Courty](http://people.irisa.fr/Nicolas.Courty/)
-The contributors to this library are
-
-* [Alexandre Gramfort](http://alexandre.gramfort.net/) (CI, documentation)
-* [Laetitia Chapel](http://people.irisa.fr/Laetitia.Chapel/) (Partial OT)
-* [Michael Perrot](http://perso.univ-st-etienne.fr/pem82055/) (Mapping estimation)
-* [Léo Gautheron](https://github.com/aje) (GPU implementation)
-* [Nathalie Gayraud](https://www.linkedin.com/in/nathalie-t-h-gayraud/?ppe=1) (DA classes)
-* [Stanislas Chambon](https://slasnista.github.io/) (DA classes)
-* [Antoine Rolet](https://arolet.github.io/) (EMD solver debug)
-* Erwan Vautier (Gromov-Wasserstein)
-* [Kilian Fatras](https://kilianfatras.github.io/) (Stochastic solvers)
-* [Alain Rakotomamonjy](https://sites.google.com/site/alainrakotomamonjy/home)
-* [Vayer Titouan](https://tvayer.github.io/) (Gromov-Wasserstein -, Fused-Gromov-Wasserstein)
-* [Hicham Janati](https://hichamjanati.github.io/) (Unbalanced OT, Debiased barycenters)
-* [Romain Tavenard](https://rtavenar.github.io/) (1d Wasserstein)
-* [Mokhtar Z. Alaya](http://mzalaya.github.io/) (Screenkhorn)
-* [Ievgen Redko](https://ievred.github.io/) (Laplacian DA, JCPOT)
-* [Adrien Corenflos](https://adriencorenflos.github.io/) (Sliced Wasserstein Distance)
-* [Tanguy Kerdoncuff](https://hv0nnus.github.io/) (Sampled Gromov Wasserstein)
-* [Minhui Huang](https://mhhuang95.github.io) (Projection Robust Wasserstein Distance)
-* [Nathan Cassereau](https://github.com/ncassereau-idris) (Backends)
-
-This toolbox benefit a lot from open source research and we would like to thank the following persons for providing some code (in various languages):
-
-* [Gabriel Peyré](http://gpeyre.github.io/) (Wasserstein Barycenters in Matlab)
-* [Mathieu Blondel](https://mblondel.org/) (original implementation smooth OT)
-* [Nicolas Bonneel](http://liris.cnrs.fr/~nbonneel/) (C++ code for EMD)
-* [Marco Cuturi](http://marcocuturi.net/) (Sinkhorn Knopp in Matlab/Cuda)
+The numerous contributors to this library are listed [here](CONTRIBUTORS.md).
+
+POT has benefited from the financing or manpower from the following partners:
+
+<img src="https://pythonot.github.io/master/_static/images/logo_anr.jpg" alt="ANR" style="height:60px;"/><img src="https://pythonot.github.io/master/_static/images/logo_cnrs.jpg" alt="CNRS" style="height:60px;"/><img src="https://pythonot.github.io/master/_static/images/logo_3ia.jpg" alt="3IA" style="height:60px;"/>
+
## Contributions and code of conduct
@@ -301,3 +283,9 @@ Conference on Machine Learning, PMLR 119:4692-4701, 2020
[38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, [Online Graph
Dictionary Learning](https://arxiv.org/pdf/2102.06555.pdf), International Conference on Machine Learning (ICML), 2021.
+
+[39] Gozlan, N., Roberto, C., Samson, P. M., & Tetali, P. (2017). [Kantorovich duality for general transport costs and applications](https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.712.1825&rep=rep1&type=pdf). Journal of Functional Analysis, 273(11), 3327-3405.
+
+[40] Forrow, A., Hütter, J. C., Nitzan, M., Rigollet, P., Schiebinger, G., & Weed, J. (2019, April). [Statistical optimal transport via factored couplings](http://proceedings.mlr.press/v89/forrow19a/forrow19a.pdf). In The 22nd International Conference on Artificial Intelligence and Statistics (pp. 2454-2465). PMLR.
+
+[41] Chapel*, L., Flamary*, R., Wu, H., Févotte, C., Gasso, G. (2021). [Unbalanced Optimal Transport through Non-negative Penalized Linear Regression](https://proceedings.neurips.cc/paper/2021/file/c3c617a9b80b3ae1ebd868b0017cc349-Paper.pdf) Advances in Neural Information Processing Systems (NeurIPS), 2020. (Two first co-authors) \ No newline at end of file
diff --git a/RELEASES.md b/RELEASES.md
index 00af0fb..be2192e 100644
--- a/RELEASES.md
+++ b/RELEASES.md
@@ -1,6 +1,77 @@
# Releases
+## 0.8.2
+
+This releases introduces several new notable features. The less important
+but most exiting one being that we now have a logo for the toolbox (color
+and dark background) :
+
+![](https://pythonot.github.io/master/_images/logo.svg)![](https://pythonot.github.io/master/_static/logo_dark.svg)
+
+This logo is generated using with matplotlib and using the solution of an OT
+problem provided by POT (with `ot.emd`). Generating the logo can be done with a
+simple python script also provided in the [documentation gallery](https://pythonot.github.io/auto_examples/others/plot_logo.html#sphx-glr-auto-examples-others-plot-logo-py).
+
+New OT solvers include [Weak
+OT](https://pythonot.github.io/gen_modules/ot.weak.html#ot.weak.weak_optimal_transport)
+ and [OT with factored
+coupling](https://pythonot.github.io/gen_modules/ot.factored.html#ot.factored.factored_optimal_transport)
+that can be used on large datasets. The [Majorization Minimization](https://pythonot.github.io/gen_modules/ot.unbalanced.html?highlight=mm_#ot.unbalanced.mm_unbalanced) solvers for
+non-regularized Unbalanced OT are now also available. We also now provide an
+implementation of [GW and FGW unmixing](https://pythonot.github.io/gen_modules/ot.gromov.html#ot.gromov.gromov_wasserstein_linear_unmixing) and [dictionary learning](https://pythonot.github.io/gen_modules/ot.gromov.html#ot.gromov.gromov_wasserstein_dictionary_learning). It is now
+possible to use autodiff to solve entropic an quadratic regularized OT in the
+dual for full or stochastic optimization thanks to the new functions to compute
+the dual loss for [entropic](https://pythonot.github.io/gen_modules/ot.stochastic.html#ot.stochastic.loss_dual_entropic) and [quadratic](https://pythonot.github.io/gen_modules/ot.stochastic.html#ot.stochastic.loss_dual_quadratic) regularized OT and reconstruct the [OT
+plan](https://pythonot.github.io/gen_modules/ot.stochastic.html#ot.stochastic.plan_dual_entropic) on part or all of the data. They can be used for instance to solve OT
+problems with stochastic gradient or for estimating the [dual potentials as
+neural networks](https://pythonot.github.io/auto_examples/backends/plot_stoch_continuous_ot_pytorch.html#sphx-glr-auto-examples-backends-plot-stoch-continuous-ot-pytorch-py).
+
+On the backend front, we now have backend compatible functions and classes in
+the domain adaptation [`ot.da`](https://pythonot.github.io/gen_modules/ot.da.html#module-ot.da) and unbalanced OT [`ot.unbalanced`](https://pythonot.github.io/gen_modules/ot.unbalanced.html) modules. This
+means that the DA classes can be used on tensors from all compatible backends.
+The [free support Wasserstein barycenter](https://pythonot.github.io/gen_modules/ot.lp.html?highlight=free%20support#ot.lp.free_support_barycenter) solver is now also backend compatible.
+
+Finally we have worked on the documentation to provide an update of existing
+examples in the gallery and and several new examples including [GW dictionary
+learning](https://pythonot.github.io/auto_examples/gromov/plot_gromov_wasserstein_dictionary_learning.html#sphx-glr-auto-examples-gromov-plot-gromov-wasserstein-dictionary-learning-py)
+[weak Optimal
+Transport](https://pythonot.github.io/auto_examples/others/plot_WeakOT_VS_OT.html#sphx-glr-auto-examples-others-plot-weakot-vs-ot-py),
+[NN based dual potentials
+estimation](https://pythonot.github.io/auto_examples/backends/plot_stoch_continuous_ot_pytorch.html#sphx-glr-auto-examples-backends-plot-stoch-continuous-ot-pytorch-py)
+and [Factored coupling OT](https://pythonot.github.io/auto_examples/others/plot_factored_coupling.html#sphx-glr-auto-examples-others-plot-factored-coupling-py).
+.
+
+#### New features
+
+- Remove deprecated `ot.gpu` submodule (PR #361)
+- Update examples in the gallery (PR #359)
+- Add stochastic loss and OT plan computation for regularized OT and
+ backend examples(PR #360)
+- Implementation of factored OT with emd and sinkhorn (PR #358)
+- A brand new logo for POT (PR #357)
+- Better list of related examples in quick start guide with `minigallery` (PR #334)
+- Add optional log-domain Sinkhorn implementation in WDA to support smaller values
+ of the regularization parameter (PR #336)
+- Backend implementation for `ot.lp.free_support_barycenter` (PR #340)
+- Add weak OT solver + example (PR #341)
+- Add backend support for Domain Adaptation and Unbalanced solvers (PR #343)
+- Add (F)GW linear dictionary learning solvers + example (PR #319)
+- Add links to related PR and Issues in the doc release page (PR #350)
+- Add new minimization-maximization algorithms for solving exact Unbalanced OT + example (PR #362)
+
+#### Closed issues
+
+- Fix mass gradient of `ot.emd2` and `ot.gromov_wasserstein2` so that they are
+ centered (Issue #364, PR #363)
+- Fix bug in instantiating an `autograd` function `ValFunction` (Issue #337,
+ PR #338)
+- Fix POT ABI compatibility with old and new numpy (Issue #346, PR #349)
+- Warning when feeding integer cost matrix to EMD solver resulting in an integer transport plan (Issue #345, PR #343)
+- Fix bug where gromov_wasserstein2 does not perform backpropagation with CUDA
+ tensors (Issue #351, PR #352)
+
+
## 0.8.1.0
*December 2021*
@@ -43,10 +114,10 @@ As always we want to that the contributors who helped make POT better (and bug f
- Fix bug in older Numpy ABI (<1.20) (Issue #308, PR #326)
- Fix bug in `ot.dist` function when non euclidean distance (Issue #305, PR #306)
-- Fix gradient scaling for functions using `nx.set_gradients` (Issue #309, PR
- #310)
-- Fix bug in generalized Conditional gradient solver and SinkhornL1L2 (Issue
- #311, PR #313)
+- Fix gradient scaling for functions using `nx.set_gradients` (Issue #309,
+ PR #310)
+- Fix bug in generalized Conditional gradient solver and SinkhornL1L2
+ (Issue #311, PR #313)
- Fix log error in `gromov_barycenters` (Issue #317, PR #3018)
## 0.8.0
@@ -227,7 +298,7 @@ are coming for the next versions.
#### Closed issues
-- Add JMLR paper to teh readme ad Mathieu Blondel to the Acknoledgments (PR
+- Add JMLR paper to the readme and Mathieu Blondel to the Acknoledgments (PR
#231, #232)
- Bug in Unbalanced OT example (Issue #127)
- Clean Cython output when calling setup.py clean (Issue #122)
diff --git a/docs/source/auto_examples/images/bak.png b/docs/source/_static/images/bak.png
index 25e7e8e..25e7e8e 100644
--- a/docs/source/auto_examples/images/bak.png
+++ b/docs/source/_static/images/bak.png
Binary files differ
diff --git a/docs/source/_static/images/logo.png b/docs/source/_static/images/logo.png
new file mode 100644
index 0000000..2dd6f65
--- /dev/null
+++ b/docs/source/_static/images/logo.png
Binary files differ
diff --git a/docs/source/_static/images/logo.svg b/docs/source/_static/images/logo.svg
new file mode 100644
index 0000000..39fe900
--- /dev/null
+++ b/docs/source/_static/images/logo.svg
@@ -0,0 +1,200 @@
+<?xml version="1.0" encoding="utf-8" standalone="no"?>
+<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN"
+ "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd">
+<svg xmlns:xlink="http://www.w3.org/1999/xlink" width="209.7pt" height="75.384pt" viewBox="0 0 209.7 75.384" xmlns="http://www.w3.org/2000/svg" version="1.1">
+ <metadata>
+ <rdf:RDF xmlns:dc="http://purl.org/dc/elements/1.1/" xmlns:cc="http://creativecommons.org/ns#" xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#">
+ <cc:Work>
+ <dc:type rdf:resource="http://purl.org/dc/dcmitype/StillImage"/>
+ <dc:date>2022-03-30T17:25:32.476826</dc:date>
+ <dc:format>image/svg+xml</dc:format>
+ <dc:creator>
+ <cc:Agent>
+ <dc:title>Matplotlib v3.5.1, https://matplotlib.org/</dc:title>
+ </cc:Agent>
+ </dc:creator>
+ </cc:Work>
+ </rdf:RDF>
+ </metadata>
+ <defs>
+ <style type="text/css">*{stroke-linejoin: round; stroke-linecap: butt}</style>
+ </defs>
+ <g id="figure_1">
+ <g id="patch_1">
+ <path d="M 0 75.384
+L 209.7 75.384
+L 209.7 0
+L 0 0
+L 0 75.384
+z
+" style="fill: none"/>
+ </g>
+ <g id="axes_1">
+ <g id="line2d_1">
+ <path d="M 16.077273 11.885975
+L 47.044503 11.885975
+" clip-path="url(#p367fff45ba)" style="fill: none; stroke: #000000; stroke-opacity: 0.6; stroke-width: 3; stroke-linecap: square"/>
+ </g>
+ <g id="line2d_2">
+ <path d="M 16.077273 22.208385
+L 57.366913 22.208385
+" clip-path="url(#p367fff45ba)" style="fill: none; stroke: #000000; stroke-opacity: 0.6; stroke-width: 3; stroke-linecap: square"/>
+ </g>
+ <g id="line2d_3">
+ <path d="M 16.077273 32.530795
+L 57.366913 32.530795
+" clip-path="url(#p367fff45ba)" style="fill: none; stroke: #000000; stroke-opacity: 0.6; stroke-width: 3; stroke-linecap: square"/>
+ </g>
+ <g id="line2d_4">
+ <path d="M 16.077273 42.853205
+L 47.044503 42.853205
+" clip-path="url(#p367fff45ba)" style="fill: none; stroke: #000000; stroke-opacity: 0.6; stroke-width: 3; stroke-linecap: square"/>
+ </g>
+ <g id="line2d_5">
+ <path d="M 16.077273 53.175615
+L 26.399683 53.175615
+" clip-path="url(#p367fff45ba)" style="fill: none; stroke: #000000; stroke-opacity: 0.6; stroke-width: 3; stroke-linecap: square"/>
+ </g>
+ <g id="line2d_6">
+ <path d="M 16.077273 63.498025
+L 26.399683 63.498025
+" clip-path="url(#p367fff45ba)" style="fill: none; stroke: #000000; stroke-opacity: 0.6; stroke-width: 3; stroke-linecap: square"/>
+ </g>
+ <g id="line2d_7">
+ <path d="M 95.353383 11.885975
+L 107.740275 11.885975
+" clip-path="url(#p367fff45ba)" style="fill: none; stroke: #000000; stroke-opacity: 0.6; stroke-width: 3; stroke-linecap: square"/>
+ </g>
+ <g id="line2d_8">
+ <path d="M 82.96649 22.208385
+L 120.127167 22.208385
+" clip-path="url(#p367fff45ba)" style="fill: none; stroke: #000000; stroke-opacity: 0.6; stroke-width: 3; stroke-linecap: square"/>
+ </g>
+ <g id="line2d_9">
+ <path d="M 76.773044 32.530795
+L 126.320613 32.530795
+" clip-path="url(#p367fff45ba)" style="fill: none; stroke: #000000; stroke-opacity: 0.6; stroke-width: 3; stroke-linecap: square"/>
+ </g>
+ <g id="line2d_10">
+ <path d="M 76.773044 42.853205
+L 126.320613 42.853205
+" clip-path="url(#p367fff45ba)" style="fill: none; stroke: #000000; stroke-opacity: 0.6; stroke-width: 3; stroke-linecap: square"/>
+ </g>
+ <g id="line2d_11">
+ <path d="M 82.96649 53.175615
+L 120.127167 53.175615
+" clip-path="url(#p367fff45ba)" style="fill: none; stroke: #000000; stroke-opacity: 0.6; stroke-width: 3; stroke-linecap: square"/>
+ </g>
+ <g id="line2d_12">
+ <path d="M 95.353383 63.498025
+L 107.740275 63.498025
+" clip-path="url(#p367fff45ba)" style="fill: none; stroke: #000000; stroke-opacity: 0.6; stroke-width: 3; stroke-linecap: square"/>
+ </g>
+ <g id="line2d_13">
+ <path d="M 142.010677 11.885975
+L 193.622727 11.885975
+" clip-path="url(#p367fff45ba)" style="fill: none; stroke: #000000; stroke-opacity: 0.6; stroke-width: 3; stroke-linecap: square"/>
+ </g>
+ <g id="line2d_14">
+ <path d="M 142.010677 22.208385
+L 193.622727 22.208385
+" clip-path="url(#p367fff45ba)" style="fill: none; stroke: #000000; stroke-opacity: 0.6; stroke-width: 3; stroke-linecap: square"/>
+ </g>
+ <g id="line2d_15">
+ <path d="M 162.655497 32.530795
+L 172.977907 32.530795
+" clip-path="url(#p367fff45ba)" style="fill: none; stroke: #000000; stroke-opacity: 0.6; stroke-width: 3; stroke-linecap: square"/>
+ </g>
+ <g id="line2d_16">
+ <path d="M 162.655497 42.853205
+L 172.977907 42.853205
+" clip-path="url(#p367fff45ba)" style="fill: none; stroke: #000000; stroke-opacity: 0.6; stroke-width: 3; stroke-linecap: square"/>
+ </g>
+ <g id="line2d_17">
+ <path d="M 162.655497 53.175615
+L 172.977907 53.175615
+" clip-path="url(#p367fff45ba)" style="fill: none; stroke: #000000; stroke-opacity: 0.6; stroke-width: 3; stroke-linecap: square"/>
+ </g>
+ <g id="line2d_18">
+ <path d="M 162.655497 63.498025
+L 172.977907 63.498025
+" clip-path="url(#p367fff45ba)" style="fill: none; stroke: #000000; stroke-opacity: 0.6; stroke-width: 3; stroke-linecap: square"/>
+ </g>
+ <g id="line2d_19">
+ <defs>
+ <path id="m5ead2df136" d="M 0 3
+C 0.795609 3 1.55874 2.683901 2.12132 2.12132
+C 2.683901 1.55874 3 0.795609 3 0
+C 3 -0.795609 2.683901 -1.55874 2.12132 -2.12132
+C 1.55874 -2.683901 0.795609 -3 0 -3
+C -0.795609 -3 -1.55874 -2.683901 -2.12132 -2.12132
+C -2.683901 -1.55874 -3 -0.795609 -3 0
+C -3 0.795609 -2.683901 1.55874 -2.12132 2.12132
+C -1.55874 2.683901 -0.795609 3 0 3
+z
+" style="stroke: #000000"/>
+ </defs>
+ <g clip-path="url(#p367fff45ba)">
+ <use xlink:href="#m5ead2df136" x="16.077273" y="11.885975" style="fill: #d62728; stroke: #000000"/>
+ <use xlink:href="#m5ead2df136" x="16.077273" y="22.208385" style="fill: #d62728; stroke: #000000"/>
+ <use xlink:href="#m5ead2df136" x="16.077273" y="32.530795" style="fill: #d62728; stroke: #000000"/>
+ <use xlink:href="#m5ead2df136" x="16.077273" y="42.853205" style="fill: #d62728; stroke: #000000"/>
+ <use xlink:href="#m5ead2df136" x="16.077273" y="53.175615" style="fill: #d62728; stroke: #000000"/>
+ <use xlink:href="#m5ead2df136" x="16.077273" y="63.498025" style="fill: #d62728; stroke: #000000"/>
+ <use xlink:href="#m5ead2df136" x="95.353383" y="11.885975" style="fill: #d62728; stroke: #000000"/>
+ <use xlink:href="#m5ead2df136" x="82.96649" y="22.208385" style="fill: #d62728; stroke: #000000"/>
+ <use xlink:href="#m5ead2df136" x="76.773044" y="32.530795" style="fill: #d62728; stroke: #000000"/>
+ <use xlink:href="#m5ead2df136" x="76.773044" y="42.853205" style="fill: #d62728; stroke: #000000"/>
+ <use xlink:href="#m5ead2df136" x="82.96649" y="53.175615" style="fill: #d62728; stroke: #000000"/>
+ <use xlink:href="#m5ead2df136" x="95.353383" y="63.498025" style="fill: #d62728; stroke: #000000"/>
+ <use xlink:href="#m5ead2df136" x="142.010677" y="11.885975" style="fill: #d62728; stroke: #000000"/>
+ <use xlink:href="#m5ead2df136" x="142.010677" y="22.208385" style="fill: #d62728; stroke: #000000"/>
+ <use xlink:href="#m5ead2df136" x="162.655497" y="32.530795" style="fill: #d62728; stroke: #000000"/>
+ <use xlink:href="#m5ead2df136" x="162.655497" y="42.853205" style="fill: #d62728; stroke: #000000"/>
+ <use xlink:href="#m5ead2df136" x="162.655497" y="53.175615" style="fill: #d62728; stroke: #000000"/>
+ <use xlink:href="#m5ead2df136" x="162.655497" y="63.498025" style="fill: #d62728; stroke: #000000"/>
+ </g>
+ </g>
+ <g id="line2d_20">
+ <defs>
+ <path id="m39fe4d1791" d="M 0 3
+C 0.795609 3 1.55874 2.683901 2.12132 2.12132
+C 2.683901 1.55874 3 0.795609 3 0
+C 3 -0.795609 2.683901 -1.55874 2.12132 -2.12132
+C 1.55874 -2.683901 0.795609 -3 0 -3
+C -0.795609 -3 -1.55874 -2.683901 -2.12132 -2.12132
+C -2.683901 -1.55874 -3 -0.795609 -3 0
+C -3 0.795609 -2.683901 1.55874 -2.12132 2.12132
+C -1.55874 2.683901 -0.795609 3 0 3
+z
+" style="stroke: #000000"/>
+ </defs>
+ <g clip-path="url(#p367fff45ba)">
+ <use xlink:href="#m39fe4d1791" x="47.044503" y="11.885975" style="fill: #0000ff; stroke: #000000"/>
+ <use xlink:href="#m39fe4d1791" x="57.366913" y="32.530795" style="fill: #0000ff; stroke: #000000"/>
+ <use xlink:href="#m39fe4d1791" x="57.366913" y="22.208385" style="fill: #0000ff; stroke: #000000"/>
+ <use xlink:href="#m39fe4d1791" x="47.044503" y="42.853205" style="fill: #0000ff; stroke: #000000"/>
+ <use xlink:href="#m39fe4d1791" x="26.399683" y="53.175615" style="fill: #0000ff; stroke: #000000"/>
+ <use xlink:href="#m39fe4d1791" x="26.399683" y="63.498025" style="fill: #0000ff; stroke: #000000"/>
+ <use xlink:href="#m39fe4d1791" x="107.740275" y="11.885975" style="fill: #0000ff; stroke: #000000"/>
+ <use xlink:href="#m39fe4d1791" x="120.127167" y="22.208385" style="fill: #0000ff; stroke: #000000"/>
+ <use xlink:href="#m39fe4d1791" x="126.320613" y="32.530795" style="fill: #0000ff; stroke: #000000"/>
+ <use xlink:href="#m39fe4d1791" x="126.320613" y="42.853205" style="fill: #0000ff; stroke: #000000"/>
+ <use xlink:href="#m39fe4d1791" x="120.127167" y="53.175615" style="fill: #0000ff; stroke: #000000"/>
+ <use xlink:href="#m39fe4d1791" x="107.740275" y="63.498025" style="fill: #0000ff; stroke: #000000"/>
+ <use xlink:href="#m39fe4d1791" x="193.622727" y="11.885975" style="fill: #0000ff; stroke: #000000"/>
+ <use xlink:href="#m39fe4d1791" x="193.622727" y="22.208385" style="fill: #0000ff; stroke: #000000"/>
+ <use xlink:href="#m39fe4d1791" x="172.977907" y="32.530795" style="fill: #0000ff; stroke: #000000"/>
+ <use xlink:href="#m39fe4d1791" x="172.977907" y="42.853205" style="fill: #0000ff; stroke: #000000"/>
+ <use xlink:href="#m39fe4d1791" x="172.977907" y="53.175615" style="fill: #0000ff; stroke: #000000"/>
+ <use xlink:href="#m39fe4d1791" x="172.977907" y="63.498025" style="fill: #0000ff; stroke: #000000"/>
+ </g>
+ </g>
+ </g>
+ </g>
+ <defs>
+ <clipPath id="p367fff45ba">
+ <rect x="7.2" y="7.2" width="195.3" height="60.984"/>
+ </clipPath>
+ </defs>
+</svg>
diff --git a/docs/source/_static/images/logo_3ia.jpg b/docs/source/_static/images/logo_3ia.jpg
new file mode 100644
index 0000000..ecc56b2
--- /dev/null
+++ b/docs/source/_static/images/logo_3ia.jpg
Binary files differ
diff --git a/docs/source/_static/images/logo_anr.jpg b/docs/source/_static/images/logo_anr.jpg
new file mode 100644
index 0000000..dcef212
--- /dev/null
+++ b/docs/source/_static/images/logo_anr.jpg
Binary files differ
diff --git a/docs/source/_static/images/logo_cnrs.jpg b/docs/source/_static/images/logo_cnrs.jpg
new file mode 100644
index 0000000..902cf6f
--- /dev/null
+++ b/docs/source/_static/images/logo_cnrs.jpg
Binary files differ
diff --git a/docs/source/_static/images/logo_dark.png b/docs/source/_static/images/logo_dark.png
new file mode 100644
index 0000000..f484188
--- /dev/null
+++ b/docs/source/_static/images/logo_dark.png
Binary files differ
diff --git a/docs/source/_static/images/logo_dark.svg b/docs/source/_static/images/logo_dark.svg
new file mode 100644
index 0000000..56ce2d9
--- /dev/null
+++ b/docs/source/_static/images/logo_dark.svg
@@ -0,0 +1,187 @@
+<?xml version="1.0" encoding="utf-8" standalone="no"?>
+<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN"
+ "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd">
+<!-- Created with matplotlib (https://matplotlib.org/) -->
+<svg height="75.384pt" version="1.1" viewBox="0 0 209.7 75.384" width="209.7pt" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
+ <metadata>
+ <rdf:RDF xmlns:cc="http://creativecommons.org/ns#" xmlns:dc="http://purl.org/dc/elements/1.1/" xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#">
+ <cc:Work>
+ <dc:type rdf:resource="http://purl.org/dc/dcmitype/StillImage"/>
+ <dc:date>2022-03-17T17:25:30.847142</dc:date>
+ <dc:format>image/svg+xml</dc:format>
+ <dc:creator>
+ <cc:Agent>
+ <dc:title>Matplotlib v3.3.3, https://matplotlib.org/</dc:title>
+ </cc:Agent>
+ </dc:creator>
+ </cc:Work>
+ </rdf:RDF>
+ </metadata>
+ <defs>
+ <style type="text/css">*{stroke-linecap:butt;stroke-linejoin:round;}</style>
+ </defs>
+ <g id="figure_1">
+ <g id="patch_1">
+ <path d="M 0 75.384
+L 209.7 75.384
+L 209.7 0
+L 0 0
+z
+" style="fill:none;"/>
+ </g>
+ <g id="axes_1">
+ <g id="line2d_1">
+ <path clip-path="url(#pa995e487cb)" d="M 16.077273 11.885975
+L 47.044503 11.885975
+" style="fill:none;stroke:#ffffff;stroke-linecap:square;stroke-opacity:0.8;stroke-width:3;"/>
+ </g>
+ <g id="line2d_2">
+ <path clip-path="url(#pa995e487cb)" d="M 16.077273 22.208385
+L 57.366913 22.208385
+" style="fill:none;stroke:#ffffff;stroke-linecap:square;stroke-opacity:0.8;stroke-width:3;"/>
+ </g>
+ <g id="line2d_3">
+ <path clip-path="url(#pa995e487cb)" d="M 16.077273 32.530795
+L 57.366913 32.530795
+" style="fill:none;stroke:#ffffff;stroke-linecap:square;stroke-opacity:0.8;stroke-width:3;"/>
+ </g>
+ <g id="line2d_4">
+ <path clip-path="url(#pa995e487cb)" d="M 16.077273 42.853205
+L 47.044503 42.853205
+" style="fill:none;stroke:#ffffff;stroke-linecap:square;stroke-opacity:0.8;stroke-width:3;"/>
+ </g>
+ <g id="line2d_5">
+ <path clip-path="url(#pa995e487cb)" d="M 16.077273 53.175615
+L 26.399683 53.175615
+" style="fill:none;stroke:#ffffff;stroke-linecap:square;stroke-opacity:0.8;stroke-width:3;"/>
+ </g>
+ <g id="line2d_6">
+ <path clip-path="url(#pa995e487cb)" d="M 16.077273 63.498025
+L 26.399683 63.498025
+" style="fill:none;stroke:#ffffff;stroke-linecap:square;stroke-opacity:0.8;stroke-width:3;"/>
+ </g>
+ <g id="line2d_7">
+ <path clip-path="url(#pa995e487cb)" d="M 95.353383 11.885975
+L 107.740275 11.885975
+" style="fill:none;stroke:#ffffff;stroke-linecap:square;stroke-opacity:0.8;stroke-width:3;"/>
+ </g>
+ <g id="line2d_8">
+ <path clip-path="url(#pa995e487cb)" d="M 82.96649 22.208385
+L 120.127167 22.208385
+" style="fill:none;stroke:#ffffff;stroke-linecap:square;stroke-opacity:0.8;stroke-width:3;"/>
+ </g>
+ <g id="line2d_9">
+ <path clip-path="url(#pa995e487cb)" d="M 76.773044 32.530795
+L 126.320613 32.530795
+" style="fill:none;stroke:#ffffff;stroke-linecap:square;stroke-opacity:0.8;stroke-width:3;"/>
+ </g>
+ <g id="line2d_10">
+ <path clip-path="url(#pa995e487cb)" d="M 76.773044 42.853205
+L 126.320613 42.853205
+" style="fill:none;stroke:#ffffff;stroke-linecap:square;stroke-opacity:0.8;stroke-width:3;"/>
+ </g>
+ <g id="line2d_11">
+ <path clip-path="url(#pa995e487cb)" d="M 82.96649 53.175615
+L 120.127167 53.175615
+" style="fill:none;stroke:#ffffff;stroke-linecap:square;stroke-opacity:0.8;stroke-width:3;"/>
+ </g>
+ <g id="line2d_12">
+ <path clip-path="url(#pa995e487cb)" d="M 95.353383 63.498025
+L 107.740275 63.498025
+" style="fill:none;stroke:#ffffff;stroke-linecap:square;stroke-opacity:0.8;stroke-width:3;"/>
+ </g>
+ <g id="line2d_13">
+ <path clip-path="url(#pa995e487cb)" d="M 142.010677 11.885975
+L 193.622727 11.885975
+" style="fill:none;stroke:#ffffff;stroke-linecap:square;stroke-opacity:0.8;stroke-width:3;"/>
+ </g>
+ <g id="line2d_14">
+ <path clip-path="url(#pa995e487cb)" d="M 142.010677 22.208385
+L 193.622727 22.208385
+" style="fill:none;stroke:#ffffff;stroke-linecap:square;stroke-opacity:0.8;stroke-width:3;"/>
+ </g>
+ <g id="line2d_15">
+ <path clip-path="url(#pa995e487cb)" d="M 162.655497 32.530795
+L 172.977907 32.530795
+" style="fill:none;stroke:#ffffff;stroke-linecap:square;stroke-opacity:0.8;stroke-width:3;"/>
+ </g>
+ <g id="line2d_16">
+ <path clip-path="url(#pa995e487cb)" d="M 162.655497 42.853205
+L 172.977907 42.853205
+" style="fill:none;stroke:#ffffff;stroke-linecap:square;stroke-opacity:0.8;stroke-width:3;"/>
+ </g>
+ <g id="line2d_17">
+ <path clip-path="url(#pa995e487cb)" d="M 162.655497 53.175615
+L 172.977907 53.175615
+" style="fill:none;stroke:#ffffff;stroke-linecap:square;stroke-opacity:0.8;stroke-width:3;"/>
+ </g>
+ <g id="line2d_18">
+ <path clip-path="url(#pa995e487cb)" d="M 162.655497 63.498025
+L 172.977907 63.498025
+" style="fill:none;stroke:#ffffff;stroke-linecap:square;stroke-opacity:0.8;stroke-width:3;"/>
+ </g>
+ <g id="line2d_19">
+ <defs>
+ <path d="M 0 3
+C 0.795609 3 1.55874 2.683901 2.12132 2.12132
+C 2.683901 1.55874 3 0.795609 3 0
+C 3 -0.795609 2.683901 -1.55874 2.12132 -2.12132
+C 1.55874 -2.683901 0.795609 -3 0 -3
+C -0.795609 -3 -1.55874 -2.683901 -2.12132 -2.12132
+C -2.683901 -1.55874 -3 -0.795609 -3 0
+C -3 0.795609 -2.683901 1.55874 -2.12132 2.12132
+C -1.55874 2.683901 -0.795609 3 0 3
+z
+" id="m5a2277d5a1" style="stroke:#ffffff;"/>
+ </defs>
+ <g clip-path="url(#pa995e487cb)">
+ <use style="fill:#ffffff;stroke:#ffffff;" x="16.077273" xlink:href="#m5a2277d5a1" y="11.885975"/>
+ <use style="fill:#ffffff;stroke:#ffffff;" x="16.077273" xlink:href="#m5a2277d5a1" y="22.208385"/>
+ <use style="fill:#ffffff;stroke:#ffffff;" x="16.077273" xlink:href="#m5a2277d5a1" y="32.530795"/>
+ <use style="fill:#ffffff;stroke:#ffffff;" x="16.077273" xlink:href="#m5a2277d5a1" y="42.853205"/>
+ <use style="fill:#ffffff;stroke:#ffffff;" x="16.077273" xlink:href="#m5a2277d5a1" y="53.175615"/>
+ <use style="fill:#ffffff;stroke:#ffffff;" x="16.077273" xlink:href="#m5a2277d5a1" y="63.498025"/>
+ <use style="fill:#ffffff;stroke:#ffffff;" x="95.353383" xlink:href="#m5a2277d5a1" y="11.885975"/>
+ <use style="fill:#ffffff;stroke:#ffffff;" x="82.96649" xlink:href="#m5a2277d5a1" y="22.208385"/>
+ <use style="fill:#ffffff;stroke:#ffffff;" x="76.773044" xlink:href="#m5a2277d5a1" y="32.530795"/>
+ <use style="fill:#ffffff;stroke:#ffffff;" x="76.773044" xlink:href="#m5a2277d5a1" y="42.853205"/>
+ <use style="fill:#ffffff;stroke:#ffffff;" x="82.96649" xlink:href="#m5a2277d5a1" y="53.175615"/>
+ <use style="fill:#ffffff;stroke:#ffffff;" x="95.353383" xlink:href="#m5a2277d5a1" y="63.498025"/>
+ <use style="fill:#ffffff;stroke:#ffffff;" x="142.010677" xlink:href="#m5a2277d5a1" y="11.885975"/>
+ <use style="fill:#ffffff;stroke:#ffffff;" x="142.010677" xlink:href="#m5a2277d5a1" y="22.208385"/>
+ <use style="fill:#ffffff;stroke:#ffffff;" x="162.655497" xlink:href="#m5a2277d5a1" y="32.530795"/>
+ <use style="fill:#ffffff;stroke:#ffffff;" x="162.655497" xlink:href="#m5a2277d5a1" y="42.853205"/>
+ <use style="fill:#ffffff;stroke:#ffffff;" x="162.655497" xlink:href="#m5a2277d5a1" y="53.175615"/>
+ <use style="fill:#ffffff;stroke:#ffffff;" x="162.655497" xlink:href="#m5a2277d5a1" y="63.498025"/>
+ </g>
+ </g>
+ <g id="line2d_20">
+ <g clip-path="url(#pa995e487cb)">
+ <use style="fill:#ffffff;stroke:#ffffff;" x="47.044503" xlink:href="#m5a2277d5a1" y="11.885975"/>
+ <use style="fill:#ffffff;stroke:#ffffff;" x="57.366913" xlink:href="#m5a2277d5a1" y="32.530795"/>
+ <use style="fill:#ffffff;stroke:#ffffff;" x="57.366913" xlink:href="#m5a2277d5a1" y="22.208385"/>
+ <use style="fill:#ffffff;stroke:#ffffff;" x="47.044503" xlink:href="#m5a2277d5a1" y="42.853205"/>
+ <use style="fill:#ffffff;stroke:#ffffff;" x="26.399683" xlink:href="#m5a2277d5a1" y="53.175615"/>
+ <use style="fill:#ffffff;stroke:#ffffff;" x="26.399683" xlink:href="#m5a2277d5a1" y="63.498025"/>
+ <use style="fill:#ffffff;stroke:#ffffff;" x="107.740275" xlink:href="#m5a2277d5a1" y="11.885975"/>
+ <use style="fill:#ffffff;stroke:#ffffff;" x="120.127167" xlink:href="#m5a2277d5a1" y="22.208385"/>
+ <use style="fill:#ffffff;stroke:#ffffff;" x="126.320613" xlink:href="#m5a2277d5a1" y="32.530795"/>
+ <use style="fill:#ffffff;stroke:#ffffff;" x="126.320613" xlink:href="#m5a2277d5a1" y="42.853205"/>
+ <use style="fill:#ffffff;stroke:#ffffff;" x="120.127167" xlink:href="#m5a2277d5a1" y="53.175615"/>
+ <use style="fill:#ffffff;stroke:#ffffff;" x="107.740275" xlink:href="#m5a2277d5a1" y="63.498025"/>
+ <use style="fill:#ffffff;stroke:#ffffff;" x="193.622727" xlink:href="#m5a2277d5a1" y="11.885975"/>
+ <use style="fill:#ffffff;stroke:#ffffff;" x="193.622727" xlink:href="#m5a2277d5a1" y="22.208385"/>
+ <use style="fill:#ffffff;stroke:#ffffff;" x="172.977907" xlink:href="#m5a2277d5a1" y="32.530795"/>
+ <use style="fill:#ffffff;stroke:#ffffff;" x="172.977907" xlink:href="#m5a2277d5a1" y="42.853205"/>
+ <use style="fill:#ffffff;stroke:#ffffff;" x="172.977907" xlink:href="#m5a2277d5a1" y="53.175615"/>
+ <use style="fill:#ffffff;stroke:#ffffff;" x="172.977907" xlink:href="#m5a2277d5a1" y="63.498025"/>
+ </g>
+ </g>
+ </g>
+ </g>
+ <defs>
+ <clipPath id="pa995e487cb">
+ <rect height="60.984" width="195.3" x="7.2" y="7.2"/>
+ </clipPath>
+ </defs>
+</svg>
diff --git a/docs/source/auto_examples/images/sinkhorn.png b/docs/source/_static/images/sinkhorn.png
index e003e13..e003e13 100644
--- a/docs/source/auto_examples/images/sinkhorn.png
+++ b/docs/source/_static/images/sinkhorn.png
Binary files differ
diff --git a/docs/source/all.rst b/docs/source/all.rst
index 6a07599..1ec6be3 100644
--- a/docs/source/all.rst
+++ b/docs/source/all.rst
@@ -20,15 +20,17 @@ API and modules
gromov
optim
da
- gpu
dr
utils
datasets
plot
stochastic
unbalanced
+ regpath
partial
sliced
+ weak
+ factored
.. autosummary::
:toctree: ../modules/generated/
diff --git a/docs/source/conf.py b/docs/source/conf.py
index 849e97c..9526518 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -17,9 +17,15 @@ import os
import re
try:
import sphinx_gallery
+
except ImportError:
print("warning sphinx-gallery not installed")
+
+
+
+
+
# !!!! allow readthedoc compilation
try:
from unittest.mock import MagicMock
@@ -74,7 +80,6 @@ extensions = [
autosummary_generate = True
-
napoleon_numpy_docstring = True
# Add any paths that contain templates here, relative to this directory.
@@ -142,7 +147,7 @@ exclude_patterns = []
#show_authors = False
# The name of the Pygments (syntax highlighting) style to use.
-pygments_style = 'sphinx'
+pygments_style = 'default'
# A list of ignored prefixes for module index sorting.
#modindex_common_prefix = []
@@ -163,6 +168,7 @@ html_theme = 'sphinx_rtd_theme'
# Theme options are theme-specific and customize the look and feel of a theme
# further. For a list of options available for each theme, see the
# documentation.
+
html_theme_options = {}
# Add any paths that contain custom themes here, relative to this directory.
@@ -177,7 +183,7 @@ html_theme_options = {}
# The name of an image file (relative to this directory) to place at the top
# of the sidebar.
-#html_logo = None
+html_logo = '_static/images/logo_dark.svg'
# The name of an image file (relative to this directory) to use as a favicon of
# the docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32
@@ -189,6 +195,7 @@ html_theme_options = {}
# so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ['_static']
+
# Add any extra paths that contain custom files (such as robots.txt or
# .htaccess) here, relative to this directory. These files are copied
# directly to the root of the documentation.
diff --git a/docs/source/contributors.rst b/docs/source/contributors.rst
new file mode 100644
index 0000000..f0acea6
--- /dev/null
+++ b/docs/source/contributors.rst
@@ -0,0 +1,6 @@
+Contributors
+============
+
+.. include:: ../../CONTRIBUTORS.md
+ :parser: myst_parser.sphinx_
+ :start-line: 2
diff --git a/docs/source/index.rst b/docs/source/index.rst
index 8de31ae..3d53ef4 100644
--- a/docs/source/index.rst
+++ b/docs/source/index.rst
@@ -6,6 +6,10 @@
POT: Python Optimal Transport
=============================
+.. image:: _static/images/logo.svg
+ :width: 400
+ :alt: POT Logo
+
Contents
--------
@@ -18,8 +22,10 @@ Contents
auto_examples/index
releases
.github/CONTRIBUTING
+ contributors
.github/CODE_OF_CONDUCT
+
.. include:: ../../README.md
:parser: myst_parser.sphinx_
diff --git a/docs/source/quickstart.rst b/docs/source/quickstart.rst
index 232df7b..b4cc8ab 100644
--- a/docs/source/quickstart.rst
+++ b/docs/source/quickstart.rst
@@ -207,13 +207,12 @@ The method implemented for solving the OT problem is the network simplex. It is
implemented in C from [1]_. It has a complexity of :math:`O(n^3)` but the
solver is quite efficient and uses sparsity of the solution.
-.. hint::
- Examples of use for :any:`ot.emd` are available in :
- - :any:`auto_examples/plot_OT_2D_samples`
- - :any:`auto_examples/plot_OT_1D`
- - :any:`auto_examples/plot_OT_L1_vs_L2`
+.. minigallery:: ot.emd
+ :add-heading: Examples of use for :any:`ot.emd`
+ :heading-level: "
+
Computing Wasserstein distance
@@ -255,11 +254,9 @@ the :math:`W_1` Wasserstein distance can be done directly with :any:`ot.emd2`
when providing :code:`M = ot.dist(xs, xt, metric='euclidean')` to use the Euclidean
distance.
-.. hint::
-
- An example of use for :any:`ot.emd2` is available in :
-
- - :any:`auto_examples/plot_compute_emd`
+.. minigallery:: ot.emd2
+ :add-heading: Examples of use for :any:`ot.emd2`
+ :heading-level: "
Special cases
@@ -416,17 +413,18 @@ of stochastic solvers for entropic regularized OT [18]_ [19]_. Those pure Pytho
implementations are not optimized for speed but provide a robust implementation
of algorithms in [18]_ [19]_.
-.. hint::
- Examples of use for :any:`ot.sinkhorn` are available in :
- - :any:`auto_examples/plot_OT_2D_samples`
- - :any:`auto_examples/plot_OT_1D`
- - :any:`auto_examples/plot_OT_1D_smooth`
- - :any:`auto_examples/plot_stochastic`
+.. minigallery:: ot.sinkhorn
+ :add-heading: Examples of use for :any:`ot.sinkhorn`
+ :heading-level: "
+.. minigallery:: ot.sinkhorn2
+ :add-heading: Examples of use for :any:`ot.sinkhorn2`
+ :heading-level: "
-Other regularization
-^^^^^^^^^^^^^^^^^^^^
+
+Other regularizations
+^^^^^^^^^^^^^^^^^^^^^
While entropic OT is the most common and favored in practice, there exists other
kinds of regularizations. We provide in POT two specific solvers for other
@@ -451,12 +449,9 @@ functions :any:`ot.smooth.smooth_ot_dual` or
:any:`ot.smooth.smooth_ot_semi_dual` with parameter :code:`reg_type='l2'` to
choose the quadratic regularization.
-.. hint::
- Examples of quadratic regularization are available in :
-
- - :any:`auto_examples/plot_OT_1D_smooth`
- - :any:`auto_examples/plot_optim_OTreg`
-
+.. minigallery:: ot.smooth.smooth_ot_dual ot.smooth.smooth_ot_semi_dual ot.optim.cg
+ :add-heading: Examples of use of quadratic regularization
+ :heading-level: "
Group Lasso regularization
@@ -480,11 +475,9 @@ be solved using an efficient majoration minimization approach with
convex group lasso and we provide a solver using generalized conditional
gradient algorithm [7]_ in function :any:`ot.da.sinkhorn_l1l2_gl`.
-.. hint::
- Examples of group Lasso regularization are available in:
-
- - :any:`auto_examples/domain-adaptation/plot_otda_classes`
- - :any:`auto_examples/domain-adaptation/plot_otda_d2`
+.. minigallery:: ot.da.SinkhornLpl1Transport ot.da.SinkhornL1l2Transport ot.da.sinkhorn_l1l2_gl ot.da.sinkhorn_lpl1_mm
+ :add-heading: Examples of group Lasso regularization
+ :heading-level: "
Generic solvers
@@ -520,10 +513,9 @@ generalized conditional gradient [7]_ implemented in :any:`ot.optim.gcg` that
does not linearize the entropic term but
relies on :any:`ot.sinkhorn` for its iterations.
-.. hint::
- An example of generic solvers are available in :
-
- - :any:`auto_examples/plot_optim_OTreg`
+.. minigallery:: ot.optim.cg ot.optim.gcg
+ :add-heading: Examples of the generic solvers
+ :heading-level: "
Wasserstein Barycenters
@@ -581,19 +573,15 @@ the matrix vector production in the Bregman projections by convolution
operators. We provide an implementation of this algorithm in function
:any:`ot.bregman.convolutional_barycenter2d`.
-.. hint::
- Examples of Wasserstein (:meth:`ot.lp.barycenter`) and regularized Wasserstein
- barycenter (:any:`ot.bregman.barycenter`) computation are available in :
- - :any:`auto_examples/barycenters/plot_barycenter_1D`
- - :any:`auto_examples/barycenters/plot_barycenter_lp_vs_entropic`
- An example of convolutional barycenter
- (:any:`ot.bregman.convolutional_barycenter2d`) computation
- for 2D images is available
- in :
+.. minigallery:: ot.lp.barycenter ot.bregman.barycenter ot.barycenter
+ :add-heading: Examples of Wasserstein and regularized Wasserstein barycenters
+ :heading-level: "
- - :any:`auto_examples/barycenters/plot_convolutional_barycenter`
+.. minigallery:: ot.bregman.convolutional_barycenter2d
+ :add-heading: An example of convolutional barycenter (:any:`ot.bregman.convolutional_barycenter2d`) computation
+ :heading-level: "
@@ -613,13 +601,9 @@ We provide a solver based on [20]_ in
return a locally optimal support :math:`\{x_i\}` for uniform or given weights
:math:`a`.
- .. hint::
-
- An example of the free support barycenter estimation is available
- in :
-
- - :any:`auto_examples/barycenters/plot_free_support_barycenter`
-
+.. minigallery:: ot.lp.free_support_barycenter
+ :add-heading: Examples of free support barycenter estimation
+ :heading-level: "
@@ -656,12 +640,10 @@ method proposed in [8]_ that estimates a continuous mapping approximating the
barycentric mapping is provided in :any:`ot.da.joint_OT_mapping_linear` for
linear mapping and :any:`ot.da.joint_OT_mapping_kernel` for non-linear mapping.
- .. hint::
-
- An example of the linear Monge mapping estimation is available
- in :
+.. minigallery:: ot.da.joint_OT_mapping_linear ot.da.joint_OT_mapping_linear ot.da.OT_mapping_linear
+ :add-heading: Examples of Monge mapping estimation
+ :heading-level: "
- - :any:`auto_examples/domain-adaptation/plot_otda_linear_mapping`
Domain adaptation classes
^^^^^^^^^^^^^^^^^^^^^^^^^
@@ -704,14 +686,11 @@ A list of the provided implementation is given in the following note.
[14]_
* :any:`ot.da.MappingTransport`: Nonlinear mapping estimation [8]_
-.. hint::
- Examples of the use of OTDA classes are available in:
+.. minigallery:: ot.da.SinkhornTransport ot.da.LinearTransport
+ :add-heading: Examples of the use of OTDA classes
+ :heading-level: "
- - :any:`auto_examples/domain-adaptation/plot_otda_color_images`
- - :any:`auto_examples/domain-adaptation/plot_otda_mapping`
- - :any:`auto_examples/domain-adaptation/plot_otda_mapping_colors_images`
- - :any:`auto_examples/domain-adaptation/plot_otda_semi_supervised`
Other applications
------------------
@@ -746,11 +725,10 @@ respectively. Note that we also provide the Fisher discriminant estimator in
:code:`autograd`, :any:`ot.dr` is not imported by default. If you want to
use it you have to specifically import it with :code:`import ot.dr` .
-.. hint::
+.. minigallery:: ot.dr.wda
+ :add-heading: Examples of the use of WDA
+ :heading-level: "
- An example of the use of WDA is available in :
-
- - :any:`auto_examples/others/plot_WDA`
Unbalanced optimal transport
@@ -787,11 +765,9 @@ linear term.
the log stabilized version of the algorithm [10]_.
-.. hint::
-
- Examples of the use of :any:`ot.sinkhorn_unbalanced` are available in :
-
- - :any:`auto_examples/unbalanced-partial/plot_UOT_1D`
+.. minigallery:: ot.sinkhorn_unbalanced ot.sinkhorn_unbalanced2 ot.unbalanced.sinkhorn_unbalanced
+ :add-heading: Examples of Unbalanced OT
+ :heading-level: "
Unbalanced Barycenters
@@ -819,11 +795,10 @@ implemented the main function :any:`ot.barycenter_unbalanced`.
the log stabilized version of the algorithm [10]_.
-.. hint::
+.. minigallery:: ot.barycenter_unbalanced ot.unbalanced.barycenter_unbalanced
+ :add-heading: Examples of Unbalanced OT barycenters
+ :heading-level: "
- Examples of the use of :any:`ot.barycenter_unbalanced` are available in :
-
- - :any:`auto_examples/unbalanced-partial/plot_UOT_barycenter_1D`
Partial optimal transport
@@ -865,11 +840,10 @@ is computed in :any:`ot.partial.partial_gromov_wasserstein` and in
regularization of the problem.
-.. hint::
-
- Examples of the use of :any:`ot.partial` are available in:
+.. minigallery:: ot.partial.partial_wasserstein ot.partial.partial_gromov_wasserstein
+ :add-heading: Examples of Partial OT
+ :heading-level: "
- - :any:`auto_examples/unbalanced-partial/plot_partial_wass_and_gromov`
@@ -898,6 +872,12 @@ There also exists an entropic regularized variant of GW that has been proposed i
[12]_ and we provide an implementation of their algorithm in
:any:`ot.gromov.entropic_gromov_wasserstein`.
+
+.. minigallery:: ot.gromov.gromov_wasserstein ot.gromov.entropic_gromov_wasserstein ot.gromov.fused_gromov_wasserstein ot.gromov.gromov_wasserstein2
+ :add-heading: Examples of computation of GW, regularized G and FGW
+ :heading-level: "
+
+
Note that similarly to Wasserstein distance GW allows for the definition of GW
barycenters that can be expressed as
@@ -919,59 +899,15 @@ graphs for instance and also provide computable barycenters.
The implementations of FGW and FGW barycenter is provided in functions
:any:`ot.gromov.fused_gromov_wasserstein` and :any:`ot.gromov.fgw_barycenters`.
-.. hint::
-
- Examples of computation of GW, regularized G and FGW are available in:
-
- - :any:`auto_examples/gromov/plot_gromov`
- - :any:`auto_examples/gromov/plot_fgw`
-
- Examples of GW, regularized GW and FGW barycenters are available in:
-
- - :any:`auto_examples/gromov/plot_gromov_barycenter`
- - :any:`auto_examples/gromov/plot_barycenter_fgw`
+.. minigallery:: ot.gromov.gromov_barycenters ot.gromov.fgw_barycenters
+ :add-heading: Examples of GW, regularized G and FGW barycenters
+ :heading-level: "
-GPU acceleration
-^^^^^^^^^^^^^^^^
-
-.. warning::
-
- The :any:`ot.gpu` has been deprecated since the release 0.8 of POT and
- should not be used. The GPU implementation (in Pytorch for instance) can be
- used with the novel backends using the compatible functions from POT.
-
-
-We provide several implementation of our OT solvers in :any:`ot.gpu`. Those
-implementations use the :code:`cupy` toolbox that obviously need to be installed.
-
-
-.. note::
-
- Several implementations of POT functions (mainly those relying on linear
- algebra) have been implemented in :any:`ot.gpu`. Here is a short list on the
- main entries:
-
- - :meth:`ot.gpu.dist`: computation of distance matrix
- - :meth:`ot.gpu.sinkhorn`: computation of sinkhorn
- - :meth:`ot.gpu.sinkhorn_lpl1_mm`: computation of sinkhorn + group lasso
-Note that while the :any:`ot.gpu` module has been designed to be compatible with
-POT, calling its function with :any:`numpy` arrays will incur a large overhead due to
-the memory copy of the array on GPU prior to computation and conversion of the
-array after computation. To avoid this overhead, we provide functions
-:meth:`ot.gpu.to_gpu` and :meth:`ot.gpu.to_np` that perform the conversion
-explicitly.
-
-.. warning::
-
- Note that due to the hard dependency on :code:`cupy`, :any:`ot.gpu` is not
- imported by default. If you want to
- use it you have to specifically import it with :code:`import ot.gpu` .
-
-Solving OT with Multiple backends
----------------------------------
+Solving OT with Multiple backends on CPU/GPU
+--------------------------------------------
.. _backends_section:
@@ -1002,7 +938,21 @@ the function will be the same type as the inputs and on the same device. When
possible all computations are done on the same device and also when possible the
output will be differentiable with respect to the input of the function.
+GPU acceleration
+^^^^^^^^^^^^^^^^
+
+The backends provide automatic computations/compatibility on GPU for most of the
+POT functions.
+Note that all solvers relying on the exact OT solver en C++ will need to solve the
+problem on CPU which can incur some memory copy overhead and be far from optimal
+when all other computations are done on GPU. They will still work on array on
+GPU since the copy is done automatically.
+
+Some of the functions that rely on the exact C++ solver are:
+- :any:`ot.emd`, :any:`ot.emd2`
+- :any:`ot.gromov_wasserstein`, :any:`ot.gromov_wasserstein2`
+- :any:`ot.optim.cg`
List of compatible Backends
^^^^^^^^^^^^^^^^^^^^^^^^^^^
@@ -1010,18 +960,21 @@ List of compatible Backends
- `Numpy <https://numpy.org/>`_ (all functions and solvers)
- `Pytorch <https://pytorch.org/>`_ (all outputs differentiable w.r.t. inputs)
- `Jax <https://github.com/google/jax>`_ (Some functions are differentiable some require a wrapper)
+- `Tensorflow <https://www.tensorflow.org/>`_ (all outputs differentiable w.r.t. inputs)
+- `Cupy <https://cupy.dev/>`_ (no differentiation, GPU only)
+
-List of compatible functions
+List of compatible modules
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
This list will get longer for new releases and will hopefully disappear when POT
become fully implemented with the backend.
-- :any:`ot.emd`
-- :any:`ot.emd2`
-- :any:`ot.sinkhorn`
-- :any:`ot.sinkhorn2`
-- :any:`ot.dist`
+- :any:`ot.bregman`
+- :any:`ot.gromov` (some functions use CPU only solvers with copy overhead)
+- :any:`ot.optim` (some functions use CPU only solvers with copy overhead)
+- :any:`ot.sliced`
+- :any:`ot.utils` (partial)
FAQ
@@ -1049,7 +1002,7 @@ FAQ
2. **pip install POT fails with error : ImportError: No module named Cython.Build**
- As discussed shortly in the README file. POT requires to have :code:`numpy`
+ As discussed shortly in the README file. POT<0.8 requires to have :code:`numpy`
and :code:`cython` installed to build. This corner case is not yet handled
by :code:`pip` and for now you need to install both library prior to
installing POT.
@@ -1075,15 +1028,6 @@ FAQ
speedup can be obtained by using a GPU implementation since all operations
are matrix/vector products.
-4. **Using GPU fails with error: module 'ot' has no attribute 'gpu'**
-
- In order to limit import time and hard dependencies in POT. we do not import
- some sub-modules automatically with :code:`import ot`. In order to use the
- acceleration in :any:`ot.gpu` you need first to import is with
- :code:`import ot.gpu`.
-
- See `Issue #85 <https://github.com/rflamary/POT/issues/85>`__ and :any:`ot.gpu`
- for more details.
References
@@ -1219,3 +1163,52 @@ References
.. [30] Flamary, Rémi, et al. "Optimal transport with Laplacian regularization:
Applications to domain adaptation and shape matching." NIPS Workshop on Optimal
Transport and Machine Learning OTML. 2014.
+
+.. [31] Bonneel, Nicolas, et al. `Sliced and radon wasserstein barycenters of
+ measures
+ <https://perso.liris.cnrs.fr/nicolas.bonneel/WassersteinSliced-JMIV.pdf>`_\
+ , Journal of Mathematical Imaging and Vision 51.1 (2015): 22-45
+
+.. [32] Huang, M., Ma S., Lai, L. (2021). `A Riemannian Block Coordinate Descent Method for Computing the Projection Robust Wasserstein Distance <http://proceedings.mlr.press/v139/huang21e.html>`_\ , Proceedings of the 38th International Conference on Machine Learning (ICML).
+
+.. [33] Kerdoncuff T., Emonet R., Marc S. `Sampled Gromov Wasserstein
+ <https://hal.archives-ouvertes.fr/hal-03232509/document>`_\ , Machine
+ Learning Journal (MJL), 2021
+
+.. [34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., Trouvé, A., &
+ Peyré, G. (2019, April). `Interpolating between optimal transport and MMD
+ using Sinkhorn divergences
+ <http://proceedings.mlr.press/v89/feydy19a/feydy19a.pdf>`_. In The 22nd
+ International Conference on Artificial Intelligence and Statistics (pp.
+ 2681-2690). PMLR.
+
+.. [35] Deshpande, I., Hu, Y. T., Sun, R., Pyrros, A., Siddiqui, N., Koyejo, S.,
+ & Schwing, A. G. (2019). `Max-sliced wasserstein distance and its use
+ for gans
+ <https://openaccess.thecvf.com/content_CVPR_2019/papers/Deshpande_Max-Sliced_Wasserstein_Distance_and_Its_Use_for_GANs_CVPR_2019_paper.pdf>`_.
+ In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (pp. 10648-10656).
+
+.. [36] Liutkus, A., Simsekli, U., Majewski, S., Durmus, A., & Stöter, F. R.
+ (2019, May). `Sliced-Wasserstein flows: Nonparametric generative modeling via
+ optimal transport and diffusions
+ <http://proceedings.mlr.press/v97/liutkus19a/liutkus19a.pdf>`_. In International
+ Conference on Machine Learning (pp. 4104-4113). PMLR.
+
+.. [37] Janati, H., Cuturi, M., Gramfort, A. `Debiased sinkhorn barycenters
+ <http://proceedings.mlr.press/v119/janati20a/janati20a.pdf>`_ Proceedings of
+ the 37th International Conference on Machine Learning, PMLR 119:4692-4701, 2020
+
+.. [38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, `Online
+ Graph Dictionary Learning <https://arxiv.org/pdf/2102.06555.pdf>`_\ ,
+ International Conference on Machine Learning (ICML), 2021.
+
+.. [39] Gozlan, N., Roberto, C., Samson, P. M., & Tetali, P. (2017).
+ `Kantorovich duality for general transport costs and applications
+ <https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.712.1825&rep=rep1&type=pdf>`_.
+ Journal of Functional Analysis, 273(11), 3327-3405.
+
+.. [40] Forrow, A., Hütter, J. C., Nitzan, M., Rigollet, P., Schiebinger, G., &
+ Weed, J. (2019, April). `Statistical optimal transport via factored
+ couplings <http://proceedings.mlr.press/v89/forrow19a/forrow19a.pdf>`_. In
+ The 22nd International Conference on Artificial Intelligence and Statistics
+ (pp. 2454-2465). PMLR.
diff --git a/docs/source/releases.rst b/docs/source/releases.rst
index 8250a4d..b2c7a44 100644
--- a/docs/source/releases.rst
+++ b/docs/source/releases.rst
@@ -3,4 +3,4 @@ Releases
.. include:: ../../RELEASES.md
:parser: myst_parser.sphinx_
- :start-line: 3
+ :start-line: 2
diff --git a/examples/backends/plot_dual_ot_pytorch.py b/examples/backends/plot_dual_ot_pytorch.py
new file mode 100644
index 0000000..d3f7a66
--- /dev/null
+++ b/examples/backends/plot_dual_ot_pytorch.py
@@ -0,0 +1,168 @@
+# -*- coding: utf-8 -*-
+r"""
+======================================================================
+Dual OT solvers for entropic and quadratic regularized OT with Pytorch
+======================================================================
+
+
+"""
+
+# Author: Remi Flamary <remi.flamary@polytechnique.edu>
+#
+# License: MIT License
+
+# sphinx_gallery_thumbnail_number = 3
+
+import numpy as np
+import matplotlib.pyplot as pl
+import torch
+import ot
+import ot.plot
+
+# %%
+# Data generation
+# ---------------
+
+torch.manual_seed(1)
+
+n_source_samples = 100
+n_target_samples = 100
+theta = 2 * np.pi / 20
+noise_level = 0.1
+
+Xs, ys = ot.datasets.make_data_classif(
+ 'gaussrot', n_source_samples, nz=noise_level)
+Xt, yt = ot.datasets.make_data_classif(
+ 'gaussrot', n_target_samples, theta=theta, nz=noise_level)
+
+# one of the target mode changes its variance (no linear mapping)
+Xt[yt == 2] *= 3
+Xt = Xt + 4
+
+
+# %%
+# Plot data
+# ---------
+
+pl.figure(1, (10, 5))
+pl.clf()
+pl.scatter(Xs[:, 0], Xs[:, 1], marker='+', label='Source samples')
+pl.scatter(Xt[:, 0], Xt[:, 1], marker='o', label='Target samples')
+pl.legend(loc=0)
+pl.title('Source and target distributions')
+
+# %%
+# Convert data to torch tensors
+# -----------------------------
+
+xs = torch.tensor(Xs)
+xt = torch.tensor(Xt)
+
+# %%
+# Estimating dual variables for entropic OT
+# -----------------------------------------
+
+u = torch.randn(n_source_samples, requires_grad=True)
+v = torch.randn(n_source_samples, requires_grad=True)
+
+reg = 0.5
+
+optimizer = torch.optim.Adam([u, v], lr=1)
+
+# number of iteration
+n_iter = 200
+
+
+losses = []
+
+for i in range(n_iter):
+
+ # generate noise samples
+
+ # minus because we maximize te dual loss
+ loss = -ot.stochastic.loss_dual_entropic(u, v, xs, xt, reg=reg)
+ losses.append(float(loss.detach()))
+
+ if i % 10 == 0:
+ print("Iter: {:3d}, loss={}".format(i, losses[-1]))
+
+ loss.backward()
+ optimizer.step()
+ optimizer.zero_grad()
+
+
+pl.figure(2)
+pl.plot(losses)
+pl.grid()
+pl.title('Dual objective (negative)')
+pl.xlabel("Iterations")
+
+Ge = ot.stochastic.plan_dual_entropic(u, v, xs, xt, reg=reg)
+
+# %%
+# Plot teh estimated entropic OT plan
+# -----------------------------------
+
+pl.figure(3, (10, 5))
+pl.clf()
+ot.plot.plot2D_samples_mat(Xs, Xt, Ge.detach().numpy(), alpha=0.1)
+pl.scatter(Xs[:, 0], Xs[:, 1], marker='+', label='Source samples', zorder=2)
+pl.scatter(Xt[:, 0], Xt[:, 1], marker='o', label='Target samples', zorder=2)
+pl.legend(loc=0)
+pl.title('Source and target distributions')
+
+
+# %%
+# Estimating dual variables for quadratic OT
+# -----------------------------------------
+
+u = torch.randn(n_source_samples, requires_grad=True)
+v = torch.randn(n_source_samples, requires_grad=True)
+
+reg = 0.01
+
+optimizer = torch.optim.Adam([u, v], lr=1)
+
+# number of iteration
+n_iter = 200
+
+
+losses = []
+
+
+for i in range(n_iter):
+
+ # generate noise samples
+
+ # minus because we maximize te dual loss
+ loss = -ot.stochastic.loss_dual_quadratic(u, v, xs, xt, reg=reg)
+ losses.append(float(loss.detach()))
+
+ if i % 10 == 0:
+ print("Iter: {:3d}, loss={}".format(i, losses[-1]))
+
+ loss.backward()
+ optimizer.step()
+ optimizer.zero_grad()
+
+
+pl.figure(4)
+pl.plot(losses)
+pl.grid()
+pl.title('Dual objective (negative)')
+pl.xlabel("Iterations")
+
+Gq = ot.stochastic.plan_dual_quadratic(u, v, xs, xt, reg=reg)
+
+
+# %%
+# Plot the estimated quadratic OT plan
+# -----------------------------------
+
+pl.figure(5, (10, 5))
+pl.clf()
+ot.plot.plot2D_samples_mat(Xs, Xt, Gq.detach().numpy(), alpha=0.1)
+pl.scatter(Xs[:, 0], Xs[:, 1], marker='+', label='Source samples', zorder=2)
+pl.scatter(Xt[:, 0], Xt[:, 1], marker='o', label='Target samples', zorder=2)
+pl.legend(loc=0)
+pl.title('OT plan with quadratic regularization')
diff --git a/examples/backends/plot_sliced_wass_grad_flow_pytorch.py b/examples/backends/plot_sliced_wass_grad_flow_pytorch.py
index 05b9952..cf5d64d 100644
--- a/examples/backends/plot_sliced_wass_grad_flow_pytorch.py
+++ b/examples/backends/plot_sliced_wass_grad_flow_pytorch.py
@@ -27,6 +27,8 @@ Machine Learning (pp. 4104-4113). PMLR.
#
# License: MIT License
+# sphinx_gallery_thumbnail_number = 4
+
# %%
# Loading the data
diff --git a/examples/backends/plot_stoch_continuous_ot_pytorch.py b/examples/backends/plot_stoch_continuous_ot_pytorch.py
new file mode 100644
index 0000000..6d9b916
--- /dev/null
+++ b/examples/backends/plot_stoch_continuous_ot_pytorch.py
@@ -0,0 +1,189 @@
+# -*- coding: utf-8 -*-
+r"""
+======================================================================
+Continuous OT plan estimation with Pytorch
+======================================================================
+
+
+"""
+
+# Author: Remi Flamary <remi.flamary@polytechnique.edu>
+#
+# License: MIT License
+
+# sphinx_gallery_thumbnail_number = 3
+
+import numpy as np
+import matplotlib.pyplot as pl
+import torch
+from torch import nn
+import ot
+import ot.plot
+
+# %%
+# Data generation
+# ---------------
+
+torch.manual_seed(42)
+np.random.seed(42)
+
+n_source_samples = 10000
+n_target_samples = 10000
+theta = 2 * np.pi / 20
+noise_level = 0.1
+
+Xs = np.random.randn(n_source_samples, 2) * 0.5
+Xt = np.random.randn(n_target_samples, 2) * 2
+
+# one of the target mode changes its variance (no linear mapping)
+Xt = Xt + 4
+
+
+# %%
+# Plot data
+# ---------
+nvisu = 300
+pl.figure(1, (5, 5))
+pl.clf()
+pl.scatter(Xs[:nvisu, 0], Xs[:nvisu, 1], marker='+', label='Source samples', alpha=0.5)
+pl.scatter(Xt[:nvisu, 0], Xt[:nvisu, 1], marker='o', label='Target samples', alpha=0.5)
+pl.legend(loc=0)
+ax_bounds = pl.axis()
+pl.title('Source and target distributions')
+
+# %%
+# Convert data to torch tensors
+# -----------------------------
+
+xs = torch.tensor(Xs)
+xt = torch.tensor(Xt)
+
+# %%
+# Estimating deep dual variables for entropic OT
+# ----------------------------------------------
+
+torch.manual_seed(42)
+
+# define the MLP model
+
+
+class Potential(torch.nn.Module):
+ def __init__(self):
+ super(Potential, self).__init__()
+ self.fc1 = nn.Linear(2, 200)
+ self.fc2 = nn.Linear(200, 1)
+ self.relu = torch.nn.ReLU() # instead of Heaviside step fn
+
+ def forward(self, x):
+ output = self.fc1(x)
+ output = self.relu(output) # instead of Heaviside step fn
+ output = self.fc2(output)
+ return output.ravel()
+
+
+u = Potential().double()
+v = Potential().double()
+
+reg = 1
+
+optimizer = torch.optim.Adam(list(u.parameters()) + list(v.parameters()), lr=.005)
+
+# number of iteration
+n_iter = 1000
+n_batch = 500
+
+
+losses = []
+
+for i in range(n_iter):
+
+ # generate noise samples
+
+ iperms = torch.randint(0, n_source_samples, (n_batch,))
+ ipermt = torch.randint(0, n_target_samples, (n_batch,))
+
+ xsi = xs[iperms]
+ xti = xt[ipermt]
+
+ # minus because we maximize te dual loss
+ loss = -ot.stochastic.loss_dual_entropic(u(xsi), v(xti), xsi, xti, reg=reg)
+ losses.append(float(loss.detach()))
+
+ if i % 10 == 0:
+ print("Iter: {:3d}, loss={}".format(i, losses[-1]))
+
+ loss.backward()
+ optimizer.step()
+ optimizer.zero_grad()
+
+
+pl.figure(2)
+pl.plot(losses)
+pl.grid()
+pl.title('Dual objective (negative)')
+pl.xlabel("Iterations")
+
+
+# %%
+# Plot the density on arget for a given source sample
+# ---------------------------------------------------
+
+
+nv = 100
+xl = np.linspace(ax_bounds[0], ax_bounds[1], nv)
+yl = np.linspace(ax_bounds[2], ax_bounds[3], nv)
+
+XX, YY = np.meshgrid(xl, yl)
+
+xg = np.concatenate((XX.ravel()[:, None], YY.ravel()[:, None]), axis=1)
+
+wxg = np.exp(-((xg[:, 0] - 4)**2 + (xg[:, 1] - 4)**2) / (2 * 2))
+wxg = wxg / np.sum(wxg)
+
+xg = torch.tensor(xg)
+wxg = torch.tensor(wxg)
+
+
+pl.figure(4, (12, 4))
+pl.clf()
+pl.subplot(1, 3, 1)
+
+iv = 2
+Gg = ot.stochastic.plan_dual_entropic(u(xs[iv:iv + 1, :]), v(xg), xs[iv:iv + 1, :], xg, reg=reg, wt=wxg)
+Gg = Gg.reshape((nv, nv)).detach().numpy()
+
+pl.scatter(Xs[:nvisu, 0], Xs[:nvisu, 1], marker='+', zorder=2, alpha=0.05)
+pl.scatter(Xt[:nvisu, 0], Xt[:nvisu, 1], marker='o', zorder=2, alpha=0.05)
+pl.scatter(Xs[iv:iv + 1, 0], Xs[iv:iv + 1, 1], s=100, marker='+', label='Source sample', zorder=2, alpha=1, color='C0')
+pl.pcolormesh(XX, YY, Gg, cmap='Greens', label='Density of transported sourec sample')
+pl.legend(loc=0)
+ax_bounds = pl.axis()
+pl.title('Density of transported source sample')
+
+pl.subplot(1, 3, 2)
+
+iv = 3
+Gg = ot.stochastic.plan_dual_entropic(u(xs[iv:iv + 1, :]), v(xg), xs[iv:iv + 1, :], xg, reg=reg, wt=wxg)
+Gg = Gg.reshape((nv, nv)).detach().numpy()
+
+pl.scatter(Xs[:nvisu, 0], Xs[:nvisu, 1], marker='+', zorder=2, alpha=0.05)
+pl.scatter(Xt[:nvisu, 0], Xt[:nvisu, 1], marker='o', zorder=2, alpha=0.05)
+pl.scatter(Xs[iv:iv + 1, 0], Xs[iv:iv + 1, 1], s=100, marker='+', label='Source sample', zorder=2, alpha=1, color='C0')
+pl.pcolormesh(XX, YY, Gg, cmap='Greens', label='Density of transported sourec sample')
+pl.legend(loc=0)
+ax_bounds = pl.axis()
+pl.title('Density of transported source sample')
+
+pl.subplot(1, 3, 3)
+
+iv = 6
+Gg = ot.stochastic.plan_dual_entropic(u(xs[iv:iv + 1, :]), v(xg), xs[iv:iv + 1, :], xg, reg=reg, wt=wxg)
+Gg = Gg.reshape((nv, nv)).detach().numpy()
+
+pl.scatter(Xs[:nvisu, 0], Xs[:nvisu, 1], marker='+', zorder=2, alpha=0.05)
+pl.scatter(Xt[:nvisu, 0], Xt[:nvisu, 1], marker='o', zorder=2, alpha=0.05)
+pl.scatter(Xs[iv:iv + 1, 0], Xs[iv:iv + 1, 1], s=100, marker='+', label='Source sample', zorder=2, alpha=1, color='C0')
+pl.pcolormesh(XX, YY, Gg, cmap='Greens', label='Density of transported sourec sample')
+pl.legend(loc=0)
+ax_bounds = pl.axis()
+pl.title('Density of transported source sample')
diff --git a/examples/backends/plot_wass1d_torch.py b/examples/backends/plot_wass1d_torch.py
index 0abdd6d..cd8e2fd 100644
--- a/examples/backends/plot_wass1d_torch.py
+++ b/examples/backends/plot_wass1d_torch.py
@@ -1,9 +1,9 @@
r"""
-=================================
-Wasserstein 1D with PyTorch
-=================================
+=================================================
+Wasserstein 1D (flow and barycenter) with PyTorch
+=================================================
-In this small example, we consider the following minization problem:
+In this small example, we consider the following minimization problem:
.. math::
\mu^* = \min_\mu W(\mu,\nu)
diff --git a/examples/barycenters/plot_free_support_barycenter.py b/examples/barycenters/plot_free_support_barycenter.py
index 2d68a39..226dfeb 100644
--- a/examples/barycenters/plot_free_support_barycenter.py
+++ b/examples/barycenters/plot_free_support_barycenter.py
@@ -9,61 +9,62 @@ sum of diracs.
"""
-# Author: Vivien Seguy <vivien.seguy@iip.ist.i.kyoto-u.ac.jp>
+# Authors: Vivien Seguy <vivien.seguy@iip.ist.i.kyoto-u.ac.jp>
+# Rémi Flamary <remi.flamary@polytechnique.edu>
#
# License: MIT License
+# sphinx_gallery_thumbnail_number = 2
+
import numpy as np
import matplotlib.pylab as pl
import ot
-##############################################################################
+# %%
# Generate data
# -------------
-N = 3
+N = 2
d = 2
-measures_locations = []
-measures_weights = []
-
-for i in range(N):
- n_i = np.random.randint(low=1, high=20) # nb samples
+I1 = pl.imread('../../data/redcross.png').astype(np.float64)[::4, ::4, 2]
+I2 = pl.imread('../../data/duck.png').astype(np.float64)[::4, ::4, 2]
- mu_i = np.random.normal(0., 4., (d,)) # Gaussian mean
+sz = I2.shape[0]
+XX, YY = np.meshgrid(np.arange(sz), np.arange(sz))
- A_i = np.random.rand(d, d)
- cov_i = np.dot(A_i, A_i.transpose()) # Gaussian covariance matrix
+x1 = np.stack((XX[I1 == 0], YY[I1 == 0]), 1) * 1.0
+x2 = np.stack((XX[I2 == 0] + 80, -YY[I2 == 0] + 32), 1) * 1.0
+x3 = np.stack((XX[I2 == 0], -YY[I2 == 0] + 32), 1) * 1.0
- x_i = ot.datasets.make_2D_samples_gauss(n_i, mu_i, cov_i) # Dirac locations
- b_i = np.random.uniform(0., 1., (n_i,))
- b_i = b_i / np.sum(b_i) # Dirac weights
+measures_locations = [x1, x2]
+measures_weights = [ot.unif(x1.shape[0]), ot.unif(x2.shape[0])]
- measures_locations.append(x_i)
- measures_weights.append(b_i)
+pl.figure(1, (12, 4))
+pl.scatter(x1[:, 0], x1[:, 1], alpha=0.5)
+pl.scatter(x2[:, 0], x2[:, 1], alpha=0.5)
+pl.title('Distributions')
-##############################################################################
+# %%
# Compute free support barycenter
# -------------------------------
-k = 10 # number of Diracs of the barycenter
+k = 200 # number of Diracs of the barycenter
X_init = np.random.normal(0., 1., (k, d)) # initial Dirac locations
b = np.ones((k,)) / k # weights of the barycenter (it will not be optimized, only the locations are optimized)
X = ot.lp.free_support_barycenter(measures_locations, measures_weights, X_init, b)
-
-##############################################################################
-# Plot data
+# %%
+# Plot the barycenter
# ---------
-pl.figure(1)
-for (x_i, b_i) in zip(measures_locations, measures_weights):
- color = np.random.randint(low=1, high=10 * N)
- pl.scatter(x_i[:, 0], x_i[:, 1], s=b_i * 1000, label='input measure')
-pl.scatter(X[:, 0], X[:, 1], s=b * 1000, c='black', marker='^', label='2-Wasserstein barycenter')
+pl.figure(2, (8, 3))
+pl.scatter(x1[:, 0], x1[:, 1], alpha=0.5)
+pl.scatter(x2[:, 0], x2[:, 1], alpha=0.5)
+pl.scatter(X[:, 0], X[:, 1], s=b * 1000, marker='s', label='2-Wasserstein barycenter')
pl.title('Data measures and their barycenter')
-pl.legend(loc=0)
+pl.legend(loc="lower right")
pl.show()
diff --git a/examples/gromov/plot_gromov_wasserstein_dictionary_learning.py b/examples/gromov/plot_gromov_wasserstein_dictionary_learning.py
new file mode 100755
index 0000000..1fdc3b9
--- /dev/null
+++ b/examples/gromov/plot_gromov_wasserstein_dictionary_learning.py
@@ -0,0 +1,357 @@
+# -*- coding: utf-8 -*-
+
+r"""
+=================================
+(Fused) Gromov-Wasserstein Linear Dictionary Learning
+=================================
+
+In this exemple, we illustrate how to learn a Gromov-Wasserstein dictionary on
+a dataset of structured data such as graphs, denoted
+:math:`\{ \mathbf{C_s} \}_{s \in [S]}` where every nodes have uniform weights.
+Given a dictionary :math:`\mathbf{C_{dict}}` composed of D structures of a fixed
+size nt, each graph :math:`(\mathbf{C_s}, \mathbf{p_s})`
+is modeled as a convex combination :math:`\mathbf{w_s} \in \Sigma_D` of these
+dictionary atoms as :math:`\sum_d w_{s,d} \mathbf{C_{dict}[d]}`.
+
+
+First, we consider a dataset composed of graphs generated by Stochastic Block models
+with variable sizes taken in :math:`\{30, ... , 50\}` and quantities of clusters
+varying in :math:`\{ 1, 2, 3\}`. We learn a dictionary of 3 atoms, by minimizing
+the Gromov-Wasserstein distance from all samples to its model in the dictionary
+with respect to the dictionary atoms.
+
+Second, we illustrate the extension of this dictionary learning framework to
+structured data endowed with node features by using the Fused Gromov-Wasserstein
+distance. Starting from the aforementioned dataset of unattributed graphs, we
+add discrete labels uniformly depending on the number of clusters. Then we learn
+and visualize attributed graph atoms where each sample is modeled as a joint convex
+combination between atom structures and features.
+
+
+[38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, Online Graph
+Dictionary Learning, International Conference on Machine Learning (ICML), 2021.
+
+"""
+# Author: Cédric Vincent-Cuaz <cedric.vincent-cuaz@inria.fr>
+#
+# License: MIT License
+
+# sphinx_gallery_thumbnail_number = 4
+
+import numpy as np
+import matplotlib.pylab as pl
+from sklearn.manifold import MDS
+from ot.gromov import gromov_wasserstein_linear_unmixing, gromov_wasserstein_dictionary_learning, fused_gromov_wasserstein_linear_unmixing, fused_gromov_wasserstein_dictionary_learning
+import ot
+import networkx
+from networkx.generators.community import stochastic_block_model as sbm
+# %%
+# =============================================================================
+# Generate a dataset composed of graphs following Stochastic Block models of 1, 2 and 3 clusters.
+# =============================================================================
+
+np.random.seed(42)
+
+N = 60 # number of graphs in the dataset
+# For every number of clusters, we generate SBM with fixed inter/intra-clusters probability.
+clusters = [1, 2, 3]
+Nc = N // len(clusters) # number of graphs by cluster
+nlabels = len(clusters)
+dataset = []
+labels = []
+
+p_inter = 0.1
+p_intra = 0.9
+for n_cluster in clusters:
+ for i in range(Nc):
+ n_nodes = int(np.random.uniform(low=30, high=50))
+
+ if n_cluster > 1:
+ P = p_inter * np.ones((n_cluster, n_cluster))
+ np.fill_diagonal(P, p_intra)
+ else:
+ P = p_intra * np.eye(1)
+ sizes = np.round(n_nodes * np.ones(n_cluster) / n_cluster).astype(np.int32)
+ G = sbm(sizes, P, seed=i, directed=False)
+ C = networkx.to_numpy_array(G)
+ dataset.append(C)
+ labels.append(n_cluster)
+
+
+# Visualize samples
+
+def plot_graph(x, C, binary=True, color='C0', s=None):
+ for j in range(C.shape[0]):
+ for i in range(j):
+ if binary:
+ if C[i, j] > 0:
+ pl.plot([x[i, 0], x[j, 0]], [x[i, 1], x[j, 1]], alpha=0.2, color='k')
+ else: # connection intensity proportional to C[i,j]
+ pl.plot([x[i, 0], x[j, 0]], [x[i, 1], x[j, 1]], alpha=C[i, j], color='k')
+
+ pl.scatter(x[:, 0], x[:, 1], c=color, s=s, zorder=10, edgecolors='k', cmap='tab10', vmax=9)
+
+
+pl.figure(1, (12, 8))
+pl.clf()
+for idx_c, c in enumerate(clusters):
+ C = dataset[(c - 1) * Nc] # sample with c clusters
+ # get 2d position for nodes
+ x = MDS(dissimilarity='precomputed', random_state=0).fit_transform(1 - C)
+ pl.subplot(2, nlabels, c)
+ pl.title('(graph) sample from label ' + str(c), fontsize=14)
+ plot_graph(x, C, binary=True, color='C0', s=50.)
+ pl.axis("off")
+ pl.subplot(2, nlabels, nlabels + c)
+ pl.title('(matrix) sample from label %s \n' % c, fontsize=14)
+ pl.imshow(C, interpolation='nearest')
+ pl.axis("off")
+pl.tight_layout()
+pl.show()
+
+# %%
+# =============================================================================
+# Estimate the gromov-wasserstein dictionary from the dataset
+# =============================================================================
+
+
+np.random.seed(0)
+ps = [ot.unif(C.shape[0]) for C in dataset]
+
+D = 3 # 3 atoms in the dictionary
+nt = 6 # of 6 nodes each
+
+q = ot.unif(nt)
+reg = 0. # regularization coefficient to promote sparsity of unmixings {w_s}
+
+Cdict_GW, log = gromov_wasserstein_dictionary_learning(
+ Cs=dataset, D=D, nt=nt, ps=ps, q=q, epochs=10, batch_size=16,
+ learning_rate=0.1, reg=reg, projection='nonnegative_symmetric',
+ tol_outer=10**(-5), tol_inner=10**(-5), max_iter_outer=30, max_iter_inner=300,
+ use_log=True, use_adam_optimizer=True, verbose=True
+)
+# visualize loss evolution over epochs
+pl.figure(2, (4, 3))
+pl.clf()
+pl.title('loss evolution by epoch', fontsize=14)
+pl.plot(log['loss_epochs'])
+pl.xlabel('epochs', fontsize=12)
+pl.ylabel('loss', fontsize=12)
+pl.tight_layout()
+pl.show()
+
+# %%
+# =============================================================================
+# Visualization of the estimated dictionary atoms
+# =============================================================================
+
+
+# Continuous connections between nodes of the atoms are colored in shades of grey (1: dark / 2: white)
+
+pl.figure(3, (12, 8))
+pl.clf()
+for idx_atom, atom in enumerate(Cdict_GW):
+ scaled_atom = (atom - atom.min()) / (atom.max() - atom.min())
+ x = MDS(dissimilarity='precomputed', random_state=0).fit_transform(1 - scaled_atom)
+ pl.subplot(2, D, idx_atom + 1)
+ pl.title('(graph) atom ' + str(idx_atom + 1), fontsize=14)
+ plot_graph(x, atom / atom.max(), binary=False, color='C0', s=100.)
+ pl.axis("off")
+ pl.subplot(2, D, D + idx_atom + 1)
+ pl.title('(matrix) atom %s \n' % (idx_atom + 1), fontsize=14)
+ pl.imshow(scaled_atom, interpolation='nearest')
+ pl.colorbar()
+ pl.axis("off")
+pl.tight_layout()
+pl.show()
+#%%
+# =============================================================================
+# Visualization of the embedding space
+# =============================================================================
+
+unmixings = []
+reconstruction_errors = []
+for C in dataset:
+ p = ot.unif(C.shape[0])
+ unmixing, Cembedded, OT, reconstruction_error = gromov_wasserstein_linear_unmixing(
+ C, Cdict_GW, p=p, q=q, reg=reg,
+ tol_outer=10**(-5), tol_inner=10**(-5),
+ max_iter_outer=30, max_iter_inner=300
+ )
+ unmixings.append(unmixing)
+ reconstruction_errors.append(reconstruction_error)
+unmixings = np.array(unmixings)
+print('cumulated reconstruction error:', np.array(reconstruction_errors).sum())
+
+
+# Compute the 2D representation of the unmixing living in the 2-simplex of probability
+unmixings2D = np.zeros(shape=(N, 2))
+for i, w in enumerate(unmixings):
+ unmixings2D[i, 0] = (2. * w[1] + w[2]) / 2.
+ unmixings2D[i, 1] = (np.sqrt(3.) * w[2]) / 2.
+x = [0., 0.]
+y = [1., 0.]
+z = [0.5, np.sqrt(3) / 2.]
+extremities = np.stack([x, y, z])
+
+pl.figure(4, (4, 4))
+pl.clf()
+pl.title('Embedding space', fontsize=14)
+for cluster in range(nlabels):
+ start, end = Nc * cluster, Nc * (cluster + 1)
+ if cluster == 0:
+ pl.scatter(unmixings2D[start:end, 0], unmixings2D[start:end, 1], c='C' + str(cluster), marker='o', s=40., label='1 cluster')
+ else:
+ pl.scatter(unmixings2D[start:end, 0], unmixings2D[start:end, 1], c='C' + str(cluster), marker='o', s=40., label='%s clusters' % (cluster + 1))
+pl.scatter(extremities[:, 0], extremities[:, 1], c='black', marker='x', s=80., label='atoms')
+pl.plot([x[0], y[0]], [x[1], y[1]], color='black', linewidth=2.)
+pl.plot([x[0], z[0]], [x[1], z[1]], color='black', linewidth=2.)
+pl.plot([y[0], z[0]], [y[1], z[1]], color='black', linewidth=2.)
+pl.axis('off')
+pl.legend(fontsize=11)
+pl.tight_layout()
+pl.show()
+# %%
+# =============================================================================
+# Endow the dataset with node features
+# =============================================================================
+
+# We follow this feature assignment on all nodes of a graph depending on its label/number of clusters
+# 1 cluster --> 0 as nodes feature
+# 2 clusters --> 1 as nodes feature
+# 3 clusters --> 2 as nodes feature
+# features are one-hot encoded following these assignments
+dataset_features = []
+for i in range(len(dataset)):
+ n = dataset[i].shape[0]
+ F = np.zeros((n, 3))
+ if i < Nc: # graph with 1 cluster
+ F[:, 0] = 1.
+ elif i < 2 * Nc: # graph with 2 clusters
+ F[:, 1] = 1.
+ else: # graph with 3 clusters
+ F[:, 2] = 1.
+ dataset_features.append(F)
+
+pl.figure(5, (12, 8))
+pl.clf()
+for idx_c, c in enumerate(clusters):
+ C = dataset[(c - 1) * Nc] # sample with c clusters
+ F = dataset_features[(c - 1) * Nc]
+ colors = ['C' + str(np.argmax(F[i])) for i in range(F.shape[0])]
+ # get 2d position for nodes
+ x = MDS(dissimilarity='precomputed', random_state=0).fit_transform(1 - C)
+ pl.subplot(2, nlabels, c)
+ pl.title('(graph) sample from label ' + str(c), fontsize=14)
+ plot_graph(x, C, binary=True, color=colors, s=50)
+ pl.axis("off")
+ pl.subplot(2, nlabels, nlabels + c)
+ pl.title('(matrix) sample from label %s \n' % c, fontsize=14)
+ pl.imshow(C, interpolation='nearest')
+ pl.axis("off")
+pl.tight_layout()
+pl.show()
+# %%
+# =============================================================================
+# Estimate a Fused Gromov-Wasserstein dictionary from the dataset of attributed graphs
+# =============================================================================
+np.random.seed(0)
+ps = [ot.unif(C.shape[0]) for C in dataset]
+D = 3 # 6 atoms instead of 3
+nt = 6
+q = ot.unif(nt)
+reg = 0.001
+alpha = 0.5 # trade-off parameter between structure and feature information of Fused Gromov-Wasserstein
+
+
+Cdict_FGW, Ydict_FGW, log = fused_gromov_wasserstein_dictionary_learning(
+ Cs=dataset, Ys=dataset_features, D=D, nt=nt, ps=ps, q=q, alpha=alpha,
+ epochs=10, batch_size=16, learning_rate_C=0.1, learning_rate_Y=0.1, reg=reg,
+ tol_outer=10**(-5), tol_inner=10**(-5), max_iter_outer=30, max_iter_inner=300,
+ projection='nonnegative_symmetric', use_log=True, use_adam_optimizer=True, verbose=True
+)
+# visualize loss evolution
+pl.figure(6, (4, 3))
+pl.clf()
+pl.title('loss evolution by epoch', fontsize=14)
+pl.plot(log['loss_epochs'])
+pl.xlabel('epochs', fontsize=12)
+pl.ylabel('loss', fontsize=12)
+pl.tight_layout()
+pl.show()
+
+# %%
+# =============================================================================
+# Visualization of the estimated dictionary atoms
+# =============================================================================
+
+pl.figure(7, (12, 8))
+pl.clf()
+max_features = Ydict_FGW.max()
+min_features = Ydict_FGW.min()
+
+for idx_atom, (Catom, Fatom) in enumerate(zip(Cdict_FGW, Ydict_FGW)):
+ scaled_atom = (Catom - Catom.min()) / (Catom.max() - Catom.min())
+ #scaled_F = 2 * (Fatom - min_features) / (max_features - min_features)
+ colors = ['C%s' % np.argmax(Fatom[i]) for i in range(Fatom.shape[0])]
+ x = MDS(dissimilarity='precomputed', random_state=0).fit_transform(1 - scaled_atom)
+ pl.subplot(2, D, idx_atom + 1)
+ pl.title('(attributed graph) atom ' + str(idx_atom + 1), fontsize=14)
+ plot_graph(x, Catom / Catom.max(), binary=False, color=colors, s=100)
+ pl.axis("off")
+ pl.subplot(2, D, D + idx_atom + 1)
+ pl.title('(matrix) atom %s \n' % (idx_atom + 1), fontsize=14)
+ pl.imshow(scaled_atom, interpolation='nearest')
+ pl.colorbar()
+ pl.axis("off")
+pl.tight_layout()
+pl.show()
+
+# %%
+# =============================================================================
+# Visualization of the embedding space
+# =============================================================================
+
+unmixings = []
+reconstruction_errors = []
+for i in range(len(dataset)):
+ C = dataset[i]
+ Y = dataset_features[i]
+ p = ot.unif(C.shape[0])
+ unmixing, Cembedded, Yembedded, OT, reconstruction_error = fused_gromov_wasserstein_linear_unmixing(
+ C, Y, Cdict_FGW, Ydict_FGW, p=p, q=q, alpha=alpha,
+ reg=reg, tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=30, max_iter_inner=300
+ )
+ unmixings.append(unmixing)
+ reconstruction_errors.append(reconstruction_error)
+unmixings = np.array(unmixings)
+print('cumulated reconstruction error:', np.array(reconstruction_errors).sum())
+
+# Visualize unmixings in the 2-simplex of probability
+unmixings2D = np.zeros(shape=(N, 2))
+for i, w in enumerate(unmixings):
+ unmixings2D[i, 0] = (2. * w[1] + w[2]) / 2.
+ unmixings2D[i, 1] = (np.sqrt(3.) * w[2]) / 2.
+x = [0., 0.]
+y = [1., 0.]
+z = [0.5, np.sqrt(3) / 2.]
+extremities = np.stack([x, y, z])
+
+pl.figure(8, (4, 4))
+pl.clf()
+pl.title('Embedding space', fontsize=14)
+for cluster in range(nlabels):
+ start, end = Nc * cluster, Nc * (cluster + 1)
+ if cluster == 0:
+ pl.scatter(unmixings2D[start:end, 0], unmixings2D[start:end, 1], c='C' + str(cluster), marker='o', s=40., label='1 cluster')
+ else:
+ pl.scatter(unmixings2D[start:end, 0], unmixings2D[start:end, 1], c='C' + str(cluster), marker='o', s=40., label='%s clusters' % (cluster + 1))
+
+pl.scatter(extremities[:, 0], extremities[:, 1], c='black', marker='x', s=80., label='atoms')
+pl.plot([x[0], y[0]], [x[1], y[1]], color='black', linewidth=2.)
+pl.plot([x[0], z[0]], [x[1], z[1]], color='black', linewidth=2.)
+pl.plot([y[0], z[0]], [y[1], z[1]], color='black', linewidth=2.)
+pl.axis('off')
+pl.legend(fontsize=11)
+pl.tight_layout()
+pl.show()
diff --git a/examples/others/plot_WeakOT_VS_OT.py b/examples/others/plot_WeakOT_VS_OT.py
new file mode 100644
index 0000000..a29c875
--- /dev/null
+++ b/examples/others/plot_WeakOT_VS_OT.py
@@ -0,0 +1,98 @@
+# -*- coding: utf-8 -*-
+"""
+====================================================
+Weak Optimal Transport VS exact Optimal Transport
+====================================================
+
+Illustration of 2D optimal transport between distributions that are weighted
+sum of diracs. The OT matrix is plotted with the samples.
+
+"""
+
+# Author: Remi Flamary <remi.flamary@polytechnique.edu>
+#
+# License: MIT License
+
+# sphinx_gallery_thumbnail_number = 4
+
+import numpy as np
+import matplotlib.pylab as pl
+import ot
+import ot.plot
+
+##############################################################################
+# Generate data an plot it
+# ------------------------
+
+#%% parameters and data generation
+
+n = 50 # nb samples
+
+mu_s = np.array([0, 0])
+cov_s = np.array([[1, 0], [0, 1]])
+
+mu_t = np.array([4, 4])
+cov_t = np.array([[1, -.8], [-.8, 1]])
+
+xs = ot.datasets.make_2D_samples_gauss(n, mu_s, cov_s)
+xt = ot.datasets.make_2D_samples_gauss(n, mu_t, cov_t)
+
+a, b = ot.unif(n), ot.unif(n) # uniform distribution on samples
+
+# loss matrix
+M = ot.dist(xs, xt)
+M /= M.max()
+
+#%% plot samples
+
+pl.figure(1)
+pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
+pl.legend(loc=0)
+pl.title('Source and target distributions')
+
+pl.figure(2)
+pl.imshow(M, interpolation='nearest')
+pl.title('Cost matrix M')
+
+
+##############################################################################
+# Compute Weak OT and exact OT solutions
+# --------------------------------------
+
+#%% EMD
+
+G0 = ot.emd(a, b, M)
+
+#%% Weak OT
+
+Gweak = ot.weak_optimal_transport(xs, xt, a, b)
+
+
+##############################################################################
+# Plot weak OT and exact OT solutions
+# --------------------------------------
+
+pl.figure(3, (8, 5))
+
+pl.subplot(1, 2, 1)
+pl.imshow(G0, interpolation='nearest')
+pl.title('OT matrix')
+
+pl.subplot(1, 2, 2)
+pl.imshow(Gweak, interpolation='nearest')
+pl.title('Weak OT matrix')
+
+pl.figure(4, (8, 5))
+
+pl.subplot(1, 2, 1)
+ot.plot.plot2D_samples_mat(xs, xt, G0, c=[.5, .5, 1])
+pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
+pl.title('OT matrix with samples')
+
+pl.subplot(1, 2, 2)
+ot.plot.plot2D_samples_mat(xs, xt, Gweak, c=[.5, .5, 1])
+pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
+pl.title('Weak OT matrix with samples')
diff --git a/examples/others/plot_factored_coupling.py b/examples/others/plot_factored_coupling.py
new file mode 100644
index 0000000..b5b1c9f
--- /dev/null
+++ b/examples/others/plot_factored_coupling.py
@@ -0,0 +1,86 @@
+# -*- coding: utf-8 -*-
+"""
+==========================================
+Optimal transport with factored couplings
+==========================================
+
+Illustration of the factored coupling OT between 2D empirical distributions
+
+"""
+
+# Author: Remi Flamary <remi.flamary@polytechnique.edu>
+#
+# License: MIT License
+
+# sphinx_gallery_thumbnail_number = 2
+
+import numpy as np
+import matplotlib.pylab as pl
+import ot
+import ot.plot
+
+# %%
+# Generate data an plot it
+# ------------------------
+
+# parameters and data generation
+
+np.random.seed(42)
+
+n = 100 # nb samples
+
+xs = np.random.rand(n, 2) - .5
+
+xs = xs + np.sign(xs)
+
+xt = np.random.rand(n, 2) - .5
+
+a, b = ot.unif(n), ot.unif(n) # uniform distribution on samples
+
+#%% plot samples
+
+pl.figure(1)
+pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
+pl.legend(loc=0)
+pl.title('Source and target distributions')
+
+
+# %%
+# Compute Factore OT and exact OT solutions
+# --------------------------------------
+
+#%% EMD
+M = ot.dist(xs, xt)
+G0 = ot.emd(a, b, M)
+
+#%% factored OT OT
+
+Ga, Gb, xb = ot.factored_optimal_transport(xs, xt, a, b, r=4)
+
+
+# %%
+# Plot factored OT and exact OT solutions
+# --------------------------------------
+
+pl.figure(2, (14, 4))
+
+pl.subplot(1, 3, 1)
+ot.plot.plot2D_samples_mat(xs, xt, G0, c=[.2, .2, .2], alpha=0.1)
+pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
+pl.title('Exact OT with samples')
+
+pl.subplot(1, 3, 2)
+ot.plot.plot2D_samples_mat(xs, xb, Ga, c=[.6, .6, .9], alpha=0.5)
+ot.plot.plot2D_samples_mat(xb, xt, Gb, c=[.9, .6, .6], alpha=0.5)
+pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
+pl.plot(xb[:, 0], xb[:, 1], 'og', label='Template samples')
+pl.title('Factored OT with template samples')
+
+pl.subplot(1, 3, 3)
+ot.plot.plot2D_samples_mat(xs, xt, Ga.dot(Gb), c=[.2, .2, .2], alpha=0.1)
+pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
+pl.title('Factored OT low rank OT plan')
diff --git a/examples/others/plot_logo.py b/examples/others/plot_logo.py
new file mode 100644
index 0000000..bb4f640
--- /dev/null
+++ b/examples/others/plot_logo.py
@@ -0,0 +1,112 @@
+
+# -*- coding: utf-8 -*-
+r"""
+=======================
+Logo of the POT toolbox
+=======================
+
+In this example we plot the logo of the POT toolbox.
+
+This logo is that it is done 100% in Python and generated using
+matplotlib and ploting teh solution of the EMD solver from POT.
+
+"""
+
+# Author: Remi Flamary <remi.flamary@polytechnique.edu>
+#
+# License: MIT License
+
+# sphinx_gallery_thumbnail_number = 1
+
+# %% Load modules
+import numpy as np
+import matplotlib.pyplot as pl
+import ot
+
+# %%
+# Data for logo
+# -------------
+
+
+# Letter P
+p1 = np.array([[0, 6.], [0, 5], [0, 4], [0, 3], [0, 2], [0, 1], ])
+p2 = np.array([[1.5, 6], [2, 4], [2, 5], [1.5, 3], [0.5, 2], [.5, 1], ])
+
+# Letter O
+o1 = np.array([[0, 6.], [-1, 5], [-1.5, 4], [-1.5, 3], [-1, 2], [0, 1], ])
+o2 = np.array([[1, 6.], [2, 5], [2.5, 4], [2.5, 3], [2, 2], [1, 1], ])
+
+# Scaling and translation for letter O
+o1[:, 0] += 6.4
+o2[:, 0] += 6.4
+o1[:, 0] *= 0.6
+o2[:, 0] *= 0.6
+
+# Letter T
+t1 = np.array([[-1, 6.], [-1, 5], [0, 4], [0, 3], [0, 2], [0, 1], ])
+t2 = np.array([[1.5, 6.], [1.5, 5], [0.5, 4], [0.5, 3], [0.5, 2], [0.5, 1], ])
+
+# Translating the T
+t1[:, 0] += 7.1
+t2[:, 0] += 7.1
+
+# Concatenate all letters
+x1 = np.concatenate((p1, o1, t1), axis=0)
+x2 = np.concatenate((p2, o2, t2), axis=0)
+
+# Horizontal and vertical scaling
+sx = 1.0
+sy = .5
+x1[:, 0] *= sx
+x1[:, 1] *= sy
+x2[:, 0] *= sx
+x2[:, 1] *= sy
+
+# %%
+# Plot the logo (clear background)
+# --------------------------------
+
+# Solve OT problem between the points
+M = ot.dist(x1, x2, metric='euclidean')
+T = ot.emd([], [], M)
+
+pl.figure(1, (3.5, 1.1))
+pl.clf()
+# plot the OT plan
+for i in range(M.shape[0]):
+ for j in range(M.shape[1]):
+ if T[i, j] > 1e-8:
+ pl.plot([x1[i, 0], x2[j, 0]], [x1[i, 1], x2[j, 1]], color='k', alpha=0.6, linewidth=3, zorder=1)
+# plot the samples
+pl.plot(x1[:, 0], x1[:, 1], 'o', markerfacecolor='C3', markeredgecolor='k')
+pl.plot(x2[:, 0], x2[:, 1], 'o', markerfacecolor='b', markeredgecolor='k')
+
+
+pl.axis('equal')
+pl.axis('off')
+
+# Save logo file
+# pl.savefig('logo.svg', dpi=150, transparent=True, bbox_inches='tight')
+# pl.savefig('logo.png', dpi=150, transparent=True, bbox_inches='tight')
+
+# %%
+# Plot the logo (dark background)
+# --------------------------------
+
+pl.figure(2, (3.5, 1.1), facecolor='darkgray')
+pl.clf()
+# plot the OT plan
+for i in range(M.shape[0]):
+ for j in range(M.shape[1]):
+ if T[i, j] > 1e-8:
+ pl.plot([x1[i, 0], x2[j, 0]], [x1[i, 1], x2[j, 1]], color='w', alpha=0.8, linewidth=3, zorder=1)
+# plot the samples
+pl.plot(x1[:, 0], x1[:, 1], 'o', markerfacecolor='w', markeredgecolor='w')
+pl.plot(x2[:, 0], x2[:, 1], 'o', markerfacecolor='w', markeredgecolor='w')
+
+pl.axis('equal')
+pl.axis('off')
+
+# Save logo file
+# pl.savefig('logo_dark.svg', dpi=150, transparent=True, bbox_inches='tight')
+# pl.savefig('logo_dark.png', dpi=150, transparent=True, bbox_inches='tight')
diff --git a/examples/plot_screenkhorn_1D.py b/examples/others/plot_screenkhorn_1D.py
index 785642a..2023649 100644
--- a/examples/plot_screenkhorn_1D.py
+++ b/examples/others/plot_screenkhorn_1D.py
@@ -1,8 +1,8 @@
# -*- coding: utf-8 -*-
"""
-===============================
-1D Screened optimal transport
-===============================
+========================================
+Screened optimal transport (Screenkhorn)
+========================================
This example illustrates the computation of Screenkhorn [26].
diff --git a/examples/plot_stochastic.py b/examples/others/plot_stochastic.py
index 3a1ef31..3a1ef31 100644
--- a/examples/plot_stochastic.py
+++ b/examples/others/plot_stochastic.py
diff --git a/examples/plot_Intro_OT.py b/examples/plot_Intro_OT.py
index f282950..219aa51 100644
--- a/examples/plot_Intro_OT.py
+++ b/examples/plot_Intro_OT.py
@@ -58,7 +58,7 @@ help(ot.dist)
# number of Bakeries to Cafés in a City (in this case Manhattan). We did a
# quick google map search in Manhattan for bakeries and Cafés:
#
-# .. image:: images/bak.png
+# .. image:: ../_static/images/bak.png
# :align: center
# :alt: bakery-cafe-manhattan
# :width: 600px
@@ -233,7 +233,7 @@ print('Wasserstein loss (EMD) = {0:.2f}'.format(W))
# The Sinkhorn algorithm is very simple to code. You can implement it directly
# using the following pseudo-code
#
-# .. image:: images/sinkhorn.png
+# .. image:: ../_static/images/sinkhorn.png
# :align: center
# :alt: Sinkhorn algorithm
# :width: 440px
diff --git a/examples/plot_OT_1D.py b/examples/plot_OT_1D.py
index 15ead96..62f0b7d 100644
--- a/examples/plot_OT_1D.py
+++ b/examples/plot_OT_1D.py
@@ -1,8 +1,8 @@
# -*- coding: utf-8 -*-
"""
-====================
-1D optimal transport
-====================
+======================================
+Optimal Transport for 1D distributions
+======================================
This example illustrates the computation of EMD and Sinkhorn transport plans
and their visualization.
@@ -64,7 +64,11 @@ ot.plot.plot1D_mat(a, b, M, 'Cost matrix M')
#%% EMD
-G0 = ot.emd(a, b, M)
+# use fast 1D solver
+G0 = ot.emd_1d(x, x, a, b)
+
+# Equivalent to
+# G0 = ot.emd(a, b, M)
pl.figure(3, figsize=(5, 5))
ot.plot.plot1D_mat(a, b, G0, 'OT matrix G0')
diff --git a/examples/plot_OT_1D_smooth.py b/examples/plot_OT_1D_smooth.py
index b07f99f..5415e4f 100644
--- a/examples/plot_OT_1D_smooth.py
+++ b/examples/plot_OT_1D_smooth.py
@@ -1,8 +1,8 @@
# -*- coding: utf-8 -*-
"""
-===========================
-1D smooth optimal transport
-===========================
+================================
+Smooth optimal transport example
+================================
This example illustrates the computation of EMD, Sinkhorn and smooth OT plans
and their visualization.
diff --git a/examples/plot_OT_2D_samples.py b/examples/plot_OT_2D_samples.py
index af1bc12..1d82fb8 100644
--- a/examples/plot_OT_2D_samples.py
+++ b/examples/plot_OT_2D_samples.py
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
"""
====================================================
-2D Optimal transport between empirical distributions
+Optimal Transport between 2D empirical distributions
====================================================
Illustration of 2D optimal transport between discributions that are weighted
@@ -42,7 +42,6 @@ a, b = np.ones((n,)) / n, np.ones((n,)) / n # uniform distribution on samples
# loss matrix
M = ot.dist(xs, xt)
-M /= M.max()
##############################################################################
# Plot data
@@ -87,7 +86,7 @@ pl.title('OT matrix with samples')
#%% sinkhorn
# reg term
-lambd = 1e-3
+lambd = 1e-1
Gs = ot.sinkhorn(a, b, M, lambd)
@@ -112,7 +111,7 @@ pl.show()
#%% sinkhorn
# reg term
-lambd = 1e-3
+lambd = 1e-1
Ges = ot.bregman.empirical_sinkhorn(xs, xt, lambd)
diff --git a/examples/plot_OT_L1_vs_L2.py b/examples/plot_OT_L1_vs_L2.py
index 60353ab..cce51f8 100644
--- a/examples/plot_OT_L1_vs_L2.py
+++ b/examples/plot_OT_L1_vs_L2.py
@@ -1,10 +1,10 @@
# -*- coding: utf-8 -*-
"""
-==========================================
-2D Optimal transport for different metrics
-==========================================
+================================================
+Optimal Transport with different gournd metrics
+================================================
-2D OT on empirical distributio with different gound metric.
+2D OT on empirical distributio with different ground metric.
Stole the figure idea from Fig. 1 and 2 in
https://arxiv.org/pdf/1706.07650.pdf
@@ -23,7 +23,7 @@ import matplotlib.pylab as pl
import ot
import ot.plot
-##############################################################################
+# %%
# Dataset 1 : uniform sampling
# ----------------------------
@@ -46,7 +46,7 @@ M2 = ot.dist(xs, xt, metric='sqeuclidean')
M2 /= M2.max()
# loss matrix
-Mp = np.sqrt(ot.dist(xs, xt, metric='euclidean'))
+Mp = ot.dist(xs, xt, metric='cityblock')
Mp /= Mp.max()
# Data
@@ -71,7 +71,7 @@ pl.title('Squared Euclidean cost')
pl.subplot(1, 3, 3)
pl.imshow(Mp, interpolation='nearest')
-pl.title('Sqrt Euclidean cost')
+pl.title('L1 (cityblock cost')
pl.tight_layout()
##############################################################################
@@ -109,22 +109,22 @@ pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
pl.axis('equal')
# pl.legend(loc=0)
-pl.title('OT sqrt Euclidean')
+pl.title('OT L1 (cityblock)')
pl.tight_layout()
pl.show()
-##############################################################################
+# %%
# Dataset 2 : Partial circle
# --------------------------
-n = 50 # nb samples
+n = 20 # nb samples
xtot = np.zeros((n + 1, 2))
xtot[:, 0] = np.cos(
- (np.arange(n + 1) + 1.0) * 0.9 / (n + 2) * 2 * np.pi)
+ (np.arange(n + 1) + 1.0) * 0.8 / (n + 2) * 2 * np.pi)
xtot[:, 1] = np.sin(
- (np.arange(n + 1) + 1.0) * 0.9 / (n + 2) * 2 * np.pi)
+ (np.arange(n + 1) + 1.0) * 0.8 / (n + 2) * 2 * np.pi)
xs = xtot[:n, :]
xt = xtot[1:, :]
@@ -140,7 +140,7 @@ M2 = ot.dist(xs, xt, metric='sqeuclidean')
M2 /= M2.max()
# loss matrix
-Mp = np.sqrt(ot.dist(xs, xt, metric='euclidean'))
+Mp = ot.dist(xs, xt, metric='cityblock')
Mp /= Mp.max()
@@ -150,7 +150,7 @@ pl.clf()
pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
pl.axis('equal')
-pl.title('Source and traget distributions')
+pl.title('Source and target distributions')
# Cost matrices
@@ -166,13 +166,13 @@ pl.title('Squared Euclidean cost')
pl.subplot(1, 3, 3)
pl.imshow(Mp, interpolation='nearest')
-pl.title('Sqrt Euclidean cost')
+pl.title('L1 (cityblock) cost')
pl.tight_layout()
##############################################################################
# Dataset 2 : Plot OT Matrices
# -----------------------------
-
+#
#%% EMD
G1 = ot.emd(a, b, M1)
@@ -204,7 +204,7 @@ pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
pl.axis('equal')
# pl.legend(loc=0)
-pl.title('OT sqrt Euclidean')
+pl.title('OT L1 (cityblock)')
pl.tight_layout()
pl.show()
diff --git a/examples/plot_compute_emd.py b/examples/plot_compute_emd.py
index 527a847..36cc7da 100644
--- a/examples/plot_compute_emd.py
+++ b/examples/plot_compute_emd.py
@@ -1,10 +1,10 @@
# -*- coding: utf-8 -*-
"""
-=================
-Plot multiple EMD
-=================
+==================
+OT distances in 1D
+==================
-Shows how to compute multiple EMD and Sinkhorn with two different
+Shows how to compute multiple Wassersein and Sinkhorn with two different
ground metrics and plot their values for different distributions.
@@ -14,7 +14,7 @@ ground metrics and plot their values for different distributions.
#
# License: MIT License
-# sphinx_gallery_thumbnail_number = 3
+# sphinx_gallery_thumbnail_number = 2
import numpy as np
import matplotlib.pylab as pl
@@ -29,7 +29,7 @@ from ot.datasets import make_1D_gauss as gauss
#%% parameters
n = 100 # nb bins
-n_target = 50 # nb target distributions
+n_target = 20 # nb target distributions
# bin positions
@@ -47,9 +47,9 @@ for i, m in enumerate(lst_m):
# loss matrix and normalization
M = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)), 'euclidean')
-M /= M.max()
+M /= M.max() * 0.1
M2 = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)), 'sqeuclidean')
-M2 /= M2.max()
+M2 /= M2.max() * 0.1
##############################################################################
# Plot data
@@ -59,10 +59,12 @@ M2 /= M2.max()
pl.figure(1)
pl.subplot(2, 1, 1)
-pl.plot(x, a, 'b', label='Source distribution')
+pl.plot(x, a, 'r', label='Source distribution')
pl.title('Source distribution')
pl.subplot(2, 1, 2)
-pl.plot(x, B, label='Target distributions')
+for i in range(n_target):
+ pl.plot(x, B[:, i], 'b', alpha=i / n_target)
+pl.plot(x, B[:, -1], 'b', label='Target distributions')
pl.title('Target distributions')
pl.tight_layout()
@@ -73,14 +75,27 @@ pl.tight_layout()
#%% Compute and plot distributions and loss matrix
-d_emd = ot.emd2(a, B, M) # direct computation of EMD
-d_emd2 = ot.emd2(a, B, M2) # direct computation of EMD with loss M2
-
+d_emd = ot.emd2(a, B, M) # direct computation of OT loss
+d_emd2 = ot.emd2(a, B, M2) # direct computation of OT loss with metrixc M2
+d_tv = [np.sum(abs(a - B[:, i])) for i in range(n_target)]
pl.figure(2)
-pl.plot(d_emd, label='Euclidean EMD')
-pl.plot(d_emd2, label='Squared Euclidean EMD')
-pl.title('EMD distances')
+pl.subplot(2, 1, 1)
+pl.plot(x, a, 'r', label='Source distribution')
+pl.title('Distributions')
+for i in range(n_target):
+ pl.plot(x, B[:, i], 'b', alpha=i / n_target)
+pl.plot(x, B[:, -1], 'b', label='Target distributions')
+pl.ylim((-.01, 0.13))
+pl.xticks(())
+pl.legend()
+pl.subplot(2, 1, 2)
+pl.plot(d_emd, label='Euclidean OT')
+pl.plot(d_emd2, label='Squared Euclidean OT')
+pl.plot(d_tv, label='Total Variation (TV)')
+#pl.xlim((-7,23))
+pl.xlabel('Displacement')
+pl.title('Divergences')
pl.legend()
##############################################################################
@@ -88,17 +103,30 @@ pl.legend()
# -----------------------------------------
#%%
-reg = 1e-2
+reg = 1e-1
d_sinkhorn = ot.sinkhorn2(a, B, M, reg)
d_sinkhorn2 = ot.sinkhorn2(a, B, M2, reg)
-pl.figure(2)
+pl.figure(3)
pl.clf()
-pl.plot(d_emd, label='Euclidean EMD')
-pl.plot(d_emd2, label='Squared Euclidean EMD')
+
+pl.subplot(2, 1, 1)
+pl.plot(x, a, 'r', label='Source distribution')
+pl.title('Distributions')
+for i in range(n_target):
+ pl.plot(x, B[:, i], 'b', alpha=i / n_target)
+pl.plot(x, B[:, -1], 'b', label='Target distributions')
+pl.ylim((-.01, 0.13))
+pl.xticks(())
+pl.legend()
+pl.subplot(2, 1, 2)
+pl.plot(d_emd, label='Euclidean OT')
+pl.plot(d_emd2, label='Squared Euclidean OT')
pl.plot(d_sinkhorn, '+', label='Euclidean Sinkhorn')
pl.plot(d_sinkhorn2, '+', label='Squared Euclidean Sinkhorn')
-pl.title('EMD distances')
+pl.plot(d_tv, label='Total Variation (TV)')
+#pl.xlim((-7,23))
+pl.xlabel('Displacement')
+pl.title('Divergences')
pl.legend()
-
pl.show()
diff --git a/examples/plot_optim_OTreg.py b/examples/plot_optim_OTreg.py
index 5eb15bd..7b021d2 100644
--- a/examples/plot_optim_OTreg.py
+++ b/examples/plot_optim_OTreg.py
@@ -24,7 +24,7 @@ arXiv preprint arXiv:1510.06567.
"""
-# sphinx_gallery_thumbnail_number = 4
+# sphinx_gallery_thumbnail_number = 5
import numpy as np
import matplotlib.pylab as pl
@@ -58,7 +58,7 @@ M /= M.max()
G0 = ot.emd(a, b, M)
-pl.figure(3, figsize=(5, 5))
+pl.figure(1, figsize=(5, 5))
ot.plot.plot1D_mat(a, b, G0, 'OT matrix G0')
##############################################################################
@@ -80,7 +80,7 @@ reg = 1e-1
Gl2 = ot.optim.cg(a, b, M, reg, f, df, verbose=True)
-pl.figure(3)
+pl.figure(2)
ot.plot.plot1D_mat(a, b, Gl2, 'OT matrix Frob. reg')
##############################################################################
@@ -102,7 +102,7 @@ reg = 1e-3
Ge = ot.optim.cg(a, b, M, reg, f, df, verbose=True)
-pl.figure(4, figsize=(5, 5))
+pl.figure(3, figsize=(5, 5))
ot.plot.plot1D_mat(a, b, Ge, 'OT matrix Entrop. reg')
##############################################################################
@@ -125,6 +125,34 @@ reg2 = 1e-1
Gel2 = ot.optim.gcg(a, b, M, reg1, reg2, f, df, verbose=True)
-pl.figure(5, figsize=(5, 5))
+pl.figure(4, figsize=(5, 5))
ot.plot.plot1D_mat(a, b, Gel2, 'OT entropic + matrix Frob. reg')
pl.show()
+
+
+# %%
+# Comparison of the OT matrices
+
+nvisu = 40
+
+pl.figure(5, figsize=(10, 4))
+
+pl.subplot(2, 2, 1)
+pl.imshow(G0[:nvisu, :])
+pl.axis('off')
+pl.title('Exact OT')
+
+pl.subplot(2, 2, 2)
+pl.imshow(Gl2[:nvisu, :])
+pl.axis('off')
+pl.title('Frobenius reg.')
+
+pl.subplot(2, 2, 3)
+pl.imshow(Ge[:nvisu, :])
+pl.axis('off')
+pl.title('Entropic reg.')
+
+pl.subplot(2, 2, 4)
+pl.imshow(Gel2[:nvisu, :])
+pl.axis('off')
+pl.title('Entropic + Frobenius reg.')
diff --git a/examples/sliced-wasserstein/README.txt b/examples/sliced-wasserstein/README.txt
index a575345..73e6122 100644
--- a/examples/sliced-wasserstein/README.txt
+++ b/examples/sliced-wasserstein/README.txt
@@ -1,4 +1,4 @@
Sliced Wasserstein Distance
---------------------------- \ No newline at end of file
+---------------------------
diff --git a/examples/sliced-wasserstein/plot_variance.py b/examples/sliced-wasserstein/plot_variance.py
index 7d73907..f12b522 100644
--- a/examples/sliced-wasserstein/plot_variance.py
+++ b/examples/sliced-wasserstein/plot_variance.py
@@ -1,8 +1,8 @@
# -*- coding: utf-8 -*-
"""
-==============================
-2D Sliced Wasserstein Distance
-==============================
+===============================================
+Sliced Wasserstein Distance on 2D distributions
+===============================================
This example illustrates the computation of the sliced Wasserstein Distance as
proposed in [31].
@@ -16,6 +16,8 @@ measures." Journal of Mathematical Imaging and Vision 51.1 (2015): 22-45
#
# License: MIT License
+# sphinx_gallery_thumbnail_number = 2
+
import matplotlib.pylab as pl
import numpy as np
diff --git a/examples/unbalanced-partial/plot_UOT_1D.py b/examples/unbalanced-partial/plot_UOT_1D.py
index 183849c..06dd02d 100644
--- a/examples/unbalanced-partial/plot_UOT_1D.py
+++ b/examples/unbalanced-partial/plot_UOT_1D.py
@@ -12,6 +12,8 @@ using a Kullback-Leibler relaxation.
#
# License: MIT License
+# sphinx_gallery_thumbnail_number = 4
+
import numpy as np
import matplotlib.pylab as pl
import ot
@@ -69,7 +71,20 @@ epsilon = 0.1 # entropy parameter
alpha = 1. # Unbalanced KL relaxation parameter
Gs = ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, alpha, verbose=True)
-pl.figure(4, figsize=(5, 5))
+pl.figure(3, figsize=(5, 5))
ot.plot.plot1D_mat(a, b, Gs, 'UOT matrix Sinkhorn')
pl.show()
+
+
+# %%
+# plot the transported mass
+# -------------------------
+
+pl.figure(4, figsize=(6.4, 3))
+pl.plot(x, a, 'b', label='Source distribution')
+pl.plot(x, b, 'r', label='Target distribution')
+pl.fill(x, Gs.sum(1), 'b', alpha=0.5, label='Transported source')
+pl.fill(x, Gs.sum(0), 'r', alpha=0.5, label='Transported target')
+pl.legend(loc='upper right')
+pl.title('Distributions and transported mass for UOT')
diff --git a/examples/unbalanced-partial/plot_regpath.py b/examples/unbalanced-partial/plot_regpath.py
index 4a51c2d..782e8c2 100644
--- a/examples/unbalanced-partial/plot_regpath.py
+++ b/examples/unbalanced-partial/plot_regpath.py
@@ -15,11 +15,12 @@ penalized linear regression.
# Author: Haoran Wu <haoran.wu@univ-ubs.fr>
# License: MIT License
+# sphinx_gallery_thumbnail_number = 2
import numpy as np
import matplotlib.pylab as pl
import ot
-
+import matplotlib.animation as animation
##############################################################################
# Generate data
# -------------
@@ -72,6 +73,9 @@ t2, t_list2, g_list2 = ot.regpath.regularization_path(a, b, M, reg=final_gamma,
##############################################################################
# Plot the regularization path
# ----------------
+#
+# The OT plan is ploted as a function of $\gamma$ that is the inverse of the
+# weight on the marginal relaxations.
#%% fully relaxed l2-penalized UOT
@@ -103,13 +107,53 @@ for p in range(4):
pl.show()
+# %%
+# Animation of the regpath for UOT l2
+# ------------------------
+
+nv = 100
+g_list_v = np.logspace(-.5, -2.5, nv)
+
+pl.figure(3)
+
+
+def _update_plot(iv):
+ pl.clf()
+ tp = ot.regpath.compute_transport_plan(g_list_v[iv], g_list,
+ t_list)
+ P = tp.reshape((n, n))
+ if P.sum() > 0:
+ P = P / P.max()
+ for i in range(n):
+ for j in range(n):
+ if P[i, j] > 0:
+ pl.plot([xs[i, 0], xt[j, 0]], [xs[i, 1], xt[j, 1]], color='C2',
+ alpha=P[i, j] * 0.5)
+ pl.scatter(xs[:, 0], xs[:, 1], c='C0', alpha=0.2)
+ pl.scatter(xt[:, 0], xt[:, 1], c='C1', alpha=0.2)
+ pl.scatter(xs[:, 0], xs[:, 1], c='C0', s=P.sum(1).ravel() * (1 + p) * 4,
+ label='Re-weighted source', alpha=1)
+ pl.scatter(xt[:, 0], xt[:, 1], c='C1', s=P.sum(0).ravel() * (1 + p) * 4,
+ label='Re-weighted target', alpha=1)
+ pl.plot([], [], color='C2', alpha=0.8, label='OT plan')
+ pl.title(r'$\ell_2$ UOT $\gamma$={:1.3f}'.format(g_list_v[iv]),
+ fontsize=11)
+ return 1
+
+
+i = 0
+_update_plot(i)
+
+ani = animation.FuncAnimation(pl.gcf(), _update_plot, nv, interval=50, repeat_delay=2000)
+
+
##############################################################################
# Plot the semi-relaxed regularization path
# -------------------
#%% semi-relaxed l2-penalized UOT
-pl.figure(3)
+pl.figure(4)
selected_gamma = [10, 1, 1e-1, 1e-2]
for p in range(4):
tp = ot.regpath.compute_transport_plan(selected_gamma[p], g_list2,
@@ -133,3 +177,43 @@ for p in range(4):
if p < 2:
pl.xticks(())
pl.show()
+
+
+# %%
+# Animation of the regpath for semi-relaxed UOT l2
+# ------------------------
+
+nv = 100
+g_list_v = np.logspace(2.5, -2, nv)
+
+pl.figure(5)
+
+
+def _update_plot(iv):
+ pl.clf()
+ tp = ot.regpath.compute_transport_plan(g_list_v[iv], g_list2,
+ t_list2)
+ P = tp.reshape((n, n))
+ if P.sum() > 0:
+ P = P / P.max()
+ for i in range(n):
+ for j in range(n):
+ if P[i, j] > 0:
+ pl.plot([xs[i, 0], xt[j, 0]], [xs[i, 1], xt[j, 1]], color='C2',
+ alpha=P[i, j] * 0.5)
+ pl.scatter(xs[:, 0], xs[:, 1], c='C0', alpha=0.2)
+ pl.scatter(xt[:, 0], xt[:, 1], c='C1', alpha=0.2)
+ pl.scatter(xs[:, 0], xs[:, 1], c='C0', s=P.sum(1).ravel() * (1 + p) * 4,
+ label='Re-weighted source', alpha=1)
+ pl.scatter(xt[:, 0], xt[:, 1], c='C1', s=P.sum(0).ravel() * (1 + p) * 4,
+ label='Re-weighted target', alpha=1)
+ pl.plot([], [], color='C2', alpha=0.8, label='OT plan')
+ pl.title(r'Semi-relaxed $\ell_2$ UOT $\gamma$={:1.3f}'.format(g_list_v[iv]),
+ fontsize=11)
+ return 1
+
+
+i = 0
+_update_plot(i)
+
+ani = animation.FuncAnimation(pl.gcf(), _update_plot, nv, interval=50, repeat_delay=2000)
diff --git a/examples/unbalanced-partial/plot_unbalanced_OT.py b/examples/unbalanced-partial/plot_unbalanced_OT.py
new file mode 100644
index 0000000..03487e7
--- /dev/null
+++ b/examples/unbalanced-partial/plot_unbalanced_OT.py
@@ -0,0 +1,116 @@
+# -*- coding: utf-8 -*-
+"""
+==============================================================
+2D examples of exact and entropic unbalanced optimal transport
+==============================================================
+This example is designed to show how to compute unbalanced and
+partial OT in POT.
+
+UOT aims at solving the following optimization problem:
+
+ .. math::
+ W = \min_{\gamma} <\gamma, \mathbf{M}>_F +
+ \mathrm{reg}\cdot\Omega(\gamma) +
+ \mathrm{reg_m} \cdot \mathrm{div}(\gamma \mathbf{1}, \mathbf{a}) +
+ \mathrm{reg_m} \cdot \mathrm{div}(\gamma^T \mathbf{1}, \mathbf{b})
+
+ s.t.
+ \gamma \geq 0
+
+where :math:`\mathrm{div}` is a divergence.
+When using the entropic UOT, :math:`\mathrm{reg}>0` and :math:`\mathrm{div}`
+should be the Kullback-Leibler divergence.
+When solving exact UOT, :math:`\mathrm{reg}=0` and :math:`\mathrm{div}`
+can be either the Kullback-Leibler or the quadratic divergence.
+Using :math:`\ell_1` norm gives the so-called partial OT.
+"""
+
+# Author: Laetitia Chapel <laetitia.chapel@univ-ubs.fr>
+# License: MIT License
+
+import numpy as np
+import matplotlib.pylab as pl
+import ot
+
+##############################################################################
+# Generate data
+# -------------
+
+# %% parameters and data generation
+
+n = 40 # nb samples
+
+mu_s = np.array([-1, -1])
+cov_s = np.array([[1, 0], [0, 1]])
+
+mu_t = np.array([4, 4])
+cov_t = np.array([[1, -.8], [-.8, 1]])
+
+np.random.seed(0)
+xs = ot.datasets.make_2D_samples_gauss(n, mu_s, cov_s)
+xt = ot.datasets.make_2D_samples_gauss(n, mu_t, cov_t)
+
+n_noise = 10
+
+xs = np.concatenate((xs, ((np.random.rand(n_noise, 2) - 4))), axis=0)
+xt = np.concatenate((xt, ((np.random.rand(n_noise, 2) + 6))), axis=0)
+
+n = n + n_noise
+
+a, b = np.ones((n,)) / n, np.ones((n,)) / n # uniform distribution on samples
+
+# loss matrix
+M = ot.dist(xs, xt)
+M /= M.max()
+
+
+##############################################################################
+# Compute entropic kl-regularized UOT, kl- and l2-regularized UOT
+# -----------
+
+reg = 0.005
+reg_m_kl = 0.05
+reg_m_l2 = 5
+mass = 0.7
+
+entropic_kl_uot = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg, reg_m_kl)
+kl_uot = ot.unbalanced.mm_unbalanced(a, b, M, reg_m_kl, div='kl')
+l2_uot = ot.unbalanced.mm_unbalanced(a, b, M, reg_m_l2, div='l2')
+partial_ot = ot.partial.partial_wasserstein(a, b, M, m=mass)
+
+##############################################################################
+# Plot the results
+# ----------------
+
+pl.figure(2)
+transp = [partial_ot, l2_uot, kl_uot, entropic_kl_uot]
+title = ["partial OT \n m=" + str(mass), "$\ell_2$-UOT \n $\mathrm{reg_m}$=" +
+ str(reg_m_l2), "kl-UOT \n $\mathrm{reg_m}$=" + str(reg_m_kl),
+ "entropic kl-UOT \n $\mathrm{reg_m}$=" + str(reg_m_kl)]
+
+for p in range(4):
+ pl.subplot(2, 4, p + 1)
+ P = transp[p]
+ if P.sum() > 0:
+ P = P / P.max()
+ for i in range(n):
+ for j in range(n):
+ if P[i, j] > 0:
+ pl.plot([xs[i, 0], xt[j, 0]], [xs[i, 1], xt[j, 1]], color='C2',
+ alpha=P[i, j] * 0.3)
+ pl.scatter(xs[:, 0], xs[:, 1], c='C0', alpha=0.2)
+ pl.scatter(xt[:, 0], xt[:, 1], c='C1', alpha=0.2)
+ pl.scatter(xs[:, 0], xs[:, 1], c='C0', s=P.sum(1).ravel() * (1 + p) * 2)
+ pl.scatter(xt[:, 0], xt[:, 1], c='C1', s=P.sum(0).ravel() * (1 + p) * 2)
+ pl.title(title[p])
+ pl.yticks(())
+ pl.xticks(())
+ if p < 1:
+ pl.ylabel("mappings")
+ pl.subplot(2, 4, p + 5)
+ pl.imshow(P, cmap='jet')
+ pl.yticks(())
+ pl.xticks(())
+ if p < 1:
+ pl.ylabel("transport plans")
+pl.show()
diff --git a/ot/__init__.py b/ot/__init__.py
index f55819d..86ed94e 100644
--- a/ot/__init__.py
+++ b/ot/__init__.py
@@ -1,5 +1,4 @@
"""
-
.. warning::
The list of automatically imported sub-modules is as follows:
:py:mod:`ot.lp`, :py:mod:`ot.bregman`, :py:mod:`ot.optim`
@@ -7,13 +6,10 @@
:py:mod:`ot.gromov`, :py:mod:`ot.smooth`
:py:mod:`ot.stochastic`, :py:mod:`ot.partial`, :py:mod:`ot.regpath`
, :py:mod:`ot.unbalanced`.
-
The following sub-modules are not imported due to additional dependencies:
-
- :any:`ot.dr` : depends on :code:`pymanopt` and :code:`autograd`.
- :any:`ot.gpu` : depends on :code:`cupy` and a CUDA GPU.
- :any:`ot.plot` : depends on :code:`matplotlib`
-
"""
# Author: Remi Flamary <remi.flamary@unice.fr>
@@ -36,6 +32,8 @@ from . import unbalanced
from . import partial
from . import backend
from . import regpath
+from . import weak
+from . import factored
# OT functions
from .lp import emd, emd2, emd_1d, emd2_1d, wasserstein_1d
@@ -46,11 +44,14 @@ from .da import sinkhorn_lpl1_mm
from .sliced import sliced_wasserstein_distance, max_sliced_wasserstein_distance
from .gromov import (gromov_wasserstein, gromov_wasserstein2,
gromov_barycenters, fused_gromov_wasserstein, fused_gromov_wasserstein2)
+from .weak import weak_optimal_transport
+from .factored import factored_optimal_transport
+
# utils functions
from .utils import dist, unif, tic, toc, toq
-__version__ = "0.8.1.0"
+__version__ = "0.8.2"
__all__ = ['emd', 'emd2', 'emd_1d', 'sinkhorn', 'sinkhorn2', 'utils',
'datasets', 'bregman', 'lp', 'tic', 'toc', 'toq', 'gromov',
@@ -59,5 +60,6 @@ __all__ = ['emd', 'emd2', 'emd_1d', 'sinkhorn', 'sinkhorn2', 'utils',
'sinkhorn_unbalanced', 'barycenter_unbalanced',
'sinkhorn_unbalanced2', 'sliced_wasserstein_distance',
'gromov_wasserstein', 'gromov_wasserstein2', 'gromov_barycenters', 'fused_gromov_wasserstein', 'fused_gromov_wasserstein2',
- 'max_sliced_wasserstein_distance',
+ 'max_sliced_wasserstein_distance', 'weak_optimal_transport',
+ 'factored_optimal_transport',
'smooth', 'stochastic', 'unbalanced', 'partial', 'regpath']
diff --git a/ot/backend.py b/ot/backend.py
index 58b652b..361ffba 100644
--- a/ot/backend.py
+++ b/ot/backend.py
@@ -87,7 +87,9 @@ Performance
# License: MIT License
import numpy as np
-import scipy.special as scipy
+import scipy
+import scipy.linalg
+import scipy.special as special
from scipy.sparse import issparse, coo_matrix, csr_matrix
import warnings
import time
@@ -102,7 +104,7 @@ except ImportError:
try:
import jax
import jax.numpy as jnp
- import jax.scipy.special as jscipy
+ import jax.scipy.special as jspecial
from jax.lib import xla_bridge
jax_type = jax.numpy.ndarray
except ImportError:
@@ -202,13 +204,29 @@ class Backend():
def __str__(self):
return self.__name__
- # convert to numpy
- def to_numpy(self, a):
+ # convert batch of tensors to numpy
+ def to_numpy(self, *arrays):
+ """Returns the numpy version of tensors"""
+ if len(arrays) == 1:
+ return self._to_numpy(arrays[0])
+ else:
+ return [self._to_numpy(array) for array in arrays]
+
+ # convert a tensor to numpy
+ def _to_numpy(self, a):
"""Returns the numpy version of a tensor"""
raise NotImplementedError()
- # convert from numpy
- def from_numpy(self, a, type_as=None):
+ # convert batch of arrays from numpy
+ def from_numpy(self, *arrays, type_as=None):
+ """Creates tensors cloning a numpy array, with the given precision (defaulting to input's precision) and the given device (in case of GPUs)"""
+ if len(arrays) == 1:
+ return self._from_numpy(arrays[0], type_as=type_as)
+ else:
+ return [self._from_numpy(array, type_as=type_as) for array in arrays]
+
+ # convert an array from numpy
+ def _from_numpy(self, a, type_as=None):
"""Creates a tensor cloning a numpy array, with the given precision (defaulting to input's precision) and the given device (in case of GPUs)"""
raise NotImplementedError()
@@ -536,6 +554,16 @@ class Backend():
"""
raise NotImplementedError()
+ def argmin(self, a, axis=None):
+ r"""
+ Returns the indices of the minimum values of a tensor along given dimensions.
+
+ This function follows the api from :any:`numpy.argmin`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.argmin.html
+ """
+ raise NotImplementedError()
+
def mean(self, a, axis=None):
r"""
Computes the arithmetic mean of a tensor along given dimensions.
@@ -786,6 +814,72 @@ class Backend():
"""
raise NotImplementedError()
+ def solve(self, a, b):
+ r"""
+ Solves a linear matrix equation, or system of linear scalar equations.
+
+ This function follows the api from :any:`numpy.linalg.solve`.
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.linalg.solve.html
+ """
+ raise NotImplementedError()
+
+ def trace(self, a):
+ r"""
+ Returns the sum along diagonals of the array.
+
+ This function follows the api from :any:`numpy.trace`.
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.trace.html
+ """
+ raise NotImplementedError()
+
+ def inv(self, a):
+ r"""
+ Computes the inverse of a matrix.
+
+ This function follows the api from :any:`scipy.linalg.inv`.
+
+ See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.linalg.inv.html
+ """
+ raise NotImplementedError()
+
+ def sqrtm(self, a):
+ r"""
+ Computes the matrix square root. Requires input to be definite positive.
+
+ This function follows the api from :any:`scipy.linalg.sqrtm`.
+
+ See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.linalg.sqrtm.html
+ """
+ raise NotImplementedError()
+
+ def isfinite(self, a):
+ r"""
+ Tests element-wise for finiteness (not infinity and not Not a Number).
+
+ This function follows the api from :any:`numpy.isfinite`.
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.isfinite.html
+ """
+ raise NotImplementedError()
+
+ def array_equal(self, a, b):
+ r"""
+ True if two arrays have the same shape and elements, False otherwise.
+
+ This function follows the api from :any:`numpy.array_equal`.
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.array_equal.html
+ """
+ raise NotImplementedError()
+
+ def is_floating_point(self, a):
+ r"""
+ Returns whether or not the input consists of floats
+ """
+ raise NotImplementedError()
+
class NumpyBackend(Backend):
"""
@@ -802,10 +896,10 @@ class NumpyBackend(Backend):
rng_ = np.random.RandomState()
- def to_numpy(self, a):
+ def _to_numpy(self, a):
return a
- def from_numpy(self, a, type_as=None):
+ def _from_numpy(self, a, type_as=None):
if type_as is None:
return a
elif isinstance(a, float):
@@ -936,6 +1030,9 @@ class NumpyBackend(Backend):
def argmax(self, a, axis=None):
return np.argmax(a, axis=axis)
+ def argmin(self, a, axis=None):
+ return np.argmin(a, axis=axis)
+
def mean(self, a, axis=None):
return np.mean(a, axis=axis)
@@ -955,7 +1052,7 @@ class NumpyBackend(Backend):
return np.unique(a)
def logsumexp(self, a, axis=None):
- return scipy.logsumexp(a, axis=axis)
+ return special.logsumexp(a, axis=axis)
def stack(self, arrays, axis=0):
return np.stack(arrays, axis)
@@ -1004,8 +1101,11 @@ class NumpyBackend(Backend):
else:
return a
- def where(self, condition, x, y):
- return np.where(condition, x, y)
+ def where(self, condition, x=None, y=None):
+ if x is None and y is None:
+ return np.where(condition)
+ else:
+ return np.where(condition, x, y)
def copy(self, a):
return a.copy()
@@ -1046,6 +1146,27 @@ class NumpyBackend(Backend):
results[key] = (t1 - t0) / n_runs
return results
+ def solve(self, a, b):
+ return np.linalg.solve(a, b)
+
+ def trace(self, a):
+ return np.trace(a)
+
+ def inv(self, a):
+ return scipy.linalg.inv(a)
+
+ def sqrtm(self, a):
+ return scipy.linalg.sqrtm(a)
+
+ def isfinite(self, a):
+ return np.isfinite(a)
+
+ def array_equal(self, a, b):
+ return np.array_equal(a, b)
+
+ def is_floating_point(self, a):
+ return a.dtype.kind == "f"
+
class JaxBackend(Backend):
"""
@@ -1075,13 +1196,15 @@ class JaxBackend(Backend):
jax.device_put(jnp.array(1, dtype=jnp.float64), d)
]
- def to_numpy(self, a):
+ def _to_numpy(self, a):
return np.array(a)
def _change_device(self, a, type_as):
return jax.device_put(a, type_as.device_buffer.device())
- def from_numpy(self, a, type_as=None):
+ def _from_numpy(self, a, type_as=None):
+ if isinstance(a, float):
+ a = np.array(a)
if type_as is None:
return jnp.array(a)
else:
@@ -1216,6 +1339,9 @@ class JaxBackend(Backend):
def argmax(self, a, axis=None):
return jnp.argmax(a, axis=axis)
+ def argmin(self, a, axis=None):
+ return jnp.argmin(a, axis=axis)
+
def mean(self, a, axis=None):
return jnp.mean(a, axis=axis)
@@ -1235,7 +1361,7 @@ class JaxBackend(Backend):
return jnp.unique(a)
def logsumexp(self, a, axis=None):
- return jscipy.logsumexp(a, axis=axis)
+ return jspecial.logsumexp(a, axis=axis)
def stack(self, arrays, axis=0):
return jnp.stack(arrays, axis)
@@ -1293,8 +1419,11 @@ class JaxBackend(Backend):
# Currently, JAX does not support sparse matrices
return a
- def where(self, condition, x, y):
- return jnp.where(condition, x, y)
+ def where(self, condition, x=None, y=None):
+ if x is None and y is None:
+ return jnp.where(condition)
+ else:
+ return jnp.where(condition, x, y)
def copy(self, a):
# No need to copy, JAX arrays are immutable
@@ -1339,6 +1468,28 @@ class JaxBackend(Backend):
results[key] = (t1 - t0) / n_runs
return results
+ def solve(self, a, b):
+ return jnp.linalg.solve(a, b)
+
+ def trace(self, a):
+ return jnp.trace(a)
+
+ def inv(self, a):
+ return jnp.linalg.inv(a)
+
+ def sqrtm(self, a):
+ L, V = jnp.linalg.eigh(a)
+ return (V * jnp.sqrt(L)[None, :]) @ V.T
+
+ def isfinite(self, a):
+ return jnp.isfinite(a)
+
+ def array_equal(self, a, b):
+ return jnp.array_equal(a, b)
+
+ def is_floating_point(self, a):
+ return a.dtype.kind == "f"
+
class TorchBackend(Backend):
"""
@@ -1384,10 +1535,10 @@ class TorchBackend(Backend):
self.ValFunction = ValFunction
- def to_numpy(self, a):
+ def _to_numpy(self, a):
return a.cpu().detach().numpy()
- def from_numpy(self, a, type_as=None):
+ def _from_numpy(self, a, type_as=None):
if isinstance(a, float):
a = np.array(a)
if type_as is None:
@@ -1397,7 +1548,7 @@ class TorchBackend(Backend):
def set_gradients(self, val, inputs, grads):
- Func = self.ValFunction()
+ Func = self.ValFunction
res = Func.apply(val, grads, *inputs)
@@ -1564,6 +1715,9 @@ class TorchBackend(Backend):
def argmax(self, a, axis=None):
return torch.argmax(a, dim=axis)
+ def argmin(self, a, axis=None):
+ return torch.argmin(a, dim=axis)
+
def mean(self, a, axis=None):
if axis is not None:
return torch.mean(a, dim=axis)
@@ -1580,8 +1734,11 @@ class TorchBackend(Backend):
return torch.linspace(start, stop, num, dtype=torch.float64)
def meshgrid(self, a, b):
- X, Y = torch.meshgrid(a, b)
- return X.T, Y.T
+ try:
+ return torch.meshgrid(a, b, indexing="xy")
+ except TypeError:
+ X, Y = torch.meshgrid(a, b)
+ return X.T, Y.T
def diag(self, a, k=0):
return torch.diag(a, diagonal=k)
@@ -1659,8 +1816,11 @@ class TorchBackend(Backend):
else:
return a
- def where(self, condition, x, y):
- return torch.where(condition, x, y)
+ def where(self, condition, x=None, y=None):
+ if x is None and y is None:
+ return torch.where(condition)
+ else:
+ return torch.where(condition, x, y)
def copy(self, a):
return torch.clone(a)
@@ -1718,6 +1878,28 @@ class TorchBackend(Backend):
torch.cuda.empty_cache()
return results
+ def solve(self, a, b):
+ return torch.linalg.solve(a, b)
+
+ def trace(self, a):
+ return torch.trace(a)
+
+ def inv(self, a):
+ return torch.linalg.inv(a)
+
+ def sqrtm(self, a):
+ L, V = torch.linalg.eigh(a)
+ return (V * torch.sqrt(L)[None, :]) @ V.T
+
+ def isfinite(self, a):
+ return torch.isfinite(a)
+
+ def array_equal(self, a, b):
+ return torch.equal(a, b)
+
+ def is_floating_point(self, a):
+ return a.dtype.is_floating_point
+
class CupyBackend(Backend): # pragma: no cover
"""
@@ -1741,10 +1923,12 @@ class CupyBackend(Backend): # pragma: no cover
cp.array(1, dtype=cp.float64)
]
- def to_numpy(self, a):
+ def _to_numpy(self, a):
return cp.asnumpy(a)
- def from_numpy(self, a, type_as=None):
+ def _from_numpy(self, a, type_as=None):
+ if isinstance(a, float):
+ a = np.array(a)
if type_as is None:
return cp.asarray(a)
else:
@@ -1884,6 +2068,9 @@ class CupyBackend(Backend): # pragma: no cover
def argmax(self, a, axis=None):
return cp.argmax(a, axis=axis)
+ def argmin(self, a, axis=None):
+ return cp.argmin(a, axis=axis)
+
def mean(self, a, axis=None):
return cp.mean(a, axis=axis)
@@ -1982,8 +2169,11 @@ class CupyBackend(Backend): # pragma: no cover
else:
return a
- def where(self, condition, x, y):
- return cp.where(condition, x, y)
+ def where(self, condition, x=None, y=None):
+ if x is None and y is None:
+ return cp.where(condition)
+ else:
+ return cp.where(condition, x, y)
def copy(self, a):
return a.copy()
@@ -2035,6 +2225,28 @@ class CupyBackend(Backend): # pragma: no cover
pinned_mempool.free_all_blocks()
return results
+ def solve(self, a, b):
+ return cp.linalg.solve(a, b)
+
+ def trace(self, a):
+ return cp.trace(a)
+
+ def inv(self, a):
+ return cp.linalg.inv(a)
+
+ def sqrtm(self, a):
+ L, V = cp.linalg.eigh(a)
+ return (V * self.sqrt(L)[None, :]) @ V.T
+
+ def isfinite(self, a):
+ return cp.isfinite(a)
+
+ def array_equal(self, a, b):
+ return cp.array_equal(a, b)
+
+ def is_floating_point(self, a):
+ return a.dtype.kind == "f"
+
class TensorflowBackend(Backend):
@@ -2060,13 +2272,16 @@ class TensorflowBackend(Backend):
"To use TensorflowBackend, you need to activate the tensorflow "
"numpy API. You can activate it by running: \n"
"from tensorflow.python.ops.numpy_ops import np_config\n"
- "np_config.enable_numpy_behavior()"
+ "np_config.enable_numpy_behavior()",
+ stacklevel=2
)
- def to_numpy(self, a):
+ def _to_numpy(self, a):
return a.numpy()
- def from_numpy(self, a, type_as=None):
+ def _from_numpy(self, a, type_as=None):
+ if isinstance(a, float):
+ a = np.array(a)
if not isinstance(a, self.__type__):
if type_as is None:
return tf.convert_to_tensor(a)
@@ -2208,6 +2423,9 @@ class TensorflowBackend(Backend):
def argmax(self, a, axis=None):
return tnp.argmax(a, axis=axis)
+ def argmin(self, a, axis=None):
+ return tnp.argmin(a, axis=axis)
+
def mean(self, a, axis=None):
return tnp.mean(a, axis=axis)
@@ -2309,8 +2527,11 @@ class TensorflowBackend(Backend):
else:
return a
- def where(self, condition, x, y):
- return tnp.where(condition, x, y)
+ def where(self, condition, x=None, y=None):
+ if x is None and y is None:
+ return tnp.where(condition)
+ else:
+ return tnp.where(condition, x, y)
def copy(self, a):
return tf.identity(a)
@@ -2364,3 +2585,24 @@ class TensorflowBackend(Backend):
results[key] = (t1 - t0) / n_runs
return results
+
+ def solve(self, a, b):
+ return tf.linalg.solve(a, b)
+
+ def trace(self, a):
+ return tf.linalg.trace(a)
+
+ def inv(self, a):
+ return tf.linalg.inv(a)
+
+ def sqrtm(self, a):
+ return tf.linalg.sqrtm(a)
+
+ def isfinite(self, a):
+ return tnp.isfinite(a)
+
+ def array_equal(self, a, b):
+ return tnp.array_equal(a, b)
+
+ def is_floating_point(self, a):
+ return a.dtype.is_floating
diff --git a/ot/bregman.py b/ot/bregman.py
index fc20175..c06af2f 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -2525,8 +2525,7 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000,
# geometric interpolation
delta = nx.exp(alpha * nx.log(other) + (1 - alpha) * nx.log(inv_new))
K = projR(K, delta)
- K0 = nx.dot(nx.diag(nx.dot(D.T, delta / inv_new)), K0)
-
+ K0 = nx.dot(D.T, delta / inv_new)[:, None] * K0
err = nx.norm(nx.sum(K0, axis=1) - old)
old = new
if log:
@@ -2656,16 +2655,16 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100,
classes = nx.unique(Ys[d])
# build the corresponding D_1 and D_2 matrices
- Dtmp1 = nx.zeros((nbclasses, nsk), type_as=Xs[0])
- Dtmp2 = nx.zeros((nbclasses, nsk), type_as=Xs[0])
+ Dtmp1 = np.zeros((nbclasses, nsk))
+ Dtmp2 = np.zeros((nbclasses, nsk))
for c in classes:
- nbelemperclass = nx.sum(Ys[d] == c)
+ nbelemperclass = float(nx.sum(Ys[d] == c))
if nbelemperclass != 0:
- Dtmp1[int(c), Ys[d] == c] = 1.
- Dtmp2[int(c), Ys[d] == c] = 1. / (nbelemperclass)
- D1.append(Dtmp1)
- D2.append(Dtmp2)
+ Dtmp1[int(c), nx.to_numpy(Ys[d] == c)] = 1.
+ Dtmp2[int(c), nx.to_numpy(Ys[d] == c)] = 1. / (nbelemperclass)
+ D1.append(nx.from_numpy(Dtmp1, type_as=Xs[0]))
+ D2.append(nx.from_numpy(Dtmp2, type_as=Xs[0]))
# build the cost matrix and the Gibbs kernel
Mtmp = dist(Xs[d], Xt, metric=metric)
diff --git a/ot/da.py b/ot/da.py
index 841f31a..0b9737e 100644
--- a/ot/da.py
+++ b/ot/da.py
@@ -12,12 +12,12 @@ Domain adaptation with optimal transport
# License: MIT License
import numpy as np
-import scipy.linalg as linalg
+from .backend import get_backend
from .bregman import sinkhorn, jcpot_barycenter
from .lp import emd
from .utils import unif, dist, kernel, cost_normalization, label_normalization, laplacian, dots
-from .utils import check_params, BaseEstimator
+from .utils import list_to_array, check_params, BaseEstimator
from .unbalanced import sinkhorn_unbalanced
from .optim import cg
from .optim import gcg
@@ -60,13 +60,13 @@ def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10,
Parameters
----------
- a : np.ndarray (ns,)
+ a : array-like (ns,)
samples weights in the source domain
- labels_a : np.ndarray (ns,)
+ labels_a : array-like (ns,)
labels of samples in the source domain
- b : np.ndarray (nt,)
+ b : array-like (nt,)
samples weights in the target domain
- M : np.ndarray (ns,nt)
+ M : array-like (ns,nt)
loss matrix
reg : float
Regularization term for entropic regularization >0
@@ -86,7 +86,7 @@ def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10,
Returns
-------
- gamma : (ns, nt) ndarray
+ gamma : (ns, nt) array-like
Optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
@@ -111,26 +111,28 @@ def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10,
ot.optim.cg : General regularized OT
"""
+ a, labels_a, b, M = list_to_array(a, labels_a, b, M)
+ nx = get_backend(a, labels_a, b, M)
+
p = 0.5
epsilon = 1e-3
indices_labels = []
- classes = np.unique(labels_a)
+ classes = nx.unique(labels_a)
for c in classes:
- idxc, = np.where(labels_a == c)
+ idxc, = nx.where(labels_a == c)
indices_labels.append(idxc)
- W = np.zeros(M.shape)
-
+ W = nx.zeros(M.shape, type_as=M)
for cpt in range(numItermax):
Mreg = M + eta * W
transp = sinkhorn(a, b, Mreg, reg, numItermax=numInnerItermax,
stopThr=stopInnerThr)
# the transport has been computed. Check if classes are really
# separated
- W = np.ones(M.shape)
+ W = nx.ones(M.shape, type_as=M)
for (i, c) in enumerate(classes):
- majs = np.sum(transp[indices_labels[i]], axis=0)
+ majs = nx.sum(transp[indices_labels[i]], axis=0)
majs = p * ((majs + epsilon) ** (p - 1))
W[indices_labels[i]] = majs
@@ -174,13 +176,13 @@ def sinkhorn_l1l2_gl(a, labels_a, b, M, reg, eta=0.1, numItermax=10,
Parameters
----------
- a : np.ndarray (ns,)
+ a : array-like (ns,)
samples weights in the source domain
- labels_a : np.ndarray (ns,)
+ labels_a : array-like (ns,)
labels of samples in the source domain
- b : np.ndarray (nt,)
+ b : array-like (nt,)
samples in the target domain
- M : np.ndarray (ns,nt)
+ M : array-like (ns,nt)
loss matrix
reg : float
Regularization term for entropic regularization >0
@@ -200,7 +202,7 @@ def sinkhorn_l1l2_gl(a, labels_a, b, M, reg, eta=0.1, numItermax=10,
Returns
-------
- gamma : (ns, nt) ndarray
+ gamma : (ns, nt) array-like
Optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
@@ -222,22 +224,25 @@ def sinkhorn_l1l2_gl(a, labels_a, b, M, reg, eta=0.1, numItermax=10,
ot.optim.gcg : Generalized conditional gradient for OT problems
"""
- lstlab = np.unique(labels_a)
+ a, labels_a, b, M = list_to_array(a, labels_a, b, M)
+ nx = get_backend(a, labels_a, b, M)
+
+ lstlab = nx.unique(labels_a)
def f(G):
res = 0
for i in range(G.shape[1]):
for lab in lstlab:
temp = G[labels_a == lab, i]
- res += np.linalg.norm(temp)
+ res += nx.norm(temp)
return res
def df(G):
- W = np.zeros(G.shape)
+ W = nx.zeros(G.shape, type_as=G)
for i in range(G.shape[1]):
for lab in lstlab:
temp = G[labels_a == lab, i]
- n = np.linalg.norm(temp)
+ n = nx.norm(temp)
if n:
W[labels_a == lab, i] = temp / n
return W
@@ -289,9 +294,9 @@ def joint_OT_mapping_linear(xs, xt, mu=1, eta=0.001, bias=False, verbose=False,
Parameters
----------
- xs : np.ndarray (ns,d)
+ xs : array-like (ns,d)
samples in the source domain
- xt : np.ndarray (nt,d)
+ xt : array-like (nt,d)
samples in the target domain
mu : float,optional
Weight for the linear OT loss (>0)
@@ -315,9 +320,9 @@ def joint_OT_mapping_linear(xs, xt, mu=1, eta=0.001, bias=False, verbose=False,
Returns
-------
- gamma : (ns, nt) ndarray
+ gamma : (ns, nt) array-like
Optimal transportation matrix for the given parameters
- L : (d, d) ndarray
+ L : (d, d) array-like
Linear mapping matrix ((:math:`d+1`, `d`) if bias)
log : dict
log dictionary return only if log==True in parameters
@@ -336,13 +341,15 @@ def joint_OT_mapping_linear(xs, xt, mu=1, eta=0.001, bias=False, verbose=False,
ot.optim.cg : General regularized OT
"""
+ xs, xt = list_to_array(xs, xt)
+ nx = get_backend(xs, xt)
ns, nt, d = xs.shape[0], xt.shape[0], xt.shape[1]
if bias:
- xs1 = np.hstack((xs, np.ones((ns, 1))))
- xstxs = xs1.T.dot(xs1)
- Id = np.eye(d + 1)
+ xs1 = nx.concatenate((xs, nx.ones((ns, 1), type_as=xs)), axis=1)
+ xstxs = nx.dot(xs1.T, xs1)
+ Id = nx.eye(d + 1, type_as=xs)
Id[-1] = 0
I0 = Id[:, :-1]
@@ -350,8 +357,8 @@ def joint_OT_mapping_linear(xs, xt, mu=1, eta=0.001, bias=False, verbose=False,
return x[:-1, :]
else:
xs1 = xs
- xstxs = xs1.T.dot(xs1)
- Id = np.eye(d)
+ xstxs = nx.dot(xs1.T, xs1)
+ Id = nx.eye(d, type_as=xs)
I0 = Id
def sel(x):
@@ -360,7 +367,8 @@ def joint_OT_mapping_linear(xs, xt, mu=1, eta=0.001, bias=False, verbose=False,
if log:
log = {'err': []}
- a, b = unif(ns), unif(nt)
+ a = unif(ns, type_as=xs)
+ b = unif(nt, type_as=xt)
M = dist(xs, xt) * ns
G = emd(a, b, M)
@@ -368,23 +376,26 @@ def joint_OT_mapping_linear(xs, xt, mu=1, eta=0.001, bias=False, verbose=False,
def loss(L, G):
"""Compute full loss"""
- return np.sum((xs1.dot(L) - ns * G.dot(xt)) ** 2) + mu * \
- np.sum(G * M) + eta * np.sum(sel(L - I0) ** 2)
+ return (
+ nx.sum((nx.dot(xs1, L) - ns * nx.dot(G, xt)) ** 2)
+ + mu * nx.sum(G * M)
+ + eta * nx.sum(sel(L - I0) ** 2)
+ )
def solve_L(G):
""" solve L problem with fixed G (least square)"""
- xst = ns * G.dot(xt)
- return np.linalg.solve(xstxs + eta * Id, xs1.T.dot(xst) + eta * I0)
+ xst = ns * nx.dot(G, xt)
+ return nx.solve(xstxs + eta * Id, nx.dot(xs1.T, xst) + eta * I0)
def solve_G(L, G0):
"""Update G with CG algorithm"""
- xsi = xs1.dot(L)
+ xsi = nx.dot(xs1, L)
def f(G):
- return np.sum((xsi - ns * G.dot(xt)) ** 2)
+ return nx.sum((xsi - ns * nx.dot(G, xt)) ** 2)
def df(G):
- return -2 * ns * (xsi - ns * G.dot(xt)).dot(xt.T)
+ return -2 * ns * nx.dot(xsi - ns * nx.dot(G, xt), xt.T)
G = cg(a, b, M, 1.0 / mu, f, df, G0=G0,
numItermax=numInnerItermax, stopThr=stopInnerThr)
@@ -481,9 +492,9 @@ def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian',
Parameters
----------
- xs : np.ndarray (ns,d)
+ xs : array-like (ns,d)
samples in the source domain
- xt : np.ndarray (nt,d)
+ xt : array-like (nt,d)
samples in the target domain
mu : float,optional
Weight for the linear OT loss (>0)
@@ -513,9 +524,9 @@ def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian',
Returns
-------
- gamma : (ns, nt) ndarray
+ gamma : (ns, nt) array-like
Optimal transportation matrix for the given parameters
- L : (ns, d) ndarray
+ L : (ns, d) array-like
Nonlinear mapping matrix ((:math:`n_s+1`, `d`) if bias)
log : dict
log dictionary return only if log==True in parameters
@@ -534,15 +545,17 @@ def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian',
ot.optim.cg : General regularized OT
"""
+ xs, xt = list_to_array(xs, xt)
+ nx = get_backend(xs, xt)
ns, nt = xs.shape[0], xt.shape[0]
K = kernel(xs, xs, method=kerneltype, sigma=sigma)
if bias:
- K1 = np.hstack((K, np.ones((ns, 1))))
- Id = np.eye(ns + 1)
+ K1 = nx.concatenate((K, nx.ones((ns, 1), type_as=xs)), axis=1)
+ Id = nx.eye(ns + 1, type_as=xs)
Id[-1] = 0
- Kp = np.eye(ns + 1)
+ Kp = nx.eye(ns + 1, type_as=xs)
Kp[:ns, :ns] = K
# ls regu
@@ -550,12 +563,12 @@ def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian',
# Kreg=I
# RKHS regul
- K0 = K1.T.dot(K1) + eta * Kp
+ K0 = nx.dot(K1.T, K1) + eta * Kp
Kreg = Kp
else:
K1 = K
- Id = np.eye(ns)
+ Id = nx.eye(ns, type_as=xs)
# ls regul
# K0 = K1.T.dot(K1)+eta*I
@@ -568,7 +581,8 @@ def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian',
if log:
log = {'err': []}
- a, b = unif(ns), unif(nt)
+ a = unif(ns, type_as=xs)
+ b = unif(nt, type_as=xt)
M = dist(xs, xt) * ns
G = emd(a, b, M)
@@ -576,28 +590,31 @@ def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian',
def loss(L, G):
"""Compute full loss"""
- return np.sum((K1.dot(L) - ns * G.dot(xt)) ** 2) + mu * \
- np.sum(G * M) + eta * np.trace(L.T.dot(Kreg).dot(L))
+ return (
+ nx.sum((nx.dot(K1, L) - ns * nx.dot(G, xt)) ** 2)
+ + mu * nx.sum(G * M)
+ + eta * nx.trace(dots(L.T, Kreg, L))
+ )
def solve_L_nobias(G):
""" solve L problem with fixed G (least square)"""
- xst = ns * G.dot(xt)
- return np.linalg.solve(K0, xst)
+ xst = ns * nx.dot(G, xt)
+ return nx.solve(K0, xst)
def solve_L_bias(G):
""" solve L problem with fixed G (least square)"""
- xst = ns * G.dot(xt)
- return np.linalg.solve(K0, K1.T.dot(xst))
+ xst = ns * nx.dot(G, xt)
+ return nx.solve(K0, nx.dot(K1.T, xst))
def solve_G(L, G0):
"""Update G with CG algorithm"""
- xsi = K1.dot(L)
+ xsi = nx.dot(K1, L)
def f(G):
- return np.sum((xsi - ns * G.dot(xt)) ** 2)
+ return nx.sum((xsi - ns * nx.dot(G, xt)) ** 2)
def df(G):
- return -2 * ns * (xsi - ns * G.dot(xt)).dot(xt.T)
+ return -2 * ns * nx.dot(xsi - ns * nx.dot(G, xt), xt.T)
G = cg(a, b, M, 1.0 / mu, f, df, G0=G0,
numItermax=numInnerItermax, stopThr=stopInnerThr)
@@ -681,15 +698,15 @@ def OT_mapping_linear(xs, xt, reg=1e-6, ws=None,
Parameters
----------
- xs : np.ndarray (ns,d)
+ xs : array-like (ns,d)
samples in the source domain
- xt : np.ndarray (nt,d)
+ xt : array-like (nt,d)
samples in the target domain
reg : float,optional
regularization added to the diagonals of covariances (>0)
- ws : np.ndarray (ns,1), optional
+ ws : array-like (ns,1), optional
weights for the source samples
- wt : np.ndarray (ns,1), optional
+ wt : array-like (ns,1), optional
weights for the target samples
bias: boolean, optional
estimate bias :math:`\mathbf{b}` else :math:`\mathbf{b} = 0` (default:True)
@@ -699,9 +716,9 @@ def OT_mapping_linear(xs, xt, reg=1e-6, ws=None,
Returns
-------
- A : (d, d) ndarray
+ A : (d, d) array-like
Linear operator
- b : (1, d) ndarray
+ b : (1, d) array-like
bias
log : dict
log dictionary return only if log==True in parameters
@@ -719,36 +736,38 @@ def OT_mapping_linear(xs, xt, reg=1e-6, ws=None,
"""
+ xs, xt = list_to_array(xs, xt)
+ nx = get_backend(xs, xt)
d = xs.shape[1]
if bias:
- mxs = xs.mean(0, keepdims=True)
- mxt = xt.mean(0, keepdims=True)
+ mxs = nx.mean(xs, axis=0)[None, :]
+ mxt = nx.mean(xt, axis=0)[None, :]
xs = xs - mxs
xt = xt - mxt
else:
- mxs = np.zeros((1, d))
- mxt = np.zeros((1, d))
+ mxs = nx.zeros((1, d), type_as=xs)
+ mxt = nx.zeros((1, d), type_as=xs)
if ws is None:
- ws = np.ones((xs.shape[0], 1)) / xs.shape[0]
+ ws = nx.ones((xs.shape[0], 1), type_as=xs) / xs.shape[0]
if wt is None:
- wt = np.ones((xt.shape[0], 1)) / xt.shape[0]
+ wt = nx.ones((xt.shape[0], 1), type_as=xt) / xt.shape[0]
- Cs = (xs * ws).T.dot(xs) / ws.sum() + reg * np.eye(d)
- Ct = (xt * wt).T.dot(xt) / wt.sum() + reg * np.eye(d)
+ Cs = nx.dot((xs * ws).T, xs) / nx.sum(ws) + reg * nx.eye(d, type_as=xs)
+ Ct = nx.dot((xt * wt).T, xt) / nx.sum(wt) + reg * nx.eye(d, type_as=xt)
- Cs12 = linalg.sqrtm(Cs)
- Cs_12 = linalg.inv(Cs12)
+ Cs12 = nx.sqrtm(Cs)
+ Cs_12 = nx.inv(Cs12)
- M0 = linalg.sqrtm(Cs12.dot(Ct.dot(Cs12)))
+ M0 = nx.sqrtm(dots(Cs12, Ct, Cs12))
- A = Cs_12.dot(M0.dot(Cs_12))
+ A = dots(Cs_12, M0, Cs_12)
- b = mxt - mxs.dot(A)
+ b = mxt - nx.dot(mxs, A)
if log:
log = {}
@@ -798,15 +817,15 @@ def emd_laplace(a, b, xs, xt, M, sim='knn', sim_param=None, reg='pos', eta=1, al
Parameters
----------
- a : np.ndarray (ns,)
+ a : array-like (ns,)
samples weights in the source domain
- b : np.ndarray (nt,)
+ b : array-like (nt,)
samples weights in the target domain
- xs : np.ndarray (ns,d)
+ xs : array-like (ns,d)
samples in the source domain
- xt : np.ndarray (nt,d)
+ xt : array-like (nt,d)
samples in the target domain
- M : np.ndarray (ns,nt)
+ M : array-like (ns,nt)
loss matrix
sim : string, optional
Type of similarity ('knn' or 'gauss') used to construct the Laplacian.
@@ -834,7 +853,7 @@ def emd_laplace(a, b, xs, xt, M, sim='knn', sim_param=None, reg='pos', eta=1, al
Returns
-------
- gamma : (ns, nt) ndarray
+ gamma : (ns, nt) array-like
Optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
@@ -862,9 +881,12 @@ def emd_laplace(a, b, xs, xt, M, sim='knn', sim_param=None, reg='pos', eta=1, al
raise ValueError(
'Similarity parameter should be an int or a float. Got {type} instead.'.format(type=type(sim_param).__name__))
+ a, b, xs, xt, M = list_to_array(a, b, xs, xt, M)
+ nx = get_backend(a, b, xs, xt, M)
+
if sim == 'gauss':
if sim_param is None:
- sim_param = 1 / (2 * (np.mean(dist(xs, xs, 'sqeuclidean')) ** 2))
+ sim_param = 1 / (2 * (nx.mean(dist(xs, xs, 'sqeuclidean')) ** 2))
sS = kernel(xs, xs, method=sim, sigma=sim_param)
sT = kernel(xt, xt, method=sim, sigma=sim_param)
@@ -874,9 +896,13 @@ def emd_laplace(a, b, xs, xt, M, sim='knn', sim_param=None, reg='pos', eta=1, al
from sklearn.neighbors import kneighbors_graph
- sS = kneighbors_graph(X=xs, n_neighbors=int(sim_param)).toarray()
+ sS = nx.from_numpy(kneighbors_graph(
+ X=nx.to_numpy(xs), n_neighbors=int(sim_param)
+ ).toarray(), type_as=xs)
sS = (sS + sS.T) / 2
- sT = kneighbors_graph(xt, n_neighbors=int(sim_param)).toarray()
+ sT = nx.from_numpy(kneighbors_graph(
+ X=nx.to_numpy(xt), n_neighbors=int(sim_param)
+ ).toarray(), type_as=xt)
sT = (sT + sT.T) / 2
else:
raise ValueError('Unknown similarity type {sim}. Currently supported similarity types are "knn" and "gauss".'.format(sim=sim))
@@ -885,12 +911,14 @@ def emd_laplace(a, b, xs, xt, M, sim='knn', sim_param=None, reg='pos', eta=1, al
lT = laplacian(sT)
def f(G):
- return alpha * np.trace(np.dot(xt.T, np.dot(G.T, np.dot(lS, np.dot(G, xt))))) \
- + (1 - alpha) * np.trace(np.dot(xs.T, np.dot(G, np.dot(lT, np.dot(G.T, xs)))))
+ return (
+ alpha * nx.trace(dots(xt.T, G.T, lS, G, xt))
+ + (1 - alpha) * nx.trace(dots(xs.T, G, lT, G.T, xs))
+ )
ls2 = lS + lS.T
lt2 = lT + lT.T
- xt2 = np.dot(xt, xt.T)
+ xt2 = nx.dot(xt, xt.T)
if reg == 'disp':
Cs = -eta * alpha / xs.shape[0] * dots(ls2, xs, xt.T)
@@ -898,8 +926,10 @@ def emd_laplace(a, b, xs, xt, M, sim='knn', sim_param=None, reg='pos', eta=1, al
M = M + Cs + Ct
def df(G):
- return alpha * np.dot(ls2, np.dot(G, xt2))\
- + (1 - alpha) * np.dot(xs, np.dot(xs.T, np.dot(G, lt2)))
+ return (
+ alpha * dots(ls2, G, xt2)
+ + (1 - alpha) * dots(xs, xs.T, G, lt2)
+ )
return cg(a, b, M, reg=eta, f=f, df=df, G0=None, numItermax=numItermax, numItermaxEmd=numInnerItermax,
stopThr=stopThr, stopThr2=stopInnerThr, verbose=verbose, log=log)
@@ -919,7 +949,7 @@ def distribution_estimation_uniform(X):
The uniform distribution estimated from :math:`\mathbf{X}`
"""
- return unif(X.shape[0])
+ return unif(X.shape[0], type_as=X)
class BaseTransport(BaseEstimator):
@@ -973,6 +1003,7 @@ class BaseTransport(BaseEstimator):
self : object
Returns self.
"""
+ nx = self._get_backend(Xs, ys, Xt, yt)
# check the necessary inputs parameters are here
if check_params(Xs=Xs, Xt=Xt):
@@ -984,14 +1015,14 @@ class BaseTransport(BaseEstimator):
if (ys is not None) and (yt is not None):
if self.limit_max != np.infty:
- self.limit_max = self.limit_max * np.max(self.cost_)
+ self.limit_max = self.limit_max * nx.max(self.cost_)
# assumes labeled source samples occupy the first rows
# and labeled target samples occupy the first columns
- classes = [c for c in np.unique(ys) if c != -1]
+ classes = [c for c in nx.unique(ys) if c != -1]
for c in classes:
- idx_s = np.where((ys != c) & (ys != -1))
- idx_t = np.where(yt == c)
+ idx_s = nx.where((ys != c) & (ys != -1))
+ idx_t = nx.where(yt == c)
# all the coefficients corresponding to a source sample
# and a target sample :
@@ -1062,23 +1093,24 @@ class BaseTransport(BaseEstimator):
transp_Xs : array-like, shape (n_source_samples, n_features)
The transport source samples.
"""
+ nx = self.nx
# check the necessary inputs parameters are here
if check_params(Xs=Xs):
- if np.array_equal(self.xs_, Xs):
+ if nx.array_equal(self.xs_, Xs):
# perform standard barycentric mapping
- transp = self.coupling_ / np.sum(self.coupling_, 1)[:, None]
+ transp = self.coupling_ / nx.sum(self.coupling_, axis=1)[:, None]
# set nans to 0
- transp[~ np.isfinite(transp)] = 0
+ transp[~ nx.isfinite(transp)] = 0
# compute transported samples
- transp_Xs = np.dot(transp, self.xt_)
+ transp_Xs = nx.dot(transp, self.xt_)
else:
# perform out of sample mapping
- indices = np.arange(Xs.shape[0])
+ indices = nx.arange(Xs.shape[0])
batch_ind = [
indices[i:i + batch_size]
for i in range(0, len(indices), batch_size)]
@@ -1087,20 +1119,20 @@ class BaseTransport(BaseEstimator):
for bi in batch_ind:
# get the nearest neighbor in the source domain
D0 = dist(Xs[bi], self.xs_)
- idx = np.argmin(D0, axis=1)
+ idx = nx.argmin(D0, axis=1)
# transport the source samples
- transp = self.coupling_ / np.sum(
- self.coupling_, 1)[:, None]
- transp[~ np.isfinite(transp)] = 0
- transp_Xs_ = np.dot(transp, self.xt_)
+ transp = self.coupling_ / nx.sum(
+ self.coupling_, axis=1)[:, None]
+ transp[~ nx.isfinite(transp)] = 0
+ transp_Xs_ = nx.dot(transp, self.xt_)
# define the transported points
transp_Xs_ = transp_Xs_[idx, :] + Xs[bi] - self.xs_[idx, :]
transp_Xs.append(transp_Xs_)
- transp_Xs = np.concatenate(transp_Xs, axis=0)
+ transp_Xs = nx.concatenate(transp_Xs, axis=0)
return transp_Xs
@@ -1127,26 +1159,27 @@ class BaseTransport(BaseEstimator):
International Conference on Artificial Intelligence and Statistics (AISTATS), 2019.
"""
+ nx = self.nx
# check the necessary inputs parameters are here
if check_params(ys=ys):
- ysTemp = label_normalization(np.copy(ys))
- classes = np.unique(ysTemp)
+ ysTemp = label_normalization(nx.copy(ys))
+ classes = nx.unique(ysTemp)
n = len(classes)
- D1 = np.zeros((n, len(ysTemp)))
+ D1 = nx.zeros((n, len(ysTemp)), type_as=self.coupling_)
# perform label propagation
- transp = self.coupling_ / np.sum(self.coupling_, 0, keepdims=True)
+ transp = self.coupling_ / nx.sum(self.coupling_, axis=0)[None, :]
# set nans to 0
- transp[~ np.isfinite(transp)] = 0
+ transp[~ nx.isfinite(transp)] = 0
for c in classes:
D1[int(c), ysTemp == c] = 1
# compute propagated labels
- transp_ys = np.dot(D1, transp)
+ transp_ys = nx.dot(D1, transp)
return transp_ys.T
@@ -1176,23 +1209,24 @@ class BaseTransport(BaseEstimator):
transp_Xt : array-like, shape (n_source_samples, n_features)
The transported target samples.
"""
+ nx = self.nx
# check the necessary inputs parameters are here
if check_params(Xt=Xt):
- if np.array_equal(self.xt_, Xt):
+ if nx.array_equal(self.xt_, Xt):
# perform standard barycentric mapping
- transp_ = self.coupling_.T / np.sum(self.coupling_, 0)[:, None]
+ transp_ = self.coupling_.T / nx.sum(self.coupling_, 0)[:, None]
# set nans to 0
- transp_[~ np.isfinite(transp_)] = 0
+ transp_[~ nx.isfinite(transp_)] = 0
# compute transported samples
- transp_Xt = np.dot(transp_, self.xs_)
+ transp_Xt = nx.dot(transp_, self.xs_)
else:
# perform out of sample mapping
- indices = np.arange(Xt.shape[0])
+ indices = nx.arange(Xt.shape[0])
batch_ind = [
indices[i:i + batch_size]
for i in range(0, len(indices), batch_size)]
@@ -1200,20 +1234,20 @@ class BaseTransport(BaseEstimator):
transp_Xt = []
for bi in batch_ind:
D0 = dist(Xt[bi], self.xt_)
- idx = np.argmin(D0, axis=1)
+ idx = nx.argmin(D0, axis=1)
# transport the target samples
- transp_ = self.coupling_.T / np.sum(
+ transp_ = self.coupling_.T / nx.sum(
self.coupling_, 0)[:, None]
- transp_[~ np.isfinite(transp_)] = 0
- transp_Xt_ = np.dot(transp_, self.xs_)
+ transp_[~ nx.isfinite(transp_)] = 0
+ transp_Xt_ = nx.dot(transp_, self.xs_)
# define the transported points
transp_Xt_ = transp_Xt_[idx, :] + Xt[bi] - self.xt_[idx, :]
transp_Xt.append(transp_Xt_)
- transp_Xt = np.concatenate(transp_Xt, axis=0)
+ transp_Xt = nx.concatenate(transp_Xt, axis=0)
return transp_Xt
@@ -1230,26 +1264,27 @@ class BaseTransport(BaseEstimator):
transp_ys : array-like, shape (n_source_samples, nb_classes)
Estimated soft source labels.
"""
+ nx = self.nx
# check the necessary inputs parameters are here
if check_params(yt=yt):
- ytTemp = label_normalization(np.copy(yt))
- classes = np.unique(ytTemp)
+ ytTemp = label_normalization(nx.copy(yt))
+ classes = nx.unique(ytTemp)
n = len(classes)
- D1 = np.zeros((n, len(ytTemp)))
+ D1 = nx.zeros((n, len(ytTemp)), type_as=self.coupling_)
# perform label propagation
- transp = self.coupling_ / np.sum(self.coupling_, 1)[:, None]
+ transp = self.coupling_ / nx.sum(self.coupling_, 1)[:, None]
# set nans to 0
- transp[~ np.isfinite(transp)] = 0
+ transp[~ nx.isfinite(transp)] = 0
for c in classes:
D1[int(c), ytTemp == c] = 1
# compute propagated samples
- transp_ys = np.dot(D1, transp.T)
+ transp_ys = nx.dot(D1, transp.T)
return transp_ys.T
@@ -1330,14 +1365,15 @@ class LinearTransport(BaseTransport):
self : object
Returns self.
"""
+ nx = self._get_backend(Xs, ys, Xt, yt)
self.mu_s = self.distribution_estimation(Xs)
self.mu_t = self.distribution_estimation(Xt)
# coupling estimation
returned_ = OT_mapping_linear(Xs, Xt, reg=self.reg,
- ws=self.mu_s.reshape((-1, 1)),
- wt=self.mu_t.reshape((-1, 1)),
+ ws=nx.reshape(self.mu_s, (-1, 1)),
+ wt=nx.reshape(self.mu_t, (-1, 1)),
bias=self.bias, log=self.log)
# deal with the value of log
@@ -1348,8 +1384,8 @@ class LinearTransport(BaseTransport):
self.log_ = dict()
# re compute inverse mapping
- self.A1_ = linalg.inv(self.A_)
- self.B1_ = -self.B_.dot(self.A1_)
+ self.A1_ = nx.inv(self.A_)
+ self.B1_ = -nx.dot(self.B_, self.A1_)
return self
@@ -1378,10 +1414,11 @@ class LinearTransport(BaseTransport):
transp_Xs : array-like, shape (n_source_samples, n_features)
The transport source samples.
"""
+ nx = self.nx
# check the necessary inputs parameters are here
if check_params(Xs=Xs):
- transp_Xs = Xs.dot(self.A_) + self.B_
+ transp_Xs = nx.dot(Xs, self.A_) + self.B_
return transp_Xs
@@ -1411,10 +1448,11 @@ class LinearTransport(BaseTransport):
transp_Xt : array-like, shape (n_source_samples, n_features)
The transported target samples.
"""
+ nx = self.nx
# check the necessary inputs parameters are here
if check_params(Xt=Xt):
- transp_Xt = Xt.dot(self.A1_) + self.B1_
+ transp_Xt = nx.dot(Xt, self.A1_) + self.B1_
return transp_Xt
@@ -2112,6 +2150,7 @@ class MappingTransport(BaseEstimator):
self : object
Returns self
"""
+ self._get_backend(Xs, ys, Xt, yt)
# check the necessary inputs parameters are here
if check_params(Xs=Xs, Xt=Xt):
@@ -2158,19 +2197,20 @@ class MappingTransport(BaseEstimator):
transp_Xs : array-like, shape (n_source_samples, n_features)
The transport source samples.
"""
+ nx = self.nx
# check the necessary inputs parameters are here
if check_params(Xs=Xs):
- if np.array_equal(self.xs_, Xs):
+ if nx.array_equal(self.xs_, Xs):
# perform standard barycentric mapping
- transp = self.coupling_ / np.sum(self.coupling_, 1)[:, None]
+ transp = self.coupling_ / nx.sum(self.coupling_, 1)[:, None]
# set nans to 0
- transp[~ np.isfinite(transp)] = 0
+ transp[~ nx.isfinite(transp)] = 0
# compute transported samples
- transp_Xs = np.dot(transp, self.xt_)
+ transp_Xs = nx.dot(transp, self.xt_)
else:
if self.kernel == "gaussian":
K = kernel(Xs, self.xs_, method=self.kernel,
@@ -2178,8 +2218,10 @@ class MappingTransport(BaseEstimator):
elif self.kernel == "linear":
K = Xs
if self.bias:
- K = np.hstack((K, np.ones((Xs.shape[0], 1))))
- transp_Xs = K.dot(self.mapping_)
+ K = nx.concatenate(
+ [K, nx.ones((Xs.shape[0], 1), type_as=K)], axis=1
+ )
+ transp_Xs = nx.dot(K, self.mapping_)
return transp_Xs
@@ -2396,6 +2438,7 @@ class JCPOTTransport(BaseTransport):
self : object
Returns self.
"""
+ self._get_backend(*Xs, *ys, Xt, yt)
# check the necessary inputs parameters are here
if check_params(Xs=Xs, Xt=Xt, ys=ys):
@@ -2438,28 +2481,29 @@ class JCPOTTransport(BaseTransport):
batch_size : int, optional (default=128)
The batch size for out of sample inverse transform
"""
+ nx = self.nx
transp_Xs = []
# check the necessary inputs parameters are here
if check_params(Xs=Xs):
- if all([np.allclose(x, y) for x, y in zip(self.xs_, Xs)]):
+ if all([nx.allclose(x, y) for x, y in zip(self.xs_, Xs)]):
# perform standard barycentric mapping for each source domain
for coupling in self.coupling_:
- transp = coupling / np.sum(coupling, 1)[:, None]
+ transp = coupling / nx.sum(coupling, 1)[:, None]
# set nans to 0
- transp[~ np.isfinite(transp)] = 0
+ transp[~ nx.isfinite(transp)] = 0
# compute transported samples
- transp_Xs.append(np.dot(transp, self.xt_))
+ transp_Xs.append(nx.dot(transp, self.xt_))
else:
# perform out of sample mapping
- indices = np.arange(Xs.shape[0])
+ indices = nx.arange(Xs.shape[0])
batch_ind = [
indices[i:i + batch_size]
for i in range(0, len(indices), batch_size)]
@@ -2470,23 +2514,22 @@ class JCPOTTransport(BaseTransport):
transp_Xs_ = []
# get the nearest neighbor in the sources domains
- xs = np.concatenate(self.xs_, axis=0)
- idx = np.argmin(dist(Xs[bi], xs), axis=1)
+ xs = nx.concatenate(self.xs_, axis=0)
+ idx = nx.argmin(dist(Xs[bi], xs), axis=1)
# transport the source samples
for coupling in self.coupling_:
- transp = coupling / np.sum(
- coupling, 1)[:, None]
- transp[~ np.isfinite(transp)] = 0
- transp_Xs_.append(np.dot(transp, self.xt_))
+ transp = coupling / nx.sum(coupling, 1)[:, None]
+ transp[~ nx.isfinite(transp)] = 0
+ transp_Xs_.append(nx.dot(transp, self.xt_))
- transp_Xs_ = np.concatenate(transp_Xs_, axis=0)
+ transp_Xs_ = nx.concatenate(transp_Xs_, axis=0)
# define the transported points
transp_Xs_ = transp_Xs_[idx, :] + Xs[bi] - xs[idx, :]
transp_Xs.append(transp_Xs_)
- transp_Xs = np.concatenate(transp_Xs, axis=0)
+ transp_Xs = nx.concatenate(transp_Xs, axis=0)
return transp_Xs
@@ -2512,32 +2555,36 @@ class JCPOTTransport(BaseTransport):
"Optimal transport for multi-source domain adaptation under target shift",
International Conference on Artificial Intelligence and Statistics (AISTATS), 2019.
"""
+ nx = self.nx
# check the necessary inputs parameters are here
if check_params(ys=ys):
- yt = np.zeros((len(np.unique(np.concatenate(ys))), self.xt_.shape[0]))
+ yt = nx.zeros(
+ (len(nx.unique(nx.concatenate(ys))), self.xt_.shape[0]),
+ type_as=ys[0]
+ )
for i in range(len(ys)):
- ysTemp = label_normalization(np.copy(ys[i]))
- classes = np.unique(ysTemp)
+ ysTemp = label_normalization(nx.copy(ys[i]))
+ classes = nx.unique(ysTemp)
n = len(classes)
ns = len(ysTemp)
# perform label propagation
- transp = self.coupling_[i] / np.sum(self.coupling_[i], 1)[:, None]
+ transp = self.coupling_[i] / nx.sum(self.coupling_[i], 1)[:, None]
# set nans to 0
- transp[~ np.isfinite(transp)] = 0
+ transp[~ nx.isfinite(transp)] = 0
if self.log:
D1 = self.log_['D1'][i]
else:
- D1 = np.zeros((n, ns))
+ D1 = nx.zeros((n, ns), type_as=transp)
for c in classes:
D1[int(c), ysTemp == c] = 1
# compute propagated labels
- yt = yt + np.dot(D1, transp) / len(ys)
+ yt = yt + nx.dot(D1, transp) / len(ys)
return yt.T
@@ -2555,14 +2602,15 @@ class JCPOTTransport(BaseTransport):
transp_ys : list of K array-like objects, shape K x (nk_source_samples, nb_classes)
A list of estimated soft source labels
"""
+ nx = self.nx
# check the necessary inputs parameters are here
if check_params(yt=yt):
transp_ys = []
- ytTemp = label_normalization(np.copy(yt))
- classes = np.unique(ytTemp)
+ ytTemp = label_normalization(nx.copy(yt))
+ classes = nx.unique(ytTemp)
n = len(classes)
- D1 = np.zeros((n, len(ytTemp)))
+ D1 = nx.zeros((n, len(ytTemp)), type_as=self.coupling_[0])
for c in classes:
D1[int(c), ytTemp == c] = 1
@@ -2570,12 +2618,12 @@ class JCPOTTransport(BaseTransport):
for i in range(len(self.xs_)):
# perform label propagation
- transp = self.coupling_[i] / np.sum(self.coupling_[i], 1)[:, None]
+ transp = self.coupling_[i] / nx.sum(self.coupling_[i], 1)[:, None]
# set nans to 0
- transp[~ np.isfinite(transp)] = 0
+ transp[~ nx.isfinite(transp)] = 0
# compute propagated labels
- transp_ys.append(np.dot(D1, transp.T).T)
+ transp_ys.append(nx.dot(D1, transp.T).T)
return transp_ys
diff --git a/ot/dr.py b/ot/dr.py
index 1671ca0..0955c55 100644
--- a/ot/dr.py
+++ b/ot/dr.py
@@ -11,6 +11,7 @@ Dimension reduction with OT
# Author: Remi Flamary <remi.flamary@unice.fr>
# Minhui Huang <mhhuang@ucdavis.edu>
+# Jakub Zadrozny <jakub.r.zadrozny@gmail.com>
#
# License: MIT License
@@ -43,6 +44,28 @@ def sinkhorn(w1, w2, M, reg, k):
return G
+def logsumexp(M, axis):
+ r"""Log-sum-exp reduction compatible with autograd (no numpy implementation)
+ """
+ amax = np.amax(M, axis=axis, keepdims=True)
+ return np.log(np.sum(np.exp(M - amax), axis=axis)) + np.squeeze(amax, axis=axis)
+
+
+def sinkhorn_log(w1, w2, M, reg, k):
+ r"""Sinkhorn algorithm in log-domain with fixed number of iteration (autograd)
+ """
+ Mr = -M / reg
+ ui = np.zeros((M.shape[0],))
+ vi = np.zeros((M.shape[1],))
+ log_w1 = np.log(w1)
+ log_w2 = np.log(w2)
+ for i in range(k):
+ vi = log_w2 - logsumexp(Mr + ui[:, None], 0)
+ ui = log_w1 - logsumexp(Mr + vi[None, :], 1)
+ G = np.exp(ui[:, None] + Mr + vi[None, :])
+ return G
+
+
def split_classes(X, y):
r"""split samples in :math:`\mathbf{X}` by classes in :math:`\mathbf{y}`
"""
@@ -110,7 +133,7 @@ def fda(X, y, p=2, reg=1e-16):
return Popt, proj
-def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None, normalize=False):
+def wda(X, y, p=2, reg=1, k=10, solver=None, sinkhorn_method='sinkhorn', maxiter=100, verbose=0, P0=None, normalize=False):
r"""
Wasserstein Discriminant Analysis :ref:`[11] <references-wda>`
@@ -126,6 +149,14 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None, no
- :math:`W` is entropic regularized Wasserstein distances
- :math:`\mathbf{X}^i` are samples in the dataset corresponding to class i
+ **Choosing a Sinkhorn solver**
+
+ By default and when using a regularization parameter that is not too small
+ the default sinkhorn solver should be enough. If you need to use a small
+ regularization to get sparse cost matrices, you should use the
+ :py:func:`ot.dr.sinkhorn_log` solver that will avoid numerical
+ errors, but can be slow in practice.
+
Parameters
----------
X : ndarray, shape (n, d)
@@ -139,6 +170,8 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None, no
solver : None | str, optional
None for steepest descent or 'TrustRegions' for trust regions algorithm
else should be a pymanopt.solvers
+ sinkhorn_method : str
+ method used for the Sinkhorn solver, either 'sinkhorn' or 'sinkhorn_log'
P0 : ndarray, shape (d, p)
Initial starting point for projection.
normalize : bool, optional
@@ -161,6 +194,13 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None, no
Wasserstein Discriminant Analysis. arXiv preprint arXiv:1608.08063.
""" # noqa
+ if sinkhorn_method.lower() == 'sinkhorn':
+ sinkhorn_solver = sinkhorn
+ elif sinkhorn_method.lower() == 'sinkhorn_log':
+ sinkhorn_solver = sinkhorn_log
+ else:
+ raise ValueError("Unknown Sinkhorn method '%s'." % sinkhorn_method)
+
mx = np.mean(X)
X -= mx.reshape((1, -1))
@@ -193,7 +233,7 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None, no
for j, xj in enumerate(xc[i:]):
xj = np.dot(xj, P)
M = dist(xi, xj)
- G = sinkhorn(wc[i], wc[j + i], M, reg * regmean[i, j], k)
+ G = sinkhorn_solver(wc[i], wc[j + i], M, reg * regmean[i, j], k)
if j == 0:
loss_w += np.sum(G * M)
else:
diff --git a/ot/factored.py b/ot/factored.py
new file mode 100644
index 0000000..abc2445
--- /dev/null
+++ b/ot/factored.py
@@ -0,0 +1,145 @@
+"""
+Factored OT solvers (low rank, cost or OT plan)
+"""
+
+# Author: Remi Flamary <remi.flamary@polytehnique.edu>
+#
+# License: MIT License
+
+from .backend import get_backend
+from .utils import dist
+from .lp import emd
+from .bregman import sinkhorn
+
+__all__ = ['factored_optimal_transport']
+
+
+def factored_optimal_transport(Xa, Xb, a=None, b=None, reg=0.0, r=100, X0=None, stopThr=1e-7, numItermax=100, verbose=False, log=False, **kwargs):
+ r"""Solves factored OT problem and return OT plans and intermediate distribution
+
+ This function solve the following OT problem [40]_
+
+ .. math::
+ \mathop{\arg \min}_\mu \quad W_2^2(\mu_a,\mu)+ W_2^2(\mu,\mu_b)
+
+ where :
+
+ - :math:`\mu_a` and :math:`\mu_b` are empirical distributions.
+ - :math:`\mu` is an empirical distribution with r samples
+
+ And returns the two OT plans between
+
+ .. note:: This function is backend-compatible and will work on arrays
+ from all compatible backends. But the algorithm uses the C++ CPU backend
+ which can lead to copy overhead on GPU arrays.
+
+ Uses the conditional gradient algorithm to solve the problem proposed in
+ :ref:`[39] <references-weak>`.
+
+ Parameters
+ ----------
+ Xa : (ns,d) array-like, float
+ Source samples
+ Xb : (nt,d) array-like, float
+ Target samples
+ a : (ns,) array-like, float
+ Source histogram (uniform weight if empty list)
+ b : (nt,) array-like, float
+ Target histogram (uniform weight if empty list))
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshold on the relative variation (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+
+ Returns
+ -------
+ Ga: array-like, shape (ns, r)
+ Optimal transportation matrix between source and the intermediate
+ distribution
+ Gb: array-like, shape (r, nt)
+ Optimal transportation matrix between the intermediate and target
+ distribution
+ X: array-like, shape (r, d)
+ Support of the intermediate distribution
+ log: dict, optional
+ If input log is true, a dictionary containing the cost and dual
+ variables and exit status
+
+
+ .. _references-factored:
+ References
+ ----------
+ .. [40] Forrow, A., Hütter, J. C., Nitzan, M., Rigollet, P., Schiebinger,
+ G., & Weed, J. (2019, April). Statistical optimal transport via factored
+ couplings. In The 22nd International Conference on Artificial
+ Intelligence and Statistics (pp. 2454-2465). PMLR.
+
+ See Also
+ --------
+ ot.bregman.sinkhorn : Entropic regularized OT ot.optim.cg : General
+ regularized OT
+ """
+
+ nx = get_backend(Xa, Xb)
+
+ n_a = Xa.shape[0]
+ n_b = Xb.shape[0]
+ d = Xa.shape[1]
+
+ if a is None:
+ a = nx.ones((n_a), type_as=Xa) / n_a
+ if b is None:
+ b = nx.ones((n_b), type_as=Xb) / n_b
+
+ if X0 is None:
+ X = nx.randn(r, d, type_as=Xa)
+ else:
+ X = X0
+
+ w = nx.ones(r, type_as=Xa) / r
+
+ def solve_ot(X1, X2, w1, w2):
+ M = dist(X1, X2)
+ if reg > 0:
+ G, log = sinkhorn(w1, w2, M, reg, log=True, **kwargs)
+ log['cost'] = nx.sum(G * M)
+ return G, log
+ else:
+ return emd(w1, w2, M, log=True, **kwargs)
+
+ norm_delta = []
+
+ # solve the barycenter
+ for i in range(numItermax):
+
+ old_X = X
+
+ # solve OT with template
+ Ga, loga = solve_ot(Xa, X, a, w)
+ Gb, logb = solve_ot(X, Xb, w, b)
+
+ X = 0.5 * (nx.dot(Ga.T, Xa) + nx.dot(Gb, Xb)) * r
+
+ delta = nx.norm(X - old_X)
+ if delta < stopThr:
+ break
+ if log:
+ norm_delta.append(delta)
+
+ if log:
+ log_dic = {'delta_iter': norm_delta,
+ 'ua': loga['u'],
+ 'va': loga['v'],
+ 'ub': logb['u'],
+ 'vb': logb['v'],
+ 'costa': loga['cost'],
+ 'costb': logb['cost'],
+ }
+ return Ga, Gb, X, log_dic
+
+ return Ga, Gb, X
diff --git a/ot/gpu/__init__.py b/ot/gpu/__init__.py
deleted file mode 100644
index 12db605..0000000
--- a/ot/gpu/__init__.py
+++ /dev/null
@@ -1,50 +0,0 @@
-# -*- coding: utf-8 -*-
-"""
-GPU implementation for several OT solvers and utility
-functions.
-
-The GPU backend in handled by `cupy
-<https://cupy.chainer.org/>`_.
-
-.. warning::
- This module is now deprecated and will be removed in future releases. POT
- now privides a backend mechanism that allows for solving prolem on GPU wth
- the pytorch backend.
-
-
-.. warning::
- Note that by default the module is not imported in :mod:`ot`. In order to
- use it you need to explicitely import :mod:`ot.gpu` .
-
-By default, the functions in this module accept and return numpy arrays
-in order to proide drop-in replacement for the other POT function but
-the transfer between CPU en GPU comes with a significant overhead.
-
-In order to get the best performances, we recommend to give only cupy
-arrays to the functions and desactivate the conversion to numpy of the
-result of the function with parameter ``to_numpy=False``.
-
-"""
-
-# Author: Remi Flamary <remi.flamary@unice.fr>
-# Leo Gautheron <https://github.com/aje>
-#
-# License: MIT License
-
-import warnings
-
-from . import bregman
-from . import da
-from .bregman import sinkhorn
-from .da import sinkhorn_lpl1_mm
-
-from . import utils
-from .utils import dist, to_gpu, to_np
-
-
-warnings.warn('This module is deprecated and will be removed in the next minor release of POT', category=DeprecationWarning)
-
-
-__all__ = ["utils", "dist", "sinkhorn",
- "sinkhorn_lpl1_mm", 'bregman', 'da', 'to_gpu', 'to_np']
-
diff --git a/ot/gpu/bregman.py b/ot/gpu/bregman.py
deleted file mode 100644
index 76af00e..0000000
--- a/ot/gpu/bregman.py
+++ /dev/null
@@ -1,196 +0,0 @@
-# -*- coding: utf-8 -*-
-"""
-Bregman projections for regularized OT with GPU
-"""
-
-# Author: Remi Flamary <remi.flamary@unice.fr>
-# Leo Gautheron <https://github.com/aje>
-#
-# License: MIT License
-
-import cupy as np # np used for matrix computation
-import cupy as cp # cp used for cupy specific operations
-from . import utils
-
-
-def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9,
- verbose=False, log=False, to_numpy=True, **kwargs):
- r"""
- Solve the entropic regularization optimal transport on GPU
-
- If the input matrix are in numpy format, they will be uploaded to the
- GPU first which can incur significant time overhead.
-
- The function solves the following optimization problem:
-
- .. math::
- \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)
-
- s.t. \gamma 1 = a
-
- \gamma^T 1= b
-
- \gamma\geq 0
- where :
-
- - M is the (ns,nt) metric cost matrix
- - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - a and b are source and target weights (sum to 1)
-
- The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [2]_
-
-
- Parameters
- ----------
- a : np.ndarray (ns,)
- samples weights in the source domain
- b : np.ndarray (nt,) or np.ndarray (nt,nbb)
- samples in the target domain, compute sinkhorn with multiple targets
- and fixed M if b is a matrix (return OT loss + dual variables in log)
- M : np.ndarray (ns,nt)
- loss matrix
- reg : float
- Regularization term >0
- numItermax : int, optional
- Max number of iterations
- stopThr : float, optional
- Stop threshold on error (>0)
- verbose : bool, optional
- Print information along iterations
- log : bool, optional
- record log if True
- to_numpy : boolean, optional (default True)
- If true convert back the GPU array result to numpy format.
-
-
- Returns
- -------
- gamma : (ns x nt) ndarray
- Optimal transportation matrix for the given parameters
- log : dict
- log dictionary return only if log==True in parameters
-
-
- References
- ----------
-
- .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
-
-
- See Also
- --------
- ot.lp.emd : Unregularized OT
- ot.optim.cg : General regularized OT
-
- """
-
- a = cp.asarray(a)
- b = cp.asarray(b)
- M = cp.asarray(M)
-
- if len(a) == 0:
- a = np.ones((M.shape[0],)) / M.shape[0]
- if len(b) == 0:
- b = np.ones((M.shape[1],)) / M.shape[1]
-
- # init data
- Nini = len(a)
- Nfin = len(b)
-
- if len(b.shape) > 1:
- nbb = b.shape[1]
- else:
- nbb = 0
-
- if log:
- log = {'err': []}
-
- # we assume that no distances are null except those of the diagonal of
- # distances
- if nbb:
- u = np.ones((Nini, nbb)) / Nini
- v = np.ones((Nfin, nbb)) / Nfin
- else:
- u = np.ones(Nini) / Nini
- v = np.ones(Nfin) / Nfin
-
- # print(reg)
-
- # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute
- K = np.empty(M.shape, dtype=M.dtype)
- np.divide(M, -reg, out=K)
- np.exp(K, out=K)
-
- # print(np.min(K))
- tmp2 = np.empty(b.shape, dtype=M.dtype)
-
- Kp = (1 / a).reshape(-1, 1) * K
- cpt = 0
- err = 1
- while (err > stopThr and cpt < numItermax):
- uprev = u
- vprev = v
-
- KtransposeU = np.dot(K.T, u)
- v = np.divide(b, KtransposeU)
- u = 1. / np.dot(Kp, v)
-
- if (np.any(KtransposeU == 0) or
- np.any(np.isnan(u)) or np.any(np.isnan(v)) or
- np.any(np.isinf(u)) or np.any(np.isinf(v))):
- # we have reached the machine precision
- # come back to previous solution and quit loop
- print('Warning: numerical errors at iteration', cpt)
- u = uprev
- v = vprev
- break
- if cpt % 10 == 0:
- # we can speed up the process by checking for the error only all
- # the 10th iterations
- if nbb:
- err = np.sqrt(
- np.sum((u - uprev)**2) / np.sum((u)**2)
- + np.sum((v - vprev)**2) / np.sum((v)**2)
- )
- else:
- # compute right marginal tmp2= (diag(u)Kdiag(v))^T1
- tmp2 = np.sum(u[:, None] * K * v[None, :], 0)
- #tmp2=np.einsum('i,ij,j->j', u, K, v)
- err = np.linalg.norm(tmp2 - b) # violation of marginal
- if log:
- log['err'].append(err)
-
- if verbose:
- if cpt % 200 == 0:
- print(
- '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
- print('{:5d}|{:8e}|'.format(cpt, err))
- cpt = cpt + 1
- if log:
- log['u'] = u
- log['v'] = v
-
- if nbb: # return only loss
- #res = np.einsum('ik,ij,jk,ij->k', u, K, v, M) (explodes cupy memory)
- res = np.empty(nbb)
- for i in range(nbb):
- res[i] = np.sum(u[:, None, i] * (K * M) * v[None, :, i])
- if to_numpy:
- res = utils.to_np(res)
- if log:
- return res, log
- else:
- return res
-
- else: # return OT matrix
- res = u.reshape((-1, 1)) * K * v.reshape((1, -1))
- if to_numpy:
- res = utils.to_np(res)
- if log:
- return res, log
- else:
- return res
-
-
-# define sinkhorn as sinkhorn_knopp
-sinkhorn = sinkhorn_knopp
diff --git a/ot/gpu/da.py b/ot/gpu/da.py
deleted file mode 100644
index 7adb830..0000000
--- a/ot/gpu/da.py
+++ /dev/null
@@ -1,144 +0,0 @@
-# -*- coding: utf-8 -*-
-"""
-Domain adaptation with optimal transport with GPU implementation
-"""
-
-# Author: Remi Flamary <remi.flamary@unice.fr>
-# Nicolas Courty <ncourty@irisa.fr>
-# Michael Perrot <michael.perrot@univ-st-etienne.fr>
-# Leo Gautheron <https://github.com/aje>
-#
-# License: MIT License
-
-
-import cupy as np # np used for matrix computation
-import cupy as cp # cp used for cupy specific operations
-import numpy as npp
-from . import utils
-
-from .bregman import sinkhorn
-
-
-def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10,
- numInnerItermax=200, stopInnerThr=1e-9, verbose=False,
- log=False, to_numpy=True):
- """
- Solve the entropic regularization optimal transport problem with nonconvex
- group lasso regularization on GPU
-
- If the input matrix are in numpy format, they will be uploaded to the
- GPU first which can incur significant time overhead.
-
-
- The function solves the following optimization problem:
-
- .. math::
- \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega_e(\gamma)
- + \eta \Omega_g(\gamma)
-
- s.t. \gamma 1 = a
-
- \gamma^T 1= b
-
- \gamma\geq 0
- where :
-
- - M is the (ns,nt) metric cost matrix
- - :math:`\Omega_e` is the entropic regularization term
- :math:`\Omega_e(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - :math:`\Omega_g` is the group lasso regulaization term
- :math:`\Omega_g(\gamma)=\sum_{i,c} \|\gamma_{i,\mathcal{I}_c}\|^{1/2}_1`
- where :math:`\mathcal{I}_c` are the index of samples from class c
- in the source domain.
- - a and b are source and target weights (sum to 1)
-
- The algorithm used for solving the problem is the generalised conditional
- gradient as proposed in [5]_ [7]_
-
-
- Parameters
- ----------
- a : np.ndarray (ns,)
- samples weights in the source domain
- labels_a : np.ndarray (ns,)
- labels of samples in the source domain
- b : np.ndarray (nt,)
- samples weights in the target domain
- M : np.ndarray (ns,nt)
- loss matrix
- reg : float
- Regularization term for entropic regularization >0
- eta : float, optional
- Regularization term for group lasso regularization >0
- numItermax : int, optional
- Max number of iterations
- numInnerItermax : int, optional
- Max number of iterations (inner sinkhorn solver)
- stopInnerThr : float, optional
- Stop threshold on error (inner sinkhorn solver) (>0)
- verbose : bool, optional
- Print information along iterations
- log : bool, optional
- record log if True
- to_numpy : boolean, optional (default True)
- If true convert back the GPU array result to numpy format.
-
-
- Returns
- -------
- gamma : (ns x nt) ndarray
- Optimal transportation matrix for the given parameters
- log : dict
- log dictionary return only if log==True in parameters
-
-
- References
- ----------
-
- .. [5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy,
- "Optimal Transport for Domain Adaptation," in IEEE
- Transactions on Pattern Analysis and Machine Intelligence ,
- vol.PP, no.99, pp.1-1
- .. [7] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015).
- Generalized conditional gradient: analysis of convergence
- and applications. arXiv preprint arXiv:1510.06567.
-
- See Also
- --------
- ot.lp.emd : Unregularized OT
- ot.bregman.sinkhorn : Entropic regularized OT
- ot.optim.cg : General regularized OT
-
- """
-
- a, labels_a, b, M = utils.to_gpu(a, labels_a, b, M)
-
- p = 0.5
- epsilon = 1e-3
-
- indices_labels = []
- labels_a2 = cp.asnumpy(labels_a)
- classes = npp.unique(labels_a2)
- for c in classes:
- idxc = utils.to_gpu(*npp.where(labels_a2 == c))
- indices_labels.append(idxc)
-
- W = np.zeros(M.shape)
-
- for cpt in range(numItermax):
- Mreg = M + eta * W
- transp = sinkhorn(a, b, Mreg, reg, numItermax=numInnerItermax,
- stopThr=stopInnerThr, to_numpy=False)
- # the transport has been computed. Check if classes are really
- # separated
- W = np.ones(M.shape)
- for (i, c) in enumerate(classes):
-
- majs = np.sum(transp[indices_labels[i]], axis=0)
- majs = p * ((majs + epsilon)**(p - 1))
- W[indices_labels[i]] = majs
-
- if to_numpy:
- return utils.to_np(transp)
- else:
- return transp
diff --git a/ot/gpu/utils.py b/ot/gpu/utils.py
deleted file mode 100644
index 41e168a..0000000
--- a/ot/gpu/utils.py
+++ /dev/null
@@ -1,101 +0,0 @@
-# -*- coding: utf-8 -*-
-"""
-Utility functions for GPU
-"""
-
-# Author: Remi Flamary <remi.flamary@unice.fr>
-# Nicolas Courty <ncourty@irisa.fr>
-# Leo Gautheron <https://github.com/aje>
-#
-# License: MIT License
-
-import cupy as np # np used for matrix computation
-import cupy as cp # cp used for cupy specific operations
-
-
-def euclidean_distances(a, b, squared=False, to_numpy=True):
- """
- Compute the pairwise euclidean distance between matrices a and b.
-
- If the input matrix are in numpy format, they will be uploaded to the
- GPU first which can incur significant time overhead.
-
- Parameters
- ----------
- a : np.ndarray (n, f)
- first matrix
- b : np.ndarray (m, f)
- second matrix
- to_numpy : boolean, optional (default True)
- If true convert back the GPU array result to numpy format.
- squared : boolean, optional (default False)
- if True, return squared euclidean distance matrix
-
- Returns
- -------
- c : (n x m) np.ndarray or cupy.ndarray
- pairwise euclidean distance distance matrix
- """
-
- a, b = to_gpu(a, b)
-
- a2 = np.sum(np.square(a), 1)
- b2 = np.sum(np.square(b), 1)
-
- c = -2 * np.dot(a, b.T)
- c += a2[:, None]
- c += b2[None, :]
-
- if not squared:
- np.sqrt(c, out=c)
- if to_numpy:
- return to_np(c)
- else:
- return c
-
-
-def dist(x1, x2=None, metric='sqeuclidean', to_numpy=True):
- """Compute distance between samples in x1 and x2 on gpu
-
- Parameters
- ----------
-
- x1 : np.array (n1,d)
- matrix with n1 samples of size d
- x2 : np.array (n2,d), optional
- matrix with n2 samples of size d (if None then x2=x1)
- metric : str
- Metric from 'sqeuclidean', 'euclidean',
-
-
- Returns
- -------
-
- M : np.array (n1,n2)
- distance matrix computed with given metric
-
- """
- if x2 is None:
- x2 = x1
- if metric == "sqeuclidean":
- return euclidean_distances(x1, x2, squared=True, to_numpy=to_numpy)
- elif metric == "euclidean":
- return euclidean_distances(x1, x2, squared=False, to_numpy=to_numpy)
- else:
- raise NotImplementedError
-
-
-def to_gpu(*args):
- """ Upload numpy arrays to GPU and return them"""
- if len(args) > 1:
- return (cp.asarray(x) for x in args)
- else:
- return cp.asarray(args[0])
-
-
-def to_np(*args):
- """ convert GPU arras to numpy and return them"""
- if len(args) > 1:
- return (cp.asnumpy(x) for x in args)
- else:
- return cp.asnumpy(args[0])
diff --git a/ot/gromov.py b/ot/gromov.py
index 6544260..55ab0bd 100644
--- a/ot/gromov.py
+++ b/ot/gromov.py
@@ -7,6 +7,7 @@ Gromov-Wasserstein and Fused-Gromov-Wasserstein solvers
# Nicolas Courty <ncourty@irisa.fr>
# Rémi Flamary <remi.flamary@unice.fr>
# Titouan Vayer <titouan.vayer@irisa.fr>
+# Cédric Vincent-Cuaz <cedric.vincent-cuaz@inria.fr>
#
# License: MIT License
@@ -17,7 +18,7 @@ from .bregman import sinkhorn
from .utils import dist, UndefinedParameter, list_to_array
from .optim import cg
from .lp import emd_1d, emd
-from .utils import check_random_state
+from .utils import check_random_state, unif
from .backend import get_backend
@@ -320,7 +321,7 @@ def update_kl_loss(p, lambdas, T, Cs):
return nx.exp(tmpsum / ppt)
-def gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=False, **kwargs):
+def gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=False, G0=None, **kwargs):
r"""
Returns the gromov-wasserstein transport between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})`
@@ -338,6 +339,10 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=F
- :math:`\mathbf{q}`: distribution in the target space
- `L`: loss function to account for the misfit between the similarity matrices
+ .. note:: This function is backend-compatible and will work on arrays
+ from all compatible backends. But the algorithm uses the C++ CPU backend
+ which can lead to copy overhead on GPU arrays.
+
Parameters
----------
C1 : array-like, shape (ns, ns)
@@ -361,6 +366,9 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=F
armijo : bool, optional
If True the step of the line-search is found via an armijo research. Else closed form is used.
If there are convergence issues use False.
+ G0: array-like, shape (ns,nt), optional
+ If None the initial transport plan of the solver is pq^T.
+ Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver.
**kwargs : dict
parameters can be directly passed to the ot.optim.cg solver
@@ -385,18 +393,26 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=F
"""
p, q = list_to_array(p, q)
-
p0, q0, C10, C20 = p, q, C1, C2
- nx = get_backend(p0, q0, C10, C20)
-
+ if G0 is None:
+ nx = get_backend(p0, q0, C10, C20)
+ else:
+ G0_ = G0
+ nx = get_backend(p0, q0, C10, C20, G0_)
p = nx.to_numpy(p)
q = nx.to_numpy(q)
C1 = nx.to_numpy(C10)
C2 = nx.to_numpy(C20)
- constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun)
+ if G0 is None:
+ G0 = p[:, None] * q[None, :]
+ else:
+ G0 = nx.to_numpy(G0_)
+ # Check marginals of G0
+ np.testing.assert_allclose(G0.sum(axis=1), p, atol=1e-08)
+ np.testing.assert_allclose(G0.sum(axis=0), q, atol=1e-08)
- G0 = p[:, None] * q[None, :]
+ constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun)
def f(G):
return gwloss(constC, hC1, hC2, G)
@@ -414,7 +430,7 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=F
return nx.from_numpy(cg(p, q, 0, 1, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=False, **kwargs), type_as=C10)
-def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=False, **kwargs):
+def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=False, G0=None, **kwargs):
r"""
Returns the gromov-wasserstein discrepancy between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})`
@@ -436,6 +452,10 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=
Note that when using backends, this loss function is differentiable wrt the
marices and weights for quadratic loss using the gradients from [38]_.
+ .. note:: This function is backend-compatible and will work on arrays
+ from all compatible backends. But the algorithm uses the C++ CPU backend
+ which can lead to copy overhead on GPU arrays.
+
Parameters
----------
C1 : array-like, shape (ns, ns)
@@ -459,6 +479,9 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=
armijo : bool, optional
If True the step of the line-search is found via an armijo research. Else closed form is used.
If there are convergence issues use False.
+ G0: array-like, shape (ns,nt), optional
+ If None the initial transport plan of the solver is pq^T.
+ Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver.
Returns
-------
@@ -483,9 +506,12 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=
"""
p, q = list_to_array(p, q)
-
p0, q0, C10, C20 = p, q, C1, C2
- nx = get_backend(p0, q0, C10, C20)
+ if G0 is None:
+ nx = get_backend(p0, q0, C10, C20)
+ else:
+ G0_ = G0
+ nx = get_backend(p0, q0, C10, C20, G0_)
p = nx.to_numpy(p)
q = nx.to_numpy(q)
@@ -494,7 +520,13 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=
constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun)
- G0 = p[:, None] * q[None, :]
+ if G0 is None:
+ G0 = p[:, None] * q[None, :]
+ else:
+ G0 = nx.to_numpy(G0_)
+ # Check marginals of G0
+ np.testing.assert_allclose(G0.sum(axis=1), p, atol=1e-08)
+ np.testing.assert_allclose(G0.sum(axis=0), q, atol=1e-08)
def f(G):
return gwloss(constC, hC1, hC2, G)
@@ -514,10 +546,13 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=
gw = log_gw['gw_dist']
if loss_fun == 'square_loss':
- gC1 = nx.from_numpy(2 * C1 * (p[:, None] * p[None, :]) - 2 * T.dot(C2).dot(T.T))
- gC2 = nx.from_numpy(2 * C2 * (q[:, None] * q[None, :]) - 2 * T.T.dot(C1).dot(T))
+ gC1 = 2 * C1 * (p[:, None] * p[None, :]) - 2 * T.dot(C2).dot(T.T)
+ gC2 = 2 * C2 * (q[:, None] * q[None, :]) - 2 * T.T.dot(C1).dot(T)
+ gC1 = nx.from_numpy(gC1, type_as=C10)
+ gC2 = nx.from_numpy(gC2, type_as=C10)
gw = nx.set_gradients(gw, (p0, q0, C10, C20),
- (log_gw['u'], log_gw['v'], gC1, gC2))
+ (log_gw['u'] - nx.mean(log_gw['u']),
+ log_gw['v'] - nx.mean(log_gw['v']), gC1, gC2))
if log:
return gw, log_gw
@@ -525,7 +560,7 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=
return gw
-def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, **kwargs):
+def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, G0=None, log=False, **kwargs):
r"""
Computes the FGW transport between two graphs (see :ref:`[24] <references-fused-gromov-wasserstein>`)
@@ -545,6 +580,10 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5,
- :math:`\mathbf{p}` and :math:`\mathbf{q}` are source and target weights (sum to 1)
- `L` is a loss function to account for the misfit between the similarity matrices
+ .. note:: This function is backend-compatible and will work on arrays
+ from all compatible backends. But the algorithm uses the C++ CPU backend
+ which can lead to copy overhead on GPU arrays.
+
The algorithm used for solving the problem is conditional gradient as discussed in :ref:`[24] <references-fused-gromov-wasserstein>`
Parameters
@@ -566,6 +605,9 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5,
armijo : bool, optional
If True the step of the line-search is found via an armijo research. Else closed form is used.
If there are convergence issues use False.
+ G0: array-like, shape (ns,nt), optional
+ If None the initial transport plan of the solver is pq^T.
+ Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver.
log : bool, optional
record log if True
**kwargs : dict
@@ -588,20 +630,28 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5,
(ICML). 2019.
"""
p, q = list_to_array(p, q)
-
p0, q0, C10, C20, M0 = p, q, C1, C2, M
- nx = get_backend(p0, q0, C10, C20, M0)
+ if G0 is None:
+ nx = get_backend(p0, q0, C10, C20, M0)
+ else:
+ G0_ = G0
+ nx = get_backend(p0, q0, C10, C20, M0, G0_)
p = nx.to_numpy(p)
q = nx.to_numpy(q)
C1 = nx.to_numpy(C10)
C2 = nx.to_numpy(C20)
M = nx.to_numpy(M0)
+ if G0 is None:
+ G0 = p[:, None] * q[None, :]
+ else:
+ G0 = nx.to_numpy(G0_)
+ # Check marginals of G0
+ np.testing.assert_allclose(G0.sum(axis=1), p, atol=1e-08)
+ np.testing.assert_allclose(G0.sum(axis=0), q, atol=1e-08)
constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun)
- G0 = p[:, None] * q[None, :]
-
def f(G):
return gwloss(constC, hC1, hC2, G)
@@ -610,19 +660,16 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5,
if log:
res, log = cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs)
-
fgw_dist = nx.from_numpy(log['loss'][-1], type_as=C10)
-
log['fgw_dist'] = fgw_dist
log['u'] = nx.from_numpy(log['u'], type_as=C10)
log['v'] = nx.from_numpy(log['v'], type_as=C10)
return nx.from_numpy(res, type_as=C10), log
-
else:
return nx.from_numpy(cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs), type_as=C10)
-def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, **kwargs):
+def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, G0=None, log=False, **kwargs):
r"""
Computes the FGW distance between two graphs see (see :ref:`[24] <references-fused-gromov-wasserstein2>`)
@@ -645,6 +692,10 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5
The algorithm used for solving the problem is conditional gradient as
discussed in :ref:`[24] <references-fused-gromov-wasserstein2>`
+ .. note:: This function is backend-compatible and will work on arrays
+ from all compatible backends. But the algorithm uses the C++ CPU backend
+ which can lead to copy overhead on GPU arrays.
+
Note that when using backends, this loss function is differentiable wrt the
marices and weights for quadratic loss using the gradients from [38]_.
@@ -667,6 +718,9 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5
armijo : bool, optional
If True the step of the line-search is found via an armijo research.
Else closed form is used. If there are convergence issues use False.
+ G0: array-like, shape (ns,nt), optional
+ If None the initial transport plan of the solver is pq^T.
+ Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver.
log : bool, optional
Record log if True.
**kwargs : dict
@@ -695,7 +749,11 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5
p, q = list_to_array(p, q)
p0, q0, C10, C20, M0 = p, q, C1, C2, M
- nx = get_backend(p0, q0, C10, C20, M0)
+ if G0 is None:
+ nx = get_backend(p0, q0, C10, C20, M0)
+ else:
+ G0_ = G0
+ nx = get_backend(p0, q0, C10, C20, M0, G0_)
p = nx.to_numpy(p)
q = nx.to_numpy(q)
@@ -705,7 +763,13 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5
constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun)
- G0 = p[:, None] * q[None, :]
+ if G0 is None:
+ G0 = p[:, None] * q[None, :]
+ else:
+ G0 = nx.to_numpy(G0_)
+ # Check marginals of G0
+ np.testing.assert_allclose(G0.sum(axis=1), p, atol=1e-08)
+ np.testing.assert_allclose(G0.sum(axis=0), q, atol=1e-08)
def f(G):
return gwloss(constC, hC1, hC2, G)
@@ -725,10 +789,14 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5
log_fgw['T'] = T0
if loss_fun == 'square_loss':
- gC1 = nx.from_numpy(2 * C1 * (p[:, None] * p[None, :]) - 2 * T.dot(C2).dot(T.T))
- gC2 = nx.from_numpy(2 * C2 * (q[:, None] * q[None, :]) - 2 * T.T.dot(C1).dot(T))
+ gC1 = 2 * C1 * (p[:, None] * p[None, :]) - 2 * T.dot(C2).dot(T.T)
+ gC2 = 2 * C2 * (q[:, None] * q[None, :]) - 2 * T.T.dot(C1).dot(T)
+ gC1 = nx.from_numpy(gC1, type_as=C10)
+ gC2 = nx.from_numpy(gC2, type_as=C10)
fgw_dist = nx.set_gradients(fgw_dist, (p0, q0, C10, C20, M0),
- (log_fgw['u'], log_fgw['v'], alpha * gC1, alpha * gC2, (1 - alpha) * T0))
+ (log_fgw['u'] - nx.mean(log_fgw['u']),
+ log_fgw['v'] - nx.mean(log_fgw['v']),
+ alpha * gC1, alpha * gC2, (1 - alpha) * T0))
if log:
return fgw_dist, log_fgw
@@ -1780,3 +1848,988 @@ def update_feature_matrix(lambdas, Ys, Ts, p):
for s in range(len(Ts))
])
return tmpsum
+
+
+def gromov_wasserstein_dictionary_learning(Cs, D, nt, reg=0., ps=None, q=None, epochs=20, batch_size=32, learning_rate=1., Cdict_init=None, projection='nonnegative_symmetric', use_log=True,
+ tol_outer=10**(-5), tol_inner=10**(-5), max_iter_outer=20, max_iter_inner=200, use_adam_optimizer=True, verbose=False, **kwargs):
+ r"""
+ Infer Gromov-Wasserstein linear dictionary :math:`\{ (\mathbf{C_{dict}[d]}, q) \}_{d \in [D]}` from the list of structures :math:`\{ (\mathbf{C_s},\mathbf{p_s}) \}_s`
+
+ .. math::
+ \min_{\mathbf{C_{dict}}, \{\mathbf{w_s} \}_{s \leq S}} \sum_{s=1}^S GW_2(\mathbf{C_s}, \sum_{d=1}^D w_{s,d}\mathbf{C_{dict}[d]}, \mathbf{p_s}, \mathbf{q}) - reg\| \mathbf{w_s} \|_2^2
+
+ such that, :math:`\forall s \leq S` :
+
+ - :math:`\mathbf{w_s}^\top \mathbf{1}_D = 1`
+ - :math:`\mathbf{w_s} \geq \mathbf{0}_D`
+
+ Where :
+
+ - :math:`\forall s \leq S, \mathbf{C_s}` is a (ns,ns) pairwise similarity matrix of variable size ns.
+ - :math:`\mathbf{C_{dict}}` is a (D, nt, nt) tensor of D pairwise similarity matrix of fixed size nt.
+ - :math:`\forall s \leq S, \mathbf{p_s}` is the source distribution corresponding to :math:`\mathbf{C_s}`
+ - :math:`\mathbf{q}` is the target distribution assigned to every structures in the embedding space.
+ - reg is the regularization coefficient.
+
+ The stochastic algorithm used for estimating the graph dictionary atoms as proposed in [38]
+
+ Parameters
+ ----------
+ Cs : list of S symmetric array-like, shape (ns, ns)
+ List of Metric/Graph cost matrices of variable size (ns, ns).
+ D: int
+ Number of dictionary atoms to learn
+ nt: int
+ Number of samples within each dictionary atoms
+ reg : float, optional
+ Coefficient of the negative quadratic regularization used to promote sparsity of w. The default is 0.
+ ps : list of S array-like, shape (ns,), optional
+ Distribution in each source space C of Cs. Default is None and corresponds to uniform distibutions.
+ q : array-like, shape (nt,), optional
+ Distribution in the embedding space whose structure will be learned. Default is None and corresponds to uniform distributions.
+ epochs: int, optional
+ Number of epochs used to learn the dictionary. Default is 32.
+ batch_size: int, optional
+ Batch size for each stochastic gradient update of the dictionary. Set to the dataset size if the provided batch_size is higher than the dataset size. Default is 32.
+ learning_rate: float, optional
+ Learning rate used for the stochastic gradient descent. Default is 1.
+ Cdict_init: list of D array-like with shape (nt, nt), optional
+ Used to initialize the dictionary.
+ If set to None (Default), the dictionary will be initialized randomly.
+ Else Cdict must have shape (D, nt, nt) i.e match provided shape features.
+ projection: str , optional
+ If 'nonnegative' and/or 'symmetric' is in projection, the corresponding projection will be performed at each stochastic update of the dictionary
+ Else the set of atoms is :math:`R^{nt * nt}`. Default is 'nonnegative_symmetric'
+ log: bool, optional
+ If set to True, losses evolution by batches and epochs are tracked. Default is False.
+ use_adam_optimizer: bool, optional
+ If set to True, adam optimizer with default settings is used as adaptative learning rate strategy.
+ Else perform SGD with fixed learning rate. Default is True.
+ tol_outer : float, optional
+ Solver precision for the BCD algorithm, measured by absolute relative error on consecutive losses. Default is :math:`10^{-5}`.
+ tol_inner : float, optional
+ Solver precision for the Conjugate Gradient algorithm used to get optimal w at a fixed transport, measured by absolute relative error on consecutive losses. Default is :math:`10^{-5}`.
+ max_iter_outer : int, optional
+ Maximum number of iterations for the BCD. Default is 20.
+ max_iter_inner : int, optional
+ Maximum number of iterations for the Conjugate Gradient. Default is 200.
+ verbose : bool, optional
+ Print the reconstruction loss every epoch. Default is False.
+
+ Returns
+ -------
+
+ Cdict_best_state : D array-like, shape (D,nt,nt)
+ Metric/Graph cost matrices composing the dictionary.
+ The dictionary leading to the best loss over an epoch is saved and returned.
+ log: dict
+ If use_log is True, contains loss evolutions by batches and epochs.
+ References
+ -------
+
+ ..[38] Cédric Vincent-Cuaz, Titouan Vayer, Rémi Flamary, Marco Corneli, Nicolas Courty.
+ "Online Graph Dictionary Learning"
+ International Conference on Machine Learning (ICML). 2021.
+ """
+ # Handle backend of non-optional arguments
+ Cs0 = Cs
+ nx = get_backend(*Cs0)
+ Cs = [nx.to_numpy(C) for C in Cs0]
+ dataset_size = len(Cs)
+ # Handle backend of optional arguments
+ if ps is None:
+ ps = [unif(C.shape[0]) for C in Cs]
+ else:
+ ps = [nx.to_numpy(p) for p in ps]
+ if q is None:
+ q = unif(nt)
+ else:
+ q = nx.to_numpy(q)
+ if Cdict_init is None:
+ # Initialize randomly structures of dictionary atoms based on samples
+ dataset_means = [C.mean() for C in Cs]
+ Cdict = np.random.normal(loc=np.mean(dataset_means), scale=np.std(dataset_means), size=(D, nt, nt))
+ else:
+ Cdict = nx.to_numpy(Cdict_init).copy()
+ assert Cdict.shape == (D, nt, nt)
+
+ if 'symmetric' in projection:
+ Cdict = 0.5 * (Cdict + Cdict.transpose((0, 2, 1)))
+ if 'nonnegative' in projection:
+ Cdict[Cdict < 0.] = 0
+ if use_adam_optimizer:
+ adam_moments = _initialize_adam_optimizer(Cdict)
+
+ log = {'loss_batches': [], 'loss_epochs': []}
+ const_q = q[:, None] * q[None, :]
+ Cdict_best_state = Cdict.copy()
+ loss_best_state = np.inf
+ if batch_size > dataset_size:
+ batch_size = dataset_size
+ iter_by_epoch = dataset_size // batch_size + int((dataset_size % batch_size) > 0)
+
+ for epoch in range(epochs):
+ cumulated_loss_over_epoch = 0.
+
+ for _ in range(iter_by_epoch):
+ # batch sampling
+ batch = np.random.choice(range(dataset_size), size=batch_size, replace=False)
+ cumulated_loss_over_batch = 0.
+ unmixings = np.zeros((batch_size, D))
+ Cs_embedded = np.zeros((batch_size, nt, nt))
+ Ts = [None] * batch_size
+
+ for batch_idx, C_idx in enumerate(batch):
+ # BCD solver for Gromov-Wassersteisn linear unmixing used independently on each structure of the sampled batch
+ unmixings[batch_idx], Cs_embedded[batch_idx], Ts[batch_idx], current_loss = gromov_wasserstein_linear_unmixing(
+ Cs[C_idx], Cdict, reg=reg, p=ps[C_idx], q=q, tol_outer=tol_outer, tol_inner=tol_inner,
+ max_iter_outer=max_iter_outer, max_iter_inner=max_iter_inner
+ )
+ cumulated_loss_over_batch += current_loss
+ cumulated_loss_over_epoch += cumulated_loss_over_batch
+
+ if use_log:
+ log['loss_batches'].append(cumulated_loss_over_batch)
+
+ # Stochastic projected gradient step over dictionary atoms
+ grad_Cdict = np.zeros_like(Cdict)
+ for batch_idx, C_idx in enumerate(batch):
+ shared_term_structures = Cs_embedded[batch_idx] * const_q - (Cs[C_idx].dot(Ts[batch_idx])).T.dot(Ts[batch_idx])
+ grad_Cdict += unmixings[batch_idx][:, None, None] * shared_term_structures[None, :, :]
+ grad_Cdict *= 2 / batch_size
+ if use_adam_optimizer:
+ Cdict, adam_moments = _adam_stochastic_updates(Cdict, grad_Cdict, learning_rate, adam_moments)
+ else:
+ Cdict -= learning_rate * grad_Cdict
+ if 'symmetric' in projection:
+ Cdict = 0.5 * (Cdict + Cdict.transpose((0, 2, 1)))
+ if 'nonnegative' in projection:
+ Cdict[Cdict < 0.] = 0.
+
+ if use_log:
+ log['loss_epochs'].append(cumulated_loss_over_epoch)
+ if loss_best_state > cumulated_loss_over_epoch:
+ loss_best_state = cumulated_loss_over_epoch
+ Cdict_best_state = Cdict.copy()
+ if verbose:
+ print('--- epoch =', epoch, ' cumulated reconstruction error: ', cumulated_loss_over_epoch)
+
+ return nx.from_numpy(Cdict_best_state), log
+
+
+def _initialize_adam_optimizer(variable):
+
+ # Initialization for our numpy implementation of adam optimizer
+ atoms_adam_m = np.zeros_like(variable) # Initialize first moment tensor
+ atoms_adam_v = np.zeros_like(variable) # Initialize second moment tensor
+ atoms_adam_count = 1
+
+ return {'mean': atoms_adam_m, 'var': atoms_adam_v, 'count': atoms_adam_count}
+
+
+def _adam_stochastic_updates(variable, grad, learning_rate, adam_moments, beta_1=0.9, beta_2=0.99, eps=1e-09):
+
+ adam_moments['mean'] = beta_1 * adam_moments['mean'] + (1 - beta_1) * grad
+ adam_moments['var'] = beta_2 * adam_moments['var'] + (1 - beta_2) * (grad**2)
+ unbiased_m = adam_moments['mean'] / (1 - beta_1**adam_moments['count'])
+ unbiased_v = adam_moments['var'] / (1 - beta_2**adam_moments['count'])
+ variable -= learning_rate * unbiased_m / (np.sqrt(unbiased_v) + eps)
+ adam_moments['count'] += 1
+
+ return variable, adam_moments
+
+
+def gromov_wasserstein_linear_unmixing(C, Cdict, reg=0., p=None, q=None, tol_outer=10**(-5), tol_inner=10**(-5), max_iter_outer=20, max_iter_inner=200, **kwargs):
+ r"""
+ Returns the Gromov-Wasserstein linear unmixing of :math:`(\mathbf{C},\mathbf{p})` onto the dictionary :math:`\{ (\mathbf{C_{dict}[d]}, \mathbf{q}) \}_{d \in [D]}`.
+
+ .. math::
+ \min_{ \mathbf{w}} GW_2(\mathbf{C}, \sum_{d=1}^D w_d\mathbf{C_{dict}[d]}, \mathbf{p}, \mathbf{q}) - reg \| \mathbf{w} \|_2^2
+
+ such that:
+
+ - :math:`\mathbf{w}^\top \mathbf{1}_D = 1`
+ - :math:`\mathbf{w} \geq \mathbf{0}_D`
+
+ Where :
+
+ - :math:`\mathbf{C}` is the (ns,ns) pairwise similarity matrix.
+ - :math:`\mathbf{C_{dict}}` is a (D, nt, nt) tensor of D pairwise similarity matrices of size nt.
+ - :math:`\mathbf{p}` and :math:`\mathbf{q}` are source and target weights.
+ - reg is the regularization coefficient.
+
+ The algorithm used for solving the problem is a Block Coordinate Descent as discussed in [38], algorithm 1.
+
+ Parameters
+ ----------
+ C : array-like, shape (ns, ns)
+ Metric/Graph cost matrix.
+ Cdict : D array-like, shape (D,nt,nt)
+ Metric/Graph cost matrices composing the dictionary on which to embed C.
+ reg : float, optional.
+ Coefficient of the negative quadratic regularization used to promote sparsity of w. Default is 0.
+ p : array-like, shape (ns,), optional
+ Distribution in the source space C. Default is None and corresponds to uniform distribution.
+ q : array-like, shape (nt,), optional
+ Distribution in the space depicted by the dictionary. Default is None and corresponds to uniform distribution.
+ tol_outer : float, optional
+ Solver precision for the BCD algorithm.
+ tol_inner : float, optional
+ Solver precision for the Conjugate Gradient algorithm used to get optimal w at a fixed transport. Default is :math:`10^{-5}`.
+ max_iter_outer : int, optional
+ Maximum number of iterations for the BCD. Default is 20.
+ max_iter_inner : int, optional
+ Maximum number of iterations for the Conjugate Gradient. Default is 200.
+
+ Returns
+ -------
+ w: array-like, shape (D,)
+ gromov-wasserstein linear unmixing of :math:`(\mathbf{C},\mathbf{p})` onto the span of the dictionary.
+ Cembedded: array-like, shape (nt,nt)
+ embedded structure of :math:`(\mathbf{C},\mathbf{p})` onto the dictionary, :math:`\sum_d w_d\mathbf{C_{dict}[d]}`.
+ T: array-like (ns, nt)
+ Gromov-Wasserstein transport plan between :math:`(\mathbf{C},\mathbf{p})` and :math:`(\sum_d w_d\mathbf{C_{dict}[d]}, \mathbf{q})`
+ current_loss: float
+ reconstruction error
+ References
+ -------
+
+ ..[38] Cédric Vincent-Cuaz, Titouan Vayer, Rémi Flamary, Marco Corneli, Nicolas Courty.
+ "Online Graph Dictionary Learning"
+ International Conference on Machine Learning (ICML). 2021.
+ """
+ C0, Cdict0 = C, Cdict
+ nx = get_backend(C0, Cdict0)
+ C = nx.to_numpy(C0)
+ Cdict = nx.to_numpy(Cdict0)
+ if p is None:
+ p = unif(C.shape[0])
+ else:
+ p = nx.to_numpy(p)
+
+ if q is None:
+ q = unif(Cdict.shape[-1])
+ else:
+ q = nx.to_numpy(q)
+
+ T = p[:, None] * q[None, :]
+ D = len(Cdict)
+
+ w = unif(D) # Initialize uniformly the unmixing w
+ Cembedded = np.sum(w[:, None, None] * Cdict, axis=0)
+
+ const_q = q[:, None] * q[None, :]
+ # Trackers for BCD convergence
+ convergence_criterion = np.inf
+ current_loss = 10**15
+ outer_count = 0
+
+ while (convergence_criterion > tol_outer) and (outer_count < max_iter_outer):
+ previous_loss = current_loss
+ # 1. Solve GW transport between (C,p) and (\sum_d Cdictionary[d],q) fixing the unmixing w
+ T, log = gromov_wasserstein(C1=C, C2=Cembedded, p=p, q=q, loss_fun='square_loss', G0=T, log=True, armijo=False, **kwargs)
+ current_loss = log['gw_dist']
+ if reg != 0:
+ current_loss -= reg * np.sum(w**2)
+
+ # 2. Solve linear unmixing problem over w with a fixed transport plan T
+ w, Cembedded, current_loss = _cg_gromov_wasserstein_unmixing(
+ C=C, Cdict=Cdict, Cembedded=Cembedded, w=w, const_q=const_q, T=T,
+ starting_loss=current_loss, reg=reg, tol=tol_inner, max_iter=max_iter_inner, **kwargs
+ )
+
+ if previous_loss != 0:
+ convergence_criterion = abs(previous_loss - current_loss) / abs(previous_loss)
+ else: # handle numerical issues around 0
+ convergence_criterion = abs(previous_loss - current_loss) / 10**(-15)
+ outer_count += 1
+
+ return nx.from_numpy(w), nx.from_numpy(Cembedded), nx.from_numpy(T), nx.from_numpy(current_loss)
+
+
+def _cg_gromov_wasserstein_unmixing(C, Cdict, Cembedded, w, const_q, T, starting_loss, reg=0., tol=10**(-5), max_iter=200, **kwargs):
+ r"""
+ Returns for a fixed admissible transport plan,
+ the linear unmixing w minimizing the Gromov-Wasserstein cost between :math:`(\mathbf{C},\mathbf{p})` and :math:`(\sum_d w[d]*\mathbf{C_{dict}[d]}, \mathbf{q})`
+
+ .. math::
+ \min_{\mathbf{w}} \sum_{ijkl} (C_{i,j} - \sum_{d=1}^D w_d*C_{dict}[d]_{k,l} )^2 T_{i,k}T_{j,l} - reg* \| \mathbf{w} \|_2^2
+
+
+ Such that:
+
+ - :math:`\mathbf{w}^\top \mathbf{1}_D = 1`
+ - :math:`\mathbf{w} \geq \mathbf{0}_D`
+
+ Where :
+
+ - :math:`\mathbf{C}` is the (ns,ns) pairwise similarity matrix.
+ - :math:`\mathbf{C_{dict}}` is a (D, nt, nt) tensor of D pairwise similarity matrices of nt points.
+ - :math:`\mathbf{p}` and :math:`\mathbf{q}` are source and target weights.
+ - :math:`\mathbf{w}` is the linear unmixing of :math:`(\mathbf{C}, \mathbf{p})` onto :math:`(\sum_d w_d \mathbf{Cdict[d]}, \mathbf{q})`.
+ - :math:`\mathbf{T}` is the optimal transport plan conditioned by the current state of :math:`\mathbf{w}`.
+ - reg is the regularization coefficient.
+
+ The algorithm used for solving the problem is a Conditional Gradient Descent as discussed in [38]
+
+ Parameters
+ ----------
+
+ C : array-like, shape (ns, ns)
+ Metric/Graph cost matrix.
+ Cdict : list of D array-like, shape (nt,nt)
+ Metric/Graph cost matrices composing the dictionary on which to embed C.
+ Each matrix in the dictionary must have the same size (nt,nt).
+ Cembedded: array-like, shape (nt,nt)
+ Embedded structure :math:`(\sum_d w[d]*Cdict[d],q)` of :math:`(\mathbf{C},\mathbf{p})` onto the dictionary. Used to avoid redundant computations.
+ w: array-like, shape (D,)
+ Linear unmixing of the input structure onto the dictionary
+ const_q: array-like, shape (nt,nt)
+ product matrix :math:`\mathbf{q}\mathbf{q}^\top` where q is the target space distribution. Used to avoid redundant computations.
+ T: array-like, shape (ns,nt)
+ fixed transport plan between the input structure and its representation in the dictionary.
+ p : array-like, shape (ns,)
+ Distribution in the source space.
+ q : array-like, shape (nt,)
+ Distribution in the embedding space depicted by the dictionary.
+ reg : float, optional.
+ Coefficient of the negative quadratic regularization used to promote sparsity of w. Default is 0.
+
+ Returns
+ -------
+ w: ndarray (D,)
+ optimal unmixing of :math:`(\mathbf{C},\mathbf{p})` onto the dictionary span given OT starting from previously optimal unmixing.
+ """
+ convergence_criterion = np.inf
+ current_loss = starting_loss
+ count = 0
+ const_TCT = np.transpose(C.dot(T)).dot(T)
+
+ while (convergence_criterion > tol) and (count < max_iter):
+
+ previous_loss = current_loss
+ # 1) Compute gradient at current point w
+ grad_w = 2 * np.sum(Cdict * (Cembedded[None, :, :] * const_q[None, :, :] - const_TCT[None, :, :]), axis=(1, 2))
+ grad_w -= 2 * reg * w
+
+ # 2) Conditional gradient direction finding: x= \argmin_x x^T.grad_w
+ min_ = np.min(grad_w)
+ x = (grad_w == min_).astype(np.float64)
+ x /= np.sum(x)
+
+ # 3) Line-search step: solve \argmin_{\gamma \in [0,1]} a*gamma^2 + b*gamma + c
+ gamma, a, b, Cembedded_diff = _linesearch_gromov_wasserstein_unmixing(w, grad_w, x, Cdict, Cembedded, const_q, const_TCT, reg)
+
+ # 4) Updates: w <-- (1-gamma)*w + gamma*x
+ w += gamma * (x - w)
+ Cembedded += gamma * Cembedded_diff
+ current_loss += a * (gamma**2) + b * gamma
+
+ if previous_loss != 0: # not that the loss can be negative if reg >0
+ convergence_criterion = abs(previous_loss - current_loss) / abs(previous_loss)
+ else: # handle numerical issues around 0
+ convergence_criterion = abs(previous_loss - current_loss) / 10**(-15)
+ count += 1
+
+ return w, Cembedded, current_loss
+
+
+def _linesearch_gromov_wasserstein_unmixing(w, grad_w, x, Cdict, Cembedded, const_q, const_TCT, reg, **kwargs):
+ r"""
+ Compute optimal steps for the line search problem of Gromov-Wasserstein linear unmixing
+ .. math::
+ \min_{\gamma \in [0,1]} \sum_{ijkl} (C_{i,j} - \sum_{d=1}^D z_d(\gamma)C_{dict}[d]_{k,l} )^2 T_{i,k}T_{j,l} - reg\| \mathbf{z}(\gamma) \|_2^2
+
+
+ Such that:
+
+ - :math:`\mathbf{z}(\gamma) = (1- \gamma)\mathbf{w} + \gamma \mathbf{x}`
+
+ Parameters
+ ----------
+
+ w : array-like, shape (D,)
+ Unmixing.
+ grad_w : array-like, shape (D, D)
+ Gradient of the reconstruction loss with respect to w.
+ x: array-like, shape (D,)
+ Conditional gradient direction.
+ Cdict : list of D array-like, shape (nt,nt)
+ Metric/Graph cost matrices composing the dictionary on which to embed C.
+ Each matrix in the dictionary must have the same size (nt,nt).
+ Cembedded: array-like, shape (nt,nt)
+ Embedded structure :math:`(\sum_d w_dCdict[d],q)` of :math:`(\mathbf{C},\mathbf{p})` onto the dictionary. Used to avoid redundant computations.
+ const_q: array-like, shape (nt,nt)
+ product matrix :math:`\mathbf{q}\mathbf{q}^\top` where q is the target space distribution. Used to avoid redundant computations.
+ const_TCT: array-like, shape (nt, nt)
+ :math:`\mathbf{T}^\top \mathbf{C}^\top \mathbf{T}`. Used to avoid redundant computations.
+ Returns
+ -------
+ gamma: float
+ Optimal value for the line-search step
+ a: float
+ Constant factor appearing in the factorization :math:`a \gamma^2 + b \gamma +c` of the reconstruction loss
+ b: float
+ Constant factor appearing in the factorization :math:`a \gamma^2 + b \gamma +c` of the reconstruction loss
+ Cembedded_diff: numpy array, shape (nt, nt)
+ Difference between models evaluated in :math:`\mathbf{w}` and in :math:`\mathbf{w}`.
+ reg : float, optional.
+ Coefficient of the negative quadratic regularization used to promote sparsity of :math:`\mathbf{w}`.
+ """
+
+ # 3) Line-search step: solve \argmin_{\gamma \in [0,1]} a*gamma^2 + b*gamma + c
+ Cembedded_x = np.sum(x[:, None, None] * Cdict, axis=0)
+ Cembedded_diff = Cembedded_x - Cembedded
+ trace_diffx = np.sum(Cembedded_diff * Cembedded_x * const_q)
+ trace_diffw = np.sum(Cembedded_diff * Cembedded * const_q)
+ a = trace_diffx - trace_diffw
+ b = 2 * (trace_diffw - np.sum(Cembedded_diff * const_TCT))
+ if reg != 0:
+ a -= reg * np.sum((x - w)**2)
+ b -= 2 * reg * np.sum(w * (x - w))
+
+ if a > 0:
+ gamma = min(1, max(0, - b / (2 * a)))
+ elif a + b < 0:
+ gamma = 1
+ else:
+ gamma = 0
+
+ return gamma, a, b, Cembedded_diff
+
+
+def fused_gromov_wasserstein_dictionary_learning(Cs, Ys, D, nt, alpha, reg=0., ps=None, q=None, epochs=20, batch_size=32, learning_rate_C=1., learning_rate_Y=1.,
+ Cdict_init=None, Ydict_init=None, projection='nonnegative_symmetric', use_log=False,
+ tol_outer=10**(-5), tol_inner=10**(-5), max_iter_outer=20, max_iter_inner=200, use_adam_optimizer=True, verbose=False, **kwargs):
+ r"""
+ Infer Fused Gromov-Wasserstein linear dictionary :math:`\{ (\mathbf{C_{dict}[d]}, \mathbf{Y_{dict}[d]}, \mathbf{q}) \}_{d \in [D]}` from the list of S attributed structures :math:`\{ (\mathbf{C_s}, \mathbf{Y_s},\mathbf{p_s}) \}_s`
+
+ .. math::
+ \min_{\mathbf{C_{dict}},\mathbf{Y_{dict}}, \{\mathbf{w_s}\}_{s}} \sum_{s=1}^S FGW_{2,\alpha}(\mathbf{C_s}, \mathbf{Y_s}, \sum_{d=1}^D w_{s,d}\mathbf{C_{dict}[d]},\sum_{d=1}^D w_{s,d}\mathbf{Y_{dict}[d]}, \mathbf{p_s}, \mathbf{q}) \\ - reg\| \mathbf{w_s} \|_2^2
+
+
+ Such that :math:`\forall s \leq S` :
+
+ - :math:`\mathbf{w_s}^\top \mathbf{1}_D = 1`
+ - :math:`\mathbf{w_s} \geq \mathbf{0}_D`
+
+ Where :
+
+ - :math:`\forall s \leq S, \mathbf{C_s}` is a (ns,ns) pairwise similarity matrix of variable size ns.
+ - :math:`\forall s \leq S, \mathbf{Y_s}` is a (ns,d) features matrix of variable size ns and fixed dimension d.
+ - :math:`\mathbf{C_{dict}}` is a (D, nt, nt) tensor of D pairwise similarity matrix of fixed size nt.
+ - :math:`\mathbf{Y_{dict}}` is a (D, nt, d) tensor of D features matrix of fixed size nt and fixed dimension d.
+ - :math:`\forall s \leq S, \mathbf{p_s}` is the source distribution corresponding to :math:`\mathbf{C_s}`
+ - :math:`\mathbf{q}` is the target distribution assigned to every structures in the embedding space.
+ - :math:`\alpha` is the trade-off parameter of Fused Gromov-Wasserstein
+ - reg is the regularization coefficient.
+
+
+ The stochastic algorithm used for estimating the attributed graph dictionary atoms as proposed in [38]
+
+ Parameters
+ ----------
+ Cs : list of S symmetric array-like, shape (ns, ns)
+ List of Metric/Graph cost matrices of variable size (ns,ns).
+ Ys : list of S array-like, shape (ns, d)
+ List of feature matrix of variable size (ns,d) with d fixed.
+ D: int
+ Number of dictionary atoms to learn
+ nt: int
+ Number of samples within each dictionary atoms
+ alpha : float
+ Trade-off parameter of Fused Gromov-Wasserstein
+ reg : float, optional
+ Coefficient of the negative quadratic regularization used to promote sparsity of w. The default is 0.
+ ps : list of S array-like, shape (ns,), optional
+ Distribution in each source space C of Cs. Default is None and corresponds to uniform distibutions.
+ q : array-like, shape (nt,), optional
+ Distribution in the embedding space whose structure will be learned. Default is None and corresponds to uniform distributions.
+ epochs: int, optional
+ Number of epochs used to learn the dictionary. Default is 32.
+ batch_size: int, optional
+ Batch size for each stochastic gradient update of the dictionary. Set to the dataset size if the provided batch_size is higher than the dataset size. Default is 32.
+ learning_rate_C: float, optional
+ Learning rate used for the stochastic gradient descent on Cdict. Default is 1.
+ learning_rate_Y: float, optional
+ Learning rate used for the stochastic gradient descent on Ydict. Default is 1.
+ Cdict_init: list of D array-like with shape (nt, nt), optional
+ Used to initialize the dictionary structures Cdict.
+ If set to None (Default), the dictionary will be initialized randomly.
+ Else Cdict must have shape (D, nt, nt) i.e match provided shape features.
+ Ydict_init: list of D array-like with shape (nt, d), optional
+ Used to initialize the dictionary features Ydict.
+ If set to None, the dictionary features will be initialized randomly.
+ Else Ydict must have shape (D, nt, d) where d is the features dimension of inputs Ys and also match provided shape features.
+ projection: str, optional
+ If 'nonnegative' and/or 'symmetric' is in projection, the corresponding projection will be performed at each stochastic update of the dictionary
+ Else the set of atoms is :math:`R^{nt * nt}`. Default is 'nonnegative_symmetric'
+ log: bool, optional
+ If set to True, losses evolution by batches and epochs are tracked. Default is False.
+ use_adam_optimizer: bool, optional
+ If set to True, adam optimizer with default settings is used as adaptative learning rate strategy.
+ Else perform SGD with fixed learning rate. Default is True.
+ tol_outer : float, optional
+ Solver precision for the BCD algorithm, measured by absolute relative error on consecutive losses. Default is :math:`10^{-5}`.
+ tol_inner : float, optional
+ Solver precision for the Conjugate Gradient algorithm used to get optimal w at a fixed transport, measured by absolute relative error on consecutive losses. Default is :math:`10^{-5}`.
+ max_iter_outer : int, optional
+ Maximum number of iterations for the BCD. Default is 20.
+ max_iter_inner : int, optional
+ Maximum number of iterations for the Conjugate Gradient. Default is 200.
+ verbose : bool, optional
+ Print the reconstruction loss every epoch. Default is False.
+
+ Returns
+ -------
+
+ Cdict_best_state : D array-like, shape (D,nt,nt)
+ Metric/Graph cost matrices composing the dictionary.
+ The dictionary leading to the best loss over an epoch is saved and returned.
+ Ydict_best_state : D array-like, shape (D,nt,d)
+ Feature matrices composing the dictionary.
+ The dictionary leading to the best loss over an epoch is saved and returned.
+ log: dict
+ If use_log is True, contains loss evolutions by batches and epoches.
+ References
+ -------
+
+ ..[38] Cédric Vincent-Cuaz, Titouan Vayer, Rémi Flamary, Marco Corneli, Nicolas Courty.
+ "Online Graph Dictionary Learning"
+ International Conference on Machine Learning (ICML). 2021.
+ """
+ Cs0, Ys0 = Cs, Ys
+ nx = get_backend(*Cs0, *Ys0)
+ Cs = [nx.to_numpy(C) for C in Cs0]
+ Ys = [nx.to_numpy(Y) for Y in Ys0]
+
+ d = Ys[0].shape[-1]
+ dataset_size = len(Cs)
+
+ if ps is None:
+ ps = [unif(C.shape[0]) for C in Cs]
+ else:
+ ps = [nx.to_numpy(p) for p in ps]
+ if q is None:
+ q = unif(nt)
+ else:
+ q = nx.to_numpy(q)
+
+ if Cdict_init is None:
+ # Initialize randomly structures of dictionary atoms based on samples
+ dataset_means = [C.mean() for C in Cs]
+ Cdict = np.random.normal(loc=np.mean(dataset_means), scale=np.std(dataset_means), size=(D, nt, nt))
+ else:
+ Cdict = nx.to_numpy(Cdict_init).copy()
+ assert Cdict.shape == (D, nt, nt)
+ if Ydict_init is None:
+ # Initialize randomly features of dictionary atoms based on samples distribution by feature component
+ dataset_feature_means = np.stack([F.mean(axis=0) for F in Ys])
+ Ydict = np.random.normal(loc=dataset_feature_means.mean(axis=0), scale=dataset_feature_means.std(axis=0), size=(D, nt, d))
+ else:
+ Ydict = nx.to_numpy(Ydict_init).copy()
+ assert Ydict.shape == (D, nt, d)
+
+ if 'symmetric' in projection:
+ Cdict = 0.5 * (Cdict + Cdict.transpose((0, 2, 1)))
+ if 'nonnegative' in projection:
+ Cdict[Cdict < 0.] = 0.
+
+ if use_adam_optimizer:
+ adam_moments_C = _initialize_adam_optimizer(Cdict)
+ adam_moments_Y = _initialize_adam_optimizer(Ydict)
+
+ log = {'loss_batches': [], 'loss_epochs': []}
+ const_q = q[:, None] * q[None, :]
+ diag_q = np.diag(q)
+ Cdict_best_state = Cdict.copy()
+ Ydict_best_state = Ydict.copy()
+ loss_best_state = np.inf
+ if batch_size > dataset_size:
+ batch_size = dataset_size
+ iter_by_epoch = dataset_size // batch_size + int((dataset_size % batch_size) > 0)
+
+ for epoch in range(epochs):
+ cumulated_loss_over_epoch = 0.
+
+ for _ in range(iter_by_epoch):
+
+ # Batch iterations
+ batch = np.random.choice(range(dataset_size), size=batch_size, replace=False)
+ cumulated_loss_over_batch = 0.
+ unmixings = np.zeros((batch_size, D))
+ Cs_embedded = np.zeros((batch_size, nt, nt))
+ Ys_embedded = np.zeros((batch_size, nt, d))
+ Ts = [None] * batch_size
+
+ for batch_idx, C_idx in enumerate(batch):
+ # BCD solver for Gromov-Wassersteisn linear unmixing used independently on each structure of the sampled batch
+ unmixings[batch_idx], Cs_embedded[batch_idx], Ys_embedded[batch_idx], Ts[batch_idx], current_loss = fused_gromov_wasserstein_linear_unmixing(
+ Cs[C_idx], Ys[C_idx], Cdict, Ydict, alpha, reg=reg, p=ps[C_idx], q=q,
+ tol_outer=tol_outer, tol_inner=tol_inner, max_iter_outer=max_iter_outer, max_iter_inner=max_iter_inner
+ )
+ cumulated_loss_over_batch += current_loss
+ cumulated_loss_over_epoch += cumulated_loss_over_batch
+ if use_log:
+ log['loss_batches'].append(cumulated_loss_over_batch)
+
+ # Stochastic projected gradient step over dictionary atoms
+ grad_Cdict = np.zeros_like(Cdict)
+ grad_Ydict = np.zeros_like(Ydict)
+
+ for batch_idx, C_idx in enumerate(batch):
+ shared_term_structures = Cs_embedded[batch_idx] * const_q - (Cs[C_idx].dot(Ts[batch_idx])).T.dot(Ts[batch_idx])
+ shared_term_features = diag_q.dot(Ys_embedded[batch_idx]) - Ts[batch_idx].T.dot(Ys[C_idx])
+ grad_Cdict += alpha * unmixings[batch_idx][:, None, None] * shared_term_structures[None, :, :]
+ grad_Ydict += (1 - alpha) * unmixings[batch_idx][:, None, None] * shared_term_features[None, :, :]
+ grad_Cdict *= 2 / batch_size
+ grad_Ydict *= 2 / batch_size
+
+ if use_adam_optimizer:
+ Cdict, adam_moments_C = _adam_stochastic_updates(Cdict, grad_Cdict, learning_rate_C, adam_moments_C)
+ Ydict, adam_moments_Y = _adam_stochastic_updates(Ydict, grad_Ydict, learning_rate_Y, adam_moments_Y)
+ else:
+ Cdict -= learning_rate_C * grad_Cdict
+ Ydict -= learning_rate_Y * grad_Ydict
+
+ if 'symmetric' in projection:
+ Cdict = 0.5 * (Cdict + Cdict.transpose((0, 2, 1)))
+ if 'nonnegative' in projection:
+ Cdict[Cdict < 0.] = 0.
+
+ if use_log:
+ log['loss_epochs'].append(cumulated_loss_over_epoch)
+ if loss_best_state > cumulated_loss_over_epoch:
+ loss_best_state = cumulated_loss_over_epoch
+ Cdict_best_state = Cdict.copy()
+ Ydict_best_state = Ydict.copy()
+ if verbose:
+ print('--- epoch: ', epoch, ' cumulated reconstruction error: ', cumulated_loss_over_epoch)
+
+ return nx.from_numpy(Cdict_best_state), nx.from_numpy(Ydict_best_state), log
+
+
+def fused_gromov_wasserstein_linear_unmixing(C, Y, Cdict, Ydict, alpha, reg=0., p=None, q=None, tol_outer=10**(-5), tol_inner=10**(-5), max_iter_outer=20, max_iter_inner=200, **kwargs):
+ r"""
+ Returns the Fused Gromov-Wasserstein linear unmixing of :math:`(\mathbf{C},\mathbf{Y},\mathbf{p})` onto the attributed dictionary atoms :math:`\{ (\mathbf{C_{dict}[d]},\mathbf{Y_{dict}[d]}, \mathbf{q}) \}_{d \in [D]}`
+
+ .. math::
+ \min_{\mathbf{w}} FGW_{2,\alpha}(\mathbf{C},\mathbf{Y}, \sum_{d=1}^D w_d\mathbf{C_{dict}[d]},\sum_{d=1}^D w_d\mathbf{Y_{dict}[d]}, \mathbf{p}, \mathbf{q}) - reg \| \mathbf{w} \|_2^2
+
+ such that, :math:`\forall s \leq S` :
+
+ - :math:`\mathbf{w_s}^\top \mathbf{1}_D = 1`
+ - :math:`\mathbf{w_s} \geq \mathbf{0}_D`
+
+ Where :
+
+ - :math:`\mathbf{C}` is a (ns,ns) pairwise similarity matrix of variable size ns.
+ - :math:`\mathbf{Y}` is a (ns,d) features matrix of variable size ns and fixed dimension d.
+ - :math:`\mathbf{C_{dict}}` is a (D, nt, nt) tensor of D pairwise similarity matrix of fixed size nt.
+ - :math:`\mathbf{Y_{dict}}` is a (D, nt, d) tensor of D features matrix of fixed size nt and fixed dimension d.
+ - :math:`\mathbf{p}` is the source distribution corresponding to :math:`\mathbf{C_s}`
+ - :math:`\mathbf{q}` is the target distribution assigned to every structures in the embedding space.
+ - :math:`\alpha` is the trade-off parameter of Fused Gromov-Wasserstein
+ - reg is the regularization coefficient.
+
+ The algorithm used for solving the problem is a Block Coordinate Descent as discussed in [38], algorithm 6.
+
+ Parameters
+ ----------
+ C : array-like, shape (ns, ns)
+ Metric/Graph cost matrix.
+ Y : array-like, shape (ns, d)
+ Feature matrix.
+ Cdict : D array-like, shape (D,nt,nt)
+ Metric/Graph cost matrices composing the dictionary on which to embed (C,Y).
+ Ydict : D array-like, shape (D,nt,d)
+ Feature matrices composing the dictionary on which to embed (C,Y).
+ alpha: float,
+ Trade-off parameter of Fused Gromov-Wasserstein.
+ reg : float, optional
+ Coefficient of the negative quadratic regularization used to promote sparsity of w. The default is 0.
+ p : array-like, shape (ns,), optional
+ Distribution in the source space C. Default is None and corresponds to uniform distribution.
+ q : array-like, shape (nt,), optional
+ Distribution in the space depicted by the dictionary. Default is None and corresponds to uniform distribution.
+ tol_outer : float, optional
+ Solver precision for the BCD algorithm.
+ tol_inner : float, optional
+ Solver precision for the Conjugate Gradient algorithm used to get optimal w at a fixed transport. Default is :math:`10^{-5}`.
+ max_iter_outer : int, optional
+ Maximum number of iterations for the BCD. Default is 20.
+ max_iter_inner : int, optional
+ Maximum number of iterations for the Conjugate Gradient. Default is 200.
+
+ Returns
+ -------
+ w: array-like, shape (D,)
+ fused gromov-wasserstein linear unmixing of (C,Y,p) onto the span of the dictionary.
+ Cembedded: array-like, shape (nt,nt)
+ embedded structure of :math:`(\mathbf{C},\mathbf{Y}, \mathbf{p})` onto the dictionary, :math:`\sum_d w_d\mathbf{C_{dict}[d]}`.
+ Yembedded: array-like, shape (nt,d)
+ embedded features of :math:`(\mathbf{C},\mathbf{Y}, \mathbf{p})` onto the dictionary, :math:`\sum_d w_d\mathbf{Y_{dict}[d]}`.
+ T: array-like (ns,nt)
+ Fused Gromov-Wasserstein transport plan between :math:`(\mathbf{C},\mathbf{p})` and :math:`(\sum_d w_d\mathbf{C_{dict}[d]}, \sum_d w_d\mathbf{Y_{dict}[d]},\mathbf{q})`.
+ current_loss: float
+ reconstruction error
+ References
+ -------
+
+ ..[38] Cédric Vincent-Cuaz, Titouan Vayer, Rémi Flamary, Marco Corneli, Nicolas Courty.
+ "Online Graph Dictionary Learning"
+ International Conference on Machine Learning (ICML). 2021.
+ """
+ C0, Y0, Cdict0, Ydict0 = C, Y, Cdict, Ydict
+ nx = get_backend(C0, Y0, Cdict0, Ydict0)
+ C = nx.to_numpy(C0)
+ Y = nx.to_numpy(Y0)
+ Cdict = nx.to_numpy(Cdict0)
+ Ydict = nx.to_numpy(Ydict0)
+
+ if p is None:
+ p = unif(C.shape[0])
+ else:
+ p = nx.to_numpy(p)
+ if q is None:
+ q = unif(Cdict.shape[-1])
+ else:
+ q = nx.to_numpy(q)
+
+ T = p[:, None] * q[None, :]
+ D = len(Cdict)
+ d = Y.shape[-1]
+ w = unif(D) # Initialize with uniform weights
+ ns = C.shape[-1]
+ nt = Cdict.shape[-1]
+
+ # modeling (C,Y)
+ Cembedded = np.sum(w[:, None, None] * Cdict, axis=0)
+ Yembedded = np.sum(w[:, None, None] * Ydict, axis=0)
+
+ # constants depending on q
+ const_q = q[:, None] * q[None, :]
+ diag_q = np.diag(q)
+ # Trackers for BCD convergence
+ convergence_criterion = np.inf
+ current_loss = 10**15
+ outer_count = 0
+ Ys_constM = (Y**2).dot(np.ones((d, nt))) # constant in computing euclidean pairwise feature matrix
+
+ while (convergence_criterion > tol_outer) and (outer_count < max_iter_outer):
+ previous_loss = current_loss
+
+ # 1. Solve GW transport between (C,p) and (\sum_d Cdictionary[d],q) fixing the unmixing w
+ Yt_varM = (np.ones((ns, d))).dot((Yembedded**2).T)
+ M = Ys_constM + Yt_varM - 2 * Y.dot(Yembedded.T) # euclidean distance matrix between features
+ T, log = fused_gromov_wasserstein(M, C, Cembedded, p, q, loss_fun='square_loss', alpha=alpha, armijo=False, G0=T, log=True)
+ current_loss = log['fgw_dist']
+ if reg != 0:
+ current_loss -= reg * np.sum(w**2)
+
+ # 2. Solve linear unmixing problem over w with a fixed transport plan T
+ w, Cembedded, Yembedded, current_loss = _cg_fused_gromov_wasserstein_unmixing(C, Y, Cdict, Ydict, Cembedded, Yembedded, w,
+ T, p, q, const_q, diag_q, current_loss, alpha, reg,
+ tol=tol_inner, max_iter=max_iter_inner, **kwargs)
+ if previous_loss != 0:
+ convergence_criterion = abs(previous_loss - current_loss) / abs(previous_loss)
+ else:
+ convergence_criterion = abs(previous_loss - current_loss) / 10**(-12)
+ outer_count += 1
+
+ return nx.from_numpy(w), nx.from_numpy(Cembedded), nx.from_numpy(Yembedded), nx.from_numpy(T), nx.from_numpy(current_loss)
+
+
+def _cg_fused_gromov_wasserstein_unmixing(C, Y, Cdict, Ydict, Cembedded, Yembedded, w, T, p, q, const_q, diag_q, starting_loss, alpha, reg, tol=10**(-6), max_iter=200, **kwargs):
+ r"""
+ Returns for a fixed admissible transport plan,
+ the optimal linear unmixing :math:`\mathbf{w}` minimizing the Fused Gromov-Wasserstein cost between :math:`(\mathbf{C},\mathbf{Y},\mathbf{p})` and :math:`(\sum_d w_d \mathbf{C_{dict}[d]},\sum_d w_d*\mathbf{Y_{dict}[d]}, \mathbf{q})`
+
+ .. math::
+ \min_{\mathbf{w}} \alpha \sum_{ijkl} (C_{i,j} - \sum_{d=1}^D w_d C_{dict}[d]_{k,l} )^2 T_{i,k}T_{j,l} \\+ (1-\alpha) \sum_{ij} \| \mathbf{Y_i} - \sum_d w_d \mathbf{Y_{dict}[d]_j} \|_2^2 T_{ij}- reg \| \mathbf{w} \|_2^2
+
+ Such that :
+
+ - :math:`\mathbf{w}^\top \mathbf{1}_D = 1`
+ - :math:`\mathbf{w} \geq \mathbf{0}_D`
+
+ Where :
+
+ - :math:`\mathbf{C}` is a (ns,ns) pairwise similarity matrix of variable size ns.
+ - :math:`\mathbf{Y}` is a (ns,d) features matrix of variable size ns and fixed dimension d.
+ - :math:`\mathbf{C_{dict}}` is a (D, nt, nt) tensor of D pairwise similarity matrix of fixed size nt.
+ - :math:`\mathbf{Y_{dict}}` is a (D, nt, d) tensor of D features matrix of fixed size nt and fixed dimension d.
+ - :math:`\mathbf{p}` is the source distribution corresponding to :math:`\mathbf{C_s}`
+ - :math:`\mathbf{q}` is the target distribution assigned to every structures in the embedding space.
+ - :math:`\mathbf{T}` is the optimal transport plan conditioned by the previous state of :math:`\mathbf{w}`
+ - :math:`\alpha` is the trade-off parameter of Fused Gromov-Wasserstein
+ - reg is the regularization coefficient.
+
+ The algorithm used for solving the problem is a Conditional Gradient Descent as discussed in [38], algorithm 7.
+
+ Parameters
+ ----------
+
+ C : array-like, shape (ns, ns)
+ Metric/Graph cost matrix.
+ Y : array-like, shape (ns, d)
+ Feature matrix.
+ Cdict : list of D array-like, shape (nt,nt)
+ Metric/Graph cost matrices composing the dictionary on which to embed (C,Y).
+ Each matrix in the dictionary must have the same size (nt,nt).
+ Ydict : list of D array-like, shape (nt,d)
+ Feature matrices composing the dictionary on which to embed (C,Y).
+ Each matrix in the dictionary must have the same size (nt,d).
+ Cembedded: array-like, shape (nt,nt)
+ Embedded structure of (C,Y) onto the dictionary
+ Yembedded: array-like, shape (nt,d)
+ Embedded features of (C,Y) onto the dictionary
+ w: array-like, shape (n_D,)
+ Linear unmixing of (C,Y) onto (Cdict,Ydict)
+ const_q: array-like, shape (nt,nt)
+ product matrix :math:`\mathbf{qq}^\top` where :math:`\mathbf{q}` is the target space distribution.
+ diag_q: array-like, shape (nt,nt)
+ diagonal matrix with values of q on the diagonal.
+ T: array-like, shape (ns,nt)
+ fixed transport plan between (C,Y) and its model
+ p : array-like, shape (ns,)
+ Distribution in the source space (C,Y).
+ q : array-like, shape (nt,)
+ Distribution in the embedding space depicted by the dictionary.
+ alpha: float,
+ Trade-off parameter of Fused Gromov-Wasserstein.
+ reg : float, optional
+ Coefficient of the negative quadratic regularization used to promote sparsity of w.
+
+ Returns
+ -------
+ w: ndarray (D,)
+ linear unmixing of :math:`(\mathbf{C},\mathbf{Y},\mathbf{p})` onto the span of :math:`(C_{dict},Y_{dict})` given OT corresponding to previous unmixing.
+ """
+ convergence_criterion = np.inf
+ current_loss = starting_loss
+ count = 0
+ const_TCT = np.transpose(C.dot(T)).dot(T)
+ ones_ns_d = np.ones(Y.shape)
+
+ while (convergence_criterion > tol) and (count < max_iter):
+ previous_loss = current_loss
+
+ # 1) Compute gradient at current point w
+ # structure
+ grad_w = alpha * np.sum(Cdict * (Cembedded[None, :, :] * const_q[None, :, :] - const_TCT[None, :, :]), axis=(1, 2))
+ # feature
+ grad_w += (1 - alpha) * np.sum(Ydict * (diag_q.dot(Yembedded)[None, :, :] - T.T.dot(Y)[None, :, :]), axis=(1, 2))
+ grad_w -= reg * w
+ grad_w *= 2
+
+ # 2) Conditional gradient direction finding: x= \argmin_x x^T.grad_w
+ min_ = np.min(grad_w)
+ x = (grad_w == min_).astype(np.float64)
+ x /= np.sum(x)
+
+ # 3) Line-search step: solve \argmin_{\gamma \in [0,1]} a*gamma^2 + b*gamma + c
+ gamma, a, b, Cembedded_diff, Yembedded_diff = _linesearch_fused_gromov_wasserstein_unmixing(w, grad_w, x, Y, Cdict, Ydict, Cembedded, Yembedded, T, const_q, const_TCT, ones_ns_d, alpha, reg)
+
+ # 4) Updates: w <-- (1-gamma)*w + gamma*x
+ w += gamma * (x - w)
+ Cembedded += gamma * Cembedded_diff
+ Yembedded += gamma * Yembedded_diff
+ current_loss += a * (gamma**2) + b * gamma
+
+ if previous_loss != 0:
+ convergence_criterion = abs(previous_loss - current_loss) / abs(previous_loss)
+ else:
+ convergence_criterion = abs(previous_loss - current_loss) / 10**(-12)
+ count += 1
+
+ return w, Cembedded, Yembedded, current_loss
+
+
+def _linesearch_fused_gromov_wasserstein_unmixing(w, grad_w, x, Y, Cdict, Ydict, Cembedded, Yembedded, T, const_q, const_TCT, ones_ns_d, alpha, reg, **kwargs):
+ r"""
+ Compute optimal steps for the line search problem of Fused Gromov-Wasserstein linear unmixing
+ .. math::
+ \min_{\gamma \in [0,1]} \alpha \sum_{ijkl} (C_{i,j} - \sum_{d=1}^D z_d(\gamma)C_{dict}[d]_{k,l} )^2 T_{i,k}T_{j,l} \\ + (1-\alpha) \sum_{ij} \| \mathbf{Y_i} - \sum_d z_d(\gamma) \mathbf{Y_{dict}[d]_j} \|_2^2 - reg\| \mathbf{z}(\gamma) \|_2^2
+
+
+ Such that :
+
+ - :math:`\mathbf{z}(\gamma) = (1- \gamma)\mathbf{w} + \gamma \mathbf{x}`
+
+ Parameters
+ ----------
+
+ w : array-like, shape (D,)
+ Unmixing.
+ grad_w : array-like, shape (D, D)
+ Gradient of the reconstruction loss with respect to w.
+ x: array-like, shape (D,)
+ Conditional gradient direction.
+ Y: arrat-like, shape (ns,d)
+ Feature matrix of the input space
+ Cdict : list of D array-like, shape (nt, nt)
+ Metric/Graph cost matrices composing the dictionary on which to embed (C,Y).
+ Each matrix in the dictionary must have the same size (nt,nt).
+ Ydict : list of D array-like, shape (nt, d)
+ Feature matrices composing the dictionary on which to embed (C,Y).
+ Each matrix in the dictionary must have the same size (nt,d).
+ Cembedded: array-like, shape (nt, nt)
+ Embedded structure of (C,Y) onto the dictionary
+ Yembedded: array-like, shape (nt, d)
+ Embedded features of (C,Y) onto the dictionary
+ T: array-like, shape (ns, nt)
+ Fixed transport plan between (C,Y) and its current model.
+ const_q: array-like, shape (nt,nt)
+ product matrix :math:`\mathbf{q}\mathbf{q}^\top` where q is the target space distribution. Used to avoid redundant computations.
+ const_TCT: array-like, shape (nt, nt)
+ :math:`\mathbf{T}^\top \mathbf{C}^\top \mathbf{T}`. Used to avoid redundant computations.
+ ones_ns_d: array-like, shape (ns, d)
+ :math:`\mathbf{1}_{ ns \times d}`. Used to avoid redundant computations.
+ alpha: float,
+ Trade-off parameter of Fused Gromov-Wasserstein.
+ reg : float, optional
+ Coefficient of the negative quadratic regularization used to promote sparsity of w.
+
+ Returns
+ -------
+ gamma: float
+ Optimal value for the line-search step
+ a: float
+ Constant factor appearing in the factorization :math:`a \gamma^2 + b \gamma +c` of the reconstruction loss
+ b: float
+ Constant factor appearing in the factorization :math:`a \gamma^2 + b \gamma +c` of the reconstruction loss
+ Cembedded_diff: numpy array, shape (nt, nt)
+ Difference between structure matrix of models evaluated in :math:`\mathbf{w}` and in :math:`\mathbf{w}`.
+ Yembedded_diff: numpy array, shape (nt, nt)
+ Difference between feature matrix of models evaluated in :math:`\mathbf{w}` and in :math:`\mathbf{w}`.
+ """
+ # polynomial coefficients from quadratic objective (with respect to w) on structures
+ Cembedded_x = np.sum(x[:, None, None] * Cdict, axis=0)
+ Cembedded_diff = Cembedded_x - Cembedded
+ trace_diffx = np.sum(Cembedded_diff * Cembedded_x * const_q)
+ trace_diffw = np.sum(Cembedded_diff * Cembedded * const_q)
+ # Constant factor appearing in the factorization a*gamma^2 + b*g + c of the Gromov-Wasserstein reconstruction loss
+ a_gw = trace_diffx - trace_diffw
+ b_gw = 2 * (trace_diffw - np.sum(Cembedded_diff * const_TCT))
+
+ # polynomial coefficient from quadratic objective (with respect to w) on features
+ Yembedded_x = np.sum(x[:, None, None] * Ydict, axis=0)
+ Yembedded_diff = Yembedded_x - Yembedded
+ # Constant factor appearing in the factorization a*gamma^2 + b*g + c of the Gromov-Wasserstein reconstruction loss
+ a_w = np.sum(ones_ns_d.dot((Yembedded_diff**2).T) * T)
+ b_w = 2 * np.sum(T * (ones_ns_d.dot((Yembedded * Yembedded_diff).T) - Y.dot(Yembedded_diff.T)))
+
+ a = alpha * a_gw + (1 - alpha) * a_w
+ b = alpha * b_gw + (1 - alpha) * b_w
+ if reg != 0:
+ a -= reg * np.sum((x - w)**2)
+ b -= 2 * reg * np.sum(w * (x - w))
+ if a > 0:
+ gamma = min(1, max(0, -b / (2 * a)))
+ elif a + b < 0:
+ gamma = 1
+ else:
+ gamma = 0
+
+ return gamma, a, b, Cembedded_diff, Yembedded_diff
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py
index 5da897d..390c32d 100644
--- a/ot/lp/__init__.py
+++ b/ot/lp/__init__.py
@@ -26,6 +26,8 @@ from ..utils import dist, list_to_array
from ..utils import parmap
from ..backend import get_backend
+
+
__all__ = ['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx', ' emd_1d_sorted',
'emd_1d', 'emd2_1d', 'wasserstein_1d']
@@ -220,7 +222,15 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1):
format
.. note:: This function is backend-compatible and will work on arrays
- from all compatible backends.
+ from all compatible backends. But the algorithm uses the C++ CPU backend
+ which can lead to copy overhead on GPU arrays.
+
+ .. note:: This function will cast the computed transport plan to the data type
+ of the provided input with the following priority: :math:`\mathbf{a}`,
+ then :math:`\mathbf{b}`, then :math:`\mathbf{M}` if marginals are not provided.
+ Casting to an integer tensor might result in a loss of precision.
+ If this behaviour is unwanted, please make sure to provide a
+ floating point input.
Uses the algorithm proposed in :ref:`[1] <references-emd>`.
@@ -287,12 +297,16 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1):
a, b, M = list_to_array(a, b, M)
a0, b0, M0 = a, b, M
+ if len(a0) != 0:
+ type_as = a0
+ elif len(b0) != 0:
+ type_as = b0
+ else:
+ type_as = M0
nx = get_backend(M0, a0, b0)
# convert to numpy
- M = nx.to_numpy(M)
- a = nx.to_numpy(a)
- b = nx.to_numpy(b)
+ M, a, b = nx.to_numpy(M, a, b)
# ensure float64
a = np.asarray(a, dtype=np.float64)
@@ -327,15 +341,23 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1):
u, v = estimate_dual_null_weights(u, v, a, b, M)
result_code_string = check_result(result_code)
+ if not nx.is_floating_point(type_as):
+ warnings.warn(
+ "Input histogram consists of integer. The transport plan will be "
+ "casted accordingly, possibly resulting in a loss of precision. "
+ "If this behaviour is unwanted, please make sure your input "
+ "histogram consists of floating point elements.",
+ stacklevel=2
+ )
if log:
log = {}
log['cost'] = cost
- log['u'] = nx.from_numpy(u, type_as=a0)
- log['v'] = nx.from_numpy(v, type_as=b0)
+ log['u'] = nx.from_numpy(u, type_as=type_as)
+ log['v'] = nx.from_numpy(v, type_as=type_as)
log['warning'] = result_code_string
log['result_code'] = result_code
- return nx.from_numpy(G, type_as=M0), log
- return nx.from_numpy(G, type_as=M0)
+ return nx.from_numpy(G, type_as=type_as), log
+ return nx.from_numpy(G, type_as=type_as)
def emd2(a, b, M, processes=1,
@@ -358,7 +380,16 @@ def emd2(a, b, M, processes=1,
- :math:`\mathbf{a}` and :math:`\mathbf{b}` are the sample weights
.. note:: This function is backend-compatible and will work on arrays
- from all compatible backends.
+ from all compatible backends. But the algorithm uses the C++ CPU backend
+ which can lead to copy overhead on GPU arrays.
+
+ .. note:: This function will cast the computed transport plan and
+ transportation loss to the data type of the provided input with the
+ following priority: :math:`\mathbf{a}`, then :math:`\mathbf{b}`,
+ then :math:`\mathbf{M}` if marginals are not provided.
+ Casting to an integer tensor might result in a loss of precision.
+ If this behaviour is unwanted, please make sure to provide a
+ floating point input.
Uses the algorithm proposed in :ref:`[1] <references-emd2>`.
@@ -428,12 +459,16 @@ def emd2(a, b, M, processes=1,
a, b, M = list_to_array(a, b, M)
a0, b0, M0 = a, b, M
+ if len(a0) != 0:
+ type_as = a0
+ elif len(b0) != 0:
+ type_as = b0
+ else:
+ type_as = M0
nx = get_backend(M0, a0, b0)
# convert to numpy
- M = nx.to_numpy(M)
- a = nx.to_numpy(a)
- b = nx.to_numpy(b)
+ M, a, b = nx.to_numpy(M, a, b)
a = np.asarray(a, dtype=np.float64)
b = np.asarray(b, dtype=np.float64)
@@ -466,15 +501,24 @@ def emd2(a, b, M, processes=1,
result_code_string = check_result(result_code)
log = {}
- G = nx.from_numpy(G, type_as=M0)
+ if not nx.is_floating_point(type_as):
+ warnings.warn(
+ "Input histogram consists of integer. The transport plan will be "
+ "casted accordingly, possibly resulting in a loss of precision. "
+ "If this behaviour is unwanted, please make sure your input "
+ "histogram consists of floating point elements.",
+ stacklevel=2
+ )
+ G = nx.from_numpy(G, type_as=type_as)
if return_matrix:
log['G'] = G
- log['u'] = nx.from_numpy(u, type_as=a0)
- log['v'] = nx.from_numpy(v, type_as=b0)
+ log['u'] = nx.from_numpy(u, type_as=type_as)
+ log['v'] = nx.from_numpy(v, type_as=type_as)
log['warning'] = result_code_string
log['result_code'] = result_code
- cost = nx.set_gradients(nx.from_numpy(cost, type_as=M0),
- (a0, b0, M0), (log['u'], log['v'], G))
+ cost = nx.set_gradients(nx.from_numpy(cost, type_as=type_as),
+ (a0, b0, M0), (log['u'] - nx.mean(log['u']),
+ log['v'] - nx.mean(log['v']), G))
return [cost, log]
else:
def f(b):
@@ -487,10 +531,18 @@ def emd2(a, b, M, processes=1,
if np.any(~asel) or np.any(~bsel):
u, v = estimate_dual_null_weights(u, v, a, b, M)
- G = nx.from_numpy(G, type_as=M0)
- cost = nx.set_gradients(nx.from_numpy(cost, type_as=M0),
- (a0, b0, M0), (nx.from_numpy(u, type_as=a0),
- nx.from_numpy(v, type_as=b0), G))
+ if not nx.is_floating_point(type_as):
+ warnings.warn(
+ "Input histogram consists of integer. The transport plan will be "
+ "casted accordingly, possibly resulting in a loss of precision. "
+ "If this behaviour is unwanted, please make sure your input "
+ "histogram consists of floating point elements.",
+ stacklevel=2
+ )
+ G = nx.from_numpy(G, type_as=type_as)
+ cost = nx.set_gradients(nx.from_numpy(cost, type_as=type_as),
+ (a0, b0, M0), (nx.from_numpy(u - np.mean(u), type_as=type_as),
+ nx.from_numpy(v - np.mean(v), type_as=type_as), G))
check_result(result_code)
return cost
@@ -535,18 +587,18 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None
Parameters
----------
- measures_locations : list of N (k_i,d) numpy.ndarray
+ measures_locations : list of N (k_i,d) array-like
The discrete support of a measure supported on :math:`k_i` locations of a `d`-dimensional space
(:math:`k_i` can be different for each element of the list)
- measures_weights : list of N (k_i,) numpy.ndarray
+ measures_weights : list of N (k_i,) array-like
Numpy arrays where each numpy array has :math:`k_i` non-negatives values summing to one
representing the weights of each discrete input measure
- X_init : (k,d) np.ndarray
+ X_init : (k,d) array-like
Initialization of the support locations (on `k` atoms) of the barycenter
- b : (k,) np.ndarray
+ b : (k,) array-like
Initialization of the weights of the barycenter (non-negatives, sum to 1)
- weights : (N,) np.ndarray
+ weights : (N,) array-like
Initialization of the coefficients of the barycenter (non-negatives, sum to 1)
numItermax : int, optional
@@ -564,7 +616,7 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None
Returns
-------
- X : (k,d) np.ndarray
+ X : (k,d) array-like
Support locations (on k atoms) of the barycenter
@@ -577,15 +629,17 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None
"""
+ nx = get_backend(*measures_locations,*measures_weights,X_init)
+
iter_count = 0
N = len(measures_locations)
k = X_init.shape[0]
d = X_init.shape[1]
if b is None:
- b = np.ones((k,)) / k
+ b = nx.ones((k,),type_as=X_init) / k
if weights is None:
- weights = np.ones((N,)) / N
+ weights = nx.ones((N,),type_as=X_init) / N
X = X_init
@@ -596,15 +650,15 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None
while (displacement_square_norm > stopThr and iter_count < numItermax):
- T_sum = np.zeros((k, d))
+ T_sum = nx.zeros((k, d),type_as=X_init)
+
- for (measure_locations_i, measure_weights_i, weight_i) in zip(measures_locations, measures_weights,
- weights.tolist()):
+ for (measure_locations_i, measure_weights_i, weight_i) in zip(measures_locations, measures_weights, weights):
M_i = dist(X, measure_locations_i)
T_i = emd(b, measure_weights_i, M_i, numThreads=numThreads)
- T_sum = T_sum + weight_i * np.reshape(1. / b, (-1, 1)) * np.matmul(T_i, measure_locations_i)
+ T_sum = T_sum + weight_i * 1. / b[:,None] * nx.dot(T_i, measure_locations_i)
- displacement_square_norm = np.sum(np.square(T_sum - X))
+ displacement_square_norm = nx.sum((T_sum - X)**2)
if log:
displacement_square_norms.append(displacement_square_norm)
@@ -620,3 +674,4 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None
return X, log_dict
else:
return X
+
diff --git a/ot/lp/cvx.py b/ot/lp/cvx.py
index 869d450..fbf3c0e 100644
--- a/ot/lp/cvx.py
+++ b/ot/lp/cvx.py
@@ -11,7 +11,6 @@ import numpy as np
import scipy as sp
import scipy.sparse as sps
-
try:
import cvxopt
from cvxopt import solvers, matrix, spmatrix
diff --git a/ot/optim.py b/ot/optim.py
index f25e2c9..5a1d605 100644
--- a/ot/optim.py
+++ b/ot/optim.py
@@ -9,12 +9,19 @@ Generic solvers for regularized OT
# License: MIT License
import numpy as np
-from scipy.optimize.linesearch import scalar_search_armijo
+import warnings
from .lp import emd
from .bregman import sinkhorn
-from ot.utils import list_to_array
+from .utils import list_to_array
from .backend import get_backend
+with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ try:
+ from scipy.optimize import scalar_search_armijo
+ except ImportError:
+ from scipy.optimize.linesearch import scalar_search_armijo
+
# The corresponding scipy function does not work for matrices
diff --git a/ot/partial.py b/ot/partial.py
index b7093e4..0a9e450 100755
--- a/ot/partial.py
+++ b/ot/partial.py
@@ -7,7 +7,6 @@ Partial OT solvers
# License: MIT License
import numpy as np
-
from .lp import emd
@@ -29,7 +28,8 @@ def partial_wasserstein_lagrange(a, b, M, reg_m=None, nb_dummies=1, log=False,
\gamma &\geq 0
- \mathbf{1}^T \gamma^T \mathbf{1} = m &\leq \min\{\|\mathbf{a}\|_1, \|\mathbf{b}\|_1\}
+ \mathbf{1}^T \gamma^T \mathbf{1} = m &
+ \leq \min\{\|\mathbf{a}\|_1, \|\mathbf{b}\|_1\}
or equivalently (see Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X.
@@ -50,7 +50,8 @@ def partial_wasserstein_lagrange(a, b, M, reg_m=None, nb_dummies=1, log=False,
- :math:`\lambda` is the lagrangian cost. Tuning its value allows attaining
a given mass to be transported `m`
- The formulation of the problem has been proposed in :ref:`[28] <references-partial-wasserstein-lagrange>`
+ The formulation of the problem has been proposed in
+ :ref:`[28] <references-partial-wasserstein-lagrange>`
Parameters
@@ -261,7 +262,7 @@ def partial_wasserstein(a, b, M, m=None, nb_dummies=1, log=False, **kwargs):
b_extended = np.append(b, [(np.sum(a) - m) / nb_dummies] * nb_dummies)
a_extended = np.append(a, [(np.sum(b) - m) / nb_dummies] * nb_dummies)
M_extended = np.zeros((len(a_extended), len(b_extended)))
- M_extended[-nb_dummies:, -nb_dummies:] = np.max(M) * 1e5
+ M_extended[-nb_dummies:, -nb_dummies:] = np.max(M) * 2
M_extended[:len(a), :len(b)] = M
gamma, log_emd = emd(a_extended, b_extended, M_extended, log=True,
@@ -455,7 +456,8 @@ def partial_gromov_wasserstein(C1, C2, p, q, m=None, nb_dummies=1, G0=None,
- :math:`\mathbf{a}` and :math:`\mathbf{b}` are the sample weights
- `m` is the amount of mass to be transported
- The formulation of the problem has been proposed in :ref:`[29] <references-partial-gromov-wasserstein>`
+ The formulation of the problem has been proposed in
+ :ref:`[29] <references-partial-gromov-wasserstein>`
Parameters
@@ -469,7 +471,8 @@ def partial_gromov_wasserstein(C1, C2, p, q, m=None, nb_dummies=1, G0=None,
q : ndarray, shape (nt,)
Distribution in the target space
m : float, optional
- Amount of mass to be transported (default: :math:`\min\{\|\mathbf{p}\|_1, \|\mathbf{q}\|_1\}`)
+ Amount of mass to be transported
+ (default: :math:`\min\{\|\mathbf{p}\|_1, \|\mathbf{q}\|_1\}`)
nb_dummies : int, optional
Number of dummy points to add (avoid instabilities in the EMD solver)
G0 : ndarray, shape (ns, nt), optional
@@ -623,16 +626,19 @@ def partial_gromov_wasserstein2(C1, C2, p, q, m=None, nb_dummies=1, G0=None,
\gamma &\geq 0
- \mathbf{1}^T \gamma^T \mathbf{1} = m &\leq \min\{\|\mathbf{a}\|_1, \|\mathbf{b}\|_1\}
+ \mathbf{1}^T \gamma^T \mathbf{1} = m
+ &\leq \min\{\|\mathbf{a}\|_1, \|\mathbf{b}\|_1\}
where :
- :math:`\mathbf{M}` is the metric cost matrix
- - :math:`\Omega` is the entropic regularization term, :math:`\Omega(\gamma) = \sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - :math:`\Omega` is the entropic regularization term,
+ :math:`\Omega(\gamma) = \sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- :math:`\mathbf{a}` and :math:`\mathbf{b}` are the sample weights
- `m` is the amount of mass to be transported
- The formulation of the problem has been proposed in :ref:`[29] <references-partial-gromov-wasserstein2>`
+ The formulation of the problem has been proposed in
+ :ref:`[29] <references-partial-gromov-wasserstein2>`
Parameters
@@ -646,7 +652,8 @@ def partial_gromov_wasserstein2(C1, C2, p, q, m=None, nb_dummies=1, G0=None,
q : ndarray, shape (nt,)
Distribution in the target space
m : float, optional
- Amount of mass to be transported (default: :math:`\min\{\|\mathbf{p}\|_1, \|\mathbf{q}\|_1\}`)
+ Amount of mass to be transported
+ (default: :math:`\min\{\|\mathbf{p}\|_1, \|\mathbf{q}\|_1\}`)
nb_dummies : int, optional
Number of dummy points to add (avoid instabilities in the EMD solver)
G0 : ndarray, shape (ns, nt), optional
@@ -728,21 +735,25 @@ def entropic_partial_wasserstein(a, b, M, reg, m=None, numItermax=1000,
The function considers the following problem:
.. math::
- \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + \mathrm{reg} \cdot\Omega(\gamma)
+ \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma,
+ \mathbf{M} \rangle_F + \mathrm{reg} \cdot\Omega(\gamma)
s.t. \gamma \mathbf{1} &\leq \mathbf{a} \\
\gamma^T \mathbf{1} &\leq \mathbf{b} \\
\gamma &\geq 0 \\
- \mathbf{1}^T \gamma^T \mathbf{1} = m &\leq \min\{\|\mathbf{a}\|_1, \|\mathbf{b}\|_1\} \\
+ \mathbf{1}^T \gamma^T \mathbf{1} = m
+ &\leq \min\{\|\mathbf{a}\|_1, \|\mathbf{b}\|_1\} \\
where :
- :math:`\mathbf{M}` is the metric cost matrix
- - :math:`\Omega` is the entropic regularization term, :math:`\Omega=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - :math:`\Omega` is the entropic regularization term,
+ :math:`\Omega=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- :math:`\mathbf{a}` and :math:`\mathbf{b}` are the sample weights
- `m` is the amount of mass to be transported
- The formulation of the problem has been proposed in :ref:`[3] <references-entropic-partial-wasserstein>` (prop. 5)
+ The formulation of the problem has been proposed in
+ :ref:`[3] <references-entropic-partial-wasserstein>` (prop. 5)
Parameters
@@ -829,12 +840,23 @@ def entropic_partial_wasserstein(a, b, M, reg, m=None, numItermax=1000,
np.multiply(K, m / np.sum(K), out=K)
err, cpt = 1, 0
+ q1 = np.ones(K.shape)
+ q2 = np.ones(K.shape)
+ q3 = np.ones(K.shape)
while (err > stopThr and cpt < numItermax):
Kprev = K
+ K = K * q1
K1 = np.dot(np.diag(np.minimum(a / np.sum(K, axis=1), dx)), K)
+ q1 = q1 * Kprev / K1
+ K1prev = K1
+ K1 = K1 * q2
K2 = np.dot(K1, np.diag(np.minimum(b / np.sum(K1, axis=0), dy)))
+ q2 = q2 * K1prev / K2
+ K2prev = K2
+ K2 = K2 * q3
K = K2 * (m / np.sum(K2))
+ q3 = q3 * K2prev / K
if np.any(np.isnan(K)) or np.any(np.isinf(K)):
print('Warning: numerical errors at iteration', cpt)
@@ -861,7 +883,8 @@ def entropic_partial_gromov_wasserstein(C1, C2, p, q, reg, m=None, G0=None,
numItermax=1000, tol=1e-7, log=False,
verbose=False):
r"""
- Returns the partial Gromov-Wasserstein transport between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})`
+ Returns the partial Gromov-Wasserstein transport between
+ :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})`
The function solves the following optimization problem:
@@ -877,7 +900,8 @@ def entropic_partial_gromov_wasserstein(C1, C2, p, q, reg, m=None, G0=None,
\gamma^T \mathbf{1} &\leq \mathbf{b}
- \mathbf{1}^T \gamma^T \mathbf{1} = m &\leq \min\{\|\mathbf{a}\|_1, \|\mathbf{b}\|_1\}
+ \mathbf{1}^T \gamma^T \mathbf{1} = m
+ &\leq \min\{\|\mathbf{a}\|_1, \|\mathbf{b}\|_1\}
where :
@@ -885,10 +909,13 @@ def entropic_partial_gromov_wasserstein(C1, C2, p, q, reg, m=None, G0=None,
- :math:`\mathbf{C_2}` is the metric cost matrix in the target space
- :math:`\mathbf{p}` and :math:`\mathbf{q}` are the sample weights
- `L`: quadratic loss function
- - :math:`\Omega` is the entropic regularization term, :math:`\Omega=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - :math:`\Omega` is the entropic regularization term,
+ :math:`\Omega=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- `m` is the amount of mass to be transported
- The formulation of the GW problem has been proposed in :ref:`[12] <references-entropic-partial-gromov-wassertein>` and the partial GW in :ref:`[29] <references-entropic-partial-gromov-wassertein>`
+ The formulation of the GW problem has been proposed in
+ :ref:`[12] <references-entropic-partial-gromov-wassertein>` and the
+ partial GW in :ref:`[29] <references-entropic-partial-gromov-wassertein>`
Parameters
----------
@@ -903,7 +930,8 @@ def entropic_partial_gromov_wasserstein(C1, C2, p, q, reg, m=None, G0=None,
reg: float
entropic regularization parameter
m : float, optional
- Amount of mass to be transported (default: :math:`\min\{\|\mathbf{p}\|_1, \|\mathbf{q}\|_1\}`)
+ Amount of mass to be transported (default:
+ :math:`\min\{\|\mathbf{p}\|_1, \|\mathbf{q}\|_1\}`)
G0 : ndarray, shape (ns, nt), optional
Initialisation of the transportation matrix
numItermax : int, optional
@@ -1005,13 +1033,15 @@ def entropic_partial_gromov_wasserstein2(C1, C2, p, q, reg, m=None, G0=None,
numItermax=1000, tol=1e-7, log=False,
verbose=False):
r"""
- Returns the partial Gromov-Wasserstein discrepancy between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})`
+ Returns the partial Gromov-Wasserstein discrepancy between
+ :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})`
The function solves the following optimization problem:
.. math::
- GW = \min_{\gamma} \quad \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l})\cdot
- \gamma_{i,j}\cdot\gamma_{k,l} + \mathrm{reg} \cdot\Omega(\gamma)
+ GW = \min_{\gamma} \quad \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k},
+ \mathbf{C_2}_{j,l})\cdot
+ \gamma_{i,j}\cdot\gamma_{k,l} + \mathrm{reg} \cdot\Omega(\gamma)
.. math::
s.t. \ \gamma &\geq 0
@@ -1028,10 +1058,13 @@ def entropic_partial_gromov_wasserstein2(C1, C2, p, q, reg, m=None, G0=None,
- :math:`\mathbf{C_2}` is the metric cost matrix in the target space
- :math:`\mathbf{p}` and :math:`\mathbf{q}` are the sample weights
- `L` : quadratic loss function
- - :math:`\Omega` is the entropic regularization term, :math:`\Omega=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - :math:`\Omega` is the entropic regularization term,
+ :math:`\Omega=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- `m` is the amount of mass to be transported
- The formulation of the GW problem has been proposed in :ref:`[12] <references-entropic-partial-gromov-wassertein2>` and the partial GW in :ref:`[29] <references-entropic-partial-gromov-wassertein2>`
+ The formulation of the GW problem has been proposed in
+ :ref:`[12] <references-entropic-partial-gromov-wassertein2>` and the
+ partial GW in :ref:`[29] <references-entropic-partial-gromov-wassertein2>`
Parameters
@@ -1047,7 +1080,8 @@ def entropic_partial_gromov_wasserstein2(C1, C2, p, q, reg, m=None, G0=None,
reg: float
entropic regularization parameter
m : float, optional
- Amount of mass to be transported (default: :math:`\min\{\|\mathbf{p}\|_1, \|\mathbf{q}\|_1\}`)
+ Amount of mass to be transported (default:
+ :math:`\min\{\|\mathbf{p}\|_1, \|\mathbf{q}\|_1\}`)
G0 : ndarray, shape (ns, nt), optional
Initialisation of the transportation matrix
numItermax : int, optional
diff --git a/ot/plot.py b/ot/plot.py
index 2208c90..8ade2eb 100644
--- a/ot/plot.py
+++ b/ot/plot.py
@@ -85,8 +85,13 @@ def plot2D_samples_mat(xs, xt, G, thr=1e-8, **kwargs):
if ('color' not in kwargs) and ('c' not in kwargs):
kwargs['color'] = 'k'
mx = G.max()
+ if 'alpha' in kwargs:
+ scale = kwargs['alpha']
+ del kwargs['alpha']
+ else:
+ scale = 1
for i in range(xs.shape[0]):
for j in range(xt.shape[0]):
if G[i, j] / mx > thr:
pl.plot([xs[i, 0], xt[j, 0]], [xs[i, 1], xt[j, 1]],
- alpha=G[i, j] / mx, **kwargs)
+ alpha=G[i, j] / mx * scale, **kwargs)
diff --git a/ot/regpath.py b/ot/regpath.py
index 269937a..e745288 100644
--- a/ot/regpath.py
+++ b/ot/regpath.py
@@ -11,34 +11,48 @@ import scipy.sparse as sp
def recast_ot_as_lasso(a, b, C):
- r"""This function recasts the l2-penalized UOT problem as a Lasso problem
+ r"""This function recasts the l2-penalized UOT problem as a Lasso problem.
+
+ Recall the l2-penalized UOT problem defined in
+ :ref:`[41] <references-regpath>`
- Recall the l2-penalized UOT problem defined in [Chapel et al., 2021]
.. math::
- UOT = \min_T <C, T> + \lambda \|T 1_m - a\|_2^2 +
- \lambda \|T^T 1_n - b\|_2^2
+ \text{UOT}_{\lambda} = \min_T <C, T> + \lambda \|T 1_m -
+ \mathbf{a}\|_2^2 +
+ \lambda \|T^T 1_n - \mathbf{b}\|_2^2
+
s.t.
T \geq 0
+
where :
- - C is the (dim_a, dim_b) metric cost matrix
- - :math:`\lambda` is the l2-regularization coefficient
- - a and b are source and target distributions
- - T is the transport plan to optimize
- The problem above can be reformulated to a non-negative penalized
+ - :math:`C` is the cost matrix
+ - :math:`\lambda` is the l2-regularization parameter
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are the source and target \
+ distributions
+ - :math:`T` is the transport plan to optimize
+
+ The problem above can be reformulated as a non-negative penalized
linear regression problem, particularly Lasso
+
.. math::
- UOT2 = \min_t \gamma c^T t + 0.5 * \|H t - y\|_2^2
+ \text{UOT2}_{\lambda} = \min_{\mathbf{t}} \gamma \mathbf{c}^T
+ \mathbf{t} + 0.5 * \|H \mathbf{t} - \mathbf{y}\|_2^2
+
s.t.
- t \geq 0
+ \mathbf{t} \geq 0
+
where :
- - c is a (dim_a * dim_b, ) metric cost vector (flattened version of C)
- - :math:`\gamma = 1/\lambda` is the l2-regularization coefficient
- - y is the concatenation of vectors a and b, defined as y^T = [a^T b^T]
- - H is a (dim_a + dim_b, dim_a * dim_b) metric matrix,
- see [Chapel et al., 2021] for the design of H. The matrix product H t
- computes both the source marginal and the target marginal.
- - t is a (dim_a * dim_b, ) metric vector (flattened version of T)
+
+ - :math:`\mathbf{c}` is the flattened version of the cost matrix :math:`C`
+ - :math:`\mathbf{y}` is the concatenation of vectors :math:`\mathbf{a}` \
+ and :math:`\mathbf{b}`
+ - :math:`H` is a metric matrix, see :ref:`[41] <references-regpath>` for \
+ the design of :math:`H`. The matrix product :math:`H\mathbf{t}` \
+ computes both the source marginal and the target marginals.
+ - :math:`\mathbf{t}` is the flattened version of the transport plan \
+ :math:`T`
+
Parameters
----------
a : np.ndarray (dim_a,)
@@ -47,14 +61,16 @@ def recast_ot_as_lasso(a, b, C):
Histogram of dimension dim_b
C : np.ndarray, shape (dim_a, dim_b)
Cost matrix
+
Returns
-------
H : np.ndarray (dim_a+dim_b, dim_a*dim_b)
- Auxiliary matrix constituted by 0 and 1
+ Design matrix that contains only 0 and 1
y : np.ndarray (ns + nt, )
- Concatenation of histogram a and histogram b
+ Concatenation of histograms :math:`\mathbf{a}` and :math:`\mathbf{b}`
c : np.ndarray (ns * nt, )
- Flattened array of cost matrix
+ Flattened array of the cost matrix
+
Examples
--------
>>> import ot
@@ -73,12 +89,12 @@ def recast_ot_as_lasso(a, b, C):
>>> c
array([16., 25., 28., 16., 40., 36.])
+
References
----------
- [Chapel et al., 2021]:
- Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021).
+ .. [41] Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021).
Unbalanced optimal transport through non-negative penalized
- linear regression.
+ linear regression. NeurIPS.
"""
dim_a = np.shape(a)[0]
@@ -97,33 +113,47 @@ def recast_ot_as_lasso(a, b, C):
def recast_semi_relaxed_as_lasso(a, b, C):
- r"""This function recasts the semi-relaxed l2-UOT problem as Lasso problem
+ r"""This function recasts the semi-relaxed l2-UOT problem as Lasso problem.
.. math::
- semi-relaxed UOT = \min_T <C, T> + \lambda \|T 1_m - a\|_2^2
+
+ \text{semi-relaxed UOT} = \min_T <C, T>
+ + \lambda \|T 1_m - \mathbf{a}\|_2^2
+
s.t.
- T^T 1_n = b
- t \geq 0
+ T^T 1_n = \mathbf{b}
+
+ \mathbf{t} \geq 0
+
where :
- - C is the (dim_a, dim_b) metric cost matrix
- - :math:`\lambda` is the l2-regularization coefficient
- - a and b are source and target distributions
- - T is the transport plan to optimize
+
+ - :math:`C` is the metric cost matrix
+ - :math:`\lambda` is the l2-regularization parameter
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are the source and target \
+ distributions
+ - :math:`T` is the transport plan to optimize
The problem above can be reformulated as follows
+
.. math::
- semi-relaxed UOT2 = \min_t \gamma c^T t + 0.5 * \|H_r t - a\|_2^2
+ \text{semi-relaxed UOT2} = \min_t \gamma \mathbf{c}^T t
+ + 0.5 * \|H_r \mathbf{t} - \mathbf{a}\|_2^2
+
s.t.
- H_c t = b
- t \geq 0
+ H_c \mathbf{t} = \mathbf{b}
+
+ \mathbf{t} \geq 0
+
where :
- - c is a (dim_a * dim_b, ) metric cost vector (flattened version of C)
- - :math:`\gamma = 1/\lambda` is the l2-regularization coefficient
- - H_r is a (dim_a, dim_a * dim_b) metric matrix,
- which computes the sum along the rows of transport plan T
- - H_c is a (dim_b, dim_a * dim_b) metric matrix,
- which computes the sum along the columns of transport plan T
- - t is a (dim_a * dim_b, ) metric vector (flattened version of T)
+
+ - :math:`\mathbf{c}` is flattened version of the cost matrix :math:`C`
+ - :math:`\gamma = 1/\lambda` is the l2-regularization parameter
+ - :math:`H_r` is a metric matrix which computes the sum along the \
+ rows of the transport plan :math:`T`
+ - :math:`H_c` is a metric matrix which computes the sum along the \
+ columns of the transport plan :math:`T`
+ - :math:`\mathbf{t}` is the flattened version of :math:`T`
+
Parameters
----------
a : np.ndarray (dim_a,)
@@ -132,16 +162,18 @@ def recast_semi_relaxed_as_lasso(a, b, C):
Histogram of dimension dim_b
C : np.ndarray, shape (dim_a, dim_b)
Cost matrix
+
Returns
-------
Hr : np.ndarray (dim_a, dim_a * dim_b)
Auxiliary matrix constituted by 0 and 1, which computes
- the sum along the rows of transport plan T
+ the sum along the rows of transport plan :math:`T`
Hc : np.ndarray (dim_b, dim_a * dim_b)
Auxiliary matrix constituted by 0 and 1, which computes
- the sum along the columns of transport plan T
+ the sum along the columns of transport plan :math:`T`
c : np.ndarray (ns * nt, )
- Flattened array of cost matrix
+ Flattened array of the cost matrix
+
Examples
--------
>>> import ot
@@ -179,49 +211,60 @@ def recast_semi_relaxed_as_lasso(a, b, C):
def ot_next_gamma(phi, delta, HtH, Hty, c, active_index, current_gamma):
r""" This function computes the next value of gamma if a variable
- will be added in next iteration of the regularization path
+ is added in the next iteration of the regularization path.
We look for the largest value of gamma such that
the gradient of an inactive variable vanishes
+
.. math::
- \max_{i \in \bar{A}} \frac{h_i^T(H_A \phi - y)}{h_i^T H_A \delta - c_i}
+ \max_{i \in \bar{A}} \frac{\mathbf{h}_i^T(H_A \phi - \mathbf{y})}
+ {\mathbf{h}_i^T H_A \delta - \mathbf{c}_i}
+
where :
+
- A is the current active set
- - h_i is the ith column of auxiliary matrix H
- - H_A is the sub-matrix constructed by the columns of H
- whose indices belong to the active set A
- - c_i is the ith element of cost vector c
- - y is the concatenation of source and target distribution
- - :math:`\phi` is the intercept of the solutions in current iteration
- - :math:`\delta` is the slope of the solutions in current iteration
+ - :math:`\mathbf{h}_i` is the :math:`i` th column of the design \
+ matrix :math:`{H}`
+ - :math:`{H}_A` is the sub-matrix constructed by the columns of \
+ :math:`{H}` whose indices belong to the active set A
+ - :math:`\mathbf{c}_i` is the :math:`i` th element of the cost vector \
+ :math:`\mathbf{c}`
+ - :math:`\mathbf{y}` is the concatenation of the source and target \
+ distributions
+ - :math:`\phi` is the intercept of the solutions at the current iteration
+ - :math:`\delta` is the slope of the solutions at the current iteration
+
Parameters
----------
- phi : np.ndarray (|A|, )
- Intercept of the solutions in current iteration (t is piecewise linear)
- delta : np.ndarray (|A|, )
- Slope of the solutions in current iteration (t is piecewise linear)
+ phi : np.ndarray (size(A), )
+ Intercept of the solutions at the current iteration
+ delta : np.ndarray (size(A), )
+ Slope of the solutions at the current iteration
HtH : np.ndarray (dim_a * dim_b, dim_a * dim_b)
- Matrix product of H^T H
+ Matrix product of :math:`{H}^T {H}`
Hty : np.ndarray (dim_a + dim_b, )
- Matrix product of H^T y
+ Matrix product of :math:`{H}^T \mathbf{y}`
c: np.ndarray (dim_a * dim_b, )
- Flattened array of cost matrix C
+ Flattened array of the cost matrix :math:`{C}`
active_index : list
Indices of active variables
current_gamma : float
- Value of regularization coefficient at the start of current iteration
+ Value of the regularization parameter at the beginning of the current \
+ iteration
+
Returns
-------
next_gamma : float
Value of gamma if a variable is added to active set in next iteration
next_active_index : int
Index of variable to be activated
+
+
References
----------
- [Chapel et al., 2021]:
- Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021).
+ .. [41] Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021).
Unbalanced optimal transport through non-negative penalized
- linear regression.
+ linear regression. NeurIPS.
"""
M = (HtH[:, active_index].dot(phi) - Hty) / \
(HtH[:, active_index].dot(delta) - c + 1e-16)
@@ -237,56 +280,65 @@ def semi_relaxed_next_gamma(phi, delta, phi_u, delta_u, HrHr, Hc, Hra,
By taking the Lagrangian form of the problem, we obtain a similar update
as the two-sided relaxed UOT
+
.. math::
- \max_{i \in \bar{A}} \frac{h_{r i}^T(H_{r A} \phi - a) + h_{c i}^T
- \phi_u}{h_{r i}^T H_{r A} \delta + h_{c i} \delta_u - c_i}
+
+ \max_{i \in \bar{A}} \frac{\mathbf{h}_{ri}^T(H_{rA} \phi - \mathbf{a})
+ + \mathbf{h}_{c i}^T\phi_u}{\mathbf{h}_{r i}^T H_{r A} \delta + \
+ \mathbf{h}_{c i} \delta_u - \mathbf{c}_i}
+
where :
+
- A is the current active set
- - h_{r i} is the ith column of the matrix H_r
- - h_{c i} is the ith column of the matrix H_c
- - H_{r A} is the sub-matrix constructed by the columns of H_r
- whose indices belong to the active set A
- - c_i is the ith element of cost vector c
- - y is the concatenation of source and target distribution
+ - :math:`\mathbf{h}_{r i}` is the ith column of the matrix :math:`H_r`
+ - :math:`\mathbf{h}_{c i}` is the ith column of the matrix :math:`H_c`
+ - :math:`H_{r A}` is the sub-matrix constructed by the columns of \
+ :math:`H_r` whose indices belong to the active set A
+ - :math:`\mathbf{c}_i` is the :math:`i` th element of cost vector \
+ :math:`\mathbf{c}`
- :math:`\phi` is the intercept of the solutions in current iteration
- :math:`\delta` is the slope of the solutions in current iteration
- - :math:`\phi_u` is the intercept of Lagrange parameter in current
- iteration
- - :math:`\delta_u` is the slope of Lagrange parameter in current iteration
+ - :math:`\phi_u` is the intercept of Lagrange parameter at the \
+ current iteration
+ - :math:`\delta_u` is the slope of Lagrange parameter at the \
+ current iteration
+
Parameters
----------
- phi : np.ndarray (|A|, )
- Intercept of the solutions in current iteration (t is piecewise linear)
- delta : np.ndarray (|A|, )
- Slope of the solutions in current iteration (t is piecewise linear)
+ phi : np.ndarray (size(A), )
+ Intercept of the solutions at the current iteration
+ delta : np.ndarray (size(A), )
+ Slope of the solutions at the current iteration
phi_u : np.ndarray (dim_b, )
- Intercept of the Lagrange parameter in current iteration (also linear)
+ Intercept of the Lagrange parameter at the current iteration
delta_u : np.ndarray (dim_b, )
- Slope of the Lagrange parameter in current iteration (also linear)
+ Slope of the Lagrange parameter at the current iteration
HrHr : np.ndarray (dim_a * dim_b, dim_a * dim_b)
- Matrix product of H_r^T H_r
+ Matrix product of :math:`H_r^T H_r`
Hc : np.ndarray (dim_b, dim_a * dim_b)
- Matrix that computes the sum along the columns of transport plan T
+ Matrix that computes the sum along the columns of the transport plan \
+ :math:`T`
Hra : np.ndarray (dim_a * dim_b, )
- Matrix product of H_r^T a
+ Matrix product of :math:`H_r^T \mathbf{a}`
c: np.ndarray (dim_a * dim_b, )
- Flattened array of cost matrix C
+ Flattened array of cost matrix :math:`C`
active_index : list
Indices of active variables
current_gamma : float
Value of regularization coefficient at the start of current iteration
+
Returns
-------
next_gamma : float
Value of gamma if a variable is added to active set in next iteration
next_active_index : int
Index of variable to be activated
+
References
----------
- [Chapel et al., 2021]:
- Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021).
+ .. [41] Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021).
Unbalanced optimal transport through non-negative penalized
- linear regression.
+ linear regression. NeurIPS.
"""
M = (HrHr[:, active_index].dot(phi) - Hra + Hc.T.dot(phi_u)) / \
@@ -297,37 +349,48 @@ def semi_relaxed_next_gamma(phi, delta, phi_u, delta_u, HrHr, Hc, Hra,
def compute_next_removal(phi, delta, current_gamma):
- r""" This function computes the next value of gamma if a variable
- is removed in next iteration of regularization path
+ r""" This function computes the next gamma value if a variable
+ is removed at the next iteration of the regularization path.
+
+ We look for the largest value of the regularization parameter such that
+ an element of the current solution vanishes
- We look for the largest value of gamma such that
- an element of current solution vanishes
.. math::
\max_{j \in A} \frac{\phi_j}{\delta_j}
+
where :
+
- A is the current active set
- - phi_j is the jth element of the intercept of current solution
- - delta_j is the jth elemnt of the slope of current solution
+ - :math:`\phi_j` is the :math:`j` th element of the intercept of the \
+ current solution
+ - :math:`\delta_j` is the :math:`j` th element of the slope of the \
+ current solution
+
+
Parameters
----------
- phi : np.ndarray (|A|, )
- Intercept of the solutions in current iteration (t is piecewise linear)
- delta : np.ndarray (|A|, )
- Slope of the solutions in current iteration (t is piecewise linear)
+ phi : ndarray, shape (size(A), )
+ Intercept of the solution at the current iteration
+ delta : ndarray, shape (size(A), )
+ Slope of the solution at the current iteration
current_gamma : float
- Value of regularization coefficient at the start of current iteration
+ Value of the regularization parameter at the beginning of the \
+ current iteration
+
Returns
-------
next_removal_gamma : float
- Value of gamma if a variable is removed in next iteration
+ Gamma value if a variable is removed at the next iteration
next_removal_index : int
- Index of the variable to remove in next iteration
+ Index of the variable to be removed at the next iteration
+
+
+ .. _references-regpath:
References
----------
- [Chapel et al., 2021]:
- Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021).
+ .. [41] Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021).
Unbalanced optimal transport through non-negative penalized
- linear regression.
+ linear regression. NeurIPS.
"""
r_candidate = phi / (delta - 1e-16)
r_candidate[r_candidate >= (1 - 1e-8) * current_gamma] = 0
@@ -335,56 +398,74 @@ def compute_next_removal(phi, delta, current_gamma):
def complement_schur(M_current, b, d, id_pop):
- r""" This function computes the inverse of matrix in regularization path
- using Schur complement
+ r""" This function computes the inverse of the design matrix in the \
+ regularization path using the Schur complement. Two cases may arise:
+
+ Case 1: one variable is added to the active set
+
- Two cases may arise: Firstly one variable is added to the active set
.. math::
M_{k+1}^{-1} =
\begin{bmatrix}
- M_{k}^{-1} + s^{-1} M_{k}^{-1} b b^T M_{k}^{-1} & -s^{-1} \\
- - s^{-1} b^T M_{k}^{-1} & s^{-1}
+ M_{k}^{-1} + s^{-1} M_{k}^{-1} \mathbf{b} \mathbf{b}^T M_{k}^{-1} \
+ & - M_{k}^{-1} \mathbf{b} s^{-1} \\
+ - s^{-1} \mathbf{b}^T M_{k}^{-1} & s^{-1}
\end{bmatrix}
+
+
where :
- - :math:`M_k^{-1}` is the inverse of matrix in previous iteration and
- :math:`M_k` is the upper left block matrix in Schur formulation
- - b is the upper right block matrix in Schur formulation. In our case,
- b is reduced to a column vector and b^T is the lower left block matrix
- - s is the Schur complement, given by
- :math:`s = d - b^T M_{k}^{-1} b` in our case
-
- Secondly, one variable is removed from the active set
+
+ - :math:`M_k^{-1}` is the inverse of the design matrix :math:`H_A^tH_A` \
+ of the previous iteration
+ - :math:`\mathbf{b}` is the last column of :math:`M_{k}`
+ - :math:`s` is the Schur complement, given by \
+ :math:`s = \mathbf{d} - \mathbf{b}^T M_{k}^{-1} \mathbf{b}`
+
+ Case 2: one variable is removed from the active set.
+
.. math::
- M_{k+1}^{-1} = M^{-1}_{A_k \backslash q} -
+ M_{k+1}^{-1} = M^{-1}_{k \backslash q} -
\frac{r_{-q,q} r^{T}_{-q,q}}{r_{q,q}}
+
where :
- - q is the index of column and row to delete
- - :math:`M^{-1}_{A_k \backslash q}` is the previous inverse matrix
- without qth column and qth row
- - r_{-q,q} is the qth column of :math:`M^{-1}_{k}` without the qth element
- - r_{q, q} is the element of qth column and qth row in :math:`M^{-1}_{k}`
+
+ - :math:`q` is the index of column and row to delete
+ - :math:`M^{-1}_{k \backslash q}` is the previous inverse matrix deprived \
+ of the :math:`q` th column and :math:`q` th row
+ - :math:`r_{-q,q}` is the :math:`q` th column of :math:`M^{-1}_{k}` \
+ without the :math:`q` th element
+ - :math:`r_{q, q}` is the element of :math:`q` th column and :math:`q` th \
+ row in :math:`M^{-1}_{k}`
+
+
Parameters
----------
- M_current : np.ndarray (|A|-1, |A|-1)
- Inverse matrix in previous iteration
- b : np.ndarray (|A|-1, )
- Upper right matrix in Schur complement, a column vector in our case
+ M_current : ndarray, shape (size(A)-1, size(A)-1)
+ Inverse matrix of :math:`H_A^tH_A` at the previous iteration, with \
+ size(A) the size of the active set
+ b : ndarray, shape (size(A)-1, )
+ None for case 2 (removal), last column of :math:`M_{k}` for case 1 \
+ (addition)
d : float
- Lower right matrix in Schur complement, a scalar in our case
- id_pop
+ should be equal to 2 when UOT and 1 for the semi-relaxed OT
+ id_pop : int
Index of the variable to be removed, equal to -1
- if none of the variables is deleted in current iteration
+ if no variable is deleted at the current iteration
+
+
Returns
-------
- M : np.ndarray (|A|, |A|)
- Inverse matrix needed in current iteration
+ M : ndarray, shape (size(A), size(A))
+ Inverse matrix of :math:`H_A^tH_A` of the current iteration
+
+
References
----------
- [Chapel et al., 2021]:
- Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021).
+ .. [41] Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021).
Unbalanced optimal transport through non-negative penalized
- linear regression.
+ linear regression. NeurIPS.
"""
+
if b is None:
b = M_current[id_pop, :]
b = np.delete(b, id_pop)
@@ -409,33 +490,39 @@ def complement_schur(M_current, b, d, id_pop):
def construct_augmented_H(active_index, m, Hc, HrHr):
- r""" This function construct an augmented matrix for the first iteration of
- semi-relaxed regularization path
+ r""" This function constructs an augmented matrix for the first iteration
+ of the semi-relaxed regularization path
.. math::
- Augmented_H =
+ \text{Augmented}_H =
\begin{bmatrix}
0 & H_{c A} \\
H_{c A}^T & H_{r A}^T H_{r A}
\end{bmatrix}
+
where :
- - H_{r A} is the sub-matrix constructed by the columns of H_r
- whose indices belong to the active set A
- - H_{c A} is the sub-matrix constructed by the columns of H_c
- whose indices belong to the active set A
+
+ - :math:`H_{r A}` is the sub-matrix constructed by the columns of \
+ :math:`H_r` whose indices belong to the active set A
+ - :math:`H_{c A}` is the sub-matrix constructed by the columns of \
+ :math:`H_c` whose indices belong to the active set A
+
+
Parameters
----------
active_index : list
- Indices of active variables
+ Indices of the active variables
m : int
Length of the target distribution
Hc : np.ndarray (dim_b, dim_a * dim_b)
- Matrix that computes the sum along the columns of transport plan T
+ Matrix that computes the sum along the columns of the transport plan \
+ :math:`T`
HrHr : np.ndarray (dim_a * dim_b, dim_a * dim_b)
- Matrix product of H_r^T H_r
+ Matrix product of :math:`H_r^T H_r`
+
Returns
-------
- H_augmented : np.ndarray (dim_b + |A|, dim_b + |A|)
+ H_augmented : np.ndarray (dim_b + size(A), dim_b + size(A))
Augmented matrix for the first iteration of the semi-relaxed
regularization path
"""
@@ -451,18 +538,27 @@ def fully_relaxed_path(a: np.array, b: np.array, C: np.array, reg=1e-4,
r"""This function gives the regularization path of l2-penalized UOT problem
The problem to optimize is the Lasso reformulation of the l2-penalized UOT:
+
.. math::
- \min_t \gamma c^T t + 0.5 * \|H t - y\|_2^2
+ \min_t \gamma \mathbf{c}^T \mathbf{t}
+ + 0.5 * \|{H} \mathbf{t} - \mathbf{y}\|_2^2
+
s.t.
- t \geq 0
+ \mathbf{t} \geq 0
+
where :
- - c is a (dim_a * dim_b, ) metric cost vector (flattened version of C)
+
+ - :math:`\mathbf{c}` is the flattened version of the cost matrix \
+ :math:`{C}`
- :math:`\gamma = 1/\lambda` is the l2-regularization coefficient
- - y is the concatenation of vectors a and b, defined as y^T = [a^T b^T]
- - H is a (dim_a + dim_b, dim_a * dim_b) metric matrix,
- see [Chapel et al., 2021] for the design of H. The matrix product Ht
- computes both the source marginal and the target marginal.
- - t is a (dim_a * dim_b, ) metric vector (flattened version of T)
+ - :math:`\mathbf{y}` is the concatenation of vectors :math:`\mathbf{a}` \
+ and :math:`\mathbf{b}`, defined as \
+ :math:`\mathbf{y}^T = [\mathbf{a}^T \mathbf{b}^T]`
+ - :math:`{H}` is a design matrix, see :ref:`[41] <references-regpath>` \
+ for the design of :math:`{H}`. The matrix product :math:`H\mathbf{t}` \
+ computes both the source marginal and the target marginals.
+ - :math:`\mathbf{t}` is the flattened version of the transport matrix
+
Parameters
----------
a : np.ndarray (dim_a,)
@@ -478,11 +574,12 @@ def fully_relaxed_path(a: np.array, b: np.array, C: np.array, reg=1e-4,
Returns
-------
t : np.ndarray (dim_a*dim_b, )
- Flattened vector of optimal transport matrix
+ Flattened vector of the optimal transport matrix
t_list : list
- List of solutions in regularization path
+ List of solutions in the regularization path
gamma_list : list
- List of regularization coefficient in regularization path
+ List of regularization coefficients in the regularization path
+
Examples
--------
>>> import ot
@@ -502,10 +599,9 @@ def fully_relaxed_path(a: np.array, b: np.array, C: np.array, reg=1e-4,
References
----------
- [Chapel et al., 2021]:
- Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021).
+ .. [41] Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021).
Unbalanced optimal transport through non-negative penalized
- linear regression.
+ linear regression. NeurIPS.
"""
n = np.shape(a)[0]
@@ -580,22 +676,32 @@ def fully_relaxed_path(a: np.array, b: np.array, C: np.array, reg=1e-4,
def semi_relaxed_path(a: np.array, b: np.array, C: np.array, reg=1e-4,
itmax=50000):
r"""This function gives the regularization path of semi-relaxed
- l2-UOT problem
+ l2-UOT problem.
The problem to optimize is the Lasso reformulation of the l2-penalized UOT:
+
.. math::
- \min_t \gamma c^T t + 0.5 * \|H_r t - a\|_2^2
+
+ \min_t \gamma \mathbf{c}^T t
+ + 0.5 * \|H_r \mathbf{t} - \mathbf{a}\|_2^2
+
s.t.
- H_c t = b
- t \geq 0
+ H_c \mathbf{t} = \mathbf{b}
+
+ \mathbf{t} \geq 0
+
where :
- - c is a (dim_a * dim_b, ) metric cost vector (flattened version of C)
- - :math:`\gamma = 1/\lambda` is the l2-regularization coefficient
- - H_r is a (dim_a, dim_a * dim_b) metric matrix,
- which computes the sum along the rows of transport plan T
- - H_c is a (dim_b, dim_a * dim_b) metric matrix,
- which computes the sum along the columns of transport plan T
- - t is a (dim_a * dim_b, ) metric vector (flattened version of T)
+
+ - :math:`\mathbf{c}` is the flattened version of the cost matrix \
+ :math:`C`
+ - :math:`\gamma = 1/\lambda` is the l2-regularization parameter
+ - :math:`H_r` is a matrix that computes the sum along the rows of \
+ the transport plan :math:`T`
+ - :math:`H_c` is a matrix that computes the sum along the columns of \
+ the transport plan :math:`T`
+ - :math:`\mathbf{t}` is the flattened version of the transport plan \
+ :math:`T`
+
Parameters
----------
a : np.ndarray (dim_a,)
@@ -608,14 +714,16 @@ def semi_relaxed_path(a: np.array, b: np.array, C: np.array, reg=1e-4,
l2-regularization coefficient
itmax: int (optional)
Maximum number of iteration
+
Returns
-------
t : np.ndarray (dim_a*dim_b, )
- Flattened vector of optimal transport matrix
+ Flattened vector of the (unregularized) optimal transport matrix
t_list : list
- List of solutions in regularization path
+ List of all the optimal transport vectors of the regularization path
gamma_list : list
- List of regularization coefficient in regularization path
+ List of the regularization parameters in the path
+
Examples
--------
>>> import ot
@@ -635,10 +743,9 @@ def semi_relaxed_path(a: np.array, b: np.array, C: np.array, reg=1e-4,
References
----------
- [Chapel et al., 2021]:
- Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021).
+ .. [41] Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021).
Unbalanced optimal transport through non-negative penalized
- linear regression.
+ linear regression. NeurIPS.
"""
n = np.shape(a)[0]
@@ -722,8 +829,44 @@ def semi_relaxed_path(a: np.array, b: np.array, C: np.array, reg=1e-4,
def regularization_path(a: np.array, b: np.array, C: np.array, reg=1e-4,
semi_relaxed=False, itmax=50000):
- r"""This function combines both the semi-relaxed and the fully-relaxed
- regularization paths of l2-UOT problem
+ r"""This function provides all the solutions of the regularization path \
+ of the l2-UOT problem :ref:`[41] <references-regpath>`.
+
+ The problem to optimize is the Lasso reformulation of the l2-penalized UOT:
+
+ .. math::
+ \min_t \gamma \mathbf{c}^T \mathbf{t}
+ + 0.5 * \|{H} \mathbf{t} - \mathbf{y}\|_2^2
+
+ s.t.
+ \mathbf{t} \geq 0
+
+ where :
+
+ - :math:`\mathbf{c}` is the flattened version of the cost matrix \
+ :math:`{C}`
+ - :math:`\gamma = 1/\lambda` is the l2-regularization coefficient
+ - :math:`\mathbf{y}` is the concatenation of vectors :math:`\mathbf{a}` \
+ and :math:`\mathbf{b}`, defined as \
+ :math:`\mathbf{y}^T = [\mathbf{a}^T \mathbf{b}^T]`
+ - :math:`{H}` is a design matrix, see :ref:`[41] <references-regpath>` \
+ for the design of :math:`{H}`. The matrix product :math:`H\mathbf{t}` \
+ computes both the source marginal and the target marginals.
+ - :math:`\mathbf{t}` is the flattened version of the transport matrix
+
+ For the semi-relaxed problem, it optimizes the Lasso reformulation of the
+ l2-penalized UOT:
+
+ .. math::
+
+ \min_t \gamma \mathbf{c}^T \mathbf{t}
+ + 0.5 * \|H_r \mathbf{t} - \mathbf{a}\|_2^2
+
+ s.t.
+ H_c \mathbf{t} = \mathbf{b}
+
+ \mathbf{t} \geq 0
+
Parameters
----------
@@ -736,23 +879,24 @@ def regularization_path(a: np.array, b: np.array, C: np.array, reg=1e-4,
reg: float (optional)
l2-regularization coefficient
semi_relaxed : bool (optional)
- Give the semi-relaxed path if true
+ Give the semi-relaxed path if True
itmax: int (optional)
Maximum number of iteration
+
Returns
-------
t : np.ndarray (dim_a*dim_b, )
- Flattened vector of optimal transport matrix
+ Flattened vector of the (unregularized) optimal transport matrix
t_list : list
- List of solutions in regularization path
+ List of all the optimal transport vectors of the regularization path
gamma_list : list
- List of regularization coefficient in regularization path
+ List of the regularization parameters in the path
+
References
----------
- [Chapel et al., 2021]:
- Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021).
+ .. [41] Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021).
Unbalanced optimal transport through non-negative penalized
- linear regression.
+ linear regression. NeurIPS.
"""
if semi_relaxed:
t, t_list, gamma_list = semi_relaxed_path(a, b, C, reg=reg,
@@ -765,27 +909,33 @@ def regularization_path(a: np.array, b: np.array, C: np.array, reg=1e-4,
def compute_transport_plan(gamma, gamma_list, Pi_list):
r""" Given the regularization path, this function computes the transport
- plan for any value of gamma by the piecewise linearity of the path
+ plan for any value of gamma thanks to the piecewise linearity of the path.
.. math::
t(\gamma) = \phi(\gamma) - \gamma \delta(\gamma)
- where :
- - :math:`\gamma` is the regularization coefficient
+
+ where:
+
+ - :math:`\gamma` is the regularization parameter
- :math:`\phi(\gamma)` is the corresponding intercept
- :math:`\delta(\gamma)` is the corresponding slope
- - t is a (dim_a * dim_b, ) vector (flattened version of transport matrix)
+ - :math:`\mathbf{t}` is the flattened version of the transport matrix
+
Parameters
----------
gamma : float
Regularization coefficient
gamma_list : list
- List of regularization coefficients in regularization path
+ List of regularization parameters of the regularization path
Pi_list : list
- List of solutions in regularization path
+ List of all the solutions of the regularization path
+
Returns
-------
t : np.ndarray (dim_a*dim_b, )
- Transport vector corresponding to the given value of gamma
+ Vectorization of the transport plan corresponding to the given value
+ of gamma
+
Examples
--------
>>> import ot
@@ -804,12 +954,13 @@ def compute_transport_plan(gamma, gamma_list, Pi_list):
array([0. , 0. , 0. , 0.19722222, 0.05555556,
0. , 0. , 0.24722222, 0. ])
+
+ .. _references-regpath:
References
----------
- [Chapel et al., 2021]:
- Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021).
+ .. [41] Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021).
Unbalanced optimal transport through non-negative penalized
- linear regression.
+ linear regression. NeurIPS.
"""
if gamma >= gamma_list[0]:
diff --git a/ot/stochastic.py b/ot/stochastic.py
index 693675f..61be9bb 100644
--- a/ot/stochastic.py
+++ b/ot/stochastic.py
@@ -4,12 +4,14 @@ Stochastic solvers for regularized OT.
"""
-# Author: Kilian Fatras <kilian.fatras@gmail.com>
+# Authors: Kilian Fatras <kilian.fatras@gmail.com>
+# Rémi Flamary <remi.flamary@polytechnique.edu>
#
# License: MIT License
import numpy as np
-
+from .utils import dist
+from .backend import get_backend
##############################################################################
# Optimization toolbox for SEMI - DUAL problems
@@ -747,3 +749,239 @@ def solve_dual_entropic(a, b, M, reg, batch_size, numItermax=10000, lr=1,
return pi, log
else:
return pi
+
+
+################################################################################
+# Losses for stochastic optimization
+################################################################################
+
+def loss_dual_entropic(u, v, xs, xt, reg=1, ws=None, wt=None, metric='sqeuclidean'):
+ r"""
+ Compute the dual loss of the entropic OT as in equation (6)-(7) of [19]
+
+ This loss is backend compatible and can be used for stochastic optimization
+ of the dual potentials. It can be used on the full dataset (beware of
+ memory) or on minibatches.
+
+
+ Parameters
+ ----------
+ u : array-like, shape (ns,)
+ Source dual potential
+ v : array-like, shape (nt,)
+ Target dual potential
+ xs : array-like, shape (ns,d)
+ Source samples
+ xt : array-like, shape (ns,d)
+ Target samples
+ reg : float
+ Regularization term > 0 (default=1)
+ ws : array-like, shape (ns,), optional
+ Source sample weights (default unif)
+ wt : array-like, shape (ns,), optional
+ Target sample weights (default unif)
+ metric : string, callable
+ Ground metric for OT (default quadratic). Can be given as a callable
+ function taking (xs,xt) as parameters.
+
+ Returns
+ -------
+ dual_loss : array-like
+ Dual loss (to maximize)
+
+
+ References
+ ----------
+ .. [19] Seguy, V., Bhushan Damodaran, B., Flamary, R., Courty, N., Rolet, A.& Blondel, M. Large-scale Optimal Transport and Mapping Estimation. International Conference on Learning Representation (2018)
+ """
+
+ nx = get_backend(u, v, xs, xt)
+
+ if ws is None:
+ ws = nx.ones(xs.shape[0], type_as=xs) / xs.shape[0]
+
+ if wt is None:
+ wt = nx.ones(xt.shape[0], type_as=xt) / xt.shape[0]
+
+ if callable(metric):
+ M = metric(xs, xt)
+ else:
+ M = dist(xs, xt, metric=metric)
+
+ F = -reg * nx.exp((u[:, None] + v[None, :] - M) / reg)
+
+ return nx.sum(u * ws) + nx.sum(v * wt) + nx.sum(ws[:, None] * F * wt[None, :])
+
+
+def plan_dual_entropic(u, v, xs, xt, reg=1, ws=None, wt=None, metric='sqeuclidean'):
+ r"""
+ Compute the primal OT plan the entropic OT as in equation (8) of [19]
+
+ This loss is backend compatible and can be used for stochastic optimization
+ of the dual potentials. It can be used on the full dataset (beware of
+ memory) or on minibatches.
+
+
+ Parameters
+ ----------
+ u : array-like, shape (ns,)
+ Source dual potential
+ v : array-like, shape (nt,)
+ Target dual potential
+ xs : array-like, shape (ns,d)
+ Source samples
+ xt : array-like, shape (ns,d)
+ Target samples
+ reg : float
+ Regularization term > 0 (default=1)
+ ws : array-like, shape (ns,), optional
+ Source sample weights (default unif)
+ wt : array-like, shape (ns,), optional
+ Target sample weights (default unif)
+ metric : string, callable
+ Ground metric for OT (default quadratic). Can be given as a callable
+ function taking (xs,xt) as parameters.
+
+ Returns
+ -------
+ G : array-like
+ Primal OT plan
+
+
+ References
+ ----------
+ .. [19] Seguy, V., Bhushan Damodaran, B., Flamary, R., Courty, N., Rolet, A.& Blondel, M. Large-scale Optimal Transport and Mapping Estimation. International Conference on Learning Representation (2018)
+ """
+
+ nx = get_backend(u, v, xs, xt)
+
+ if ws is None:
+ ws = nx.ones(xs.shape[0], type_as=xs) / xs.shape[0]
+
+ if wt is None:
+ wt = nx.ones(xt.shape[0], type_as=xt) / xt.shape[0]
+
+ if callable(metric):
+ M = metric(xs, xt)
+ else:
+ M = dist(xs, xt, metric=metric)
+
+ H = nx.exp((u[:, None] + v[None, :] - M) / reg)
+
+ return ws[:, None] * H * wt[None, :]
+
+
+def loss_dual_quadratic(u, v, xs, xt, reg=1, ws=None, wt=None, metric='sqeuclidean'):
+ r"""
+ Compute the dual loss of the quadratic regularized OT as in equation (6)-(7) of [19]
+
+ This loss is backend compatible and can be used for stochastic optimization
+ of the dual potentials. It can be used on the full dataset (beware of
+ memory) or on minibatches.
+
+
+ Parameters
+ ----------
+ u : array-like, shape (ns,)
+ Source dual potential
+ v : array-like, shape (nt,)
+ Target dual potential
+ xs : array-like, shape (ns,d)
+ Source samples
+ xt : array-like, shape (ns,d)
+ Target samples
+ reg : float
+ Regularization term > 0 (default=1)
+ ws : array-like, shape (ns,), optional
+ Source sample weights (default unif)
+ wt : array-like, shape (ns,), optional
+ Target sample weights (default unif)
+ metric : string, callable
+ Ground metric for OT (default quadratic). Can be given as a callable
+ function taking (xs,xt) as parameters.
+
+ Returns
+ -------
+ dual_loss : array-like
+ Dual loss (to maximize)
+
+
+ References
+ ----------
+ .. [19] Seguy, V., Bhushan Damodaran, B., Flamary, R., Courty, N., Rolet, A.& Blondel, M. Large-scale Optimal Transport and Mapping Estimation. International Conference on Learning Representation (2018)
+ """
+
+ nx = get_backend(u, v, xs, xt)
+
+ if ws is None:
+ ws = nx.ones(xs.shape[0], type_as=xs) / xs.shape[0]
+
+ if wt is None:
+ wt = nx.ones(xt.shape[0], type_as=xt) / xt.shape[0]
+
+ if callable(metric):
+ M = metric(xs, xt)
+ else:
+ M = dist(xs, xt, metric=metric)
+
+ F = -1.0 / (4 * reg) * nx.maximum(u[:, None] + v[None, :] - M, 0.0)**2
+
+ return nx.sum(u * ws) + nx.sum(v * wt) + nx.sum(ws[:, None] * F * wt[None, :])
+
+
+def plan_dual_quadratic(u, v, xs, xt, reg=1, ws=None, wt=None, metric='sqeuclidean'):
+ r"""
+ Compute the primal OT plan the quadratic regularized OT as in equation (8) of [19]
+
+ This loss is backend compatible and can be used for stochastic optimization
+ of the dual potentials. It can be used on the full dataset (beware of
+ memory) or on minibatches.
+
+
+ Parameters
+ ----------
+ u : array-like, shape (ns,)
+ Source dual potential
+ v : array-like, shape (nt,)
+ Target dual potential
+ xs : array-like, shape (ns,d)
+ Source samples
+ xt : array-like, shape (ns,d)
+ Target samples
+ reg : float
+ Regularization term > 0 (default=1)
+ ws : array-like, shape (ns,), optional
+ Source sample weights (default unif)
+ wt : array-like, shape (ns,), optional
+ Target sample weights (default unif)
+ metric : string, callable
+ Ground metric for OT (default quadratic). Can be given as a callable
+ function taking (xs,xt) as parameters.
+
+ Returns
+ -------
+ G : array-like
+ Primal OT plan
+
+
+ References
+ ----------
+ .. [19] Seguy, V., Bhushan Damodaran, B., Flamary, R., Courty, N., Rolet, A.& Blondel, M. Large-scale Optimal Transport and Mapping Estimation. International Conference on Learning Representation (2018)
+ """
+
+ nx = get_backend(u, v, xs, xt)
+
+ if ws is None:
+ ws = nx.ones(xs.shape[0], type_as=xs) / xs.shape[0]
+
+ if wt is None:
+ wt = nx.ones(xt.shape[0], type_as=xt) / xt.shape[0]
+
+ if callable(metric):
+ M = metric(xs, xt)
+ else:
+ M = dist(xs, xt, metric=metric)
+
+ H = 1.0 / (2 * reg) * nx.maximum(u[:, None] + v[None, :] - M, 0.0)
+
+ return ws[:, None] * H * wt[None, :]
diff --git a/ot/unbalanced.py b/ot/unbalanced.py
index 15e180b..90c920c 100644
--- a/ot/unbalanced.py
+++ b/ot/unbalanced.py
@@ -4,13 +4,14 @@ Regularized Unbalanced OT solvers
"""
# Author: Hicham Janati <hicham.janati@inria.fr>
+# Laetitia Chapel <laetitia.chapel@univ-ubs.fr>
# License: MIT License
from __future__ import division
import warnings
-import numpy as np
-from scipy.special import logsumexp
+from .backend import get_backend
+from .utils import list_to_array
# from .utils import unif, dist
@@ -43,12 +44,12 @@ def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', numItermax=1000,
Parameters
----------
- a : np.ndarray (dim_a,)
+ a : array-like (dim_a,)
Unnormalized histogram of dimension `dim_a`
- b : np.ndarray (dim_b,) or np.ndarray (dim_b, n_hists)
+ b : array-like (dim_b,) or array-like (dim_b, n_hists)
One or multiple unnormalized histograms of dimension `dim_b`.
If many, compute all the OT distances :math:`(\mathbf{a}, \mathbf{b}_i)_i`
- M : np.ndarray (dim_a, dim_b)
+ M : array-like (dim_a, dim_b)
loss matrix
reg : float
Entropy regularization term > 0
@@ -70,12 +71,12 @@ def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', numItermax=1000,
Returns
-------
if n_hists == 1:
- - gamma : (dim_a, dim_b) ndarray
+ - gamma : (dim_a, dim_b) array-like
Optimal transportation matrix for the given parameters
- log : dict
log dictionary returned only if `log` is `True`
else:
- - ot_distance : (n_hists,) ndarray
+ - ot_distance : (n_hists,) array-like
the OT distance between :math:`\mathbf{a}` and each of the histograms :math:`\mathbf{b}_i`
- log : dict
log dictionary returned only if `log` is `True`
@@ -172,12 +173,12 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn',
Parameters
----------
- a : np.ndarray (dim_a,)
+ a : array-like (dim_a,)
Unnormalized histogram of dimension `dim_a`
- b : np.ndarray (dim_b,) or np.ndarray (dim_b, n_hists)
+ b : array-like (dim_b,) or array-like (dim_b, n_hists)
One or multiple unnormalized histograms of dimension `dim_b`.
If many, compute all the OT distances :math:`(\mathbf{a}, \mathbf{b}_i)_i`
- M : np.ndarray (dim_a, dim_b)
+ M : array-like (dim_a, dim_b)
loss matrix
reg : float
Entropy regularization term > 0
@@ -198,7 +199,7 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn',
Returns
-------
- ot_distance : (n_hists,) ndarray
+ ot_distance : (n_hists,) array-like
the OT distance between :math:`\mathbf{a}` and each of the histograms :math:`\mathbf{b}_i`
log : dict
log dictionary returned only if `log` is `True`
@@ -239,9 +240,10 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn',
ot.unbalanced.sinkhorn_reg_scaling: Unbalanced Sinkhorn with epslilon scaling :ref:`[9, 10] <references-sinkhorn-unbalanced2>`
"""
- b = np.asarray(b, dtype=np.float64)
+ b = list_to_array(b)
if len(b.shape) < 2:
b = b[:, None]
+
if method.lower() == 'sinkhorn':
return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m,
numItermax=numItermax,
@@ -291,12 +293,12 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000,
Parameters
----------
- a : np.ndarray (dim_a,)
+ a : array-like (dim_a,)
Unnormalized histogram of dimension `dim_a`
- b : np.ndarray (dim_b,) or np.ndarray (dim_b, n_hists)
+ b : array-like (dim_b,) or array-like (dim_b, n_hists)
One or multiple unnormalized histograms of dimension `dim_b`
If many, compute all the OT distances (a, b_i)
- M : np.ndarray (dim_a, dim_b)
+ M : array-like (dim_a, dim_b)
loss matrix
reg : float
Entropy regularization term > 0
@@ -315,12 +317,12 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000,
Returns
-------
if n_hists == 1:
- - gamma : (dim_a, dim_b) ndarray
+ - gamma : (dim_a, dim_b) array-like
Optimal transportation matrix for the given parameters
- log : dict
log dictionary returned only if `log` is `True`
else:
- - ot_distance : (n_hists,) ndarray
+ - ot_distance : (n_hists,) array-like
the OT distance between :math:`\mathbf{a}` and each of the histograms :math:`\mathbf{b}_i`
- log : dict
log dictionary returned only if `log` is `True`
@@ -354,17 +356,15 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000,
ot.optim.cg : General regularized OT
"""
-
- a = np.asarray(a, dtype=np.float64)
- b = np.asarray(b, dtype=np.float64)
- M = np.asarray(M, dtype=np.float64)
+ M, a, b = list_to_array(M, a, b)
+ nx = get_backend(M, a, b)
dim_a, dim_b = M.shape
if len(a) == 0:
- a = np.ones(dim_a, dtype=np.float64) / dim_a
+ a = nx.ones(dim_a, type_as=M) / dim_a
if len(b) == 0:
- b = np.ones(dim_b, dtype=np.float64) / dim_b
+ b = nx.ones(dim_b, type_as=M) / dim_b
if len(b.shape) > 1:
n_hists = b.shape[1]
@@ -377,17 +377,14 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000,
# we assume that no distances are null except those of the diagonal of
# distances
if n_hists:
- u = np.ones((dim_a, 1)) / dim_a
- v = np.ones((dim_b, n_hists)) / dim_b
+ u = nx.ones((dim_a, 1), type_as=M) / dim_a
+ v = nx.ones((dim_b, n_hists), type_as=M) / dim_b
a = a.reshape(dim_a, 1)
else:
- u = np.ones(dim_a) / dim_a
- v = np.ones(dim_b) / dim_b
+ u = nx.ones(dim_a, type_as=M) / dim_a
+ v = nx.ones(dim_b, type_as=M) / dim_b
- # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute
- K = np.empty(M.shape, dtype=M.dtype)
- np.divide(M, -reg, out=K)
- np.exp(K, out=K)
+ K = nx.exp(M / (-reg))
fi = reg_m / (reg_m + reg)
@@ -397,14 +394,14 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000,
uprev = u
vprev = v
- Kv = K.dot(v)
+ Kv = nx.dot(K, v)
u = (a / Kv) ** fi
- Ktu = K.T.dot(u)
+ Ktu = nx.dot(K.T, u)
v = (b / Ktu) ** fi
- if (np.any(Ktu == 0.)
- or np.any(np.isnan(u)) or np.any(np.isnan(v))
- or np.any(np.isinf(u)) or np.any(np.isinf(v))):
+ if (nx.any(Ktu == 0.)
+ or nx.any(nx.isnan(u)) or nx.any(nx.isnan(v))
+ or nx.any(nx.isinf(u)) or nx.any(nx.isinf(v))):
# we have reached the machine precision
# come back to previous solution and quit loop
warnings.warn('Numerical errors at iteration %s' % i)
@@ -412,8 +409,12 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000,
v = vprev
break
- err_u = abs(u - uprev).max() / max(abs(u).max(), abs(uprev).max(), 1.)
- err_v = abs(v - vprev).max() / max(abs(v).max(), abs(vprev).max(), 1.)
+ err_u = nx.max(nx.abs(u - uprev)) / max(
+ nx.max(nx.abs(u)), nx.max(nx.abs(uprev)), 1.
+ )
+ err_v = nx.max(nx.abs(v - vprev)) / max(
+ nx.max(nx.abs(v)), nx.max(nx.abs(vprev)), 1.
+ )
err = 0.5 * (err_u + err_v)
if log:
log['err'].append(err)
@@ -426,11 +427,11 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000,
break
if log:
- log['logu'] = np.log(u + 1e-300)
- log['logv'] = np.log(v + 1e-300)
+ log['logu'] = nx.log(u + 1e-300)
+ log['logv'] = nx.log(v + 1e-300)
if n_hists: # return only loss
- res = np.einsum('ik,ij,jk,ij->k', u, K, v, M)
+ res = nx.einsum('ik,ij,jk,ij->k', u, K, v, M)
if log:
return res, log
else:
@@ -475,12 +476,12 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000
Parameters
----------
- a : np.ndarray (dim_a,)
+ a : array-like (dim_a,)
Unnormalized histogram of dimension `dim_a`
- b : np.ndarray (dim_b,) or np.ndarray (dim_b, n_hists)
+ b : array-like (dim_b,) or array-like (dim_b, n_hists)
One or multiple unnormalized histograms of dimension `dim_b`.
If many, compute all the OT distances :math:`(\mathbf{a}, \mathbf{b}_i)_i`
- M : np.ndarray (dim_a, dim_b)
+ M : array-like (dim_a, dim_b)
loss matrix
reg : float
Entropy regularization term > 0
@@ -501,12 +502,12 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000
Returns
-------
if n_hists == 1:
- - gamma : (dim_a, dim_b) ndarray
+ - gamma : (dim_a, dim_b) array-like
Optimal transportation matrix for the given parameters
- log : dict
log dictionary returned only if `log` is `True`
else:
- - ot_distance : (n_hists,) ndarray
+ - ot_distance : (n_hists,) array-like
the OT distance between :math:`\mathbf{a}` and each of the histograms :math:`\mathbf{b}_i`
- log : dict
log dictionary returned only if `log` is `True`
@@ -538,17 +539,15 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000
ot.optim.cg : General regularized OT
"""
-
- a = np.asarray(a, dtype=np.float64)
- b = np.asarray(b, dtype=np.float64)
- M = np.asarray(M, dtype=np.float64)
+ a, b, M = list_to_array(a, b, M)
+ nx = get_backend(M, a, b)
dim_a, dim_b = M.shape
if len(a) == 0:
- a = np.ones(dim_a, dtype=np.float64) / dim_a
+ a = nx.ones(dim_a, type_as=M) / dim_a
if len(b) == 0:
- b = np.ones(dim_b, dtype=np.float64) / dim_b
+ b = nx.ones(dim_b, type_as=M) / dim_b
if len(b.shape) > 1:
n_hists = b.shape[1]
@@ -561,56 +560,52 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000
# we assume that no distances are null except those of the diagonal of
# distances
if n_hists:
- u = np.ones((dim_a, n_hists)) / dim_a
- v = np.ones((dim_b, n_hists)) / dim_b
+ u = nx.ones((dim_a, n_hists), type_as=M) / dim_a
+ v = nx.ones((dim_b, n_hists), type_as=M) / dim_b
a = a.reshape(dim_a, 1)
else:
- u = np.ones(dim_a) / dim_a
- v = np.ones(dim_b) / dim_b
+ u = nx.ones(dim_a, type_as=M) / dim_a
+ v = nx.ones(dim_b, type_as=M) / dim_b
# print(reg)
- # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute
- K = np.empty(M.shape, dtype=M.dtype)
- np.divide(M, -reg, out=K)
- np.exp(K, out=K)
+ K = nx.exp(-M / reg)
fi = reg_m / (reg_m + reg)
cpt = 0
err = 1.
- alpha = np.zeros(dim_a)
- beta = np.zeros(dim_b)
+ alpha = nx.zeros(dim_a, type_as=M)
+ beta = nx.zeros(dim_b, type_as=M)
while (err > stopThr and cpt < numItermax):
uprev = u
vprev = v
- Kv = K.dot(v)
- f_alpha = np.exp(- alpha / (reg + reg_m))
- f_beta = np.exp(- beta / (reg + reg_m))
+ Kv = nx.dot(K, v)
+ f_alpha = nx.exp(- alpha / (reg + reg_m))
+ f_beta = nx.exp(- beta / (reg + reg_m))
if n_hists:
f_alpha = f_alpha[:, None]
f_beta = f_beta[:, None]
u = ((a / (Kv + 1e-16)) ** fi) * f_alpha
- Ktu = K.T.dot(u)
+ Ktu = nx.dot(K.T, u)
v = ((b / (Ktu + 1e-16)) ** fi) * f_beta
absorbing = False
- if (u > tau).any() or (v > tau).any():
+ if nx.any(u > tau) or nx.any(v > tau):
absorbing = True
if n_hists:
- alpha = alpha + reg * np.log(np.max(u, 1))
- beta = beta + reg * np.log(np.max(v, 1))
+ alpha = alpha + reg * nx.log(nx.max(u, 1))
+ beta = beta + reg * nx.log(nx.max(v, 1))
else:
- alpha = alpha + reg * np.log(np.max(u))
- beta = beta + reg * np.log(np.max(v))
- K = np.exp((alpha[:, None] + beta[None, :] -
- M) / reg)
- v = np.ones_like(v)
- Kv = K.dot(v)
-
- if (np.any(Ktu == 0.)
- or np.any(np.isnan(u)) or np.any(np.isnan(v))
- or np.any(np.isinf(u)) or np.any(np.isinf(v))):
+ alpha = alpha + reg * nx.log(nx.max(u))
+ beta = beta + reg * nx.log(nx.max(v))
+ K = nx.exp((alpha[:, None] + beta[None, :] - M) / reg)
+ v = nx.ones(v.shape, type_as=v)
+ Kv = nx.dot(K, v)
+
+ if (nx.any(Ktu == 0.)
+ or nx.any(nx.isnan(u)) or nx.any(nx.isnan(v))
+ or nx.any(nx.isinf(u)) or nx.any(nx.isinf(v))):
# we have reached the machine precision
# come back to previous solution and quit loop
warnings.warn('Numerical errors at iteration %s' % cpt)
@@ -620,8 +615,9 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000
if (cpt % 10 == 0 and not absorbing) or cpt == 0:
# we can speed up the process by checking for the error only all
# the 10th iterations
- err = abs(u - uprev).max() / max(abs(u).max(), abs(uprev).max(),
- 1.)
+ err = nx.max(nx.abs(u - uprev)) / max(
+ nx.max(nx.abs(u)), nx.max(nx.abs(uprev)), 1.
+ )
if log:
log['err'].append(err)
if verbose:
@@ -636,25 +632,30 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000
"Try a larger entropy `reg` or a lower mass `reg_m`." +
"Or a larger absorption threshold `tau`.")
if n_hists:
- logu = alpha[:, None] / reg + np.log(u)
- logv = beta[:, None] / reg + np.log(v)
+ logu = alpha[:, None] / reg + nx.log(u)
+ logv = beta[:, None] / reg + nx.log(v)
else:
- logu = alpha / reg + np.log(u)
- logv = beta / reg + np.log(v)
+ logu = alpha / reg + nx.log(u)
+ logv = beta / reg + nx.log(v)
if log:
log['logu'] = logu
log['logv'] = logv
if n_hists: # return only loss
- res = logsumexp(np.log(M + 1e-100)[:, :, None] + logu[:, None, :] +
- logv[None, :, :] - M[:, :, None] / reg, axis=(0, 1))
- res = np.exp(res)
+ res = nx.logsumexp(
+ nx.log(M + 1e-100)[:, :, None]
+ + logu[:, None, :]
+ + logv[None, :, :]
+ - M[:, :, None] / reg,
+ axis=(0, 1)
+ )
+ res = nx.exp(res)
if log:
return res, log
else:
return res
else: # return OT matrix
- ot_matrix = np.exp(logu[:, None] + logv[None, :] - M / reg)
+ ot_matrix = nx.exp(logu[:, None] + logv[None, :] - M / reg)
if log:
return ot_matrix, log
else:
@@ -683,9 +684,9 @@ def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3,
Parameters
----------
- A : np.ndarray (dim, n_hists)
+ A : array-like (dim, n_hists)
`n_hists` training distributions :math:`\mathbf{a}_i` of dimension `dim`
- M : np.ndarray (dim, dim)
+ M : array-like (dim, dim)
ground metric matrix for OT.
reg : float
Entropy regularization term > 0
@@ -693,7 +694,7 @@ def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3,
Marginal relaxation term > 0
tau : float
Stabilization threshold for log domain absorption.
- weights : np.ndarray (n_hists,) optional
+ weights : array-like (n_hists,) optional
Weight of each distribution (barycentric coodinates)
If None, uniform weights are used.
numItermax : int, optional
@@ -708,7 +709,7 @@ def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3,
Returns
-------
- a : (dim,) ndarray
+ a : (dim,) array-like
Unbalanced Wasserstein barycenter
log : dict
log dictionary return only if log==True in parameters
@@ -726,9 +727,12 @@ def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3,
"""
+ A, M = list_to_array(A, M)
+ nx = get_backend(A, M)
+
dim, n_hists = A.shape
if weights is None:
- weights = np.ones(n_hists) / n_hists
+ weights = nx.ones(n_hists, type_as=A) / n_hists
else:
assert(len(weights) == A.shape[1])
@@ -737,47 +741,43 @@ def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3,
fi = reg_m / (reg_m + reg)
- u = np.ones((dim, n_hists)) / dim
- v = np.ones((dim, n_hists)) / dim
+ u = nx.ones((dim, n_hists), type_as=A) / dim
+ v = nx.ones((dim, n_hists), type_as=A) / dim
# print(reg)
- # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute
- K = np.empty(M.shape, dtype=M.dtype)
- np.divide(M, -reg, out=K)
- np.exp(K, out=K)
+ K = nx.exp(-M / reg)
fi = reg_m / (reg_m + reg)
cpt = 0
err = 1.
- alpha = np.zeros(dim)
- beta = np.zeros(dim)
- q = np.ones(dim) / dim
+ alpha = nx.zeros(dim, type_as=A)
+ beta = nx.zeros(dim, type_as=A)
+ q = nx.ones(dim, type_as=A) / dim
for i in range(numItermax):
- qprev = q.copy()
- Kv = K.dot(v)
- f_alpha = np.exp(- alpha / (reg + reg_m))
- f_beta = np.exp(- beta / (reg + reg_m))
+ qprev = nx.copy(q)
+ Kv = nx.dot(K, v)
+ f_alpha = nx.exp(- alpha / (reg + reg_m))
+ f_beta = nx.exp(- beta / (reg + reg_m))
f_alpha = f_alpha[:, None]
f_beta = f_beta[:, None]
u = ((A / (Kv + 1e-16)) ** fi) * f_alpha
- Ktu = K.T.dot(u)
+ Ktu = nx.dot(K.T, u)
q = (Ktu ** (1 - fi)) * f_beta
- q = q.dot(weights) ** (1 / (1 - fi))
+ q = nx.dot(q, weights) ** (1 / (1 - fi))
Q = q[:, None]
v = ((Q / (Ktu + 1e-16)) ** fi) * f_beta
absorbing = False
- if (u > tau).any() or (v > tau).any():
+ if nx.any(u > tau) or nx.any(v > tau):
absorbing = True
- alpha = alpha + reg * np.log(np.max(u, 1))
- beta = beta + reg * np.log(np.max(v, 1))
- K = np.exp((alpha[:, None] + beta[None, :] -
- M) / reg)
- v = np.ones_like(v)
- Kv = K.dot(v)
- if (np.any(Ktu == 0.)
- or np.any(np.isnan(u)) or np.any(np.isnan(v))
- or np.any(np.isinf(u)) or np.any(np.isinf(v))):
+ alpha = alpha + reg * nx.log(nx.max(u, 1))
+ beta = beta + reg * nx.log(nx.max(v, 1))
+ K = nx.exp((alpha[:, None] + beta[None, :] - M) / reg)
+ v = nx.ones(v.shape, type_as=v)
+ Kv = nx.dot(K, v)
+ if (nx.any(Ktu == 0.)
+ or nx.any(nx.isnan(u)) or nx.any(nx.isnan(v))
+ or nx.any(nx.isinf(u)) or nx.any(nx.isinf(v))):
# we have reached the machine precision
# come back to previous solution and quit loop
warnings.warn('Numerical errors at iteration %s' % cpt)
@@ -786,8 +786,9 @@ def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3,
if (i % 10 == 0 and not absorbing) or i == 0:
# we can speed up the process by checking for the error only all
# the 10th iterations
- err = abs(q - qprev).max() / max(abs(q).max(),
- abs(qprev).max(), 1.)
+ err = nx.max(nx.abs(q - qprev)) / max(
+ nx.max(nx.abs(q)), nx.max(nx.abs(qprev)), 1.
+ )
if log:
log['err'].append(err)
if verbose:
@@ -804,8 +805,8 @@ def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3,
"Or a larger absorption threshold `tau`.")
if log:
log['niter'] = i
- log['logu'] = np.log(u + 1e-300)
- log['logv'] = np.log(v + 1e-300)
+ log['logu'] = nx.log(u + 1e-300)
+ log['logv'] = nx.log(v + 1e-300)
return q, log
else:
return q
@@ -833,15 +834,15 @@ def barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, weights=None,
Parameters
----------
- A : np.ndarray (dim, n_hists)
+ A : array-like (dim, n_hists)
`n_hists` training distributions :math:`\mathbf{a}_i` of dimension `dim`
- M : np.ndarray (dim, dim)
+ M : array-like (dim, dim)
ground metric matrix for OT.
reg : float
Entropy regularization term > 0
reg_m: float
Marginal relaxation term > 0
- weights : np.ndarray (n_hists,) optional
+ weights : array-like (n_hists,) optional
Weight of each distribution (barycentric coodinates)
If None, uniform weights are used.
numItermax : int, optional
@@ -856,7 +857,7 @@ def barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, weights=None,
Returns
-------
- a : (dim,) ndarray
+ a : (dim,) array-like
Unbalanced Wasserstein barycenter
log : dict
log dictionary return only if log==True in parameters
@@ -874,40 +875,43 @@ def barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, weights=None,
"""
+ A, M = list_to_array(A, M)
+ nx = get_backend(A, M)
+
dim, n_hists = A.shape
if weights is None:
- weights = np.ones(n_hists) / n_hists
+ weights = nx.ones(n_hists, type_as=A) / n_hists
else:
assert(len(weights) == A.shape[1])
if log:
log = {'err': []}
- K = np.exp(- M / reg)
+ K = nx.exp(-M / reg)
fi = reg_m / (reg_m + reg)
- v = np.ones((dim, n_hists))
- u = np.ones((dim, 1))
- q = np.ones(dim)
+ v = nx.ones((dim, n_hists), type_as=A)
+ u = nx.ones((dim, 1), type_as=A)
+ q = nx.ones(dim, type_as=A)
err = 1.
for i in range(numItermax):
- uprev = u.copy()
- vprev = v.copy()
- qprev = q.copy()
+ uprev = nx.copy(u)
+ vprev = nx.copy(v)
+ qprev = nx.copy(q)
- Kv = K.dot(v)
+ Kv = nx.dot(K, v)
u = (A / Kv) ** fi
- Ktu = K.T.dot(u)
- q = ((Ktu ** (1 - fi)).dot(weights))
+ Ktu = nx.dot(K.T, u)
+ q = nx.dot(Ktu ** (1 - fi), weights)
q = q ** (1 / (1 - fi))
Q = q[:, None]
v = (Q / Ktu) ** fi
- if (np.any(Ktu == 0.)
- or np.any(np.isnan(u)) or np.any(np.isnan(v))
- or np.any(np.isinf(u)) or np.any(np.isinf(v))):
+ if (nx.any(Ktu == 0.)
+ or nx.any(nx.isnan(u)) or nx.any(nx.isnan(v))
+ or nx.any(nx.isinf(u)) or nx.any(nx.isinf(v))):
# we have reached the machine precision
# come back to previous solution and quit loop
warnings.warn('Numerical errors at iteration %s' % i)
@@ -916,8 +920,9 @@ def barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, weights=None,
q = qprev
break
# compute change in barycenter
- err = abs(q - qprev).max()
- err /= max(abs(q).max(), abs(qprev).max(), 1.)
+ err = nx.max(nx.abs(q - qprev)) / max(
+ nx.max(nx.abs(q)), nx.max(nx.abs(qprev)), 1.0
+ )
if log:
log['err'].append(err)
# if barycenter did not change + at least 10 iterations - stop
@@ -932,8 +937,8 @@ def barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, weights=None,
if log:
log['niter'] = i
- log['logu'] = np.log(u + 1e-300)
- log['logv'] = np.log(v + 1e-300)
+ log['logu'] = nx.log(u + 1e-300)
+ log['logv'] = nx.log(v + 1e-300)
return q, log
else:
return q
@@ -961,15 +966,15 @@ def barycenter_unbalanced(A, M, reg, reg_m, method="sinkhorn", weights=None,
Parameters
----------
- A : np.ndarray (dim, n_hists)
+ A : array-like (dim, n_hists)
`n_hists` training distributions :math:`\mathbf{a}_i` of dimension `dim`
- M : np.ndarray (dim, dim)
+ M : array-like (dim, dim)
ground metric matrix for OT.
reg : float
Entropy regularization term > 0
reg_m: float
Marginal relaxation term > 0
- weights : np.ndarray (n_hists,) optional
+ weights : array-like (n_hists,) optional
Weight of each distribution (barycentric coodinates)
If None, uniform weights are used.
numItermax : int, optional
@@ -984,7 +989,7 @@ def barycenter_unbalanced(A, M, reg, reg_m, method="sinkhorn", weights=None,
Returns
-------
- a : (dim,) ndarray
+ a : (dim,) array-like
Unbalanced Wasserstein barycenter
log : dict
log dictionary return only if log==True in parameters
@@ -1025,3 +1030,225 @@ def barycenter_unbalanced(A, M, reg, reg_m, method="sinkhorn", weights=None,
log=log, **kwargs)
else:
raise ValueError("Unknown method '%s'." % method)
+
+
+def mm_unbalanced(a, b, M, reg_m, div='kl', G0=None, numItermax=1000,
+ stopThr=1e-15, verbose=False, log=False):
+ r"""
+ Solve the unbalanced optimal transport problem and return the OT plan.
+ The function solves the following optimization problem:
+
+ .. math::
+ W = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
+ \mathrm{reg_m} \cdot \mathrm{div}(\gamma \mathbf{1}, \mathbf{a}) +
+ \mathrm{reg_m} \cdot \mathrm{div}(\gamma^T \mathbf{1}, \mathbf{b})
+ s.t.
+ \gamma \geq 0
+
+ where:
+
+ - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target
+ unbalanced distributions
+ - div is a divergence, either Kullback-Leibler or :math:`\ell_2` divergence
+
+ The algorithm used for solving the problem is a maximization-
+ minimization algorithm as proposed in :ref:`[41] <references-regpath>`
+
+ Parameters
+ ----------
+ a : array-like (dim_a,)
+ Unnormalized histogram of dimension `dim_a`
+ b : array-like (dim_b,)
+ Unnormalized histogram of dimension `dim_b`
+ M : array-like (dim_a, dim_b)
+ loss matrix
+ reg_m: float
+ Marginal relaxation term > 0
+ div: string, optional
+ Divergence to quantify the difference between the marginals.
+ Can take two values: 'kl' (Kullback-Leibler) or 'l2' (quadratic)
+ G0: array-like (dim_a, dim_b)
+ Initialization of the transport matrix
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshold on error (> 0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+ Returns
+ -------
+ gamma : (dim_a, dim_b) array-like
+ Optimal transportation matrix for the given parameters
+ log : dict
+ log dictionary returned only if `log` is `True`
+
+ Examples
+ --------
+ >>> import ot
+ >>> import numpy as np
+ >>> a=[.5, .5]
+ >>> b=[.5, .5]
+ >>> M=[[1., 36.],[9., 4.]]
+ >>> np.round(ot.unbalanced.mm_unbalanced(a, b, M, 1, 'kl'), 2)
+ array([[0.3 , 0. ],
+ [0. , 0.07]])
+ >>> np.round(ot.unbalanced.mm_unbalanced(a, b, M, 1, 'l2'), 2)
+ array([[0.25, 0. ],
+ [0. , 0. ]])
+
+
+ .. _references-regpath:
+ References
+ ----------
+ .. [41] Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021).
+ Unbalanced optimal transport through non-negative penalized
+ linear regression. NeurIPS.
+ See Also
+ --------
+ ot.lp.emd : Unregularized OT
+ ot.unbalanced.sinkhorn_unbalanced : Entropic regularized OT
+ """
+ M, a, b = list_to_array(M, a, b)
+ nx = get_backend(M, a, b)
+
+ dim_a, dim_b = M.shape
+
+ if len(a) == 0:
+ a = nx.ones(dim_a, type_as=M) / dim_a
+ if len(b) == 0:
+ b = nx.ones(dim_b, type_as=M) / dim_b
+
+ if G0 is None:
+ G = a[:, None] * b[None, :]
+ else:
+ G = G0
+
+ if log:
+ log = {'err': [], 'G': []}
+
+ if div == 'kl':
+ K = nx.exp(M / - reg_m / 2)
+ elif div == 'l2':
+ K = nx.maximum(a[:, None] + b[None, :] - M / reg_m / 2,
+ nx.zeros((dim_a, dim_b), type_as=M))
+ else:
+ warnings.warn("The div parameter should be either equal to 'kl' or \
+ 'l2': it has been set to 'kl'.")
+ div = 'kl'
+ K = nx.exp(M / - reg_m / 2)
+
+ for i in range(numItermax):
+ Gprev = G
+
+ if div == 'kl':
+ u = nx.sqrt(a / (nx.sum(G, 1) + 1e-16))
+ v = nx.sqrt(b / (nx.sum(G, 0) + 1e-16))
+ G = G * K * u[:, None] * v[None, :]
+ elif div == 'l2':
+ Gd = nx.sum(G, 0, keepdims=True) + nx.sum(G, 1, keepdims=True) + 1e-16
+ G = G * K / Gd
+
+ err = nx.sqrt(nx.sum((G - Gprev) ** 2))
+ if log:
+ log['err'].append(err)
+ log['G'].append(G)
+ if verbose:
+ print('{:5d}|{:8e}|'.format(i, err))
+ if err < stopThr:
+ break
+
+ if log:
+ log['cost'] = nx.sum(G * M)
+ return G, log
+ else:
+ return G
+
+
+def mm_unbalanced2(a, b, M, reg_m, div='kl', G0=None, numItermax=1000,
+ stopThr=1e-15, verbose=False, log=False):
+ r"""
+ Solve the unbalanced optimal transport problem and return the OT plan.
+ The function solves the following optimization problem:
+
+ .. math::
+ W = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
+ \mathrm{reg_m} \cdot \mathrm{div}(\gamma \mathbf{1}, \mathbf{a}) +
+ \mathrm{reg_m} \cdot \mathrm{div}(\gamma^T \mathbf{1}, \mathbf{b})
+
+ s.t.
+ \gamma \geq 0
+
+ where:
+
+ - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target
+ unbalanced distributions
+ - :math:`\mathrm{div}` is a divergence, either Kullback-Leibler or :math:`\ell_2` divergence
+
+ The algorithm used for solving the problem is a maximization-
+ minimization algorithm as proposed in :ref:`[41] <references-regpath>`
+
+ Parameters
+ ----------
+ a : array-like (dim_a,)
+ Unnormalized histogram of dimension `dim_a`
+ b : array-like (dim_b,)
+ Unnormalized histogram of dimension `dim_b`
+ M : array-like (dim_a, dim_b)
+ loss matrix
+ reg_m: float
+ Marginal relaxation term > 0
+ div: string, optional
+ Divergence to quantify the difference between the marginals.
+ Can take two values: 'kl' (Kullback-Leibler) or 'l2' (quadratic)
+ G0: array-like (dim_a, dim_b)
+ Initialization of the transport matrix
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshold on error (> 0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+ Returns
+ -------
+ ot_distance : array-like
+ the OT distance between :math:`\mathbf{a}` and :math:`\mathbf{b}`
+ log : dict
+ log dictionary returned only if `log` is `True`
+
+ Examples
+ --------
+ >>> import ot
+ >>> import numpy as np
+ >>> a=[.5, .5]
+ >>> b=[.5, .5]
+ >>> M=[[1., 36.],[9., 4.]]
+ >>> np.round(ot.unbalanced.mm_unbalanced2(a, b, M, 1, 'l2'),2)
+ 0.25
+ >>> np.round(ot.unbalanced.mm_unbalanced2(a, b, M, 1, 'kl'),2)
+ 0.57
+
+ References
+ ----------
+ .. [41] Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021).
+ Unbalanced optimal transport through non-negative penalized
+ linear regression. NeurIPS.
+ See Also
+ --------
+ ot.lp.emd2 : Unregularized OT loss
+ ot.unbalanced.sinkhorn_unbalanced2 : Entropic regularized OT loss
+ """
+ _, log_mm = mm_unbalanced(a, b, M, reg_m, div=div, G0=G0,
+ numItermax=numItermax, stopThr=stopThr,
+ verbose=verbose, log=True)
+
+ if log:
+ return log_mm['cost'], log_mm
+ else:
+ return log_mm['cost']
diff --git a/ot/utils.py b/ot/utils.py
index e6c93c8..a23ce7e 100644
--- a/ot/utils.py
+++ b/ot/utils.py
@@ -15,7 +15,7 @@ from scipy.spatial.distance import cdist
import sys
import warnings
from inspect import signature
-from .backend import get_backend
+from .backend import get_backend, Backend
__time_tic_toc = time.time()
@@ -51,7 +51,8 @@ def kernel(x1, x2, method='gaussian', sigma=1, **kwargs):
def laplacian(x):
r"""Compute Laplacian matrix"""
- L = np.diag(np.sum(x, axis=0)) - x
+ nx = get_backend(x)
+ L = nx.diag(nx.sum(x, axis=0)) - x
return L
@@ -116,7 +117,7 @@ def proj_simplex(v, z=1):
return w
-def unif(n):
+def unif(n, type_as=None):
r"""
Return a uniform histogram of length `n` (simplex).
@@ -124,13 +125,19 @@ def unif(n):
----------
n : int
number of bins in the histogram
+ type_as : array_like
+ array of the same type of the expected output (numpy/pytorch/jax)
Returns
-------
- h : np.array (`n`,)
+ h : array_like (`n`,)
histogram of length `n` such that :math:`\forall i, \mathbf{h}_i = \frac{1}{n}`
"""
- return np.ones((n,)) / n
+ if type_as is None:
+ return np.ones((n,)) / n
+ else:
+ nx = get_backend(type_as)
+ return nx.ones((n,), type_as=type_as) / n
def clean_zeros(a, b, M):
@@ -290,7 +297,8 @@ def cost_normalization(C, norm=None):
def dots(*args):
r""" dots function for multiple matrix multiply """
- return reduce(np.dot, args)
+ nx = get_backend(*args)
+ return reduce(nx.dot, args)
def label_normalization(y, start=0):
@@ -308,8 +316,9 @@ def label_normalization(y, start=0):
y : array-like, shape (`n1`, )
The input vector of labels normalized according to given start value.
"""
+ nx = get_backend(y)
- diff = np.min(np.unique(y)) - start
+ diff = nx.min(nx.unique(y)) - start
if diff != 0:
y -= diff
return y
@@ -476,6 +485,19 @@ class BaseEstimator(object):
arguments (no ``*args`` or ``**kwargs``).
"""
+ nx: Backend = None
+
+ def _get_backend(self, *arrays):
+ nx = get_backend(
+ *[input_ for input_ in arrays if input_ is not None]
+ )
+ if nx.__name__ in ("jax", "tf"):
+ raise TypeError(
+ """JAX or TF arrays have been received but domain
+ adaptation does not support those backend.""")
+ self.nx = nx
+ return nx
+
@classmethod
def _get_param_names(cls):
r"""Get parameter names for the estimator"""
diff --git a/ot/weak.py b/ot/weak.py
new file mode 100644
index 0000000..f7d5b23
--- /dev/null
+++ b/ot/weak.py
@@ -0,0 +1,124 @@
+"""
+Weak optimal ransport solvers
+"""
+
+# Author: Remi Flamary <remi.flamary@polytehnique.edu>
+#
+# License: MIT License
+
+from .backend import get_backend
+from .optim import cg
+import numpy as np
+
+__all__ = ['weak_optimal_transport']
+
+
+def weak_optimal_transport(Xa, Xb, a=None, b=None, verbose=False, log=False, G0=None, **kwargs):
+ r"""Solves the weak optimal transport problem between two empirical distributions
+
+
+ .. math::
+ \gamma = \mathop{\arg \min}_\gamma \quad \|X_a-diag(1/a)\gammaX_b\|_F^2
+
+ s.t. \ \gamma \mathbf{1} = \mathbf{a}
+
+ \gamma^T \mathbf{1} = \mathbf{b}
+
+ \gamma \geq 0
+
+ where :
+
+ - :math:`X_a` :math:`X_b` are the sample matrices.
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are the sample weights
+
+
+ .. note:: This function is backend-compatible and will work on arrays
+ from all compatible backends. But the algorithm uses the C++ CPU backend
+ which can lead to copy overhead on GPU arrays.
+
+ Uses the conditional gradient algorithm to solve the problem proposed
+ in :ref:`[39] <references-weak>`.
+
+ Parameters
+ ----------
+ Xa : (ns,d) array-like, float
+ Source samples
+ Xb : (nt,d) array-like, float
+ Target samples
+ a : (ns,) array-like, float
+ Source histogram (uniform weight if empty list)
+ b : (nt,) array-like, float
+ Target histogram (uniform weight if empty list))
+ numItermax : int, optional
+ Max number of iterations
+ numItermaxEmd : int, optional
+ Max number of iterations for emd
+ stopThr : float, optional
+ Stop threshold on the relative variation (>0)
+ stopThr2 : float, optional
+ Stop threshold on the absolute variation (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+
+ Returns
+ -------
+ gamma: array-like, shape (ns, nt)
+ Optimal transportation matrix for the given
+ parameters
+ log: dict, optional
+ If input log is true, a dictionary containing the
+ cost and dual variables and exit status
+
+
+ .. _references-weak:
+ References
+ ----------
+ .. [39] Gozlan, N., Roberto, C., Samson, P. M., & Tetali, P. (2017).
+ Kantorovich duality for general transport costs and applications.
+ Journal of Functional Analysis, 273(11), 3327-3405.
+
+ See Also
+ --------
+ ot.bregman.sinkhorn : Entropic regularized OT
+ ot.optim.cg : General regularized OT
+ """
+
+ nx = get_backend(Xa, Xb)
+
+ Xa2 = nx.to_numpy(Xa)
+ Xb2 = nx.to_numpy(Xb)
+
+ if a is None:
+ a2 = np.ones((Xa.shape[0])) / Xa.shape[0]
+ else:
+ a2 = nx.to_numpy(a)
+ if b is None:
+ b2 = np.ones((Xb.shape[0])) / Xb.shape[0]
+ else:
+ b2 = nx.to_numpy(b)
+
+ # init uniform
+ if G0 is None:
+ T0 = a2[:, None] * b2[None, :]
+ else:
+ T0 = nx.to_numpy(G0)
+
+ # weak OT loss
+ def f(T):
+ return np.dot(a2, np.sum((Xa2 - np.dot(T, Xb2) / a2[:, None])**2, 1))
+
+ # weak OT gradient
+ def df(T):
+ return -2 * np.dot(Xa2 - np.dot(T, Xb2) / a2[:, None], Xb2.T)
+
+ # solve with conditional gradient and return solution
+ if log:
+ res, log = cg(a2, b2, 0, 1, f, df, T0, log=log, verbose=verbose, **kwargs)
+ log['u'] = nx.from_numpy(log['u'], type_as=Xa)
+ log['v'] = nx.from_numpy(log['v'], type_as=Xb)
+ return nx.from_numpy(res, type_as=Xa), log
+ else:
+ return nx.from_numpy(cg(a2, b2, 0, 1, f, df, T0, log=log, verbose=verbose, **kwargs), type_as=Xa)
diff --git a/pyproject.toml b/pyproject.toml
index 93ebab3..3789206 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,3 +1,3 @@
[build-system]
-requires = ["setuptools", "wheel", "numpy>=1.20", "cython>=0.23"]
+requires = ["setuptools", "wheel", "oldest-supported-numpy", "cython>=0.23"]
build-backend = "setuptools.build_meta" \ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
index f9934ce..7cbb29a 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,6 +1,5 @@
numpy>=1.20
scipy>=1.3
-cython
matplotlib
autograd
pymanopt==0.2.4; python_version <'3'
diff --git a/setup.py b/setup.py
index d46ae1c..c03191a 100644
--- a/setup.py
+++ b/setup.py
@@ -68,8 +68,9 @@ setup(
license='MIT',
scripts=[],
data_files=[],
- setup_requires=["numpy>=1.20", "cython>=0.23"],
- install_requires=["numpy>=1.20", "scipy>=1.0"],
+ setup_requires=["oldest-supported-numpy", "cython>=0.23"],
+ install_requires=["numpy>=1.16", "scipy>=1.0"],
+ python_requires=">=3.6",
classifiers=[
'Development Status :: 5 - Production/Stable',
'Intended Audience :: Developers',
diff --git a/test/test_1d_solver.py b/test/test_1d_solver.py
index 6a42cfe..20f307a 100644
--- a/test/test_1d_solver.py
+++ b/test/test_1d_solver.py
@@ -66,9 +66,7 @@ def test_wasserstein_1d(nx):
rho_v = np.abs(rng.randn(n))
rho_v /= rho_v.sum()
- xb = nx.from_numpy(x)
- rho_ub = nx.from_numpy(rho_u)
- rho_vb = nx.from_numpy(rho_v)
+ xb, rho_ub, rho_vb = nx.from_numpy(x, rho_u, rho_v)
# test 1 : wasserstein_1d should be close to scipy W_1 implementation
np.testing.assert_almost_equal(wasserstein_1d(xb, xb, rho_ub, rho_vb, p=1),
@@ -98,9 +96,7 @@ def test_wasserstein_1d_type_devices(nx):
for tp in nx.__type_list__:
print(nx.dtype_device(tp))
- xb = nx.from_numpy(x, type_as=tp)
- rho_ub = nx.from_numpy(rho_u, type_as=tp)
- rho_vb = nx.from_numpy(rho_v, type_as=tp)
+ xb, rho_ub, rho_vb = nx.from_numpy(x, rho_u, rho_v, type_as=tp)
res = wasserstein_1d(xb, xb, rho_ub, rho_vb, p=1)
@@ -122,17 +118,13 @@ def test_wasserstein_1d_device_tf():
# Check that everything stays on the CPU
with tf.device("/CPU:0"):
- xb = nx.from_numpy(x)
- rho_ub = nx.from_numpy(rho_u)
- rho_vb = nx.from_numpy(rho_v)
+ xb, rho_ub, rho_vb = nx.from_numpy(x, rho_u, rho_v)
res = wasserstein_1d(xb, xb, rho_ub, rho_vb, p=1)
nx.assert_same_dtype_device(xb, res)
if len(tf.config.list_physical_devices('GPU')) > 0:
# Check that everything happens on the GPU
- xb = nx.from_numpy(x)
- rho_ub = nx.from_numpy(rho_u)
- rho_vb = nx.from_numpy(rho_v)
+ xb, rho_ub, rho_vb = nx.from_numpy(x, rho_u, rho_v)
res = wasserstein_1d(xb, xb, rho_ub, rho_vb, p=1)
nx.assert_same_dtype_device(xb, res)
assert nx.dtype_device(res)[1].startswith("GPU")
@@ -190,9 +182,7 @@ def test_emd1d_type_devices(nx):
for tp in nx.__type_list__:
print(nx.dtype_device(tp))
- xb = nx.from_numpy(x, type_as=tp)
- rho_ub = nx.from_numpy(rho_u, type_as=tp)
- rho_vb = nx.from_numpy(rho_v, type_as=tp)
+ xb, rho_ub, rho_vb = nx.from_numpy(x, rho_u, rho_v, type_as=tp)
emd = ot.emd_1d(xb, xb, rho_ub, rho_vb)
emd2 = ot.emd2_1d(xb, xb, rho_ub, rho_vb)
@@ -214,9 +204,7 @@ def test_emd1d_device_tf():
# Check that everything stays on the CPU
with tf.device("/CPU:0"):
- xb = nx.from_numpy(x)
- rho_ub = nx.from_numpy(rho_u)
- rho_vb = nx.from_numpy(rho_v)
+ xb, rho_ub, rho_vb = nx.from_numpy(x, rho_u, rho_v)
emd = ot.emd_1d(xb, xb, rho_ub, rho_vb)
emd2 = ot.emd2_1d(xb, xb, rho_ub, rho_vb)
nx.assert_same_dtype_device(xb, emd)
@@ -224,9 +212,7 @@ def test_emd1d_device_tf():
if len(tf.config.list_physical_devices('GPU')) > 0:
# Check that everything happens on the GPU
- xb = nx.from_numpy(x)
- rho_ub = nx.from_numpy(rho_u)
- rho_vb = nx.from_numpy(rho_v)
+ xb, rho_ub, rho_vb = nx.from_numpy(x, rho_u, rho_v)
emd = ot.emd_1d(xb, xb, rho_ub, rho_vb)
emd2 = ot.emd2_1d(xb, xb, rho_ub, rho_vb)
nx.assert_same_dtype_device(xb, emd)
diff --git a/test/test_backend.py b/test/test_backend.py
index 027c4cd..311c075 100644
--- a/test/test_backend.py
+++ b/test/test_backend.py
@@ -218,6 +218,8 @@ def test_empty_backend():
with pytest.raises(NotImplementedError):
nx.argmax(M)
with pytest.raises(NotImplementedError):
+ nx.argmin(M)
+ with pytest.raises(NotImplementedError):
nx.mean(M)
with pytest.raises(NotImplementedError):
nx.std(M)
@@ -264,12 +266,27 @@ def test_empty_backend():
nx.device_type(M)
with pytest.raises(NotImplementedError):
nx._bench(lambda x: x, M, n_runs=1)
+ with pytest.raises(NotImplementedError):
+ nx.solve(M, v)
+ with pytest.raises(NotImplementedError):
+ nx.trace(M)
+ with pytest.raises(NotImplementedError):
+ nx.inv(M)
+ with pytest.raises(NotImplementedError):
+ nx.sqrtm(M)
+ with pytest.raises(NotImplementedError):
+ nx.isfinite(M)
+ with pytest.raises(NotImplementedError):
+ nx.array_equal(M, M)
+ with pytest.raises(NotImplementedError):
+ nx.is_floating_point(M)
def test_func_backends(nx):
rnd = np.random.RandomState(0)
M = rnd.randn(10, 3)
+ SquareM = rnd.randn(10, 10)
v = rnd.randn(3)
val = np.array([1.0])
@@ -288,6 +305,7 @@ def test_func_backends(nx):
lst_name = []
Mb = nx.from_numpy(M)
+ SquareMb = nx.from_numpy(SquareM)
vb = nx.from_numpy(v)
val = nx.from_numpy(val)
@@ -467,6 +485,10 @@ def test_func_backends(nx):
lst_b.append(nx.to_numpy(A))
lst_name.append('argmax')
+ A = nx.argmin(Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('argmin')
+
A = nx.mean(Mb)
lst_b.append(nx.to_numpy(A))
lst_name.append('mean')
@@ -529,7 +551,11 @@ def test_func_backends(nx):
A = nx.where(Mb >= nx.stack([nx.linspace(0, 1, 10)] * 3, axis=1), Mb, 0.0)
lst_b.append(nx.to_numpy(A))
- lst_name.append('where')
+ lst_name.append('where (cond, x, y)')
+
+ A = nx.where(nx.from_numpy(np.array([True, False])))
+ lst_b.append(nx.to_numpy(nx.stack(A)))
+ lst_name.append('where (cond)')
A = nx.copy(Mb)
lst_b.append(nx.to_numpy(A))
@@ -550,15 +576,47 @@ def test_func_backends(nx):
nx._bench(lambda x: x, M, n_runs=1)
+ A = nx.solve(SquareMb, Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('solve')
+
+ A = nx.trace(SquareMb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('trace')
+
+ A = nx.inv(SquareMb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('matrix inverse')
+
+ A = nx.sqrtm(SquareMb.T @ SquareMb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append("matrix square root")
+
+ A = nx.concatenate([vb, nx.from_numpy(np.array([np.inf, np.nan]))], axis=0)
+ A = nx.isfinite(A)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append("isfinite")
+
+ assert not nx.array_equal(Mb, vb), "array_equal (shape)"
+ assert nx.array_equal(Mb, Mb), "array_equal (elements) - expected true"
+ assert not nx.array_equal(
+ Mb, Mb + nx.eye(*list(Mb.shape))
+ ), "array_equal (elements) - expected false"
+
+ assert nx.is_floating_point(Mb), "is_floating_point - expected true"
+ assert not nx.is_floating_point(
+ nx.from_numpy(np.array([0, 1, 2], dtype=int))
+ ), "is_floating_point - expected false"
+
lst_tot.append(lst_b)
lst_np = lst_tot[0]
lst_b = lst_tot[1]
for a1, a2, name in zip(lst_np, lst_b, lst_name):
- if not np.allclose(a1, a2):
- print('Assert fail on: ', name)
- assert np.allclose(a1, a2, atol=1e-7)
+ np.testing.assert_allclose(
+ a2, a1, atol=1e-7, err_msg=f'ASSERT FAILED ON: {name}'
+ )
def test_random_backends(nx):
diff --git a/test/test_bregman.py b/test/test_bregman.py
index 6e90aa4..6c37984 100644
--- a/test/test_bregman.py
+++ b/test/test_bregman.py
@@ -60,7 +60,7 @@ def test_convergence_warning(method):
ot.sinkhorn2(a1, a2, M, 1, method=method, stopThr=0, numItermax=1)
-def test_not_impemented_method():
+def test_not_implemented_method():
# test sinkhorn
w = 10
n = w ** 2
@@ -155,8 +155,7 @@ def test_sinkhorn_backends(nx):
G = ot.sinkhorn(a, a, M, 1)
- ab = nx.from_numpy(a)
- M_nx = nx.from_numpy(M)
+ ab, M_nx = nx.from_numpy(a, M)
Gb = ot.sinkhorn(ab, ab, M_nx, 1)
@@ -176,8 +175,7 @@ def test_sinkhorn2_backends(nx):
G = ot.sinkhorn(a, a, M, 1)
- ab = nx.from_numpy(a)
- M_nx = nx.from_numpy(M)
+ ab, M_nx = nx.from_numpy(a, M)
Gb = ot.sinkhorn2(ab, ab, M_nx, 1)
@@ -260,8 +258,7 @@ def test_sinkhorn_variants(nx):
M = ot.dist(x, x)
- ub = nx.from_numpy(u)
- M_nx = nx.from_numpy(M)
+ ub, M_nx = nx.from_numpy(u, M)
G = ot.sinkhorn(u, u, M, 1, method='sinkhorn', stopThr=1e-10)
Gl = nx.to_numpy(ot.sinkhorn(ub, ub, M_nx, 1, method='sinkhorn_log', stopThr=1e-10))
@@ -298,8 +295,7 @@ def test_sinkhorn_variants_dtype_device(nx, method):
for tp in nx.__type_list__:
print(nx.dtype_device(tp))
- ub = nx.from_numpy(u, type_as=tp)
- Mb = nx.from_numpy(M, type_as=tp)
+ ub, Mb = nx.from_numpy(u, M, type_as=tp)
Gb = ot.sinkhorn(ub, ub, Mb, 1, method=method, stopThr=1e-10)
@@ -318,8 +314,7 @@ def test_sinkhorn2_variants_dtype_device(nx, method):
for tp in nx.__type_list__:
print(nx.dtype_device(tp))
- ub = nx.from_numpy(u, type_as=tp)
- Mb = nx.from_numpy(M, type_as=tp)
+ ub, Mb = nx.from_numpy(u, M, type_as=tp)
lossb = ot.sinkhorn2(ub, ub, Mb, 1, method=method, stopThr=1e-10)
@@ -337,8 +332,7 @@ def test_sinkhorn2_variants_device_tf(method):
# Check that everything stays on the CPU
with tf.device("/CPU:0"):
- ub = nx.from_numpy(u)
- Mb = nx.from_numpy(M)
+ ub, Mb = nx.from_numpy(u, M)
Gb = ot.sinkhorn(ub, ub, Mb, 1, method=method, stopThr=1e-10)
lossb = ot.sinkhorn2(ub, ub, Mb, 1, method=method, stopThr=1e-10)
nx.assert_same_dtype_device(Mb, Gb)
@@ -346,8 +340,7 @@ def test_sinkhorn2_variants_device_tf(method):
if len(tf.config.list_physical_devices('GPU')) > 0:
# Check that everything happens on the GPU
- ub = nx.from_numpy(u)
- Mb = nx.from_numpy(M)
+ ub, Mb = nx.from_numpy(u, M)
Gb = ot.sinkhorn(ub, ub, Mb, 1, method=method, stopThr=1e-10)
lossb = ot.sinkhorn2(ub, ub, Mb, 1, method=method, stopThr=1e-10)
nx.assert_same_dtype_device(Mb, Gb)
@@ -370,9 +363,7 @@ def test_sinkhorn_variants_multi_b(nx):
M = ot.dist(x, x)
- ub = nx.from_numpy(u)
- bb = nx.from_numpy(b)
- M_nx = nx.from_numpy(M)
+ ub, bb, M_nx = nx.from_numpy(u, b, M)
G = ot.sinkhorn(u, b, M, 1, method='sinkhorn', stopThr=1e-10)
Gl = nx.to_numpy(ot.sinkhorn(ub, bb, M_nx, 1, method='sinkhorn_log', stopThr=1e-10))
@@ -400,9 +391,7 @@ def test_sinkhorn2_variants_multi_b(nx):
M = ot.dist(x, x)
- ub = nx.from_numpy(u)
- bb = nx.from_numpy(b)
- M_nx = nx.from_numpy(M)
+ ub, bb, M_nx = nx.from_numpy(u, b, M)
G = ot.sinkhorn2(u, b, M, 1, method='sinkhorn', stopThr=1e-10)
Gl = nx.to_numpy(ot.sinkhorn2(ub, bb, M_nx, 1, method='sinkhorn_log', stopThr=1e-10))
@@ -483,9 +472,7 @@ def test_barycenter(nx, method, verbose, warn):
alpha = 0.5 # 0<=alpha<=1
weights = np.array([1 - alpha, alpha])
- A_nx = nx.from_numpy(A)
- M_nx = nx.from_numpy(M)
- weights_nx = nx.from_numpy(weights)
+ A_nx, M_nx, weights_nx = nx.from_numpy(A, M, weights)
reg = 1e-2
if nx.__name__ in ("jax", "tf") and method == "sinkhorn_log":
@@ -523,9 +510,7 @@ def test_barycenter_debiased(nx, method, verbose, warn):
alpha = 0.5 # 0<=alpha<=1
weights = np.array([1 - alpha, alpha])
- A_nx = nx.from_numpy(A)
- M_nx = nx.from_numpy(M)
- weights_nx = nx.from_numpy(weights)
+ A_nx, M_nx, weights_nx = nx.from_numpy(A, M, weights)
# wasserstein
reg = 1e-2
@@ -594,9 +579,7 @@ def test_barycenter_stabilization(nx):
alpha = 0.5 # 0<=alpha<=1
weights = np.array([1 - alpha, alpha])
- A_nx = nx.from_numpy(A)
- M_nx = nx.from_numpy(M)
- weights_b = nx.from_numpy(weights)
+ A_nx, M_nx, weights_b = nx.from_numpy(A, M, weights)
# wasserstein
reg = 1e-2
@@ -635,7 +618,7 @@ def test_wasserstein_bary_2d(nx, method):
with pytest.raises(NotImplementedError):
ot.bregman.convolutional_barycenter2d(A_nx, reg, method=method)
else:
- bary_wass_np = ot.bregman.convolutional_barycenter2d(A, reg, method=method)
+ bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d(A, reg, method=method, verbose=True, log=True)
bary_wass = nx.to_numpy(ot.bregman.convolutional_barycenter2d(A_nx, reg, method=method))
np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3)
@@ -667,7 +650,7 @@ def test_wasserstein_bary_2d_debiased(nx, method):
with pytest.raises(NotImplementedError):
ot.bregman.convolutional_barycenter2d_debiased(A_nx, reg, method=method)
else:
- bary_wass_np = ot.bregman.convolutional_barycenter2d_debiased(A, reg, method=method)
+ bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d_debiased(A, reg, method=method, verbose=True, log=True)
bary_wass = nx.to_numpy(ot.bregman.convolutional_barycenter2d_debiased(A_nx, reg, method=method))
np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3)
@@ -697,11 +680,7 @@ def test_unmix(nx):
M0 /= M0.max()
h0 = ot.unif(2)
- ab = nx.from_numpy(a)
- Db = nx.from_numpy(D)
- M_nx = nx.from_numpy(M)
- M0b = nx.from_numpy(M0)
- h0b = nx.from_numpy(h0)
+ ab, Db, M_nx, M0b, h0b = nx.from_numpy(a, D, M, M0, h0)
# wasserstein
reg = 1e-3
@@ -727,12 +706,7 @@ def test_empirical_sinkhorn(nx):
M = ot.dist(X_s, X_t)
M_m = ot.dist(X_s, X_t, metric='euclidean')
- ab = nx.from_numpy(a)
- bb = nx.from_numpy(b)
- X_sb = nx.from_numpy(X_s)
- X_tb = nx.from_numpy(X_t)
- M_nx = nx.from_numpy(M, type_as=ab)
- M_mb = nx.from_numpy(M_m, type_as=ab)
+ ab, bb, X_sb, X_tb, M_nx, M_mb = nx.from_numpy(a, b, X_s, X_t, M, M_m)
G_sqe = nx.to_numpy(ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1))
sinkhorn_sqe = nx.to_numpy(ot.sinkhorn(ab, bb, M_nx, 1))
@@ -776,12 +750,7 @@ def test_lazy_empirical_sinkhorn(nx):
M = ot.dist(X_s, X_t)
M_m = ot.dist(X_s, X_t, metric='euclidean')
- ab = nx.from_numpy(a)
- bb = nx.from_numpy(b)
- X_sb = nx.from_numpy(X_s)
- X_tb = nx.from_numpy(X_t)
- M_nx = nx.from_numpy(M, type_as=ab)
- M_mb = nx.from_numpy(M_m, type_as=ab)
+ ab, bb, X_sb, X_tb, M_nx, M_mb = nx.from_numpy(a, b, X_s, X_t, M, M_m)
f, g = ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1, numIterMax=numIterMax, isLazy=True, batchSize=(1, 3), verbose=True)
f, g = nx.to_numpy(f), nx.to_numpy(g)
@@ -825,19 +794,13 @@ def test_empirical_sinkhorn_divergence(nx):
a = np.linspace(1, n, n)
a /= a.sum()
b = ot.unif(n)
- X_s = np.reshape(np.arange(n), (n, 1))
- X_t = np.reshape(np.arange(0, n * 2, 2), (n, 1))
+ X_s = np.reshape(np.arange(n, dtype=np.float64), (n, 1))
+ X_t = np.reshape(np.arange(0, n * 2, 2, dtype=np.float64), (n, 1))
M = ot.dist(X_s, X_t)
M_s = ot.dist(X_s, X_s)
M_t = ot.dist(X_t, X_t)
- ab = nx.from_numpy(a)
- bb = nx.from_numpy(b)
- X_sb = nx.from_numpy(X_s)
- X_tb = nx.from_numpy(X_t)
- M_nx = nx.from_numpy(M, type_as=ab)
- M_sb = nx.from_numpy(M_s, type_as=ab)
- M_tb = nx.from_numpy(M_t, type_as=ab)
+ ab, bb, X_sb, X_tb, M_nx, M_sb, M_tb = nx.from_numpy(a, b, X_s, X_t, M, M_s, M_t)
emp_sinkhorn_div = nx.to_numpy(ot.bregman.empirical_sinkhorn_divergence(X_sb, X_tb, 1, a=ab, b=bb))
sinkhorn_div = nx.to_numpy(
@@ -872,9 +835,7 @@ def test_stabilized_vs_sinkhorn_multidim(nx):
M /= np.median(M)
epsilon = 0.1
- ab = nx.from_numpy(a)
- bb = nx.from_numpy(b)
- M_nx = nx.from_numpy(M, type_as=ab)
+ ab, bb, M_nx = nx.from_numpy(a, b, M)
G_np, _ = ot.bregman.sinkhorn(a, b, M, reg=epsilon, method="sinkhorn", log=True)
G, log = ot.bregman.sinkhorn(ab, bb, M_nx, reg=epsilon,
@@ -936,18 +897,13 @@ def test_screenkhorn(nx):
x = rng.randn(n, 2)
M = ot.dist(x, x)
- ab = nx.from_numpy(a)
- bb = nx.from_numpy(b)
- M_nx = nx.from_numpy(M, type_as=ab)
+ ab, bb, M_nx = nx.from_numpy(a, b, M)
- # np sinkhorn
- G_sink_np = ot.sinkhorn(a, b, M, 1e-03)
# sinkhorn
- G_sink = nx.to_numpy(ot.sinkhorn(ab, bb, M_nx, 1e-03))
+ G_sink = nx.to_numpy(ot.sinkhorn(ab, bb, M_nx, 1e-1))
# screenkhorn
- G_screen = nx.to_numpy(ot.bregman.screenkhorn(ab, bb, M_nx, 1e-03, uniform=True, verbose=True))
+ G_screen = nx.to_numpy(ot.bregman.screenkhorn(ab, bb, M_nx, 1e-1, uniform=True, verbose=True))
# check marginals
- np.testing.assert_allclose(G_sink_np, G_sink)
np.testing.assert_allclose(G_sink.sum(0), G_screen.sum(0), atol=1e-02)
np.testing.assert_allclose(G_sink.sum(1), G_screen.sum(1), atol=1e-02)
diff --git a/test/test_da.py b/test/test_da.py
index 9f2bb50..4bf0ab1 100644
--- a/test/test_da.py
+++ b/test/test_da.py
@@ -19,7 +19,32 @@ except ImportError:
nosklearn = True
-def test_sinkhorn_lpl1_transport_class():
+def test_class_jax_tf():
+ backends = []
+ from ot.backend import jax, tf
+ if jax:
+ backends.append(ot.backend.JaxBackend())
+ if tf:
+ backends.append(ot.backend.TensorflowBackend())
+
+ for nx in backends:
+ ns = 150
+ nt = 200
+
+ Xs, ys = make_data_classif('3gauss', ns)
+ Xt, yt = make_data_classif('3gauss2', nt)
+
+ Xs, ys, Xt, yt = nx.from_numpy(Xs, ys, Xt, yt)
+
+ otda = ot.da.SinkhornLpl1Transport()
+
+ with pytest.raises(TypeError):
+ otda.fit(Xs=Xs, ys=ys, Xt=Xt)
+
+
+@pytest.skip_backend("jax")
+@pytest.skip_backend("tf")
+def test_sinkhorn_lpl1_transport_class(nx):
"""test_sinkhorn_transport
"""
@@ -29,6 +54,8 @@ def test_sinkhorn_lpl1_transport_class():
Xs, ys = make_data_classif('3gauss', ns)
Xt, yt = make_data_classif('3gauss2', nt)
+ Xs, ys, Xt, yt = nx.from_numpy(Xs, ys, Xt, yt)
+
otda = ot.da.SinkhornLpl1Transport()
# test its computed
@@ -44,15 +71,15 @@ def test_sinkhorn_lpl1_transport_class():
mu_s = unif(ns)
mu_t = unif(nt)
assert_allclose(
- np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3)
+ nx.to_numpy(nx.sum(otda.coupling_, axis=0)), mu_t, rtol=1e-3, atol=1e-3)
assert_allclose(
- np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3)
+ nx.to_numpy(nx.sum(otda.coupling_, axis=1)), mu_s, rtol=1e-3, atol=1e-3)
# test transform
transp_Xs = otda.transform(Xs=Xs)
assert_equal(transp_Xs.shape, Xs.shape)
- Xs_new, _ = make_data_classif('3gauss', ns + 1)
+ Xs_new = nx.from_numpy(make_data_classif('3gauss', ns + 1)[0])
transp_Xs_new = otda.transform(Xs_new)
# check that the oos method is working
@@ -62,7 +89,7 @@ def test_sinkhorn_lpl1_transport_class():
transp_Xt = otda.inverse_transform(Xt=Xt)
assert_equal(transp_Xt.shape, Xt.shape)
- Xt_new, _ = make_data_classif('3gauss2', nt + 1)
+ Xt_new = nx.from_numpy(make_data_classif('3gauss2', nt + 1)[0])
transp_Xt_new = otda.inverse_transform(Xt=Xt_new)
# check that the oos method is working
@@ -85,24 +112,26 @@ def test_sinkhorn_lpl1_transport_class():
# test unsupervised vs semi-supervised mode
otda_unsup = ot.da.SinkhornLpl1Transport()
otda_unsup.fit(Xs=Xs, ys=ys, Xt=Xt)
- n_unsup = np.sum(otda_unsup.cost_)
+ n_unsup = nx.sum(otda_unsup.cost_)
otda_semi = ot.da.SinkhornLpl1Transport()
otda_semi.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt)
assert_equal(otda_semi.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
- n_semisup = np.sum(otda_semi.cost_)
+ n_semisup = nx.sum(otda_semi.cost_)
# check that the cost matrix norms are indeed different
assert n_unsup != n_semisup, "semisupervised mode not working"
# check that the coupling forbids mass transport between labeled source
# and labeled target samples
- mass_semi = np.sum(
+ mass_semi = nx.sum(
otda_semi.coupling_[otda_semi.cost_ == otda_semi.limit_max])
assert mass_semi == 0, "semisupervised mode not working"
-def test_sinkhorn_l1l2_transport_class():
+@pytest.skip_backend("jax")
+@pytest.skip_backend("tf")
+def test_sinkhorn_l1l2_transport_class(nx):
"""test_sinkhorn_transport
"""
@@ -112,6 +141,8 @@ def test_sinkhorn_l1l2_transport_class():
Xs, ys = make_data_classif('3gauss', ns)
Xt, yt = make_data_classif('3gauss2', nt)
+ Xs, ys, Xt, yt = nx.from_numpy(Xs, ys, Xt, yt)
+
otda = ot.da.SinkhornL1l2Transport()
# test its computed
@@ -128,15 +159,15 @@ def test_sinkhorn_l1l2_transport_class():
mu_s = unif(ns)
mu_t = unif(nt)
assert_allclose(
- np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3)
+ nx.to_numpy(nx.sum(otda.coupling_, axis=0)), mu_t, rtol=1e-3, atol=1e-3)
assert_allclose(
- np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3)
+ nx.to_numpy(nx.sum(otda.coupling_, axis=1)), mu_s, rtol=1e-3, atol=1e-3)
# test transform
transp_Xs = otda.transform(Xs=Xs)
assert_equal(transp_Xs.shape, Xs.shape)
- Xs_new, _ = make_data_classif('3gauss', ns + 1)
+ Xs_new = nx.from_numpy(make_data_classif('3gauss', ns + 1)[0])
transp_Xs_new = otda.transform(Xs_new)
# check that the oos method is working
@@ -156,7 +187,7 @@ def test_sinkhorn_l1l2_transport_class():
assert_equal(transp_ys.shape[0], ys.shape[0])
assert_equal(transp_ys.shape[1], len(np.unique(yt)))
- Xt_new, _ = make_data_classif('3gauss2', nt + 1)
+ Xt_new = nx.from_numpy(make_data_classif('3gauss2', nt + 1)[0])
transp_Xt_new = otda.inverse_transform(Xt=Xt_new)
# check that the oos method is working
@@ -169,22 +200,22 @@ def test_sinkhorn_l1l2_transport_class():
# test unsupervised vs semi-supervised mode
otda_unsup = ot.da.SinkhornL1l2Transport()
otda_unsup.fit(Xs=Xs, ys=ys, Xt=Xt)
- n_unsup = np.sum(otda_unsup.cost_)
+ n_unsup = nx.sum(otda_unsup.cost_)
otda_semi = ot.da.SinkhornL1l2Transport()
otda_semi.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt)
assert_equal(otda_semi.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
- n_semisup = np.sum(otda_semi.cost_)
+ n_semisup = nx.sum(otda_semi.cost_)
# check that the cost matrix norms are indeed different
assert n_unsup != n_semisup, "semisupervised mode not working"
# check that the coupling forbids mass transport between labeled source
# and labeled target samples
- mass_semi = np.sum(
+ mass_semi = nx.sum(
otda_semi.coupling_[otda_semi.cost_ == otda_semi.limit_max])
mass_semi = otda_semi.coupling_[otda_semi.cost_ == otda_semi.limit_max]
- assert_allclose(mass_semi, np.zeros_like(mass_semi),
+ assert_allclose(nx.to_numpy(mass_semi), np.zeros(list(mass_semi.shape)),
rtol=1e-9, atol=1e-9)
# check everything runs well with log=True
@@ -193,7 +224,9 @@ def test_sinkhorn_l1l2_transport_class():
assert len(otda.log_.keys()) != 0
-def test_sinkhorn_transport_class():
+@pytest.skip_backend("jax")
+@pytest.skip_backend("tf")
+def test_sinkhorn_transport_class(nx):
"""test_sinkhorn_transport
"""
@@ -203,6 +236,8 @@ def test_sinkhorn_transport_class():
Xs, ys = make_data_classif('3gauss', ns)
Xt, yt = make_data_classif('3gauss2', nt)
+ Xs, ys, Xt, yt = nx.from_numpy(Xs, ys, Xt, yt)
+
otda = ot.da.SinkhornTransport()
# test its computed
@@ -219,15 +254,15 @@ def test_sinkhorn_transport_class():
mu_s = unif(ns)
mu_t = unif(nt)
assert_allclose(
- np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3)
+ nx.to_numpy(nx.sum(otda.coupling_, axis=0)), mu_t, rtol=1e-3, atol=1e-3)
assert_allclose(
- np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3)
+ nx.to_numpy(nx.sum(otda.coupling_, axis=1)), mu_s, rtol=1e-3, atol=1e-3)
# test transform
transp_Xs = otda.transform(Xs=Xs)
assert_equal(transp_Xs.shape, Xs.shape)
- Xs_new, _ = make_data_classif('3gauss', ns + 1)
+ Xs_new = nx.from_numpy(make_data_classif('3gauss', ns + 1)[0])
transp_Xs_new = otda.transform(Xs_new)
# check that the oos method is working
@@ -247,7 +282,7 @@ def test_sinkhorn_transport_class():
assert_equal(transp_ys.shape[0], ys.shape[0])
assert_equal(transp_ys.shape[1], len(np.unique(yt)))
- Xt_new, _ = make_data_classif('3gauss2', nt + 1)
+ Xt_new = nx.from_numpy(make_data_classif('3gauss2', nt + 1)[0])
transp_Xt_new = otda.inverse_transform(Xt=Xt_new)
# check that the oos method is working
@@ -260,19 +295,19 @@ def test_sinkhorn_transport_class():
# test unsupervised vs semi-supervised mode
otda_unsup = ot.da.SinkhornTransport()
otda_unsup.fit(Xs=Xs, Xt=Xt)
- n_unsup = np.sum(otda_unsup.cost_)
+ n_unsup = nx.sum(otda_unsup.cost_)
otda_semi = ot.da.SinkhornTransport()
otda_semi.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt)
assert_equal(otda_semi.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
- n_semisup = np.sum(otda_semi.cost_)
+ n_semisup = nx.sum(otda_semi.cost_)
# check that the cost matrix norms are indeed different
assert n_unsup != n_semisup, "semisupervised mode not working"
# check that the coupling forbids mass transport between labeled source
# and labeled target samples
- mass_semi = np.sum(
+ mass_semi = nx.sum(
otda_semi.coupling_[otda_semi.cost_ == otda_semi.limit_max])
assert mass_semi == 0, "semisupervised mode not working"
@@ -282,7 +317,9 @@ def test_sinkhorn_transport_class():
assert len(otda.log_.keys()) != 0
-def test_unbalanced_sinkhorn_transport_class():
+@pytest.skip_backend("jax")
+@pytest.skip_backend("tf")
+def test_unbalanced_sinkhorn_transport_class(nx):
"""test_sinkhorn_transport
"""
@@ -292,6 +329,8 @@ def test_unbalanced_sinkhorn_transport_class():
Xs, ys = make_data_classif('3gauss', ns)
Xt, yt = make_data_classif('3gauss2', nt)
+ Xs, ys, Xt, yt = nx.from_numpy(Xs, ys, Xt, yt)
+
otda = ot.da.UnbalancedSinkhornTransport()
# test its computed
@@ -318,7 +357,7 @@ def test_unbalanced_sinkhorn_transport_class():
assert_equal(transp_ys.shape[0], ys.shape[0])
assert_equal(transp_ys.shape[1], len(np.unique(yt)))
- Xs_new, _ = make_data_classif('3gauss', ns + 1)
+ Xs_new = nx.from_numpy(make_data_classif('3gauss', ns + 1)[0])
transp_Xs_new = otda.transform(Xs_new)
# check that the oos method is working
@@ -328,7 +367,7 @@ def test_unbalanced_sinkhorn_transport_class():
transp_Xt = otda.inverse_transform(Xt=Xt)
assert_equal(transp_Xt.shape, Xt.shape)
- Xt_new, _ = make_data_classif('3gauss2', nt + 1)
+ Xt_new = nx.from_numpy(make_data_classif('3gauss2', nt + 1)[0])
transp_Xt_new = otda.inverse_transform(Xt=Xt_new)
# check that the oos method is working
@@ -341,12 +380,12 @@ def test_unbalanced_sinkhorn_transport_class():
# test unsupervised vs semi-supervised mode
otda_unsup = ot.da.SinkhornTransport()
otda_unsup.fit(Xs=Xs, Xt=Xt)
- n_unsup = np.sum(otda_unsup.cost_)
+ n_unsup = nx.sum(otda_unsup.cost_)
otda_semi = ot.da.SinkhornTransport()
otda_semi.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt)
assert_equal(otda_semi.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
- n_semisup = np.sum(otda_semi.cost_)
+ n_semisup = nx.sum(otda_semi.cost_)
# check that the cost matrix norms are indeed different
assert n_unsup != n_semisup, "semisupervised mode not working"
@@ -357,7 +396,9 @@ def test_unbalanced_sinkhorn_transport_class():
assert len(otda.log_.keys()) != 0
-def test_emd_transport_class():
+@pytest.skip_backend("jax")
+@pytest.skip_backend("tf")
+def test_emd_transport_class(nx):
"""test_sinkhorn_transport
"""
@@ -367,6 +408,8 @@ def test_emd_transport_class():
Xs, ys = make_data_classif('3gauss', ns)
Xt, yt = make_data_classif('3gauss2', nt)
+ Xs, ys, Xt, yt = nx.from_numpy(Xs, ys, Xt, yt)
+
otda = ot.da.EMDTransport()
# test its computed
@@ -382,15 +425,15 @@ def test_emd_transport_class():
mu_s = unif(ns)
mu_t = unif(nt)
assert_allclose(
- np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3)
+ nx.to_numpy(nx.sum(otda.coupling_, axis=0)), mu_t, rtol=1e-3, atol=1e-3)
assert_allclose(
- np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3)
+ nx.to_numpy(nx.sum(otda.coupling_, axis=1)), mu_s, rtol=1e-3, atol=1e-3)
# test transform
transp_Xs = otda.transform(Xs=Xs)
assert_equal(transp_Xs.shape, Xs.shape)
- Xs_new, _ = make_data_classif('3gauss', ns + 1)
+ Xs_new = nx.from_numpy(make_data_classif('3gauss', ns + 1)[0])
transp_Xs_new = otda.transform(Xs_new)
# check that the oos method is working
@@ -410,7 +453,7 @@ def test_emd_transport_class():
assert_equal(transp_ys.shape[0], ys.shape[0])
assert_equal(transp_ys.shape[1], len(np.unique(yt)))
- Xt_new, _ = make_data_classif('3gauss2', nt + 1)
+ Xt_new = nx.from_numpy(make_data_classif('3gauss2', nt + 1)[0])
transp_Xt_new = otda.inverse_transform(Xt=Xt_new)
# check that the oos method is working
@@ -423,28 +466,32 @@ def test_emd_transport_class():
# test unsupervised vs semi-supervised mode
otda_unsup = ot.da.EMDTransport()
otda_unsup.fit(Xs=Xs, ys=ys, Xt=Xt)
- n_unsup = np.sum(otda_unsup.cost_)
+ n_unsup = nx.sum(otda_unsup.cost_)
otda_semi = ot.da.EMDTransport()
otda_semi.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt)
assert_equal(otda_semi.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
- n_semisup = np.sum(otda_semi.cost_)
+ n_semisup = nx.sum(otda_semi.cost_)
# check that the cost matrix norms are indeed different
assert n_unsup != n_semisup, "semisupervised mode not working"
# check that the coupling forbids mass transport between labeled source
# and labeled target samples
- mass_semi = np.sum(
+ mass_semi = nx.sum(
otda_semi.coupling_[otda_semi.cost_ == otda_semi.limit_max])
mass_semi = otda_semi.coupling_[otda_semi.cost_ == otda_semi.limit_max]
# we need to use a small tolerance here, otherwise the test breaks
- assert_allclose(mass_semi, np.zeros_like(mass_semi),
+ assert_allclose(nx.to_numpy(mass_semi), np.zeros(list(mass_semi.shape)),
rtol=1e-2, atol=1e-2)
-def test_mapping_transport_class():
+@pytest.skip_backend("jax")
+@pytest.skip_backend("tf")
+@pytest.mark.parametrize("kernel", ["linear", "gaussian"])
+@pytest.mark.parametrize("bias", ["unbiased", "biased"])
+def test_mapping_transport_class(nx, kernel, bias):
"""test_mapping_transport
"""
@@ -455,101 +502,29 @@ def test_mapping_transport_class():
Xt, yt = make_data_classif('3gauss2', nt)
Xs_new, _ = make_data_classif('3gauss', ns + 1)
- ##########################################################################
- # kernel == linear mapping tests
- ##########################################################################
+ Xs, Xt, Xs_new = nx.from_numpy(Xs, Xt, Xs_new)
- # check computation and dimensions if bias == False
- otda = ot.da.MappingTransport(kernel="linear", bias=False)
+ # Mapping tests
+ bias = bias == "biased"
+ otda = ot.da.MappingTransport(kernel=kernel, bias=bias)
otda.fit(Xs=Xs, Xt=Xt)
assert hasattr(otda, "coupling_")
assert hasattr(otda, "mapping_")
assert hasattr(otda, "log_")
assert_equal(otda.coupling_.shape, ((Xs.shape[0], Xt.shape[0])))
- assert_equal(otda.mapping_.shape, ((Xs.shape[1], Xt.shape[1])))
+ S = Xs.shape[0] if kernel == "gaussian" else Xs.shape[1] # if linear
+ if bias:
+ S += 1
+ assert_equal(otda.mapping_.shape, ((S, Xt.shape[1])))
# test margin constraints
mu_s = unif(ns)
mu_t = unif(nt)
assert_allclose(
- np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3)
+ nx.to_numpy(nx.sum(otda.coupling_, axis=0)), mu_t, rtol=1e-3, atol=1e-3)
assert_allclose(
- np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3)
-
- # test transform
- transp_Xs = otda.transform(Xs=Xs)
- assert_equal(transp_Xs.shape, Xs.shape)
-
- transp_Xs_new = otda.transform(Xs_new)
-
- # check that the oos method is working
- assert_equal(transp_Xs_new.shape, Xs_new.shape)
-
- # check computation and dimensions if bias == True
- otda = ot.da.MappingTransport(kernel="linear", bias=True)
- otda.fit(Xs=Xs, Xt=Xt)
- assert_equal(otda.coupling_.shape, ((Xs.shape[0], Xt.shape[0])))
- assert_equal(otda.mapping_.shape, ((Xs.shape[1] + 1, Xt.shape[1])))
-
- # test margin constraints
- mu_s = unif(ns)
- mu_t = unif(nt)
- assert_allclose(
- np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3)
- assert_allclose(
- np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3)
-
- # test transform
- transp_Xs = otda.transform(Xs=Xs)
- assert_equal(transp_Xs.shape, Xs.shape)
-
- transp_Xs_new = otda.transform(Xs_new)
-
- # check that the oos method is working
- assert_equal(transp_Xs_new.shape, Xs_new.shape)
-
- ##########################################################################
- # kernel == gaussian mapping tests
- ##########################################################################
-
- # check computation and dimensions if bias == False
- otda = ot.da.MappingTransport(kernel="gaussian", bias=False)
- otda.fit(Xs=Xs, Xt=Xt)
-
- assert_equal(otda.coupling_.shape, ((Xs.shape[0], Xt.shape[0])))
- assert_equal(otda.mapping_.shape, ((Xs.shape[0], Xt.shape[1])))
-
- # test margin constraints
- mu_s = unif(ns)
- mu_t = unif(nt)
- assert_allclose(
- np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3)
- assert_allclose(
- np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3)
-
- # test transform
- transp_Xs = otda.transform(Xs=Xs)
- assert_equal(transp_Xs.shape, Xs.shape)
-
- transp_Xs_new = otda.transform(Xs_new)
-
- # check that the oos method is working
- assert_equal(transp_Xs_new.shape, Xs_new.shape)
-
- # check computation and dimensions if bias == True
- otda = ot.da.MappingTransport(kernel="gaussian", bias=True)
- otda.fit(Xs=Xs, Xt=Xt)
- assert_equal(otda.coupling_.shape, ((Xs.shape[0], Xt.shape[0])))
- assert_equal(otda.mapping_.shape, ((Xs.shape[0] + 1, Xt.shape[1])))
-
- # test margin constraints
- mu_s = unif(ns)
- mu_t = unif(nt)
- assert_allclose(
- np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3)
- assert_allclose(
- np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3)
+ nx.to_numpy(nx.sum(otda.coupling_, axis=1)), mu_s, rtol=1e-3, atol=1e-3)
# test transform
transp_Xs = otda.transform(Xs=Xs)
@@ -561,29 +536,39 @@ def test_mapping_transport_class():
assert_equal(transp_Xs_new.shape, Xs_new.shape)
# check everything runs well with log=True
- otda = ot.da.MappingTransport(kernel="gaussian", log=True)
+ otda = ot.da.MappingTransport(kernel=kernel, bias=bias, log=True)
otda.fit(Xs=Xs, Xt=Xt)
assert len(otda.log_.keys()) != 0
+
+@pytest.skip_backend("jax")
+@pytest.skip_backend("tf")
+def test_mapping_transport_class_specific_seed(nx):
# check that it does not crash when derphi is very close to 0
+ ns = 20
+ nt = 30
np.random.seed(39)
Xs, ys = make_data_classif('3gauss', ns)
Xt, yt = make_data_classif('3gauss2', nt)
otda = ot.da.MappingTransport(kernel="gaussian", bias=False)
- otda.fit(Xs=Xs, Xt=Xt)
+ otda.fit(Xs=nx.from_numpy(Xs), Xt=nx.from_numpy(Xt))
np.random.seed(None)
-def test_linear_mapping():
+@pytest.skip_backend("jax")
+@pytest.skip_backend("tf")
+def test_linear_mapping(nx):
ns = 150
nt = 200
Xs, ys = make_data_classif('3gauss', ns)
Xt, yt = make_data_classif('3gauss2', nt)
- A, b = ot.da.OT_mapping_linear(Xs, Xt)
+ Xsb, Xtb = nx.from_numpy(Xs, Xt)
- Xst = Xs.dot(A) + b
+ A, b = ot.da.OT_mapping_linear(Xsb, Xtb)
+
+ Xst = nx.to_numpy(nx.dot(Xsb, A) + b)
Ct = np.cov(Xt.T)
Cst = np.cov(Xst.T)
@@ -591,22 +576,26 @@ def test_linear_mapping():
np.testing.assert_allclose(Ct, Cst, rtol=1e-2, atol=1e-2)
-def test_linear_mapping_class():
+@pytest.skip_backend("jax")
+@pytest.skip_backend("tf")
+def test_linear_mapping_class(nx):
ns = 150
nt = 200
Xs, ys = make_data_classif('3gauss', ns)
Xt, yt = make_data_classif('3gauss2', nt)
+ Xsb, Xtb = nx.from_numpy(Xs, Xt)
+
otmap = ot.da.LinearTransport()
- otmap.fit(Xs=Xs, Xt=Xt)
+ otmap.fit(Xs=Xsb, Xt=Xtb)
assert hasattr(otmap, "A_")
assert hasattr(otmap, "B_")
assert hasattr(otmap, "A1_")
assert hasattr(otmap, "B1_")
- Xst = otmap.transform(Xs=Xs)
+ Xst = nx.to_numpy(otmap.transform(Xs=Xsb))
Ct = np.cov(Xt.T)
Cst = np.cov(Xst.T)
@@ -614,7 +603,9 @@ def test_linear_mapping_class():
np.testing.assert_allclose(Ct, Cst, rtol=1e-2, atol=1e-2)
-def test_jcpot_transport_class():
+@pytest.skip_backend("jax")
+@pytest.skip_backend("tf")
+def test_jcpot_transport_class(nx):
"""test_jcpot_transport
"""
@@ -627,6 +618,8 @@ def test_jcpot_transport_class():
Xt, yt = make_data_classif('3gauss2', nt)
+ Xs1, ys1, Xs2, ys2, Xt, yt = nx.from_numpy(Xs1, ys1, Xs2, ys2, Xt, yt)
+
Xs = [Xs1, Xs2]
ys = [ys1, ys2]
@@ -649,19 +642,24 @@ def test_jcpot_transport_class():
for i in range(len(Xs)):
# test margin constraints w.r.t. uniform target weights for each coupling matrix
assert_allclose(
- np.sum(otda.coupling_[i], axis=0), mu_t, rtol=1e-3, atol=1e-3)
+ nx.to_numpy(nx.sum(otda.coupling_[i], axis=0)), mu_t, rtol=1e-3, atol=1e-3)
# test margin constraints w.r.t. modified source weights for each source domain
assert_allclose(
- np.dot(otda.log_['D1'][i], np.sum(otda.coupling_[i], axis=1)), otda.proportions_, rtol=1e-3,
- atol=1e-3)
+ nx.to_numpy(
+ nx.dot(otda.log_['D1'][i], nx.sum(otda.coupling_[i], axis=1))
+ ),
+ nx.to_numpy(otda.proportions_),
+ rtol=1e-3,
+ atol=1e-3
+ )
# test transform
transp_Xs = otda.transform(Xs=Xs)
[assert_equal(x.shape, y.shape) for x, y in zip(transp_Xs, Xs)]
- Xs_new, _ = make_data_classif('3gauss', ns1 + 1)
+ Xs_new = nx.from_numpy(make_data_classif('3gauss', ns1 + 1)[0])
transp_Xs_new = otda.transform(Xs_new)
# check that the oos method is working
@@ -670,15 +668,16 @@ def test_jcpot_transport_class():
# check label propagation
transp_yt = otda.transform_labels(ys)
assert_equal(transp_yt.shape[0], yt.shape[0])
- assert_equal(transp_yt.shape[1], len(np.unique(ys)))
+ assert_equal(transp_yt.shape[1], len(np.unique(nx.to_numpy(*ys))))
# check inverse label propagation
transp_ys = otda.inverse_transform_labels(yt)
- [assert_equal(x.shape[0], y.shape[0]) for x, y in zip(transp_ys, ys)]
- [assert_equal(x.shape[1], len(np.unique(y))) for x, y in zip(transp_ys, ys)]
+ for x, y in zip(transp_ys, ys):
+ assert_equal(x.shape[0], y.shape[0])
+ assert_equal(x.shape[1], len(np.unique(nx.to_numpy(y))))
-def test_jcpot_barycenter():
+def test_jcpot_barycenter(nx):
"""test_jcpot_barycenter
"""
@@ -695,19 +694,23 @@ def test_jcpot_barycenter():
Xs1, ys1 = make_data_classif('2gauss_prop', ns1, nz=sigma, p=ps1)
Xs2, ys2 = make_data_classif('2gauss_prop', ns2, nz=sigma, p=ps2)
- Xt, yt = make_data_classif('2gauss_prop', nt, nz=sigma, p=pt)
+ Xt, _ = make_data_classif('2gauss_prop', nt, nz=sigma, p=pt)
- Xs = [Xs1, Xs2]
- ys = [ys1, ys2]
+ Xs1b, ys1b, Xs2b, ys2b, Xtb = nx.from_numpy(Xs1, ys1, Xs2, ys2, Xt)
- prop = ot.bregman.jcpot_barycenter(Xs, ys, Xt, reg=.5, metric='sqeuclidean',
+ Xsb = [Xs1b, Xs2b]
+ ysb = [ys1b, ys2b]
+
+ prop = ot.bregman.jcpot_barycenter(Xsb, ysb, Xtb, reg=.5, metric='sqeuclidean',
numItermax=10000, stopThr=1e-9, verbose=False, log=False)
- np.testing.assert_allclose(prop, [1 - pt, pt], rtol=1e-3, atol=1e-3)
+ np.testing.assert_allclose(nx.to_numpy(prop), [1 - pt, pt], rtol=1e-3, atol=1e-3)
@pytest.mark.skipif(nosklearn, reason="No sklearn available")
-def test_emd_laplace_class():
+@pytest.skip_backend("jax")
+@pytest.skip_backend("tf")
+def test_emd_laplace_class(nx):
"""test_emd_laplace_transport
"""
ns = 150
@@ -716,6 +719,8 @@ def test_emd_laplace_class():
Xs, ys = make_data_classif('3gauss', ns)
Xt, yt = make_data_classif('3gauss2', nt)
+ Xs, ys, Xt, yt = nx.from_numpy(Xs, ys, Xt, yt)
+
otda = ot.da.EMDLaplaceTransport(reg_lap=0.01, max_iter=1000, tol=1e-9, verbose=False, log=True)
# test its computed
@@ -732,15 +737,15 @@ def test_emd_laplace_class():
mu_t = unif(nt)
assert_allclose(
- np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3)
+ nx.to_numpy(nx.sum(otda.coupling_, axis=0)), mu_t, rtol=1e-3, atol=1e-3)
assert_allclose(
- np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3)
+ nx.to_numpy(nx.sum(otda.coupling_, axis=1)), mu_s, rtol=1e-3, atol=1e-3)
# test transform
transp_Xs = otda.transform(Xs=Xs)
[assert_equal(x.shape, y.shape) for x, y in zip(transp_Xs, Xs)]
- Xs_new, _ = make_data_classif('3gauss', ns + 1)
+ Xs_new = nx.from_numpy(make_data_classif('3gauss', ns + 1)[0])
transp_Xs_new = otda.transform(Xs_new)
# check that the oos method is working
@@ -750,7 +755,7 @@ def test_emd_laplace_class():
transp_Xt = otda.inverse_transform(Xt=Xt)
assert_equal(transp_Xt.shape, Xt.shape)
- Xt_new, _ = make_data_classif('3gauss2', nt + 1)
+ Xt_new = nx.from_numpy(make_data_classif('3gauss2', nt + 1)[0])
transp_Xt_new = otda.inverse_transform(Xt=Xt_new)
# check that the oos method is working
@@ -763,9 +768,9 @@ def test_emd_laplace_class():
# check label propagation
transp_yt = otda.transform_labels(ys)
assert_equal(transp_yt.shape[0], yt.shape[0])
- assert_equal(transp_yt.shape[1], len(np.unique(ys)))
+ assert_equal(transp_yt.shape[1], len(np.unique(nx.to_numpy(ys))))
# check inverse label propagation
transp_ys = otda.inverse_transform_labels(yt)
assert_equal(transp_ys.shape[0], ys.shape[0])
- assert_equal(transp_ys.shape[1], len(np.unique(yt)))
+ assert_equal(transp_ys.shape[1], len(np.unique(nx.to_numpy(yt))))
diff --git a/test/test_dr.py b/test/test_dr.py
index 741f2ad..6d7fc9a 100644
--- a/test/test_dr.py
+++ b/test/test_dr.py
@@ -61,6 +61,28 @@ def test_wda():
@pytest.mark.skipif(nogo, reason="Missing modules (autograd or pymanopt)")
+def test_wda_low_reg():
+
+ n_samples = 100 # nb samples in source and target datasets
+ np.random.seed(0)
+
+ # generate gaussian dataset
+ xs, ys = ot.datasets.make_data_classif('gaussrot', n_samples)
+
+ n_features_noise = 8
+
+ xs = np.hstack((xs, np.random.randn(n_samples, n_features_noise)))
+
+ p = 2
+
+ Pwda, projwda = ot.dr.wda(xs, ys, p, reg=0.01, maxiter=10, sinkhorn_method='sinkhorn_log')
+
+ projwda(xs)
+
+ np.testing.assert_allclose(np.sum(Pwda**2, 0), np.ones(p))
+
+
+@pytest.mark.skipif(nogo, reason="Missing modules (autograd or pymanopt)")
def test_wda_normalized():
n_samples = 100 # nb samples in source and target datasets
diff --git a/test/test_factored.py b/test/test_factored.py
new file mode 100644
index 0000000..fd2fd01
--- /dev/null
+++ b/test/test_factored.py
@@ -0,0 +1,56 @@
+"""Tests for main module ot.weak """
+
+# Author: Remi Flamary <remi.flamary@unice.fr>
+#
+# License: MIT License
+
+import ot
+import numpy as np
+
+
+def test_factored_ot():
+ # test weak ot solver and identity stationary point
+ n = 50
+ rng = np.random.RandomState(0)
+
+ xs = rng.randn(n, 2)
+ xt = rng.randn(n, 2)
+ u = ot.utils.unif(n)
+
+ Ga, Gb, X, log = ot.factored_optimal_transport(xs, xt, u, u, r=10, log=True)
+
+ # check constraints
+ np.testing.assert_allclose(u, Ga.sum(1))
+ np.testing.assert_allclose(u, Gb.sum(0))
+
+ Ga, Gb, X, log = ot.factored_optimal_transport(xs, xt, u, u, reg=1, r=10, log=True)
+
+ # check constraints
+ np.testing.assert_allclose(u, Ga.sum(1))
+ np.testing.assert_allclose(u, Gb.sum(0))
+
+
+def test_factored_ot_backends(nx):
+ # test weak ot solver for different backends
+ n = 50
+ rng = np.random.RandomState(0)
+
+ xs = rng.randn(n, 2)
+ xt = rng.randn(n, 2)
+ u = ot.utils.unif(n)
+
+ xs2 = nx.from_numpy(xs)
+ xt2 = nx.from_numpy(xt)
+ u2 = nx.from_numpy(u)
+
+ Ga2, Gb2, X2 = ot.factored_optimal_transport(xs2, xt2, u2, u2, r=10)
+
+ # check constraints
+ np.testing.assert_allclose(u, nx.to_numpy(Ga2).sum(1))
+ np.testing.assert_allclose(u, nx.to_numpy(Gb2).sum(0))
+
+ Ga2, Gb2, X2 = ot.factored_optimal_transport(xs2, xt2, reg=1, r=10, X0=X2)
+
+ # check constraints
+ np.testing.assert_allclose(u, nx.to_numpy(Ga2).sum(1))
+ np.testing.assert_allclose(u, nx.to_numpy(Gb2).sum(0))
diff --git a/test/test_gpu.py b/test/test_gpu.py
deleted file mode 100644
index 8e62a74..0000000
--- a/test/test_gpu.py
+++ /dev/null
@@ -1,106 +0,0 @@
-"""Tests for module gpu for gpu acceleration """
-
-# Author: Remi Flamary <remi.flamary@unice.fr>
-#
-# License: MIT License
-
-import numpy as np
-import ot
-import pytest
-
-try: # test if cudamat installed
- import ot.gpu
- nogpu = False
-except ImportError:
- nogpu = True
-
-
-@pytest.mark.skipif(nogpu, reason="No GPU available")
-def test_gpu_old_doctests():
- a = [.5, .5]
- b = [.5, .5]
- M = [[0., 1.], [1., 0.]]
- G = ot.sinkhorn(a, b, M, 1)
- np.testing.assert_allclose(G, np.array([[0.36552929, 0.13447071],
- [0.13447071, 0.36552929]]))
-
-
-@pytest.mark.skipif(nogpu, reason="No GPU available")
-def test_gpu_dist():
-
- rng = np.random.RandomState(0)
-
- for n_samples in [50, 100, 500, 1000]:
- print(n_samples)
- a = rng.rand(n_samples // 4, 100)
- b = rng.rand(n_samples, 100)
-
- M = ot.dist(a.copy(), b.copy())
- M2 = ot.gpu.dist(a.copy(), b.copy())
-
- np.testing.assert_allclose(M, M2, rtol=1e-10)
-
- M2 = ot.gpu.dist(a.copy(), b.copy(), metric='euclidean', to_numpy=False)
-
- # check raise not implemented wrong metric
- with pytest.raises(NotImplementedError):
- M2 = ot.gpu.dist(a.copy(), b.copy(), metric='cityblock', to_numpy=False)
-
-
-@pytest.mark.skipif(nogpu, reason="No GPU available")
-def test_gpu_sinkhorn():
-
- rng = np.random.RandomState(0)
-
- for n_samples in [50, 100, 500, 1000]:
- a = rng.rand(n_samples // 4, 100)
- b = rng.rand(n_samples, 100)
-
- wa = ot.unif(n_samples // 4)
- wb = ot.unif(n_samples)
-
- wb2 = np.random.rand(n_samples, 20)
- wb2 /= wb2.sum(0, keepdims=True)
-
- M = ot.dist(a.copy(), b.copy())
- M2 = ot.gpu.dist(a.copy(), b.copy(), to_numpy=False)
-
- reg = 1
-
- G = ot.sinkhorn(wa, wb, M, reg)
- G1 = ot.gpu.sinkhorn(wa, wb, M, reg)
-
- np.testing.assert_allclose(G1, G, rtol=1e-10)
-
- # run all on gpu
- ot.gpu.sinkhorn(wa, wb, M2, reg, to_numpy=False, log=True)
-
- # run sinkhorn for multiple targets
- ot.gpu.sinkhorn(wa, wb2, M2, reg, to_numpy=False, log=True)
-
-
-@pytest.mark.skipif(nogpu, reason="No GPU available")
-def test_gpu_sinkhorn_lpl1():
-
- rng = np.random.RandomState(0)
-
- for n_samples in [50, 100, 500]:
- print(n_samples)
- a = rng.rand(n_samples // 4, 100)
- labels_a = np.random.randint(10, size=(n_samples // 4))
- b = rng.rand(n_samples, 100)
-
- wa = ot.unif(n_samples // 4)
- wb = ot.unif(n_samples)
-
- M = ot.dist(a.copy(), b.copy())
- M2 = ot.gpu.dist(a.copy(), b.copy(), to_numpy=False)
-
- reg = 1
-
- G = ot.da.sinkhorn_lpl1_mm(wa, labels_a, wb, M, reg)
- G1 = ot.gpu.da.sinkhorn_lpl1_mm(wa, labels_a, wb, M, reg)
-
- np.testing.assert_allclose(G1, G, rtol=1e-10)
-
- ot.gpu.da.sinkhorn_lpl1_mm(wa, labels_a, wb, M2, reg, to_numpy=False, log=True)
diff --git a/test/test_gromov.py b/test/test_gromov.py
index 4b995d5..9c85b92 100644
--- a/test/test_gromov.py
+++ b/test/test_gromov.py
@@ -3,6 +3,7 @@
# Author: Erwan Vautier <erwan.vautier@gmail.com>
# Nicolas Courty <ncourty@irisa.fr>
# Titouan Vayer <titouan.vayer@irisa.fr>
+# Cédric Vincent-Cuaz <cedric.vincent-cuaz@inria.fr>
#
# License: MIT License
@@ -26,6 +27,7 @@ def test_gromov(nx):
p = ot.unif(n_samples)
q = ot.unif(n_samples)
+ G0 = p[:, None] * q[None, :]
C1 = ot.dist(xs, xs)
C2 = ot.dist(xt, xt)
@@ -33,13 +35,10 @@ def test_gromov(nx):
C1 /= C1.max()
C2 /= C2.max()
- C1b = nx.from_numpy(C1)
- C2b = nx.from_numpy(C2)
- pb = nx.from_numpy(p)
- qb = nx.from_numpy(q)
+ C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0)
- G = ot.gromov.gromov_wasserstein(C1, C2, p, q, 'square_loss', verbose=True)
- Gb = nx.to_numpy(ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', verbose=True))
+ G = ot.gromov.gromov_wasserstein(C1, C2, p, q, 'square_loss', G0=G0, verbose=True)
+ Gb = nx.to_numpy(ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', G0=G0b, verbose=True))
# check constraints
np.testing.assert_allclose(G, Gb, atol=1e-06)
@@ -56,9 +55,9 @@ def test_gromov(nx):
gwb, logb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', log=True)
gwb = nx.to_numpy(gwb)
- gw_val = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', log=False)
+ gw_val = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', G0=G0, log=False)
gw_valb = nx.to_numpy(
- ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', log=False)
+ ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', G0=G0b, log=False)
)
G = log['T']
@@ -91,6 +90,7 @@ def test_gromov_dtype_device(nx):
p = ot.unif(n_samples)
q = ot.unif(n_samples)
+ G0 = p[:, None] * q[None, :]
C1 = ot.dist(xs, xs)
C2 = ot.dist(xt, xt)
@@ -101,13 +101,10 @@ def test_gromov_dtype_device(nx):
for tp in nx.__type_list__:
print(nx.dtype_device(tp))
- C1b = nx.from_numpy(C1, type_as=tp)
- C2b = nx.from_numpy(C2, type_as=tp)
- pb = nx.from_numpy(p, type_as=tp)
- qb = nx.from_numpy(q, type_as=tp)
+ C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0, type_as=tp)
- Gb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', verbose=True)
- gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', log=False)
+ Gb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', G0=G0b, verbose=True)
+ gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', G0=G0b, log=False)
nx.assert_same_dtype_device(C1b, Gb)
nx.assert_same_dtype_device(C1b, gw_valb)
@@ -123,6 +120,7 @@ def test_gromov_device_tf():
xt = xs[::-1].copy()
p = ot.unif(n_samples)
q = ot.unif(n_samples)
+ G0 = p[:, None] * q[None, :]
C1 = ot.dist(xs, xs)
C2 = ot.dist(xt, xt)
C1 /= C1.max()
@@ -130,21 +128,15 @@ def test_gromov_device_tf():
# Check that everything stays on the CPU
with tf.device("/CPU:0"):
- C1b = nx.from_numpy(C1)
- C2b = nx.from_numpy(C2)
- pb = nx.from_numpy(p)
- qb = nx.from_numpy(q)
- Gb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', verbose=True)
- gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', log=False)
+ C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0)
+ Gb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', G0=G0b, verbose=True)
+ gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', G0=G0b, log=False)
nx.assert_same_dtype_device(C1b, Gb)
nx.assert_same_dtype_device(C1b, gw_valb)
if len(tf.config.list_physical_devices('GPU')) > 0:
# Check that everything happens on the GPU
- C1b = nx.from_numpy(C1)
- C2b = nx.from_numpy(C2)
- pb = nx.from_numpy(p)
- qb = nx.from_numpy(q)
+ C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0)
Gb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', verbose=True)
gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', log=False)
nx.assert_same_dtype_device(C1b, Gb)
@@ -173,25 +165,30 @@ def test_gromov2_gradients():
if torch:
- p1 = torch.tensor(p, requires_grad=True)
- q1 = torch.tensor(q, requires_grad=True)
- C11 = torch.tensor(C1, requires_grad=True)
- C12 = torch.tensor(C2, requires_grad=True)
+ devices = [torch.device("cpu")]
+ if torch.cuda.is_available():
+ devices.append(torch.device("cuda"))
+ for device in devices:
+ p1 = torch.tensor(p, requires_grad=True, device=device)
+ q1 = torch.tensor(q, requires_grad=True, device=device)
+ C11 = torch.tensor(C1, requires_grad=True, device=device)
+ C12 = torch.tensor(C2, requires_grad=True, device=device)
- val = ot.gromov_wasserstein2(C11, C12, p1, q1)
+ val = ot.gromov_wasserstein2(C11, C12, p1, q1)
- val.backward()
+ val.backward()
- assert q1.shape == q1.grad.shape
- assert p1.shape == p1.grad.shape
- assert C11.shape == C11.grad.shape
- assert C12.shape == C12.grad.shape
+ assert val.device == p1.device
+ assert q1.shape == q1.grad.shape
+ assert p1.shape == p1.grad.shape
+ assert C11.shape == C11.grad.shape
+ assert C12.shape == C12.grad.shape
@pytest.skip_backend("jax", reason="test very slow with jax backend")
@pytest.skip_backend("tf", reason="test very slow with tf backend")
def test_entropic_gromov(nx):
- n_samples = 50 # nb samples
+ n_samples = 10 # nb samples
mu_s = np.array([0, 0])
cov_s = np.array([[1, 0], [0, 1]])
@@ -209,10 +206,7 @@ def test_entropic_gromov(nx):
C1 /= C1.max()
C2 /= C2.max()
- C1b = nx.from_numpy(C1)
- C2b = nx.from_numpy(C2)
- pb = nx.from_numpy(p)
- qb = nx.from_numpy(q)
+ C1b, C2b, pb, qb = nx.from_numpy(C1, C2, p, q)
G = ot.gromov.entropic_gromov_wasserstein(
C1, C2, p, q, 'square_loss', epsilon=5e-4, verbose=True)
@@ -228,9 +222,9 @@ def test_entropic_gromov(nx):
q, Gb.sum(0), atol=1e-04) # cf convergence gromov
gw, log = ot.gromov.entropic_gromov_wasserstein2(
- C1, C2, p, q, 'kl_loss', epsilon=1e-2, log=True)
+ C1, C2, p, q, 'kl_loss', max_iter=10, epsilon=1e-2, log=True)
gwb, logb = ot.gromov.entropic_gromov_wasserstein2(
- C1b, C2b, pb, qb, 'kl_loss', epsilon=1e-2, log=True)
+ C1b, C2b, pb, qb, 'kl_loss', max_iter=10, epsilon=1e-2, log=True)
gwb = nx.to_numpy(gwb)
G = log['T']
@@ -251,7 +245,7 @@ def test_entropic_gromov(nx):
@pytest.skip_backend("tf", reason="test very slow with tf backend")
def test_entropic_gromov_dtype_device(nx):
# setup
- n_samples = 50 # nb samples
+ n_samples = 5 # nb samples
mu_s = np.array([0, 0])
cov_s = np.array([[1, 0], [0, 1]])
@@ -272,10 +266,7 @@ def test_entropic_gromov_dtype_device(nx):
for tp in nx.__type_list__:
print(nx.dtype_device(tp))
- C1b = nx.from_numpy(C1, type_as=tp)
- C2b = nx.from_numpy(C2, type_as=tp)
- pb = nx.from_numpy(p, type_as=tp)
- qb = nx.from_numpy(q, type_as=tp)
+ C1b, C2b, pb, qb = nx.from_numpy(C1, C2, p, q, type_as=tp)
Gb = ot.gromov.entropic_gromov_wasserstein(
C1b, C2b, pb, qb, 'square_loss', epsilon=5e-4, verbose=True
@@ -289,7 +280,7 @@ def test_entropic_gromov_dtype_device(nx):
def test_pointwise_gromov(nx):
- n_samples = 50 # nb samples
+ n_samples = 5 # nb samples
mu_s = np.array([0, 0])
cov_s = np.array([[1, 0], [0, 1]])
@@ -307,10 +298,7 @@ def test_pointwise_gromov(nx):
C1 /= C1.max()
C2 /= C2.max()
- C1b = nx.from_numpy(C1)
- C2b = nx.from_numpy(C2)
- pb = nx.from_numpy(p)
- qb = nx.from_numpy(q)
+ C1b, C2b, pb, qb = nx.from_numpy(C1, C2, p, q)
def loss(x, y):
return np.abs(x - y)
@@ -343,14 +331,12 @@ def test_pointwise_gromov(nx):
Gb = nx.to_numpy(nx.todense(Gb))
np.testing.assert_allclose(G, Gb, atol=1e-06)
- np.testing.assert_allclose(float(logb['gw_dist_estimated']), 0.10342276348494964, atol=1e-8)
- np.testing.assert_allclose(float(logb['gw_dist_std']), 0.0015952535464736394, atol=1e-8)
@pytest.skip_backend("tf", reason="test very slow with tf backend")
@pytest.skip_backend("jax", reason="test very slow with jax backend")
def test_sampled_gromov(nx):
- n_samples = 50 # nb samples
+ n_samples = 5 # nb samples
mu_s = np.array([0, 0], dtype=np.float64)
cov_s = np.array([[1, 0], [0, 1]], dtype=np.float64)
@@ -368,10 +354,7 @@ def test_sampled_gromov(nx):
C1 /= C1.max()
C2 /= C2.max()
- C1b = nx.from_numpy(C1)
- C2b = nx.from_numpy(C2)
- pb = nx.from_numpy(p)
- qb = nx.from_numpy(q)
+ C1b, C2b, pb, qb = nx.from_numpy(C1, C2, p, q)
def loss(x, y):
return np.abs(x - y)
@@ -380,9 +363,9 @@ def test_sampled_gromov(nx):
return nx.abs(x - y)
G, log = ot.gromov.sampled_gromov_wasserstein(
- C1, C2, p, q, loss, max_iter=100, epsilon=1, log=True, verbose=True, random_state=42)
+ C1, C2, p, q, loss, max_iter=20, nb_samples_grad=2, epsilon=1, log=True, verbose=True, random_state=42)
Gb, logb = ot.gromov.sampled_gromov_wasserstein(
- C1b, C2b, pb, qb, lossb, max_iter=100, epsilon=1, log=True, verbose=True, random_state=42)
+ C1b, C2b, pb, qb, lossb, max_iter=20, nb_samples_grad=2, epsilon=1, log=True, verbose=True, random_state=42)
Gb = nx.to_numpy(Gb)
# check constraints
@@ -392,13 +375,10 @@ def test_sampled_gromov(nx):
np.testing.assert_allclose(
q, Gb.sum(0), atol=1e-04) # cf convergence gromov
- np.testing.assert_allclose(float(logb['gw_dist_estimated']), 0.05679474884977278, atol=1e-08)
- np.testing.assert_allclose(float(logb['gw_dist_std']), 0.0005986592106971995, atol=1e-08)
-
def test_gromov_barycenter(nx):
- ns = 10
- nt = 20
+ ns = 5
+ nt = 8
Xs, ys = ot.datasets.make_data_classif('3gauss', ns, random_state=42)
Xt, yt = ot.datasets.make_data_classif('3gauss2', nt, random_state=42)
@@ -410,19 +390,15 @@ def test_gromov_barycenter(nx):
n_samples = 3
p = ot.unif(n_samples)
- C1b = nx.from_numpy(C1)
- C2b = nx.from_numpy(C2)
- p1b = nx.from_numpy(p1)
- p2b = nx.from_numpy(p2)
- pb = nx.from_numpy(p)
+ C1b, C2b, p1b, p2b, pb = nx.from_numpy(C1, C2, p1, p2, p)
Cb = ot.gromov.gromov_barycenters(
n_samples, [C1, C2], [p1, p2], p, [.5, .5],
- 'square_loss', max_iter=100, tol=1e-3, verbose=True, random_state=42
+ 'square_loss', max_iter=100, tol=1e-3, verbose=False, random_state=42
)
Cbb = nx.to_numpy(ot.gromov.gromov_barycenters(
n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5],
- 'square_loss', max_iter=100, tol=1e-3, verbose=True, random_state=42
+ 'square_loss', max_iter=100, tol=1e-3, verbose=False, random_state=42
))
np.testing.assert_allclose(Cb, Cbb, atol=1e-06)
np.testing.assert_allclose(Cbb.shape, (n_samples, n_samples))
@@ -430,15 +406,15 @@ def test_gromov_barycenter(nx):
# test of gromov_barycenters with `log` on
Cb_, err_ = ot.gromov.gromov_barycenters(
n_samples, [C1, C2], [p1, p2], p, [.5, .5],
- 'square_loss', max_iter=100, tol=1e-3, verbose=True, random_state=42, log=True
+ 'square_loss', max_iter=100, tol=1e-3, verbose=False, random_state=42, log=True
)
Cbb_, errb_ = ot.gromov.gromov_barycenters(
n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5],
- 'square_loss', max_iter=100, tol=1e-3, verbose=True, random_state=42, log=True
+ 'square_loss', max_iter=100, tol=1e-3, verbose=False, random_state=42, log=True
)
Cbb_ = nx.to_numpy(Cbb_)
np.testing.assert_allclose(Cb_, Cbb_, atol=1e-06)
- np.testing.assert_array_almost_equal(err_['err'], errb_['err'])
+ np.testing.assert_array_almost_equal(err_['err'], nx.to_numpy(*errb_['err']))
np.testing.assert_allclose(Cbb_.shape, (n_samples, n_samples))
Cb2 = ot.gromov.gromov_barycenters(
@@ -455,22 +431,22 @@ def test_gromov_barycenter(nx):
# test of gromov_barycenters with `log` on
Cb2_, err2_ = ot.gromov.gromov_barycenters(
n_samples, [C1, C2], [p1, p2], p, [.5, .5],
- 'kl_loss', max_iter=100, tol=1e-3, verbose=True, random_state=42, log=True
+ 'kl_loss', max_iter=100, tol=1e-3, verbose=False, random_state=42, log=True
)
Cb2b_, err2b_ = ot.gromov.gromov_barycenters(
n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5],
- 'kl_loss', max_iter=100, tol=1e-3, verbose=True, random_state=42, log=True
+ 'kl_loss', max_iter=100, tol=1e-3, verbose=False, random_state=42, log=True
)
Cb2b_ = nx.to_numpy(Cb2b_)
np.testing.assert_allclose(Cb2_, Cb2b_, atol=1e-06)
- np.testing.assert_array_almost_equal(err2_['err'], err2_['err'])
+ np.testing.assert_array_almost_equal(err2_['err'], nx.to_numpy(*err2b_['err']))
np.testing.assert_allclose(Cb2b_.shape, (n_samples, n_samples))
@pytest.mark.filterwarnings("ignore:divide")
def test_gromov_entropic_barycenter(nx):
- ns = 10
- nt = 20
+ ns = 5
+ nt = 10
Xs, ys = ot.datasets.make_data_classif('3gauss', ns, random_state=42)
Xt, yt = ot.datasets.make_data_classif('3gauss2', nt, random_state=42)
@@ -482,11 +458,7 @@ def test_gromov_entropic_barycenter(nx):
n_samples = 2
p = ot.unif(n_samples)
- C1b = nx.from_numpy(C1)
- C2b = nx.from_numpy(C2)
- p1b = nx.from_numpy(p1)
- p2b = nx.from_numpy(p2)
- pb = nx.from_numpy(p)
+ C1b, C2b, p1b, p2b, pb = nx.from_numpy(C1, C2, p1, p2, p)
Cb = ot.gromov.entropic_gromov_barycenters(
n_samples, [C1, C2], [p1, p2], p, [.5, .5],
@@ -510,7 +482,7 @@ def test_gromov_entropic_barycenter(nx):
)
Cbb_ = nx.to_numpy(Cbb_)
np.testing.assert_allclose(Cb_, Cbb_, atol=1e-06)
- np.testing.assert_array_almost_equal(err_['err'], errb_['err'])
+ np.testing.assert_array_almost_equal(err_['err'], nx.to_numpy(*errb_['err']))
np.testing.assert_allclose(Cbb_.shape, (n_samples, n_samples))
Cb2 = ot.gromov.entropic_gromov_barycenters(
@@ -535,12 +507,12 @@ def test_gromov_entropic_barycenter(nx):
)
Cb2b_ = nx.to_numpy(Cb2b_)
np.testing.assert_allclose(Cb2_, Cb2b_, atol=1e-06)
- np.testing.assert_array_almost_equal(err2_['err'], err2_['err'])
+ np.testing.assert_array_almost_equal(err2_['err'], nx.to_numpy(*err2b_['err']))
np.testing.assert_allclose(Cb2b_.shape, (n_samples, n_samples))
def test_fgw(nx):
- n_samples = 50 # nb samples
+ n_samples = 20 # nb samples
mu_s = np.array([0, 0])
cov_s = np.array([[1, 0], [0, 1]])
@@ -554,6 +526,7 @@ def test_fgw(nx):
p = ot.unif(n_samples)
q = ot.unif(n_samples)
+ G0 = p[:, None] * q[None, :]
C1 = ot.dist(xs, xs)
C2 = ot.dist(xt, xt)
@@ -564,14 +537,10 @@ def test_fgw(nx):
M = ot.dist(ys, yt)
M /= M.max()
- Mb = nx.from_numpy(M)
- C1b = nx.from_numpy(C1)
- C2b = nx.from_numpy(C2)
- pb = nx.from_numpy(p)
- qb = nx.from_numpy(q)
+ Mb, C1b, C2b, pb, qb, G0b = nx.from_numpy(M, C1, C2, p, q, G0)
- G, log = ot.gromov.fused_gromov_wasserstein(M, C1, C2, p, q, 'square_loss', alpha=0.5, log=True)
- Gb, logb = ot.gromov.fused_gromov_wasserstein(Mb, C1b, C2b, pb, qb, 'square_loss', alpha=0.5, log=True)
+ G, log = ot.gromov.fused_gromov_wasserstein(M, C1, C2, p, q, 'square_loss', alpha=0.5, G0=G0, log=True)
+ Gb, logb = ot.gromov.fused_gromov_wasserstein(Mb, C1b, C2b, pb, qb, 'square_loss', alpha=0.5, G0=G0b, log=True)
Gb = nx.to_numpy(Gb)
# check constraints
@@ -586,8 +555,8 @@ def test_fgw(nx):
np.testing.assert_allclose(
Gb, np.flipud(Id), atol=1e-04) # cf convergence gromov
- fgw, log = ot.gromov.fused_gromov_wasserstein2(M, C1, C2, p, q, 'square_loss', alpha=0.5, log=True)
- fgwb, logb = ot.gromov.fused_gromov_wasserstein2(Mb, C1b, C2b, pb, qb, 'square_loss', alpha=0.5, log=True)
+ fgw, log = ot.gromov.fused_gromov_wasserstein2(M, C1, C2, p, q, 'square_loss', G0=None, alpha=0.5, log=True)
+ fgwb, logb = ot.gromov.fused_gromov_wasserstein2(Mb, C1b, C2b, pb, qb, 'square_loss', G0=G0b, alpha=0.5, log=True)
fgwb = nx.to_numpy(fgwb)
G = log['T']
@@ -605,7 +574,7 @@ def test_fgw(nx):
def test_fgw2_gradients():
- n_samples = 50 # nb samples
+ n_samples = 20 # nb samples
mu_s = np.array([0, 0])
cov_s = np.array([[1, 0], [0, 1]])
@@ -626,28 +595,33 @@ def test_fgw2_gradients():
if torch:
- p1 = torch.tensor(p, requires_grad=True)
- q1 = torch.tensor(q, requires_grad=True)
- C11 = torch.tensor(C1, requires_grad=True)
- C12 = torch.tensor(C2, requires_grad=True)
- M1 = torch.tensor(M, requires_grad=True)
+ devices = [torch.device("cpu")]
+ if torch.cuda.is_available():
+ devices.append(torch.device("cuda"))
+ for device in devices:
+ p1 = torch.tensor(p, requires_grad=True, device=device)
+ q1 = torch.tensor(q, requires_grad=True, device=device)
+ C11 = torch.tensor(C1, requires_grad=True, device=device)
+ C12 = torch.tensor(C2, requires_grad=True, device=device)
+ M1 = torch.tensor(M, requires_grad=True, device=device)
- val = ot.fused_gromov_wasserstein2(M1, C11, C12, p1, q1)
+ val = ot.fused_gromov_wasserstein2(M1, C11, C12, p1, q1)
- val.backward()
+ val.backward()
- assert q1.shape == q1.grad.shape
- assert p1.shape == p1.grad.shape
- assert C11.shape == C11.grad.shape
- assert C12.shape == C12.grad.shape
- assert M1.shape == M1.grad.shape
+ assert val.device == p1.device
+ assert q1.shape == q1.grad.shape
+ assert p1.shape == p1.grad.shape
+ assert C11.shape == C11.grad.shape
+ assert C12.shape == C12.grad.shape
+ assert M1.shape == M1.grad.shape
def test_fgw_barycenter(nx):
np.random.seed(42)
- ns = 50
- nt = 60
+ ns = 10
+ nt = 20
Xs, ys = ot.datasets.make_data_classif('3gauss', ns, random_state=42)
Xt, yt = ot.datasets.make_data_classif('3gauss2', nt, random_state=42)
@@ -661,13 +635,7 @@ def test_fgw_barycenter(nx):
n_samples = 3
p = ot.unif(n_samples)
- ysb = nx.from_numpy(ys)
- ytb = nx.from_numpy(yt)
- C1b = nx.from_numpy(C1)
- C2b = nx.from_numpy(C2)
- p1b = nx.from_numpy(p1)
- p2b = nx.from_numpy(p2)
- pb = nx.from_numpy(p)
+ ysb, ytb, C1b, C2b, p1b, p2b, pb = nx.from_numpy(ys, yt, C1, C2, p1, p2, p)
Xb, Cb = ot.gromov.fgw_barycenters(
n_samples, [ysb, ytb], [C1b, C2b], [p1b, p2b], [.5, .5], 0.5, fixed_structure=False,
@@ -698,3 +666,523 @@ def test_fgw_barycenter(nx):
Xb, Cb = nx.to_numpy(Xb), nx.to_numpy(Cb)
np.testing.assert_allclose(Cb.shape, (n_samples, n_samples))
np.testing.assert_allclose(Xb.shape, (n_samples, ys.shape[1]))
+
+
+def test_gromov_wasserstein_linear_unmixing(nx):
+ n = 4
+
+ X1, y1 = ot.datasets.make_data_classif('3gauss', n, random_state=42)
+ X2, y2 = ot.datasets.make_data_classif('3gauss2', n, random_state=42)
+
+ C1 = ot.dist(X1)
+ C2 = ot.dist(X2)
+ Cdict = np.stack([C1, C2])
+ p = ot.unif(n)
+
+ C1b, C2b, Cdictb, pb = nx.from_numpy(C1, C2, Cdict, p)
+
+ tol = 10**(-5)
+ # Tests without regularization
+ reg = 0.
+ unmixing1, C1_emb, OT, reconstruction1 = ot.gromov.gromov_wasserstein_linear_unmixing(
+ C1, Cdict, reg=reg, p=p, q=p,
+ tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200
+ )
+
+ unmixing1b, C1b_emb, OTb, reconstruction1b = ot.gromov.gromov_wasserstein_linear_unmixing(
+ C1b, Cdictb, reg=reg, p=None, q=None,
+ tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200
+ )
+
+ unmixing2, C2_emb, OT, reconstruction2 = ot.gromov.gromov_wasserstein_linear_unmixing(
+ C2, Cdict, reg=reg, p=None, q=None,
+ tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200
+ )
+
+ unmixing2b, C2b_emb, OTb, reconstruction2b = ot.gromov.gromov_wasserstein_linear_unmixing(
+ C2b, Cdictb, reg=reg, p=pb, q=pb,
+ tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200
+ )
+
+ np.testing.assert_allclose(unmixing1, nx.to_numpy(unmixing1b), atol=5e-06)
+ np.testing.assert_allclose(unmixing1, [1., 0.], atol=5e-01)
+ np.testing.assert_allclose(unmixing2, nx.to_numpy(unmixing2b), atol=5e-06)
+ np.testing.assert_allclose(unmixing2, [0., 1.], atol=5e-01)
+ np.testing.assert_allclose(C1_emb, nx.to_numpy(C1b_emb), atol=1e-06)
+ np.testing.assert_allclose(C2_emb, nx.to_numpy(C2b_emb), atol=1e-06)
+ np.testing.assert_allclose(reconstruction1, nx.to_numpy(reconstruction1b), atol=1e-06)
+ np.testing.assert_allclose(reconstruction2, nx.to_numpy(reconstruction2b), atol=1e-06)
+ np.testing.assert_allclose(C1b_emb.shape, (n, n))
+ np.testing.assert_allclose(C2b_emb.shape, (n, n))
+
+ # Tests with regularization
+
+ reg = 0.001
+ unmixing1, C1_emb, OT, reconstruction1 = ot.gromov.gromov_wasserstein_linear_unmixing(
+ C1, Cdict, reg=reg, p=p, q=p,
+ tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200
+ )
+
+ unmixing1b, C1b_emb, OTb, reconstruction1b = ot.gromov.gromov_wasserstein_linear_unmixing(
+ C1b, Cdictb, reg=reg, p=None, q=None,
+ tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200
+ )
+
+ unmixing2, C2_emb, OT, reconstruction2 = ot.gromov.gromov_wasserstein_linear_unmixing(
+ C2, Cdict, reg=reg, p=None, q=None,
+ tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200
+ )
+
+ unmixing2b, C2b_emb, OTb, reconstruction2b = ot.gromov.gromov_wasserstein_linear_unmixing(
+ C2b, Cdictb, reg=reg, p=pb, q=pb,
+ tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200
+ )
+
+ np.testing.assert_allclose(unmixing1, nx.to_numpy(unmixing1b), atol=1e-06)
+ np.testing.assert_allclose(unmixing1, [1., 0.], atol=1e-01)
+ np.testing.assert_allclose(unmixing2, nx.to_numpy(unmixing2b), atol=1e-06)
+ np.testing.assert_allclose(unmixing2, [0., 1.], atol=1e-01)
+ np.testing.assert_allclose(C1_emb, nx.to_numpy(C1b_emb), atol=1e-06)
+ np.testing.assert_allclose(C2_emb, nx.to_numpy(C2b_emb), atol=1e-06)
+ np.testing.assert_allclose(reconstruction1, nx.to_numpy(reconstruction1b), atol=1e-06)
+ np.testing.assert_allclose(reconstruction2, nx.to_numpy(reconstruction2b), atol=1e-06)
+ np.testing.assert_allclose(C1b_emb.shape, (n, n))
+ np.testing.assert_allclose(C2b_emb.shape, (n, n))
+
+
+def test_gromov_wasserstein_dictionary_learning(nx):
+
+ # create dataset composed from 2 structures which are repeated 5 times
+ shape = 4
+ n_samples = 2
+ n_atoms = 2
+ projection = 'nonnegative_symmetric'
+ X1, y1 = ot.datasets.make_data_classif('3gauss', shape, random_state=42)
+ X2, y2 = ot.datasets.make_data_classif('3gauss2', shape, random_state=42)
+ C1 = ot.dist(X1)
+ C2 = ot.dist(X2)
+ Cs = [C1.copy() for _ in range(n_samples // 2)] + [C2.copy() for _ in range(n_samples // 2)]
+ ps = [ot.unif(shape) for _ in range(n_samples)]
+ q = ot.unif(shape)
+
+ # Provide initialization for the graph dictionary of shape (n_atoms, shape, shape)
+ # following the same procedure than implemented in gromov_wasserstein_dictionary_learning.
+ dataset_means = [C.mean() for C in Cs]
+ np.random.seed(0)
+ Cdict_init = np.random.normal(loc=np.mean(dataset_means), scale=np.std(dataset_means), size=(n_atoms, shape, shape))
+
+ if projection == 'nonnegative_symmetric':
+ Cdict_init = 0.5 * (Cdict_init + Cdict_init.transpose((0, 2, 1)))
+ Cdict_init[Cdict_init < 0.] = 0.
+
+ Csb = nx.from_numpy(*Cs)
+ psb = nx.from_numpy(*ps)
+ qb, Cdict_initb = nx.from_numpy(q, Cdict_init)
+
+ # Test: compare reconstruction error using initial dictionary and dictionary learned using this initialization
+ # > Compute initial reconstruction of samples on this random dictionary without backend
+ use_adam_optimizer = True
+ verbose = False
+ tol = 10**(-5)
+ epochs = 1
+
+ initial_total_reconstruction = 0
+ for i in range(n_samples):
+ _, _, _, reconstruction = ot.gromov.gromov_wasserstein_linear_unmixing(
+ Cs[i], Cdict_init, p=ps[i], q=q, reg=0.,
+ tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50
+ )
+ initial_total_reconstruction += reconstruction
+
+ # > Learn the dictionary using this init
+ Cdict, log = ot.gromov.gromov_wasserstein_dictionary_learning(
+ Cs, D=n_atoms, nt=shape, ps=ps, q=q, Cdict_init=Cdict_init,
+ epochs=epochs, batch_size=2 * n_samples, learning_rate=1., reg=0.,
+ tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50,
+ projection=projection, use_log=False, use_adam_optimizer=use_adam_optimizer, verbose=verbose
+ )
+ # > Compute reconstruction of samples on learned dictionary without backend
+ total_reconstruction = 0
+ for i in range(n_samples):
+ _, _, _, reconstruction = ot.gromov.gromov_wasserstein_linear_unmixing(
+ Cs[i], Cdict, p=None, q=None, reg=0.,
+ tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50
+ )
+ total_reconstruction += reconstruction
+
+ np.testing.assert_array_less(total_reconstruction, initial_total_reconstruction)
+
+ # Test: Perform same experiments after going through backend
+
+ Cdictb, log = ot.gromov.gromov_wasserstein_dictionary_learning(
+ Csb, D=n_atoms, nt=shape, ps=None, q=None, Cdict_init=Cdict_initb,
+ epochs=epochs, batch_size=n_samples, learning_rate=1., reg=0.,
+ tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50,
+ projection=projection, use_log=False, use_adam_optimizer=use_adam_optimizer, verbose=verbose
+ )
+ # Compute reconstruction of samples on learned dictionary
+ total_reconstruction_b = 0
+ for i in range(n_samples):
+ _, _, _, reconstruction = ot.gromov.gromov_wasserstein_linear_unmixing(
+ Csb[i], Cdictb, p=psb[i], q=qb, reg=0.,
+ tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50
+ )
+ total_reconstruction_b += reconstruction
+
+ total_reconstruction_b = nx.to_numpy(total_reconstruction_b)
+ np.testing.assert_array_less(total_reconstruction_b, initial_total_reconstruction)
+ np.testing.assert_allclose(total_reconstruction_b, total_reconstruction, atol=1e-05)
+ np.testing.assert_allclose(total_reconstruction_b, total_reconstruction, atol=1e-05)
+ np.testing.assert_allclose(Cdict, nx.to_numpy(Cdictb), atol=1e-03)
+
+ # Test: Perform same comparison without providing the initial dictionary being an optional input
+ # knowing than the initialization scheme is the same than implemented to set the benchmarked initialization.
+ np.random.seed(0)
+ Cdict_bis, log = ot.gromov.gromov_wasserstein_dictionary_learning(
+ Cs, D=n_atoms, nt=shape, ps=None, q=None, Cdict_init=None,
+ epochs=epochs, batch_size=n_samples, learning_rate=1., reg=0.,
+ tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50,
+ projection=projection, use_log=False, use_adam_optimizer=use_adam_optimizer, verbose=verbose
+ )
+ # > Compute reconstruction of samples on learned dictionary
+ total_reconstruction_bis = 0
+ for i in range(n_samples):
+ _, _, _, reconstruction = ot.gromov.gromov_wasserstein_linear_unmixing(
+ Cs[i], Cdict_bis, p=ps[i], q=q, reg=0.,
+ tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50
+ )
+ total_reconstruction_bis += reconstruction
+
+ np.testing.assert_allclose(total_reconstruction_bis, total_reconstruction, atol=1e-05)
+
+ # Test: Same after going through backend
+ np.random.seed(0)
+ Cdictb_bis, log = ot.gromov.gromov_wasserstein_dictionary_learning(
+ Csb, D=n_atoms, nt=shape, ps=psb, q=qb, Cdict_init=None,
+ epochs=epochs, batch_size=n_samples, learning_rate=1., reg=0.,
+ tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50,
+ projection=projection, use_log=False, use_adam_optimizer=use_adam_optimizer, verbose=verbose
+ )
+ # > Compute reconstruction of samples on learned dictionary
+ total_reconstruction_b_bis = 0
+ for i in range(n_samples):
+ _, _, _, reconstruction = ot.gromov.gromov_wasserstein_linear_unmixing(
+ Csb[i], Cdictb_bis, p=None, q=None, reg=0.,
+ tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50
+ )
+ total_reconstruction_b_bis += reconstruction
+
+ total_reconstruction_b_bis = nx.to_numpy(total_reconstruction_b_bis)
+ np.testing.assert_allclose(total_reconstruction_b_bis, total_reconstruction_b, atol=1e-05)
+ np.testing.assert_allclose(Cdict_bis, nx.to_numpy(Cdictb_bis), atol=1e-03)
+
+ # Test: Perform same comparison without providing the initial dictionary being an optional input
+ # and testing other optimization settings untested until now.
+ # We pass previously estimated dictionaries to speed up the process.
+ use_adam_optimizer = False
+ verbose = True
+ use_log = True
+
+ np.random.seed(0)
+ Cdict_bis2, log = ot.gromov.gromov_wasserstein_dictionary_learning(
+ Cs, D=n_atoms, nt=shape, ps=ps, q=q, Cdict_init=Cdict,
+ epochs=epochs, batch_size=n_samples, learning_rate=10., reg=0.,
+ tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50,
+ projection=projection, use_log=use_log, use_adam_optimizer=use_adam_optimizer, verbose=verbose
+ )
+ # > Compute reconstruction of samples on learned dictionary
+ total_reconstruction_bis2 = 0
+ for i in range(n_samples):
+ _, _, _, reconstruction = ot.gromov.gromov_wasserstein_linear_unmixing(
+ Cs[i], Cdict_bis2, p=ps[i], q=q, reg=0.,
+ tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50
+ )
+ total_reconstruction_bis2 += reconstruction
+
+ np.testing.assert_array_less(total_reconstruction_bis2, total_reconstruction)
+
+ # Test: Same after going through backend
+ np.random.seed(0)
+ Cdictb_bis2, log = ot.gromov.gromov_wasserstein_dictionary_learning(
+ Csb, D=n_atoms, nt=shape, ps=psb, q=qb, Cdict_init=Cdictb,
+ epochs=epochs, batch_size=n_samples, learning_rate=10., reg=0.,
+ tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50,
+ projection=projection, use_log=use_log, use_adam_optimizer=use_adam_optimizer, verbose=verbose
+ )
+ # > Compute reconstruction of samples on learned dictionary
+ total_reconstruction_b_bis2 = 0
+ for i in range(n_samples):
+ _, _, _, reconstruction = ot.gromov.gromov_wasserstein_linear_unmixing(
+ Csb[i], Cdictb_bis2, p=psb[i], q=qb, reg=0.,
+ tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50
+ )
+ total_reconstruction_b_bis2 += reconstruction
+
+ total_reconstruction_b_bis2 = nx.to_numpy(total_reconstruction_b_bis2)
+ np.testing.assert_allclose(total_reconstruction_b_bis2, total_reconstruction_bis2, atol=1e-05)
+
+
+def test_fused_gromov_wasserstein_linear_unmixing(nx):
+
+ n = 4
+ X1, y1 = ot.datasets.make_data_classif('3gauss', n, random_state=42)
+ X2, y2 = ot.datasets.make_data_classif('3gauss2', n, random_state=42)
+ F, y = ot.datasets.make_data_classif('3gauss', n, random_state=42)
+
+ C1 = ot.dist(X1)
+ C2 = ot.dist(X2)
+ Cdict = np.stack([C1, C2])
+ Ydict = np.stack([F, F])
+ p = ot.unif(n)
+
+ C1b, C2b, Fb, Cdictb, Ydictb, pb = nx.from_numpy(C1, C2, F, Cdict, Ydict, p)
+
+ # Tests without regularization
+ reg = 0.
+
+ unmixing1, C1_emb, Y1_emb, OT, reconstruction1 = ot.gromov.fused_gromov_wasserstein_linear_unmixing(
+ C1, F, Cdict, Ydict, p=p, q=p, alpha=0.5, reg=reg,
+ tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=10, max_iter_inner=50
+ )
+
+ unmixing1b, C1b_emb, Y1b_emb, OTb, reconstruction1b = ot.gromov.fused_gromov_wasserstein_linear_unmixing(
+ C1b, Fb, Cdictb, Ydictb, p=None, q=None, alpha=0.5, reg=reg,
+ tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=10, max_iter_inner=50
+ )
+
+ unmixing2, C2_emb, Y2_emb, OT, reconstruction2 = ot.gromov.fused_gromov_wasserstein_linear_unmixing(
+ C2, F, Cdict, Ydict, p=None, q=None, alpha=0.5, reg=reg,
+ tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=10, max_iter_inner=50
+ )
+
+ unmixing2b, C2b_emb, Y2b_emb, OTb, reconstruction2b = ot.gromov.fused_gromov_wasserstein_linear_unmixing(
+ C2b, Fb, Cdictb, Ydictb, p=pb, q=pb, alpha=0.5, reg=reg,
+ tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=10, max_iter_inner=50
+ )
+
+ np.testing.assert_allclose(unmixing1, nx.to_numpy(unmixing1b), atol=4e-06)
+ np.testing.assert_allclose(unmixing1, [1., 0.], atol=4e-01)
+ np.testing.assert_allclose(unmixing2, nx.to_numpy(unmixing2b), atol=4e-06)
+ np.testing.assert_allclose(unmixing2, [0., 1.], atol=4e-01)
+ np.testing.assert_allclose(C1_emb, nx.to_numpy(C1b_emb), atol=1e-03)
+ np.testing.assert_allclose(C2_emb, nx.to_numpy(C2b_emb), atol=1e-03)
+ np.testing.assert_allclose(Y1_emb, nx.to_numpy(Y1b_emb), atol=1e-03)
+ np.testing.assert_allclose(Y2_emb, nx.to_numpy(Y2b_emb), atol=1e-03)
+ np.testing.assert_allclose(reconstruction1, nx.to_numpy(reconstruction1b), atol=1e-06)
+ np.testing.assert_allclose(reconstruction2, nx.to_numpy(reconstruction2b), atol=1e-06)
+ np.testing.assert_allclose(C1b_emb.shape, (n, n))
+ np.testing.assert_allclose(C2b_emb.shape, (n, n))
+
+ # Tests with regularization
+ reg = 0.001
+
+ unmixing1, C1_emb, Y1_emb, OT, reconstruction1 = ot.gromov.fused_gromov_wasserstein_linear_unmixing(
+ C1, F, Cdict, Ydict, p=p, q=p, alpha=0.5, reg=reg,
+ tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=10, max_iter_inner=50
+ )
+
+ unmixing1b, C1b_emb, Y1b_emb, OTb, reconstruction1b = ot.gromov.fused_gromov_wasserstein_linear_unmixing(
+ C1b, Fb, Cdictb, Ydictb, p=None, q=None, alpha=0.5, reg=reg,
+ tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=10, max_iter_inner=50
+ )
+
+ unmixing2, C2_emb, Y2_emb, OT, reconstruction2 = ot.gromov.fused_gromov_wasserstein_linear_unmixing(
+ C2, F, Cdict, Ydict, p=None, q=None, alpha=0.5, reg=reg,
+ tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=10, max_iter_inner=50
+ )
+
+ unmixing2b, C2b_emb, Y2b_emb, OTb, reconstruction2b = ot.gromov.fused_gromov_wasserstein_linear_unmixing(
+ C2b, Fb, Cdictb, Ydictb, p=pb, q=pb, alpha=0.5, reg=reg,
+ tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=10, max_iter_inner=50
+ )
+
+ np.testing.assert_allclose(unmixing1, nx.to_numpy(unmixing1b), atol=1e-06)
+ np.testing.assert_allclose(unmixing1, [1., 0.], atol=1e-01)
+ np.testing.assert_allclose(unmixing2, nx.to_numpy(unmixing2b), atol=1e-06)
+ np.testing.assert_allclose(unmixing2, [0., 1.], atol=1e-01)
+ np.testing.assert_allclose(C1_emb, nx.to_numpy(C1b_emb), atol=1e-03)
+ np.testing.assert_allclose(C2_emb, nx.to_numpy(C2b_emb), atol=1e-03)
+ np.testing.assert_allclose(Y1_emb, nx.to_numpy(Y1b_emb), atol=1e-03)
+ np.testing.assert_allclose(Y2_emb, nx.to_numpy(Y2b_emb), atol=1e-03)
+ np.testing.assert_allclose(reconstruction1, nx.to_numpy(reconstruction1b), atol=1e-06)
+ np.testing.assert_allclose(reconstruction2, nx.to_numpy(reconstruction2b), atol=1e-06)
+ np.testing.assert_allclose(C1b_emb.shape, (n, n))
+ np.testing.assert_allclose(C2b_emb.shape, (n, n))
+
+
+def test_fused_gromov_wasserstein_dictionary_learning(nx):
+
+ # create dataset composed from 2 structures which are repeated 5 times
+ shape = 4
+ n_samples = 2
+ n_atoms = 2
+ projection = 'nonnegative_symmetric'
+ X1, y1 = ot.datasets.make_data_classif('3gauss', shape, random_state=42)
+ X2, y2 = ot.datasets.make_data_classif('3gauss2', shape, random_state=42)
+ F, y = ot.datasets.make_data_classif('3gauss', shape, random_state=42)
+
+ C1 = ot.dist(X1)
+ C2 = ot.dist(X2)
+ Cs = [C1.copy() for _ in range(n_samples // 2)] + [C2.copy() for _ in range(n_samples // 2)]
+ Ys = [F.copy() for _ in range(n_samples)]
+ ps = [ot.unif(shape) for _ in range(n_samples)]
+ q = ot.unif(shape)
+
+ # Provide initialization for the graph dictionary of shape (n_atoms, shape, shape)
+ # following the same procedure than implemented in gromov_wasserstein_dictionary_learning.
+ dataset_structure_means = [C.mean() for C in Cs]
+ np.random.seed(0)
+ Cdict_init = np.random.normal(loc=np.mean(dataset_structure_means), scale=np.std(dataset_structure_means), size=(n_atoms, shape, shape))
+ if projection == 'nonnegative_symmetric':
+ Cdict_init = 0.5 * (Cdict_init + Cdict_init.transpose((0, 2, 1)))
+ Cdict_init[Cdict_init < 0.] = 0.
+ dataset_feature_means = np.stack([Y.mean(axis=0) for Y in Ys])
+ Ydict_init = np.random.normal(loc=dataset_feature_means.mean(axis=0), scale=dataset_feature_means.std(axis=0), size=(n_atoms, shape, 2))
+
+ Csb = nx.from_numpy(*Cs)
+ Ysb = nx.from_numpy(*Ys)
+ psb = nx.from_numpy(*ps)
+ qb, Cdict_initb, Ydict_initb = nx.from_numpy(q, Cdict_init, Ydict_init)
+
+ # Test: Compute initial reconstruction of samples on this random dictionary
+ alpha = 0.5
+ use_adam_optimizer = True
+ verbose = False
+ tol = 1e-05
+ epochs = 1
+
+ initial_total_reconstruction = 0
+ for i in range(n_samples):
+ _, _, _, _, reconstruction = ot.gromov.fused_gromov_wasserstein_linear_unmixing(
+ Cs[i], Ys[i], Cdict_init, Ydict_init, p=ps[i], q=q,
+ alpha=alpha, reg=0., tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50
+ )
+ initial_total_reconstruction += reconstruction
+
+ # > Learn a dictionary using this given initialization and check that the reconstruction loss
+ # on the learned dictionary is lower than the one using its initialization.
+ Cdict, Ydict, log = ot.gromov.fused_gromov_wasserstein_dictionary_learning(
+ Cs, Ys, D=n_atoms, nt=shape, ps=ps, q=q, Cdict_init=Cdict_init, Ydict_init=Ydict_init,
+ epochs=epochs, batch_size=n_samples, learning_rate_C=1., learning_rate_Y=1., alpha=alpha, reg=0.,
+ tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50,
+ projection=projection, use_log=False, use_adam_optimizer=use_adam_optimizer, verbose=verbose
+ )
+ # > Compute reconstruction of samples on learned dictionary
+ total_reconstruction = 0
+ for i in range(n_samples):
+ _, _, _, _, reconstruction = ot.gromov.fused_gromov_wasserstein_linear_unmixing(
+ Cs[i], Ys[i], Cdict, Ydict, p=None, q=None, alpha=alpha, reg=0.,
+ tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50
+ )
+ total_reconstruction += reconstruction
+ # Compare both
+ np.testing.assert_array_less(total_reconstruction, initial_total_reconstruction)
+
+ # Test: Perform same experiments after going through backend
+
+ Cdictb, Ydictb, log = ot.gromov.fused_gromov_wasserstein_dictionary_learning(
+ Csb, Ysb, D=n_atoms, nt=shape, ps=None, q=None, Cdict_init=Cdict_initb, Ydict_init=Ydict_initb,
+ epochs=epochs, batch_size=2 * n_samples, learning_rate_C=1., learning_rate_Y=1., alpha=alpha, reg=0.,
+ tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50,
+ projection=projection, use_log=False, use_adam_optimizer=use_adam_optimizer, verbose=verbose
+ )
+ # > Compute reconstruction of samples on learned dictionary
+ total_reconstruction_b = 0
+ for i in range(n_samples):
+ _, _, _, _, reconstruction = ot.gromov.fused_gromov_wasserstein_linear_unmixing(
+ Csb[i], Ysb[i], Cdictb, Ydictb, p=psb[i], q=qb, alpha=alpha, reg=0.,
+ tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50
+ )
+ total_reconstruction_b += reconstruction
+
+ total_reconstruction_b = nx.to_numpy(total_reconstruction_b)
+ np.testing.assert_array_less(total_reconstruction_b, initial_total_reconstruction)
+ np.testing.assert_allclose(total_reconstruction_b, total_reconstruction, atol=1e-05)
+ np.testing.assert_allclose(Cdict, nx.to_numpy(Cdictb), atol=1e-03)
+ np.testing.assert_allclose(Ydict, nx.to_numpy(Ydictb), atol=1e-03)
+
+ # Test: Perform similar experiment without providing the initial dictionary being an optional input
+ np.random.seed(0)
+ Cdict_bis, Ydict_bis, log = ot.gromov.fused_gromov_wasserstein_dictionary_learning(
+ Cs, Ys, D=n_atoms, nt=shape, ps=None, q=None, Cdict_init=None, Ydict_init=None,
+ epochs=epochs, batch_size=n_samples, learning_rate_C=1., learning_rate_Y=1., alpha=alpha, reg=0.,
+ tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50,
+ projection=projection, use_log=False, use_adam_optimizer=use_adam_optimizer, verbose=verbose
+ )
+ # > Compute reconstruction of samples on learned dictionary
+ total_reconstruction_bis = 0
+ for i in range(n_samples):
+ _, _, _, _, reconstruction = ot.gromov.fused_gromov_wasserstein_linear_unmixing(
+ Cs[i], Ys[i], Cdict_bis, Ydict_bis, p=ps[i], q=q, alpha=alpha, reg=0.,
+ tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50
+ )
+ total_reconstruction_bis += reconstruction
+
+ np.testing.assert_allclose(total_reconstruction_bis, total_reconstruction, atol=1e-05)
+
+ # > Same after going through backend
+ np.random.seed(0)
+ Cdictb_bis, Ydictb_bis, log = ot.gromov.fused_gromov_wasserstein_dictionary_learning(
+ Csb, Ysb, D=n_atoms, nt=shape, ps=None, q=None, Cdict_init=None, Ydict_init=None,
+ epochs=epochs, batch_size=n_samples, learning_rate_C=1., learning_rate_Y=1., alpha=alpha, reg=0.,
+ tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50,
+ projection=projection, use_log=False, use_adam_optimizer=use_adam_optimizer, verbose=verbose
+ )
+
+ # > Compute reconstruction of samples on learned dictionary
+ total_reconstruction_b_bis = 0
+ for i in range(n_samples):
+ _, _, _, _, reconstruction = ot.gromov.fused_gromov_wasserstein_linear_unmixing(
+ Csb[i], Ysb[i], Cdictb_bis, Ydictb_bis, p=psb[i], q=qb, alpha=alpha, reg=0.,
+ tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50
+ )
+ total_reconstruction_b_bis += reconstruction
+
+ total_reconstruction_b_bis = nx.to_numpy(total_reconstruction_b_bis)
+ np.testing.assert_allclose(total_reconstruction_b_bis, total_reconstruction_b, atol=1e-05)
+
+ # Test: without using adam optimizer, with log and verbose set to True
+ use_adam_optimizer = False
+ verbose = True
+ use_log = True
+
+ # > Experiment providing previously estimated dictionary to speed up the test compared to providing initial random init.
+ np.random.seed(0)
+ Cdict_bis2, Ydict_bis2, log = ot.gromov.fused_gromov_wasserstein_dictionary_learning(
+ Cs, Ys, D=n_atoms, nt=shape, ps=ps, q=q, Cdict_init=Cdict, Ydict_init=Ydict,
+ epochs=epochs, batch_size=n_samples, learning_rate_C=10., learning_rate_Y=10., alpha=alpha, reg=0.,
+ tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50,
+ projection=projection, use_log=use_log, use_adam_optimizer=use_adam_optimizer, verbose=verbose
+ )
+ # > Compute reconstruction of samples on learned dictionary
+ total_reconstruction_bis2 = 0
+ for i in range(n_samples):
+ _, _, _, _, reconstruction = ot.gromov.fused_gromov_wasserstein_linear_unmixing(
+ Cs[i], Ys[i], Cdict_bis2, Ydict_bis2, p=ps[i], q=q, alpha=alpha, reg=0.,
+ tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50
+ )
+ total_reconstruction_bis2 += reconstruction
+
+ np.testing.assert_array_less(total_reconstruction_bis2, total_reconstruction)
+
+ # > Same after going through backend
+ np.random.seed(0)
+ Cdictb_bis2, Ydictb_bis2, log = ot.gromov.fused_gromov_wasserstein_dictionary_learning(
+ Csb, Ysb, D=n_atoms, nt=shape, ps=None, q=None, Cdict_init=Cdictb, Ydict_init=Ydictb,
+ epochs=epochs, batch_size=n_samples, learning_rate_C=10., learning_rate_Y=10., alpha=alpha, reg=0.,
+ tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50,
+ projection=projection, use_log=use_log, use_adam_optimizer=use_adam_optimizer, verbose=verbose
+ )
+
+ # > Compute reconstruction of samples on learned dictionary
+ total_reconstruction_b_bis2 = 0
+ for i in range(n_samples):
+ _, _, _, _, reconstruction = ot.gromov.fused_gromov_wasserstein_linear_unmixing(
+ Csb[i], Ysb[i], Cdictb_bis2, Ydictb_bis2, p=None, q=None, alpha=alpha, reg=0.,
+ tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50
+ )
+ total_reconstruction_b_bis2 += reconstruction
+
+ # > Compare results with/without backend
+ total_reconstruction_b_bis2 = nx.to_numpy(total_reconstruction_b_bis2)
+ np.testing.assert_allclose(total_reconstruction_bis2, total_reconstruction_b_bis2, atol=1e-05)
diff --git a/test/test_optim.py b/test/test_optim.py
index 41f9cbe..67e9d13 100644
--- a/test/test_optim.py
+++ b/test/test_optim.py
@@ -32,9 +32,7 @@ def test_conditional_gradient(nx):
def fb(G):
return 0.5 * nx.sum(G ** 2)
- ab = nx.from_numpy(a)
- bb = nx.from_numpy(b)
- Mb = nx.from_numpy(M, type_as=ab)
+ ab, bb, Mb = nx.from_numpy(a, b, M)
reg = 1e-1
@@ -74,9 +72,7 @@ def test_conditional_gradient_itermax(nx):
def fb(G):
return 0.5 * nx.sum(G ** 2)
- ab = nx.from_numpy(a)
- bb = nx.from_numpy(b)
- Mb = nx.from_numpy(M, type_as=ab)
+ ab, bb, Mb = nx.from_numpy(a, b, M)
reg = 1e-1
@@ -118,9 +114,7 @@ def test_generalized_conditional_gradient(nx):
reg1 = 1e-3
reg2 = 1e-1
- ab = nx.from_numpy(a)
- bb = nx.from_numpy(b)
- Mb = nx.from_numpy(M, type_as=ab)
+ ab, bb, Mb = nx.from_numpy(a, b, M)
G, log = ot.optim.gcg(a, b, M, reg1, reg2, f, df, verbose=True, log=True)
Gb, log = ot.optim.gcg(ab, bb, Mb, reg1, reg2, fb, df, verbose=True, log=True)
@@ -142,9 +136,12 @@ def test_line_search_armijo(nx):
pk = np.array([[-0.25, 0.25], [0.25, -0.25]])
gfk = np.array([[23.04273441, 23.0449082], [23.04273441, 23.0449082]])
old_fval = -123
+
+ xkb, pkb, gfkb = nx.from_numpy(xk, pk, gfk)
+
# Should not throw an exception and return 0. for alpha
alpha, a, b = ot.optim.line_search_armijo(
- lambda x: 1, nx.from_numpy(xk), nx.from_numpy(pk), nx.from_numpy(gfk), old_fval
+ lambda x: 1, xkb, pkb, gfkb, old_fval
)
alpha_np, anp, bnp = ot.optim.line_search_armijo(
lambda x: 1, xk, pk, gfk, old_fval
diff --git a/test/test_ot.py b/test/test_ot.py
index 53edf4f..bf832f6 100644
--- a/test/test_ot.py
+++ b/test/test_ot.py
@@ -47,8 +47,7 @@ def test_emd_backends(nx):
G = ot.emd(a, a, M)
- ab = nx.from_numpy(a)
- Mb = nx.from_numpy(M)
+ ab, Mb = nx.from_numpy(a, M)
Gb = ot.emd(ab, ab, Mb)
@@ -68,8 +67,7 @@ def test_emd2_backends(nx):
val = ot.emd2(a, a, M)
- ab = nx.from_numpy(a)
- Mb = nx.from_numpy(M)
+ ab, Mb = nx.from_numpy(a, M)
valb = ot.emd2(ab, ab, Mb)
@@ -90,8 +88,7 @@ def test_emd_emd2_types_devices(nx):
for tp in nx.__type_list__:
print(nx.dtype_device(tp))
- ab = nx.from_numpy(a, type_as=tp)
- Mb = nx.from_numpy(M, type_as=tp)
+ ab, Mb = nx.from_numpy(a, M, type_as=tp)
Gb = ot.emd(ab, ab, Mb)
@@ -117,8 +114,7 @@ def test_emd_emd2_devices_tf():
# Check that everything stays on the CPU
with tf.device("/CPU:0"):
- ab = nx.from_numpy(a)
- Mb = nx.from_numpy(M)
+ ab, Mb = nx.from_numpy(a, M)
Gb = ot.emd(ab, ab, Mb)
w = ot.emd2(ab, ab, Mb)
nx.assert_same_dtype_device(Mb, Gb)
@@ -126,8 +122,7 @@ def test_emd_emd2_devices_tf():
if len(tf.config.list_physical_devices('GPU')) > 0:
# Check that everything happens on the GPU
- ab = nx.from_numpy(a)
- Mb = nx.from_numpy(M)
+ ab, Mb = nx.from_numpy(a, M)
Gb = ot.emd(ab, ab, Mb)
w = ot.emd2(ab, ab, Mb)
nx.assert_same_dtype_device(Mb, Gb)
@@ -152,7 +147,7 @@ def test_emd2_gradients():
b1 = torch.tensor(a, requires_grad=True)
M1 = torch.tensor(M, requires_grad=True)
- val = ot.emd2(a1, b1, M1)
+ val, log = ot.emd2(a1, b1, M1, log=True)
val.backward()
@@ -160,6 +155,12 @@ def test_emd2_gradients():
assert b1.shape == b1.grad.shape
assert M1.shape == M1.grad.shape
+ assert np.allclose(a1.grad.cpu().detach().numpy(),
+ log['u'].cpu().detach().numpy() - log['u'].cpu().detach().numpy().mean())
+
+ assert np.allclose(b1.grad.cpu().detach().numpy(),
+ log['v'].cpu().detach().numpy() - log['v'].cpu().detach().numpy().mean())
+
# Testing for bug #309, checking for scaling of gradient
a2 = torch.tensor(a, requires_grad=True)
b2 = torch.tensor(a, requires_grad=True)
@@ -232,7 +233,7 @@ def test_emd2_multi():
# Gaussian distributions
a = gauss(n, m=20, s=5) # m= mean, s= std
- ls = np.arange(20, 500, 20)
+ ls = np.arange(20, 500, 100)
nb = len(ls)
b = np.zeros((n, nb))
for i in range(nb):
@@ -302,6 +303,23 @@ def test_free_support_barycenter():
np.testing.assert_allclose(X, bar_locations, rtol=1e-5, atol=1e-7)
+def test_free_support_barycenter_backends(nx):
+
+ measures_locations = [np.array([-1.]).reshape((1, 1)), np.array([1.]).reshape((1, 1))]
+ measures_weights = [np.array([1.]), np.array([1.])]
+ X_init = np.array([-12.]).reshape((1, 1))
+
+ X = ot.lp.free_support_barycenter(measures_locations, measures_weights, X_init)
+
+ measures_locations2 = nx.from_numpy(*measures_locations)
+ measures_weights2 = nx.from_numpy(*measures_weights)
+ X_init2 = nx.from_numpy(X_init)
+
+ X2 = ot.lp.free_support_barycenter(measures_locations2, measures_weights2, X_init2)
+
+ np.testing.assert_allclose(X, nx.to_numpy(X2))
+
+
@pytest.mark.skipif(not ot.lp.cvx.cvxopt, reason="No cvxopt available")
def test_lp_barycenter_cvxopt():
a1 = np.array([1.0, 0, 0])[:, None]
diff --git a/test/test_sliced.py b/test/test_sliced.py
index 91e0961..08ab4fb 100644
--- a/test/test_sliced.py
+++ b/test/test_sliced.py
@@ -123,9 +123,7 @@ def test_sliced_backend(nx):
n_projections = 20
- xb = nx.from_numpy(x)
- yb = nx.from_numpy(y)
- Pb = nx.from_numpy(P)
+ xb, yb, Pb = nx.from_numpy(x, y, P)
val0 = ot.sliced_wasserstein_distance(x, y, projections=P)
@@ -153,9 +151,7 @@ def test_sliced_backend_type_devices(nx):
for tp in nx.__type_list__:
print(nx.dtype_device(tp))
- xb = nx.from_numpy(x, type_as=tp)
- yb = nx.from_numpy(y, type_as=tp)
- Pb = nx.from_numpy(P, type_as=tp)
+ xb, yb, Pb = nx.from_numpy(x, y, P, type_as=tp)
valb = ot.sliced_wasserstein_distance(xb, yb, projections=Pb)
@@ -174,17 +170,13 @@ def test_sliced_backend_device_tf():
# Check that everything stays on the CPU
with tf.device("/CPU:0"):
- xb = nx.from_numpy(x)
- yb = nx.from_numpy(y)
- Pb = nx.from_numpy(P)
+ xb, yb, Pb = nx.from_numpy(x, y, P)
valb = ot.sliced_wasserstein_distance(xb, yb, projections=Pb)
nx.assert_same_dtype_device(xb, valb)
if len(tf.config.list_physical_devices('GPU')) > 0:
# Check that everything happens on the GPU
- xb = nx.from_numpy(x)
- yb = nx.from_numpy(y)
- Pb = nx.from_numpy(P)
+ xb, yb, Pb = nx.from_numpy(x, y, P)
valb = ot.sliced_wasserstein_distance(xb, yb, projections=Pb)
nx.assert_same_dtype_device(xb, valb)
assert nx.dtype_device(valb)[1].startswith("GPU")
@@ -203,9 +195,7 @@ def test_max_sliced_backend(nx):
n_projections = 20
- xb = nx.from_numpy(x)
- yb = nx.from_numpy(y)
- Pb = nx.from_numpy(P)
+ xb, yb, Pb = nx.from_numpy(x, y, P)
val0 = ot.max_sliced_wasserstein_distance(x, y, projections=P)
@@ -233,9 +223,7 @@ def test_max_sliced_backend_type_devices(nx):
for tp in nx.__type_list__:
print(nx.dtype_device(tp))
- xb = nx.from_numpy(x, type_as=tp)
- yb = nx.from_numpy(y, type_as=tp)
- Pb = nx.from_numpy(P, type_as=tp)
+ xb, yb, Pb = nx.from_numpy(x, y, P, type_as=tp)
valb = ot.max_sliced_wasserstein_distance(xb, yb, projections=Pb)
@@ -254,17 +242,13 @@ def test_max_sliced_backend_device_tf():
# Check that everything stays on the CPU
with tf.device("/CPU:0"):
- xb = nx.from_numpy(x)
- yb = nx.from_numpy(y)
- Pb = nx.from_numpy(P)
+ xb, yb, Pb = nx.from_numpy(x, y, P)
valb = ot.max_sliced_wasserstein_distance(xb, yb, projections=Pb)
nx.assert_same_dtype_device(xb, valb)
if len(tf.config.list_physical_devices('GPU')) > 0:
# Check that everything happens on the GPU
- xb = nx.from_numpy(x)
- yb = nx.from_numpy(y)
- Pb = nx.from_numpy(P)
+ xb, yb, Pb = nx.from_numpy(x, y, P)
valb = ot.max_sliced_wasserstein_distance(xb, yb, projections=Pb)
nx.assert_same_dtype_device(xb, valb)
assert nx.dtype_device(valb)[1].startswith("GPU")
diff --git a/test/test_stochastic.py b/test/test_stochastic.py
index 736df32..2b5c0fb 100644
--- a/test/test_stochastic.py
+++ b/test/test_stochastic.py
@@ -8,7 +8,8 @@ for descrete and semicontinous measures from the POT library.
"""
-# Author: Kilian Fatras <kilian.fatras@gmail.com>
+# Authors: Kilian Fatras <kilian.fatras@gmail.com>
+# Rémi Flamary <remi.flamary@polytechnique.edu>
#
# License: MIT License
@@ -213,3 +214,115 @@ def test_dual_sgd_sinkhorn():
G_sgd.sum(0), G_sinkhorn.sum(0), atol=1e-03)
np.testing.assert_allclose(
G_sgd, G_sinkhorn, atol=1e-03) # cf convergence sgd
+
+
+def test_loss_dual_entropic(nx):
+
+ nx.seed(0)
+
+ xs = nx.randn(50, 2)
+ xt = nx.randn(40, 2) + 2
+
+ ws = nx.rand(50)
+ ws = ws / nx.sum(ws)
+
+ wt = nx.rand(40)
+ wt = wt / nx.sum(wt)
+
+ u = nx.randn(50)
+ v = nx.randn(40)
+
+ def metric(x, y):
+ return -nx.dot(x, y.T)
+
+ ot.stochastic.loss_dual_entropic(u, v, xs, xt)
+
+ ot.stochastic.loss_dual_entropic(u, v, xs, xt, ws=ws, wt=wt, metric=metric)
+
+
+def test_plan_dual_entropic(nx):
+
+ nx.seed(0)
+
+ xs = nx.randn(50, 2)
+ xt = nx.randn(40, 2) + 2
+
+ ws = nx.rand(50)
+ ws = ws / nx.sum(ws)
+
+ wt = nx.rand(40)
+ wt = wt / nx.sum(wt)
+
+ u = nx.randn(50)
+ v = nx.randn(40)
+
+ def metric(x, y):
+ return -nx.dot(x, y.T)
+
+ G1 = ot.stochastic.plan_dual_entropic(u, v, xs, xt)
+
+ assert np.all(nx.to_numpy(G1) >= 0)
+ assert G1.shape[0] == 50
+ assert G1.shape[1] == 40
+
+ G2 = ot.stochastic.plan_dual_entropic(u, v, xs, xt, ws=ws, wt=wt, metric=metric)
+
+ assert np.all(nx.to_numpy(G2) >= 0)
+ assert G2.shape[0] == 50
+ assert G2.shape[1] == 40
+
+
+def test_loss_dual_quadratic(nx):
+
+ nx.seed(0)
+
+ xs = nx.randn(50, 2)
+ xt = nx.randn(40, 2) + 2
+
+ ws = nx.rand(50)
+ ws = ws / nx.sum(ws)
+
+ wt = nx.rand(40)
+ wt = wt / nx.sum(wt)
+
+ u = nx.randn(50)
+ v = nx.randn(40)
+
+ def metric(x, y):
+ return -nx.dot(x, y.T)
+
+ ot.stochastic.loss_dual_quadratic(u, v, xs, xt)
+
+ ot.stochastic.loss_dual_quadratic(u, v, xs, xt, ws=ws, wt=wt, metric=metric)
+
+
+def test_plan_dual_quadratic(nx):
+
+ nx.seed(0)
+
+ xs = nx.randn(50, 2)
+ xt = nx.randn(40, 2) + 2
+
+ ws = nx.rand(50)
+ ws = ws / nx.sum(ws)
+
+ wt = nx.rand(40)
+ wt = wt / nx.sum(wt)
+
+ u = nx.randn(50)
+ v = nx.randn(40)
+
+ def metric(x, y):
+ return -nx.dot(x, y.T)
+
+ G1 = ot.stochastic.plan_dual_quadratic(u, v, xs, xt)
+
+ assert np.all(nx.to_numpy(G1) >= 0)
+ assert G1.shape[0] == 50
+ assert G1.shape[1] == 40
+
+ G2 = ot.stochastic.plan_dual_quadratic(u, v, xs, xt, ws=ws, wt=wt, metric=metric)
+
+ assert np.all(nx.to_numpy(G2) >= 0)
+ assert G2.shape[0] == 50
+ assert G2.shape[1] == 40
diff --git a/test/test_unbalanced.py b/test/test_unbalanced.py
index e8349d1..02b3fc3 100644
--- a/test/test_unbalanced.py
+++ b/test/test_unbalanced.py
@@ -1,6 +1,7 @@
"""Tests for module Unbalanced OT with entropy regularization"""
# Author: Hicham Janati <hicham.janati@inria.fr>
+# Laetitia Chapel <laetitia.chapel@univ-ubs.fr>
#
# License: MIT License
@@ -9,11 +10,9 @@ import ot
import pytest
from ot.unbalanced import barycenter_unbalanced
-from scipy.special import logsumexp
-
@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"])
-def test_unbalanced_convergence(method):
+def test_unbalanced_convergence(nx, method):
# test generalized sinkhorn for unbalanced OT
n = 100
rng = np.random.RandomState(42)
@@ -28,36 +27,51 @@ def test_unbalanced_convergence(method):
epsilon = 1.
reg_m = 1.
+ a, b, M = nx.from_numpy(a, b, M)
+
G, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon,
reg_m=reg_m,
method=method,
log=True,
verbose=True)
- loss = ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, reg_m,
- method=method,
- verbose=True)
+ loss = nx.to_numpy(ot.unbalanced.sinkhorn_unbalanced2(
+ a, b, M, epsilon, reg_m, method=method, verbose=True
+ ))
# check fixed point equations
# in log-domain
fi = reg_m / (reg_m + epsilon)
- logb = np.log(b + 1e-16)
- loga = np.log(a + 1e-16)
- logKtu = logsumexp(log["logu"][None, :] - M.T / epsilon, axis=1)
- logKv = logsumexp(log["logv"][None, :] - M / epsilon, axis=1)
+ logb = nx.log(b + 1e-16)
+ loga = nx.log(a + 1e-16)
+ logKtu = nx.logsumexp(log["logu"][None, :] - M.T / epsilon, axis=1)
+ logKv = nx.logsumexp(log["logv"][None, :] - M / epsilon, axis=1)
v_final = fi * (logb - logKtu)
u_final = fi * (loga - logKv)
np.testing.assert_allclose(
- u_final, log["logu"], atol=1e-05)
+ nx.to_numpy(u_final), nx.to_numpy(log["logu"]), atol=1e-05)
np.testing.assert_allclose(
- v_final, log["logv"], atol=1e-05)
+ nx.to_numpy(v_final), nx.to_numpy(log["logv"]), atol=1e-05)
# check if sinkhorn_unbalanced2 returns the correct loss
- np.testing.assert_allclose((G * M).sum(), loss, atol=1e-5)
+ np.testing.assert_allclose(nx.to_numpy(nx.sum(G * M)), loss, atol=1e-5)
+
+ # check in case no histogram is provided
+ M_np = nx.to_numpy(M)
+ a_np, b_np = np.array([]), np.array([])
+ a, b = nx.from_numpy(a_np, b_np)
+
+ G = ot.unbalanced.sinkhorn_unbalanced(
+ a, b, M, reg=epsilon, reg_m=reg_m, method=method, verbose=True
+ )
+ G_np = ot.unbalanced.sinkhorn_unbalanced(
+ a_np, b_np, M_np, reg=epsilon, reg_m=reg_m, method=method, verbose=True
+ )
+ np.testing.assert_allclose(G_np, nx.to_numpy(G))
@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"])
-def test_unbalanced_multiple_inputs(method):
+def test_unbalanced_multiple_inputs(nx, method):
# test generalized sinkhorn for unbalanced OT
n = 100
rng = np.random.RandomState(42)
@@ -72,6 +86,8 @@ def test_unbalanced_multiple_inputs(method):
epsilon = 1.
reg_m = 1.
+ a, b, M = nx.from_numpy(a, b, M)
+
loss, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon,
reg_m=reg_m,
method=method,
@@ -80,23 +96,24 @@ def test_unbalanced_multiple_inputs(method):
# check fixed point equations
# in log-domain
fi = reg_m / (reg_m + epsilon)
- logb = np.log(b + 1e-16)
- loga = np.log(a + 1e-16)[:, None]
- logKtu = logsumexp(log["logu"][:, None, :] - M[:, :, None] / epsilon,
- axis=0)
- logKv = logsumexp(log["logv"][None, :] - M[:, :, None] / epsilon, axis=1)
+ logb = nx.log(b + 1e-16)
+ loga = nx.log(a + 1e-16)[:, None]
+ logKtu = nx.logsumexp(
+ log["logu"][:, None, :] - M[:, :, None] / epsilon, axis=0
+ )
+ logKv = nx.logsumexp(log["logv"][None, :] - M[:, :, None] / epsilon, axis=1)
v_final = fi * (logb - logKtu)
u_final = fi * (loga - logKv)
np.testing.assert_allclose(
- u_final, log["logu"], atol=1e-05)
+ nx.to_numpy(u_final), nx.to_numpy(log["logu"]), atol=1e-05)
np.testing.assert_allclose(
- v_final, log["logv"], atol=1e-05)
+ nx.to_numpy(v_final), nx.to_numpy(log["logv"]), atol=1e-05)
assert len(loss) == b.shape[1]
-def test_stabilized_vs_sinkhorn():
+def test_stabilized_vs_sinkhorn(nx):
# test if stable version matches sinkhorn
n = 100
@@ -112,19 +129,27 @@ def test_stabilized_vs_sinkhorn():
M /= np.median(M)
epsilon = 0.1
reg_m = 1.
- G, log = ot.unbalanced.sinkhorn_unbalanced2(a, b, M, reg=epsilon,
- method="sinkhorn_stabilized",
- reg_m=reg_m,
- log=True,
- verbose=True)
- G2, log2 = ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, reg_m,
- method="sinkhorn", log=True)
+
+ ab, bb, Mb = nx.from_numpy(a, b, M)
+
+ G, _ = ot.unbalanced.sinkhorn_unbalanced2(
+ ab, bb, Mb, epsilon, reg_m, method="sinkhorn_stabilized", log=True
+ )
+ G2, _ = ot.unbalanced.sinkhorn_unbalanced2(
+ ab, bb, Mb, epsilon, reg_m, method="sinkhorn", log=True
+ )
+ G2_np, _ = ot.unbalanced.sinkhorn_unbalanced2(
+ a, b, M, epsilon, reg_m, method="sinkhorn", log=True
+ )
+ G = nx.to_numpy(G)
+ G2 = nx.to_numpy(G2)
np.testing.assert_allclose(G, G2, atol=1e-5)
+ np.testing.assert_allclose(G2, G2_np, atol=1e-5)
@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"])
-def test_unbalanced_barycenter(method):
+def test_unbalanced_barycenter(nx, method):
# test generalized sinkhorn for unbalanced OT barycenter
n = 100
rng = np.random.RandomState(42)
@@ -138,25 +163,29 @@ def test_unbalanced_barycenter(method):
epsilon = 1.
reg_m = 1.
- q, log = barycenter_unbalanced(A, M, reg=epsilon, reg_m=reg_m,
- method=method, log=True, verbose=True)
+ A, M = nx.from_numpy(A, M)
+
+ q, log = barycenter_unbalanced(
+ A, M, reg=epsilon, reg_m=reg_m, method=method, log=True, verbose=True
+ )
# check fixed point equations
fi = reg_m / (reg_m + epsilon)
- logA = np.log(A + 1e-16)
- logq = np.log(q + 1e-16)[:, None]
- logKtu = logsumexp(log["logu"][:, None, :] - M[:, :, None] / epsilon,
- axis=0)
- logKv = logsumexp(log["logv"][None, :] - M[:, :, None] / epsilon, axis=1)
+ logA = nx.log(A + 1e-16)
+ logq = nx.log(q + 1e-16)[:, None]
+ logKtu = nx.logsumexp(
+ log["logu"][:, None, :] - M[:, :, None] / epsilon, axis=0
+ )
+ logKv = nx.logsumexp(log["logv"][None, :] - M[:, :, None] / epsilon, axis=1)
v_final = fi * (logq - logKtu)
u_final = fi * (logA - logKv)
np.testing.assert_allclose(
- u_final, log["logu"], atol=1e-05)
+ nx.to_numpy(u_final), nx.to_numpy(log["logu"]), atol=1e-05)
np.testing.assert_allclose(
- v_final, log["logv"], atol=1e-05)
+ nx.to_numpy(v_final), nx.to_numpy(log["logv"]), atol=1e-05)
-def test_barycenter_stabilized_vs_sinkhorn():
+def test_barycenter_stabilized_vs_sinkhorn(nx):
# test generalized sinkhorn for unbalanced OT barycenter
n = 100
rng = np.random.RandomState(42)
@@ -170,21 +199,24 @@ def test_barycenter_stabilized_vs_sinkhorn():
epsilon = 0.5
reg_m = 10
- qstable, log = barycenter_unbalanced(A, M, reg=epsilon,
- reg_m=reg_m, log=True,
- tau=100,
- method="sinkhorn_stabilized",
- verbose=True
- )
- q, log = barycenter_unbalanced(A, M, reg=epsilon, reg_m=reg_m,
- method="sinkhorn",
- log=True)
+ Ab, Mb = nx.from_numpy(A, M)
- np.testing.assert_allclose(
- q, qstable, atol=1e-05)
+ qstable, _ = barycenter_unbalanced(
+ Ab, Mb, reg=epsilon, reg_m=reg_m, log=True, tau=100,
+ method="sinkhorn_stabilized", verbose=True
+ )
+ q, _ = barycenter_unbalanced(
+ Ab, Mb, reg=epsilon, reg_m=reg_m, method="sinkhorn", log=True
+ )
+ q_np, _ = barycenter_unbalanced(
+ A, M, reg=epsilon, reg_m=reg_m, method="sinkhorn", log=True
+ )
+ q, qstable = nx.to_numpy(q, qstable)
+ np.testing.assert_allclose(q, qstable, atol=1e-05)
+ np.testing.assert_allclose(q, q_np, atol=1e-05)
-def test_wrong_method():
+def test_wrong_method(nx):
n = 10
rng = np.random.RandomState(42)
@@ -199,19 +231,20 @@ def test_wrong_method():
epsilon = 1.
reg_m = 1.
+ a, b, M = nx.from_numpy(a, b, M)
+
with pytest.raises(ValueError):
- ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon,
- reg_m=reg_m,
- method='badmethod',
- log=True,
- verbose=True)
+ ot.unbalanced.sinkhorn_unbalanced(
+ a, b, M, reg=epsilon, reg_m=reg_m, method='badmethod',
+ log=True, verbose=True
+ )
with pytest.raises(ValueError):
- ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, reg_m,
- method='badmethod',
- verbose=True)
+ ot.unbalanced.sinkhorn_unbalanced2(
+ a, b, M, epsilon, reg_m, method='badmethod', verbose=True
+ )
-def test_implemented_methods():
+def test_implemented_methods(nx):
IMPLEMENTED_METHODS = ['sinkhorn', 'sinkhorn_stabilized']
TO_BE_IMPLEMENTED_METHODS = ['sinkhorn_reg_scaling']
NOT_VALID_TOKENS = ['foo']
@@ -228,6 +261,9 @@ def test_implemented_methods():
M = ot.dist(x, x)
epsilon = 1.
reg_m = 1.
+
+ a, b, M, A = nx.from_numpy(a, b, M, A)
+
for method in IMPLEMENTED_METHODS:
ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, reg_m,
method=method)
@@ -251,3 +287,52 @@ def test_implemented_methods():
method=method)
barycenter_unbalanced(A, M, reg=epsilon, reg_m=reg_m,
method=method)
+
+
+def test_mm_convergence(nx):
+ n = 100
+ rng = np.random.RandomState(42)
+ x = rng.randn(n, 2)
+ rng = np.random.RandomState(75)
+ y = rng.randn(n, 2)
+ a = ot.utils.unif(n)
+ b = ot.utils.unif(n)
+
+ M = ot.dist(x, y)
+ M = M / M.max()
+ reg_m = 100
+ a, b, M = nx.from_numpy(a, b, M)
+
+ G_kl, _ = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='kl',
+ verbose=True, log=True)
+ loss_kl = nx.to_numpy(ot.unbalanced.mm_unbalanced2(
+ a, b, M, reg_m, div='kl', verbose=True))
+ G_l2, _ = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='l2',
+ verbose=False, log=True)
+
+ # check if the marginals come close to the true ones when large reg
+ np.testing.assert_allclose(np.sum(nx.to_numpy(G_kl), 1), a, atol=1e-03)
+ np.testing.assert_allclose(np.sum(nx.to_numpy(G_kl), 0), b, atol=1e-03)
+ np.testing.assert_allclose(np.sum(nx.to_numpy(G_l2), 1), a, atol=1e-03)
+ np.testing.assert_allclose(np.sum(nx.to_numpy(G_l2), 0), b, atol=1e-03)
+
+ # check if mm_unbalanced2 returns the correct loss
+ np.testing.assert_allclose(nx.to_numpy(nx.sum(G_kl * M)), loss_kl,
+ atol=1e-5)
+
+ # check in case no histogram is provided
+ a_np, b_np = np.array([]), np.array([])
+ a, b = nx.from_numpy(a_np, b_np)
+
+ G_kl_null = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='kl')
+ G_l2_null = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='l2')
+ np.testing.assert_allclose(G_kl_null, G_kl)
+ np.testing.assert_allclose(G_l2_null, G_l2)
+
+ # test when G0 is given
+ G0 = ot.emd(a, b, M)
+ reg_m = 10000
+ G_kl = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='kl', G0=G0)
+ G_l2 = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='l2', G0=G0)
+ np.testing.assert_allclose(G0, G_kl, atol=1e-05)
+ np.testing.assert_allclose(G0, G_l2, atol=1e-05)
diff --git a/test/test_utils.py b/test/test_utils.py
index 6b476b2..3cfd295 100644
--- a/test/test_utils.py
+++ b/test/test_utils.py
@@ -62,12 +62,14 @@ def test_tic_toc():
import time
ot.tic()
- time.sleep(0.5)
+ time.sleep(0.1)
t = ot.toc()
t2 = ot.toq()
# test timing
- np.testing.assert_allclose(0.5, t, rtol=1e-1, atol=1e-1)
+ # np.testing.assert_allclose(0.1, t, rtol=1e-1, atol=1e-1)
+ # very slow macos github action equality not possible
+ assert t > 0.09
# test toc vs toq
np.testing.assert_allclose(t, t2, rtol=1e-1, atol=1e-1)
@@ -94,10 +96,22 @@ def test_unif():
np.testing.assert_allclose(1, np.sum(u))
-def test_dist():
+def test_unif_backend(nx):
n = 100
+ for tp in nx.__type_list__:
+ print(nx.dtype_device(tp))
+
+ u = ot.unif(n, type_as=tp)
+
+ np.testing.assert_allclose(1, np.sum(nx.to_numpy(u)), atol=1e-6)
+
+
+def test_dist():
+
+ n = 10
+
rng = np.random.RandomState(0)
x = rng.randn(n, 2)
@@ -122,7 +136,7 @@ def test_dist():
'braycurtis', 'canberra', 'chebyshev', 'cityblock', 'correlation', 'cosine', 'dice',
'euclidean', 'hamming', 'jaccard', 'kulsinski',
'matching', 'minkowski', 'rogerstanimoto', 'russellrao',
- 'sokalmichener', 'sokalsneath', 'sqeuclidean', 'wminkowski', 'yule'
+ 'sokalmichener', 'sokalsneath', 'sqeuclidean', 'yule'
] # those that support weights
metrics = ['mahalanobis', 'seuclidean'] # do not support weights depending on scipy's version
diff --git a/test/test_weak.py b/test/test_weak.py
new file mode 100644
index 0000000..945efb1
--- /dev/null
+++ b/test/test_weak.py
@@ -0,0 +1,52 @@
+"""Tests for main module ot.weak """
+
+# Author: Remi Flamary <remi.flamary@unice.fr>
+#
+# License: MIT License
+
+import ot
+import numpy as np
+
+
+def test_weak_ot():
+ # test weak ot solver and identity stationary point
+ n = 50
+ rng = np.random.RandomState(0)
+
+ xs = rng.randn(n, 2)
+ xt = rng.randn(n, 2)
+ u = ot.utils.unif(n)
+
+ G, log = ot.weak_optimal_transport(xs, xt, u, u, log=True)
+
+ # check constraints
+ np.testing.assert_allclose(u, G.sum(1))
+ np.testing.assert_allclose(u, G.sum(0))
+
+ # chaeck that identity is recovered
+ G = ot.weak_optimal_transport(xs, xs, G0=np.eye(n) / n)
+
+ # check G is identity
+ np.testing.assert_allclose(G, np.eye(n) / n)
+
+ # check constraints
+ np.testing.assert_allclose(u, G.sum(1))
+ np.testing.assert_allclose(u, G.sum(0))
+
+
+def test_weak_ot_bakends(nx):
+ # test weak ot solver for different backends
+ n = 50
+ rng = np.random.RandomState(0)
+
+ xs = rng.randn(n, 2)
+ xt = rng.randn(n, 2)
+ u = ot.utils.unif(n)
+
+ G = ot.weak_optimal_transport(xs, xt, u, u)
+
+ xs2, xt2, u2 = nx.from_numpy(xs, xt, u)
+
+ G2 = ot.weak_optimal_transport(xs2, xt2, u2, u2)
+
+ np.testing.assert_allclose(nx.to_numpy(G2), G)