summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2021-06-01 10:10:54 +0200
committerGitHub <noreply@github.com>2021-06-01 10:10:54 +0200
commit184f8f4f7ac78f1dd7f653496d2753211a4e3426 (patch)
tree483a7274c91030fd644de49b03a5fad04af9deba
parent1f16614954e2522fbdb1598c5b1f5c3630c68472 (diff)
[MRG] POT numpy/torch/jax backends (#249)
* add numpy and torch backends * stat sets on functions * proper import * install recent torch on windows * install recent torch on windows * now testing all functions in backedn * add jax backedn * clenaup windowds * proper convert for jax backedn * pep8 * try again windows tests * test jax conversion * try proper widows tests * emd fuction ses backedn * better test partial OT * proper tests to_numpy and teplate Backend * pep8 * pep8 x2 * feaking sinkhorn works with torch * sinkhorn2 compatible * working ot.emd2 * important detach * it should work * jax autodiff emd * pep8 * no tast same for jax * new independat tests per backedn * freaking pep8 * add tests for gradients * deprecate ot.gpu * worging dist function * working dist * dist done in backedn * not in * remove indexing * change accuacy for jax * first pull backend * projection simplex * projection simplex * projection simplex * projection simplex no ci * projection simplex no ci * projection simplex no ci * pep8 * add backedn discusion to quickstart guide * projection simplex no ci * projection simplex no ci * projection simplex no ci * pep8 + better doc * proper links * corect doctest * big debug documentation * doctest again * doctest again bis * doctest again ter (last one or i kill myself) * backend test + doc proj simplex * correction test_utils * correction test_utils * correction cumsum * correction flip * correction flip v2 * more debug * more debug * more debug + pep8 * pep8 * argh * proj_simplex * backedn works for sort * proj simplex * jax sucks * update doc * Update test/test_utils.py Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update docs/source/quickstart.rst Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update docs/source/quickstart.rst Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update docs/source/quickstart.rst Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update docs/source/readme.rst Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update test/test_utils.py Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update ot/utils.py Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update docs/source/readme.rst Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update ot/lp/__init__.py Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * begin comment alex * comment alex part 2 * optimize test gromov * proj_simplex on vectors * add awesome gradient decsnt example on the weights * pep98 of course * proof read example by alex * pep8 again * encoding oos in translation * correct legend Co-authored-by: Nicolas Courty <ncourty@irisa.fr> Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
-rw-r--r--.github/requirements_test_windows.txt10
-rw-r--r--.github/workflows/build_tests.yml9
-rw-r--r--README.md8
-rw-r--r--docs/source/quickstart.rst68
-rw-r--r--docs/source/readme.rst70
-rw-r--r--examples/README.txt2
-rw-r--r--examples/backends/README.txt4
-rw-r--r--examples/backends/plot_unmix_optim_torch.py161
-rw-r--r--ot/__init__.py1
-rw-r--r--ot/backend.py536
-rw-r--r--ot/bregman.py141
-rw-r--r--ot/gpu/__init__.py4
-rw-r--r--ot/lp/__init__.py137
-rw-r--r--ot/utils.py128
-rw-r--r--requirements.txt3
-rw-r--r--test/test_backend.py364
-rw-r--r--test/test_bregman.py74
-rw-r--r--test/test_gromov.py10
-rw-r--r--test/test_ot.py91
-rwxr-xr-xtest/test_partial.py4
-rw-r--r--test/test_utils.py76
21 files changed, 1692 insertions, 209 deletions
diff --git a/.github/requirements_test_windows.txt b/.github/requirements_test_windows.txt
new file mode 100644
index 0000000..331dd57
--- /dev/null
+++ b/.github/requirements_test_windows.txt
@@ -0,0 +1,10 @@
+numpy
+scipy>=1.3
+cython
+matplotlib
+autograd
+pymanopt==0.2.4; python_version <'3'
+pymanopt; python_version >= '3'
+cvxopt
+scikit-learn
+pytest \ No newline at end of file
diff --git a/.github/workflows/build_tests.yml b/.github/workflows/build_tests.yml
index 2fc6770..92a07b5 100644
--- a/.github/workflows/build_tests.yml
+++ b/.github/workflows/build_tests.yml
@@ -40,7 +40,7 @@ jobs:
pip install -e .
- name: Run tests
run: |
- python -m pytest -v test/ ot/ --doctest-modules --ignore ot/gpu/ --cov=ot
+ python -m pytest -v test/ ot/ --doctest-modules --ignore ot/gpu/ --cov=ot --color=yes
- name: Upload codecov
run: |
codecov
@@ -142,11 +142,12 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
- pip install -r requirements.txt
- pip install pytest "pytest-cov<2.6"
+ python -m pip install -r .github/requirements_test_windows.txt
+ python -m pip install torch==1.8.1+cpu -f https://download.pytorch.org/whl/torch_stable.html
+ python -m pip install pytest "pytest-cov<2.6"
- name: Install POT
run: |
- pip install -e .
+ python -m pip install -e .
- name: Run tests
run: |
python -m pytest -v test/ ot/ --doctest-modules --ignore ot/gpu/ --cov=ot
diff --git a/README.md b/README.md
index f5d18c1..e5e16e0 100644
--- a/README.md
+++ b/README.md
@@ -20,7 +20,7 @@ POT provides the following generic OT solvers (links to examples):
* [OT Network Simplex solver](https://pythonot.github.io/auto_examples/plot_OT_1D.html) for the linear program/ Earth Movers Distance [1] .
* [Conditional gradient](https://pythonot.github.io/auto_examples/plot_optim_OTreg.html) [6] and [Generalized conditional gradient](https://pythonot.github.io/auto_examples/plot_optim_OTreg.html) for regularized OT [7].
-* Entropic regularization OT solver with [Sinkhorn Knopp Algorithm](https://pythonot.github.io/auto_examples/plot_OT_1D.html) [2] , stabilized version [9] [10], greedy Sinkhorn [22] and [Screening Sinkhorn [26] ](https://pythonot.github.io/auto_examples/plot_screenkhorn_1D.html) with optional GPU implementation (requires cupy).
+* Entropic regularization OT solver with [Sinkhorn Knopp Algorithm](https://pythonot.github.io/auto_examples/plot_OT_1D.html) [2] , stabilized version [9] [10], greedy Sinkhorn [22] and [Screening Sinkhorn [26] ](https://pythonot.github.io/auto_examples/plot_screenkhorn_1D.html).
* Bregman projections for [Wasserstein barycenter](https://pythonot.github.io/auto_examples/barycenters/plot_barycenter_lp_vs_entropic.html) [3], [convolutional barycenter](https://pythonot.github.io/auto_examples/barycenters/plot_convolutional_barycenter.html) [21] and unmixing [4].
* Sinkhorn divergence [23] and entropic regularization OT from empirical data.
* [Smooth optimal transport solvers](https://pythonot.github.io/auto_examples/plot_OT_1D_smooth.html) (dual and semi-dual) for KL and squared L2 regularizations [17].
@@ -33,6 +33,7 @@ POT provides the following generic OT solvers (links to examples):
* [Partial Wasserstein and Gromov-Wasserstein](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_partial_wass_and_gromov.html) (exact [29] and entropic [3]
formulations).
* [Sliced Wasserstein](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_variance.html) [31, 32].
+* [Several backends](https://pythonot.github.io/quickstart.html#solving-ot-with-multiple-backends) for easy use of POT with [Pytorch](https://pytorch.org/)/[jax](https://github.com/google/jax)/[Numpy](https://numpy.org/) arrays.
POT provides the following Machine Learning related solvers:
@@ -77,8 +78,7 @@ The library has been tested on Linux, MacOSX and Windows. It requires a C++ comp
- Numpy (>=1.16)
- Scipy (>=1.0)
-- Cython (>=0.23)
-- Matplotlib (>=1.5)
+- Cython (>=0.23) (build only, not necessary when installing wheels from pip or conda)
#### Pip installation
@@ -129,7 +129,7 @@ Some sub-modules require additional dependences which are discussed below
pip install pymanopt autograd
```
-* **ot.gpu** (GPU accelerated OT) depends on cupy that have to be installed following instructions on [this page](https://docs-cupy.chainer.org/en/stable/install.html). Obviously you will need CUDA installed and a compatible GPU.
+* **ot.gpu** (GPU accelerated OT) depends on cupy that have to be installed following instructions on [this page](https://docs-cupy.chainer.org/en/stable/install.html). Obviously you will need CUDA installed and a compatible GPU. Note that this module is deprecated since version 0.8 and will be deleted in the future. GPU is now handled automatically through the backends and several solver already can run on GPU using the Pytorch backend.
## Examples
diff --git a/docs/source/quickstart.rst b/docs/source/quickstart.rst
index cf5d6aa..fd046a1 100644
--- a/docs/source/quickstart.rst
+++ b/docs/source/quickstart.rst
@@ -15,6 +15,12 @@ are also available as notebooks on the POT Github.
in ML applications we refer the reader to the following `OTML tutorial
<https://remi.flamary.com/cours/tuto_otml.html>`_.
+.. note::
+
+ Since version 0.8, POT provides a backend to automatically solve some OT
+ problems independently from the toolbox used by the user (numpy/torch/jax).
+ We provide a discussion about which functions are compatible in section
+ `Backend section <#solving-ot-with-multiple-backends>`_ .
Why Optimal Transport ?
@@ -158,7 +164,6 @@ Wasserstein but has better computational and
`statistical properties <https://arxiv.org/pdf/1910.04091.pdf>`_.
-
Optimal transport and Wasserstein distance
------------------------------------------
@@ -922,6 +927,13 @@ The implementations of FGW and FGW barycenter is provided in functions
GPU acceleration
^^^^^^^^^^^^^^^^
+.. warning::
+
+ The :any:`ot.gpu` has been deprecated since the release 0.8 of POT and
+ should not be used. The GPU implementation (in Pytorch for instance) can be
+ used with the novel backends using the compatible functions from POT.
+
+
We provide several implementation of our OT solvers in :any:`ot.gpu`. Those
implementations use the :code:`cupy` toolbox that obviously need to be installed.
@@ -950,6 +962,60 @@ explicitly.
use it you have to specifically import it with :code:`import ot.gpu` .
+Solving OT with Multiple backends
+---------------------------------
+
+.. _backends_section:
+
+Since version 0.8, POT provides a backend that allows to code solvers
+independently from the type of the input arrays. The idea is to provide the user
+with a package that works seamlessly and returns a solution for instance as a
+Pytorch tensors when the function has Pytorch tensors as input.
+
+
+How it works
+^^^^^^^^^^^^
+
+The aim of the backend is to use the same function independently of the type of
+the input arrays.
+
+For instance when executing the following code
+
+.. code:: python
+
+ # a and b are 1D histograms (sum to 1 and positive)
+ # M is the ground cost matrix
+ T = ot.emd(a, b, M) # exact linear program
+ w = ot.emd2(a, b, M) # Wasserstein computation
+
+the functions :any:`ot.emd` and :any:`ot.emd2` can take inputs of the type
+:any:`numpy.array`, :any:`torch.tensor` or :any:`jax.numpy.array`. The output of
+the function will be the same type as the inputs and on the same device. When
+possible all computations are done on the same device and also when possible the
+output will be differentiable with respect to the input of the function.
+
+
+
+List of compatible Backends
+^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+- `Numpy <https://numpy.org/>`_ (all functions and solvers)
+- `Pytorch <https://pytorch.org/>`_ (all outputs differentiable w.r.t. inputs)
+- `Jax <https://github.com/google/jax>`_ (Some functions are differentiable some require a wrapper)
+
+List of compatible functions
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+This list will get longer for new releases and will hopefully disappear when POT
+become fully implemented with the backend.
+
+- :any:`ot.emd`
+- :any:`ot.emd2`
+- :any:`ot.sinkhorn`
+- :any:`ot.sinkhorn2`
+- :any:`ot.dist`
+
+
FAQ
---
diff --git a/docs/source/readme.rst b/docs/source/readme.rst
index 3b594c2..82d3e6c 100644
--- a/docs/source/readme.rst
+++ b/docs/source/readme.rst
@@ -26,8 +26,7 @@ POT provides the following generic OT solvers (links to examples):
Algorithm <auto_examples/plot_OT_1D.html>`__
[2] , stabilized version [9] [10], greedy Sinkhorn [22] and
`Screening Sinkhorn
- [26] <auto_examples/plot_screenkhorn_1D.html>`__
- with optional GPU implementation (requires cupy).
+ [26] <auto_examples/plot_screenkhorn_1D.html>`__.
- Bregman projections for `Wasserstein
barycenter <auto_examples/barycenters/plot_barycenter_lp_vs_entropic.html>`__
[3], `convolutional
@@ -69,6 +68,11 @@ POT provides the following generic OT solvers (links to examples):
- `Sliced
Wasserstein <auto_examples/sliced-wasserstein/plot_variance.html>`__
[31, 32].
+- `Several
+ backends <https://pythonot.github.io/quickstart.html#solving-ot-with-multiple-backends>`__
+ for easy use of POT with
+ `Pytorch <https://pytorch.org/>`__/`jax <https://github.com/google/jax>`__/`Numpy <https://numpy.org/>`__
+ arrays.
POT provides the following Machine Learning related solvers:
@@ -104,12 +108,14 @@ paper <https://jmlr.org/papers/v22/20-451.html>`__:
::
- Rémi Flamary, Nicolas Courty, Alexandre Gramfort, Mokhtar Z. Alaya, Aurélie Boisbunon, Stanislas Chambon, Laetitia Chapel, Adrien Corenflos, Kilian Fatras, Nemo Fournier, Léo Gautheron, Nathalie T.H. Gayraud, Hicham Janati, Alain Rakotomamonjy, Ievgen Redko, Antoine Rolet, Antony Schutz, Vivien Seguy, Danica J. Sutherland, Romain Tavenard, Alexander Tong, Titouan Vayer;, POT Python Optimal Transport library, Journal of Machine Learning Research, 22(78):1−8, 2021.
+ Rémi Flamary, Nicolas Courty, Alexandre Gramfort, Mokhtar Z. Alaya, Aurélie Boisbunon, Stanislas Chambon, Laetitia Chapel, Adrien Corenflos, Kilian Fatras, Nemo Fournier, Léo Gautheron, Nathalie T.H. Gayraud, Hicham Janati, Alain Rakotomamonjy, Ievgen Redko, Antoine Rolet, Antony Schutz, Vivien Seguy, Danica J. Sutherland, Romain Tavenard, Alexander Tong, Titouan Vayer,
+ POT Python Optimal Transport library,
+ Journal of Machine Learning Research, 22(78):1−8, 2021.
Website: https://pythonot.github.io/
In Bibtex format:
-::
+.. code:: bibtex
@article{flamary2021pot,
author = {R{\'e}mi Flamary and Nicolas Courty and Alexandre Gramfort and Mokhtar Z. Alaya and Aur{\'e}lie Boisbunon and Stanislas Chambon and Laetitia Chapel and Adrien Corenflos and Kilian Fatras and Nemo Fournier and L{\'e}o Gautheron and Nathalie T.H. Gayraud and Hicham Janati and Alain Rakotomamonjy and Ievgen Redko and Antoine Rolet and Antony Schutz and Vivien Seguy and Danica J. Sutherland and Romain Tavenard and Alexander Tong and Titouan Vayer},
@@ -131,8 +137,8 @@ following Python modules:
- Numpy (>=1.16)
- Scipy (>=1.0)
-- Cython (>=0.23)
-- Matplotlib (>=1.5)
+- Cython (>=0.23) (build only, not necessary when installing wheels
+ from pip or conda)
Pip installation
^^^^^^^^^^^^^^^^
@@ -140,19 +146,19 @@ Pip installation
Note that due to a limitation of pip, ``cython`` and ``numpy`` need to
be installed prior to installing POT. This can be done easily with
-::
+.. code:: console
pip install numpy cython
You can install the toolbox through PyPI with:
-::
+.. code:: console
pip install POT
or get the very latest version by running:
-::
+.. code:: console
pip install -U https://github.com/PythonOT/POT/archive/master.zip # with --user for user install (no root)
@@ -163,7 +169,7 @@ If you use the Anaconda python distribution, POT is available in
`conda-forge <https://conda-forge.org>`__. To install it and the
required dependencies:
-::
+.. code:: console
conda install -c conda-forge pot
@@ -188,15 +194,17 @@ below
- **ot.dr** (Wasserstein dimensionality reduction) depends on autograd
and pymanopt that can be installed with:
- ::
+.. code:: shell
- pip install pymanopt autograd
+ pip install pymanopt autograd
- **ot.gpu** (GPU accelerated OT) depends on cupy that have to be
installed following instructions on `this
page <https://docs-cupy.chainer.org/en/stable/install.html>`__.
-
-obviously you need CUDA installed and a compatible GPU.
+ Obviously you will need CUDA installed and a compatible GPU. Note
+ that this module is deprecated since version 0.8 and will be deleted
+ in the future. GPU is now handled automatically through the backends
+ and several solver already can run on GPU using the Pytorch backend.
Examples
--------
@@ -206,36 +214,36 @@ Short examples
- Import the toolbox
- .. code:: python
+.. code:: python
- import ot
+ import ot
- Compute Wasserstein distances
- .. code:: python
+.. code:: python
- # a,b are 1D histograms (sum to 1 and positive)
- # M is the ground cost matrix
- Wd=ot.emd2(a,b,M) # exact linear program
- Wd_reg=ot.sinkhorn2(a,b,M,reg) # entropic regularized OT
- # if b is a matrix compute all distances to a and return a vector
+ # a and b are 1D histograms (sum to 1 and positive)
+ # M is the ground cost matrix
+ Wd = ot.emd2(a, b, M) # exact linear program
+ Wd_reg = ot.sinkhorn2(a, b, M, reg) # entropic regularized OT
+ # if b is a matrix compute all distances to a and return a vector
- Compute OT matrix
- .. code:: python
+.. code:: python
- # a,b are 1D histograms (sum to 1 and positive)
- # M is the ground cost matrix
- T=ot.emd(a,b,M) # exact linear program
- T_reg=ot.sinkhorn(a,b,M,reg) # entropic regularized OT
+ # a and b are 1D histograms (sum to 1 and positive)
+ # M is the ground cost matrix
+ T = ot.emd(a, b, M) # exact linear program
+ T_reg = ot.sinkhorn(a, b, M, reg) # entropic regularized OT
- Compute Wasserstein barycenter
- .. code:: python
+.. code:: python
- # A is a n*d matrix containing d 1D histograms
- # M is the ground cost matrix
- ba=ot.barycenter(A,M,reg) # reg is regularization parameter
+ # A is a n*d matrix containing d 1D histograms
+ # M is the ground cost matrix
+ ba = ot.barycenter(A, M, reg) # reg is regularization parameter
Examples and Notebooks
~~~~~~~~~~~~~~~~~~~~~~
diff --git a/examples/README.txt b/examples/README.txt
index 69a9f84..b48487f 100644
--- a/examples/README.txt
+++ b/examples/README.txt
@@ -1,7 +1,7 @@
Examples gallery
================
-This is a gallery of all the POT example files.
+This is a gallery of all the POT example files.
OT and regularized OT
diff --git a/examples/backends/README.txt b/examples/backends/README.txt
new file mode 100644
index 0000000..3ee0e27
--- /dev/null
+++ b/examples/backends/README.txt
@@ -0,0 +1,4 @@
+
+
+POT backend examples
+-------------------- \ No newline at end of file
diff --git a/examples/backends/plot_unmix_optim_torch.py b/examples/backends/plot_unmix_optim_torch.py
new file mode 100644
index 0000000..9ae66e9
--- /dev/null
+++ b/examples/backends/plot_unmix_optim_torch.py
@@ -0,0 +1,161 @@
+# -*- coding: utf-8 -*-
+r"""
+=================================
+Wasserstein unmixing with PyTorch
+=================================
+
+In this example we estimate mixing parameters from distributions that minimize
+the Wasserstein distance. In other words we suppose that a target
+distribution :math:`\mu^t` can be expressed as a weighted sum of source
+distributions :math:`\mu^s_k` with the following model:
+
+.. math::
+ \mu^t = \sum_{k=1}^K w_k\mu^s_k
+
+where :math:`\mathbf{w}` is a vector of size :math:`K` and belongs in the
+distribution simplex :math:`\Delta_K`.
+
+In order to estimate this weight vector we propose to optimize the Wasserstein
+distance between the model and the observed :math:`\mu^t` with respect to
+the vector. This leads to the following optimization problem:
+
+.. math::
+ \min_{\mathbf{w}\in\Delta_K} \quad W \left(\mu^t,\sum_{k=1}^K w_k\mu^s_k\right)
+
+This minimization is done in this example with a simple projected gradient
+descent in PyTorch. We use the automatic backend of POT that allows us to
+compute the Wasserstein distance with :any:`ot.emd2` with
+differentiable losses.
+
+"""
+
+# Author: Remi Flamary <remi.flamary@polytechnique.edu>
+#
+# License: MIT License
+
+# sphinx_gallery_thumbnail_number = 2
+
+import numpy as np
+import matplotlib.pylab as pl
+import ot
+import torch
+
+
+##############################################################################
+# Generate data
+# -------------
+
+#%% Data
+
+nt = 100
+nt1 = 10 #
+
+ns1 = 50
+ns = 2 * ns1
+
+rng = np.random.RandomState(2)
+
+xt = rng.randn(nt, 2) * 0.2
+xt[:nt1, 0] += 1
+xt[nt1:, 1] += 1
+
+
+xs1 = rng.randn(ns1, 2) * 0.2
+xs1[:, 0] += 1
+xs2 = rng.randn(ns1, 2) * 0.2
+xs2[:, 1] += 1
+
+xs = np.concatenate((xs1, xs2))
+
+# Sample reweighting matrix H
+H = np.zeros((ns, 2))
+H[:ns1, 0] = 1 / ns1
+H[ns1:, 1] = 1 / ns1
+# each columns sums to 1 and has weights only for samples form the
+# corresponding source distribution
+
+M = ot.dist(xs, xt)
+
+##############################################################################
+# Plot data
+# ---------
+
+#%% plot the distributions
+
+pl.figure(1)
+pl.scatter(xt[:, 0], xt[:, 1], label='Target $\mu^t$', alpha=0.5)
+pl.scatter(xs1[:, 0], xs1[:, 1], label='Source $\mu^s_1$', alpha=0.5)
+pl.scatter(xs2[:, 0], xs2[:, 1], label='Source $\mu^s_2$', alpha=0.5)
+pl.title('Sources and Target distributions')
+pl.legend()
+
+
+##############################################################################
+# Optimization of the model wrt the Wasserstein distance
+# ------------------------------------------------------
+
+
+#%% Weights optimization with gradient descent
+
+# convert numpy arrays to torch tensors
+H2 = torch.tensor(H)
+M2 = torch.tensor(M)
+
+# weights for the source distributions
+w = torch.tensor(ot.unif(2), requires_grad=True)
+
+# uniform weights for target
+b = torch.tensor(ot.unif(nt))
+
+lr = 2e-3 # learning rate
+niter = 500 # number of iterations
+losses = [] # loss along the iterations
+
+# loss for the minimal Wasserstein estimator
+
+
+def get_loss(w):
+ a = torch.mv(H2, w) # distribution reweighting
+ return ot.emd2(a, b, M2) # squared Wasserstein 2
+
+
+for i in range(niter):
+
+ loss = get_loss(w)
+ losses.append(float(loss))
+
+ loss.backward()
+
+ with torch.no_grad():
+ w -= lr * w.grad # gradient step
+ w[:] = ot.utils.proj_simplex(w) # projection on the simplex
+
+ w.grad.zero_()
+
+
+##############################################################################
+# Estimated weights and convergence of the objective
+# ---------------------------------------------------
+
+we = w.detach().numpy()
+print('Estimated mixture:', we)
+
+pl.figure(2)
+pl.semilogy(losses)
+pl.grid()
+pl.title('Wasserstein distance')
+pl.xlabel("Iterations")
+
+##############################################################################
+# Ploting the reweighted source distribution
+# ------------------------------------------
+
+pl.figure(3)
+
+# compute source weights
+ws = H.dot(we)
+
+pl.scatter(xt[:, 0], xt[:, 1], label='Target $\mu^t$', alpha=0.5)
+pl.scatter(xs[:, 0], xs[:, 1], color='C3', s=ws * 20 * ns, label='Weighted sources $\sum_{k} w_k\mu^s_k$', alpha=0.5)
+pl.title('Target and reweighted source distributions')
+pl.legend()
diff --git a/ot/__init__.py b/ot/__init__.py
index 5a8a415..3b072c6 100644
--- a/ot/__init__.py
+++ b/ot/__init__.py
@@ -33,6 +33,7 @@ from . import smooth
from . import stochastic
from . import unbalanced
from . import partial
+from . import backend
# OT functions
from .lp import emd, emd2, emd_1d, emd2_1d, wasserstein_1d
diff --git a/ot/backend.py b/ot/backend.py
new file mode 100644
index 0000000..d68f5cf
--- /dev/null
+++ b/ot/backend.py
@@ -0,0 +1,536 @@
+# -*- coding: utf-8 -*-
+"""
+Multi-lib backend for POT
+"""
+
+# Author: Remi Flamary <remi.flamary@polytechnique.edu>
+# Nicolas Courty <ncourty@irisa.fr>
+#
+# License: MIT License
+
+import numpy as np
+
+try:
+ import torch
+ torch_type = torch.Tensor
+except ImportError:
+ torch = False
+ torch_type = float
+
+try:
+ import jax
+ import jax.numpy as jnp
+ jax_type = jax.numpy.ndarray
+except ImportError:
+ jax = False
+ jax_type = float
+
+str_type_error = "All array should be from the same type/backend. Current types are : {}"
+
+
+def get_backend_list():
+ """ returns the list of available backends)"""
+ lst = [NumpyBackend(), ]
+
+ if torch:
+ lst.append(TorchBackend())
+
+ if jax:
+ lst.append(JaxBackend())
+
+ return lst
+
+
+def get_backend(*args):
+ """returns the proper backend for a list of input arrays
+
+ Also raises TypeError if all arrays are not from the same backend
+ """
+ # check that some arrays given
+ if not len(args) > 0:
+ raise ValueError(" The function takes at least one parameter")
+ # check all same type
+
+ if isinstance(args[0], np.ndarray):
+ if not len(set(type(a) for a in args)) == 1:
+ raise ValueError(str_type_error.format([type(a) for a in args]))
+ return NumpyBackend()
+ elif torch and isinstance(args[0], torch_type):
+ if not len(set(type(a) for a in args)) == 1:
+ raise ValueError(str_type_error.format([type(a) for a in args]))
+ return TorchBackend()
+ elif isinstance(args[0], jax_type):
+ return JaxBackend()
+ else:
+ raise ValueError("Unknown type of non implemented backend.")
+
+
+def to_numpy(*args):
+ """returns numpy arrays from any compatible backend"""
+
+ if len(args) == 1:
+ return get_backend(args[0]).to_numpy(args[0])
+ else:
+ return [get_backend(a).to_numpy(a) for a in args]
+
+
+class Backend():
+
+ __name__ = None
+ __type__ = None
+
+ def __str__(self):
+ return self.__name__
+
+ # convert to numpy
+ def to_numpy(self, a):
+ raise NotImplementedError()
+
+ # convert from numpy
+ def from_numpy(self, a, type_as=None):
+ raise NotImplementedError()
+
+ def set_gradients(self, val, inputs, grads):
+ """ define the gradients for the value val wrt the inputs """
+ raise NotImplementedError()
+
+ def zeros(self, shape, type_as=None):
+ raise NotImplementedError()
+
+ def ones(self, shape, type_as=None):
+ raise NotImplementedError()
+
+ def arange(self, stop, start=0, step=1, type_as=None):
+ raise NotImplementedError()
+
+ def full(self, shape, fill_value, type_as=None):
+ raise NotImplementedError()
+
+ def eye(self, N, M=None, type_as=None):
+ raise NotImplementedError()
+
+ def sum(self, a, axis=None, keepdims=False):
+ raise NotImplementedError()
+
+ def cumsum(self, a, axis=None):
+ raise NotImplementedError()
+
+ def max(self, a, axis=None, keepdims=False):
+ raise NotImplementedError()
+
+ def min(self, a, axis=None, keepdims=False):
+ raise NotImplementedError()
+
+ def maximum(self, a, b):
+ raise NotImplementedError()
+
+ def minimum(self, a, b):
+ raise NotImplementedError()
+
+ def dot(self, a, b):
+ raise NotImplementedError()
+
+ def abs(self, a):
+ raise NotImplementedError()
+
+ def exp(self, a):
+ raise NotImplementedError()
+
+ def log(self, a):
+ raise NotImplementedError()
+
+ def sqrt(self, a):
+ raise NotImplementedError()
+
+ def norm(self, a):
+ raise NotImplementedError()
+
+ def any(self, a):
+ raise NotImplementedError()
+
+ def isnan(self, a):
+ raise NotImplementedError()
+
+ def isinf(self, a):
+ raise NotImplementedError()
+
+ def einsum(self, subscripts, *operands):
+ raise NotImplementedError()
+
+ def sort(self, a, axis=-1):
+ raise NotImplementedError()
+
+ def argsort(self, a, axis=None):
+ raise NotImplementedError()
+
+ def flip(self, a, axis=None):
+ raise NotImplementedError()
+
+
+class NumpyBackend(Backend):
+
+ __name__ = 'numpy'
+ __type__ = np.ndarray
+
+ def to_numpy(self, a):
+ return a
+
+ def from_numpy(self, a, type_as=None):
+ if type_as is None:
+ return a
+ elif isinstance(a, float):
+ return a
+ else:
+ return a.astype(type_as.dtype)
+
+ def set_gradients(self, val, inputs, grads):
+ # no gradients for numpy
+ return val
+
+ def zeros(self, shape, type_as=None):
+ if type_as is None:
+ return np.zeros(shape)
+ else:
+ return np.zeros(shape, dtype=type_as.dtype)
+
+ def ones(self, shape, type_as=None):
+ if type_as is None:
+ return np.ones(shape)
+ else:
+ return np.ones(shape, dtype=type_as.dtype)
+
+ def arange(self, stop, start=0, step=1, type_as=None):
+ return np.arange(start, stop, step)
+
+ def full(self, shape, fill_value, type_as=None):
+ if type_as is None:
+ return np.full(shape, fill_value)
+ else:
+ return np.full(shape, fill_value, dtype=type_as.dtype)
+
+ def eye(self, N, M=None, type_as=None):
+ if type_as is None:
+ return np.eye(N, M)
+ else:
+ return np.eye(N, M, dtype=type_as.dtype)
+
+ def sum(self, a, axis=None, keepdims=False):
+ return np.sum(a, axis, keepdims=keepdims)
+
+ def cumsum(self, a, axis=None):
+ return np.cumsum(a, axis)
+
+ def max(self, a, axis=None, keepdims=False):
+ return np.max(a, axis, keepdims=keepdims)
+
+ def min(self, a, axis=None, keepdims=False):
+ return np.min(a, axis, keepdims=keepdims)
+
+ def maximum(self, a, b):
+ return np.maximum(a, b)
+
+ def minimum(self, a, b):
+ return np.minimum(a, b)
+
+ def dot(self, a, b):
+ return np.dot(a, b)
+
+ def abs(self, a):
+ return np.abs(a)
+
+ def exp(self, a):
+ return np.exp(a)
+
+ def log(self, a):
+ return np.log(a)
+
+ def sqrt(self, a):
+ return np.sqrt(a)
+
+ def norm(self, a):
+ return np.sqrt(np.sum(np.square(a)))
+
+ def any(self, a):
+ return np.any(a)
+
+ def isnan(self, a):
+ return np.isnan(a)
+
+ def isinf(self, a):
+ return np.isinf(a)
+
+ def einsum(self, subscripts, *operands):
+ return np.einsum(subscripts, *operands)
+
+ def sort(self, a, axis=-1):
+ return np.sort(a, axis)
+
+ def argsort(self, a, axis=-1):
+ return np.argsort(a, axis)
+
+ def flip(self, a, axis=None):
+ return np.flip(a, axis)
+
+
+class JaxBackend(Backend):
+
+ __name__ = 'jax'
+ __type__ = jax_type
+
+ def to_numpy(self, a):
+ return np.array(a)
+
+ def from_numpy(self, a, type_as=None):
+ if type_as is None:
+ return jnp.array(a)
+ else:
+ return jnp.array(a).astype(type_as.dtype)
+
+ def set_gradients(self, val, inputs, grads):
+ # no gradients for jax because it is functional
+
+ # does not work
+ # from jax import custom_jvp
+ # @custom_jvp
+ # def f(*inputs):
+ # return val
+ # f.defjvps(*grads)
+ # return f(*inputs)
+
+ return val
+
+ def zeros(self, shape, type_as=None):
+ if type_as is None:
+ return jnp.zeros(shape)
+ else:
+ return jnp.zeros(shape, dtype=type_as.dtype)
+
+ def ones(self, shape, type_as=None):
+ if type_as is None:
+ return jnp.ones(shape)
+ else:
+ return jnp.ones(shape, dtype=type_as.dtype)
+
+ def arange(self, stop, start=0, step=1, type_as=None):
+ return jnp.arange(start, stop, step)
+
+ def full(self, shape, fill_value, type_as=None):
+ if type_as is None:
+ return jnp.full(shape, fill_value)
+ else:
+ return jnp.full(shape, fill_value, dtype=type_as.dtype)
+
+ def eye(self, N, M=None, type_as=None):
+ if type_as is None:
+ return jnp.eye(N, M)
+ else:
+ return jnp.eye(N, M, dtype=type_as.dtype)
+
+ def sum(self, a, axis=None, keepdims=False):
+ return jnp.sum(a, axis, keepdims=keepdims)
+
+ def cumsum(self, a, axis=None):
+ return jnp.cumsum(a, axis)
+
+ def max(self, a, axis=None, keepdims=False):
+ return jnp.max(a, axis, keepdims=keepdims)
+
+ def min(self, a, axis=None, keepdims=False):
+ return jnp.min(a, axis, keepdims=keepdims)
+
+ def maximum(self, a, b):
+ return jnp.maximum(a, b)
+
+ def minimum(self, a, b):
+ return jnp.minimum(a, b)
+
+ def dot(self, a, b):
+ return jnp.dot(a, b)
+
+ def abs(self, a):
+ return jnp.abs(a)
+
+ def exp(self, a):
+ return jnp.exp(a)
+
+ def log(self, a):
+ return jnp.log(a)
+
+ def sqrt(self, a):
+ return jnp.sqrt(a)
+
+ def norm(self, a):
+ return jnp.sqrt(jnp.sum(jnp.square(a)))
+
+ def any(self, a):
+ return jnp.any(a)
+
+ def isnan(self, a):
+ return jnp.isnan(a)
+
+ def isinf(self, a):
+ return jnp.isinf(a)
+
+ def einsum(self, subscripts, *operands):
+ return jnp.einsum(subscripts, *operands)
+
+ def sort(self, a, axis=-1):
+ return jnp.sort(a, axis)
+
+ def argsort(self, a, axis=-1):
+ return jnp.argsort(a, axis)
+
+ def flip(self, a, axis=None):
+ return jnp.flip(a, axis)
+
+
+class TorchBackend(Backend):
+
+ __name__ = 'torch'
+ __type__ = torch_type
+
+ def to_numpy(self, a):
+ return a.cpu().detach().numpy()
+
+ def from_numpy(self, a, type_as=None):
+ if type_as is None:
+ return torch.from_numpy(a)
+ else:
+ return torch.as_tensor(a, dtype=type_as.dtype, device=type_as.device)
+
+ def set_gradients(self, val, inputs, grads):
+ from torch.autograd import Function
+
+ # define a function that takes inputs and return val
+ class ValFunction(Function):
+ @staticmethod
+ def forward(ctx, *inputs):
+ return val
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ # the gradients are grad
+ return grads
+
+ return ValFunction.apply(*inputs)
+
+ def zeros(self, shape, type_as=None):
+ if type_as is None:
+ return torch.zeros(shape)
+ else:
+ return torch.zeros(shape, dtype=type_as.dtype, device=type_as.device)
+
+ def ones(self, shape, type_as=None):
+ if type_as is None:
+ return torch.ones(shape)
+ else:
+ return torch.ones(shape, dtype=type_as.dtype, device=type_as.device)
+
+ def arange(self, stop, start=0, step=1, type_as=None):
+ if type_as is None:
+ return torch.arange(start, stop, step)
+ else:
+ return torch.arange(start, stop, step, device=type_as.device)
+
+ def full(self, shape, fill_value, type_as=None):
+ if type_as is None:
+ return torch.full(shape, fill_value)
+ else:
+ return torch.full(shape, fill_value, dtype=type_as.dtype, device=type_as.device)
+
+ def eye(self, N, M=None, type_as=None):
+ if M is None:
+ M = N
+ if type_as is None:
+ return torch.eye(N, m=M)
+ else:
+ return torch.eye(N, m=M, dtype=type_as.dtype, device=type_as.device)
+
+ def sum(self, a, axis=None, keepdims=False):
+ if axis is None:
+ return torch.sum(a)
+ else:
+ return torch.sum(a, axis, keepdim=keepdims)
+
+ def cumsum(self, a, axis=None):
+ if axis is None:
+ return torch.cumsum(a.flatten(), 0)
+ else:
+ return torch.cumsum(a, axis)
+
+ def max(self, a, axis=None, keepdims=False):
+ if axis is None:
+ return torch.max(a)
+ else:
+ return torch.max(a, axis, keepdim=keepdims)[0]
+
+ def min(self, a, axis=None, keepdims=False):
+ if axis is None:
+ return torch.min(a)
+ else:
+ return torch.min(a, axis, keepdim=keepdims)[0]
+
+ def maximum(self, a, b):
+ if isinstance(a, int) or isinstance(a, float):
+ a = torch.tensor([float(a)], dtype=b.dtype, device=b.device)
+ if isinstance(b, int) or isinstance(b, float):
+ b = torch.tensor([float(b)], dtype=a.dtype, device=a.device)
+ return torch.maximum(a, b)
+
+ def minimum(self, a, b):
+ if isinstance(a, int) or isinstance(a, float):
+ a = torch.tensor([float(a)], dtype=b.dtype, device=b.device)
+ if isinstance(b, int) or isinstance(b, float):
+ b = torch.tensor([float(b)], dtype=a.dtype, device=a.device)
+ return torch.minimum(a, b)
+
+ def dot(self, a, b):
+ if len(a.shape) == len(b.shape) == 1:
+ return torch.dot(a, b)
+ elif len(a.shape) == 2 and len(b.shape) == 1:
+ return torch.mv(a, b)
+ else:
+ return torch.mm(a, b)
+
+ def abs(self, a):
+ return torch.abs(a)
+
+ def exp(self, a):
+ return torch.exp(a)
+
+ def log(self, a):
+ return torch.log(a)
+
+ def sqrt(self, a):
+ return torch.sqrt(a)
+
+ def norm(self, a):
+ return torch.sqrt(torch.sum(torch.square(a)))
+
+ def any(self, a):
+ return torch.any(a)
+
+ def isnan(self, a):
+ return torch.isnan(a)
+
+ def isinf(self, a):
+ return torch.isinf(a)
+
+ def einsum(self, subscripts, *operands):
+ return torch.einsum(subscripts, *operands)
+
+ def sort(self, a, axis=-1):
+ sorted0, indices = torch.sort(a, dim=axis)
+ return sorted0
+
+ def argsort(self, a, axis=-1):
+ sorted, indices = torch.sort(a, dim=axis)
+ return indices
+
+ def flip(self, a, axis=None):
+ if axis is None:
+ return torch.flip(a, tuple(i for i in range(len(a.shape))))
+ if isinstance(axis, int):
+ return torch.flip(a, (axis,))
+ else:
+ return torch.flip(a, dims=axis)
diff --git a/ot/bregman.py b/ot/bregman.py
index 559db14..b10effd 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -19,7 +19,8 @@ import warnings
import numpy as np
from scipy.optimize import fmin_l_bfgs_b
-from ot.utils import unif, dist
+from ot.utils import unif, dist, list_to_array
+from .backend import get_backend
def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
@@ -43,17 +44,36 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
- :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- a and b are source and target weights (histograms, both sum to 1)
- The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [2]_
+ .. note:: This function is backend-compatible and will work on arrays
+ from all compatible backends.
+
+ The algorithm used for solving the problem is the Sinkhorn-Knopp matrix
+ scaling algorithm as proposed in [2]_
+
+ **Choosing a Sinkhorn solver**
+
+ By default and when using a regularization parameter that is not too small
+ the default sinkhorn solver should be enough. If you need to use a small
+ regularization to get sharper OT matrices, you should use the
+ :any:`ot.bregman.sinkhorn_stabilized` solver that will avoid numerical
+ errors. This last solver can be very slow in practice and might not even
+ converge to a reasonable OT matrix in a finite time. This is why
+ :any:`ot.bregman.sinkhorn_epsilon_scaling` that relies on iterating the value
+ of the regularization (and using warm start) sometimes leads to better
+ solutions. Note that the greedy version of the sinkhorn
+ :any:`ot.bregman.greenkhorn` can also lead to a speedup and the screening
+ version of the sinkhorn :any:`ot.bregman.screenkhorn` aim a providing a
+ fast approximation of the Sinkhorn problem.
Parameters
----------
- a : ndarray, shape (dim_a,)
+ a : array-like, shape (dim_a,)
samples weights in the source domain
- b : ndarray, shape (dim_b,) or ndarray, shape (dim_b, n_hists)
+ b : array-like, shape (dim_b,) or ndarray, shape (dim_b, n_hists)
samples in the target domain, compute sinkhorn with multiple targets
and fixed M if b is a matrix (return OT loss + dual variables in log)
- M : ndarray, shape (dim_a, dim_b)
+ M : array-like, shape (dim_a, dim_b)
loss matrix
reg : float
Regularization term >0
@@ -69,25 +89,9 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
log : bool, optional
record log if True
- **Choosing a Sinkhorn solver**
-
- By default and when using a regularization parameter that is not too small
- the default sinkhorn solver should be enough. If you need to use a small
- regularization to get sharper OT matrices, you should use the
- :any:`ot.bregman.sinkhorn_stabilized` solver that will avoid numerical
- errors. This last solver can be very slow in practice and might not even
- converge to a reasonable OT matrix in a finite time. This is why
- :any:`ot.bregman.sinkhorn_epsilon_scaling` that relie on iterating the value
- of the regularization (and using warm start) sometimes leads to better
- solutions. Note that the greedy version of the sinkhorn
- :any:`ot.bregman.greenkhorn` can also lead to a speedup and the screening
- version of the sinkhorn :any:`ot.bregman.screenkhorn` aim a providing a
- fast approximation of the Sinkhorn problem.
-
-
Returns
-------
- gamma : ndarray, shape (dim_a, dim_b)
+ gamma : array-like, shape (dim_a, dim_b)
Optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
@@ -166,17 +170,35 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
- :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- a and b are source and target weights (histograms, both sum to 1)
+ .. note:: This function is backend-compatible and will work on arrays
+ from all compatible backends.
+
The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [2]_
+ **Choosing a Sinkhorn solver**
+
+ By default and when using a regularization parameter that is not too small
+ the default sinkhorn solver should be enough. If you need to use a small
+ regularization to get sharper OT matrices, you should use the
+ :any:`ot.bregman.sinkhorn_stabilized` solver that will avoid numerical
+ errors. This last solver can be very slow in practice and might not even
+ converge to a reasonable OT matrix in a finite time. This is why
+ :any:`ot.bregman.sinkhorn_epsilon_scaling` that relies on iterating the value
+ of the regularization (and using warm start) sometimes leads to better
+ solutions. Note that the greedy version of the sinkhorn
+ :any:`ot.bregman.greenkhorn` can also lead to a speedup and the screening
+ version of the sinkhorn :any:`ot.bregman.screenkhorn` aim a providing a
+ fast approximation of the Sinkhorn problem.
+
Parameters
----------
- a : ndarray, shape (dim_a,)
+ a : array-like, shape (dim_a,)
samples weights in the source domain
- b : ndarray, shape (dim_b,) or ndarray, shape (dim_b, n_hists)
+ b : array-like, shape (dim_b,) or ndarray, shape (dim_b, n_hists)
samples in the target domain, compute sinkhorn with multiple targets
and fixed M if b is a matrix (return OT loss + dual variables in log)
- M : ndarray, shape (dim_a, dim_b)
+ M : array-like, shape (dim_a, dim_b)
loss matrix
reg : float
Regularization term >0
@@ -191,28 +213,14 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
log : bool, optional
record log if True
- **Choosing a Sinkhorn solver**
-
- By default and when using a regularization parameter that is not too small
- the default sinkhorn solver should be enough. If you need to use a small
- regularization to get sharper OT matrices, you should use the
- :any:`ot.bregman.sinkhorn_stabilized` solver that will avoid numerical
- errors. This last solver can be very slow in practice and might not even
- converge to a reasonable OT matrix in a finite time. This is why
- :any:`ot.bregman.sinkhorn_epsilon_scaling` that relie on iterating the value
- of the regularization (and using warm start) sometimes leads to better
- solutions. Note that the greedy version of the sinkhorn
- :any:`ot.bregman.greenkhorn` can also lead to a speedup and the screening
- version of the sinkhorn :any:`ot.bregman.screenkhorn` aim a providing a
- fast approximation of the Sinkhorn problem.
-
Returns
-------
- W : (n_hists) ndarray
+ W : (n_hists) float/array-like
Optimal transportation loss for the given parameters
log : dict
log dictionary return only if log==True in parameters
+
Examples
--------
@@ -247,7 +255,8 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
ot.bregman.sinkhorn_stabilized: Stabilized sinkhorn [9][10]
"""
- b = np.asarray(b, dtype=np.float64)
+
+ b = list_to_array(b)
if len(b.shape) < 2:
b = b[:, None]
@@ -339,14 +348,14 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
"""
- a = np.asarray(a, dtype=np.float64)
- b = np.asarray(b, dtype=np.float64)
- M = np.asarray(M, dtype=np.float64)
+ a, b, M = list_to_array(a, b, M)
+
+ nx = get_backend(M, a, b)
if len(a) == 0:
- a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0]
+ a = nx.full((M.shape[0],), 1.0 / M.shape[0], type_as=M)
if len(b) == 0:
- b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1]
+ b = nx.full((M.shape[1],), 1.0 / M.shape[1], type_as=M)
# init data
dim_a = len(a)
@@ -363,21 +372,13 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
# we assume that no distances are null except those of the diagonal of
# distances
if n_hists:
- u = np.ones((dim_a, n_hists)) / dim_a
- v = np.ones((dim_b, n_hists)) / dim_b
+ u = nx.ones((dim_a, n_hists), type_as=M) / dim_a
+ v = nx.ones((dim_b, n_hists), type_as=M) / dim_b
else:
- u = np.ones(dim_a) / dim_a
- v = np.ones(dim_b) / dim_b
+ u = nx.ones(dim_a, type_as=M) / dim_a
+ v = nx.ones(dim_b, type_as=M) / dim_b
- # print(reg)
-
- # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute
- K = np.empty(M.shape, dtype=M.dtype)
- np.divide(M, -reg, out=K)
- np.exp(K, out=K)
-
- # print(np.min(K))
- tmp2 = np.empty(b.shape, dtype=M.dtype)
+ K = nx.exp(M / (-reg))
Kp = (1 / a).reshape(-1, 1) * K
cpt = 0
@@ -386,13 +387,13 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
uprev = u
vprev = v
- KtransposeU = np.dot(K.T, u)
- v = np.divide(b, KtransposeU)
- u = 1. / np.dot(Kp, v)
+ KtransposeU = nx.dot(K.T, u)
+ v = b / KtransposeU
+ u = 1. / nx.dot(Kp, v)
- if (np.any(KtransposeU == 0)
- or np.any(np.isnan(u)) or np.any(np.isnan(v))
- or np.any(np.isinf(u)) or np.any(np.isinf(v))):
+ if (nx.any(KtransposeU == 0)
+ or nx.any(nx.isnan(u)) or nx.any(nx.isnan(v))
+ or nx.any(nx.isinf(u)) or nx.any(nx.isinf(v))):
# we have reached the machine precision
# come back to previous solution and quit loop
print('Warning: numerical errors at iteration', cpt)
@@ -403,11 +404,11 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
# we can speed up the process by checking for the error only all
# the 10th iterations
if n_hists:
- np.einsum('ik,ij,jk->jk', u, K, v, out=tmp2)
+ tmp2 = nx.einsum('ik,ij,jk->jk', u, K, v)
else:
# compute right marginal tmp2= (diag(u)Kdiag(v))^T1
- np.einsum('i,ij,j->j', u, K, v, out=tmp2)
- err = np.linalg.norm(tmp2 - b) # violation of marginal
+ tmp2 = nx.einsum('i,ij,j->j', u, K, v)
+ err = nx.norm(tmp2 - b) # violation of marginal
if log:
log['err'].append(err)
@@ -422,7 +423,7 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
log['v'] = v
if n_hists: # return only loss
- res = np.einsum('ik,ij,jk,ij->k', u, K, v, M)
+ res = nx.einsum('ik,ij,jk,ij->k', u, K, v, M)
if log:
return res, log
else:
diff --git a/ot/gpu/__init__.py b/ot/gpu/__init__.py
index 7478fb9..e939610 100644
--- a/ot/gpu/__init__.py
+++ b/ot/gpu/__init__.py
@@ -25,6 +25,8 @@ result of the function with parameter ``to_numpy=False``.
#
# License: MIT License
+import warnings
+
from . import bregman
from . import da
from .bregman import sinkhorn
@@ -34,7 +36,7 @@ from . import utils
from .utils import dist, to_gpu, to_np
-
+warnings.warn('This module will be deprecated in the next minor release of POT', category=DeprecationWarning)
__all__ = ["utils", "dist", "sinkhorn",
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py
index d5c3a5e..c8c9da6 100644
--- a/ot/lp/__init__.py
+++ b/ot/lp/__init__.py
@@ -18,8 +18,9 @@ from . import cvx
from .cvx import barycenter
# import compiled emd
from .emd_wrap import emd_c, check_result, emd_1d_sorted
-from ..utils import dist
+from ..utils import dist, list_to_array
from ..utils import parmap
+from ..backend import get_backend
__all__ = ['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx',
'emd_1d', 'emd2_1d', 'wasserstein_1d']
@@ -176,8 +177,7 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True):
r"""Solves the Earth Movers distance problem and returns the OT matrix
- .. math::
- \gamma = arg\min_\gamma <\gamma,M>_F
+ .. math:: \gamma = arg\min_\gamma <\gamma,M>_F
s.t. \gamma 1 = a
@@ -189,37 +189,41 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True):
- M is the metric cost matrix
- a and b are the sample weights
- .. warning::
- Note that the M matrix needs to be a C-order numpy.array in float64
- format.
+ .. warning:: Note that the M matrix in numpy needs to be a C-order
+ numpy.array in float64 format. It will be converted if not in this
+ format
+
+ .. note:: This function is backend-compatible and will work on arrays
+ from all compatible backends.
Uses the algorithm proposed in [1]_
Parameters
----------
- a : (ns,) numpy.ndarray, float64
+ a : (ns,) array-like, float
Source histogram (uniform weight if empty list)
- b : (nt,) numpy.ndarray, float64
- Target histogram (uniform weight if empty list)
- M : (ns,nt) numpy.ndarray, float64
- Loss matrix (c-order array with type float64)
- numItermax : int, optional (default=100000)
+ b : (nt,) array-like, float
+ Target histogram (uniform weight if empty list)
+ M : (ns,nt) array-like, float
+ Loss matrix (c-order array in numpy with type float64)
+ numItermax : int, optional (default=100000)
The maximum number of iterations before stopping the optimization
- algorithm if it has not converged.
- log: bool, optional (default=False)
- If True, returns a dictionary containing the cost and dual
- variables. Otherwise returns only the optimal transportation matrix.
+ algorithm if it has not converged.
+ log: bool, optional (default=False)
+ If True, returns a dictionary containing the cost and dual variables.
+ Otherwise returns only the optimal transportation matrix.
center_dual: boolean, optional (default=True)
- If True, centers the dual potential using function
+ If True, centers the dual potential using function
:ref:`center_ot_dual`.
Returns
-------
- gamma: (ns x nt) numpy.ndarray
- Optimal transportation matrix for the given parameters
- log: dict
- If input log is true, a dictionary containing the cost and dual
- variables and exit status
+ gamma: array-like, shape (ns, nt)
+ Optimal transportation matrix for the given
+ parameters
+ log: dict, optional
+ If input log is true, a dictionary containing the
+ cost and dual variables and exit status
Examples
@@ -232,26 +236,37 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True):
>>> a=[.5,.5]
>>> b=[.5,.5]
>>> M=[[0.,1.],[1.,0.]]
- >>> ot.emd(a,b,M)
+ >>> ot.emd(a, b, M)
array([[0.5, 0. ],
[0. , 0.5]])
References
----------
- .. [1] Bonneel, N., Van De Panne, M., Paris, S., & Heidrich, W.
- (2011, December). Displacement interpolation using Lagrangian mass
- transport. In ACM Transactions on Graphics (TOG) (Vol. 30, No. 6, p.
- 158). ACM.
+ .. [1] Bonneel, N., Van De Panne, M., Paris, S., & Heidrich, W. (2011,
+ December). Displacement interpolation using Lagrangian mass transport.
+ In ACM Transactions on Graphics (TOG) (Vol. 30, No. 6, p. 158). ACM.
See Also
--------
- ot.bregman.sinkhorn : Entropic regularized OT
- ot.optim.cg : General regularized OT"""
-
+ ot.bregman.sinkhorn : Entropic regularized OT ot.optim.cg : General
+ regularized OT"""
+
+ # convert to numpy if list
+ a, b, M = list_to_array(a, b, M)
+
+ a0, b0, M0 = a, b, M
+ nx = get_backend(M0, a0, b0)
+
+ # convert to numpy
+ M = nx.to_numpy(M)
+ a = nx.to_numpy(a)
+ b = nx.to_numpy(b)
+
+ # ensure float64
a = np.asarray(a, dtype=np.float64)
b = np.asarray(b, dtype=np.float64)
- M = np.asarray(M, dtype=np.float64)
+ M = np.asarray(M, dtype=np.float64, order='C')
# if empty array given then use uniform distributions
if len(a) == 0:
@@ -262,6 +277,11 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True):
assert (a.shape[0] == M.shape[0] and b.shape[0] == M.shape[1]), \
"Dimension mismatch, check dimensions of M with a and b"
+ # ensure that same mass
+ np.testing.assert_almost_equal(a.sum(0),
+ b.sum(0), err_msg='a and b vector must have the same sum')
+ b=b*a.sum()/b.sum()
+
asel = a != 0
bsel = b != 0
@@ -277,12 +297,12 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True):
if log:
log = {}
log['cost'] = cost
- log['u'] = u
- log['v'] = v
+ log['u'] = nx.from_numpy(u, type_as=a0)
+ log['v'] = nx.from_numpy(v, type_as=b0)
log['warning'] = result_code_string
log['result_code'] = result_code
- return G, log
- return G
+ return nx.from_numpy(G, type_as=M0), log
+ return nx.from_numpy(G, type_as=M0)
def emd2(a, b, M, processes=multiprocessing.cpu_count(),
@@ -303,20 +323,19 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
- M is the metric cost matrix
- a and b are the sample weights
- .. warning::
- Note that the M matrix needs to be a C-order numpy.array in float64
- format.
+ .. note:: This function is backend-compatible and will work on arrays
+ from all compatible backends.
Uses the algorithm proposed in [1]_
Parameters
----------
- a : (ns,) numpy.ndarray, float64
+ a : (ns,) array-like, float64
Source histogram (uniform weight if empty list)
- b : (nt,) numpy.ndarray, float64
+ b : (nt,) array-like, float64
Target histogram (uniform weight if empty list)
- M : (ns,nt) numpy.ndarray, float64
- Loss matrix (c-order array with type float64)
+ M : (ns,nt) array-like, float64
+ Loss matrix (for numpy c-order array with type float64)
processes : int, optional (default=nb cpu)
Nb of processes used for multiple emd computation (not used on windows)
numItermax : int, optional (default=100000)
@@ -333,9 +352,9 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
Returns
-------
- W: float
+ W: float, array-like
Optimal transportation loss for the given parameters
- log: dictnp
+ log: dict
If input log is true, a dictionary containing dual
variables and exit status
@@ -367,12 +386,22 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
ot.bregman.sinkhorn : Entropic regularized OT
ot.optim.cg : General regularized OT"""
+ a, b, M = list_to_array(a, b, M)
+
+ a0, b0, M0 = a, b, M
+ nx = get_backend(M0, a0, b0)
+
+ # convert to numpy
+ M = nx.to_numpy(M)
+ a = nx.to_numpy(a)
+ b = nx.to_numpy(b)
+
a = np.asarray(a, dtype=np.float64)
b = np.asarray(b, dtype=np.float64)
- M = np.asarray(M, dtype=np.float64)
+ M = np.asarray(M, dtype=np.float64, order= 'C')
# problem with pikling Forks
- if sys.platform.endswith('win32'):
+ if sys.platform.endswith('win32') or not nx.__name__ == 'numpy':
processes = 1
# if empty array given then use uniform distributions
@@ -400,12 +429,15 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
result_code_string = check_result(result_code)
log = {}
+ G = nx.from_numpy(G, type_as=M0)
if return_matrix:
log['G'] = G
- log['u'] = u
- log['v'] = v
+ log['u'] = nx.from_numpy(u, type_as=a0)
+ log['v'] = nx.from_numpy(v, type_as=b0)
log['warning'] = result_code_string
log['result_code'] = result_code
+ cost = nx.set_gradients(nx.from_numpy(cost, type_as=M0),
+ (a0,b0, M0), (log['u'], log['v'], G))
return [cost, log]
else:
def f(b):
@@ -418,6 +450,11 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
if np.any(~asel) or np.any(~bsel):
u, v = estimate_dual_null_weights(u, v, a, b, M)
+ G = nx.from_numpy(G, type_as=M0)
+ cost = nx.set_gradients(nx.from_numpy(cost, type_as=M0),
+ (a0,b0, M0), (nx.from_numpy(u, type_as=a0),
+ nx.from_numpy(v, type_as=b0),G))
+
check_result(result_code)
return cost
@@ -637,6 +674,10 @@ def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True,
if b.ndim == 0 or len(b) == 0:
b = np.ones((x_b.shape[0],), dtype=np.float64) / x_b.shape[0]
+ # ensure that same mass
+ np.testing.assert_almost_equal(a.sum(0),b.sum(0),err_msg='a and b vector must have the same sum')
+ b=b*a.sum()/b.sum()
+
x_a_1d = x_a.reshape((-1,))
x_b_1d = x_b.reshape((-1,))
perm_a = np.argsort(x_a_1d)
diff --git a/ot/utils.py b/ot/utils.py
index 544c569..4dac0c5 100644
--- a/ot/utils.py
+++ b/ot/utils.py
@@ -16,6 +16,7 @@ from scipy.spatial.distance import cdist
import sys
import warnings
from inspect import signature
+from .backend import get_backend
__time_tic_toc = time.time()
@@ -41,8 +42,11 @@ def toq():
def kernel(x1, x2, method='gaussian', sigma=1, **kwargs):
"""Compute kernel matrix"""
+
+ nx = get_backend(x1, x2)
+
if method.lower() in ['gaussian', 'gauss', 'rbf']:
- K = np.exp(-dist(x1, x2) / (2 * sigma**2))
+ K = nx.exp(-dist(x1, x2) / (2 * sigma**2))
return K
@@ -52,6 +56,66 @@ def laplacian(x):
return L
+def list_to_array(*lst):
+ """ Convert a list if in numpy format """
+ if len(lst) > 1:
+ return [np.array(a) if isinstance(a, list) else a for a in lst]
+ else:
+ return np.array(lst[0]) if isinstance(lst[0], list) else lst[0]
+
+
+def proj_simplex(v, z=1):
+ r""" compute the closest point (orthogonal projection) on the
+ generalized (n-1)-simplex of a vector v wrt. to the Euclidean
+ distance, thus solving:
+ .. math::
+ \mathcal{P}(w) \in arg\min_\gamma || \gamma - v ||_2
+
+ s.t. \gamma^T 1= z
+
+ \gamma\geq 0
+
+ If v is a 2d array, compute all the projections wrt. axis 0
+
+ .. note:: This function is backend-compatible and will work on arrays
+ from all compatible backends.
+
+ Parameters
+ ----------
+ v : {array-like}, shape (n, d)
+ z : int, optional
+ 'size' of the simplex (each vectors sum to z, 1 by default)
+
+ Returns
+ -------
+ h : ndarray, shape (n,d)
+ Array of projections on the simplex
+ """
+ nx = get_backend(v)
+ n = v.shape[0]
+ if v.ndim == 1:
+ d1 = 1
+ v = v[:, None]
+ else:
+ d1 = 0
+ d = v.shape[1]
+
+ # sort u in ascending order
+ u = nx.sort(v, axis=0)
+ # take the descending order
+ u = nx.flip(u, 0)
+ cssv = nx.cumsum(u, axis=0) - z
+ ind = nx.arange(n, type_as=v)[:, None] + 1
+ cond = u - cssv / ind > 0
+ rho = nx.sum(cond, 0)
+ theta = cssv[rho - 1, nx.arange(d)] / rho
+ w = nx.maximum(v - theta[None, :], nx.zeros(v.shape, type_as=v))
+ if d1:
+ return w[:, 0]
+ else:
+ return w
+
+
def unif(n):
""" return a uniform histogram of length n (simplex)
@@ -84,52 +148,68 @@ def euclidean_distances(X, Y, squared=False):
"""
Considering the rows of X (and Y=X) as vectors, compute the
distance matrix between each pair of vectors.
+
+ .. note:: This function is backend-compatible and will work on arrays
+ from all compatible backends.
+
Parameters
----------
X : {array-like}, shape (n_samples_1, n_features)
Y : {array-like}, shape (n_samples_2, n_features)
squared : boolean, optional
Return squared Euclidean distances.
+
Returns
-------
distances : {array}, shape (n_samples_1, n_samples_2)
"""
- XX = np.einsum('ij,ij->i', X, X)[:, np.newaxis]
- YY = np.einsum('ij,ij->i', Y, Y)[np.newaxis, :]
- distances = np.dot(X, Y.T)
- distances *= -2
- distances += XX
- distances += YY
- np.maximum(distances, 0, out=distances)
+
+ nx = get_backend(X, Y)
+
+ a2 = nx.einsum('ij,ij->i', X, X)
+ b2 = nx.einsum('ij,ij->i', Y, Y)
+
+ c = -2 * nx.dot(X, Y.T)
+ c += a2[:, None]
+ c += b2[None, :]
+
+ c = nx.maximum(c, 0)
+
+ if not squared:
+ c = nx.sqrt(c)
+
if X is Y:
- # Ensure that distances between vectors and themselves are set to 0.0.
- # This may not be the case due to floating point rounding errors.
- distances.flat[::distances.shape[0] + 1] = 0.0
- return distances if squared else np.sqrt(distances, out=distances)
+ c = c * (1 - nx.eye(X.shape[0], type_as=c))
+
+ return c
def dist(x1, x2=None, metric='sqeuclidean'):
- """Compute distance between samples in x1 and x2 using function scipy.spatial.distance.cdist
+ """Compute distance between samples in x1 and x2
+
+ .. note:: This function is backend-compatible and will work on arrays
+ from all compatible backends.
Parameters
----------
- x1 : ndarray, shape (n1,d)
+ x1 : array-like, shape (n1,d)
matrix with n1 samples of size d
- x2 : array, shape (n2,d), optional
+ x2 : array-like, shape (n2,d), optional
matrix with n2 samples of size d (if None then x2=x1)
metric : str | callable, optional
- Name of the metric to be computed (full list in the doc of scipy), If a string,
- the distance function can be 'braycurtis', 'canberra', 'chebyshev', 'cityblock',
- 'correlation', 'cosine', 'dice', 'euclidean', 'hamming', 'jaccard', 'kulsinski',
- 'mahalanobis', 'matching', 'minkowski', 'rogerstanimoto', 'russellrao', 'seuclidean',
+ 'sqeuclidean' or 'euclidean' on all backends. On numpy the function also
+ accepts from the scipy.spatial.distance.cdist function : 'braycurtis',
+ 'canberra', 'chebyshev', 'cityblock', 'correlation', 'cosine', 'dice',
+ 'euclidean', 'hamming', 'jaccard', 'kulsinski', 'mahalanobis',
+ 'matching', 'minkowski', 'rogerstanimoto', 'russellrao', 'seuclidean',
'sokalmichener', 'sokalsneath', 'sqeuclidean', 'wminkowski', 'yule'.
Returns
-------
- M : np.array (n1,n2)
+ M : array-like, shape (n1, n2)
distance matrix computed with given metric
"""
@@ -137,7 +217,13 @@ def dist(x1, x2=None, metric='sqeuclidean'):
x2 = x1
if metric == "sqeuclidean":
return euclidean_distances(x1, x2, squared=True)
- return cdist(x1, x2, metric=metric)
+ elif metric == "euclidean":
+ return euclidean_distances(x1, x2, squared=False)
+ else:
+ if not get_backend(x1, x2).__name__ == 'numpy':
+ raise NotImplementedError()
+ else:
+ return cdist(x1, x2, metric=metric)
def dist0(n, method='lin_square'):
diff --git a/requirements.txt b/requirements.txt
index 331dd57..4353247 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -7,4 +7,7 @@ pymanopt==0.2.4; python_version <'3'
pymanopt; python_version >= '3'
cvxopt
scikit-learn
+torch
+jax
+jaxlib
pytest \ No newline at end of file
diff --git a/test/test_backend.py b/test/test_backend.py
new file mode 100644
index 0000000..bc5b00c
--- /dev/null
+++ b/test/test_backend.py
@@ -0,0 +1,364 @@
+"""Tests for backend module """
+
+# Author: Remi Flamary <remi.flamary@polytechnique.edu>
+#
+# License: MIT License
+
+import ot
+import ot.backend
+from ot.backend import torch, jax
+
+import pytest
+
+import numpy as np
+from numpy.testing import assert_array_almost_equal_nulp
+
+from ot.backend import get_backend, get_backend_list, to_numpy
+
+
+backend_list = get_backend_list()
+
+
+def test_get_backend_list():
+
+ lst = get_backend_list()
+
+ assert len(lst) > 0
+ assert isinstance(lst[0], ot.backend.NumpyBackend)
+
+
+@pytest.mark.parametrize('nx', backend_list)
+def test_to_numpy(nx):
+
+ v = nx.zeros(10)
+ M = nx.ones((10, 10))
+
+ v2 = to_numpy(v)
+ assert isinstance(v2, np.ndarray)
+
+ v2, M2 = to_numpy(v, M)
+ assert isinstance(M2, np.ndarray)
+
+
+def test_get_backend():
+
+ A = np.zeros((3, 2))
+ B = np.zeros((3, 1))
+
+ nx = get_backend(A)
+ assert nx.__name__ == 'numpy'
+
+ nx = get_backend(A, B)
+ assert nx.__name__ == 'numpy'
+
+ # error if no parameters
+ with pytest.raises(ValueError):
+ get_backend()
+
+ # error if unknown types
+ with pytest.raises(ValueError):
+ get_backend(1, 2.0)
+
+ # test torch
+ if torch:
+
+ A2 = torch.from_numpy(A)
+ B2 = torch.from_numpy(B)
+
+ nx = get_backend(A2)
+ assert nx.__name__ == 'torch'
+
+ nx = get_backend(A2, B2)
+ assert nx.__name__ == 'torch'
+
+ # test not unique types in input
+ with pytest.raises(ValueError):
+ get_backend(A, B2)
+
+ if jax:
+
+ A2 = jax.numpy.array(A)
+ B2 = jax.numpy.array(B)
+
+ nx = get_backend(A2)
+ assert nx.__name__ == 'jax'
+
+ nx = get_backend(A2, B2)
+ assert nx.__name__ == 'jax'
+
+ # test not unique types in input
+ with pytest.raises(ValueError):
+ get_backend(A, B2)
+
+
+@pytest.mark.parametrize('nx', backend_list)
+def test_convert_between_backends(nx):
+
+ A = np.zeros((3, 2))
+ B = np.zeros((3, 1))
+
+ A2 = nx.from_numpy(A)
+ B2 = nx.from_numpy(B)
+
+ assert isinstance(A2, nx.__type__)
+ assert isinstance(B2, nx.__type__)
+
+ nx2 = get_backend(A2, B2)
+
+ assert nx2.__name__ == nx.__name__
+
+ assert_array_almost_equal_nulp(nx.to_numpy(A2), A)
+ assert_array_almost_equal_nulp(nx.to_numpy(B2), B)
+
+
+def test_empty_backend():
+
+ rnd = np.random.RandomState(0)
+ M = rnd.randn(10, 3)
+ v = rnd.randn(3)
+
+ nx = ot.backend.Backend()
+
+ with pytest.raises(NotImplementedError):
+ nx.from_numpy(M)
+ with pytest.raises(NotImplementedError):
+ nx.to_numpy(M)
+ with pytest.raises(NotImplementedError):
+ nx.set_gradients(0, 0, 0)
+ with pytest.raises(NotImplementedError):
+ nx.zeros((10, 3))
+ with pytest.raises(NotImplementedError):
+ nx.ones((10, 3))
+ with pytest.raises(NotImplementedError):
+ nx.arange(10, 1, 2)
+ with pytest.raises(NotImplementedError):
+ nx.full((10, 3), 3.14)
+ with pytest.raises(NotImplementedError):
+ nx.eye((10, 3))
+ with pytest.raises(NotImplementedError):
+ nx.sum(M)
+ with pytest.raises(NotImplementedError):
+ nx.cumsum(M)
+ with pytest.raises(NotImplementedError):
+ nx.max(M)
+ with pytest.raises(NotImplementedError):
+ nx.min(M)
+ with pytest.raises(NotImplementedError):
+ nx.maximum(v, v)
+ with pytest.raises(NotImplementedError):
+ nx.minimum(v, v)
+ with pytest.raises(NotImplementedError):
+ nx.abs(M)
+ with pytest.raises(NotImplementedError):
+ nx.log(M)
+ with pytest.raises(NotImplementedError):
+ nx.exp(M)
+ with pytest.raises(NotImplementedError):
+ nx.sqrt(M)
+ with pytest.raises(NotImplementedError):
+ nx.dot(v, v)
+ with pytest.raises(NotImplementedError):
+ nx.norm(M)
+ with pytest.raises(NotImplementedError):
+ nx.exp(M)
+ with pytest.raises(NotImplementedError):
+ nx.any(M)
+ with pytest.raises(NotImplementedError):
+ nx.isnan(M)
+ with pytest.raises(NotImplementedError):
+ nx.isinf(M)
+ with pytest.raises(NotImplementedError):
+ nx.einsum('ij->i', M)
+ with pytest.raises(NotImplementedError):
+ nx.sort(M)
+ with pytest.raises(NotImplementedError):
+ nx.argsort(M)
+ with pytest.raises(NotImplementedError):
+ nx.flip(M)
+
+
+@pytest.mark.parametrize('backend', backend_list)
+def test_func_backends(backend):
+
+ rnd = np.random.RandomState(0)
+ M = rnd.randn(10, 3)
+ v = rnd.randn(3)
+ val = np.array([1.0])
+
+ lst_tot = []
+
+ for nx in [ot.backend.NumpyBackend(), backend]:
+
+ print('Backend: ', nx.__name__)
+
+ lst_b = []
+ lst_name = []
+
+ Mb = nx.from_numpy(M)
+ vb = nx.from_numpy(v)
+ val = nx.from_numpy(val)
+
+ A = nx.set_gradients(val, v, v)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('set_gradients')
+
+ A = nx.zeros((10, 3))
+ A = nx.zeros((10, 3), type_as=Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('zeros')
+
+ A = nx.ones((10, 3))
+ A = nx.ones((10, 3), type_as=Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('ones')
+
+ A = nx.arange(10, 1, 2)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('arange')
+
+ A = nx.full((10, 3), 3.14)
+ A = nx.full((10, 3), 3.14, type_as=Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('full')
+
+ A = nx.eye(10, 3)
+ A = nx.eye(10, 3, type_as=Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('eye')
+
+ A = nx.sum(Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('sum')
+
+ A = nx.sum(Mb, axis=1, keepdims=True)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('sum(axis)')
+
+ A = nx.cumsum(Mb, 0)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('cumsum(axis)')
+
+ A = nx.max(Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('max')
+
+ A = nx.max(Mb, axis=1, keepdims=True)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('max(axis)')
+
+ A = nx.min(Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('min')
+
+ A = nx.min(Mb, axis=1, keepdims=True)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('min(axis)')
+
+ A = nx.maximum(vb, 0)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('maximum')
+
+ A = nx.minimum(vb, 0)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('minimum')
+
+ A = nx.abs(Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('abs')
+
+ A = nx.log(A)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('log')
+
+ A = nx.exp(Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('exp')
+
+ A = nx.sqrt(nx.abs(Mb))
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('sqrt')
+
+ A = nx.dot(vb, vb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('dot(v,v)')
+
+ A = nx.dot(Mb, vb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('dot(M,v)')
+
+ A = nx.dot(Mb, Mb.T)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('dot(M,M)')
+
+ A = nx.norm(vb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('norm')
+
+ A = nx.any(vb > 0)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('any')
+
+ A = nx.isnan(vb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('isnan')
+
+ A = nx.isinf(vb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('isinf')
+
+ A = nx.einsum('ij->i', Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('einsum(ij->i)')
+
+ A = nx.einsum('ij,j->i', Mb, vb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('nx.einsum(ij,j->i)')
+
+ A = nx.einsum('ij->i', Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('nx.einsum(ij->i)')
+
+ A = nx.sort(Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('sort')
+
+ A = nx.argsort(Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('argsort')
+
+ A = nx.flip(Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('flip')
+
+ lst_tot.append(lst_b)
+
+ lst_np = lst_tot[0]
+ lst_b = lst_tot[1]
+
+ for a1, a2, name in zip(lst_np, lst_b, lst_name):
+ if not np.allclose(a1, a2):
+ print('Assert fail on: ', name)
+ assert np.allclose(a1, a2, atol=1e-7)
+
+
+def test_gradients_backends():
+
+ rnd = np.random.RandomState(0)
+ v = rnd.randn(10)
+ c = rnd.randn(1)
+
+ if torch:
+
+ nx = ot.backend.TorchBackend()
+
+ v2 = torch.tensor(v, requires_grad=True)
+ c2 = torch.tensor(c, requires_grad=True)
+
+ val = c2 * torch.sum(v2 * v2)
+
+ val2 = nx.set_gradients(val, (v2, c2), (v2, c2))
+
+ val2.backward()
+
+ assert torch.equal(v2.grad, v2)
+ assert torch.equal(c2.grad, c2)
diff --git a/test/test_bregman.py b/test/test_bregman.py
index 1ebd21f..7c5162a 100644
--- a/test/test_bregman.py
+++ b/test/test_bregman.py
@@ -9,6 +9,10 @@ import numpy as np
import pytest
import ot
+from ot.backend import get_backend_list
+from ot.backend import torch
+
+backend_list = get_backend_list()
def test_sinkhorn():
@@ -30,6 +34,76 @@ def test_sinkhorn():
u, G.sum(0), atol=1e-05) # cf convergence sinkhorn
+@pytest.mark.parametrize('nx', backend_list)
+def test_sinkhorn_backends(nx):
+ n_samples = 100
+ n_features = 2
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n_samples, n_features)
+ y = rng.randn(n_samples, n_features)
+ a = ot.utils.unif(n_samples)
+
+ M = ot.dist(x, y)
+
+ G = ot.sinkhorn(a, a, M, 1)
+
+ ab = nx.from_numpy(a)
+ Mb = nx.from_numpy(M)
+
+ Gb = ot.sinkhorn(ab, ab, Mb, 1)
+
+ np.allclose(G, nx.to_numpy(Gb))
+
+
+@pytest.mark.parametrize('nx', backend_list)
+def test_sinkhorn2_backends(nx):
+ n_samples = 100
+ n_features = 2
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n_samples, n_features)
+ y = rng.randn(n_samples, n_features)
+ a = ot.utils.unif(n_samples)
+
+ M = ot.dist(x, y)
+
+ G = ot.sinkhorn(a, a, M, 1)
+
+ ab = nx.from_numpy(a)
+ Mb = nx.from_numpy(M)
+
+ Gb = ot.sinkhorn2(ab, ab, Mb, 1)
+
+ np.allclose(G, nx.to_numpy(Gb))
+
+
+def test_sinkhorn2_gradients():
+ n_samples = 100
+ n_features = 2
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n_samples, n_features)
+ y = rng.randn(n_samples, n_features)
+ a = ot.utils.unif(n_samples)
+
+ M = ot.dist(x, y)
+
+ if torch:
+
+ a1 = torch.tensor(a, requires_grad=True)
+ b1 = torch.tensor(a, requires_grad=True)
+ M1 = torch.tensor(M, requires_grad=True)
+
+ val = ot.sinkhorn2(a1, b1, M1, 1)
+
+ val.backward()
+
+ assert a1.shape == a1.grad.shape
+ assert b1.shape == b1.grad.shape
+ assert M1.shape == M1.grad.shape
+
+
def test_sinkhorn_empty():
# test sinkhorn
n = 100
diff --git a/test/test_gromov.py b/test/test_gromov.py
index 43da9fc..81138ca 100644
--- a/test/test_gromov.py
+++ b/test/test_gromov.py
@@ -181,7 +181,7 @@ def test_fgw():
M = ot.dist(ys, yt)
M /= M.max()
- G = ot.gromov.fused_gromov_wasserstein(M, C1, C2, p, q, 'square_loss', alpha=0.5)
+ G, log = ot.gromov.fused_gromov_wasserstein(M, C1, C2, p, q, 'square_loss', alpha=0.5, log=True)
# check constratints
np.testing.assert_allclose(
@@ -242,9 +242,9 @@ def test_fgw_barycenter():
init_X = np.random.randn(n_samples, ys.shape[1])
- X, C = ot.gromov.fgw_barycenters(n_samples, [ys, yt], [C1, C2], [ot.unif(ns), ot.unif(nt)], [.5, .5], 0.5,
- fixed_structure=False, fixed_features=True, init_X=init_X,
- p=ot.unif(n_samples), loss_fun='square_loss',
- max_iter=100, tol=1e-3)
+ X, C, log = ot.gromov.fgw_barycenters(n_samples, [ys, yt], [C1, C2], [ot.unif(ns), ot.unif(nt)], [.5, .5], 0.5,
+ fixed_structure=False, fixed_features=True, init_X=init_X,
+ p=ot.unif(n_samples), loss_fun='square_loss',
+ max_iter=100, tol=1e-3, log=True)
np.testing.assert_allclose(C.shape, (n_samples, n_samples))
np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1]))
diff --git a/test/test_ot.py b/test/test_ot.py
index f45e4c9..3e953dc 100644
--- a/test/test_ot.py
+++ b/test/test_ot.py
@@ -12,9 +12,12 @@ from scipy.stats import wasserstein_distance
import ot
from ot.datasets import make_1D_gauss as gauss
+from ot.backend import get_backend_list, torch
+backend_list = get_backend_list()
-def test_emd_dimension_mismatch():
+
+def test_emd_dimension_and_mass_mismatch():
# test emd and emd2 for dimension mismatch
n_samples = 100
n_features = 2
@@ -29,6 +32,80 @@ def test_emd_dimension_mismatch():
np.testing.assert_raises(AssertionError, ot.emd2, a, a, M)
+ b = a.copy()
+ a[0] = 100
+ np.testing.assert_raises(AssertionError, ot.emd, a, b, M)
+
+
+@pytest.mark.parametrize('nx', backend_list)
+def test_emd_backends(nx):
+ n_samples = 100
+ n_features = 2
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n_samples, n_features)
+ y = rng.randn(n_samples, n_features)
+ a = ot.utils.unif(n_samples)
+
+ M = ot.dist(x, y)
+
+ G = ot.emd(a, a, M)
+
+ ab = nx.from_numpy(a)
+ Mb = nx.from_numpy(M)
+
+ Gb = ot.emd(ab, ab, Mb)
+
+ np.allclose(G, nx.to_numpy(Gb))
+
+
+@pytest.mark.parametrize('nx', backend_list)
+def test_emd2_backends(nx):
+ n_samples = 100
+ n_features = 2
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n_samples, n_features)
+ y = rng.randn(n_samples, n_features)
+ a = ot.utils.unif(n_samples)
+
+ M = ot.dist(x, y)
+
+ val = ot.emd2(a, a, M)
+
+ ab = nx.from_numpy(a)
+ Mb = nx.from_numpy(M)
+
+ valb = ot.emd2(ab, ab, Mb)
+
+ np.allclose(val, nx.to_numpy(valb))
+
+
+def test_emd2_gradients():
+ n_samples = 100
+ n_features = 2
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n_samples, n_features)
+ y = rng.randn(n_samples, n_features)
+ a = ot.utils.unif(n_samples)
+
+ M = ot.dist(x, y)
+
+ if torch:
+
+ a1 = torch.tensor(a, requires_grad=True)
+ b1 = torch.tensor(a, requires_grad=True)
+ M1 = torch.tensor(M, requires_grad=True)
+
+ val = ot.emd2(a1, b1, M1)
+
+ val.backward()
+
+ assert a1.shape == a1.grad.shape
+ assert b1.shape == b1.grad.shape
+ assert M1.shape == M1.grad.shape
+
def test_emd_emd2():
# test emd and emd2 for simple identity
@@ -83,7 +160,7 @@ def test_emd_1d_emd2_1d():
np.testing.assert_allclose(np.ones((m,)) / m, G.sum(0))
# check G is similar
- np.testing.assert_allclose(G, G_1d)
+ np.testing.assert_allclose(G, G_1d, atol=1e-15)
# check AssertionError is raised if called on non 1d arrays
u = np.random.randn(n, 2)
@@ -292,16 +369,6 @@ def test_warnings():
ot.emd(a, b, M, numItermax=1)
assert "numItermax" in str(w[-1].message)
#assert len(w) == 1
- a[0] = 100
- print('Computing {} EMD '.format(2))
- ot.emd(a, b, M)
- assert "infeasible" in str(w[-1].message)
- #assert len(w) == 2
- a[0] = -1
- print('Computing {} EMD '.format(2))
- ot.emd(a, b, M)
- assert "infeasible" in str(w[-1].message)
- #assert len(w) == 3
def test_dual_variables():
diff --git a/test/test_partial.py b/test/test_partial.py
index 121f345..3571e2a 100755
--- a/test/test_partial.py
+++ b/test/test_partial.py
@@ -129,9 +129,9 @@ def test_partial_wasserstein():
# check constratints
np.testing.assert_equal(
- G.sum(1) <= p, [True] * len(p)) # cf convergence wasserstein
+ G.sum(1) - p <= 1e-5, [True] * len(p)) # cf convergence wasserstein
np.testing.assert_equal(
- G.sum(0) <= q, [True] * len(q)) # cf convergence wasserstein
+ G.sum(0) - q <= 1e-5, [True] * len(q)) # cf convergence wasserstein
np.testing.assert_allclose(
np.sum(G), m, atol=1e-04)
diff --git a/test/test_utils.py b/test/test_utils.py
index db9cda6..76b1faa 100644
--- a/test/test_utils.py
+++ b/test/test_utils.py
@@ -4,11 +4,47 @@
#
# License: MIT License
-
+import pytest
import ot
import numpy as np
import sys
+from ot.backend import get_backend_list
+
+backend_list = get_backend_list()
+
+
+@pytest.mark.parametrize('nx', backend_list)
+def test_proj_simplex(nx):
+ n = 10
+ rng = np.random.RandomState(0)
+
+ # test on matrix when projection is done on axis 0
+ x = rng.randn(n, 2)
+ x1 = nx.from_numpy(x)
+
+ # all projections should sum to 1
+ proj = ot.utils.proj_simplex(x1)
+ l1 = np.sum(nx.to_numpy(proj), axis=0)
+ l2 = np.ones(2)
+ np.testing.assert_allclose(l1, l2, atol=1e-5)
+
+ # all projections should sum to 3
+ proj = ot.utils.proj_simplex(x1, 3)
+ l1 = np.sum(nx.to_numpy(proj), axis=0)
+ l2 = 3 * np.ones(2)
+ np.testing.assert_allclose(l1, l2, atol=1e-5)
+
+ # tets on vector
+ x = rng.randn(n)
+ x1 = nx.from_numpy(x)
+
+ # all projections should sum to 1
+ proj = ot.utils.proj_simplex(x1)
+ l1 = np.sum(nx.to_numpy(proj), axis=0)
+ l2 = np.ones(2)
+ np.testing.assert_allclose(l1, l2, atol=1e-5)
+
def test_parmap():
@@ -45,8 +81,8 @@ def test_tic_toc():
def test_kernel():
n = 100
-
- x = np.random.randn(n, 2)
+ rng = np.random.RandomState(0)
+ x = rng.randn(n, 2)
K = ot.utils.kernel(x, x)
@@ -67,7 +103,8 @@ def test_dist():
n = 100
- x = np.random.randn(n, 2)
+ rng = np.random.RandomState(0)
+ x = rng.randn(n, 2)
D = np.zeros((n, n))
for i in range(n):
@@ -78,8 +115,27 @@ def test_dist():
D3 = ot.dist(x)
# dist shoul return squared euclidean
- np.testing.assert_allclose(D, D2)
- np.testing.assert_allclose(D, D3)
+ np.testing.assert_allclose(D, D2, atol=1e-14)
+ np.testing.assert_allclose(D, D3, atol=1e-14)
+
+
+@ pytest.mark.parametrize('nx', backend_list)
+def test_dist_backends(nx):
+
+ n = 100
+ rng = np.random.RandomState(0)
+ x = rng.randn(n, 2)
+ x1 = nx.from_numpy(x)
+
+ lst_metric = ['euclidean', 'sqeuclidean']
+
+ for metric in lst_metric:
+
+ D = ot.dist(x, x, metric=metric)
+ D1 = ot.dist(x1, x1, metric=metric)
+
+ # low atol because jax forces float32
+ np.testing.assert_allclose(D, nx.to_numpy(D1), atol=1e-5)
def test_dist0():
@@ -95,9 +151,11 @@ def test_dots():
n1, n2, n3, n4 = 100, 50, 200, 100
- A = np.random.randn(n1, n2)
- B = np.random.randn(n2, n3)
- C = np.random.randn(n3, n4)
+ rng = np.random.RandomState(0)
+
+ A = rng.randn(n1, n2)
+ B = rng.randn(n2, n3)
+ C = rng.randn(n3, n4)
X1 = ot.utils.dots(A, B, C)