summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGard Spreemann <gspr@nonempty.org>2021-12-29 19:26:32 +0100
committerGard Spreemann <gspr@nonempty.org>2021-12-29 19:26:32 +0100
commit367366a649f57a147456f11f7e803de12ced3b8f (patch)
treea900af1302f4a6923323d203ae8cc22550b59e8f
parent88d850422a838c29d70ef757d04ab57707d7cd26 (diff)
parentedab1c60630f95b38db430017585d06253c92817 (diff)
Merge branch 'dfsg/latest' into debian/sid
-rw-r--r--.circleci/config.yml37
-rw-r--r--.github/CONTRIBUTING.md7
-rw-r--r--.github/PULL_REQUEST_TEMPLATE.md13
-rw-r--r--.github/requirements_test_windows.txt2
-rw-r--r--.github/workflows/build_tests.yml21
-rw-r--r--.github/workflows/build_wheels.yml4
-rw-r--r--.github/workflows/build_wheels_weekly.yml2
-rw-r--r--README.md7
-rw-r--r--RELEASES.md36
-rw-r--r--benchmarks/__init__.py5
-rw-r--r--benchmarks/benchmark.py105
-rw-r--r--benchmarks/emd.py40
-rw-r--r--benchmarks/sinkhorn_knopp.py42
-rw-r--r--docs/requirements.txt1
-rw-r--r--docs/requirements_rtd.txt1
-rw-r--r--docs/source/.github/CODE_OF_CONDUCT.rst6
-rw-r--r--docs/source/.github/CONTRIBUTING.rst6
-rw-r--r--docs/source/_templates/versions.html10
-rw-r--r--docs/source/conf.py5
-rw-r--r--docs/source/index.rst7
-rw-r--r--docs/source/readme.rst539
-rw-r--r--docs/source/releases.rst469
-rw-r--r--examples/plot_Intro_OT.py2
-rw-r--r--ot/__init__.py2
-rw-r--r--ot/backend.py876
-rw-r--r--ot/bregman.py72
-rw-r--r--ot/da.py44
-rw-r--r--ot/datasets.py2
-rw-r--r--ot/dr.py2
-rw-r--r--ot/gromov.py59
-rw-r--r--ot/lp/solver_1d.py4
-rw-r--r--ot/optim.py47
-rw-r--r--ot/plot.py4
-rw-r--r--ot/utils.py10
-rw-r--r--pyproject.toml2
-rw-r--r--requirements.txt5
-rw-r--r--setup.cfg2
-rw-r--r--setup.py5
-rw-r--r--test/conftest.py12
-rw-r--r--test/test_1d_solver.py68
-rw-r--r--test/test_backend.py71
-rw-r--r--test/test_bregman.py46
-rw-r--r--test/test_gromov.py119
-rw-r--r--test/test_optim.py4
-rw-r--r--test/test_ot.py52
-rw-r--r--test/test_sliced.py57
-rw-r--r--test/test_utils.py20
47 files changed, 1765 insertions, 1187 deletions
diff --git a/.circleci/config.yml b/.circleci/config.yml
index 85f8073..32c211f 100644
--- a/.circleci/config.yml
+++ b/.circleci/config.yml
@@ -106,8 +106,10 @@ jobs:
echo "Deploying dev docs for ${CIRCLE_BRANCH}.";
cd master
cp -a /tmp/build/html/* .;
+ cp -a /tmp/build/html/.github .github;
touch .nojekyll;
git add -A;
+ git add -f .github/*.html ;
git commit -m "CircleCI update of dev docs (${CIRCLE_BUILD_NUM}).";
git push origin master;
else
@@ -134,24 +136,23 @@ jobs:
name: Deploy docs
command: |
set -e;
- if [ "${CIRCLE_BRANCH}" == "master" ]; then
- git config --global user.email "circle@PythonOT.com";
- git config --global user.name "Circle CI";
- cd ~/PythonOT.github.io;
- git checkout master
- git remote -v
- git fetch origin
- git reset --hard origin/master
- git clean -xdf
- echo "Deploying dev docs for ${CIRCLE_BRANCH}.";
- cp -a /tmp/build/html/* .;
- touch .nojekyll;
- git add -A;
- git commit -m "CircleCI update of dev docs (${CIRCLE_BUILD_NUM}).";
- git push origin master;
- else
- echo "No deployment (build: ${CIRCLE_BRANCH}).";
- fi
+ git config --global user.email "circle@PythonOT.com";
+ git config --global user.name "Circle CI";
+ cd ~/PythonOT.github.io;
+ git checkout master
+ git remote -v
+ git fetch origin
+ git reset --hard origin/master
+ git clean -xdf
+ echo "Deploying dev docs for ${CIRCLE_BRANCH}.";
+ cp -a /tmp/build/html/* .;
+ cp -a /tmp/build/html/.github .github;
+ touch .nojekyll;
+ git add -A;
+ git add -f .github/*.html ;
+ git commit -m "CircleCI update of dev docs (${CIRCLE_BUILD_NUM}).";
+ git push origin master;
+
workflows:
diff --git a/.github/CONTRIBUTING.md b/.github/CONTRIBUTING.md
index 54e7e42..9bc8e87 100644
--- a/.github/CONTRIBUTING.md
+++ b/.github/CONTRIBUTING.md
@@ -182,11 +182,10 @@ reStructuredText documents live in the source code repository under the
doc/ directory.
You can edit the documentation using any text editor and then generate
-the HTML output by typing ``make html`` from the doc/ directory.
+the HTML output by typing ``make html`` from the ``docs/`` directory.
Alternatively, ``make`` can be used to quickly generate the
-documentation without the example gallery. The resulting HTML files will
-be placed in ``_build/html/`` and are viewable in a web browser. See the
-``README`` file in the ``doc/`` directory for more information.
+documentation without the example gallery with `make html-noplot`. The resulting HTML files will
+be placed in `docs/build/html/` and are viewable in a web browser.
For building the documentation, you will need
[sphinx](http://sphinx.pocoo.org/),
diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md
index 7cfe4e6..f2c6606 100644
--- a/.github/PULL_REQUEST_TEMPLATE.md
+++ b/.github/PULL_REQUEST_TEMPLATE.md
@@ -1,10 +1,6 @@
## Types of changes
<!--- What types of changes does your code introduce? Put an `x` in all the boxes that apply: -->
-- [ ] Docs change / refactoring / dependency upgrade
-- [ ] Bug fix (non-breaking change which fixes an issue)
-- [ ] New feature (non-breaking change which adds functionality)
-- [ ] Breaking change (fix or feature that would cause existing functionality to change)
## Motivation and context / Related issue
@@ -13,16 +9,19 @@
<!--- (we recommend to have an existing issue for each pull request) -->
+
## How has this been tested (if it applies)
<!--- Please describe here how your modifications have been tested. -->
-## Checklist
+
+## PR checklist
<!-- - Go over all the following points, and put an `x` in all the boxes that apply. -->
<!--- If you're unsure about any of these, don't hesitate to ask. We're here to help! -->
-- [ ] The documentation is up-to-date with the changes I made.
- [ ] I have read the [**CONTRIBUTING**](CONTRIBUTING.md) document.
-- [ ] All tests passed, and additional code has been covered with new tests.
+- [ ] The documentation is up-to-date with the changes I made (check build artifacts).
+- [ ] All tests passed, and additional code has been **covered with new tests**.
+- [ ] I have added the PR and Issue fix to the [**RELEASES.md**](RELEASES.md) file.
<!--- In any case, don't hesitate to join and ask questions if you need on slack (https://pot-toolbox.slack.com/), gitter (https://gitter.im/PythonOT/community), or the mailing list (https://mail.python.org/mm3/mailman3/lists/pot.python.org/). -->
diff --git a/.github/requirements_test_windows.txt b/.github/requirements_test_windows.txt
index 331dd57..b94392f 100644
--- a/.github/requirements_test_windows.txt
+++ b/.github/requirements_test_windows.txt
@@ -4,7 +4,7 @@ cython
matplotlib
autograd
pymanopt==0.2.4; python_version <'3'
-pymanopt; python_version >= '3'
+pymanopt==0.2.6rc1; python_version >= '3'
cvxopt
scikit-learn
pytest \ No newline at end of file
diff --git a/.github/workflows/build_tests.yml b/.github/workflows/build_tests.yml
index ee5a435..3c99da8 100644
--- a/.github/workflows/build_tests.yml
+++ b/.github/workflows/build_tests.yml
@@ -22,7 +22,7 @@ jobs:
strategy:
max-parallel: 4
matrix:
- python-version: [ "3.6", "3.7", "3.8", "3.9"]
+ python-version: ["3.7", "3.8", "3.9"]
steps:
- uses: actions/checkout@v1
@@ -128,12 +128,29 @@ jobs:
uses: actions/setup-python@v1
with:
python-version: ${{ matrix.python-version }}
+ - name: RC.exe
+ run: |
+ function Invoke-VSDevEnvironment {
+ $vswhere = "${env:ProgramFiles(x86)}\Microsoft Visual Studio\Installer\vswhere.exe"
+ $installationPath = & $vswhere -prerelease -legacy -latest -property installationPath
+ $Command = Join-Path $installationPath "Common7\Tools\vsdevcmd.bat"
+ & "${env:COMSPEC}" /s /c "`"$Command`" -no_logo && set" | Foreach-Object {
+ if ($_ -match '^([^=]+)=(.*)') {
+ [System.Environment]::SetEnvironmentVariable($matches[1], $matches[2])
+ }
+ }
+ }
+ Invoke-VSDevEnvironment
+ Get-Command rc.exe | Format-Table -AutoSize
+ - name: Update pip
+ run : |
+ python -m pip install --upgrade pip setuptools
+ python -m pip install cython
- name: Install POT
run: |
python -m pip install -e .
- name: Install dependencies
run: |
- python -m pip install --upgrade pip
python -m pip install -r .github/requirements_test_windows.txt
python -m pip install torch==1.8.1+cpu -f https://download.pytorch.org/whl/torch_stable.html
python -m pip install pytest "pytest-cov<2.6"
diff --git a/.github/workflows/build_wheels.yml b/.github/workflows/build_wheels.yml
index a935a5e..c746eb8 100644
--- a/.github/workflows/build_wheels.yml
+++ b/.github/workflows/build_wheels.yml
@@ -36,7 +36,7 @@ jobs:
- name: Build wheels
env:
- CIBW_SKIP: "pp*-win* pp*-macosx* cp2* pp*" # remove pypy on mac and win (wrong version)
+ CIBW_SKIP: "pp*-win* pp*-macosx* cp2* pp* cp36*" # remove pypy on mac and win (wrong version)
CIBW_BEFORE_BUILD: "pip install numpy cython"
run: |
python -m cibuildwheel --output-dir wheelhouse
@@ -80,7 +80,7 @@ jobs:
- name: Build wheels
env:
- CIBW_SKIP: "pp*-win* pp*-macosx* cp2* pp* cp*musl*" # remove pypy on mac and win (wrong version)
+ CIBW_SKIP: "pp*-win* pp*-macosx* cp2* pp* cp*musl* cp36*" # remove pypy on mac and win (wrong version)
CIBW_BEFORE_BUILD: "pip install numpy cython"
CIBW_ARCHS_LINUX: auto aarch64 # force aarch64 with QEMU
CIBW_ARCHS_MACOS: x86_64 universal2 arm64
diff --git a/.github/workflows/build_wheels_weekly.yml b/.github/workflows/build_wheels_weekly.yml
index 2964844..dbf342f 100644
--- a/.github/workflows/build_wheels_weekly.yml
+++ b/.github/workflows/build_wheels_weekly.yml
@@ -41,7 +41,7 @@ jobs:
- name: Build wheels
env:
- CIBW_SKIP: "pp*-win* pp*-macosx* cp2* pp* cp*musl*" # remove pypy on mac and win (wrong version)
+ CIBW_SKIP: "pp*-win* pp*-macosx* cp2* pp* cp*musl* cp36*" # remove pypy on mac and win (wrong version)
CIBW_BEFORE_BUILD: "pip install numpy cython"
CIBW_ARCHS_LINUX: auto aarch64 # force aarch64 with QEMU
CIBW_ARCHS_MACOS: x86_64 universal2 arm64
diff --git a/README.md b/README.md
index 08db003..17fbe81 100644
--- a/README.md
+++ b/README.md
@@ -35,7 +35,7 @@ POT provides the following generic OT solvers (links to examples):
* [Partial Wasserstein and Gromov-Wasserstein](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_partial_wass_and_gromov.html) (exact [29] and entropic [3]
formulations).
* [Sliced Wasserstein](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_variance.html) [31, 32] and Max-sliced Wasserstein [35] that can be used for gradient flows [36].
-* [Several backends](https://pythonot.github.io/quickstart.html#solving-ot-with-multiple-backends) for easy use of POT with [Pytorch](https://pytorch.org/)/[jax](https://github.com/google/jax)/[Numpy](https://numpy.org/) arrays.
+* [Several backends](https://pythonot.github.io/quickstart.html#solving-ot-with-multiple-backends) for easy use of POT with [Pytorch](https://pytorch.org/)/[jax](https://github.com/google/jax)/[Numpy](https://numpy.org/)/[Cupy](https://cupy.dev/)/[Tensorflow](https://www.tensorflow.org/) arrays.
POT provides the following Machine Learning related solvers:
@@ -196,12 +196,13 @@ The contributors to this library are
* [Adrien Corenflos](https://adriencorenflos.github.io/) (Sliced Wasserstein Distance)
* [Tanguy Kerdoncuff](https://hv0nnus.github.io/) (Sampled Gromov Wasserstein)
* [Minhui Huang](https://mhhuang95.github.io) (Projection Robust Wasserstein Distance)
+* [Nathan Cassereau](https://github.com/ncassereau-idris) (Backends)
This toolbox benefit a lot from open source research and we would like to thank the following persons for providing some code (in various languages):
* [Gabriel Peyré](http://gpeyre.github.io/) (Wasserstein Barycenters in Matlab)
* [Mathieu Blondel](https://mblondel.org/) (original implementation smooth OT)
-* [Nicolas Bonneel](http://liris.cnrs.fr/~nbonneel/) ( C++ code for EMD)
+* [Nicolas Bonneel](http://liris.cnrs.fr/~nbonneel/) (C++ code for EMD)
* [Marco Cuturi](http://marcocuturi.net/) (Sinkhorn Knopp in Matlab/Cuda)
## Contributions and code of conduct
@@ -299,4 +300,4 @@ Machine Learning (pp. 4104-4113). PMLR.
Conference on Machine Learning, PMLR 119:4692-4701, 2020
[38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, [Online Graph
-Dictionary Learning](https://arxiv.org/pdf/2102.06555.pdf), International Conference on Machine Learning (ICML), 2021. \ No newline at end of file
+Dictionary Learning](https://arxiv.org/pdf/2102.06555.pdf), International Conference on Machine Learning (ICML), 2021.
diff --git a/RELEASES.md b/RELEASES.md
index 6eb1502..2a45465 100644
--- a/RELEASES.md
+++ b/RELEASES.md
@@ -1,6 +1,42 @@
# Releases
+## 0.8.1
+*December 2021*
+
+This release fixes several bugs and introduces two new backends: Cupy
+and Tensorflow. Note that the tensorflow backend will work only when tensorflow
+has enabled the Numpy behavior (for transpose that is not by default in
+tensorflow). We also introduce a simple benchmark on CPU GPU for the sinkhorn
+solver that will be provided in the
+[backend](https://pythonot.github.io/gen_modules/ot.backend.html) documentation.
+
+This release also brings a few changes in dependencies and compatibility. First
+we removed tests for Python 3.6 that will not be updated in the future.
+Also note that POT now depends on Numpy (>= 1.20) because a recent change in ABI is making the
+wheels non-compatible with older numpy versions. If you really need an older
+numpy POT will work with no problems but you will need to build it from source.
+
+As always we want to that the contributors who helped make POT better (and bug free).
+
+#### New features
+
+- New benchmark for sinkhorn solver on CPU/GPU and between backends (PR #316)
+- New tensorflow backend (PR #316)
+- New Cupy backend (PR #315)
+- Documentation always up-to-date with README, RELEASES, CONTRIBUTING and
+ CODE_OF_CONDUCT files (PR #316, PR #322).
+
+#### Closed issues
+
+- Fix bug in older Numpy ABI (<1.20) (Issue #308, PR #326)
+- Fix bug in `ot.dist` function when non euclidean distance (Issue #305, PR #306)
+- Fix gradient scaling for functions using `nx.set_gradients` (Issue #309, PR
+ #310)
+- Fix bug in generalized Conditional gradient solver and SinkhornL1L2 (Issue
+ #311, PR #313)
+- Fix log error in `gromov_barycenters` (Issue #317, PR #3018)
+
## 0.8.0
*November 2021*
diff --git a/benchmarks/__init__.py b/benchmarks/__init__.py
new file mode 100644
index 0000000..37f5e56
--- /dev/null
+++ b/benchmarks/__init__.py
@@ -0,0 +1,5 @@
+from . import benchmark
+from . import sinkhorn_knopp
+from . import emd
+
+__all__= ["benchmark", "sinkhorn_knopp", "emd"]
diff --git a/benchmarks/benchmark.py b/benchmarks/benchmark.py
new file mode 100644
index 0000000..7973c6b
--- /dev/null
+++ b/benchmarks/benchmark.py
@@ -0,0 +1,105 @@
+# /usr/bin/env python3
+# -*- coding: utf-8 -*-
+
+from ot.backend import get_backend_list, jax, tf
+import gc
+
+
+def setup_backends():
+ if jax:
+ from jax.config import config
+ config.update("jax_enable_x64", True)
+
+ if tf:
+ from tensorflow.python.ops.numpy_ops import np_config
+ np_config.enable_numpy_behavior()
+
+
+def exec_bench(setup, tested_function, param_list, n_runs, warmup_runs):
+ backend_list = get_backend_list()
+ for i, nx in enumerate(backend_list):
+ if nx.__name__ == "tf" and i < len(backend_list) - 1:
+ # Tensorflow should be the last one to be benchmarked because
+ # as far as I'm aware, there is no way to force it to release
+ # GPU memory. Hence, if any other backend is benchmarked after
+ # Tensorflow and requires the usage of a GPU, it will not have the
+ # full memory available and you may have a GPU Out Of Memory error
+ # even though your GPU can technically hold your tensors in memory.
+ backend_list.pop(i)
+ backend_list.append(nx)
+ break
+
+ inputs = [setup(param) for param in param_list]
+ results = dict()
+ for nx in backend_list:
+ for i in range(len(param_list)):
+ print(nx, param_list[i])
+ args = inputs[i]
+ results_nx = nx._bench(
+ tested_function,
+ *args,
+ n_runs=n_runs,
+ warmup_runs=warmup_runs
+ )
+ gc.collect()
+ results_nx_with_param_in_key = dict()
+ for key in results_nx:
+ new_key = (param_list[i], *key)
+ results_nx_with_param_in_key[new_key] = results_nx[key]
+ results.update(results_nx_with_param_in_key)
+ return results
+
+
+def convert_to_html_table(results, param_name, main_title=None, comments=None):
+ string = "<table>\n"
+ keys = list(results.keys())
+ params, names, devices, bitsizes = zip(*keys)
+
+ devices_names = sorted(list(set(zip(devices, names))))
+ params = sorted(list(set(params)))
+ bitsizes = sorted(list(set(bitsizes)))
+ length = len(devices_names) + 1
+ cpus_cols = list(devices).count("CPU") / len(bitsizes) / len(params)
+ gpus_cols = list(devices).count("GPU") / len(bitsizes) / len(params)
+ assert cpus_cols + gpus_cols == len(devices_names)
+
+ if main_title is not None:
+ string += f'<tr><th align="center" colspan="{length}">{str(main_title)}</th></tr>\n'
+
+ for i, bitsize in enumerate(bitsizes):
+
+ if i != 0:
+ string += f'<tr><td colspan="{length}">&nbsp;</td></tr>\n'
+
+ # make bitsize header
+ text = f"{bitsize} bits"
+ if comments is not None:
+ text += " - "
+ if isinstance(comments, (tuple, list)) and len(comments) == len(bitsizes):
+ text += str(comments[i])
+ else:
+ text += str(comments)
+ string += f'<tr><th align="center">Bitsize</th>'
+ string += f'<th align="center" colspan="{length - 1}">{text}</th></tr>\n'
+
+ # make device header
+ string += f'<tr><th align="center">Device</th>'
+ string += f'<th align="center" colspan="{cpus_cols}">CPU</th>'
+ string += f'<th align="center" colspan="{gpus_cols}">GPU</th></tr>\n'
+
+ # make param_name / backend header
+ string += f'<tr><th align="center">{param_name}</th>'
+ for device, name in devices_names:
+ string += f'<th align="center">{name}</th>'
+ string += "</tr>\n"
+
+ # make results rows
+ for param in params:
+ string += f'<tr><td align="center">{param}</td>'
+ for device, name in devices_names:
+ key = (param, name, device, bitsize)
+ string += f'<td align="center">{results[key]:.4f}</td>'
+ string += "</tr>\n"
+
+ string += "</table>"
+ return string
diff --git a/benchmarks/emd.py b/benchmarks/emd.py
new file mode 100644
index 0000000..9f64863
--- /dev/null
+++ b/benchmarks/emd.py
@@ -0,0 +1,40 @@
+# /usr/bin/env python3
+# -*- coding: utf-8 -*-
+
+import numpy as np
+import ot
+from .benchmark import (
+ setup_backends,
+ exec_bench,
+ convert_to_html_table
+)
+
+
+def setup(n_samples):
+ rng = np.random.RandomState(789465132)
+ x = rng.randn(n_samples, 2)
+ y = rng.randn(n_samples, 2)
+
+ a = ot.utils.unif(n_samples)
+ M = ot.dist(x, y)
+ return a, M
+
+
+if __name__ == "__main__":
+ n_runs = 100
+ warmup_runs = 10
+ param_list = [50, 100, 500, 1000, 2000, 5000]
+
+ setup_backends()
+ results = exec_bench(
+ setup=setup,
+ tested_function=lambda a, M: ot.emd(a, a, M),
+ param_list=param_list,
+ n_runs=n_runs,
+ warmup_runs=warmup_runs
+ )
+ print(convert_to_html_table(
+ results,
+ param_name="Sample size",
+ main_title=f"EMD - Averaged on {n_runs} runs"
+ ))
diff --git a/benchmarks/sinkhorn_knopp.py b/benchmarks/sinkhorn_knopp.py
new file mode 100644
index 0000000..3a1ef3f
--- /dev/null
+++ b/benchmarks/sinkhorn_knopp.py
@@ -0,0 +1,42 @@
+# /usr/bin/env python3
+# -*- coding: utf-8 -*-
+
+import numpy as np
+import ot
+from .benchmark import (
+ setup_backends,
+ exec_bench,
+ convert_to_html_table
+)
+
+
+def setup(n_samples):
+ rng = np.random.RandomState(123456789)
+ a = rng.rand(n_samples // 4, 100)
+ b = rng.rand(n_samples, 100)
+
+ wa = ot.unif(n_samples // 4)
+ wb = ot.unif(n_samples)
+
+ M = ot.dist(a.copy(), b.copy())
+ return wa, wb, M
+
+
+if __name__ == "__main__":
+ n_runs = 100
+ warmup_runs = 10
+ param_list = [50, 100, 500, 1000, 2000, 5000]
+
+ setup_backends()
+ results = exec_bench(
+ setup=setup,
+ tested_function=lambda *args: ot.bregman.sinkhorn(*args, reg=1, stopThr=1e-7),
+ param_list=param_list,
+ n_runs=n_runs,
+ warmup_runs=warmup_runs
+ )
+ print(convert_to_html_table(
+ results,
+ param_name="Sample size",
+ main_title=f"Sinkhorn Knopp - Averaged on {n_runs} runs"
+ ))
diff --git a/docs/requirements.txt b/docs/requirements.txt
index 256706b..2e060b9 100644
--- a/docs/requirements.txt
+++ b/docs/requirements.txt
@@ -4,3 +4,4 @@ numpydoc
memory_profiler
pillow
networkx
+myst-parser \ No newline at end of file
diff --git a/docs/requirements_rtd.txt b/docs/requirements_rtd.txt
index e3999d6..11957fb 100644
--- a/docs/requirements_rtd.txt
+++ b/docs/requirements_rtd.txt
@@ -3,6 +3,7 @@ numpydoc
memory_profiler
pillow
networkx
+myst-parser
numpy
scipy>=1.0
cython
diff --git a/docs/source/.github/CODE_OF_CONDUCT.rst b/docs/source/.github/CODE_OF_CONDUCT.rst
new file mode 100644
index 0000000..d4c5cec
--- /dev/null
+++ b/docs/source/.github/CODE_OF_CONDUCT.rst
@@ -0,0 +1,6 @@
+Code of Conduct
+===============
+
+.. include:: ../../../.github/CODE_OF_CONDUCT.md
+ :parser: myst_parser.sphinx_
+ :start-line: 2
diff --git a/docs/source/.github/CONTRIBUTING.rst b/docs/source/.github/CONTRIBUTING.rst
new file mode 100644
index 0000000..aef24e9
--- /dev/null
+++ b/docs/source/.github/CONTRIBUTING.rst
@@ -0,0 +1,6 @@
+Contributing to POT
+===================
+
+.. include:: ../../../.github/CONTRIBUTING.md
+ :parser: myst_parser.sphinx_
+ :start-line: 3
diff --git a/docs/source/_templates/versions.html b/docs/source/_templates/versions.html
index 10d60d7..f48ab86 100644
--- a/docs/source/_templates/versions.html
+++ b/docs/source/_templates/versions.html
@@ -1,4 +1,6 @@
-<div class="rst-versions shift-up" data-toggle="rst-versions" role="note" aria-label="versions">
+<div class="rst-versions" data-toggle="rst-versions" role="note"
+aria-label="versions">
+ <!-- add shift_up to the class for force viewing -->
<span class="rst-current-version" data-toggle="rst-current-version">
<span class="fa fa-book"> Python Optimal Transport</span>
versions
@@ -12,10 +14,12 @@
<dl>
<dt>Versions</dt>
+
+ <dd><a href="https://pythonot.github.io/">Release</a></dd>
- <dd><a href="https://pythonot.github.io/master">latest</a></dd>
+ <dd><a href="https://pythonot.github.io/master">Development</a></dd>
- <dd><a href="https://pythonot.github.io/">stable</a></dd>
+
</dl>
diff --git a/docs/source/conf.py b/docs/source/conf.py
index 9b5a719..849e97c 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -69,6 +69,7 @@ extensions = [
'sphinx.ext.viewcode',
'sphinx.ext.napoleon',
'sphinx_gallery.gen_gallery',
+ 'myst_parser'
]
autosummary_generate = True
@@ -81,8 +82,8 @@ templates_path = ['_templates']
# The suffix(es) of source filenames.
# You can specify multiple suffix as a list of string:
-# source_suffix = ['.rst', '.md']
-source_suffix = '.rst'
+source_suffix = ['.rst', '.md']
+# source_suffix = '.rst'
# The encoding of source files.
source_encoding = 'utf-8-sig'
diff --git a/docs/source/index.rst b/docs/source/index.rst
index be01343..8de31ae 100644
--- a/docs/source/index.rst
+++ b/docs/source/index.rst
@@ -17,10 +17,11 @@ Contents
all
auto_examples/index
releases
+ .github/CONTRIBUTING
+ .github/CODE_OF_CONDUCT
-.. include:: readme.rst
- :start-line: 2
-
+.. include:: ../../README.md
+ :parser: myst_parser.sphinx_
Indices and tables
diff --git a/docs/source/readme.rst b/docs/source/readme.rst
deleted file mode 100644
index a8f1bc0..0000000
--- a/docs/source/readme.rst
+++ /dev/null
@@ -1,539 +0,0 @@
-POT: Python Optimal Transport
-=============================
-
-|PyPI version| |Anaconda Cloud| |Build Status| |Codecov Status|
-|Downloads| |Anaconda downloads| |License|
-
-This open source Python library provide several solvers for optimization
-problems related to Optimal Transport for signal, image processing and
-machine learning.
-
-Website and documentation: https://PythonOT.github.io/
-
-Source Code (MIT): https://github.com/PythonOT/POT
-
-POT provides the following generic OT solvers (links to examples):
-
-- `OT Network Simplex
- solver <auto_examples/plot_OT_1D.html>`__
- for the linear program/ Earth Movers Distance [1] .
-- `Conditional
- gradient <auto_examples/plot_optim_OTreg.html>`__
- [6] and `Generalized conditional
- gradient <auto_examples/plot_optim_OTreg.html>`__
- for regularized OT [7].
-- Entropic regularization OT solver with `Sinkhorn Knopp
- Algorithm <auto_examples/plot_OT_1D.html>`__
- [2] , stabilized version [9] [10] [34], greedy Sinkhorn [22] and
- `Screening Sinkhorn
- [26] <auto_examples/plot_screenkhorn_1D.html>`__.
-- Bregman projections for `Wasserstein
- barycenter <auto_examples/barycenters/plot_barycenter_lp_vs_entropic.html>`__
- [3], `convolutional
- barycenter <auto_examples/barycenters/plot_convolutional_barycenter.html>`__
- [21] and unmixing [4].
-- Sinkhorn divergence [23] and entropic regularization OT from
- empirical data.
-- Debiased Sinkhorn barycenters `Sinkhorn divergence
- barycenter <auto_examples/barycenters/plot_debiased_barycenter.html>`__
- [37]
-- `Smooth optimal transport
- solvers <auto_examples/plot_OT_1D_smooth.html>`__
- (dual and semi-dual) for KL and squared L2 regularizations [17].
-- Non regularized `Wasserstein barycenters
- [16] <auto_examples/barycenters/plot_barycenter_lp_vs_entropic.html>`__)
- with LP solver (only small scale).
-- `Gromov-Wasserstein
- distances <auto_examples/gromov/plot_gromov.html>`__
- and `GW
- barycenters <auto_examples/gromov/plot_gromov_barycenter.html>`__
- (exact [13] and regularized [12]), differentiable using gradients
- from
-- `Fused-Gromov-Wasserstein distances
- solver <auto_examples/gromov/plot_fgw.html#sphx-glr-auto-examples-plot-fgw-py>`__
- and `FGW
- barycenters <auto_examples/gromov/plot_barycenter_fgw.html>`__
- [24]
-- `Stochastic
- solver <auto_examples/plot_stochastic.html>`__
- for Large-scale Optimal Transport (semi-dual problem [18] and dual
- problem [19])
-- `Stochastic solver of Gromov
- Wasserstein <auto_examples/gromov/plot_gromov.html>`__
- for large-scale problem with any loss functions [33]
-- Non regularized `free support Wasserstein
- barycenters <auto_examples/barycenters/plot_free_support_barycenter.html>`__
- [20].
-- `Unbalanced
- OT <auto_examples/unbalanced-partial/plot_UOT_1D.html>`__
- with KL relaxation and
- `barycenter <auto_examples/unbalanced-partial/plot_UOT_barycenter_1D.html>`__
- [10, 25].
-- `Partial Wasserstein and
- Gromov-Wasserstein <auto_examples/unbalanced-partial/plot_partial_wass_and_gromov.html>`__
- (exact [29] and entropic [3] formulations).
-- `Sliced
- Wasserstein <auto_examples/sliced-wasserstein/plot_variance.html>`__
- [31, 32] and Max-sliced Wasserstein [35] that can be used for
- gradient flows [36].
-- `Several
- backends <https://pythonot.github.io/quickstart.html#solving-ot-with-multiple-backends>`__
- for easy use of POT with
- `Pytorch <https://pytorch.org/>`__/`jax <https://github.com/google/jax>`__/`Numpy <https://numpy.org/>`__
- arrays.
-
-POT provides the following Machine Learning related solvers:
-
-- `Optimal transport for domain
- adaptation <auto_examples/domain-adaptation/plot_otda_classes.html>`__
- with `group lasso
- regularization <auto_examples/domain-adaptation/plot_otda_classes.html>`__,
- `Laplacian
- regularization <auto_examples/domain-adaptation/plot_otda_laplacian.html>`__
- [5] [30] and `semi supervised
- setting <auto_examples/domain-adaptation/plot_otda_semi_supervised.html>`__.
-- `Linear OT
- mapping <auto_examples/domain-adaptation/plot_otda_linear_mapping.html>`__
- [14] and `Joint OT mapping
- estimation <auto_examples/domain-adaptation/plot_otda_mapping.html>`__
- [8].
-- `Wasserstein Discriminant
- Analysis <auto_examples/others/plot_WDA.html>`__
- [11] (requires autograd + pymanopt).
-- `JCPOT algorithm for multi-source domain adaptation with target
- shift <auto_examples/domain-adaptation/plot_otda_jcpot.html>`__
- [27].
-
-Some other examples are available in the
-`documentation <auto_examples/index.html>`__.
-
-Using and citing the toolbox
-^^^^^^^^^^^^^^^^^^^^^^^^^^^^
-
-If you use this toolbox in your research and find it useful, please cite
-POT using the following reference from our `JMLR
-paper <https://jmlr.org/papers/v22/20-451.html>`__:
-
-::
-
- Rémi Flamary, Nicolas Courty, Alexandre Gramfort, Mokhtar Z. Alaya, Aurélie Boisbunon, Stanislas Chambon, Laetitia Chapel, Adrien Corenflos, Kilian Fatras, Nemo Fournier, Léo Gautheron, Nathalie T.H. Gayraud, Hicham Janati, Alain Rakotomamonjy, Ievgen Redko, Antoine Rolet, Antony Schutz, Vivien Seguy, Danica J. Sutherland, Romain Tavenard, Alexander Tong, Titouan Vayer,
- POT Python Optimal Transport library,
- Journal of Machine Learning Research, 22(78):1−8, 2021.
- Website: https://pythonot.github.io/
-
-In Bibtex format:
-
-.. code:: bibtex
-
- @article{flamary2021pot,
- author = {R{\'e}mi Flamary and Nicolas Courty and Alexandre Gramfort and Mokhtar Z. Alaya and Aur{\'e}lie Boisbunon and Stanislas Chambon and Laetitia Chapel and Adrien Corenflos and Kilian Fatras and Nemo Fournier and L{\'e}o Gautheron and Nathalie T.H. Gayraud and Hicham Janati and Alain Rakotomamonjy and Ievgen Redko and Antoine Rolet and Antony Schutz and Vivien Seguy and Danica J. Sutherland and Romain Tavenard and Alexander Tong and Titouan Vayer},
- title = {POT: Python Optimal Transport},
- journal = {Journal of Machine Learning Research},
- year = {2021},
- volume = {22},
- number = {78},
- pages = {1-8},
- url = {http://jmlr.org/papers/v22/20-451.html}
- }
-
-Installation
-------------
-
-The library has been tested on Linux, MacOSX and Windows. It requires a
-C++ compiler for building/installing the EMD solver and relies on the
-following Python modules:
-
-- Numpy (>=1.16)
-- Scipy (>=1.0)
-- Cython (>=0.23) (build only, not necessary when installing from pip
- or conda)
-
-Pip installation
-^^^^^^^^^^^^^^^^
-
-You can install the toolbox through PyPI with:
-
-.. code:: console
-
- pip install POT
-
-or get the very latest version by running:
-
-.. code:: console
-
- pip install -U https://github.com/PythonOT/POT/archive/master.zip # with --user for user install (no root)
-
-Anaconda installation with conda-forge
-^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
-
-If you use the Anaconda python distribution, POT is available in
-`conda-forge <https://conda-forge.org>`__. To install it and the
-required dependencies:
-
-.. code:: console
-
- conda install -c conda-forge pot
-
-Post installation check
-^^^^^^^^^^^^^^^^^^^^^^^
-
-After a correct installation, you should be able to import the module
-without errors:
-
-.. code:: python
-
- import ot
-
-Note that for easier access the module is named ``ot`` instead of
-``pot``.
-
-Dependencies
-~~~~~~~~~~~~
-
-Some sub-modules require additional dependences which are discussed
-below
-
-- **ot.dr** (Wasserstein dimensionality reduction) depends on autograd
- and pymanopt that can be installed with:
-
-.. code:: shell
-
- pip install pymanopt autograd
-
-- **ot.gpu** (GPU accelerated OT) depends on cupy that have to be
- installed following instructions on `this
- page <https://docs-cupy.chainer.org/en/stable/install.html>`__.
- Obviously you will need CUDA installed and a compatible GPU. Note
- that this module is deprecated since version 0.8 and will be deleted
- in the future. GPU is now handled automatically through the backends
- and several solver already can run on GPU using the Pytorch backend.
-
-Examples
---------
-
-Short examples
-~~~~~~~~~~~~~~
-
-- Import the toolbox
-
-.. code:: python
-
- import ot
-
-- Compute Wasserstein distances
-
-.. code:: python
-
- # a,b are 1D histograms (sum to 1 and positive)
- # M is the ground cost matrix
- Wd = ot.emd2(a, b, M) # exact linear program
- Wd_reg = ot.sinkhorn2(a, b, M, reg) # entropic regularized OT
- # if b is a matrix compute all distances to a and return a vector
-
-- Compute OT matrix
-
-.. code:: python
-
- # a,b are 1D histograms (sum to 1 and positive)
- # M is the ground cost matrix
- T = ot.emd(a, b, M) # exact linear program
- T_reg = ot.sinkhorn(a, b, M, reg) # entropic regularized OT
-
-- Compute Wasserstein barycenter
-
-.. code:: python
-
- # A is a n*d matrix containing d 1D histograms
- # M is the ground cost matrix
- ba = ot.barycenter(A, M, reg) # reg is regularization parameter
-
-Examples and Notebooks
-~~~~~~~~~~~~~~~~~~~~~~
-
-The examples folder contain several examples and use case for the
-library. The full documentation with examples and output is available on
-https://PythonOT.github.io/.
-
-Acknowledgements
-----------------
-
-This toolbox has been created and is maintained by
-
-- `Rémi Flamary <http://remi.flamary.com/>`__
-- `Nicolas Courty <http://people.irisa.fr/Nicolas.Courty/>`__
-
-The contributors to this library are
-
-- `Alexandre Gramfort <http://alexandre.gramfort.net/>`__ (CI,
- documentation)
-- `Laetitia Chapel <http://people.irisa.fr/Laetitia.Chapel/>`__
- (Partial OT)
-- `Michael Perrot <http://perso.univ-st-etienne.fr/pem82055/>`__
- (Mapping estimation)
-- `Léo Gautheron <https://github.com/aje>`__ (GPU implementation)
-- `Nathalie
- Gayraud <https://www.linkedin.com/in/nathalie-t-h-gayraud/?ppe=1>`__
- (DA classes)
-- `Stanislas Chambon <https://slasnista.github.io/>`__ (DA classes)
-- `Antoine Rolet <https://arolet.github.io/>`__ (EMD solver debug)
-- Erwan Vautier (Gromov-Wasserstein)
-- `Kilian Fatras <https://kilianfatras.github.io/>`__ (Stochastic
- solvers)
-- `Alain
- Rakotomamonjy <https://sites.google.com/site/alainrakotomamonjy/home>`__
-- `Vayer Titouan <https://tvayer.github.io/>`__ (Gromov-Wasserstein -,
- Fused-Gromov-Wasserstein)
-- `Hicham Janati <https://hichamjanati.github.io/>`__ (Unbalanced OT,
- Debiased barycenters)
-- `Romain Tavenard <https://rtavenar.github.io/>`__ (1d Wasserstein)
-- `Mokhtar Z. Alaya <http://mzalaya.github.io/>`__ (Screenkhorn)
-- `Ievgen Redko <https://ievred.github.io/>`__ (Laplacian DA, JCPOT)
-- `Adrien Corenflos <https://adriencorenflos.github.io/>`__ (Sliced
- Wasserstein Distance)
-- `Tanguy Kerdoncuff <https://hv0nnus.github.io/>`__ (Sampled Gromov
- Wasserstein)
-- `Minhui Huang <https://mhhuang95.github.io>`__ (Projection Robust
- Wasserstein Distance)
-
-This toolbox benefit a lot from open source research and we would like
-to thank the following persons for providing some code (in various
-languages):
-
-- `Gabriel Peyré <http://gpeyre.github.io/>`__ (Wasserstein Barycenters
- in Matlab)
-- `Mathieu Blondel <https://mblondel.org/>`__ (original implementation
- smooth OT)
-- `Nicolas Bonneel <http://liris.cnrs.fr/~nbonneel/>`__ ( C++ code for
- EMD)
-- `Marco Cuturi <http://marcocuturi.net/>`__ (Sinkhorn Knopp in
- Matlab/Cuda)
-
-Contributions and code of conduct
----------------------------------
-
-Every contribution is welcome and should respect the `contribution
-guidelines <.github/CONTRIBUTING.md>`__. Each member of the project is
-expected to follow the `code of conduct <.github/CODE_OF_CONDUCT.md>`__.
-
-Support
--------
-
-You can ask questions and join the development discussion:
-
-- On the POT `slack channel <https://pot-toolbox.slack.com>`__
-- On the POT `gitter channel <https://gitter.im/PythonOT/community>`__
-- On the POT `mailing
- list <https://mail.python.org/mm3/mailman3/lists/pot.python.org/>`__
-
-You can also post bug reports and feature requests in Github issues.
-Make sure to read our `guidelines <.github/CONTRIBUTING.md>`__ first.
-
-References
-----------
-
-[1] Bonneel, N., Van De Panne, M., Paris, S., & Heidrich, W. (2011,
-December). `Displacement interpolation using Lagrangian mass
-transport <https://people.csail.mit.edu/sparis/publi/2011/sigasia/Bonneel_11_Displacement_Interpolation.pdf>`__.
-In ACM Transactions on Graphics (TOG) (Vol. 30, No. 6, p. 158). ACM.
-
-[2] Cuturi, M. (2013). `Sinkhorn distances: Lightspeed computation of
-optimal transport <https://arxiv.org/pdf/1306.0895.pdf>`__. In Advances
-in Neural Information Processing Systems (pp. 2292-2300).
-
-[3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G.
-(2015). `Iterative Bregman projections for regularized transportation
-problems <https://arxiv.org/pdf/1412.5154.pdf>`__. SIAM Journal on
-Scientific Computing, 37(2), A1111-A1138.
-
-[4] S. Nakhostin, N. Courty, R. Flamary, D. Tuia, T. Corpetti,
-`Supervised planetary unmixing with optimal
-transport <https://hal.archives-ouvertes.fr/hal-01377236/document>`__,
-Whorkshop on Hyperspectral Image and Signal Processing : Evolution in
-Remote Sensing (WHISPERS), 2016.
-
-[5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, `Optimal Transport
-for Domain Adaptation <https://arxiv.org/pdf/1507.00504.pdf>`__, in IEEE
-Transactions on Pattern Analysis and Machine Intelligence , vol.PP,
-no.99, pp.1-1
-
-[6] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014).
-`Regularized discrete optimal
-transport <https://arxiv.org/pdf/1307.5551.pdf>`__. SIAM Journal on
-Imaging Sciences, 7(3), 1853-1882.
-
-[7] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015). `Generalized
-conditional gradient: analysis of convergence and
-applications <https://arxiv.org/pdf/1510.06567.pdf>`__. arXiv preprint
-arXiv:1510.06567.
-
-[8] M. Perrot, N. Courty, R. Flamary, A. Habrard (2016), `Mapping
-estimation for discrete optimal
-transport <http://remi.flamary.com/biblio/perrot2016mapping.pdf>`__,
-Neural Information Processing Systems (NIPS).
-
-[9] Schmitzer, B. (2016). `Stabilized Sparse Scaling Algorithms for
-Entropy Regularized Transport
-Problems <https://arxiv.org/pdf/1610.06519.pdf>`__. arXiv preprint
-arXiv:1610.06519.
-
-[10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
-`Scaling algorithms for unbalanced transport
-problems <https://arxiv.org/pdf/1607.05816.pdf>`__. arXiv preprint
-arXiv:1607.05816.
-
-[11] Flamary, R., Cuturi, M., Courty, N., & Rakotomamonjy, A. (2016).
-`Wasserstein Discriminant
-Analysis <https://arxiv.org/pdf/1608.08063.pdf>`__. arXiv preprint
-arXiv:1608.08063.
-
-[12] Gabriel Peyré, Marco Cuturi, and Justin Solomon (2016),
-`Gromov-Wasserstein averaging of kernel and distance
-matrices <http://proceedings.mlr.press/v48/peyre16.html>`__
-International Conference on Machine Learning (ICML).
-
-[13] Mémoli, Facundo (2011). `Gromov–Wasserstein distances and the
-metric approach to object
-matching <https://media.adelaide.edu.au/acvt/Publications/2011/2011-Gromov%E2%80%93Wasserstein%20Distances%20and%20the%20Metric%20Approach%20to%20Object%20Matching.pdf>`__.
-Foundations of computational mathematics 11.4 : 417-487.
-
-[14] Knott, M. and Smith, C. S. (1984).`On the optimal mapping of
-distributions <https://link.springer.com/article/10.1007/BF00934745>`__,
-Journal of Optimization Theory and Applications Vol 43.
-
-[15] Peyré, G., & Cuturi, M. (2018). `Computational Optimal
-Transport <https://arxiv.org/pdf/1803.00567.pdf>`__ .
-
-[16] Agueh, M., & Carlier, G. (2011). `Barycenters in the Wasserstein
-space <https://hal.archives-ouvertes.fr/hal-00637399/document>`__. SIAM
-Journal on Mathematical Analysis, 43(2), 904-924.
-
-[17] Blondel, M., Seguy, V., & Rolet, A. (2018). `Smooth and Sparse
-Optimal Transport <https://arxiv.org/abs/1710.06276>`__. Proceedings of
-the Twenty-First International Conference on Artificial Intelligence and
-Statistics (AISTATS).
-
-[18] Genevay, A., Cuturi, M., Peyré, G. & Bach, F. (2016) `Stochastic
-Optimization for Large-scale Optimal
-Transport <https://arxiv.org/abs/1605.08527>`__. Advances in Neural
-Information Processing Systems (2016).
-
-[19] Seguy, V., Bhushan Damodaran, B., Flamary, R., Courty, N., Rolet,
-A.& Blondel, M. `Large-scale Optimal Transport and Mapping
-Estimation <https://arxiv.org/pdf/1711.02283.pdf>`__. International
-Conference on Learning Representation (2018)
-
-[20] Cuturi, M. and Doucet, A. (2014) `Fast Computation of Wasserstein
-Barycenters <http://proceedings.mlr.press/v32/cuturi14.html>`__.
-International Conference in Machine Learning
-
-[21] Solomon, J., De Goes, F., Peyré, G., Cuturi, M., Butscher, A.,
-Nguyen, A. & Guibas, L. (2015). `Convolutional wasserstein distances:
-Efficient optimal transportation on geometric
-domains <https://dl.acm.org/citation.cfm?id=2766963>`__. ACM
-Transactions on Graphics (TOG), 34(4), 66.
-
-[22] J. Altschuler, J.Weed, P. Rigollet, (2017) `Near-linear time
-approximation algorithms for optimal transport via Sinkhorn
-iteration <https://papers.nips.cc/paper/6792-near-linear-time-approximation-algorithms-for-optimal-transport-via-sinkhorn-iteration.pdf>`__,
-Advances in Neural Information Processing Systems (NIPS) 31
-
-[23] Aude, G., Peyré, G., Cuturi, M., `Learning Generative Models with
-Sinkhorn Divergences <https://arxiv.org/abs/1706.00292>`__, Proceedings
-of the Twenty-First International Conference on Artficial Intelligence
-and Statistics, (AISTATS) 21, 2018
-
-[24] Vayer, T., Chapel, L., Flamary, R., Tavenard, R. and Courty, N.
-(2019). `Optimal Transport for structured data with application on
-graphs <http://proceedings.mlr.press/v97/titouan19a.html>`__ Proceedings
-of the 36th International Conference on Machine Learning (ICML).
-
-[25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. (2015).
-`Learning with a Wasserstein Loss <http://cbcl.mit.edu/wasserstein/>`__
-Advances in Neural Information Processing Systems (NIPS).
-
-[26] Alaya M. Z., Bérar M., Gasso G., Rakotomamonjy A. (2019).
-`Screening Sinkhorn Algorithm for Regularized Optimal
-Transport <https://papers.nips.cc/paper/9386-screening-sinkhorn-algorithm-for-regularized-optimal-transport>`__,
-Advances in Neural Information Processing Systems 33 (NeurIPS).
-
-[27] Redko I., Courty N., Flamary R., Tuia D. (2019). `Optimal Transport
-for Multi-source Domain Adaptation under Target
-Shift <http://proceedings.mlr.press/v89/redko19a.html>`__, Proceedings
-of the Twenty-Second International Conference on Artificial Intelligence
-and Statistics (AISTATS) 22, 2019.
-
-[28] Caffarelli, L. A., McCann, R. J. (2010). `Free boundaries in
-optimal transport and Monge-Ampere obstacle
-problems <http://www.math.toronto.edu/~mccann/papers/annals2010.pdf>`__,
-Annals of mathematics, 673-730.
-
-[29] Chapel, L., Alaya, M., Gasso, G. (2020). `Partial Optimal Transport
-with Applications on Positive-Unlabeled
-Learning <https://arxiv.org/abs/2002.08276>`__, Advances in Neural
-Information Processing Systems (NeurIPS), 2020.
-
-[30] Flamary R., Courty N., Tuia D., Rakotomamonjy A. (2014). `Optimal
-transport with Laplacian regularization: Applications to domain
-adaptation and shape
-matching <https://remi.flamary.com/biblio/flamary2014optlaplace.pdf>`__,
-NIPS Workshop on Optimal Transport and Machine Learning OTML, 2014.
-
-[31] Bonneel, Nicolas, et al. `Sliced and radon wasserstein barycenters
-of
-measures <https://perso.liris.cnrs.fr/nicolas.bonneel/WassersteinSliced-JMIV.pdf>`__,
-Journal of Mathematical Imaging and Vision 51.1 (2015): 22-45
-
-[32] Huang, M., Ma S., Lai, L. (2021). `A Riemannian Block Coordinate
-Descent Method for Computing the Projection Robust Wasserstein
-Distance <http://proceedings.mlr.press/v139/huang21e.html>`__,
-Proceedings of the 38th International Conference on Machine Learning
-(ICML).
-
-[33] Kerdoncuff T., Emonet R., Marc S. `Sampled Gromov
-Wasserstein <https://hal.archives-ouvertes.fr/hal-03232509/document>`__,
-Machine Learning Journal (MJL), 2021
-
-[34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., Trouvé, A.,
-& Peyré, G. (2019, April). `Interpolating between optimal transport and
-MMD using Sinkhorn
-divergences <http://proceedings.mlr.press/v89/feydy19a/feydy19a.pdf>`__.
-In The 22nd International Conference on Artificial Intelligence and
-Statistics (pp. 2681-2690). PMLR.
-
-[35] Deshpande, I., Hu, Y. T., Sun, R., Pyrros, A., Siddiqui, N.,
-Koyejo, S., ... & Schwing, A. G. (2019). `Max-sliced wasserstein
-distance and its use for
-gans <https://openaccess.thecvf.com/content_CVPR_2019/papers/Deshpande_Max-Sliced_Wasserstein_Distance_and_Its_Use_for_GANs_CVPR_2019_paper.pdf>`__.
-In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern
-Recognition (pp. 10648-10656).
-
-[36] Liutkus, A., Simsekli, U., Majewski, S., Durmus, A., & Stöter, F.
-R. (2019, May). `Sliced-Wasserstein flows: Nonparametric generative
-modeling via optimal transport and
-diffusions <http://proceedings.mlr.press/v97/liutkus19a/liutkus19a.pdf>`__.
-In International Conference on Machine Learning (pp. 4104-4113). PMLR.
-
-[37] Janati, H., Cuturi, M., Gramfort, A. `Debiased sinkhorn
-barycenters <http://proceedings.mlr.press/v119/janati20a/janati20a.pdf>`__
-Proceedings of the 37th International Conference on Machine Learning,
-PMLR 119:4692-4701, 2020
-
-[38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty,
-`Online Graph Dictionary
-Learning <https://arxiv.org/pdf/2102.06555.pdf>`__, International
-Conference on Machine Learning (ICML), 2021.
-
-.. |PyPI version| image:: https://badge.fury.io/py/POT.svg
- :target: https://badge.fury.io/py/POT
-.. |Anaconda Cloud| image:: https://anaconda.org/conda-forge/pot/badges/version.svg
- :target: https://anaconda.org/conda-forge/pot
-.. |Build Status| image:: https://github.com/PythonOT/POT/workflows/build/badge.svg?branch=master&event=push
- :target: https://github.com/PythonOT/POT/actions
-.. |Codecov Status| image:: https://codecov.io/gh/PythonOT/POT/branch/master/graph/badge.svg
- :target: https://codecov.io/gh/PythonOT/POT
-.. |Downloads| image:: https://pepy.tech/badge/pot
- :target: https://pepy.tech/project/pot
-.. |Anaconda downloads| image:: https://anaconda.org/conda-forge/pot/badges/downloads.svg
- :target: https://anaconda.org/conda-forge/pot
-.. |License| image:: https://anaconda.org/conda-forge/pot/badges/license.svg
- :target: https://github.com/PythonOT/POT/blob/master/LICENSE
diff --git a/docs/source/releases.rst b/docs/source/releases.rst
index aa06105..8250a4d 100644
--- a/docs/source/releases.rst
+++ b/docs/source/releases.rst
@@ -1,469 +1,6 @@
Releases
========
-0.8.0
------
-
-*November 2021*
-
-This new stable release introduces several important features.
-
-First we now have an OpenMP compatible exact ot solver in ``ot.emd``.
-The OpenMP version is used when the parameter ``numThreads`` is greater
-than one and can lead to nice speedups on multi-core machines.
-
-| Second we have introduced a backend mechanism that allows to use
- standard POT function seamlessly on Numpy, Pytorch and Jax arrays.
- Other backends are coming but right now POT can be used seamlessly for
- training neural networks in Pytorch. Notably we propose the first
- differentiable computation of the exact OT loss with ``ot.emd2`` (can
- be differentiated w.r.t. both cost matrix and sample weights), but
- also for the classical Sinkhorn loss with ``ot.sinkhorn2``, the
- Wasserstein distance in 1D with ``ot.wasserstein_1d``, sliced
- Wasserstein with ``ot.sliced_wasserstein_distance`` and
- Gromov-Wasserstein with ``ot.gromov_wasserstein2``. Examples of how
- this new feature can be used are now available in the documentation
- where the Pytorch backend is used to estimate a `minimal Wasserstein
- estimator <https://PythonOT.github.io/auto_examples/backends/plot_unmix_optim_torch.html>`__,
- a `Generative Network
- (GAN) <https://PythonOT.github.io/auto_examples/backends/plot_wass2_gan_torch.html>`__,
- for a `sliced Wasserstein gradient
- flow <https://PythonOT.github.io/auto_examples/backends/plot_sliced_wass_grad_flow_pytorch.html>`__
- and `optimizing the Gromov-Wassersein
- distance <https://PythonOT.github.io/auto_examples/backends/plot_optim_gromov_pytorch.html>`__.
- Note that the Jax backend is still in early development and quite slow
- at the moment, we strongly recommend for Jax users to use the `OTT
- toolbox <https://github.com/google-research/ott>`__ when possible.
-| As a result of this new feature, the old ``ot.gpu`` submodule is now
- deprecated since GPU implementations can be done using GPU arrays on
- the torch backends.
-
-Other novel features include implementation for `Sampled Gromov
-Wasserstein and Pointwise Gromov
-Wasserstein <https://PythonOT.github.io/auto_examples/gromov/plot_gromov.html#compute-gw-with-a-scalable-stochastic-method-with-any-loss-function>`__,
-Sinkhorn in log space with ``method='sinkhorn_log'``, `Projection Robust
-Wasserstein <https://PythonOT.github.io/gen_modules/ot.dr.html?highlight=robust#ot.dr.projection_robust_wasserstein>`__,
-ans `deviased Sinkorn
-barycenters <https://PythonOT.github.ioauto_examples/barycenters/plot_debiased_barycenter.html>`__.
-
-This release will also simplify the installation process. We have now a
-``pyproject.toml`` that defines the build dependency and POT should now
-build even when cython is not installed yet. Also we now provide
-pe-compiled wheels for linux ``aarch64`` that is used on Raspberry PI
-and android phones and for MacOS on ARM processors.
-
-Finally POT was accepted for publication in the Journal of Machine
-Learning Research (JMLR) open source software track and we ask the POT
-users to cite `this
-paper <https://www.jmlr.org/papers/v22/20-451.html>`__ from now on. The
-documentation has been improved in particular by adding a "Why OT?"
-section to the quick start guide and several new examples illustrating
-the new features. The documentation now has two version : the stable
-version https://pythonot.github.io/ corresponding to the last release
-and the master version https://pythonot.github.io/master that
-corresponds to the current master branch on GitHub.
-
-As usual, we want to thank all the POT contributors (now 37 people have
-contributed to the toolbox). But for this release we thank in particular
-Nathan Cassereau and Kamel Guerda from the AI support team at
-`IDRIS <http://www.idris.fr/>`__ for their support to the development of
-the backend and OpenMP implementations.
-
-New features
-^^^^^^^^^^^^
-
-- OpenMP support for exact OT solvers (PR #260)
-- Backend for running POT in numpy/torch + exact solver (PR #249)
-- Backend implementation of most functions in ``ot.bregman`` (PR #280)
-- Backend implementation of most functions in ``ot.optim`` (PR #282)
-- Backend implementation of most functions in ``ot.gromov`` (PR #294,
- PR #302)
-- Test for arrays of different type and device (CPU/GPU) (PR #304,
- #303)
-- Implementation of Sinkhorn in log space with
- ``method='sinkhorn_log'`` (PR #290)
-- Implementation of regularization path for L2 Unbalanced OT (PR #274)
-- Implementation of Projection Robust Wasserstein (PR #267)
-- Implementation of Debiased Sinkhorn Barycenters (PR #291)
-- Implementation of Sampled Gromov Wasserstein and Pointwise Gromov
- Wasserstein (PR #275)
-- Add ``pyproject.toml`` and build POT without installing cython first
- (PR #293)
-- Lazy implementation in log space for sinkhorn on samples (PR #259)
-- Documentation cleanup (PR #298)
-- Two up-to-date documentations `for stable
- release <https://PythonOT.github.io/>`__ and for `master
- branch <https://pythonot.github.io/master/>`__.
-- Building wheels on ARM for Raspberry PI and smartphones (PR #238)
-- Update build wheels to new version and new pythons (PR #236, #253)
-- Implementation of sliced Wasserstein distance (Issue #202, PR #203)
-- Add minimal build to CI and perform pep8 test separately (PR #210)
-- Speedup of tests and return run time (PR #262)
-- Add "Why OT" discussion to the documentation (PR #220)
-- New introductory example to discrete OT in the documentation (PR
- #191)
-- Add templates for Issues/PR on Github (PR#181)
-
-Closed issues
-^^^^^^^^^^^^^
-
-- Debug Memory leak in GAN example (#254)
-- DEbug GPU bug (Issue #284, #287, PR #288)
-- set\_gradients method for JAX backend (PR #278)
-- Quicker GAN example for CircleCI build (PR #258)
-- Better formatting in Readme (PR #234)
-- Debug CI tests (PR #240, #241, #242)
-- Bug in Partial OT solver dummy points (PR #215)
-- Bug when Armijo linesearch (Issue #184, #198, #281, PR #189, #199,
- #286)
-- Bug Barycenter Sinkhorn (Issue 134, PR #195)
-- Infeasible solution in exact OT (Issues #126,#93, PR #217)
-- Doc for SUpport Barycenters (Issue #200, PR #201)
-- Fix labels transport in BaseTransport (Issue #207, PR #208)
-- Bug in ``emd_1d``, non respected bounds (Issue #169, PR #170)
-- Removed Python 2.7 support and update codecov file (PR #178)
-- Add normalization for WDA and test it (PR #172, #296)
-- Cleanup code for new version of ``flake8`` (PR #176)
-- Fixed requirements in ``setup.py`` (PR #174)
-- Removed specific MacOS flags (PR #175)
-
-0.7.0
------
-
-*May 2020*
-
-This is the new stable release for POT. We made a lot of changes in the
-documentation and added several new features such as Partial OT,
-Unbalanced and Multi Sources OT Domain Adaptation and several bug fixes.
-One important change is that we have created the GitHub organization
-`PythonOT <https://github.com/PythonOT>`__ that now owns the main POT
-repository https://github.com/PythonOT/POT and the repository for the
-new documentation is now hosted at https://PythonOT.github.io/.
-
-This is the first release where the Python 2.7 tests have been removed.
-Most of the toolbox should still work but we do not offer support for
-Python 2.7 and will close related Issues.
-
-A lot of changes have been done to the documentation that is now hosted
-on https://PythonOT.github.io/ instead of readthedocs. It was a hard
-choice but readthedocs did not allow us to run sphinx-gallery to update
-our beautiful examples and it was a huge amount of work to maintain. The
-documentation is now automatically compiled and updated on merge. We
-also removed the notebooks from the repository for space reason and also
-because they are all available in the `example
-gallery <auto_examples/index.html>`__. Note
-that now the output of the documentation build for each commit in the PR
-is available to check that the doc builds correctly before merging which
-was not possible with readthedocs.
-
-The CI framework has also been changed with a move from Travis to Github
-Action which allows to get faster tests on Windows, MacOS and Linux. We
-also now report our coverage on
-`Codecov.io <https://codecov.io/gh/PythonOT/POT>`__ and we have a
-reasonable 92% coverage. We also now generate wheels for a number of OS
-and Python versions at each merge in the master branch. They are
-available as outputs of this
-`action <https://github.com/PythonOT/POT/actions?query=workflow%3A%22Build+dist+and+wheels%22>`__.
-This will allow simpler multi-platform releases from now on.
-
-In terms of new features we now have `OTDA Classes for unbalanced
-OT <https://pythonot.github.io/gen_modules/ot.da.html#ot.da.UnbalancedSinkhornTransport>`__,
-a new Domain adaptation class form `multi domain problems
-(JCPOT) <auto_examples/domain-adaptation/plot_otda_jcpot.html#sphx-glr-auto-examples-domain-adaptation-plot-otda-jcpot-py>`__,
-and several solvers to solve the `Partial Optimal
-Transport <auto_examples/unbalanced-partial/plot_partial_wass_and_gromov.html#sphx-glr-auto-examples-unbalanced-partial-plot-partial-wass-and-gromov-py>`__
-problems.
-
-This release is also the moment to thank all the POT contributors (old
-and new) for helping making POT such a nice toolbox. A lot of changes
-(also in the API) are coming for the next versions.
-
-Features
-^^^^^^^^
-
-- New documentation on https://PythonOT.github.io/ (PR #160, PR #143,
- PR #144)
-- Documentation build on CircleCI with sphinx-gallery (PR #145,PR #146,
- #155)
-- Run sphinx gallery in CI (PR #146)
-- Remove notebooks from repo because available in doc (PR #156)
-- Build wheels in CI (#157)
-- Move from travis to GitHub Action for Windows, MacOS and Linux (PR
- #148, PR #150)
-- Partial Optimal Transport (PR#141 and PR #142)
-- Laplace regularized OTDA (PR #140)
-- Multi source DA with target shift (PR #137)
-- Screenkhorn algorithm (PR #121)
-
-Closed issues
-^^^^^^^^^^^^^
-
-- Add JMLR paper to teh readme ad Mathieu Blondel to the Acknoledgments
- (PR #231, #232)
-- Bug in Unbalanced OT example (Issue #127)
-- Clean Cython output when calling setup.py clean (Issue #122)
-- Various Macosx compilation problems (Issue #113, Issue #118, PR#130)
-- EMD dimension mismatch (Issue #114, Fixed in PR #116)
-- 2D barycenter bug for non square images (Issue #124, fixed in PR
- #132)
-- Bad value in EMD 1D (Issue #138, fixed in PR #139)
-- Log bugs for Gromov-Wassertein solver (Issue #107, fixed in PR #108)
-- Weight issues in barycenter function (PR #106)
-
-0.6.0
------
-
-*July 2019*
-
-This is the first official stable release of POT and this means a jump
-to 0.6! The library has been used in the wild for a while now and we
-have reached a state where a lot of fundamental OT solvers are available
-and tested. It has been quite stable in the last months but kept the
-beta flag in its Pypi classifiers until now.
-
-Note that this release will be the last one supporting officially Python
-2.7 (See https://python3statement.org/ for more reasons). For next
-release we will keep the travis tests for Python 2 but will make them
-non necessary for merge in 2020.
-
-The features are never complete in a toolbox designed for solving
-mathematical problems and research but with the new contributions we now
-implement algorithms and solvers from 24 scientific papers (listed in
-the README.md file). New features include a direct implementation of the
-`empirical Sinkhorn
-divergence <all.html#ot.bregman.empirical_sinkhorn_divergence>`__,
-a new efficient (Cython implementation) solver for `EMD in
-1D <all.html#ot.lp.emd_1d>`__ and
-corresponding `Wasserstein
-1D <all.html#ot.lp.wasserstein_1d>`__.
-We now also have implementations for `Unbalanced
-OT <auto_examples/plot_UOT_1D.html>`__
-and a solver for `Unbalanced OT
-barycenters <auto_examples/plot_UOT_barycenter_1D.html>`__.
-A new variant of Gromov-Wasserstein divergence called `Fused
-Gromov-Wasserstein <all.html?highlight=fused_#ot.gromov.fused_gromov_wasserstein>`__
-has been also contributed with exemples of use on `structured
-data <auto_examples/plot_fgw.html>`__
-and computing `barycenters of labeld
-graphs <auto_examples/plot_barycenter_fgw.html>`__.
-
-A lot of work has been done on the documentation with several new
-examples corresponding to the new features and a lot of corrections for
-the docstrings. But the most visible change is a new `quick start
-guide <quickstart.html>`__ for POT
-that gives several pointers about which function or classes allow to
-solve which specific OT problem. When possible a link is provided to
-relevant examples.
-
-We will also provide with this release some pre-compiled Python wheels
-for Linux 64bit on github and pip. This will simplify the install
-process that before required a C compiler and numpy/cython already
-installed.
-
-Finally we would like to acknowledge and thank the numerous contributors
-of POT that has helped in the past build the foundation and are still
-contributing to bring new features and solvers to the library.
-
-Features
-^^^^^^^^
-
-- Add compiled manylinux 64bits wheels to pip releases (PR #91)
-- Add quick start guide (PR #88)
-- Make doctest work on travis (PR #90)
-- Update documentation (PR #79, PR #84)
-- Solver for EMD in 1D (PR #89)
-- Solvers for regularized unbalanced OT (PR #87, PR#99)
-- Solver for Fused Gromov-Wasserstein (PR #86)
-- Add empirical Sinkhorn and empirical Sinkhorn divergences (PR #80)
-
-Closed issues
-^^^^^^^^^^^^^
-
-- Issue #59 fail when using "pip install POT" (new details in doc+
- hopefully wheels)
-- Issue #85 Cannot run gpu modules
-- Issue #75 Greenkhorn do not return log (solved in PR #76)
-- Issue #82 Gromov-Wasserstein fails when the cost matrices are
- slightly different
-- Issue #72 Macosx build problem
-
-0.5.0
------
-
-*Sep 2018*
-
-POT is 2 years old! This release brings numerous new features to the
-toolbox as listed below but also several bug correction.
-
-| Among the new features, we can highlight a `non-regularized
- Gromov-Wasserstein
- solver <auto_examples/plot_gromov.html>`__,
- a new `greedy variant of
- sinkhorn <all.html#ot.bregman.greenkhorn>`__,
-| `non-regularized <all.html#ot.lp.barycenter>`__,
- `convolutional
- (2D) <auto_examples/plot_convolutional_barycenter.html>`__
- and `free
- support <auto_examples/plot_free_support_barycenter.html>`__
- Wasserstein barycenters and
- `smooth <https://github.com/rflamary/POT/blob/prV0.5/notebooks/plot_OT_1D_smooth.html>`__
- and
- `stochastic <all.html#ot.stochastic.sgd_entropic_regularization>`__
- implementation of entropic OT.
-
-POT 0.5 also comes with a rewriting of ot.gpu using the cupy framework
-instead of the unmaintained cudamat. Note that while we tried to keed
-changes to the minimum, the OTDA classes were deprecated. If you are
-happy with the cudamat implementation, we recommend you stay with stable
-release 0.4 for now.
-
-The code quality has also improved with 92% code coverage in tests that
-is now printed to the log in the Travis builds. The documentation has
-also been greatly improved with new modules and examples/notebooks.
-
-This new release is so full of new stuff and corrections thanks to the
-old and new POT contributors (you can see the list in the
-`readme <https://github.com/rflamary/POT/blob/master/README.md>`__).
-
-Features
-^^^^^^^^
-
-- Add non regularized Gromov-Wasserstein solver (PR #41)
-- Linear OT mapping between empirical distributions and 90% test
- coverage (PR #42)
-- Add log parameter in class EMDTransport and SinkhornLpL1Transport (PR
- #44)
-- Add Markdown format for Pipy (PR #45)
-- Test for Python 3.5 and 3.6 on Travis (PR #46)
-- Non regularized Wasserstein barycenter with scipy linear solver
- and/or cvxopt (PR #47)
-- Rename dataset functions to be more sklearn compliant (PR #49)
-- Smooth and sparse Optimal transport implementation with entropic and
- quadratic regularization (PR #50)
-- Stochastic OT in the dual and semi-dual (PR #52 and PR #62)
-- Free support barycenters (PR #56)
-- Speed-up Sinkhorn function (PR #57 and PR #58)
-- Add convolutional Wassersein barycenters for 2D images (PR #64)
-- Add Greedy Sinkhorn variant (Greenkhorn) (PR #66)
-- Big ot.gpu update with cupy implementation (instead of un-maintained
- cudamat) (PR #67)
-
-Deprecation
-^^^^^^^^^^^
-
-Deprecated OTDA Classes were removed from ot.da and ot.gpu for version
-0.5 (PR #48 and PR #67). The deprecation message has been for a year
-here since 0.4 and it is time to pull the plug.
-
-Closed issues
-^^^^^^^^^^^^^
-
-- Issue #35 : remove import plot from ot/\ **init**.py (See PR #41)
-- Issue #43 : Unusable parameter log for EMDTransport (See PR #44)
-- Issue #55 : UnicodeDecodeError: 'ascii' while installing with pip
-
-0.4
----
-
-*15 Sep 2017*
-
-This release contains a lot of contribution from new contributors.
-
-Features
-^^^^^^^^
-
-- Automatic notebooks and doc update (PR #27)
-- Add gromov Wasserstein solver and Gromov Barycenters (PR #23)
-- emd and emd2 can now return dual variables and have max\_iter (PR #29
- and PR #25)
-- New domain adaptation classes compatible with scikit-learn (PR #22)
-- Proper tests with pytest on travis (PR #19)
-- PEP 8 tests (PR #13)
-
-Closed issues
-^^^^^^^^^^^^^
-
-- emd convergence problem du to fixed max iterations (#24)
-- Semi supervised DA error (#26)
-
-0.3.1
------
-
-*11 Jul 2017*
-
-- Correct bug in emd on windows
-
-0.3
----
-
-*7 Jul 2017*
-
-- emd\* and sinkhorn\* are now performed in parallel for multiple
- target distributions
-- emd and sinkhorn are for OT matrix computation
-- emd2 and sinkhorn2 are for OT loss computation
-- new notebooks for emd computation and Wasserstein Discriminant
- Analysis
-- relocate notebooks
-- update documentation
-- clean\_zeros(a,b,M) for removimg zeros in sparse distributions
-- GPU implementations for sinkhorn and group lasso regularization
-
-V0.2
-----
-
-*7 Apr 2017*
-
-- New dimensionality reduction method (WDA)
-- Efficient method emd2 returns only tarnsport (in paralell if several
- histograms given)
-
-0.1.11
-------
-
-*5 Jan 2017*
-
-- Add sphinx gallery for better documentation
-- Small efficiency tweak in sinkhorn
-- Add simple tic() toc() functions for timing
-
-0.1.10
-------
-
-*7 Nov 2016* \* numerical stabilization for sinkhorn (log domain and
-epsilon scaling)
-
-0.1.9
------
-
-*4 Nov 2016*
-
-- Update classes and examples for domain adaptation
-- Joint OT matrix and mapping estimation
-
-0.1.7
------
-
-*31 Oct 2016*
-
-- Original Domain adaptation classes
-
-0.1.3
------
-
-- pipy works
-
-First pre-release
------------------
-
-*28 Oct 2016*
-
-It provides the following solvers: \* OT solver for the linear program/
-Earth Movers Distance. \* Entropic regularization OT solver with
-Sinkhorn Knopp Algorithm. \* Bregman projections for Wasserstein
-barycenter [3] and unmixing. \* Optimal transport for domain adaptation
-with group lasso regularization \* Conditional gradient and Generalized
-conditional gradient for regularized OT.
-
-Some demonstrations (both in Python and Jupyter Notebook format) are
-available in the examples folder.
+.. include:: ../../RELEASES.md
+ :parser: myst_parser.sphinx_
+ :start-line: 3
diff --git a/examples/plot_Intro_OT.py b/examples/plot_Intro_OT.py
index 2e2c6fd..f282950 100644
--- a/examples/plot_Intro_OT.py
+++ b/examples/plot_Intro_OT.py
@@ -327,7 +327,7 @@ for k in range(len(reg_parameter)):
time_sinkhorn_reg[k] = time.time() - start
if k % 4 == 0 and k > 0: # we only plot a few
- ax = pl.subplot(1, 5, k / 4)
+ ax = pl.subplot(1, 5, k // 4)
im = pl.imshow(ot_sinkhorn, vmin=0, vmax=max_ot)
pl.title('reg={0:.2g}'.format(reg_parameter[k]))
pl.xlabel('Cafés')
diff --git a/ot/__init__.py b/ot/__init__.py
index b6dc2b4..e436571 100644
--- a/ot/__init__.py
+++ b/ot/__init__.py
@@ -50,7 +50,7 @@ from .gromov import (gromov_wasserstein, gromov_wasserstein2,
# utils functions
from .utils import dist, unif, tic, toc, toq
-__version__ = "0.8.0"
+__version__ = "0.8.1"
__all__ = ['emd', 'emd2', 'emd_1d', 'sinkhorn', 'sinkhorn2', 'utils',
'datasets', 'bregman', 'lp', 'tic', 'toc', 'toq', 'gromov',
diff --git a/ot/backend.py b/ot/backend.py
index a044f84..58b652b 100644
--- a/ot/backend.py
+++ b/ot/backend.py
@@ -3,7 +3,7 @@
Multi-lib backend for POT
The goal is to write backend-agnostic code. Whether you're using Numpy, PyTorch,
-or Jax, POT code should work nonetheless.
+Jax, Cupy, or Tensorflow, POT code should work nonetheless.
To achieve that, POT provides backend classes which implements functions in their respective backend
imitating Numpy API. As a convention, we use nx instead of np to refer to the backend.
@@ -17,6 +17,68 @@ Examples
... nx = get_backend(a, b) # infer the backend from the arguments
... c = nx.dot(a, b) # now use the backend to do any calculation
... return c
+
+.. warning::
+ Tensorflow only works with the Numpy API. To activate it, please run the following:
+
+ .. code-block::
+
+ from tensorflow.python.ops.numpy_ops import np_config
+ np_config.enable_numpy_behavior()
+
+Performance
+--------
+
+- CPU: Intel(R) Xeon(R) Gold 6248 CPU @ 2.50GHz
+- GPU: Tesla V100-SXM2-32GB
+- Date of the benchmark: December 8th, 2021
+- Commit of benchmark: PR #316, https://github.com/PythonOT/POT/pull/316
+
+.. raw:: html
+
+ <style>
+ #perftable {
+ width: 100%;
+ margin-bottom: 1em;
+ }
+
+ #perftable table{
+ border-collapse: collapse;
+ table-layout: fixed;
+ width: 100%;
+ }
+
+ #perftable th, #perftable td {
+ border: 1px solid #ddd;
+ padding: 8px;
+ font-size: smaller;
+ }
+ </style>
+
+ <div id="perftable">
+ <table>
+ <tr><th align="center" colspan="8">Sinkhorn Knopp - Averaged on 100 runs</th></tr>
+ <tr><th align="center">Bitsize</th><th align="center" colspan="7">32 bits</th></tr>
+ <tr><th align="center">Device</th><th align="center" colspan="3.0"">CPU</th><th align="center" colspan="4.0">GPU</tr>
+ <tr><th align="center">Sample size</th><th align="center">Numpy</th><th align="center">Pytorch</th><th align="center">Tensorflow</th><th align="center">Cupy</th><th align="center">Jax</th><th align="center">Pytorch</th><th align="center">Tensorflow</th></tr>
+ <tr><td align="center">50</td><td align="center">0.0008</td><td align="center">0.0022</td><td align="center">0.0151</td><td align="center">0.0095</td><td align="center">0.0193</td><td align="center">0.0051</td><td align="center">0.0293</td></tr>
+ <tr><td align="center">100</td><td align="center">0.0005</td><td align="center">0.0013</td><td align="center">0.0097</td><td align="center">0.0057</td><td align="center">0.0115</td><td align="center">0.0029</td><td align="center">0.0173</td></tr>
+ <tr><td align="center">500</td><td align="center">0.0009</td><td align="center">0.0016</td><td align="center">0.0110</td><td align="center">0.0058</td><td align="center">0.0115</td><td align="center">0.0029</td><td align="center">0.0166</td></tr>
+ <tr><td align="center">1000</td><td align="center">0.0021</td><td align="center">0.0021</td><td align="center">0.0145</td><td align="center">0.0056</td><td align="center">0.0118</td><td align="center">0.0029</td><td align="center">0.0168</td></tr>
+ <tr><td align="center">2000</td><td align="center">0.0069</td><td align="center">0.0043</td><td align="center">0.0278</td><td align="center">0.0059</td><td align="center">0.0118</td><td align="center">0.0030</td><td align="center">0.0165</td></tr>
+ <tr><td align="center">5000</td><td align="center">0.0707</td><td align="center">0.0314</td><td align="center">0.1395</td><td align="center">0.0074</td><td align="center">0.0125</td><td align="center">0.0035</td><td align="center">0.0198</td></tr>
+ <tr><td colspan="8">&nbsp;</td></tr>
+ <tr><th align="center">Bitsize</th><th align="center" colspan="7">64 bits</th></tr>
+ <tr><th align="center">Device</th><th align="center" colspan="3.0"">CPU</th><th align="center" colspan="4.0">GPU</tr>
+ <tr><th align="center">Sample size</th><th align="center">Numpy</th><th align="center">Pytorch</th><th align="center">Tensorflow</th><th align="center">Cupy</th><th align="center">Jax</th><th align="center">Pytorch</th><th align="center">Tensorflow</th></tr>
+ <tr><td align="center">50</td><td align="center">0.0008</td><td align="center">0.0020</td><td align="center">0.0154</td><td align="center">0.0093</td><td align="center">0.0191</td><td align="center">0.0051</td><td align="center">0.0328</td></tr>
+ <tr><td align="center">100</td><td align="center">0.0005</td><td align="center">0.0013</td><td align="center">0.0094</td><td align="center">0.0056</td><td align="center">0.0114</td><td align="center">0.0029</td><td align="center">0.0169</td></tr>
+ <tr><td align="center">500</td><td align="center">0.0013</td><td align="center">0.0017</td><td align="center">0.0120</td><td align="center">0.0059</td><td align="center">0.0116</td><td align="center">0.0029</td><td align="center">0.0168</td></tr>
+ <tr><td align="center">1000</td><td align="center">0.0034</td><td align="center">0.0027</td><td align="center">0.0177</td><td align="center">0.0058</td><td align="center">0.0118</td><td align="center">0.0029</td><td align="center">0.0167</td></tr>
+ <tr><td align="center">2000</td><td align="center">0.0146</td><td align="center">0.0075</td><td align="center">0.0436</td><td align="center">0.0059</td><td align="center">0.0120</td><td align="center">0.0029</td><td align="center">0.0165</td></tr>
+ <tr><td align="center">5000</td><td align="center">0.1467</td><td align="center">0.0568</td><td align="center">0.2468</td><td align="center">0.0077</td><td align="center">0.0146</td><td align="center">0.0045</td><td align="center">0.0204</td></tr>
+ </table>
+ </div>
"""
# Author: Remi Flamary <remi.flamary@polytechnique.edu>
@@ -27,6 +89,8 @@ Examples
import numpy as np
import scipy.special as scipy
from scipy.sparse import issparse, coo_matrix, csr_matrix
+import warnings
+import time
try:
import torch
@@ -39,11 +103,29 @@ try:
import jax
import jax.numpy as jnp
import jax.scipy.special as jscipy
+ from jax.lib import xla_bridge
jax_type = jax.numpy.ndarray
except ImportError:
jax = False
jax_type = float
+try:
+ import cupy as cp
+ import cupyx
+ cp_type = cp.ndarray
+except ImportError:
+ cp = False
+ cp_type = float
+
+try:
+ import tensorflow as tf
+ import tensorflow.experimental.numpy as tnp
+ tf_type = tf.Tensor
+except ImportError:
+ tf = False
+ tf_type = float
+
+
str_type_error = "All array should be from the same type/backend. Current types are : {}"
@@ -57,6 +139,12 @@ def get_backend_list():
if jax:
lst.append(JaxBackend())
+ if cp: # pragma: no cover
+ lst.append(CupyBackend())
+
+ if tf:
+ lst.append(TensorflowBackend())
+
return lst
@@ -78,6 +166,10 @@ def get_backend(*args):
return TorchBackend()
elif isinstance(args[0], jax_type):
return JaxBackend()
+ elif isinstance(args[0], cp_type): # pragma: no cover
+ return CupyBackend()
+ elif isinstance(args[0], tf_type):
+ return TensorflowBackend()
else:
raise ValueError("Unknown type of non implemented backend.")
@@ -94,7 +186,8 @@ def to_numpy(*args):
class Backend():
"""
Backend abstract class.
- Implementations: :py:class:`JaxBackend`, :py:class:`NumpyBackend`, :py:class:`TorchBackend`
+ Implementations: :py:class:`JaxBackend`, :py:class:`NumpyBackend`, :py:class:`TorchBackend`,
+ :py:class:`CupyBackend`, :py:class:`TensorflowBackend`
- The `__name__` class attribute refers to the name of the backend.
- The `__type__` class attribute refers to the data structure used by the backend.
@@ -665,6 +758,34 @@ class Backend():
"""
raise NotImplementedError()
+ def squeeze(self, a, axis=None):
+ r"""
+ Remove axes of length one from a.
+
+ This function follows the api from :any:`numpy.squeeze`.
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.squeeze.html
+ """
+ raise NotImplementedError()
+
+ def bitsize(self, type_as):
+ r"""
+ Gives the number of bits used by the data type of the given tensor.
+ """
+ raise NotImplementedError()
+
+ def device_type(self, type_as):
+ r"""
+ Returns CPU or GPU depending on the device where the given tensor is located.
+ """
+ raise NotImplementedError()
+
+ def _bench(self, callable, *args, n_runs=1, warmup_runs=1):
+ r"""
+ Executes a benchmark of the given callable with the given arguments.
+ """
+ raise NotImplementedError()
+
class NumpyBackend(Backend):
"""
@@ -902,6 +1023,29 @@ class NumpyBackend(Backend):
# numpy has implicit type conversion so we automatically validate the test
pass
+ def squeeze(self, a, axis=None):
+ return np.squeeze(a, axis=axis)
+
+ def bitsize(self, type_as):
+ return type_as.itemsize * 8
+
+ def device_type(self, type_as):
+ return "CPU"
+
+ def _bench(self, callable, *args, n_runs=1, warmup_runs=1):
+ results = dict()
+ for type_as in self.__type_list__:
+ inputs = [self.from_numpy(arg, type_as=type_as) for arg in args]
+ for _ in range(warmup_runs):
+ callable(*inputs)
+ t0 = time.perf_counter()
+ for _ in range(n_runs):
+ callable(*inputs)
+ t1 = time.perf_counter()
+ key = ("Numpy", self.device_type(type_as), self.bitsize(type_as))
+ results[key] = (t1 - t0) / n_runs
+ return results
+
class JaxBackend(Backend):
"""
@@ -920,9 +1064,16 @@ class JaxBackend(Backend):
def __init__(self):
self.rng_ = jax.random.PRNGKey(42)
- for d in jax.devices():
- self.__type_list__ = [jax.device_put(jnp.array(1, dtype=jnp.float32), d),
- jax.device_put(jnp.array(1, dtype=jnp.float64), d)]
+ self.__type_list__ = []
+ # available_devices = jax.devices("cpu")
+ available_devices = []
+ if xla_bridge.get_backend().platform == "gpu":
+ available_devices += jax.devices("gpu")
+ for d in available_devices:
+ self.__type_list__ += [
+ jax.device_put(jnp.array(1, dtype=jnp.float32), d),
+ jax.device_put(jnp.array(1, dtype=jnp.float64), d)
+ ]
def to_numpy(self, a):
return np.array(a)
@@ -1162,6 +1313,32 @@ class JaxBackend(Backend):
assert a_dtype == b_dtype, "Dtype discrepancy"
assert a_device == b_device, f"Device discrepancy. First input is on {str(a_device)}, whereas second input is on {str(b_device)}"
+ def squeeze(self, a, axis=None):
+ return jnp.squeeze(a, axis=axis)
+
+ def bitsize(self, type_as):
+ return type_as.dtype.itemsize * 8
+
+ def device_type(self, type_as):
+ return self.dtype_device(type_as)[1].platform.upper()
+
+ def _bench(self, callable, *args, n_runs=1, warmup_runs=1):
+ results = dict()
+
+ for type_as in self.__type_list__:
+ inputs = [self.from_numpy(arg, type_as=type_as) for arg in args]
+ for _ in range(warmup_runs):
+ a = callable(*inputs)
+ a.block_until_ready()
+ t0 = time.perf_counter()
+ for _ in range(n_runs):
+ a = callable(*inputs)
+ a.block_until_ready()
+ t1 = time.perf_counter()
+ key = ("Jax", self.device_type(type_as), self.bitsize(type_as))
+ results[key] = (t1 - t0) / n_runs
+ return results
+
class TorchBackend(Backend):
"""
@@ -1203,7 +1380,7 @@ class TorchBackend(Backend):
@staticmethod
def backward(ctx, grad_output):
# the gradients are grad
- return (None, None) + ctx.grads
+ return (None, None) + tuple(g * grad_output for g in ctx.grads)
self.ValFunction = ValFunction
@@ -1500,3 +1677,690 @@ class TorchBackend(Backend):
assert a_dtype == b_dtype, "Dtype discrepancy"
assert a_device == b_device, f"Device discrepancy. First input is on {str(a_device)}, whereas second input is on {str(b_device)}"
+
+ def squeeze(self, a, axis=None):
+ if axis is None:
+ return torch.squeeze(a)
+ else:
+ return torch.squeeze(a, dim=axis)
+
+ def bitsize(self, type_as):
+ return torch.finfo(type_as.dtype).bits
+
+ def device_type(self, type_as):
+ return type_as.device.type.replace("cuda", "gpu").upper()
+
+ def _bench(self, callable, *args, n_runs=1, warmup_runs=1):
+ results = dict()
+ for type_as in self.__type_list__:
+ inputs = [self.from_numpy(arg, type_as=type_as) for arg in args]
+ for _ in range(warmup_runs):
+ callable(*inputs)
+ if self.device_type(type_as) == "GPU": # pragma: no cover
+ torch.cuda.synchronize()
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+ start.record()
+ else:
+ start = time.perf_counter()
+ for _ in range(n_runs):
+ callable(*inputs)
+ if self.device_type(type_as) == "GPU": # pragma: no cover
+ end.record()
+ torch.cuda.synchronize()
+ duration = start.elapsed_time(end) / 1000.
+ else:
+ end = time.perf_counter()
+ duration = end - start
+ key = ("Pytorch", self.device_type(type_as), self.bitsize(type_as))
+ results[key] = duration / n_runs
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+ return results
+
+
+class CupyBackend(Backend): # pragma: no cover
+ """
+ CuPy implementation of the backend
+
+ - `__name__` is "cupy"
+ - `__type__` is cp.ndarray
+ """
+
+ __name__ = 'cupy'
+ __type__ = cp_type
+ __type_list__ = None
+
+ rng_ = None
+
+ def __init__(self):
+ self.rng_ = cp.random.RandomState()
+
+ self.__type_list__ = [
+ cp.array(1, dtype=cp.float32),
+ cp.array(1, dtype=cp.float64)
+ ]
+
+ def to_numpy(self, a):
+ return cp.asnumpy(a)
+
+ def from_numpy(self, a, type_as=None):
+ if type_as is None:
+ return cp.asarray(a)
+ else:
+ with cp.cuda.Device(type_as.device):
+ return cp.asarray(a, dtype=type_as.dtype)
+
+ def set_gradients(self, val, inputs, grads):
+ # No gradients for cupy
+ return val
+
+ def zeros(self, shape, type_as=None):
+ if isinstance(shape, (list, tuple)):
+ shape = tuple(int(i) for i in shape)
+ if type_as is None:
+ return cp.zeros(shape)
+ else:
+ with cp.cuda.Device(type_as.device):
+ return cp.zeros(shape, dtype=type_as.dtype)
+
+ def ones(self, shape, type_as=None):
+ if isinstance(shape, (list, tuple)):
+ shape = tuple(int(i) for i in shape)
+ if type_as is None:
+ return cp.ones(shape)
+ else:
+ with cp.cuda.Device(type_as.device):
+ return cp.ones(shape, dtype=type_as.dtype)
+
+ def arange(self, stop, start=0, step=1, type_as=None):
+ return cp.arange(start, stop, step)
+
+ def full(self, shape, fill_value, type_as=None):
+ if isinstance(shape, (list, tuple)):
+ shape = tuple(int(i) for i in shape)
+ if type_as is None:
+ return cp.full(shape, fill_value)
+ else:
+ with cp.cuda.Device(type_as.device):
+ return cp.full(shape, fill_value, dtype=type_as.dtype)
+
+ def eye(self, N, M=None, type_as=None):
+ if type_as is None:
+ return cp.eye(N, M)
+ else:
+ with cp.cuda.Device(type_as.device):
+ return cp.eye(N, M, dtype=type_as.dtype)
+
+ def sum(self, a, axis=None, keepdims=False):
+ return cp.sum(a, axis, keepdims=keepdims)
+
+ def cumsum(self, a, axis=None):
+ return cp.cumsum(a, axis)
+
+ def max(self, a, axis=None, keepdims=False):
+ return cp.max(a, axis, keepdims=keepdims)
+
+ def min(self, a, axis=None, keepdims=False):
+ return cp.min(a, axis, keepdims=keepdims)
+
+ def maximum(self, a, b):
+ return cp.maximum(a, b)
+
+ def minimum(self, a, b):
+ return cp.minimum(a, b)
+
+ def abs(self, a):
+ return cp.abs(a)
+
+ def exp(self, a):
+ return cp.exp(a)
+
+ def log(self, a):
+ return cp.log(a)
+
+ def sqrt(self, a):
+ return cp.sqrt(a)
+
+ def power(self, a, exponents):
+ return cp.power(a, exponents)
+
+ def dot(self, a, b):
+ return cp.dot(a, b)
+
+ def norm(self, a):
+ return cp.sqrt(cp.sum(cp.square(a)))
+
+ def any(self, a):
+ return cp.any(a)
+
+ def isnan(self, a):
+ return cp.isnan(a)
+
+ def isinf(self, a):
+ return cp.isinf(a)
+
+ def einsum(self, subscripts, *operands):
+ return cp.einsum(subscripts, *operands)
+
+ def sort(self, a, axis=-1):
+ return cp.sort(a, axis)
+
+ def argsort(self, a, axis=-1):
+ return cp.argsort(a, axis)
+
+ def searchsorted(self, a, v, side='left'):
+ if a.ndim == 1:
+ return cp.searchsorted(a, v, side)
+ else:
+ # this is a not very efficient way to make numpy
+ # searchsorted work on 2d arrays
+ ret = cp.empty(v.shape, dtype=int)
+ for i in range(a.shape[0]):
+ ret[i, :] = cp.searchsorted(a[i, :], v[i, :], side)
+ return ret
+
+ def flip(self, a, axis=None):
+ return cp.flip(a, axis)
+
+ def outer(self, a, b):
+ return cp.outer(a, b)
+
+ def clip(self, a, a_min, a_max):
+ return cp.clip(a, a_min, a_max)
+
+ def repeat(self, a, repeats, axis=None):
+ return cp.repeat(a, repeats, axis)
+
+ def take_along_axis(self, arr, indices, axis):
+ return cp.take_along_axis(arr, indices, axis)
+
+ def concatenate(self, arrays, axis=0):
+ return cp.concatenate(arrays, axis)
+
+ def zero_pad(self, a, pad_width):
+ return cp.pad(a, pad_width)
+
+ def argmax(self, a, axis=None):
+ return cp.argmax(a, axis=axis)
+
+ def mean(self, a, axis=None):
+ return cp.mean(a, axis=axis)
+
+ def std(self, a, axis=None):
+ return cp.std(a, axis=axis)
+
+ def linspace(self, start, stop, num):
+ return cp.linspace(start, stop, num)
+
+ def meshgrid(self, a, b):
+ return cp.meshgrid(a, b)
+
+ def diag(self, a, k=0):
+ return cp.diag(a, k)
+
+ def unique(self, a):
+ return cp.unique(a)
+
+ def logsumexp(self, a, axis=None):
+ # Taken from
+ # https://github.com/scipy/scipy/blob/v1.7.1/scipy/special/_logsumexp.py#L7-L127
+ a_max = cp.amax(a, axis=axis, keepdims=True)
+
+ if a_max.ndim > 0:
+ a_max[~cp.isfinite(a_max)] = 0
+ elif not cp.isfinite(a_max):
+ a_max = 0
+
+ tmp = cp.exp(a - a_max)
+ s = cp.sum(tmp, axis=axis)
+ out = cp.log(s)
+ a_max = cp.squeeze(a_max, axis=axis)
+ out += a_max
+ return out
+
+ def stack(self, arrays, axis=0):
+ return cp.stack(arrays, axis)
+
+ def reshape(self, a, shape):
+ return cp.reshape(a, shape)
+
+ def seed(self, seed=None):
+ if seed is not None:
+ self.rng_.seed(seed)
+
+ def rand(self, *size, type_as=None):
+ if type_as is None:
+ return self.rng_.rand(*size)
+ else:
+ with cp.cuda.Device(type_as.device):
+ return self.rng_.rand(*size, dtype=type_as.dtype)
+
+ def randn(self, *size, type_as=None):
+ if type_as is None:
+ return self.rng_.randn(*size)
+ else:
+ with cp.cuda.Device(type_as.device):
+ return self.rng_.randn(*size, dtype=type_as.dtype)
+
+ def coo_matrix(self, data, rows, cols, shape=None, type_as=None):
+ data = self.from_numpy(data)
+ rows = self.from_numpy(rows)
+ cols = self.from_numpy(cols)
+ if type_as is None:
+ return cupyx.scipy.sparse.coo_matrix(
+ (data, (rows, cols)), shape=shape
+ )
+ else:
+ with cp.cuda.Device(type_as.device):
+ return cupyx.scipy.sparse.coo_matrix(
+ (data, (rows, cols)), shape=shape, dtype=type_as.dtype
+ )
+
+ def issparse(self, a):
+ return cupyx.scipy.sparse.issparse(a)
+
+ def tocsr(self, a):
+ if self.issparse(a):
+ return a.tocsr()
+ else:
+ return cupyx.scipy.sparse.csr_matrix(a)
+
+ def eliminate_zeros(self, a, threshold=0.):
+ if threshold > 0:
+ if self.issparse(a):
+ a.data[self.abs(a.data) <= threshold] = 0
+ else:
+ a[self.abs(a) <= threshold] = 0
+ if self.issparse(a):
+ a.eliminate_zeros()
+ return a
+
+ def todense(self, a):
+ if self.issparse(a):
+ return a.toarray()
+ else:
+ return a
+
+ def where(self, condition, x, y):
+ return cp.where(condition, x, y)
+
+ def copy(self, a):
+ return a.copy()
+
+ def allclose(self, a, b, rtol=1e-05, atol=1e-08, equal_nan=False):
+ return cp.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)
+
+ def dtype_device(self, a):
+ return a.dtype, a.device
+
+ def assert_same_dtype_device(self, a, b):
+ a_dtype, a_device = self.dtype_device(a)
+ b_dtype, b_device = self.dtype_device(b)
+
+ # cupy has implicit type conversion so
+ # we automatically validate the test for type
+ assert a_device == b_device, f"Device discrepancy. First input is on {str(a_device)}, whereas second input is on {str(b_device)}"
+
+ def squeeze(self, a, axis=None):
+ return cp.squeeze(a, axis=axis)
+
+ def bitsize(self, type_as):
+ return type_as.itemsize * 8
+
+ def device_type(self, type_as):
+ return "GPU"
+
+ def _bench(self, callable, *args, n_runs=1, warmup_runs=1):
+ mempool = cp.get_default_memory_pool()
+ pinned_mempool = cp.get_default_pinned_memory_pool()
+
+ results = dict()
+ for type_as in self.__type_list__:
+ inputs = [self.from_numpy(arg, type_as=type_as) for arg in args]
+ start_gpu = cp.cuda.Event()
+ end_gpu = cp.cuda.Event()
+ for _ in range(warmup_runs):
+ callable(*inputs)
+ start_gpu.synchronize()
+ start_gpu.record()
+ for _ in range(n_runs):
+ callable(*inputs)
+ end_gpu.record()
+ end_gpu.synchronize()
+ key = ("Cupy", self.device_type(type_as), self.bitsize(type_as))
+ t_gpu = cp.cuda.get_elapsed_time(start_gpu, end_gpu) / 1000.
+ results[key] = t_gpu / n_runs
+ mempool.free_all_blocks()
+ pinned_mempool.free_all_blocks()
+ return results
+
+
+class TensorflowBackend(Backend):
+
+ __name__ = "tf"
+ __type__ = tf_type
+ __type_list__ = None
+
+ rng_ = None
+
+ def __init__(self):
+ self.seed(None)
+
+ self.__type_list__ = [
+ tf.convert_to_tensor([1], dtype=tf.float32),
+ tf.convert_to_tensor([1], dtype=tf.float64)
+ ]
+
+ tmp = self.randn(15, 10)
+ try:
+ tmp.reshape((150, 1))
+ except AttributeError:
+ warnings.warn(
+ "To use TensorflowBackend, you need to activate the tensorflow "
+ "numpy API. You can activate it by running: \n"
+ "from tensorflow.python.ops.numpy_ops import np_config\n"
+ "np_config.enable_numpy_behavior()"
+ )
+
+ def to_numpy(self, a):
+ return a.numpy()
+
+ def from_numpy(self, a, type_as=None):
+ if not isinstance(a, self.__type__):
+ if type_as is None:
+ return tf.convert_to_tensor(a)
+ else:
+ return tf.convert_to_tensor(a, dtype=type_as.dtype)
+ else:
+ if type_as is None:
+ return a
+ else:
+ return tf.cast(a, dtype=type_as.dtype)
+
+ def set_gradients(self, val, inputs, grads):
+ @tf.custom_gradient
+ def tmp(input):
+ def grad(upstream):
+ return grads
+ return val, grad
+ return tmp(inputs)
+
+ def zeros(self, shape, type_as=None):
+ if type_as is None:
+ return tnp.zeros(shape)
+ else:
+ return tnp.zeros(shape, dtype=type_as.dtype)
+
+ def ones(self, shape, type_as=None):
+ if type_as is None:
+ return tnp.ones(shape)
+ else:
+ return tnp.ones(shape, dtype=type_as.dtype)
+
+ def arange(self, stop, start=0, step=1, type_as=None):
+ return tnp.arange(start, stop, step)
+
+ def full(self, shape, fill_value, type_as=None):
+ if type_as is None:
+ return tnp.full(shape, fill_value)
+ else:
+ return tnp.full(shape, fill_value, dtype=type_as.dtype)
+
+ def eye(self, N, M=None, type_as=None):
+ if type_as is None:
+ return tnp.eye(N, M)
+ else:
+ return tnp.eye(N, M, dtype=type_as.dtype)
+
+ def sum(self, a, axis=None, keepdims=False):
+ return tnp.sum(a, axis, keepdims=keepdims)
+
+ def cumsum(self, a, axis=None):
+ return tnp.cumsum(a, axis)
+
+ def max(self, a, axis=None, keepdims=False):
+ return tnp.max(a, axis, keepdims=keepdims)
+
+ def min(self, a, axis=None, keepdims=False):
+ return tnp.min(a, axis, keepdims=keepdims)
+
+ def maximum(self, a, b):
+ return tnp.maximum(a, b)
+
+ def minimum(self, a, b):
+ return tnp.minimum(a, b)
+
+ def dot(self, a, b):
+ if len(b.shape) == 1:
+ if len(a.shape) == 1:
+ # inner product
+ return tf.reduce_sum(tf.multiply(a, b))
+ else:
+ # matrix vector
+ return tf.linalg.matvec(a, b)
+ else:
+ if len(a.shape) == 1:
+ return tf.linalg.matvec(b.T, a.T).T
+ else:
+ return tf.matmul(a, b)
+
+ def abs(self, a):
+ return tnp.abs(a)
+
+ def exp(self, a):
+ return tnp.exp(a)
+
+ def log(self, a):
+ return tnp.log(a)
+
+ def sqrt(self, a):
+ return tnp.sqrt(a)
+
+ def power(self, a, exponents):
+ return tnp.power(a, exponents)
+
+ def norm(self, a):
+ return tf.math.reduce_euclidean_norm(a)
+
+ def any(self, a):
+ return tnp.any(a)
+
+ def isnan(self, a):
+ return tnp.isnan(a)
+
+ def isinf(self, a):
+ return tnp.isinf(a)
+
+ def einsum(self, subscripts, *operands):
+ return tnp.einsum(subscripts, *operands)
+
+ def sort(self, a, axis=-1):
+ return tnp.sort(a, axis)
+
+ def argsort(self, a, axis=-1):
+ return tnp.argsort(a, axis)
+
+ def searchsorted(self, a, v, side='left'):
+ return tf.searchsorted(a, v, side=side)
+
+ def flip(self, a, axis=None):
+ return tnp.flip(a, axis)
+
+ def outer(self, a, b):
+ return tnp.outer(a, b)
+
+ def clip(self, a, a_min, a_max):
+ return tnp.clip(a, a_min, a_max)
+
+ def repeat(self, a, repeats, axis=None):
+ return tnp.repeat(a, repeats, axis)
+
+ def take_along_axis(self, arr, indices, axis):
+ return tnp.take_along_axis(arr, indices, axis)
+
+ def concatenate(self, arrays, axis=0):
+ return tnp.concatenate(arrays, axis)
+
+ def zero_pad(self, a, pad_width):
+ return tnp.pad(a, pad_width, mode="constant")
+
+ def argmax(self, a, axis=None):
+ return tnp.argmax(a, axis=axis)
+
+ def mean(self, a, axis=None):
+ return tnp.mean(a, axis=axis)
+
+ def std(self, a, axis=None):
+ return tnp.std(a, axis=axis)
+
+ def linspace(self, start, stop, num):
+ return tnp.linspace(start, stop, num)
+
+ def meshgrid(self, a, b):
+ return tnp.meshgrid(a, b)
+
+ def diag(self, a, k=0):
+ return tnp.diag(a, k)
+
+ def unique(self, a):
+ return tf.sort(tf.unique(tf.reshape(a, [-1]))[0])
+
+ def logsumexp(self, a, axis=None):
+ return tf.math.reduce_logsumexp(a, axis=axis)
+
+ def stack(self, arrays, axis=0):
+ return tnp.stack(arrays, axis)
+
+ def reshape(self, a, shape):
+ return tnp.reshape(a, shape)
+
+ def seed(self, seed=None):
+ if isinstance(seed, int):
+ self.rng_ = tf.random.Generator.from_seed(seed)
+ elif isinstance(seed, tf.random.Generator):
+ self.rng_ = seed
+ elif seed is None:
+ self.rng_ = tf.random.Generator.from_non_deterministic_state()
+ else:
+ raise ValueError("Non compatible seed : {}".format(seed))
+
+ def rand(self, *size, type_as=None):
+ if type_as is None:
+ return self.rng_.uniform(size, minval=0., maxval=1.)
+ else:
+ return self.rng_.uniform(
+ size, minval=0., maxval=1., dtype=type_as.dtype
+ )
+
+ def randn(self, *size, type_as=None):
+ if type_as is None:
+ return self.rng_.normal(size)
+ else:
+ return self.rng_.normal(size, dtype=type_as.dtype)
+
+ def _convert_to_index_for_coo(self, tensor):
+ if isinstance(tensor, self.__type__):
+ return int(self.max(tensor)) + 1
+ else:
+ return int(np.max(tensor)) + 1
+
+ def coo_matrix(self, data, rows, cols, shape=None, type_as=None):
+ if shape is None:
+ shape = (
+ self._convert_to_index_for_coo(rows),
+ self._convert_to_index_for_coo(cols)
+ )
+ if type_as is not None:
+ data = self.from_numpy(data, type_as=type_as)
+
+ sparse_tensor = tf.sparse.SparseTensor(
+ indices=tnp.stack([rows, cols]).T,
+ values=data,
+ dense_shape=shape
+ )
+ # if type_as is not None:
+ # sparse_tensor = self.from_numpy(sparse_tensor, type_as=type_as)
+ # SparseTensor are not subscriptable so we use dense tensors
+ return self.todense(sparse_tensor)
+
+ def issparse(self, a):
+ return isinstance(a, tf.sparse.SparseTensor)
+
+ def tocsr(self, a):
+ return a
+
+ def eliminate_zeros(self, a, threshold=0.):
+ if self.issparse(a):
+ values = a.values
+ if threshold > 0:
+ mask = self.abs(values) <= threshold
+ else:
+ mask = values == 0
+ return tf.sparse.retain(a, ~mask)
+ else:
+ if threshold > 0:
+ a = tnp.where(self.abs(a) > threshold, a, 0.)
+ return a
+
+ def todense(self, a):
+ if self.issparse(a):
+ return tf.sparse.to_dense(tf.sparse.reorder(a))
+ else:
+ return a
+
+ def where(self, condition, x, y):
+ return tnp.where(condition, x, y)
+
+ def copy(self, a):
+ return tf.identity(a)
+
+ def allclose(self, a, b, rtol=1e-05, atol=1e-08, equal_nan=False):
+ return tnp.allclose(
+ a, b, rtol=rtol, atol=atol, equal_nan=equal_nan
+ )
+
+ def dtype_device(self, a):
+ return a.dtype, a.device.split("device:")[1]
+
+ def assert_same_dtype_device(self, a, b):
+ a_dtype, a_device = self.dtype_device(a)
+ b_dtype, b_device = self.dtype_device(b)
+
+ assert a_dtype == b_dtype, "Dtype discrepancy"
+ assert a_device == b_device, f"Device discrepancy. First input is on {str(a_device)}, whereas second input is on {str(b_device)}"
+
+ def squeeze(self, a, axis=None):
+ return tnp.squeeze(a, axis=axis)
+
+ def bitsize(self, type_as):
+ return type_as.dtype.size * 8
+
+ def device_type(self, type_as):
+ return self.dtype_device(type_as)[1].split(":")[0]
+
+ def _bench(self, callable, *args, n_runs=1, warmup_runs=1):
+ results = dict()
+ device_contexts = [tf.device("/CPU:0")]
+ if len(tf.config.list_physical_devices('GPU')) > 0: # pragma: no cover
+ device_contexts.append(tf.device("/GPU:0"))
+
+ for device_context in device_contexts:
+ with device_context:
+ for type_as in self.__type_list__:
+ inputs = [self.from_numpy(arg, type_as=type_as) for arg in args]
+ for _ in range(warmup_runs):
+ callable(*inputs)
+ t0 = time.perf_counter()
+ for _ in range(n_runs):
+ res = callable(*inputs)
+ _ = res.numpy()
+ t1 = time.perf_counter()
+ key = (
+ "Tensorflow",
+ self.device_type(inputs[0]),
+ self.bitsize(type_as)
+ )
+ results[key] = (t1 - t0) / n_runs
+
+ return results
diff --git a/ot/bregman.py b/ot/bregman.py
index cce52e2..fc20175 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -830,9 +830,9 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False,
a, b, M = list_to_array(a, b, M)
nx = get_backend(M, a, b)
- if nx.__name__ == "jax":
- raise TypeError("JAX arrays have been received. Greenkhorn is not "
- "compatible with JAX")
+ if nx.__name__ in ("jax", "tf"):
+ raise TypeError("JAX or TF arrays have been received. Greenkhorn is not "
+ "compatible with neither JAX nor TF")
if len(a) == 0:
a = nx.ones((M.shape[0],), type_as=M) / M.shape[0]
@@ -865,20 +865,20 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False,
if m_viol_1 > m_viol_2:
old_u = u[i_1]
- new_u = a[i_1] / (K[i_1, :].dot(v))
+ new_u = a[i_1] / nx.dot(K[i_1, :], v)
G[i_1, :] = new_u * K[i_1, :] * v
- viol[i_1] = new_u * K[i_1, :].dot(v) - a[i_1]
+ viol[i_1] = nx.dot(new_u * K[i_1, :], v) - a[i_1]
viol_2 += (K[i_1, :].T * (new_u - old_u) * v)
u[i_1] = new_u
else:
old_v = v[i_2]
- new_v = b[i_2] / (K[:, i_2].T.dot(u))
+ new_v = b[i_2] / nx.dot(K[:, i_2].T, u)
G[:, i_2] = u * K[:, i_2] * new_v
# aviol = (G@one_m - a)
# aviol_2 = (G.T@one_n - b)
viol += (-old_v + new_v) * K[:, i_2] * u
- viol_2[i_2] = new_v * K[:, i_2].dot(u) - b[i_2]
+ viol_2[i_2] = new_v * nx.dot(K[:, i_2], u) - b[i_2]
v[i_2] = new_v
if stopThr_val <= stopThr:
@@ -1550,9 +1550,11 @@ def _barycenter_sinkhorn_log(A, M, reg, weights=None, numItermax=1000,
nx = get_backend(A, M)
- if nx.__name__ == "jax":
- raise NotImplementedError("Log-domain functions are not yet implemented"
- " for Jax. Use numpy or torch arrays instead.")
+ if nx.__name__ in ("jax", "tf"):
+ raise NotImplementedError(
+ "Log-domain functions are not yet implemented"
+ " for Jax and tf. Use numpy or torch arrays instead."
+ )
if weights is None:
weights = nx.ones(n_hists, type_as=A) / n_hists
@@ -1886,9 +1888,11 @@ def _barycenter_debiased_log(A, M, reg, weights=None, numItermax=1000,
dim, n_hists = A.shape
nx = get_backend(A, M)
- if nx.__name__ == "jax":
- raise NotImplementedError("Log-domain functions are not yet implemented"
- " for Jax. Use numpy or torch arrays instead.")
+ if nx.__name__ in ("jax", "tf"):
+ raise NotImplementedError(
+ "Log-domain functions are not yet implemented"
+ " for Jax and TF. Use numpy or torch arrays instead."
+ )
if weights is None:
weights = nx.ones(n_hists, type_as=A) / n_hists
@@ -2043,7 +2047,7 @@ def _convolutional_barycenter2d(A, reg, weights=None, numItermax=10000,
log = {'err': []}
bar = nx.ones(A.shape[1:], type_as=A)
- bar /= bar.sum()
+ bar /= nx.sum(bar)
U = nx.ones(A.shape, type_as=A)
V = nx.ones(A.shape, type_as=A)
err = 1
@@ -2069,9 +2073,11 @@ def _convolutional_barycenter2d(A, reg, weights=None, numItermax=10000,
KV = convol_imgs(V)
U = A / KV
KU = convol_imgs(U)
- bar = nx.exp((weights[:, None, None] * nx.log(KU + stabThr)).sum(axis=0))
+ bar = nx.exp(
+ nx.sum(weights[:, None, None] * nx.log(KU + stabThr), axis=0)
+ )
if ii % 10 == 9:
- err = (V * KU).std(axis=0).sum()
+ err = nx.sum(nx.std(V * KU, axis=0))
# log and verbose print
if log:
log['err'].append(err)
@@ -2106,9 +2112,11 @@ def _convolutional_barycenter2d_log(A, reg, weights=None, numItermax=10000,
A = list_to_array(A)
nx = get_backend(A)
- if nx.__name__ == "jax":
- raise NotImplementedError("Log-domain functions are not yet implemented"
- " for Jax. Use numpy or torch arrays instead.")
+ if nx.__name__ in ("jax", "tf"):
+ raise NotImplementedError(
+ "Log-domain functions are not yet implemented"
+ " for Jax and TF. Use numpy or torch arrays instead."
+ )
n_hists, width, height = A.shape
@@ -2298,13 +2306,15 @@ def _convolutional_barycenter2d_debiased(A, reg, weights=None, numItermax=10000,
KV = convol_imgs(V)
U = A / KV
KU = convol_imgs(U)
- bar = c * nx.exp((weights[:, None, None] * nx.log(KU + stabThr)).sum(axis=0))
+ bar = c * nx.exp(
+ nx.sum(weights[:, None, None] * nx.log(KU + stabThr), axis=0)
+ )
for _ in range(10):
- c = (c * bar / convol_imgs(c[None]).squeeze()) ** 0.5
+ c = (c * bar / nx.squeeze(convol_imgs(c[None]))) ** 0.5
if ii % 10 == 9:
- err = (V * KU).std(axis=0).sum()
+ err = nx.sum(nx.std(V * KU, axis=0))
# log and verbose print
if log:
log['err'].append(err)
@@ -2340,9 +2350,11 @@ def _convolutional_barycenter2d_debiased_log(A, reg, weights=None, numItermax=10
A = list_to_array(A)
n_hists, width, height = A.shape
nx = get_backend(A)
- if nx.__name__ == "jax":
- raise NotImplementedError("Log-domain functions are not yet implemented"
- " for Jax. Use numpy or torch arrays instead.")
+ if nx.__name__ in ("jax", "tf"):
+ raise NotImplementedError(
+ "Log-domain functions are not yet implemented"
+ " for Jax and TF. Use numpy or torch arrays instead."
+ )
if weights is None:
weights = nx.ones((n_hists,), type_as=A) / n_hists
else:
@@ -2382,7 +2394,7 @@ def _convolutional_barycenter2d_debiased_log(A, reg, weights=None, numItermax=10
c = 0.5 * (c + log_bar - convol_img(c))
if ii % 10 == 9:
- err = nx.exp(G + log_KU).std(axis=0).sum()
+ err = nx.sum(nx.std(nx.exp(G + log_KU), axis=0))
# log and verbose print
if log:
log['err'].append(err)
@@ -3312,9 +3324,9 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False,
a, b, M = list_to_array(a, b, M)
nx = get_backend(M, a, b)
- if nx.__name__ == "jax":
- raise TypeError("JAX arrays have been received but screenkhorn is not "
- "compatible with JAX.")
+ if nx.__name__ in ("jax", "tf"):
+ raise TypeError("JAX or TF arrays have been received but screenkhorn is not "
+ "compatible with neither JAX nor TF.")
ns, nt = M.shape
@@ -3328,7 +3340,7 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False,
K = nx.exp(-M / reg)
def projection(u, epsilon):
- u[u <= epsilon] = epsilon
+ u = nx.maximum(u, epsilon)
return u
# ----------------------------------------------------------------------------------------------------------------#
diff --git a/ot/da.py b/ot/da.py
index 4fd97df..841f31a 100644
--- a/ot/da.py
+++ b/ot/da.py
@@ -906,7 +906,7 @@ def emd_laplace(a, b, xs, xt, M, sim='knn', sim_param=None, reg='pos', eta=1, al
def distribution_estimation_uniform(X):
- """estimates a uniform distribution from an array of samples :math:`\mathbf{X}`
+ r"""estimates a uniform distribution from an array of samples :math:`\mathbf{X}`
Parameters
----------
@@ -950,7 +950,7 @@ class BaseTransport(BaseEstimator):
"""
def fit(self, Xs=None, ys=None, Xt=None, yt=None):
- """Build a coupling matrix from source and target sets of samples
+ r"""Build a coupling matrix from source and target sets of samples
:math:`(\mathbf{X_s}, \mathbf{y_s})` and :math:`(\mathbf{X_t}, \mathbf{y_t})`
Parameters
@@ -1010,7 +1010,7 @@ class BaseTransport(BaseEstimator):
return self
def fit_transform(self, Xs=None, ys=None, Xt=None, yt=None):
- """Build a coupling matrix from source and target sets of samples
+ r"""Build a coupling matrix from source and target sets of samples
:math:`(\mathbf{X_s}, \mathbf{y_s})` and :math:`(\mathbf{X_t}, \mathbf{y_t})`
and transports source samples :math:`\mathbf{X_s}` onto target ones :math:`\mathbf{X_t}`
@@ -1038,7 +1038,7 @@ class BaseTransport(BaseEstimator):
return self.fit(Xs, ys, Xt, yt).transform(Xs, ys, Xt, yt)
def transform(self, Xs=None, ys=None, Xt=None, yt=None, batch_size=128):
- """Transports source samples :math:`\mathbf{X_s}` onto target ones :math:`\mathbf{X_t}`
+ r"""Transports source samples :math:`\mathbf{X_s}` onto target ones :math:`\mathbf{X_t}`
Parameters
----------
@@ -1105,7 +1105,7 @@ class BaseTransport(BaseEstimator):
return transp_Xs
def transform_labels(self, ys=None):
- """Propagate source labels :math:`\mathbf{y_s}` to obtain estimated target labels as in
+ r"""Propagate source labels :math:`\mathbf{y_s}` to obtain estimated target labels as in
:ref:`[27] <references-basetransport-transform-labels>`.
Parameters
@@ -1152,7 +1152,7 @@ class BaseTransport(BaseEstimator):
def inverse_transform(self, Xs=None, ys=None, Xt=None, yt=None,
batch_size=128):
- """Transports target samples :math:`\mathbf{X_t}` onto source samples :math:`\mathbf{X_s}`
+ r"""Transports target samples :math:`\mathbf{X_t}` onto source samples :math:`\mathbf{X_s}`
Parameters
----------
@@ -1218,7 +1218,7 @@ class BaseTransport(BaseEstimator):
return transp_Xt
def inverse_transform_labels(self, yt=None):
- """Propagate target labels :math:`\mathbf{y_t}` to obtain estimated source labels
+ r"""Propagate target labels :math:`\mathbf{y_t}` to obtain estimated source labels
:math:`\mathbf{y_s}`
Parameters
@@ -1307,7 +1307,7 @@ class LinearTransport(BaseTransport):
self.distribution_estimation = distribution_estimation
def fit(self, Xs=None, ys=None, Xt=None, yt=None):
- """Build a coupling matrix from source and target sets of samples
+ r"""Build a coupling matrix from source and target sets of samples
:math:`(\mathbf{X_s}, \mathbf{y_s})` and :math:`(\mathbf{X_t}, \mathbf{y_t})`
Parameters
@@ -1354,7 +1354,7 @@ class LinearTransport(BaseTransport):
return self
def transform(self, Xs=None, ys=None, Xt=None, yt=None, batch_size=128):
- """Transports source samples :math:`\mathbf{X_s}` onto target ones :math:`\mathbf{X_t}`
+ r"""Transports source samples :math:`\mathbf{X_s}` onto target ones :math:`\mathbf{X_t}`
Parameters
----------
@@ -1387,7 +1387,7 @@ class LinearTransport(BaseTransport):
def inverse_transform(self, Xs=None, ys=None, Xt=None, yt=None,
batch_size=128):
- """Transports target samples :math:`\mathbf{X_t}` onto source samples :math:`\mathbf{X_s}`
+ r"""Transports target samples :math:`\mathbf{X_t}` onto source samples :math:`\mathbf{X_s}`
Parameters
----------
@@ -1493,7 +1493,7 @@ class SinkhornTransport(BaseTransport):
self.out_of_sample_map = out_of_sample_map
def fit(self, Xs=None, ys=None, Xt=None, yt=None):
- """Build a coupling matrix from source and target sets of samples
+ r"""Build a coupling matrix from source and target sets of samples
:math:`(\mathbf{X_s}, \mathbf{y_s})` and :math:`(\mathbf{X_t}, \mathbf{y_t})`
Parameters
@@ -1592,7 +1592,7 @@ class EMDTransport(BaseTransport):
self.max_iter = max_iter
def fit(self, Xs, ys=None, Xt=None, yt=None):
- """Build a coupling matrix from source and target sets of samples
+ r"""Build a coupling matrix from source and target sets of samples
:math:`(\mathbf{X_s}, \mathbf{y_s})` and :math:`(\mathbf{X_t}, \mathbf{y_t})`
Parameters
@@ -1711,7 +1711,7 @@ class SinkhornLpl1Transport(BaseTransport):
self.limit_max = limit_max
def fit(self, Xs, ys=None, Xt=None, yt=None):
- """Build a coupling matrix from source and target sets of samples
+ r"""Build a coupling matrix from source and target sets of samples
:math:`(\mathbf{X_s}, \mathbf{y_s})` and :math:`(\mathbf{X_t}, \mathbf{y_t})`
Parameters
@@ -1839,7 +1839,7 @@ class EMDLaplaceTransport(BaseTransport):
self.out_of_sample_map = out_of_sample_map
def fit(self, Xs, ys=None, Xt=None, yt=None):
- """Build a coupling matrix from source and target sets of samples
+ r"""Build a coupling matrix from source and target sets of samples
:math:`(\mathbf{X_s}, \mathbf{y_s})` and :math:`(\mathbf{X_t}, \mathbf{y_t})`
Parameters
@@ -1962,7 +1962,7 @@ class SinkhornL1l2Transport(BaseTransport):
self.limit_max = limit_max
def fit(self, Xs, ys=None, Xt=None, yt=None):
- """Build a coupling matrix from source and target sets of samples
+ r"""Build a coupling matrix from source and target sets of samples
:math:`(\mathbf{X_s}, \mathbf{y_s})` and :math:`(\mathbf{X_t}, \mathbf{y_t})`
Parameters
@@ -2088,7 +2088,7 @@ class MappingTransport(BaseEstimator):
self.verbose2 = verbose2
def fit(self, Xs=None, ys=None, Xt=None, yt=None):
- """Builds an optimal coupling and estimates the associated mapping
+ r"""Builds an optimal coupling and estimates the associated mapping
from source and target sets of samples
:math:`(\mathbf{X_s}, \mathbf{y_s})` and :math:`(\mathbf{X_t}, \mathbf{y_t})`
@@ -2146,7 +2146,7 @@ class MappingTransport(BaseEstimator):
return self
def transform(self, Xs):
- """Transports source samples :math:`\mathbf{X_s}` onto target ones :math:`\mathbf{X_t}`
+ r"""Transports source samples :math:`\mathbf{X_s}` onto target ones :math:`\mathbf{X_t}`
Parameters
----------
@@ -2261,7 +2261,7 @@ class UnbalancedSinkhornTransport(BaseTransport):
self.limit_max = limit_max
def fit(self, Xs, ys=None, Xt=None, yt=None):
- """Build a coupling matrix from source and target sets of samples
+ r"""Build a coupling matrix from source and target sets of samples
:math:`(\mathbf{X_s}, \mathbf{y_s})` and :math:`(\mathbf{X_t}, \mathbf{y_t})`
Parameters
@@ -2373,7 +2373,7 @@ class JCPOTTransport(BaseTransport):
self.out_of_sample_map = out_of_sample_map
def fit(self, Xs, ys=None, Xt=None, yt=None):
- """Building coupling matrices from a list of source and target sets of samples
+ r"""Building coupling matrices from a list of source and target sets of samples
:math:`(\mathbf{X_s}, \mathbf{y_s})` and :math:`(\mathbf{X_t}, \mathbf{y_t})`
Parameters
@@ -2419,7 +2419,7 @@ class JCPOTTransport(BaseTransport):
return self
def transform(self, Xs=None, ys=None, Xt=None, yt=None, batch_size=128):
- """Transports source samples :math:`\mathbf{X_s}` onto target ones :math:`\mathbf{X_t}`
+ r"""Transports source samples :math:`\mathbf{X_s}` onto target ones :math:`\mathbf{X_t}`
Parameters
----------
@@ -2491,7 +2491,7 @@ class JCPOTTransport(BaseTransport):
return transp_Xs
def transform_labels(self, ys=None):
- """Propagate source labels :math:`\mathbf{y_s}` to obtain target labels as in
+ r"""Propagate source labels :math:`\mathbf{y_s}` to obtain target labels as in
:ref:`[27] <references-jcpottransport-transform-labels>`
Parameters
@@ -2542,7 +2542,7 @@ class JCPOTTransport(BaseTransport):
return yt.T
def inverse_transform_labels(self, yt=None):
- """Propagate target labels :math:`\mathbf{y_t}` to obtain estimated source labels
+ r"""Propagate target labels :math:`\mathbf{y_t}` to obtain estimated source labels
:math:`\mathbf{y_s}`
Parameters
diff --git a/ot/datasets.py b/ot/datasets.py
index ad6390c..a839074 100644
--- a/ot/datasets.py
+++ b/ot/datasets.py
@@ -41,7 +41,7 @@ def get_1D_gauss(n, m, sigma):
def make_2D_samples_gauss(n, m, sigma, random_state=None):
- """Return `n` samples drawn from 2D gaussian :math:`\mathcal{N}(m, \sigma)`
+ r"""Return `n` samples drawn from 2D gaussian :math:`\mathcal{N}(m, \sigma)`
Parameters
----------
diff --git a/ot/dr.py b/ot/dr.py
index c2f51f8..1671ca0 100644
--- a/ot/dr.py
+++ b/ot/dr.py
@@ -16,6 +16,7 @@ Dimension reduction with OT
from scipy import linalg
import autograd.numpy as np
+from pymanopt.function import Autograd
from pymanopt.manifolds import Stiefel
from pymanopt import Problem
from pymanopt.solvers import SteepestDescent, TrustRegions
@@ -181,6 +182,7 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None, no
else:
regmean = np.ones((len(xc), len(xc)))
+ @Autograd
def cost(P):
# wda loss
loss_b = 0
diff --git a/ot/gromov.py b/ot/gromov.py
index ea667e4..6544260 100644
--- a/ot/gromov.py
+++ b/ot/gromov.py
@@ -822,8 +822,12 @@ def GW_distance_estimation(C1, C2, p, q, loss_fun, T,
index_k = np.zeros((nb_samples_p, nb_samples_q), dtype=int)
index_l = np.zeros((nb_samples_p, nb_samples_q), dtype=int)
- index_i = generator.choice(len_p, size=nb_samples_p, p=p, replace=False)
- index_j = generator.choice(len_p, size=nb_samples_p, p=p, replace=False)
+ index_i = generator.choice(
+ len_p, size=nb_samples_p, p=nx.to_numpy(p), replace=False
+ )
+ index_j = generator.choice(
+ len_p, size=nb_samples_p, p=nx.to_numpy(p), replace=False
+ )
for i in range(nb_samples_p):
if nx.issparse(T):
@@ -836,13 +840,13 @@ def GW_distance_estimation(C1, C2, p, q, loss_fun, T,
index_k[i] = generator.choice(
len_q,
size=nb_samples_q,
- p=T_indexi / nx.sum(T_indexi),
+ p=nx.to_numpy(T_indexi / nx.sum(T_indexi)),
replace=True
)
index_l[i] = generator.choice(
len_q,
size=nb_samples_q,
- p=T_indexj / nx.sum(T_indexj),
+ p=nx.to_numpy(T_indexj / nx.sum(T_indexj)),
replace=True
)
@@ -934,15 +938,17 @@ def pointwise_gromov_wasserstein(C1, C2, p, q, loss_fun,
index = np.zeros(2, dtype=int)
# Initialize with default marginal
- index[0] = generator.choice(len_p, size=1, p=p)
- index[1] = generator.choice(len_q, size=1, p=q)
+ index[0] = generator.choice(len_p, size=1, p=nx.to_numpy(p))
+ index[1] = generator.choice(len_q, size=1, p=nx.to_numpy(q))
T = nx.tocsr(emd_1d(C1[index[0]], C2[index[1]], a=p, b=q, dense=False))
best_gw_dist_estimated = np.inf
for cpt in range(max_iter):
- index[0] = generator.choice(len_p, size=1, p=p)
+ index[0] = generator.choice(len_p, size=1, p=nx.to_numpy(p))
T_index0 = nx.reshape(nx.todense(T[index[0], :]), (-1,))
- index[1] = generator.choice(len_q, size=1, p=T_index0 / T_index0.sum())
+ index[1] = generator.choice(
+ len_q, size=1, p=nx.to_numpy(T_index0 / nx.sum(T_index0))
+ )
if alpha == 1:
T = nx.tocsr(
@@ -1071,13 +1077,16 @@ def sampled_gromov_wasserstein(C1, C2, p, q, loss_fun,
C_are_symmetric = nx.allclose(C1, C1.T, rtol=1e-10, atol=1e-10) and nx.allclose(C2, C2.T, rtol=1e-10, atol=1e-10)
for cpt in range(max_iter):
- index0 = generator.choice(len_p, size=nb_samples_grad_p, p=p, replace=False)
+ index0 = generator.choice(
+ len_p, size=nb_samples_grad_p, p=nx.to_numpy(p), replace=False
+ )
Lik = 0
for i, index0_i in enumerate(index0):
- index1 = generator.choice(len_q,
- size=nb_samples_grad_q,
- p=T[index0_i, :] / nx.sum(T[index0_i, :]),
- replace=False)
+ index1 = generator.choice(
+ len_q, size=nb_samples_grad_q,
+ p=nx.to_numpy(T[index0_i, :] / nx.sum(T[index0_i, :])),
+ replace=False
+ )
# If the matrices C are not symmetric, the gradient has 2 terms, thus the term is chosen randomly.
if (not C_are_symmetric) and generator.rand(1) > 0.5:
Lik += nx.mean(loss_fun(
@@ -1359,6 +1368,8 @@ def entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon,
-------
C : array-like, shape (`N`, `N`)
Similarity matrix in the barycenter space (permutated arbitrarily)
+ log : dict
+ Log dictionary of error during iterations. Return only if `log=True` in parameters.
References
----------
@@ -1392,7 +1403,7 @@ def entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon,
Cprev = C
T = [entropic_gromov_wasserstein(Cs[s], C, ps[s], p, loss_fun, epsilon,
- max_iter, 1e-4, verbose, log) for s in range(S)]
+ max_iter, 1e-4, verbose, log=False) for s in range(S)]
if loss_fun == 'square_loss':
C = update_square_loss(p, lambdas, T, Cs)
@@ -1405,9 +1416,6 @@ def entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon,
err = nx.norm(C - Cprev)
error.append(err)
- if log:
- log['err'].append(err)
-
if verbose:
if cpt % 200 == 0:
print('{:5s}|{:12s}'.format(
@@ -1416,7 +1424,10 @@ def entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon,
cpt += 1
- return C
+ if log:
+ return C, {"err": error}
+ else:
+ return C
def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun,
@@ -1470,6 +1481,8 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun,
-------
C : array-like, shape (`N`, `N`)
Similarity matrix in the barycenter space (permutated arbitrarily)
+ log : dict
+ Log dictionary of error during iterations. Return only if `log=True` in parameters.
References
----------
@@ -1504,7 +1517,7 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun,
Cprev = C
T = [gromov_wasserstein(Cs[s], C, ps[s], p, loss_fun,
- numItermax=max_iter, stopThr=1e-5, verbose=verbose, log=log) for s in range(S)]
+ numItermax=max_iter, stopThr=1e-5, verbose=verbose, log=False) for s in range(S)]
if loss_fun == 'square_loss':
C = update_square_loss(p, lambdas, T, Cs)
@@ -1517,9 +1530,6 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun,
err = nx.norm(C - Cprev)
error.append(err)
- if log:
- log['err'].append(err)
-
if verbose:
if cpt % 200 == 0:
print('{:5s}|{:12s}'.format(
@@ -1528,7 +1538,10 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun,
cpt += 1
- return C
+ if log:
+ return C, {"err": error}
+ else:
+ return C
def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_features=False,
diff --git a/ot/lp/solver_1d.py b/ot/lp/solver_1d.py
index 8b4d0c3..43763a9 100644
--- a/ot/lp/solver_1d.py
+++ b/ot/lp/solver_1d.py
@@ -100,11 +100,11 @@ def wasserstein_1d(u_values, v_values, u_weights=None, v_weights=None, p=1, requ
m = v_values.shape[0]
if u_weights is None:
- u_weights = nx.full(u_values.shape, 1. / n)
+ u_weights = nx.full(u_values.shape, 1. / n, type_as=u_values)
elif u_weights.ndim != u_values.ndim:
u_weights = nx.repeat(u_weights[..., None], u_values.shape[-1], -1)
if v_weights is None:
- v_weights = nx.full(v_values.shape, 1. / m)
+ v_weights = nx.full(v_values.shape, 1. / m, type_as=v_values)
elif v_weights.ndim != v_values.ndim:
v_weights = nx.repeat(v_weights[..., None], v_values.shape[-1], -1)
diff --git a/ot/optim.py b/ot/optim.py
index bd8ca26..f25e2c9 100644
--- a/ot/optim.py
+++ b/ot/optim.py
@@ -18,8 +18,10 @@ from .backend import get_backend
# The corresponding scipy function does not work for matrices
-def line_search_armijo(f, xk, pk, gfk, old_fval,
- args=(), c1=1e-4, alpha0=0.99):
+def line_search_armijo(
+ f, xk, pk, gfk, old_fval, args=(), c1=1e-4,
+ alpha0=0.99, alpha_min=None, alpha_max=None
+):
r"""
Armijo linesearch function that works with matrices
@@ -44,6 +46,10 @@ def line_search_armijo(f, xk, pk, gfk, old_fval,
:math:`c_1` const in armijo rule (>0)
alpha0 : float, optional
initial step (>0)
+ alpha_min : float, optional
+ minimum value for alpha
+ alpha_max : float, optional
+ maximum value for alpha
Returns
-------
@@ -77,14 +83,18 @@ def line_search_armijo(f, xk, pk, gfk, old_fval,
alpha, phi1 = scalar_search_armijo(
phi, phi0, derphi0, c1=c1, alpha0=alpha0)
- # scalar_search_armijo can return alpha > 1
- if alpha is not None:
- alpha = min(1, alpha)
- return alpha, fc[0], phi1
+ if alpha is None:
+ return 0., fc[0], phi0
+ else:
+ if alpha_min is not None or alpha_max is not None:
+ alpha = np.clip(alpha, alpha_min, alpha_max)
+ return float(alpha), fc[0], phi1
-def solve_linesearch(cost, G, deltaG, Mi, f_val,
- armijo=True, C1=None, C2=None, reg=None, Gc=None, constC=None, M=None):
+def solve_linesearch(
+ cost, G, deltaG, Mi, f_val, armijo=True, C1=None, C2=None,
+ reg=None, Gc=None, constC=None, M=None, alpha_min=None, alpha_max=None
+):
"""
Solve the linesearch in the FW iterations
@@ -115,6 +125,10 @@ def solve_linesearch(cost, G, deltaG, Mi, f_val,
Constant for the gromov cost. See :ref:`[24] <references-solve-linesearch>`. Only used and necessary when armijo=False
M : array-like (ns,nt), optional
Cost matrix between the features. Only used and necessary when armijo=False
+ alpha_min : float, optional
+ Minimum value for alpha
+ alpha_max : float, optional
+ Maximum value for alpha
Returns
-------
@@ -134,7 +148,9 @@ def solve_linesearch(cost, G, deltaG, Mi, f_val,
International Conference on Machine Learning (ICML). 2019.
"""
if armijo:
- alpha, fc, f_val = line_search_armijo(cost, G, deltaG, Mi, f_val)
+ alpha, fc, f_val = line_search_armijo(
+ cost, G, deltaG, Mi, f_val, alpha_min=alpha_min, alpha_max=alpha_max
+ )
else: # requires symetric matrices
G, deltaG, C1, C2, constC, M = list_to_array(G, deltaG, C1, C2, constC, M)
if isinstance(M, int) or isinstance(M, float):
@@ -148,6 +164,8 @@ def solve_linesearch(cost, G, deltaG, Mi, f_val,
c = cost(G)
alpha = solve_1d_linesearch_quad(a, b, c)
+ if alpha_min is not None or alpha_max is not None:
+ alpha = np.clip(alpha, alpha_min, alpha_max)
fc = None
f_val = cost(G + alpha * deltaG)
@@ -272,9 +290,10 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000,
deltaG = Gc - G
# line search
- alpha, fc, f_val = solve_linesearch(cost, G, deltaG, Mi, f_val, reg=reg, M=M, Gc=Gc, **kwargs)
- if alpha is None:
- alpha = 0.0
+ alpha, fc, f_val = solve_linesearch(
+ cost, G, deltaG, Mi, f_val, reg=reg, M=M, Gc=Gc,
+ alpha_min=0., alpha_max=1., **kwargs
+ )
G = G + alpha * deltaG
@@ -420,7 +439,9 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10,
# line search
dcost = Mi + reg1 * (1 + nx.log(G)) # ??
- alpha, fc, f_val = line_search_armijo(cost, G, deltaG, dcost, f_val)
+ alpha, fc, f_val = line_search_armijo(
+ cost, G, deltaG, dcost, f_val, alpha_min=0., alpha_max=1.
+ )
G = G + alpha * deltaG
diff --git a/ot/plot.py b/ot/plot.py
index 3e3bed7..2208c90 100644
--- a/ot/plot.py
+++ b/ot/plot.py
@@ -18,7 +18,7 @@ from matplotlib import gridspec
def plot1D_mat(a, b, M, title=''):
- """ Plot matrix :math:`\mathbf{M}` with the source and target 1D distribution
+ r""" Plot matrix :math:`\mathbf{M}` with the source and target 1D distribution
Creates a subplot with the source distribution :math:`\mathbf{a}` on the left and
target distribution :math:`\mathbf{b}` on the top. The matrix :math:`\mathbf{M}` is shown in between.
@@ -61,7 +61,7 @@ def plot1D_mat(a, b, M, title=''):
def plot2D_samples_mat(xs, xt, G, thr=1e-8, **kwargs):
- """ Plot matrix :math:`\mathbf{G}` in 2D with lines using alpha values
+ r""" Plot matrix :math:`\mathbf{G}` in 2D with lines using alpha values
Plot lines between source and target 2D samples with a color
proportional to the value of the matrix :math:`\mathbf{G}` between samples.
diff --git a/ot/utils.py b/ot/utils.py
index c878563..e6c93c8 100644
--- a/ot/utils.py
+++ b/ot/utils.py
@@ -182,7 +182,7 @@ def euclidean_distances(X, Y, squared=False):
return c
-def dist(x1, x2=None, metric='sqeuclidean', p=2):
+def dist(x1, x2=None, metric='sqeuclidean', p=2, w=None):
r"""Compute distance between samples in :math:`\mathbf{x_1}` and :math:`\mathbf{x_2}`
.. note:: This function is backend-compatible and will work on arrays
@@ -202,6 +202,10 @@ def dist(x1, x2=None, metric='sqeuclidean', p=2):
'euclidean', 'hamming', 'jaccard', 'kulsinski', 'mahalanobis',
'matching', 'minkowski', 'rogerstanimoto', 'russellrao', 'seuclidean',
'sokalmichener', 'sokalsneath', 'sqeuclidean', 'wminkowski', 'yule'.
+ p : float, optional
+ p-norm for the Minkowski and the Weighted Minkowski metrics. Default value is 2.
+ w : array-like, rank 1
+ Weights for the weighted metrics.
Returns
@@ -221,7 +225,9 @@ def dist(x1, x2=None, metric='sqeuclidean', p=2):
if not get_backend(x1, x2).__name__ == 'numpy':
raise NotImplementedError()
else:
- return cdist(x1, x2, metric=metric, p=p)
+ if metric.endswith("minkowski"):
+ return cdist(x1, x2, metric=metric, p=p, w=w)
+ return cdist(x1, x2, metric=metric, w=w)
def dist0(n, method='lin_square'):
diff --git a/pyproject.toml b/pyproject.toml
index 3f8ae8b..93ebab3 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,3 +1,3 @@
[build-system]
-requires = ["setuptools", "wheel", "numpy>=1.16", "cython>=0.23"]
+requires = ["setuptools", "wheel", "numpy>=1.20", "cython>=0.23"]
build-backend = "setuptools.build_meta" \ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
index 4353247..f9934ce 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,13 +1,14 @@
-numpy
+numpy>=1.20
scipy>=1.3
cython
matplotlib
autograd
pymanopt==0.2.4; python_version <'3'
-pymanopt; python_version >= '3'
+pymanopt==0.2.6rc1; python_version >= '3'
cvxopt
scikit-learn
torch
jax
jaxlib
+tensorflow
pytest \ No newline at end of file
diff --git a/setup.cfg b/setup.cfg
index 1177faf..9a4c434 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -1,5 +1,5 @@
[metadata]
-description-file = README.md
+description_file = README.md
[flake8]
exclude = __init__.py
diff --git a/setup.py b/setup.py
index 86c7c8d..17bf968 100644
--- a/setup.py
+++ b/setup.py
@@ -68,8 +68,8 @@ setup(
license='MIT',
scripts=[],
data_files=[],
- setup_requires=["numpy>=1.16", "cython>=0.23"],
- install_requires=["numpy>=1.16", "scipy>=1.0"],
+ setup_requires=["numpy>=1.20", "cython>=0.23"],
+ install_requires=["numpy>=1.20", "scipy>=1.0"],
classifiers=[
'Development Status :: 5 - Production/Stable',
'Intended Audience :: Developers',
@@ -95,5 +95,6 @@ setup(
'Programming Language :: Python :: 3.7',
'Programming Language :: Python :: 3.8',
'Programming Language :: Python :: 3.9',
+ 'Programming Language :: Python :: 3.10',
]
)
diff --git a/test/conftest.py b/test/conftest.py
index 987d98e..c0db8ab 100644
--- a/test/conftest.py
+++ b/test/conftest.py
@@ -5,7 +5,7 @@
# License: MIT License
import pytest
-from ot.backend import jax
+from ot.backend import jax, tf
from ot.backend import get_backend_list
import functools
@@ -13,6 +13,10 @@ if jax:
from jax.config import config
config.update("jax_enable_x64", True)
+if tf:
+ from tensorflow.python.ops.numpy_ops import np_config
+ np_config.enable_numpy_behavior()
+
backend_list = get_backend_list()
@@ -24,16 +28,16 @@ def nx(request):
def skip_arg(arg, value, reason=None, getter=lambda x: x):
- if isinstance(arg, tuple) or isinstance(arg, list):
+ if isinstance(arg, (tuple, list)):
n = len(arg)
else:
arg = (arg, )
n = 1
- if n != 1 and (isinstance(value, tuple) or isinstance(value, list)):
+ if n != 1 and isinstance(value, (tuple, list)):
pass
else:
value = (value, )
- if isinstance(getter, tuple) or isinstance(value, list):
+ if isinstance(getter, (tuple, list)):
pass
else:
getter = [getter] * n
diff --git a/test/test_1d_solver.py b/test/test_1d_solver.py
index cb85cb9..6a42cfe 100644
--- a/test/test_1d_solver.py
+++ b/test/test_1d_solver.py
@@ -11,7 +11,7 @@ import pytest
import ot
from ot.lp import wasserstein_1d
-from ot.backend import get_backend_list
+from ot.backend import get_backend_list, tf
from scipy.stats import wasserstein_distance
backend_list = get_backend_list()
@@ -86,7 +86,6 @@ def test_wasserstein_1d(nx):
def test_wasserstein_1d_type_devices(nx):
-
rng = np.random.RandomState(0)
n = 10
@@ -108,6 +107,37 @@ def test_wasserstein_1d_type_devices(nx):
nx.assert_same_dtype_device(xb, res)
+@pytest.mark.skipif(not tf, reason="tf not installed")
+def test_wasserstein_1d_device_tf():
+ if not tf:
+ return
+ nx = ot.backend.TensorflowBackend()
+ rng = np.random.RandomState(0)
+ n = 10
+ x = np.linspace(0, 5, n)
+ rho_u = np.abs(rng.randn(n))
+ rho_u /= rho_u.sum()
+ rho_v = np.abs(rng.randn(n))
+ rho_v /= rho_v.sum()
+
+ # Check that everything stays on the CPU
+ with tf.device("/CPU:0"):
+ xb = nx.from_numpy(x)
+ rho_ub = nx.from_numpy(rho_u)
+ rho_vb = nx.from_numpy(rho_v)
+ res = wasserstein_1d(xb, xb, rho_ub, rho_vb, p=1)
+ nx.assert_same_dtype_device(xb, res)
+
+ if len(tf.config.list_physical_devices('GPU')) > 0:
+ # Check that everything happens on the GPU
+ xb = nx.from_numpy(x)
+ rho_ub = nx.from_numpy(rho_u)
+ rho_vb = nx.from_numpy(rho_v)
+ res = wasserstein_1d(xb, xb, rho_ub, rho_vb, p=1)
+ nx.assert_same_dtype_device(xb, res)
+ assert nx.dtype_device(res)[1].startswith("GPU")
+
+
def test_emd_1d_emd2_1d():
# test emd1d gives similar results as emd
n = 20
@@ -148,7 +178,6 @@ def test_emd_1d_emd2_1d():
def test_emd1d_type_devices(nx):
-
rng = np.random.RandomState(0)
n = 10
@@ -170,3 +199,36 @@ def test_emd1d_type_devices(nx):
nx.assert_same_dtype_device(xb, emd)
nx.assert_same_dtype_device(xb, emd2)
+
+
+@pytest.mark.skipif(not tf, reason="tf not installed")
+def test_emd1d_device_tf():
+ nx = ot.backend.TensorflowBackend()
+ rng = np.random.RandomState(0)
+ n = 10
+ x = np.linspace(0, 5, n)
+ rho_u = np.abs(rng.randn(n))
+ rho_u /= rho_u.sum()
+ rho_v = np.abs(rng.randn(n))
+ rho_v /= rho_v.sum()
+
+ # Check that everything stays on the CPU
+ with tf.device("/CPU:0"):
+ xb = nx.from_numpy(x)
+ rho_ub = nx.from_numpy(rho_u)
+ rho_vb = nx.from_numpy(rho_v)
+ emd = ot.emd_1d(xb, xb, rho_ub, rho_vb)
+ emd2 = ot.emd2_1d(xb, xb, rho_ub, rho_vb)
+ nx.assert_same_dtype_device(xb, emd)
+ nx.assert_same_dtype_device(xb, emd2)
+
+ if len(tf.config.list_physical_devices('GPU')) > 0:
+ # Check that everything happens on the GPU
+ xb = nx.from_numpy(x)
+ rho_ub = nx.from_numpy(rho_u)
+ rho_vb = nx.from_numpy(rho_v)
+ emd = ot.emd_1d(xb, xb, rho_ub, rho_vb)
+ emd2 = ot.emd2_1d(xb, xb, rho_ub, rho_vb)
+ nx.assert_same_dtype_device(xb, emd)
+ nx.assert_same_dtype_device(xb, emd2)
+ assert nx.dtype_device(emd)[1].startswith("GPU")
diff --git a/test/test_backend.py b/test/test_backend.py
index 1832b91..027c4cd 100644
--- a/test/test_backend.py
+++ b/test/test_backend.py
@@ -7,7 +7,7 @@
import ot
import ot.backend
-from ot.backend import torch, jax
+from ot.backend import torch, jax, cp, tf
import pytest
@@ -87,6 +87,34 @@ def test_get_backend():
with pytest.raises(ValueError):
get_backend(A, B2)
+ if cp:
+ A2 = cp.asarray(A)
+ B2 = cp.asarray(B)
+
+ nx = get_backend(A2)
+ assert nx.__name__ == 'cupy'
+
+ nx = get_backend(A2, B2)
+ assert nx.__name__ == 'cupy'
+
+ # test not unique types in input
+ with pytest.raises(ValueError):
+ get_backend(A, B2)
+
+ if tf:
+ A2 = tf.convert_to_tensor(A)
+ B2 = tf.convert_to_tensor(B)
+
+ nx = get_backend(A2)
+ assert nx.__name__ == 'tf'
+
+ nx = get_backend(A2, B2)
+ assert nx.__name__ == 'tf'
+
+ # test not unique types in input
+ with pytest.raises(ValueError):
+ get_backend(A, B2)
+
def test_convert_between_backends(nx):
@@ -228,6 +256,14 @@ def test_empty_backend():
nx.copy(M)
with pytest.raises(NotImplementedError):
nx.allclose(M, M)
+ with pytest.raises(NotImplementedError):
+ nx.squeeze(M)
+ with pytest.raises(NotImplementedError):
+ nx.bitsize(M)
+ with pytest.raises(NotImplementedError):
+ nx.device_type(M)
+ with pytest.raises(NotImplementedError):
+ nx._bench(lambda x: x, M, n_runs=1)
def test_func_backends(nx):
@@ -240,7 +276,7 @@ def test_func_backends(nx):
# Sparse tensors test
sp_row = np.array([0, 3, 1, 0, 3])
sp_col = np.array([0, 3, 1, 2, 2])
- sp_data = np.array([4, 5, 7, 9, 0])
+ sp_data = np.array([4, 5, 7, 9, 0], dtype=np.float64)
lst_tot = []
@@ -393,7 +429,8 @@ def test_func_backends(nx):
lst_b.append(nx.to_numpy(A))
lst_name.append('argsort')
- A = nx.searchsorted(Mb, Mb, 'right')
+ tmp = nx.sort(Mb)
+ A = nx.searchsorted(tmp, tmp, 'right')
lst_b.append(nx.to_numpy(A))
lst_name.append('searchsorted')
@@ -476,7 +513,7 @@ def test_func_backends(nx):
lst_name.append('coo_matrix')
assert not nx.issparse(Mb), 'Assert fail on: issparse (expected False)'
- assert nx.issparse(sp_Mb) or nx.__name__ == "jax", 'Assert fail on: issparse (expected True)'
+ assert nx.issparse(sp_Mb) or nx.__name__ in ("jax", "tf"), 'Assert fail on: issparse (expected True)'
A = nx.tocsr(sp_Mb)
lst_b.append(nx.to_numpy(nx.todense(A)))
@@ -501,6 +538,18 @@ def test_func_backends(nx):
assert nx.allclose(Mb, Mb), 'Assert fail on: allclose (expected True)'
assert not nx.allclose(2 * Mb, Mb), 'Assert fail on: allclose (expected False)'
+ A = nx.squeeze(nx.zeros((3, 1, 4, 1)))
+ assert tuple(A.shape) == (3, 4), 'Assert fail on: squeeze'
+
+ A = nx.bitsize(Mb)
+ lst_b.append(float(A))
+ lst_name.append("bitsize")
+
+ A = nx.device_type(Mb)
+ assert A in ("CPU", "GPU")
+
+ nx._bench(lambda x: x, M, n_runs=1)
+
lst_tot.append(lst_b)
lst_np = lst_tot[0]
@@ -575,3 +624,17 @@ def test_gradients_backends():
np.testing.assert_almost_equal(fun(v, c, e), c * np.sum(v ** 4) + e, decimal=4)
np.testing.assert_allclose(grad_val[0], v, atol=1e-4)
np.testing.assert_allclose(grad_val[2], 2 * e, atol=1e-4)
+
+ if tf:
+ nx = ot.backend.TensorflowBackend()
+ w = tf.Variable(tf.random.normal((3, 2)), name='w')
+ b = tf.Variable(tf.random.normal((2,), dtype=tf.float32), name='b')
+ x = tf.random.normal((1, 3), dtype=tf.float32)
+
+ with tf.GradientTape() as tape:
+ y = x @ w + b
+ loss = tf.reduce_mean(y ** 2)
+ manipulated_loss = nx.set_gradients(loss, (w, b), (w, b))
+ [dl_dw, dl_db] = tape.gradient(manipulated_loss, [w, b])
+ assert nx.allclose(dl_dw, w)
+ assert nx.allclose(dl_db, b)
diff --git a/test/test_bregman.py b/test/test_bregman.py
index 830052d..6e90aa4 100644
--- a/test/test_bregman.py
+++ b/test/test_bregman.py
@@ -12,7 +12,7 @@ import numpy as np
import pytest
import ot
-from ot.backend import torch
+from ot.backend import torch, tf
@pytest.mark.parametrize("verbose, warn", product([True, False], [True, False]))
@@ -248,6 +248,7 @@ def test_sinkhorn_empty():
ot.sinkhorn([], [], M, 1, method='greenkhorn', stopThr=1e-10, log=True)
+@pytest.skip_backend('tf')
@pytest.skip_backend("jax")
def test_sinkhorn_variants(nx):
# test sinkhorn
@@ -282,6 +283,8 @@ def test_sinkhorn_variants(nx):
"sinkhorn_epsilon_scaling",
"greenkhorn",
"sinkhorn_log"])
+@pytest.skip_arg(("nx", "method"), ("tf", "sinkhorn_epsilon_scaling"), reason="tf does not support sinkhorn_epsilon_scaling", getter=str)
+@pytest.skip_arg(("nx", "method"), ("tf", "greenkhorn"), reason="tf does not support greenkhorn", getter=str)
@pytest.skip_arg(("nx", "method"), ("jax", "sinkhorn_epsilon_scaling"), reason="jax does not support sinkhorn_epsilon_scaling", getter=str)
@pytest.skip_arg(("nx", "method"), ("jax", "greenkhorn"), reason="jax does not support greenkhorn", getter=str)
def test_sinkhorn_variants_dtype_device(nx, method):
@@ -323,6 +326,36 @@ def test_sinkhorn2_variants_dtype_device(nx, method):
nx.assert_same_dtype_device(Mb, lossb)
+@pytest.mark.skipif(not tf, reason="tf not installed")
+@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized", "sinkhorn_log"])
+def test_sinkhorn2_variants_device_tf(method):
+ nx = ot.backend.TensorflowBackend()
+ n = 100
+ x = np.random.randn(n, 2)
+ u = ot.utils.unif(n)
+ M = ot.dist(x, x)
+
+ # Check that everything stays on the CPU
+ with tf.device("/CPU:0"):
+ ub = nx.from_numpy(u)
+ Mb = nx.from_numpy(M)
+ Gb = ot.sinkhorn(ub, ub, Mb, 1, method=method, stopThr=1e-10)
+ lossb = ot.sinkhorn2(ub, ub, Mb, 1, method=method, stopThr=1e-10)
+ nx.assert_same_dtype_device(Mb, Gb)
+ nx.assert_same_dtype_device(Mb, lossb)
+
+ if len(tf.config.list_physical_devices('GPU')) > 0:
+ # Check that everything happens on the GPU
+ ub = nx.from_numpy(u)
+ Mb = nx.from_numpy(M)
+ Gb = ot.sinkhorn(ub, ub, Mb, 1, method=method, stopThr=1e-10)
+ lossb = ot.sinkhorn2(ub, ub, Mb, 1, method=method, stopThr=1e-10)
+ nx.assert_same_dtype_device(Mb, Gb)
+ nx.assert_same_dtype_device(Mb, lossb)
+ assert nx.dtype_device(Gb)[1].startswith("GPU")
+
+
+@pytest.skip_backend('tf')
@pytest.skip_backend("jax")
def test_sinkhorn_variants_multi_b(nx):
# test sinkhorn
@@ -352,6 +385,7 @@ def test_sinkhorn_variants_multi_b(nx):
np.testing.assert_allclose(G0, Gs, atol=1e-05)
+@pytest.skip_backend('tf')
@pytest.skip_backend("jax")
def test_sinkhorn2_variants_multi_b(nx):
# test sinkhorn
@@ -454,7 +488,7 @@ def test_barycenter(nx, method, verbose, warn):
weights_nx = nx.from_numpy(weights)
reg = 1e-2
- if nx.__name__ == "jax" and method == "sinkhorn_log":
+ if nx.__name__ in ("jax", "tf") and method == "sinkhorn_log":
with pytest.raises(NotImplementedError):
ot.bregman.barycenter(A_nx, M_nx, reg, weights, method=method)
else:
@@ -495,7 +529,7 @@ def test_barycenter_debiased(nx, method, verbose, warn):
# wasserstein
reg = 1e-2
- if nx.__name__ == "jax" and method == "sinkhorn_log":
+ if nx.__name__ in ("jax", "tf") and method == "sinkhorn_log":
with pytest.raises(NotImplementedError):
ot.bregman.barycenter_debiased(A_nx, M_nx, reg, weights, method=method)
else:
@@ -597,7 +631,7 @@ def test_wasserstein_bary_2d(nx, method):
# wasserstein
reg = 1e-2
- if nx.__name__ == "jax" and method == "sinkhorn_log":
+ if nx.__name__ in ("jax", "tf") and method == "sinkhorn_log":
with pytest.raises(NotImplementedError):
ot.bregman.convolutional_barycenter2d(A_nx, reg, method=method)
else:
@@ -629,7 +663,7 @@ def test_wasserstein_bary_2d_debiased(nx, method):
# wasserstein
reg = 1e-2
- if nx.__name__ == "jax" and method == "sinkhorn_log":
+ if nx.__name__ in ("jax", "tf") and method == "sinkhorn_log":
with pytest.raises(NotImplementedError):
ot.bregman.convolutional_barycenter2d_debiased(A_nx, reg, method=method)
else:
@@ -888,6 +922,8 @@ def test_implemented_methods():
ot.bregman.sinkhorn2(a, b, M, epsilon, method=method)
+@pytest.skip_backend('tf')
+@pytest.skip_backend("cupy")
@pytest.skip_backend("jax")
@pytest.mark.filterwarnings("ignore:Bottleneck")
def test_screenkhorn(nx):
diff --git a/test/test_gromov.py b/test/test_gromov.py
index c4bc04c..4b995d5 100644
--- a/test/test_gromov.py
+++ b/test/test_gromov.py
@@ -9,7 +9,7 @@
import numpy as np
import ot
from ot.backend import NumpyBackend
-from ot.backend import torch
+from ot.backend import torch, tf
import pytest
@@ -54,9 +54,12 @@ def test_gromov(nx):
gw, log = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', log=True)
gwb, logb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', log=True)
+ gwb = nx.to_numpy(gwb)
gw_val = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', log=False)
- gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', log=False)
+ gw_valb = nx.to_numpy(
+ ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', log=False)
+ )
G = log['T']
Gb = nx.to_numpy(logb['T'])
@@ -110,6 +113,45 @@ def test_gromov_dtype_device(nx):
nx.assert_same_dtype_device(C1b, gw_valb)
+@pytest.mark.skipif(not tf, reason="tf not installed")
+def test_gromov_device_tf():
+ nx = ot.backend.TensorflowBackend()
+ n_samples = 50 # nb samples
+ mu_s = np.array([0, 0])
+ cov_s = np.array([[1, 0], [0, 1]])
+ xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=4)
+ xt = xs[::-1].copy()
+ p = ot.unif(n_samples)
+ q = ot.unif(n_samples)
+ C1 = ot.dist(xs, xs)
+ C2 = ot.dist(xt, xt)
+ C1 /= C1.max()
+ C2 /= C2.max()
+
+ # Check that everything stays on the CPU
+ with tf.device("/CPU:0"):
+ C1b = nx.from_numpy(C1)
+ C2b = nx.from_numpy(C2)
+ pb = nx.from_numpy(p)
+ qb = nx.from_numpy(q)
+ Gb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', verbose=True)
+ gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', log=False)
+ nx.assert_same_dtype_device(C1b, Gb)
+ nx.assert_same_dtype_device(C1b, gw_valb)
+
+ if len(tf.config.list_physical_devices('GPU')) > 0:
+ # Check that everything happens on the GPU
+ C1b = nx.from_numpy(C1)
+ C2b = nx.from_numpy(C2)
+ pb = nx.from_numpy(p)
+ qb = nx.from_numpy(q)
+ Gb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', verbose=True)
+ gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', log=False)
+ nx.assert_same_dtype_device(C1b, Gb)
+ nx.assert_same_dtype_device(C1b, gw_valb)
+ assert nx.dtype_device(Gb)[1].startswith("GPU")
+
+
def test_gromov2_gradients():
n_samples = 50 # nb samples
@@ -147,6 +189,7 @@ def test_gromov2_gradients():
@pytest.skip_backend("jax", reason="test very slow with jax backend")
+@pytest.skip_backend("tf", reason="test very slow with tf backend")
def test_entropic_gromov(nx):
n_samples = 50 # nb samples
@@ -188,6 +231,7 @@ def test_entropic_gromov(nx):
C1, C2, p, q, 'kl_loss', epsilon=1e-2, log=True)
gwb, logb = ot.gromov.entropic_gromov_wasserstein2(
C1b, C2b, pb, qb, 'kl_loss', epsilon=1e-2, log=True)
+ gwb = nx.to_numpy(gwb)
G = log['T']
Gb = nx.to_numpy(logb['T'])
@@ -204,6 +248,7 @@ def test_entropic_gromov(nx):
@pytest.skip_backend("jax", reason="test very slow with jax backend")
+@pytest.skip_backend("tf", reason="test very slow with tf backend")
def test_entropic_gromov_dtype_device(nx):
# setup
n_samples = 50 # nb samples
@@ -287,8 +332,8 @@ def test_pointwise_gromov(nx):
np.testing.assert_allclose(
q, Gb.sum(0), atol=1e-04) # cf convergence gromov
- np.testing.assert_allclose(logb['gw_dist_estimated'], 0.0, atol=1e-08)
- np.testing.assert_allclose(logb['gw_dist_std'], 0.0, atol=1e-08)
+ np.testing.assert_allclose(float(logb['gw_dist_estimated']), 0.0, atol=1e-08)
+ np.testing.assert_allclose(float(logb['gw_dist_std']), 0.0, atol=1e-08)
G, log = ot.gromov.pointwise_gromov_wasserstein(
C1, C2, p, q, loss, max_iter=100, alpha=0.1, log=True, verbose=True, random_state=42)
@@ -298,10 +343,11 @@ def test_pointwise_gromov(nx):
Gb = nx.to_numpy(nx.todense(Gb))
np.testing.assert_allclose(G, Gb, atol=1e-06)
- np.testing.assert_allclose(logb['gw_dist_estimated'], 0.10342276348494964, atol=1e-8)
- np.testing.assert_allclose(logb['gw_dist_std'], 0.0015952535464736394, atol=1e-8)
+ np.testing.assert_allclose(float(logb['gw_dist_estimated']), 0.10342276348494964, atol=1e-8)
+ np.testing.assert_allclose(float(logb['gw_dist_std']), 0.0015952535464736394, atol=1e-8)
+@pytest.skip_backend("tf", reason="test very slow with tf backend")
@pytest.skip_backend("jax", reason="test very slow with jax backend")
def test_sampled_gromov(nx):
n_samples = 50 # nb samples
@@ -346,8 +392,8 @@ def test_sampled_gromov(nx):
np.testing.assert_allclose(
q, Gb.sum(0), atol=1e-04) # cf convergence gromov
- np.testing.assert_allclose(logb['gw_dist_estimated'], 0.05679474884977278, atol=1e-08)
- np.testing.assert_allclose(logb['gw_dist_std'], 0.0005986592106971995, atol=1e-08)
+ np.testing.assert_allclose(float(logb['gw_dist_estimated']), 0.05679474884977278, atol=1e-08)
+ np.testing.assert_allclose(float(logb['gw_dist_std']), 0.0005986592106971995, atol=1e-08)
def test_gromov_barycenter(nx):
@@ -381,6 +427,20 @@ def test_gromov_barycenter(nx):
np.testing.assert_allclose(Cb, Cbb, atol=1e-06)
np.testing.assert_allclose(Cbb.shape, (n_samples, n_samples))
+ # test of gromov_barycenters with `log` on
+ Cb_, err_ = ot.gromov.gromov_barycenters(
+ n_samples, [C1, C2], [p1, p2], p, [.5, .5],
+ 'square_loss', max_iter=100, tol=1e-3, verbose=True, random_state=42, log=True
+ )
+ Cbb_, errb_ = ot.gromov.gromov_barycenters(
+ n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5],
+ 'square_loss', max_iter=100, tol=1e-3, verbose=True, random_state=42, log=True
+ )
+ Cbb_ = nx.to_numpy(Cbb_)
+ np.testing.assert_allclose(Cb_, Cbb_, atol=1e-06)
+ np.testing.assert_array_almost_equal(err_['err'], errb_['err'])
+ np.testing.assert_allclose(Cbb_.shape, (n_samples, n_samples))
+
Cb2 = ot.gromov.gromov_barycenters(
n_samples, [C1, C2], [p1, p2], p, [.5, .5],
'kl_loss', max_iter=100, tol=1e-3, random_state=42
@@ -392,6 +452,20 @@ def test_gromov_barycenter(nx):
np.testing.assert_allclose(Cb2, Cb2b, atol=1e-06)
np.testing.assert_allclose(Cb2b.shape, (n_samples, n_samples))
+ # test of gromov_barycenters with `log` on
+ Cb2_, err2_ = ot.gromov.gromov_barycenters(
+ n_samples, [C1, C2], [p1, p2], p, [.5, .5],
+ 'kl_loss', max_iter=100, tol=1e-3, verbose=True, random_state=42, log=True
+ )
+ Cb2b_, err2b_ = ot.gromov.gromov_barycenters(
+ n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5],
+ 'kl_loss', max_iter=100, tol=1e-3, verbose=True, random_state=42, log=True
+ )
+ Cb2b_ = nx.to_numpy(Cb2b_)
+ np.testing.assert_allclose(Cb2_, Cb2b_, atol=1e-06)
+ np.testing.assert_array_almost_equal(err2_['err'], err2_['err'])
+ np.testing.assert_allclose(Cb2b_.shape, (n_samples, n_samples))
+
@pytest.mark.filterwarnings("ignore:divide")
def test_gromov_entropic_barycenter(nx):
@@ -425,6 +499,20 @@ def test_gromov_entropic_barycenter(nx):
np.testing.assert_allclose(Cb, Cbb, atol=1e-06)
np.testing.assert_allclose(Cbb.shape, (n_samples, n_samples))
+ # test of entropic_gromov_barycenters with `log` on
+ Cb_, err_ = ot.gromov.entropic_gromov_barycenters(
+ n_samples, [C1, C2], [p1, p2], p, [.5, .5],
+ 'square_loss', 1e-3, max_iter=100, tol=1e-3, verbose=True, random_state=42, log=True
+ )
+ Cbb_, errb_ = ot.gromov.entropic_gromov_barycenters(
+ n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5],
+ 'square_loss', 1e-3, max_iter=100, tol=1e-3, verbose=True, random_state=42, log=True
+ )
+ Cbb_ = nx.to_numpy(Cbb_)
+ np.testing.assert_allclose(Cb_, Cbb_, atol=1e-06)
+ np.testing.assert_array_almost_equal(err_['err'], errb_['err'])
+ np.testing.assert_allclose(Cbb_.shape, (n_samples, n_samples))
+
Cb2 = ot.gromov.entropic_gromov_barycenters(
n_samples, [C1, C2], [p1, p2], p, [.5, .5],
'kl_loss', 1e-3, max_iter=100, tol=1e-3, random_state=42
@@ -436,6 +524,20 @@ def test_gromov_entropic_barycenter(nx):
np.testing.assert_allclose(Cb2, Cb2b, atol=1e-06)
np.testing.assert_allclose(Cb2b.shape, (n_samples, n_samples))
+ # test of entropic_gromov_barycenters with `log` on
+ Cb2_, err2_ = ot.gromov.entropic_gromov_barycenters(
+ n_samples, [C1, C2], [p1, p2], p, [.5, .5],
+ 'kl_loss', 1e-3, max_iter=100, tol=1e-3, verbose=True, random_state=42, log=True
+ )
+ Cb2b_, err2b_ = ot.gromov.entropic_gromov_barycenters(
+ n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5],
+ 'kl_loss', 1e-3, max_iter=100, tol=1e-3, verbose=True, random_state=42, log=True
+ )
+ Cb2b_ = nx.to_numpy(Cb2b_)
+ np.testing.assert_allclose(Cb2_, Cb2b_, atol=1e-06)
+ np.testing.assert_array_almost_equal(err2_['err'], err2_['err'])
+ np.testing.assert_allclose(Cb2b_.shape, (n_samples, n_samples))
+
def test_fgw(nx):
n_samples = 50 # nb samples
@@ -486,6 +588,7 @@ def test_fgw(nx):
fgw, log = ot.gromov.fused_gromov_wasserstein2(M, C1, C2, p, q, 'square_loss', alpha=0.5, log=True)
fgwb, logb = ot.gromov.fused_gromov_wasserstein2(Mb, C1b, C2b, pb, qb, 'square_loss', alpha=0.5, log=True)
+ fgwb = nx.to_numpy(fgwb)
G = log['T']
Gb = nx.to_numpy(logb['T'])
diff --git a/test/test_optim.py b/test/test_optim.py
index 4efd9b1..41f9cbe 100644
--- a/test/test_optim.py
+++ b/test/test_optim.py
@@ -142,7 +142,7 @@ def test_line_search_armijo(nx):
pk = np.array([[-0.25, 0.25], [0.25, -0.25]])
gfk = np.array([[23.04273441, 23.0449082], [23.04273441, 23.0449082]])
old_fval = -123
- # Should not throw an exception and return None for alpha
+ # Should not throw an exception and return 0. for alpha
alpha, a, b = ot.optim.line_search_armijo(
lambda x: 1, nx.from_numpy(xk), nx.from_numpy(pk), nx.from_numpy(gfk), old_fval
)
@@ -151,7 +151,7 @@ def test_line_search_armijo(nx):
)
assert a == anp
assert b == bnp
- assert alpha is None
+ assert alpha == 0.
# check line search armijo
def f(x):
diff --git a/test/test_ot.py b/test/test_ot.py
index 92f26a7..53edf4f 100644
--- a/test/test_ot.py
+++ b/test/test_ot.py
@@ -11,7 +11,7 @@ import pytest
import ot
from ot.datasets import make_1D_gauss as gauss
-from ot.backend import torch
+from ot.backend import torch, tf
def test_emd_dimension_and_mass_mismatch():
@@ -101,6 +101,40 @@ def test_emd_emd2_types_devices(nx):
nx.assert_same_dtype_device(Mb, w)
+@pytest.mark.skipif(not tf, reason="tf not installed")
+def test_emd_emd2_devices_tf():
+ if not tf:
+ return
+ nx = ot.backend.TensorflowBackend()
+
+ n_samples = 100
+ n_features = 2
+ rng = np.random.RandomState(0)
+ x = rng.randn(n_samples, n_features)
+ y = rng.randn(n_samples, n_features)
+ a = ot.utils.unif(n_samples)
+ M = ot.dist(x, y)
+
+ # Check that everything stays on the CPU
+ with tf.device("/CPU:0"):
+ ab = nx.from_numpy(a)
+ Mb = nx.from_numpy(M)
+ Gb = ot.emd(ab, ab, Mb)
+ w = ot.emd2(ab, ab, Mb)
+ nx.assert_same_dtype_device(Mb, Gb)
+ nx.assert_same_dtype_device(Mb, w)
+
+ if len(tf.config.list_physical_devices('GPU')) > 0:
+ # Check that everything happens on the GPU
+ ab = nx.from_numpy(a)
+ Mb = nx.from_numpy(M)
+ Gb = ot.emd(ab, ab, Mb)
+ w = ot.emd2(ab, ab, Mb)
+ nx.assert_same_dtype_device(Mb, Gb)
+ nx.assert_same_dtype_device(Mb, w)
+ assert nx.dtype_device(Gb)[1].startswith("GPU")
+
+
def test_emd2_gradients():
n_samples = 100
n_features = 2
@@ -126,6 +160,22 @@ def test_emd2_gradients():
assert b1.shape == b1.grad.shape
assert M1.shape == M1.grad.shape
+ # Testing for bug #309, checking for scaling of gradient
+ a2 = torch.tensor(a, requires_grad=True)
+ b2 = torch.tensor(a, requires_grad=True)
+ M2 = torch.tensor(M, requires_grad=True)
+
+ val = 10.0 * ot.emd2(a2, b2, M2)
+
+ val.backward()
+
+ assert np.allclose(10.0 * a1.grad.cpu().detach().numpy(),
+ a2.grad.cpu().detach().numpy())
+ assert np.allclose(10.0 * b1.grad.cpu().detach().numpy(),
+ b2.grad.cpu().detach().numpy())
+ assert np.allclose(10.0 * M1.grad.cpu().detach().numpy(),
+ M2.grad.cpu().detach().numpy())
+
def test_emd_emd2():
# test emd and emd2 for simple identity
diff --git a/test/test_sliced.py b/test/test_sliced.py
index 245202c..91e0961 100644
--- a/test/test_sliced.py
+++ b/test/test_sliced.py
@@ -10,6 +10,7 @@ import pytest
import ot
from ot.sliced import get_random_projections
+from ot.backend import tf
def test_get_random_projections():
@@ -161,6 +162,34 @@ def test_sliced_backend_type_devices(nx):
nx.assert_same_dtype_device(xb, valb)
+@pytest.mark.skipif(not tf, reason="tf not installed")
+def test_sliced_backend_device_tf():
+ nx = ot.backend.TensorflowBackend()
+ n = 100
+ rng = np.random.RandomState(0)
+ x = rng.randn(n, 2)
+ y = rng.randn(2 * n, 2)
+ P = rng.randn(2, 20)
+ P = P / np.sqrt((P**2).sum(0, keepdims=True))
+
+ # Check that everything stays on the CPU
+ with tf.device("/CPU:0"):
+ xb = nx.from_numpy(x)
+ yb = nx.from_numpy(y)
+ Pb = nx.from_numpy(P)
+ valb = ot.sliced_wasserstein_distance(xb, yb, projections=Pb)
+ nx.assert_same_dtype_device(xb, valb)
+
+ if len(tf.config.list_physical_devices('GPU')) > 0:
+ # Check that everything happens on the GPU
+ xb = nx.from_numpy(x)
+ yb = nx.from_numpy(y)
+ Pb = nx.from_numpy(P)
+ valb = ot.sliced_wasserstein_distance(xb, yb, projections=Pb)
+ nx.assert_same_dtype_device(xb, valb)
+ assert nx.dtype_device(valb)[1].startswith("GPU")
+
+
def test_max_sliced_backend(nx):
n = 100
@@ -211,3 +240,31 @@ def test_max_sliced_backend_type_devices(nx):
valb = ot.max_sliced_wasserstein_distance(xb, yb, projections=Pb)
nx.assert_same_dtype_device(xb, valb)
+
+
+@pytest.mark.skipif(not tf, reason="tf not installed")
+def test_max_sliced_backend_device_tf():
+ nx = ot.backend.TensorflowBackend()
+ n = 100
+ rng = np.random.RandomState(0)
+ x = rng.randn(n, 2)
+ y = rng.randn(2 * n, 2)
+ P = rng.randn(2, 20)
+ P = P / np.sqrt((P**2).sum(0, keepdims=True))
+
+ # Check that everything stays on the CPU
+ with tf.device("/CPU:0"):
+ xb = nx.from_numpy(x)
+ yb = nx.from_numpy(y)
+ Pb = nx.from_numpy(P)
+ valb = ot.max_sliced_wasserstein_distance(xb, yb, projections=Pb)
+ nx.assert_same_dtype_device(xb, valb)
+
+ if len(tf.config.list_physical_devices('GPU')) > 0:
+ # Check that everything happens on the GPU
+ xb = nx.from_numpy(x)
+ yb = nx.from_numpy(y)
+ Pb = nx.from_numpy(P)
+ valb = ot.max_sliced_wasserstein_distance(xb, yb, projections=Pb)
+ nx.assert_same_dtype_device(xb, valb)
+ assert nx.dtype_device(valb)[1].startswith("GPU")
diff --git a/test/test_utils.py b/test/test_utils.py
index 40f4e49..6b476b2 100644
--- a/test/test_utils.py
+++ b/test/test_utils.py
@@ -117,6 +117,26 @@ def test_dist():
np.testing.assert_allclose(D, D2, atol=1e-14)
np.testing.assert_allclose(D, D3, atol=1e-14)
+ # tests that every metric runs correctly
+ metrics_w = [
+ 'braycurtis', 'canberra', 'chebyshev', 'cityblock', 'correlation', 'cosine', 'dice',
+ 'euclidean', 'hamming', 'jaccard', 'kulsinski',
+ 'matching', 'minkowski', 'rogerstanimoto', 'russellrao',
+ 'sokalmichener', 'sokalsneath', 'sqeuclidean', 'wminkowski', 'yule'
+ ] # those that support weights
+ metrics = ['mahalanobis', 'seuclidean'] # do not support weights depending on scipy's version
+
+ for metric in metrics_w:
+ print(metric)
+ ot.dist(x, x, metric=metric, p=3, w=np.random.random((2, )))
+ for metric in metrics:
+ print(metric)
+ ot.dist(x, x, metric=metric, p=3)
+
+ # weighted minkowski but with no weights
+ with pytest.raises(ValueError):
+ ot.dist(x, x, metric="wminkowski")
+
def test_dist_backends(nx):