summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNathan Cassereau <84033440+ncassereau-idris@users.noreply.github.com>2021-12-09 17:55:12 +0100
committerGitHub <noreply@github.com>2021-12-09 17:55:12 +0100
commitf8d871e8c6f15009f559ece6a12eb8d8891c60fb (patch)
tree9aa46b2fcc8046c6cddd8e9159a6f607dcf0e1e9
parentb3dc68feac355fa94c4237f4ecad65edc9f7a7e8 (diff)
[MRG] Tensorflow backend & Benchmarker & Myst_parser (#316)
* First batch of tf methods (to be continued) * Second batch of method (yet to debug) * tensorflow for cpu * add tf requirement * pep8 + bug * small changes * attempt to solve pymanopt bug with tf2 * attempt #2 * attempt #3 * attempt 4 * docstring * correct pep8 violation introduced in merge conflicts resolution * attempt 5 * attempt 6 * just a random try * Revert "just a random try" This reverts commit 8223e768bfe33635549fb66cca2267514a60ebbf. * GPU tests for tensorflow * pep8 * attempt to solve issue with m2r2 * Remove transpose backend method * first draft of benchmarker (need to correct time measurement) * prettier bench table * Bitsize and prettier device methods * prettified table bench * Bug corrected (results were mixed up in the final table) * Better perf counter (for GPU support) * pep8 * EMD bench * solve bug if no GPU available * pep8 * warning about tensorflow numpy api being required in the backend.py docstring * Bug solve in backend docstring * not covering code which requires a GPU * Tensorflow gradients manipulation tested * Number of warmup runs is now customizable * typo * Remove some warnings while building docs * Change prettier_device to device_type in backend * Correct JAX mistakes preventing to see the CPU if a GPU is present * Attempt to solve JAX bug in case no GPU is found * Reworked benchmarks order and results storage & clear GPU after usage by benchmark * Add bench to backend docstring * better benchs * remove useless stuff * Better device_type * Now using MYST_PARSER and solving links issue in the README.md / online docs
-rw-r--r--.github/requirements_test_windows.txt2
-rw-r--r--README.md8
-rw-r--r--benchmarks/__init__.py5
-rw-r--r--benchmarks/benchmark.py105
-rw-r--r--benchmarks/emd.py40
-rw-r--r--benchmarks/sinkhorn_knopp.py42
-rw-r--r--docs/requirements.txt2
-rw-r--r--docs/requirements_rtd.txt2
-rw-r--r--docs/source/.github/CODE_OF_CONDUCT.rst6
-rw-r--r--docs/source/.github/CONTRIBUTING.rst6
-rw-r--r--docs/source/code_of_conduct.rst1
-rw-r--r--docs/source/conf.py2
-rw-r--r--docs/source/contributing.rst1
-rw-r--r--docs/source/index.rst9
-rw-r--r--ot/backend.py580
-rw-r--r--ot/bregman.py72
-rw-r--r--ot/da.py44
-rw-r--r--ot/datasets.py2
-rw-r--r--ot/dr.py2
-rw-r--r--ot/gromov.py2
-rw-r--r--ot/lp/solver_1d.py4
-rw-r--r--ot/plot.py4
-rw-r--r--requirements.txt3
-rw-r--r--test/conftest.py12
-rw-r--r--test/test_1d_solver.py68
-rw-r--r--test/test_backend.py52
-rw-r--r--test/test_bregman.py45
-rw-r--r--test/test_gromov.py44
-rw-r--r--test/test_ot.py36
-rw-r--r--test/test_sliced.py57
30 files changed, 1161 insertions, 97 deletions
diff --git a/.github/requirements_test_windows.txt b/.github/requirements_test_windows.txt
index 331dd57..b94392f 100644
--- a/.github/requirements_test_windows.txt
+++ b/.github/requirements_test_windows.txt
@@ -4,7 +4,7 @@ cython
matplotlib
autograd
pymanopt==0.2.4; python_version <'3'
-pymanopt; python_version >= '3'
+pymanopt==0.2.6rc1; python_version >= '3'
cvxopt
scikit-learn
pytest \ No newline at end of file
diff --git a/README.md b/README.md
index 18064a3..17fbe81 100644
--- a/README.md
+++ b/README.md
@@ -35,7 +35,7 @@ POT provides the following generic OT solvers (links to examples):
* [Partial Wasserstein and Gromov-Wasserstein](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_partial_wass_and_gromov.html) (exact [29] and entropic [3]
formulations).
* [Sliced Wasserstein](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_variance.html) [31, 32] and Max-sliced Wasserstein [35] that can be used for gradient flows [36].
-* [Several backends](https://pythonot.github.io/quickstart.html#solving-ot-with-multiple-backends) for easy use of POT with [Pytorch](https://pytorch.org/)/[jax](https://github.com/google/jax)/[Numpy](https://numpy.org/) arrays.
+* [Several backends](https://pythonot.github.io/quickstart.html#solving-ot-with-multiple-backends) for easy use of POT with [Pytorch](https://pytorch.org/)/[jax](https://github.com/google/jax)/[Numpy](https://numpy.org/)/[Cupy](https://cupy.dev/)/[Tensorflow](https://www.tensorflow.org/) arrays.
POT provides the following Machine Learning related solvers:
@@ -202,12 +202,12 @@ This toolbox benefit a lot from open source research and we would like to thank
* [Gabriel Peyré](http://gpeyre.github.io/) (Wasserstein Barycenters in Matlab)
* [Mathieu Blondel](https://mblondel.org/) (original implementation smooth OT)
-* [Nicolas Bonneel](http://liris.cnrs.fr/~nbonneel/) ( C++ code for EMD)
+* [Nicolas Bonneel](http://liris.cnrs.fr/~nbonneel/) (C++ code for EMD)
* [Marco Cuturi](http://marcocuturi.net/) (Sinkhorn Knopp in Matlab/Cuda)
## Contributions and code of conduct
-Every contribution is welcome and should respect the [contribution guidelines](https://pythonot.github.io/contributing.html). Each member of the project is expected to follow the [code of conduct](https://pythonot.github.io/code_of_conduct.html).
+Every contribution is welcome and should respect the [contribution guidelines](.github/CONTRIBUTING.md). Each member of the project is expected to follow the [code of conduct](.github/CODE_OF_CONDUCT.md).
## Support
@@ -217,7 +217,7 @@ You can ask questions and join the development discussion:
* On the POT [gitter channel](https://gitter.im/PythonOT/community)
* On the POT [mailing list](https://mail.python.org/mm3/mailman3/lists/pot.python.org/)
-You can also post bug reports and feature requests in Github issues. Make sure to read our [guidelines](https://pythonot.github.io/contributing.html) first.
+You can also post bug reports and feature requests in Github issues. Make sure to read our [guidelines](.github/CONTRIBUTING.md) first.
## References
diff --git a/benchmarks/__init__.py b/benchmarks/__init__.py
new file mode 100644
index 0000000..37f5e56
--- /dev/null
+++ b/benchmarks/__init__.py
@@ -0,0 +1,5 @@
+from . import benchmark
+from . import sinkhorn_knopp
+from . import emd
+
+__all__= ["benchmark", "sinkhorn_knopp", "emd"]
diff --git a/benchmarks/benchmark.py b/benchmarks/benchmark.py
new file mode 100644
index 0000000..7973c6b
--- /dev/null
+++ b/benchmarks/benchmark.py
@@ -0,0 +1,105 @@
+# /usr/bin/env python3
+# -*- coding: utf-8 -*-
+
+from ot.backend import get_backend_list, jax, tf
+import gc
+
+
+def setup_backends():
+ if jax:
+ from jax.config import config
+ config.update("jax_enable_x64", True)
+
+ if tf:
+ from tensorflow.python.ops.numpy_ops import np_config
+ np_config.enable_numpy_behavior()
+
+
+def exec_bench(setup, tested_function, param_list, n_runs, warmup_runs):
+ backend_list = get_backend_list()
+ for i, nx in enumerate(backend_list):
+ if nx.__name__ == "tf" and i < len(backend_list) - 1:
+ # Tensorflow should be the last one to be benchmarked because
+ # as far as I'm aware, there is no way to force it to release
+ # GPU memory. Hence, if any other backend is benchmarked after
+ # Tensorflow and requires the usage of a GPU, it will not have the
+ # full memory available and you may have a GPU Out Of Memory error
+ # even though your GPU can technically hold your tensors in memory.
+ backend_list.pop(i)
+ backend_list.append(nx)
+ break
+
+ inputs = [setup(param) for param in param_list]
+ results = dict()
+ for nx in backend_list:
+ for i in range(len(param_list)):
+ print(nx, param_list[i])
+ args = inputs[i]
+ results_nx = nx._bench(
+ tested_function,
+ *args,
+ n_runs=n_runs,
+ warmup_runs=warmup_runs
+ )
+ gc.collect()
+ results_nx_with_param_in_key = dict()
+ for key in results_nx:
+ new_key = (param_list[i], *key)
+ results_nx_with_param_in_key[new_key] = results_nx[key]
+ results.update(results_nx_with_param_in_key)
+ return results
+
+
+def convert_to_html_table(results, param_name, main_title=None, comments=None):
+ string = "<table>\n"
+ keys = list(results.keys())
+ params, names, devices, bitsizes = zip(*keys)
+
+ devices_names = sorted(list(set(zip(devices, names))))
+ params = sorted(list(set(params)))
+ bitsizes = sorted(list(set(bitsizes)))
+ length = len(devices_names) + 1
+ cpus_cols = list(devices).count("CPU") / len(bitsizes) / len(params)
+ gpus_cols = list(devices).count("GPU") / len(bitsizes) / len(params)
+ assert cpus_cols + gpus_cols == len(devices_names)
+
+ if main_title is not None:
+ string += f'<tr><th align="center" colspan="{length}">{str(main_title)}</th></tr>\n'
+
+ for i, bitsize in enumerate(bitsizes):
+
+ if i != 0:
+ string += f'<tr><td colspan="{length}">&nbsp;</td></tr>\n'
+
+ # make bitsize header
+ text = f"{bitsize} bits"
+ if comments is not None:
+ text += " - "
+ if isinstance(comments, (tuple, list)) and len(comments) == len(bitsizes):
+ text += str(comments[i])
+ else:
+ text += str(comments)
+ string += f'<tr><th align="center">Bitsize</th>'
+ string += f'<th align="center" colspan="{length - 1}">{text}</th></tr>\n'
+
+ # make device header
+ string += f'<tr><th align="center">Device</th>'
+ string += f'<th align="center" colspan="{cpus_cols}">CPU</th>'
+ string += f'<th align="center" colspan="{gpus_cols}">GPU</th></tr>\n'
+
+ # make param_name / backend header
+ string += f'<tr><th align="center">{param_name}</th>'
+ for device, name in devices_names:
+ string += f'<th align="center">{name}</th>'
+ string += "</tr>\n"
+
+ # make results rows
+ for param in params:
+ string += f'<tr><td align="center">{param}</td>'
+ for device, name in devices_names:
+ key = (param, name, device, bitsize)
+ string += f'<td align="center">{results[key]:.4f}</td>'
+ string += "</tr>\n"
+
+ string += "</table>"
+ return string
diff --git a/benchmarks/emd.py b/benchmarks/emd.py
new file mode 100644
index 0000000..9f64863
--- /dev/null
+++ b/benchmarks/emd.py
@@ -0,0 +1,40 @@
+# /usr/bin/env python3
+# -*- coding: utf-8 -*-
+
+import numpy as np
+import ot
+from .benchmark import (
+ setup_backends,
+ exec_bench,
+ convert_to_html_table
+)
+
+
+def setup(n_samples):
+ rng = np.random.RandomState(789465132)
+ x = rng.randn(n_samples, 2)
+ y = rng.randn(n_samples, 2)
+
+ a = ot.utils.unif(n_samples)
+ M = ot.dist(x, y)
+ return a, M
+
+
+if __name__ == "__main__":
+ n_runs = 100
+ warmup_runs = 10
+ param_list = [50, 100, 500, 1000, 2000, 5000]
+
+ setup_backends()
+ results = exec_bench(
+ setup=setup,
+ tested_function=lambda a, M: ot.emd(a, a, M),
+ param_list=param_list,
+ n_runs=n_runs,
+ warmup_runs=warmup_runs
+ )
+ print(convert_to_html_table(
+ results,
+ param_name="Sample size",
+ main_title=f"EMD - Averaged on {n_runs} runs"
+ ))
diff --git a/benchmarks/sinkhorn_knopp.py b/benchmarks/sinkhorn_knopp.py
new file mode 100644
index 0000000..3a1ef3f
--- /dev/null
+++ b/benchmarks/sinkhorn_knopp.py
@@ -0,0 +1,42 @@
+# /usr/bin/env python3
+# -*- coding: utf-8 -*-
+
+import numpy as np
+import ot
+from .benchmark import (
+ setup_backends,
+ exec_bench,
+ convert_to_html_table
+)
+
+
+def setup(n_samples):
+ rng = np.random.RandomState(123456789)
+ a = rng.rand(n_samples // 4, 100)
+ b = rng.rand(n_samples, 100)
+
+ wa = ot.unif(n_samples // 4)
+ wb = ot.unif(n_samples)
+
+ M = ot.dist(a.copy(), b.copy())
+ return wa, wb, M
+
+
+if __name__ == "__main__":
+ n_runs = 100
+ warmup_runs = 10
+ param_list = [50, 100, 500, 1000, 2000, 5000]
+
+ setup_backends()
+ results = exec_bench(
+ setup=setup,
+ tested_function=lambda *args: ot.bregman.sinkhorn(*args, reg=1, stopThr=1e-7),
+ param_list=param_list,
+ n_runs=n_runs,
+ warmup_runs=warmup_runs
+ )
+ print(convert_to_html_table(
+ results,
+ param_name="Sample size",
+ main_title=f"Sinkhorn Knopp - Averaged on {n_runs} runs"
+ ))
diff --git a/docs/requirements.txt b/docs/requirements.txt
index 95147d2..2e060b9 100644
--- a/docs/requirements.txt
+++ b/docs/requirements.txt
@@ -4,4 +4,4 @@ numpydoc
memory_profiler
pillow
networkx
-m2r2 \ No newline at end of file
+myst-parser \ No newline at end of file
diff --git a/docs/requirements_rtd.txt b/docs/requirements_rtd.txt
index 5963ea2..11957fb 100644
--- a/docs/requirements_rtd.txt
+++ b/docs/requirements_rtd.txt
@@ -3,7 +3,7 @@ numpydoc
memory_profiler
pillow
networkx
-m2r2
+myst-parser
numpy
scipy>=1.0
cython
diff --git a/docs/source/.github/CODE_OF_CONDUCT.rst b/docs/source/.github/CODE_OF_CONDUCT.rst
new file mode 100644
index 0000000..d4c5cec
--- /dev/null
+++ b/docs/source/.github/CODE_OF_CONDUCT.rst
@@ -0,0 +1,6 @@
+Code of Conduct
+===============
+
+.. include:: ../../../.github/CODE_OF_CONDUCT.md
+ :parser: myst_parser.sphinx_
+ :start-line: 2
diff --git a/docs/source/.github/CONTRIBUTING.rst b/docs/source/.github/CONTRIBUTING.rst
new file mode 100644
index 0000000..aef24e9
--- /dev/null
+++ b/docs/source/.github/CONTRIBUTING.rst
@@ -0,0 +1,6 @@
+Contributing to POT
+===================
+
+.. include:: ../../../.github/CONTRIBUTING.md
+ :parser: myst_parser.sphinx_
+ :start-line: 3
diff --git a/docs/source/code_of_conduct.rst b/docs/source/code_of_conduct.rst
deleted file mode 100644
index b37ba7b..0000000
--- a/docs/source/code_of_conduct.rst
+++ /dev/null
@@ -1 +0,0 @@
-.. mdinclude:: ../../.github/CODE_OF_CONDUCT.md \ No newline at end of file
diff --git a/docs/source/conf.py b/docs/source/conf.py
index 1320afa..849e97c 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -69,7 +69,7 @@ extensions = [
'sphinx.ext.viewcode',
'sphinx.ext.napoleon',
'sphinx_gallery.gen_gallery',
- 'm2r2'
+ 'myst_parser'
]
autosummary_generate = True
diff --git a/docs/source/contributing.rst b/docs/source/contributing.rst
deleted file mode 100644
index dc81e75..0000000
--- a/docs/source/contributing.rst
+++ /dev/null
@@ -1 +0,0 @@
-.. mdinclude:: ../../.github/CONTRIBUTING.md \ No newline at end of file
diff --git a/docs/source/index.rst b/docs/source/index.rst
index 7aaa524..8de31ae 100644
--- a/docs/source/index.rst
+++ b/docs/source/index.rst
@@ -17,12 +17,11 @@ Contents
all
auto_examples/index
releases
- contributing
- Code of Conduct <code_of_conduct>
-
-.. mdinclude:: ../../README.md
- :start-line: 2
+ .github/CONTRIBUTING
+ .github/CODE_OF_CONDUCT
+.. include:: ../../README.md
+ :parser: myst_parser.sphinx_
Indices and tables
diff --git a/ot/backend.py b/ot/backend.py
index 1630ac4..58b652b 100644
--- a/ot/backend.py
+++ b/ot/backend.py
@@ -3,7 +3,7 @@
Multi-lib backend for POT
The goal is to write backend-agnostic code. Whether you're using Numpy, PyTorch,
-Jax, or Cupy, POT code should work nonetheless.
+Jax, Cupy, or Tensorflow, POT code should work nonetheless.
To achieve that, POT provides backend classes which implements functions in their respective backend
imitating Numpy API. As a convention, we use nx instead of np to refer to the backend.
@@ -17,6 +17,68 @@ Examples
... nx = get_backend(a, b) # infer the backend from the arguments
... c = nx.dot(a, b) # now use the backend to do any calculation
... return c
+
+.. warning::
+ Tensorflow only works with the Numpy API. To activate it, please run the following:
+
+ .. code-block::
+
+ from tensorflow.python.ops.numpy_ops import np_config
+ np_config.enable_numpy_behavior()
+
+Performance
+--------
+
+- CPU: Intel(R) Xeon(R) Gold 6248 CPU @ 2.50GHz
+- GPU: Tesla V100-SXM2-32GB
+- Date of the benchmark: December 8th, 2021
+- Commit of benchmark: PR #316, https://github.com/PythonOT/POT/pull/316
+
+.. raw:: html
+
+ <style>
+ #perftable {
+ width: 100%;
+ margin-bottom: 1em;
+ }
+
+ #perftable table{
+ border-collapse: collapse;
+ table-layout: fixed;
+ width: 100%;
+ }
+
+ #perftable th, #perftable td {
+ border: 1px solid #ddd;
+ padding: 8px;
+ font-size: smaller;
+ }
+ </style>
+
+ <div id="perftable">
+ <table>
+ <tr><th align="center" colspan="8">Sinkhorn Knopp - Averaged on 100 runs</th></tr>
+ <tr><th align="center">Bitsize</th><th align="center" colspan="7">32 bits</th></tr>
+ <tr><th align="center">Device</th><th align="center" colspan="3.0"">CPU</th><th align="center" colspan="4.0">GPU</tr>
+ <tr><th align="center">Sample size</th><th align="center">Numpy</th><th align="center">Pytorch</th><th align="center">Tensorflow</th><th align="center">Cupy</th><th align="center">Jax</th><th align="center">Pytorch</th><th align="center">Tensorflow</th></tr>
+ <tr><td align="center">50</td><td align="center">0.0008</td><td align="center">0.0022</td><td align="center">0.0151</td><td align="center">0.0095</td><td align="center">0.0193</td><td align="center">0.0051</td><td align="center">0.0293</td></tr>
+ <tr><td align="center">100</td><td align="center">0.0005</td><td align="center">0.0013</td><td align="center">0.0097</td><td align="center">0.0057</td><td align="center">0.0115</td><td align="center">0.0029</td><td align="center">0.0173</td></tr>
+ <tr><td align="center">500</td><td align="center">0.0009</td><td align="center">0.0016</td><td align="center">0.0110</td><td align="center">0.0058</td><td align="center">0.0115</td><td align="center">0.0029</td><td align="center">0.0166</td></tr>
+ <tr><td align="center">1000</td><td align="center">0.0021</td><td align="center">0.0021</td><td align="center">0.0145</td><td align="center">0.0056</td><td align="center">0.0118</td><td align="center">0.0029</td><td align="center">0.0168</td></tr>
+ <tr><td align="center">2000</td><td align="center">0.0069</td><td align="center">0.0043</td><td align="center">0.0278</td><td align="center">0.0059</td><td align="center">0.0118</td><td align="center">0.0030</td><td align="center">0.0165</td></tr>
+ <tr><td align="center">5000</td><td align="center">0.0707</td><td align="center">0.0314</td><td align="center">0.1395</td><td align="center">0.0074</td><td align="center">0.0125</td><td align="center">0.0035</td><td align="center">0.0198</td></tr>
+ <tr><td colspan="8">&nbsp;</td></tr>
+ <tr><th align="center">Bitsize</th><th align="center" colspan="7">64 bits</th></tr>
+ <tr><th align="center">Device</th><th align="center" colspan="3.0"">CPU</th><th align="center" colspan="4.0">GPU</tr>
+ <tr><th align="center">Sample size</th><th align="center">Numpy</th><th align="center">Pytorch</th><th align="center">Tensorflow</th><th align="center">Cupy</th><th align="center">Jax</th><th align="center">Pytorch</th><th align="center">Tensorflow</th></tr>
+ <tr><td align="center">50</td><td align="center">0.0008</td><td align="center">0.0020</td><td align="center">0.0154</td><td align="center">0.0093</td><td align="center">0.0191</td><td align="center">0.0051</td><td align="center">0.0328</td></tr>
+ <tr><td align="center">100</td><td align="center">0.0005</td><td align="center">0.0013</td><td align="center">0.0094</td><td align="center">0.0056</td><td align="center">0.0114</td><td align="center">0.0029</td><td align="center">0.0169</td></tr>
+ <tr><td align="center">500</td><td align="center">0.0013</td><td align="center">0.0017</td><td align="center">0.0120</td><td align="center">0.0059</td><td align="center">0.0116</td><td align="center">0.0029</td><td align="center">0.0168</td></tr>
+ <tr><td align="center">1000</td><td align="center">0.0034</td><td align="center">0.0027</td><td align="center">0.0177</td><td align="center">0.0058</td><td align="center">0.0118</td><td align="center">0.0029</td><td align="center">0.0167</td></tr>
+ <tr><td align="center">2000</td><td align="center">0.0146</td><td align="center">0.0075</td><td align="center">0.0436</td><td align="center">0.0059</td><td align="center">0.0120</td><td align="center">0.0029</td><td align="center">0.0165</td></tr>
+ <tr><td align="center">5000</td><td align="center">0.1467</td><td align="center">0.0568</td><td align="center">0.2468</td><td align="center">0.0077</td><td align="center">0.0146</td><td align="center">0.0045</td><td align="center">0.0204</td></tr>
+ </table>
+ </div>
"""
# Author: Remi Flamary <remi.flamary@polytechnique.edu>
@@ -27,6 +89,8 @@ Examples
import numpy as np
import scipy.special as scipy
from scipy.sparse import issparse, coo_matrix, csr_matrix
+import warnings
+import time
try:
import torch
@@ -39,6 +103,7 @@ try:
import jax
import jax.numpy as jnp
import jax.scipy.special as jscipy
+ from jax.lib import xla_bridge
jax_type = jax.numpy.ndarray
except ImportError:
jax = False
@@ -52,6 +117,15 @@ except ImportError:
cp = False
cp_type = float
+try:
+ import tensorflow as tf
+ import tensorflow.experimental.numpy as tnp
+ tf_type = tf.Tensor
+except ImportError:
+ tf = False
+ tf_type = float
+
+
str_type_error = "All array should be from the same type/backend. Current types are : {}"
@@ -65,9 +139,12 @@ def get_backend_list():
if jax:
lst.append(JaxBackend())
- if cp:
+ if cp: # pragma: no cover
lst.append(CupyBackend())
+ if tf:
+ lst.append(TensorflowBackend())
+
return lst
@@ -89,8 +166,10 @@ def get_backend(*args):
return TorchBackend()
elif isinstance(args[0], jax_type):
return JaxBackend()
- elif isinstance(args[0], cp_type):
+ elif isinstance(args[0], cp_type): # pragma: no cover
return CupyBackend()
+ elif isinstance(args[0], tf_type):
+ return TensorflowBackend()
else:
raise ValueError("Unknown type of non implemented backend.")
@@ -108,7 +187,7 @@ class Backend():
"""
Backend abstract class.
Implementations: :py:class:`JaxBackend`, :py:class:`NumpyBackend`, :py:class:`TorchBackend`,
- :py:class:`CupyBackend`
+ :py:class:`CupyBackend`, :py:class:`TensorflowBackend`
- The `__name__` class attribute refers to the name of the backend.
- The `__type__` class attribute refers to the data structure used by the backend.
@@ -679,6 +758,34 @@ class Backend():
"""
raise NotImplementedError()
+ def squeeze(self, a, axis=None):
+ r"""
+ Remove axes of length one from a.
+
+ This function follows the api from :any:`numpy.squeeze`.
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.squeeze.html
+ """
+ raise NotImplementedError()
+
+ def bitsize(self, type_as):
+ r"""
+ Gives the number of bits used by the data type of the given tensor.
+ """
+ raise NotImplementedError()
+
+ def device_type(self, type_as):
+ r"""
+ Returns CPU or GPU depending on the device where the given tensor is located.
+ """
+ raise NotImplementedError()
+
+ def _bench(self, callable, *args, n_runs=1, warmup_runs=1):
+ r"""
+ Executes a benchmark of the given callable with the given arguments.
+ """
+ raise NotImplementedError()
+
class NumpyBackend(Backend):
"""
@@ -916,6 +1023,29 @@ class NumpyBackend(Backend):
# numpy has implicit type conversion so we automatically validate the test
pass
+ def squeeze(self, a, axis=None):
+ return np.squeeze(a, axis=axis)
+
+ def bitsize(self, type_as):
+ return type_as.itemsize * 8
+
+ def device_type(self, type_as):
+ return "CPU"
+
+ def _bench(self, callable, *args, n_runs=1, warmup_runs=1):
+ results = dict()
+ for type_as in self.__type_list__:
+ inputs = [self.from_numpy(arg, type_as=type_as) for arg in args]
+ for _ in range(warmup_runs):
+ callable(*inputs)
+ t0 = time.perf_counter()
+ for _ in range(n_runs):
+ callable(*inputs)
+ t1 = time.perf_counter()
+ key = ("Numpy", self.device_type(type_as), self.bitsize(type_as))
+ results[key] = (t1 - t0) / n_runs
+ return results
+
class JaxBackend(Backend):
"""
@@ -934,9 +1064,16 @@ class JaxBackend(Backend):
def __init__(self):
self.rng_ = jax.random.PRNGKey(42)
- for d in jax.devices():
- self.__type_list__ = [jax.device_put(jnp.array(1, dtype=jnp.float32), d),
- jax.device_put(jnp.array(1, dtype=jnp.float64), d)]
+ self.__type_list__ = []
+ # available_devices = jax.devices("cpu")
+ available_devices = []
+ if xla_bridge.get_backend().platform == "gpu":
+ available_devices += jax.devices("gpu")
+ for d in available_devices:
+ self.__type_list__ += [
+ jax.device_put(jnp.array(1, dtype=jnp.float32), d),
+ jax.device_put(jnp.array(1, dtype=jnp.float64), d)
+ ]
def to_numpy(self, a):
return np.array(a)
@@ -1176,6 +1313,32 @@ class JaxBackend(Backend):
assert a_dtype == b_dtype, "Dtype discrepancy"
assert a_device == b_device, f"Device discrepancy. First input is on {str(a_device)}, whereas second input is on {str(b_device)}"
+ def squeeze(self, a, axis=None):
+ return jnp.squeeze(a, axis=axis)
+
+ def bitsize(self, type_as):
+ return type_as.dtype.itemsize * 8
+
+ def device_type(self, type_as):
+ return self.dtype_device(type_as)[1].platform.upper()
+
+ def _bench(self, callable, *args, n_runs=1, warmup_runs=1):
+ results = dict()
+
+ for type_as in self.__type_list__:
+ inputs = [self.from_numpy(arg, type_as=type_as) for arg in args]
+ for _ in range(warmup_runs):
+ a = callable(*inputs)
+ a.block_until_ready()
+ t0 = time.perf_counter()
+ for _ in range(n_runs):
+ a = callable(*inputs)
+ a.block_until_ready()
+ t1 = time.perf_counter()
+ key = ("Jax", self.device_type(type_as), self.bitsize(type_as))
+ results[key] = (t1 - t0) / n_runs
+ return results
+
class TorchBackend(Backend):
"""
@@ -1515,6 +1678,46 @@ class TorchBackend(Backend):
assert a_dtype == b_dtype, "Dtype discrepancy"
assert a_device == b_device, f"Device discrepancy. First input is on {str(a_device)}, whereas second input is on {str(b_device)}"
+ def squeeze(self, a, axis=None):
+ if axis is None:
+ return torch.squeeze(a)
+ else:
+ return torch.squeeze(a, dim=axis)
+
+ def bitsize(self, type_as):
+ return torch.finfo(type_as.dtype).bits
+
+ def device_type(self, type_as):
+ return type_as.device.type.replace("cuda", "gpu").upper()
+
+ def _bench(self, callable, *args, n_runs=1, warmup_runs=1):
+ results = dict()
+ for type_as in self.__type_list__:
+ inputs = [self.from_numpy(arg, type_as=type_as) for arg in args]
+ for _ in range(warmup_runs):
+ callable(*inputs)
+ if self.device_type(type_as) == "GPU": # pragma: no cover
+ torch.cuda.synchronize()
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+ start.record()
+ else:
+ start = time.perf_counter()
+ for _ in range(n_runs):
+ callable(*inputs)
+ if self.device_type(type_as) == "GPU": # pragma: no cover
+ end.record()
+ torch.cuda.synchronize()
+ duration = start.elapsed_time(end) / 1000.
+ else:
+ end = time.perf_counter()
+ duration = end - start
+ key = ("Pytorch", self.device_type(type_as), self.bitsize(type_as))
+ results[key] = duration / n_runs
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+ return results
+
class CupyBackend(Backend): # pragma: no cover
"""
@@ -1798,3 +2001,366 @@ class CupyBackend(Backend): # pragma: no cover
# cupy has implicit type conversion so
# we automatically validate the test for type
assert a_device == b_device, f"Device discrepancy. First input is on {str(a_device)}, whereas second input is on {str(b_device)}"
+
+ def squeeze(self, a, axis=None):
+ return cp.squeeze(a, axis=axis)
+
+ def bitsize(self, type_as):
+ return type_as.itemsize * 8
+
+ def device_type(self, type_as):
+ return "GPU"
+
+ def _bench(self, callable, *args, n_runs=1, warmup_runs=1):
+ mempool = cp.get_default_memory_pool()
+ pinned_mempool = cp.get_default_pinned_memory_pool()
+
+ results = dict()
+ for type_as in self.__type_list__:
+ inputs = [self.from_numpy(arg, type_as=type_as) for arg in args]
+ start_gpu = cp.cuda.Event()
+ end_gpu = cp.cuda.Event()
+ for _ in range(warmup_runs):
+ callable(*inputs)
+ start_gpu.synchronize()
+ start_gpu.record()
+ for _ in range(n_runs):
+ callable(*inputs)
+ end_gpu.record()
+ end_gpu.synchronize()
+ key = ("Cupy", self.device_type(type_as), self.bitsize(type_as))
+ t_gpu = cp.cuda.get_elapsed_time(start_gpu, end_gpu) / 1000.
+ results[key] = t_gpu / n_runs
+ mempool.free_all_blocks()
+ pinned_mempool.free_all_blocks()
+ return results
+
+
+class TensorflowBackend(Backend):
+
+ __name__ = "tf"
+ __type__ = tf_type
+ __type_list__ = None
+
+ rng_ = None
+
+ def __init__(self):
+ self.seed(None)
+
+ self.__type_list__ = [
+ tf.convert_to_tensor([1], dtype=tf.float32),
+ tf.convert_to_tensor([1], dtype=tf.float64)
+ ]
+
+ tmp = self.randn(15, 10)
+ try:
+ tmp.reshape((150, 1))
+ except AttributeError:
+ warnings.warn(
+ "To use TensorflowBackend, you need to activate the tensorflow "
+ "numpy API. You can activate it by running: \n"
+ "from tensorflow.python.ops.numpy_ops import np_config\n"
+ "np_config.enable_numpy_behavior()"
+ )
+
+ def to_numpy(self, a):
+ return a.numpy()
+
+ def from_numpy(self, a, type_as=None):
+ if not isinstance(a, self.__type__):
+ if type_as is None:
+ return tf.convert_to_tensor(a)
+ else:
+ return tf.convert_to_tensor(a, dtype=type_as.dtype)
+ else:
+ if type_as is None:
+ return a
+ else:
+ return tf.cast(a, dtype=type_as.dtype)
+
+ def set_gradients(self, val, inputs, grads):
+ @tf.custom_gradient
+ def tmp(input):
+ def grad(upstream):
+ return grads
+ return val, grad
+ return tmp(inputs)
+
+ def zeros(self, shape, type_as=None):
+ if type_as is None:
+ return tnp.zeros(shape)
+ else:
+ return tnp.zeros(shape, dtype=type_as.dtype)
+
+ def ones(self, shape, type_as=None):
+ if type_as is None:
+ return tnp.ones(shape)
+ else:
+ return tnp.ones(shape, dtype=type_as.dtype)
+
+ def arange(self, stop, start=0, step=1, type_as=None):
+ return tnp.arange(start, stop, step)
+
+ def full(self, shape, fill_value, type_as=None):
+ if type_as is None:
+ return tnp.full(shape, fill_value)
+ else:
+ return tnp.full(shape, fill_value, dtype=type_as.dtype)
+
+ def eye(self, N, M=None, type_as=None):
+ if type_as is None:
+ return tnp.eye(N, M)
+ else:
+ return tnp.eye(N, M, dtype=type_as.dtype)
+
+ def sum(self, a, axis=None, keepdims=False):
+ return tnp.sum(a, axis, keepdims=keepdims)
+
+ def cumsum(self, a, axis=None):
+ return tnp.cumsum(a, axis)
+
+ def max(self, a, axis=None, keepdims=False):
+ return tnp.max(a, axis, keepdims=keepdims)
+
+ def min(self, a, axis=None, keepdims=False):
+ return tnp.min(a, axis, keepdims=keepdims)
+
+ def maximum(self, a, b):
+ return tnp.maximum(a, b)
+
+ def minimum(self, a, b):
+ return tnp.minimum(a, b)
+
+ def dot(self, a, b):
+ if len(b.shape) == 1:
+ if len(a.shape) == 1:
+ # inner product
+ return tf.reduce_sum(tf.multiply(a, b))
+ else:
+ # matrix vector
+ return tf.linalg.matvec(a, b)
+ else:
+ if len(a.shape) == 1:
+ return tf.linalg.matvec(b.T, a.T).T
+ else:
+ return tf.matmul(a, b)
+
+ def abs(self, a):
+ return tnp.abs(a)
+
+ def exp(self, a):
+ return tnp.exp(a)
+
+ def log(self, a):
+ return tnp.log(a)
+
+ def sqrt(self, a):
+ return tnp.sqrt(a)
+
+ def power(self, a, exponents):
+ return tnp.power(a, exponents)
+
+ def norm(self, a):
+ return tf.math.reduce_euclidean_norm(a)
+
+ def any(self, a):
+ return tnp.any(a)
+
+ def isnan(self, a):
+ return tnp.isnan(a)
+
+ def isinf(self, a):
+ return tnp.isinf(a)
+
+ def einsum(self, subscripts, *operands):
+ return tnp.einsum(subscripts, *operands)
+
+ def sort(self, a, axis=-1):
+ return tnp.sort(a, axis)
+
+ def argsort(self, a, axis=-1):
+ return tnp.argsort(a, axis)
+
+ def searchsorted(self, a, v, side='left'):
+ return tf.searchsorted(a, v, side=side)
+
+ def flip(self, a, axis=None):
+ return tnp.flip(a, axis)
+
+ def outer(self, a, b):
+ return tnp.outer(a, b)
+
+ def clip(self, a, a_min, a_max):
+ return tnp.clip(a, a_min, a_max)
+
+ def repeat(self, a, repeats, axis=None):
+ return tnp.repeat(a, repeats, axis)
+
+ def take_along_axis(self, arr, indices, axis):
+ return tnp.take_along_axis(arr, indices, axis)
+
+ def concatenate(self, arrays, axis=0):
+ return tnp.concatenate(arrays, axis)
+
+ def zero_pad(self, a, pad_width):
+ return tnp.pad(a, pad_width, mode="constant")
+
+ def argmax(self, a, axis=None):
+ return tnp.argmax(a, axis=axis)
+
+ def mean(self, a, axis=None):
+ return tnp.mean(a, axis=axis)
+
+ def std(self, a, axis=None):
+ return tnp.std(a, axis=axis)
+
+ def linspace(self, start, stop, num):
+ return tnp.linspace(start, stop, num)
+
+ def meshgrid(self, a, b):
+ return tnp.meshgrid(a, b)
+
+ def diag(self, a, k=0):
+ return tnp.diag(a, k)
+
+ def unique(self, a):
+ return tf.sort(tf.unique(tf.reshape(a, [-1]))[0])
+
+ def logsumexp(self, a, axis=None):
+ return tf.math.reduce_logsumexp(a, axis=axis)
+
+ def stack(self, arrays, axis=0):
+ return tnp.stack(arrays, axis)
+
+ def reshape(self, a, shape):
+ return tnp.reshape(a, shape)
+
+ def seed(self, seed=None):
+ if isinstance(seed, int):
+ self.rng_ = tf.random.Generator.from_seed(seed)
+ elif isinstance(seed, tf.random.Generator):
+ self.rng_ = seed
+ elif seed is None:
+ self.rng_ = tf.random.Generator.from_non_deterministic_state()
+ else:
+ raise ValueError("Non compatible seed : {}".format(seed))
+
+ def rand(self, *size, type_as=None):
+ if type_as is None:
+ return self.rng_.uniform(size, minval=0., maxval=1.)
+ else:
+ return self.rng_.uniform(
+ size, minval=0., maxval=1., dtype=type_as.dtype
+ )
+
+ def randn(self, *size, type_as=None):
+ if type_as is None:
+ return self.rng_.normal(size)
+ else:
+ return self.rng_.normal(size, dtype=type_as.dtype)
+
+ def _convert_to_index_for_coo(self, tensor):
+ if isinstance(tensor, self.__type__):
+ return int(self.max(tensor)) + 1
+ else:
+ return int(np.max(tensor)) + 1
+
+ def coo_matrix(self, data, rows, cols, shape=None, type_as=None):
+ if shape is None:
+ shape = (
+ self._convert_to_index_for_coo(rows),
+ self._convert_to_index_for_coo(cols)
+ )
+ if type_as is not None:
+ data = self.from_numpy(data, type_as=type_as)
+
+ sparse_tensor = tf.sparse.SparseTensor(
+ indices=tnp.stack([rows, cols]).T,
+ values=data,
+ dense_shape=shape
+ )
+ # if type_as is not None:
+ # sparse_tensor = self.from_numpy(sparse_tensor, type_as=type_as)
+ # SparseTensor are not subscriptable so we use dense tensors
+ return self.todense(sparse_tensor)
+
+ def issparse(self, a):
+ return isinstance(a, tf.sparse.SparseTensor)
+
+ def tocsr(self, a):
+ return a
+
+ def eliminate_zeros(self, a, threshold=0.):
+ if self.issparse(a):
+ values = a.values
+ if threshold > 0:
+ mask = self.abs(values) <= threshold
+ else:
+ mask = values == 0
+ return tf.sparse.retain(a, ~mask)
+ else:
+ if threshold > 0:
+ a = tnp.where(self.abs(a) > threshold, a, 0.)
+ return a
+
+ def todense(self, a):
+ if self.issparse(a):
+ return tf.sparse.to_dense(tf.sparse.reorder(a))
+ else:
+ return a
+
+ def where(self, condition, x, y):
+ return tnp.where(condition, x, y)
+
+ def copy(self, a):
+ return tf.identity(a)
+
+ def allclose(self, a, b, rtol=1e-05, atol=1e-08, equal_nan=False):
+ return tnp.allclose(
+ a, b, rtol=rtol, atol=atol, equal_nan=equal_nan
+ )
+
+ def dtype_device(self, a):
+ return a.dtype, a.device.split("device:")[1]
+
+ def assert_same_dtype_device(self, a, b):
+ a_dtype, a_device = self.dtype_device(a)
+ b_dtype, b_device = self.dtype_device(b)
+
+ assert a_dtype == b_dtype, "Dtype discrepancy"
+ assert a_device == b_device, f"Device discrepancy. First input is on {str(a_device)}, whereas second input is on {str(b_device)}"
+
+ def squeeze(self, a, axis=None):
+ return tnp.squeeze(a, axis=axis)
+
+ def bitsize(self, type_as):
+ return type_as.dtype.size * 8
+
+ def device_type(self, type_as):
+ return self.dtype_device(type_as)[1].split(":")[0]
+
+ def _bench(self, callable, *args, n_runs=1, warmup_runs=1):
+ results = dict()
+ device_contexts = [tf.device("/CPU:0")]
+ if len(tf.config.list_physical_devices('GPU')) > 0: # pragma: no cover
+ device_contexts.append(tf.device("/GPU:0"))
+
+ for device_context in device_contexts:
+ with device_context:
+ for type_as in self.__type_list__:
+ inputs = [self.from_numpy(arg, type_as=type_as) for arg in args]
+ for _ in range(warmup_runs):
+ callable(*inputs)
+ t0 = time.perf_counter()
+ for _ in range(n_runs):
+ res = callable(*inputs)
+ _ = res.numpy()
+ t1 = time.perf_counter()
+ key = (
+ "Tensorflow",
+ self.device_type(inputs[0]),
+ self.bitsize(type_as)
+ )
+ results[key] = (t1 - t0) / n_runs
+
+ return results
diff --git a/ot/bregman.py b/ot/bregman.py
index cce52e2..fc20175 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -830,9 +830,9 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False,
a, b, M = list_to_array(a, b, M)
nx = get_backend(M, a, b)
- if nx.__name__ == "jax":
- raise TypeError("JAX arrays have been received. Greenkhorn is not "
- "compatible with JAX")
+ if nx.__name__ in ("jax", "tf"):
+ raise TypeError("JAX or TF arrays have been received. Greenkhorn is not "
+ "compatible with neither JAX nor TF")
if len(a) == 0:
a = nx.ones((M.shape[0],), type_as=M) / M.shape[0]
@@ -865,20 +865,20 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False,
if m_viol_1 > m_viol_2:
old_u = u[i_1]
- new_u = a[i_1] / (K[i_1, :].dot(v))
+ new_u = a[i_1] / nx.dot(K[i_1, :], v)
G[i_1, :] = new_u * K[i_1, :] * v
- viol[i_1] = new_u * K[i_1, :].dot(v) - a[i_1]
+ viol[i_1] = nx.dot(new_u * K[i_1, :], v) - a[i_1]
viol_2 += (K[i_1, :].T * (new_u - old_u) * v)
u[i_1] = new_u
else:
old_v = v[i_2]
- new_v = b[i_2] / (K[:, i_2].T.dot(u))
+ new_v = b[i_2] / nx.dot(K[:, i_2].T, u)
G[:, i_2] = u * K[:, i_2] * new_v
# aviol = (G@one_m - a)
# aviol_2 = (G.T@one_n - b)
viol += (-old_v + new_v) * K[:, i_2] * u
- viol_2[i_2] = new_v * K[:, i_2].dot(u) - b[i_2]
+ viol_2[i_2] = new_v * nx.dot(K[:, i_2], u) - b[i_2]
v[i_2] = new_v
if stopThr_val <= stopThr:
@@ -1550,9 +1550,11 @@ def _barycenter_sinkhorn_log(A, M, reg, weights=None, numItermax=1000,
nx = get_backend(A, M)
- if nx.__name__ == "jax":
- raise NotImplementedError("Log-domain functions are not yet implemented"
- " for Jax. Use numpy or torch arrays instead.")
+ if nx.__name__ in ("jax", "tf"):
+ raise NotImplementedError(
+ "Log-domain functions are not yet implemented"
+ " for Jax and tf. Use numpy or torch arrays instead."
+ )
if weights is None:
weights = nx.ones(n_hists, type_as=A) / n_hists
@@ -1886,9 +1888,11 @@ def _barycenter_debiased_log(A, M, reg, weights=None, numItermax=1000,
dim, n_hists = A.shape
nx = get_backend(A, M)
- if nx.__name__ == "jax":
- raise NotImplementedError("Log-domain functions are not yet implemented"
- " for Jax. Use numpy or torch arrays instead.")
+ if nx.__name__ in ("jax", "tf"):
+ raise NotImplementedError(
+ "Log-domain functions are not yet implemented"
+ " for Jax and TF. Use numpy or torch arrays instead."
+ )
if weights is None:
weights = nx.ones(n_hists, type_as=A) / n_hists
@@ -2043,7 +2047,7 @@ def _convolutional_barycenter2d(A, reg, weights=None, numItermax=10000,
log = {'err': []}
bar = nx.ones(A.shape[1:], type_as=A)
- bar /= bar.sum()
+ bar /= nx.sum(bar)
U = nx.ones(A.shape, type_as=A)
V = nx.ones(A.shape, type_as=A)
err = 1
@@ -2069,9 +2073,11 @@ def _convolutional_barycenter2d(A, reg, weights=None, numItermax=10000,
KV = convol_imgs(V)
U = A / KV
KU = convol_imgs(U)
- bar = nx.exp((weights[:, None, None] * nx.log(KU + stabThr)).sum(axis=0))
+ bar = nx.exp(
+ nx.sum(weights[:, None, None] * nx.log(KU + stabThr), axis=0)
+ )
if ii % 10 == 9:
- err = (V * KU).std(axis=0).sum()
+ err = nx.sum(nx.std(V * KU, axis=0))
# log and verbose print
if log:
log['err'].append(err)
@@ -2106,9 +2112,11 @@ def _convolutional_barycenter2d_log(A, reg, weights=None, numItermax=10000,
A = list_to_array(A)
nx = get_backend(A)
- if nx.__name__ == "jax":
- raise NotImplementedError("Log-domain functions are not yet implemented"
- " for Jax. Use numpy or torch arrays instead.")
+ if nx.__name__ in ("jax", "tf"):
+ raise NotImplementedError(
+ "Log-domain functions are not yet implemented"
+ " for Jax and TF. Use numpy or torch arrays instead."
+ )
n_hists, width, height = A.shape
@@ -2298,13 +2306,15 @@ def _convolutional_barycenter2d_debiased(A, reg, weights=None, numItermax=10000,
KV = convol_imgs(V)
U = A / KV
KU = convol_imgs(U)
- bar = c * nx.exp((weights[:, None, None] * nx.log(KU + stabThr)).sum(axis=0))
+ bar = c * nx.exp(
+ nx.sum(weights[:, None, None] * nx.log(KU + stabThr), axis=0)
+ )
for _ in range(10):
- c = (c * bar / convol_imgs(c[None]).squeeze()) ** 0.5
+ c = (c * bar / nx.squeeze(convol_imgs(c[None]))) ** 0.5
if ii % 10 == 9:
- err = (V * KU).std(axis=0).sum()
+ err = nx.sum(nx.std(V * KU, axis=0))
# log and verbose print
if log:
log['err'].append(err)
@@ -2340,9 +2350,11 @@ def _convolutional_barycenter2d_debiased_log(A, reg, weights=None, numItermax=10
A = list_to_array(A)
n_hists, width, height = A.shape
nx = get_backend(A)
- if nx.__name__ == "jax":
- raise NotImplementedError("Log-domain functions are not yet implemented"
- " for Jax. Use numpy or torch arrays instead.")
+ if nx.__name__ in ("jax", "tf"):
+ raise NotImplementedError(
+ "Log-domain functions are not yet implemented"
+ " for Jax and TF. Use numpy or torch arrays instead."
+ )
if weights is None:
weights = nx.ones((n_hists,), type_as=A) / n_hists
else:
@@ -2382,7 +2394,7 @@ def _convolutional_barycenter2d_debiased_log(A, reg, weights=None, numItermax=10
c = 0.5 * (c + log_bar - convol_img(c))
if ii % 10 == 9:
- err = nx.exp(G + log_KU).std(axis=0).sum()
+ err = nx.sum(nx.std(nx.exp(G + log_KU), axis=0))
# log and verbose print
if log:
log['err'].append(err)
@@ -3312,9 +3324,9 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False,
a, b, M = list_to_array(a, b, M)
nx = get_backend(M, a, b)
- if nx.__name__ == "jax":
- raise TypeError("JAX arrays have been received but screenkhorn is not "
- "compatible with JAX.")
+ if nx.__name__ in ("jax", "tf"):
+ raise TypeError("JAX or TF arrays have been received but screenkhorn is not "
+ "compatible with neither JAX nor TF.")
ns, nt = M.shape
@@ -3328,7 +3340,7 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False,
K = nx.exp(-M / reg)
def projection(u, epsilon):
- u[u <= epsilon] = epsilon
+ u = nx.maximum(u, epsilon)
return u
# ----------------------------------------------------------------------------------------------------------------#
diff --git a/ot/da.py b/ot/da.py
index 4fd97df..841f31a 100644
--- a/ot/da.py
+++ b/ot/da.py
@@ -906,7 +906,7 @@ def emd_laplace(a, b, xs, xt, M, sim='knn', sim_param=None, reg='pos', eta=1, al
def distribution_estimation_uniform(X):
- """estimates a uniform distribution from an array of samples :math:`\mathbf{X}`
+ r"""estimates a uniform distribution from an array of samples :math:`\mathbf{X}`
Parameters
----------
@@ -950,7 +950,7 @@ class BaseTransport(BaseEstimator):
"""
def fit(self, Xs=None, ys=None, Xt=None, yt=None):
- """Build a coupling matrix from source and target sets of samples
+ r"""Build a coupling matrix from source and target sets of samples
:math:`(\mathbf{X_s}, \mathbf{y_s})` and :math:`(\mathbf{X_t}, \mathbf{y_t})`
Parameters
@@ -1010,7 +1010,7 @@ class BaseTransport(BaseEstimator):
return self
def fit_transform(self, Xs=None, ys=None, Xt=None, yt=None):
- """Build a coupling matrix from source and target sets of samples
+ r"""Build a coupling matrix from source and target sets of samples
:math:`(\mathbf{X_s}, \mathbf{y_s})` and :math:`(\mathbf{X_t}, \mathbf{y_t})`
and transports source samples :math:`\mathbf{X_s}` onto target ones :math:`\mathbf{X_t}`
@@ -1038,7 +1038,7 @@ class BaseTransport(BaseEstimator):
return self.fit(Xs, ys, Xt, yt).transform(Xs, ys, Xt, yt)
def transform(self, Xs=None, ys=None, Xt=None, yt=None, batch_size=128):
- """Transports source samples :math:`\mathbf{X_s}` onto target ones :math:`\mathbf{X_t}`
+ r"""Transports source samples :math:`\mathbf{X_s}` onto target ones :math:`\mathbf{X_t}`
Parameters
----------
@@ -1105,7 +1105,7 @@ class BaseTransport(BaseEstimator):
return transp_Xs
def transform_labels(self, ys=None):
- """Propagate source labels :math:`\mathbf{y_s}` to obtain estimated target labels as in
+ r"""Propagate source labels :math:`\mathbf{y_s}` to obtain estimated target labels as in
:ref:`[27] <references-basetransport-transform-labels>`.
Parameters
@@ -1152,7 +1152,7 @@ class BaseTransport(BaseEstimator):
def inverse_transform(self, Xs=None, ys=None, Xt=None, yt=None,
batch_size=128):
- """Transports target samples :math:`\mathbf{X_t}` onto source samples :math:`\mathbf{X_s}`
+ r"""Transports target samples :math:`\mathbf{X_t}` onto source samples :math:`\mathbf{X_s}`
Parameters
----------
@@ -1218,7 +1218,7 @@ class BaseTransport(BaseEstimator):
return transp_Xt
def inverse_transform_labels(self, yt=None):
- """Propagate target labels :math:`\mathbf{y_t}` to obtain estimated source labels
+ r"""Propagate target labels :math:`\mathbf{y_t}` to obtain estimated source labels
:math:`\mathbf{y_s}`
Parameters
@@ -1307,7 +1307,7 @@ class LinearTransport(BaseTransport):
self.distribution_estimation = distribution_estimation
def fit(self, Xs=None, ys=None, Xt=None, yt=None):
- """Build a coupling matrix from source and target sets of samples
+ r"""Build a coupling matrix from source and target sets of samples
:math:`(\mathbf{X_s}, \mathbf{y_s})` and :math:`(\mathbf{X_t}, \mathbf{y_t})`
Parameters
@@ -1354,7 +1354,7 @@ class LinearTransport(BaseTransport):
return self
def transform(self, Xs=None, ys=None, Xt=None, yt=None, batch_size=128):
- """Transports source samples :math:`\mathbf{X_s}` onto target ones :math:`\mathbf{X_t}`
+ r"""Transports source samples :math:`\mathbf{X_s}` onto target ones :math:`\mathbf{X_t}`
Parameters
----------
@@ -1387,7 +1387,7 @@ class LinearTransport(BaseTransport):
def inverse_transform(self, Xs=None, ys=None, Xt=None, yt=None,
batch_size=128):
- """Transports target samples :math:`\mathbf{X_t}` onto source samples :math:`\mathbf{X_s}`
+ r"""Transports target samples :math:`\mathbf{X_t}` onto source samples :math:`\mathbf{X_s}`
Parameters
----------
@@ -1493,7 +1493,7 @@ class SinkhornTransport(BaseTransport):
self.out_of_sample_map = out_of_sample_map
def fit(self, Xs=None, ys=None, Xt=None, yt=None):
- """Build a coupling matrix from source and target sets of samples
+ r"""Build a coupling matrix from source and target sets of samples
:math:`(\mathbf{X_s}, \mathbf{y_s})` and :math:`(\mathbf{X_t}, \mathbf{y_t})`
Parameters
@@ -1592,7 +1592,7 @@ class EMDTransport(BaseTransport):
self.max_iter = max_iter
def fit(self, Xs, ys=None, Xt=None, yt=None):
- """Build a coupling matrix from source and target sets of samples
+ r"""Build a coupling matrix from source and target sets of samples
:math:`(\mathbf{X_s}, \mathbf{y_s})` and :math:`(\mathbf{X_t}, \mathbf{y_t})`
Parameters
@@ -1711,7 +1711,7 @@ class SinkhornLpl1Transport(BaseTransport):
self.limit_max = limit_max
def fit(self, Xs, ys=None, Xt=None, yt=None):
- """Build a coupling matrix from source and target sets of samples
+ r"""Build a coupling matrix from source and target sets of samples
:math:`(\mathbf{X_s}, \mathbf{y_s})` and :math:`(\mathbf{X_t}, \mathbf{y_t})`
Parameters
@@ -1839,7 +1839,7 @@ class EMDLaplaceTransport(BaseTransport):
self.out_of_sample_map = out_of_sample_map
def fit(self, Xs, ys=None, Xt=None, yt=None):
- """Build a coupling matrix from source and target sets of samples
+ r"""Build a coupling matrix from source and target sets of samples
:math:`(\mathbf{X_s}, \mathbf{y_s})` and :math:`(\mathbf{X_t}, \mathbf{y_t})`
Parameters
@@ -1962,7 +1962,7 @@ class SinkhornL1l2Transport(BaseTransport):
self.limit_max = limit_max
def fit(self, Xs, ys=None, Xt=None, yt=None):
- """Build a coupling matrix from source and target sets of samples
+ r"""Build a coupling matrix from source and target sets of samples
:math:`(\mathbf{X_s}, \mathbf{y_s})` and :math:`(\mathbf{X_t}, \mathbf{y_t})`
Parameters
@@ -2088,7 +2088,7 @@ class MappingTransport(BaseEstimator):
self.verbose2 = verbose2
def fit(self, Xs=None, ys=None, Xt=None, yt=None):
- """Builds an optimal coupling and estimates the associated mapping
+ r"""Builds an optimal coupling and estimates the associated mapping
from source and target sets of samples
:math:`(\mathbf{X_s}, \mathbf{y_s})` and :math:`(\mathbf{X_t}, \mathbf{y_t})`
@@ -2146,7 +2146,7 @@ class MappingTransport(BaseEstimator):
return self
def transform(self, Xs):
- """Transports source samples :math:`\mathbf{X_s}` onto target ones :math:`\mathbf{X_t}`
+ r"""Transports source samples :math:`\mathbf{X_s}` onto target ones :math:`\mathbf{X_t}`
Parameters
----------
@@ -2261,7 +2261,7 @@ class UnbalancedSinkhornTransport(BaseTransport):
self.limit_max = limit_max
def fit(self, Xs, ys=None, Xt=None, yt=None):
- """Build a coupling matrix from source and target sets of samples
+ r"""Build a coupling matrix from source and target sets of samples
:math:`(\mathbf{X_s}, \mathbf{y_s})` and :math:`(\mathbf{X_t}, \mathbf{y_t})`
Parameters
@@ -2373,7 +2373,7 @@ class JCPOTTransport(BaseTransport):
self.out_of_sample_map = out_of_sample_map
def fit(self, Xs, ys=None, Xt=None, yt=None):
- """Building coupling matrices from a list of source and target sets of samples
+ r"""Building coupling matrices from a list of source and target sets of samples
:math:`(\mathbf{X_s}, \mathbf{y_s})` and :math:`(\mathbf{X_t}, \mathbf{y_t})`
Parameters
@@ -2419,7 +2419,7 @@ class JCPOTTransport(BaseTransport):
return self
def transform(self, Xs=None, ys=None, Xt=None, yt=None, batch_size=128):
- """Transports source samples :math:`\mathbf{X_s}` onto target ones :math:`\mathbf{X_t}`
+ r"""Transports source samples :math:`\mathbf{X_s}` onto target ones :math:`\mathbf{X_t}`
Parameters
----------
@@ -2491,7 +2491,7 @@ class JCPOTTransport(BaseTransport):
return transp_Xs
def transform_labels(self, ys=None):
- """Propagate source labels :math:`\mathbf{y_s}` to obtain target labels as in
+ r"""Propagate source labels :math:`\mathbf{y_s}` to obtain target labels as in
:ref:`[27] <references-jcpottransport-transform-labels>`
Parameters
@@ -2542,7 +2542,7 @@ class JCPOTTransport(BaseTransport):
return yt.T
def inverse_transform_labels(self, yt=None):
- """Propagate target labels :math:`\mathbf{y_t}` to obtain estimated source labels
+ r"""Propagate target labels :math:`\mathbf{y_t}` to obtain estimated source labels
:math:`\mathbf{y_s}`
Parameters
diff --git a/ot/datasets.py b/ot/datasets.py
index ad6390c..a839074 100644
--- a/ot/datasets.py
+++ b/ot/datasets.py
@@ -41,7 +41,7 @@ def get_1D_gauss(n, m, sigma):
def make_2D_samples_gauss(n, m, sigma, random_state=None):
- """Return `n` samples drawn from 2D gaussian :math:`\mathcal{N}(m, \sigma)`
+ r"""Return `n` samples drawn from 2D gaussian :math:`\mathcal{N}(m, \sigma)`
Parameters
----------
diff --git a/ot/dr.py b/ot/dr.py
index c2f51f8..1671ca0 100644
--- a/ot/dr.py
+++ b/ot/dr.py
@@ -16,6 +16,7 @@ Dimension reduction with OT
from scipy import linalg
import autograd.numpy as np
+from pymanopt.function import Autograd
from pymanopt.manifolds import Stiefel
from pymanopt import Problem
from pymanopt.solvers import SteepestDescent, TrustRegions
@@ -181,6 +182,7 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None, no
else:
regmean = np.ones((len(xc), len(xc)))
+ @Autograd
def cost(P):
# wda loss
loss_b = 0
diff --git a/ot/gromov.py b/ot/gromov.py
index dc95c74..6544260 100644
--- a/ot/gromov.py
+++ b/ot/gromov.py
@@ -947,7 +947,7 @@ def pointwise_gromov_wasserstein(C1, C2, p, q, loss_fun,
index[0] = generator.choice(len_p, size=1, p=nx.to_numpy(p))
T_index0 = nx.reshape(nx.todense(T[index[0], :]), (-1,))
index[1] = generator.choice(
- len_q, size=1, p=nx.to_numpy(T_index0 / T_index0.sum())
+ len_q, size=1, p=nx.to_numpy(T_index0 / nx.sum(T_index0))
)
if alpha == 1:
diff --git a/ot/lp/solver_1d.py b/ot/lp/solver_1d.py
index 8b4d0c3..43763a9 100644
--- a/ot/lp/solver_1d.py
+++ b/ot/lp/solver_1d.py
@@ -100,11 +100,11 @@ def wasserstein_1d(u_values, v_values, u_weights=None, v_weights=None, p=1, requ
m = v_values.shape[0]
if u_weights is None:
- u_weights = nx.full(u_values.shape, 1. / n)
+ u_weights = nx.full(u_values.shape, 1. / n, type_as=u_values)
elif u_weights.ndim != u_values.ndim:
u_weights = nx.repeat(u_weights[..., None], u_values.shape[-1], -1)
if v_weights is None:
- v_weights = nx.full(v_values.shape, 1. / m)
+ v_weights = nx.full(v_values.shape, 1. / m, type_as=v_values)
elif v_weights.ndim != v_values.ndim:
v_weights = nx.repeat(v_weights[..., None], v_values.shape[-1], -1)
diff --git a/ot/plot.py b/ot/plot.py
index 3e3bed7..2208c90 100644
--- a/ot/plot.py
+++ b/ot/plot.py
@@ -18,7 +18,7 @@ from matplotlib import gridspec
def plot1D_mat(a, b, M, title=''):
- """ Plot matrix :math:`\mathbf{M}` with the source and target 1D distribution
+ r""" Plot matrix :math:`\mathbf{M}` with the source and target 1D distribution
Creates a subplot with the source distribution :math:`\mathbf{a}` on the left and
target distribution :math:`\mathbf{b}` on the top. The matrix :math:`\mathbf{M}` is shown in between.
@@ -61,7 +61,7 @@ def plot1D_mat(a, b, M, title=''):
def plot2D_samples_mat(xs, xt, G, thr=1e-8, **kwargs):
- """ Plot matrix :math:`\mathbf{G}` in 2D with lines using alpha values
+ r""" Plot matrix :math:`\mathbf{G}` in 2D with lines using alpha values
Plot lines between source and target 2D samples with a color
proportional to the value of the matrix :math:`\mathbf{G}` between samples.
diff --git a/requirements.txt b/requirements.txt
index 4353247..d43be7a 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -4,10 +4,11 @@ cython
matplotlib
autograd
pymanopt==0.2.4; python_version <'3'
-pymanopt; python_version >= '3'
+pymanopt==0.2.6rc1; python_version >= '3'
cvxopt
scikit-learn
torch
jax
jaxlib
+tensorflow
pytest \ No newline at end of file
diff --git a/test/conftest.py b/test/conftest.py
index 987d98e..c0db8ab 100644
--- a/test/conftest.py
+++ b/test/conftest.py
@@ -5,7 +5,7 @@
# License: MIT License
import pytest
-from ot.backend import jax
+from ot.backend import jax, tf
from ot.backend import get_backend_list
import functools
@@ -13,6 +13,10 @@ if jax:
from jax.config import config
config.update("jax_enable_x64", True)
+if tf:
+ from tensorflow.python.ops.numpy_ops import np_config
+ np_config.enable_numpy_behavior()
+
backend_list = get_backend_list()
@@ -24,16 +28,16 @@ def nx(request):
def skip_arg(arg, value, reason=None, getter=lambda x: x):
- if isinstance(arg, tuple) or isinstance(arg, list):
+ if isinstance(arg, (tuple, list)):
n = len(arg)
else:
arg = (arg, )
n = 1
- if n != 1 and (isinstance(value, tuple) or isinstance(value, list)):
+ if n != 1 and isinstance(value, (tuple, list)):
pass
else:
value = (value, )
- if isinstance(getter, tuple) or isinstance(value, list):
+ if isinstance(getter, (tuple, list)):
pass
else:
getter = [getter] * n
diff --git a/test/test_1d_solver.py b/test/test_1d_solver.py
index cb85cb9..6a42cfe 100644
--- a/test/test_1d_solver.py
+++ b/test/test_1d_solver.py
@@ -11,7 +11,7 @@ import pytest
import ot
from ot.lp import wasserstein_1d
-from ot.backend import get_backend_list
+from ot.backend import get_backend_list, tf
from scipy.stats import wasserstein_distance
backend_list = get_backend_list()
@@ -86,7 +86,6 @@ def test_wasserstein_1d(nx):
def test_wasserstein_1d_type_devices(nx):
-
rng = np.random.RandomState(0)
n = 10
@@ -108,6 +107,37 @@ def test_wasserstein_1d_type_devices(nx):
nx.assert_same_dtype_device(xb, res)
+@pytest.mark.skipif(not tf, reason="tf not installed")
+def test_wasserstein_1d_device_tf():
+ if not tf:
+ return
+ nx = ot.backend.TensorflowBackend()
+ rng = np.random.RandomState(0)
+ n = 10
+ x = np.linspace(0, 5, n)
+ rho_u = np.abs(rng.randn(n))
+ rho_u /= rho_u.sum()
+ rho_v = np.abs(rng.randn(n))
+ rho_v /= rho_v.sum()
+
+ # Check that everything stays on the CPU
+ with tf.device("/CPU:0"):
+ xb = nx.from_numpy(x)
+ rho_ub = nx.from_numpy(rho_u)
+ rho_vb = nx.from_numpy(rho_v)
+ res = wasserstein_1d(xb, xb, rho_ub, rho_vb, p=1)
+ nx.assert_same_dtype_device(xb, res)
+
+ if len(tf.config.list_physical_devices('GPU')) > 0:
+ # Check that everything happens on the GPU
+ xb = nx.from_numpy(x)
+ rho_ub = nx.from_numpy(rho_u)
+ rho_vb = nx.from_numpy(rho_v)
+ res = wasserstein_1d(xb, xb, rho_ub, rho_vb, p=1)
+ nx.assert_same_dtype_device(xb, res)
+ assert nx.dtype_device(res)[1].startswith("GPU")
+
+
def test_emd_1d_emd2_1d():
# test emd1d gives similar results as emd
n = 20
@@ -148,7 +178,6 @@ def test_emd_1d_emd2_1d():
def test_emd1d_type_devices(nx):
-
rng = np.random.RandomState(0)
n = 10
@@ -170,3 +199,36 @@ def test_emd1d_type_devices(nx):
nx.assert_same_dtype_device(xb, emd)
nx.assert_same_dtype_device(xb, emd2)
+
+
+@pytest.mark.skipif(not tf, reason="tf not installed")
+def test_emd1d_device_tf():
+ nx = ot.backend.TensorflowBackend()
+ rng = np.random.RandomState(0)
+ n = 10
+ x = np.linspace(0, 5, n)
+ rho_u = np.abs(rng.randn(n))
+ rho_u /= rho_u.sum()
+ rho_v = np.abs(rng.randn(n))
+ rho_v /= rho_v.sum()
+
+ # Check that everything stays on the CPU
+ with tf.device("/CPU:0"):
+ xb = nx.from_numpy(x)
+ rho_ub = nx.from_numpy(rho_u)
+ rho_vb = nx.from_numpy(rho_v)
+ emd = ot.emd_1d(xb, xb, rho_ub, rho_vb)
+ emd2 = ot.emd2_1d(xb, xb, rho_ub, rho_vb)
+ nx.assert_same_dtype_device(xb, emd)
+ nx.assert_same_dtype_device(xb, emd2)
+
+ if len(tf.config.list_physical_devices('GPU')) > 0:
+ # Check that everything happens on the GPU
+ xb = nx.from_numpy(x)
+ rho_ub = nx.from_numpy(rho_u)
+ rho_vb = nx.from_numpy(rho_v)
+ emd = ot.emd_1d(xb, xb, rho_ub, rho_vb)
+ emd2 = ot.emd2_1d(xb, xb, rho_ub, rho_vb)
+ nx.assert_same_dtype_device(xb, emd)
+ nx.assert_same_dtype_device(xb, emd2)
+ assert nx.dtype_device(emd)[1].startswith("GPU")
diff --git a/test/test_backend.py b/test/test_backend.py
index 2e7eecc..027c4cd 100644
--- a/test/test_backend.py
+++ b/test/test_backend.py
@@ -7,7 +7,7 @@
import ot
import ot.backend
-from ot.backend import torch, jax, cp
+from ot.backend import torch, jax, cp, tf
import pytest
@@ -101,6 +101,20 @@ def test_get_backend():
with pytest.raises(ValueError):
get_backend(A, B2)
+ if tf:
+ A2 = tf.convert_to_tensor(A)
+ B2 = tf.convert_to_tensor(B)
+
+ nx = get_backend(A2)
+ assert nx.__name__ == 'tf'
+
+ nx = get_backend(A2, B2)
+ assert nx.__name__ == 'tf'
+
+ # test not unique types in input
+ with pytest.raises(ValueError):
+ get_backend(A, B2)
+
def test_convert_between_backends(nx):
@@ -242,6 +256,14 @@ def test_empty_backend():
nx.copy(M)
with pytest.raises(NotImplementedError):
nx.allclose(M, M)
+ with pytest.raises(NotImplementedError):
+ nx.squeeze(M)
+ with pytest.raises(NotImplementedError):
+ nx.bitsize(M)
+ with pytest.raises(NotImplementedError):
+ nx.device_type(M)
+ with pytest.raises(NotImplementedError):
+ nx._bench(lambda x: x, M, n_runs=1)
def test_func_backends(nx):
@@ -491,7 +513,7 @@ def test_func_backends(nx):
lst_name.append('coo_matrix')
assert not nx.issparse(Mb), 'Assert fail on: issparse (expected False)'
- assert nx.issparse(sp_Mb) or nx.__name__ == "jax", 'Assert fail on: issparse (expected True)'
+ assert nx.issparse(sp_Mb) or nx.__name__ in ("jax", "tf"), 'Assert fail on: issparse (expected True)'
A = nx.tocsr(sp_Mb)
lst_b.append(nx.to_numpy(nx.todense(A)))
@@ -516,6 +538,18 @@ def test_func_backends(nx):
assert nx.allclose(Mb, Mb), 'Assert fail on: allclose (expected True)'
assert not nx.allclose(2 * Mb, Mb), 'Assert fail on: allclose (expected False)'
+ A = nx.squeeze(nx.zeros((3, 1, 4, 1)))
+ assert tuple(A.shape) == (3, 4), 'Assert fail on: squeeze'
+
+ A = nx.bitsize(Mb)
+ lst_b.append(float(A))
+ lst_name.append("bitsize")
+
+ A = nx.device_type(Mb)
+ assert A in ("CPU", "GPU")
+
+ nx._bench(lambda x: x, M, n_runs=1)
+
lst_tot.append(lst_b)
lst_np = lst_tot[0]
@@ -590,3 +624,17 @@ def test_gradients_backends():
np.testing.assert_almost_equal(fun(v, c, e), c * np.sum(v ** 4) + e, decimal=4)
np.testing.assert_allclose(grad_val[0], v, atol=1e-4)
np.testing.assert_allclose(grad_val[2], 2 * e, atol=1e-4)
+
+ if tf:
+ nx = ot.backend.TensorflowBackend()
+ w = tf.Variable(tf.random.normal((3, 2)), name='w')
+ b = tf.Variable(tf.random.normal((2,), dtype=tf.float32), name='b')
+ x = tf.random.normal((1, 3), dtype=tf.float32)
+
+ with tf.GradientTape() as tape:
+ y = x @ w + b
+ loss = tf.reduce_mean(y ** 2)
+ manipulated_loss = nx.set_gradients(loss, (w, b), (w, b))
+ [dl_dw, dl_db] = tape.gradient(manipulated_loss, [w, b])
+ assert nx.allclose(dl_dw, w)
+ assert nx.allclose(dl_db, b)
diff --git a/test/test_bregman.py b/test/test_bregman.py
index f42ac6f..6e90aa4 100644
--- a/test/test_bregman.py
+++ b/test/test_bregman.py
@@ -12,7 +12,7 @@ import numpy as np
import pytest
import ot
-from ot.backend import torch
+from ot.backend import torch, tf
@pytest.mark.parametrize("verbose, warn", product([True, False], [True, False]))
@@ -248,6 +248,7 @@ def test_sinkhorn_empty():
ot.sinkhorn([], [], M, 1, method='greenkhorn', stopThr=1e-10, log=True)
+@pytest.skip_backend('tf')
@pytest.skip_backend("jax")
def test_sinkhorn_variants(nx):
# test sinkhorn
@@ -282,6 +283,8 @@ def test_sinkhorn_variants(nx):
"sinkhorn_epsilon_scaling",
"greenkhorn",
"sinkhorn_log"])
+@pytest.skip_arg(("nx", "method"), ("tf", "sinkhorn_epsilon_scaling"), reason="tf does not support sinkhorn_epsilon_scaling", getter=str)
+@pytest.skip_arg(("nx", "method"), ("tf", "greenkhorn"), reason="tf does not support greenkhorn", getter=str)
@pytest.skip_arg(("nx", "method"), ("jax", "sinkhorn_epsilon_scaling"), reason="jax does not support sinkhorn_epsilon_scaling", getter=str)
@pytest.skip_arg(("nx", "method"), ("jax", "greenkhorn"), reason="jax does not support greenkhorn", getter=str)
def test_sinkhorn_variants_dtype_device(nx, method):
@@ -323,6 +326,36 @@ def test_sinkhorn2_variants_dtype_device(nx, method):
nx.assert_same_dtype_device(Mb, lossb)
+@pytest.mark.skipif(not tf, reason="tf not installed")
+@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized", "sinkhorn_log"])
+def test_sinkhorn2_variants_device_tf(method):
+ nx = ot.backend.TensorflowBackend()
+ n = 100
+ x = np.random.randn(n, 2)
+ u = ot.utils.unif(n)
+ M = ot.dist(x, x)
+
+ # Check that everything stays on the CPU
+ with tf.device("/CPU:0"):
+ ub = nx.from_numpy(u)
+ Mb = nx.from_numpy(M)
+ Gb = ot.sinkhorn(ub, ub, Mb, 1, method=method, stopThr=1e-10)
+ lossb = ot.sinkhorn2(ub, ub, Mb, 1, method=method, stopThr=1e-10)
+ nx.assert_same_dtype_device(Mb, Gb)
+ nx.assert_same_dtype_device(Mb, lossb)
+
+ if len(tf.config.list_physical_devices('GPU')) > 0:
+ # Check that everything happens on the GPU
+ ub = nx.from_numpy(u)
+ Mb = nx.from_numpy(M)
+ Gb = ot.sinkhorn(ub, ub, Mb, 1, method=method, stopThr=1e-10)
+ lossb = ot.sinkhorn2(ub, ub, Mb, 1, method=method, stopThr=1e-10)
+ nx.assert_same_dtype_device(Mb, Gb)
+ nx.assert_same_dtype_device(Mb, lossb)
+ assert nx.dtype_device(Gb)[1].startswith("GPU")
+
+
+@pytest.skip_backend('tf')
@pytest.skip_backend("jax")
def test_sinkhorn_variants_multi_b(nx):
# test sinkhorn
@@ -352,6 +385,7 @@ def test_sinkhorn_variants_multi_b(nx):
np.testing.assert_allclose(G0, Gs, atol=1e-05)
+@pytest.skip_backend('tf')
@pytest.skip_backend("jax")
def test_sinkhorn2_variants_multi_b(nx):
# test sinkhorn
@@ -454,7 +488,7 @@ def test_barycenter(nx, method, verbose, warn):
weights_nx = nx.from_numpy(weights)
reg = 1e-2
- if nx.__name__ == "jax" and method == "sinkhorn_log":
+ if nx.__name__ in ("jax", "tf") and method == "sinkhorn_log":
with pytest.raises(NotImplementedError):
ot.bregman.barycenter(A_nx, M_nx, reg, weights, method=method)
else:
@@ -495,7 +529,7 @@ def test_barycenter_debiased(nx, method, verbose, warn):
# wasserstein
reg = 1e-2
- if nx.__name__ == "jax" and method == "sinkhorn_log":
+ if nx.__name__ in ("jax", "tf") and method == "sinkhorn_log":
with pytest.raises(NotImplementedError):
ot.bregman.barycenter_debiased(A_nx, M_nx, reg, weights, method=method)
else:
@@ -597,7 +631,7 @@ def test_wasserstein_bary_2d(nx, method):
# wasserstein
reg = 1e-2
- if nx.__name__ == "jax" and method == "sinkhorn_log":
+ if nx.__name__ in ("jax", "tf") and method == "sinkhorn_log":
with pytest.raises(NotImplementedError):
ot.bregman.convolutional_barycenter2d(A_nx, reg, method=method)
else:
@@ -629,7 +663,7 @@ def test_wasserstein_bary_2d_debiased(nx, method):
# wasserstein
reg = 1e-2
- if nx.__name__ == "jax" and method == "sinkhorn_log":
+ if nx.__name__ in ("jax", "tf") and method == "sinkhorn_log":
with pytest.raises(NotImplementedError):
ot.bregman.convolutional_barycenter2d_debiased(A_nx, reg, method=method)
else:
@@ -888,6 +922,7 @@ def test_implemented_methods():
ot.bregman.sinkhorn2(a, b, M, epsilon, method=method)
+@pytest.skip_backend('tf')
@pytest.skip_backend("cupy")
@pytest.skip_backend("jax")
@pytest.mark.filterwarnings("ignore:Bottleneck")
diff --git a/test/test_gromov.py b/test/test_gromov.py
index 38a7fd7..4b995d5 100644
--- a/test/test_gromov.py
+++ b/test/test_gromov.py
@@ -9,7 +9,7 @@
import numpy as np
import ot
from ot.backend import NumpyBackend
-from ot.backend import torch
+from ot.backend import torch, tf
import pytest
@@ -113,6 +113,45 @@ def test_gromov_dtype_device(nx):
nx.assert_same_dtype_device(C1b, gw_valb)
+@pytest.mark.skipif(not tf, reason="tf not installed")
+def test_gromov_device_tf():
+ nx = ot.backend.TensorflowBackend()
+ n_samples = 50 # nb samples
+ mu_s = np.array([0, 0])
+ cov_s = np.array([[1, 0], [0, 1]])
+ xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=4)
+ xt = xs[::-1].copy()
+ p = ot.unif(n_samples)
+ q = ot.unif(n_samples)
+ C1 = ot.dist(xs, xs)
+ C2 = ot.dist(xt, xt)
+ C1 /= C1.max()
+ C2 /= C2.max()
+
+ # Check that everything stays on the CPU
+ with tf.device("/CPU:0"):
+ C1b = nx.from_numpy(C1)
+ C2b = nx.from_numpy(C2)
+ pb = nx.from_numpy(p)
+ qb = nx.from_numpy(q)
+ Gb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', verbose=True)
+ gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', log=False)
+ nx.assert_same_dtype_device(C1b, Gb)
+ nx.assert_same_dtype_device(C1b, gw_valb)
+
+ if len(tf.config.list_physical_devices('GPU')) > 0:
+ # Check that everything happens on the GPU
+ C1b = nx.from_numpy(C1)
+ C2b = nx.from_numpy(C2)
+ pb = nx.from_numpy(p)
+ qb = nx.from_numpy(q)
+ Gb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', verbose=True)
+ gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', log=False)
+ nx.assert_same_dtype_device(C1b, Gb)
+ nx.assert_same_dtype_device(C1b, gw_valb)
+ assert nx.dtype_device(Gb)[1].startswith("GPU")
+
+
def test_gromov2_gradients():
n_samples = 50 # nb samples
@@ -150,6 +189,7 @@ def test_gromov2_gradients():
@pytest.skip_backend("jax", reason="test very slow with jax backend")
+@pytest.skip_backend("tf", reason="test very slow with tf backend")
def test_entropic_gromov(nx):
n_samples = 50 # nb samples
@@ -208,6 +248,7 @@ def test_entropic_gromov(nx):
@pytest.skip_backend("jax", reason="test very slow with jax backend")
+@pytest.skip_backend("tf", reason="test very slow with tf backend")
def test_entropic_gromov_dtype_device(nx):
# setup
n_samples = 50 # nb samples
@@ -306,6 +347,7 @@ def test_pointwise_gromov(nx):
np.testing.assert_allclose(float(logb['gw_dist_std']), 0.0015952535464736394, atol=1e-8)
+@pytest.skip_backend("tf", reason="test very slow with tf backend")
@pytest.skip_backend("jax", reason="test very slow with jax backend")
def test_sampled_gromov(nx):
n_samples = 50 # nb samples
diff --git a/test/test_ot.py b/test/test_ot.py
index c4d7713..53edf4f 100644
--- a/test/test_ot.py
+++ b/test/test_ot.py
@@ -11,7 +11,7 @@ import pytest
import ot
from ot.datasets import make_1D_gauss as gauss
-from ot.backend import torch
+from ot.backend import torch, tf
def test_emd_dimension_and_mass_mismatch():
@@ -101,6 +101,40 @@ def test_emd_emd2_types_devices(nx):
nx.assert_same_dtype_device(Mb, w)
+@pytest.mark.skipif(not tf, reason="tf not installed")
+def test_emd_emd2_devices_tf():
+ if not tf:
+ return
+ nx = ot.backend.TensorflowBackend()
+
+ n_samples = 100
+ n_features = 2
+ rng = np.random.RandomState(0)
+ x = rng.randn(n_samples, n_features)
+ y = rng.randn(n_samples, n_features)
+ a = ot.utils.unif(n_samples)
+ M = ot.dist(x, y)
+
+ # Check that everything stays on the CPU
+ with tf.device("/CPU:0"):
+ ab = nx.from_numpy(a)
+ Mb = nx.from_numpy(M)
+ Gb = ot.emd(ab, ab, Mb)
+ w = ot.emd2(ab, ab, Mb)
+ nx.assert_same_dtype_device(Mb, Gb)
+ nx.assert_same_dtype_device(Mb, w)
+
+ if len(tf.config.list_physical_devices('GPU')) > 0:
+ # Check that everything happens on the GPU
+ ab = nx.from_numpy(a)
+ Mb = nx.from_numpy(M)
+ Gb = ot.emd(ab, ab, Mb)
+ w = ot.emd2(ab, ab, Mb)
+ nx.assert_same_dtype_device(Mb, Gb)
+ nx.assert_same_dtype_device(Mb, w)
+ assert nx.dtype_device(Gb)[1].startswith("GPU")
+
+
def test_emd2_gradients():
n_samples = 100
n_features = 2
diff --git a/test/test_sliced.py b/test/test_sliced.py
index 245202c..91e0961 100644
--- a/test/test_sliced.py
+++ b/test/test_sliced.py
@@ -10,6 +10,7 @@ import pytest
import ot
from ot.sliced import get_random_projections
+from ot.backend import tf
def test_get_random_projections():
@@ -161,6 +162,34 @@ def test_sliced_backend_type_devices(nx):
nx.assert_same_dtype_device(xb, valb)
+@pytest.mark.skipif(not tf, reason="tf not installed")
+def test_sliced_backend_device_tf():
+ nx = ot.backend.TensorflowBackend()
+ n = 100
+ rng = np.random.RandomState(0)
+ x = rng.randn(n, 2)
+ y = rng.randn(2 * n, 2)
+ P = rng.randn(2, 20)
+ P = P / np.sqrt((P**2).sum(0, keepdims=True))
+
+ # Check that everything stays on the CPU
+ with tf.device("/CPU:0"):
+ xb = nx.from_numpy(x)
+ yb = nx.from_numpy(y)
+ Pb = nx.from_numpy(P)
+ valb = ot.sliced_wasserstein_distance(xb, yb, projections=Pb)
+ nx.assert_same_dtype_device(xb, valb)
+
+ if len(tf.config.list_physical_devices('GPU')) > 0:
+ # Check that everything happens on the GPU
+ xb = nx.from_numpy(x)
+ yb = nx.from_numpy(y)
+ Pb = nx.from_numpy(P)
+ valb = ot.sliced_wasserstein_distance(xb, yb, projections=Pb)
+ nx.assert_same_dtype_device(xb, valb)
+ assert nx.dtype_device(valb)[1].startswith("GPU")
+
+
def test_max_sliced_backend(nx):
n = 100
@@ -211,3 +240,31 @@ def test_max_sliced_backend_type_devices(nx):
valb = ot.max_sliced_wasserstein_distance(xb, yb, projections=Pb)
nx.assert_same_dtype_device(xb, valb)
+
+
+@pytest.mark.skipif(not tf, reason="tf not installed")
+def test_max_sliced_backend_device_tf():
+ nx = ot.backend.TensorflowBackend()
+ n = 100
+ rng = np.random.RandomState(0)
+ x = rng.randn(n, 2)
+ y = rng.randn(2 * n, 2)
+ P = rng.randn(2, 20)
+ P = P / np.sqrt((P**2).sum(0, keepdims=True))
+
+ # Check that everything stays on the CPU
+ with tf.device("/CPU:0"):
+ xb = nx.from_numpy(x)
+ yb = nx.from_numpy(y)
+ Pb = nx.from_numpy(P)
+ valb = ot.max_sliced_wasserstein_distance(xb, yb, projections=Pb)
+ nx.assert_same_dtype_device(xb, valb)
+
+ if len(tf.config.list_physical_devices('GPU')) > 0:
+ # Check that everything happens on the GPU
+ xb = nx.from_numpy(x)
+ yb = nx.from_numpy(y)
+ Pb = nx.from_numpy(P)
+ valb = ot.max_sliced_wasserstein_distance(xb, yb, projections=Pb)
+ nx.assert_same_dtype_device(xb, valb)
+ assert nx.dtype_device(valb)[1].startswith("GPU")