summaryrefslogtreecommitdiff
path: root/docs
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2021-06-01 10:10:54 +0200
committerGitHub <noreply@github.com>2021-06-01 10:10:54 +0200
commit184f8f4f7ac78f1dd7f653496d2753211a4e3426 (patch)
tree483a7274c91030fd644de49b03a5fad04af9deba /docs
parent1f16614954e2522fbdb1598c5b1f5c3630c68472 (diff)
[MRG] POT numpy/torch/jax backends (#249)
* add numpy and torch backends * stat sets on functions * proper import * install recent torch on windows * install recent torch on windows * now testing all functions in backedn * add jax backedn * clenaup windowds * proper convert for jax backedn * pep8 * try again windows tests * test jax conversion * try proper widows tests * emd fuction ses backedn * better test partial OT * proper tests to_numpy and teplate Backend * pep8 * pep8 x2 * feaking sinkhorn works with torch * sinkhorn2 compatible * working ot.emd2 * important detach * it should work * jax autodiff emd * pep8 * no tast same for jax * new independat tests per backedn * freaking pep8 * add tests for gradients * deprecate ot.gpu * worging dist function * working dist * dist done in backedn * not in * remove indexing * change accuacy for jax * first pull backend * projection simplex * projection simplex * projection simplex * projection simplex no ci * projection simplex no ci * projection simplex no ci * pep8 * add backedn discusion to quickstart guide * projection simplex no ci * projection simplex no ci * projection simplex no ci * pep8 + better doc * proper links * corect doctest * big debug documentation * doctest again * doctest again bis * doctest again ter (last one or i kill myself) * backend test + doc proj simplex * correction test_utils * correction test_utils * correction cumsum * correction flip * correction flip v2 * more debug * more debug * more debug + pep8 * pep8 * argh * proj_simplex * backedn works for sort * proj simplex * jax sucks * update doc * Update test/test_utils.py Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update docs/source/quickstart.rst Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update docs/source/quickstart.rst Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update docs/source/quickstart.rst Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update docs/source/readme.rst Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update test/test_utils.py Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update ot/utils.py Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update docs/source/readme.rst Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update ot/lp/__init__.py Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * begin comment alex * comment alex part 2 * optimize test gromov * proj_simplex on vectors * add awesome gradient decsnt example on the weights * pep98 of course * proof read example by alex * pep8 again * encoding oos in translation * correct legend Co-authored-by: Nicolas Courty <ncourty@irisa.fr> Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
Diffstat (limited to 'docs')
-rw-r--r--docs/source/quickstart.rst68
-rw-r--r--docs/source/readme.rst70
2 files changed, 106 insertions, 32 deletions
diff --git a/docs/source/quickstart.rst b/docs/source/quickstart.rst
index cf5d6aa..fd046a1 100644
--- a/docs/source/quickstart.rst
+++ b/docs/source/quickstart.rst
@@ -15,6 +15,12 @@ are also available as notebooks on the POT Github.
in ML applications we refer the reader to the following `OTML tutorial
<https://remi.flamary.com/cours/tuto_otml.html>`_.
+.. note::
+
+ Since version 0.8, POT provides a backend to automatically solve some OT
+ problems independently from the toolbox used by the user (numpy/torch/jax).
+ We provide a discussion about which functions are compatible in section
+ `Backend section <#solving-ot-with-multiple-backends>`_ .
Why Optimal Transport ?
@@ -158,7 +164,6 @@ Wasserstein but has better computational and
`statistical properties <https://arxiv.org/pdf/1910.04091.pdf>`_.
-
Optimal transport and Wasserstein distance
------------------------------------------
@@ -922,6 +927,13 @@ The implementations of FGW and FGW barycenter is provided in functions
GPU acceleration
^^^^^^^^^^^^^^^^
+.. warning::
+
+ The :any:`ot.gpu` has been deprecated since the release 0.8 of POT and
+ should not be used. The GPU implementation (in Pytorch for instance) can be
+ used with the novel backends using the compatible functions from POT.
+
+
We provide several implementation of our OT solvers in :any:`ot.gpu`. Those
implementations use the :code:`cupy` toolbox that obviously need to be installed.
@@ -950,6 +962,60 @@ explicitly.
use it you have to specifically import it with :code:`import ot.gpu` .
+Solving OT with Multiple backends
+---------------------------------
+
+.. _backends_section:
+
+Since version 0.8, POT provides a backend that allows to code solvers
+independently from the type of the input arrays. The idea is to provide the user
+with a package that works seamlessly and returns a solution for instance as a
+Pytorch tensors when the function has Pytorch tensors as input.
+
+
+How it works
+^^^^^^^^^^^^
+
+The aim of the backend is to use the same function independently of the type of
+the input arrays.
+
+For instance when executing the following code
+
+.. code:: python
+
+ # a and b are 1D histograms (sum to 1 and positive)
+ # M is the ground cost matrix
+ T = ot.emd(a, b, M) # exact linear program
+ w = ot.emd2(a, b, M) # Wasserstein computation
+
+the functions :any:`ot.emd` and :any:`ot.emd2` can take inputs of the type
+:any:`numpy.array`, :any:`torch.tensor` or :any:`jax.numpy.array`. The output of
+the function will be the same type as the inputs and on the same device. When
+possible all computations are done on the same device and also when possible the
+output will be differentiable with respect to the input of the function.
+
+
+
+List of compatible Backends
+^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+- `Numpy <https://numpy.org/>`_ (all functions and solvers)
+- `Pytorch <https://pytorch.org/>`_ (all outputs differentiable w.r.t. inputs)
+- `Jax <https://github.com/google/jax>`_ (Some functions are differentiable some require a wrapper)
+
+List of compatible functions
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+This list will get longer for new releases and will hopefully disappear when POT
+become fully implemented with the backend.
+
+- :any:`ot.emd`
+- :any:`ot.emd2`
+- :any:`ot.sinkhorn`
+- :any:`ot.sinkhorn2`
+- :any:`ot.dist`
+
+
FAQ
---
diff --git a/docs/source/readme.rst b/docs/source/readme.rst
index 3b594c2..82d3e6c 100644
--- a/docs/source/readme.rst
+++ b/docs/source/readme.rst
@@ -26,8 +26,7 @@ POT provides the following generic OT solvers (links to examples):
Algorithm <auto_examples/plot_OT_1D.html>`__
[2] , stabilized version [9] [10], greedy Sinkhorn [22] and
`Screening Sinkhorn
- [26] <auto_examples/plot_screenkhorn_1D.html>`__
- with optional GPU implementation (requires cupy).
+ [26] <auto_examples/plot_screenkhorn_1D.html>`__.
- Bregman projections for `Wasserstein
barycenter <auto_examples/barycenters/plot_barycenter_lp_vs_entropic.html>`__
[3], `convolutional
@@ -69,6 +68,11 @@ POT provides the following generic OT solvers (links to examples):
- `Sliced
Wasserstein <auto_examples/sliced-wasserstein/plot_variance.html>`__
[31, 32].
+- `Several
+ backends <https://pythonot.github.io/quickstart.html#solving-ot-with-multiple-backends>`__
+ for easy use of POT with
+ `Pytorch <https://pytorch.org/>`__/`jax <https://github.com/google/jax>`__/`Numpy <https://numpy.org/>`__
+ arrays.
POT provides the following Machine Learning related solvers:
@@ -104,12 +108,14 @@ paper <https://jmlr.org/papers/v22/20-451.html>`__:
::
- Rémi Flamary, Nicolas Courty, Alexandre Gramfort, Mokhtar Z. Alaya, Aurélie Boisbunon, Stanislas Chambon, Laetitia Chapel, Adrien Corenflos, Kilian Fatras, Nemo Fournier, Léo Gautheron, Nathalie T.H. Gayraud, Hicham Janati, Alain Rakotomamonjy, Ievgen Redko, Antoine Rolet, Antony Schutz, Vivien Seguy, Danica J. Sutherland, Romain Tavenard, Alexander Tong, Titouan Vayer;, POT Python Optimal Transport library, Journal of Machine Learning Research, 22(78):1−8, 2021.
+ Rémi Flamary, Nicolas Courty, Alexandre Gramfort, Mokhtar Z. Alaya, Aurélie Boisbunon, Stanislas Chambon, Laetitia Chapel, Adrien Corenflos, Kilian Fatras, Nemo Fournier, Léo Gautheron, Nathalie T.H. Gayraud, Hicham Janati, Alain Rakotomamonjy, Ievgen Redko, Antoine Rolet, Antony Schutz, Vivien Seguy, Danica J. Sutherland, Romain Tavenard, Alexander Tong, Titouan Vayer,
+ POT Python Optimal Transport library,
+ Journal of Machine Learning Research, 22(78):1−8, 2021.
Website: https://pythonot.github.io/
In Bibtex format:
-::
+.. code:: bibtex
@article{flamary2021pot,
author = {R{\'e}mi Flamary and Nicolas Courty and Alexandre Gramfort and Mokhtar Z. Alaya and Aur{\'e}lie Boisbunon and Stanislas Chambon and Laetitia Chapel and Adrien Corenflos and Kilian Fatras and Nemo Fournier and L{\'e}o Gautheron and Nathalie T.H. Gayraud and Hicham Janati and Alain Rakotomamonjy and Ievgen Redko and Antoine Rolet and Antony Schutz and Vivien Seguy and Danica J. Sutherland and Romain Tavenard and Alexander Tong and Titouan Vayer},
@@ -131,8 +137,8 @@ following Python modules:
- Numpy (>=1.16)
- Scipy (>=1.0)
-- Cython (>=0.23)
-- Matplotlib (>=1.5)
+- Cython (>=0.23) (build only, not necessary when installing wheels
+ from pip or conda)
Pip installation
^^^^^^^^^^^^^^^^
@@ -140,19 +146,19 @@ Pip installation
Note that due to a limitation of pip, ``cython`` and ``numpy`` need to
be installed prior to installing POT. This can be done easily with
-::
+.. code:: console
pip install numpy cython
You can install the toolbox through PyPI with:
-::
+.. code:: console
pip install POT
or get the very latest version by running:
-::
+.. code:: console
pip install -U https://github.com/PythonOT/POT/archive/master.zip # with --user for user install (no root)
@@ -163,7 +169,7 @@ If you use the Anaconda python distribution, POT is available in
`conda-forge <https://conda-forge.org>`__. To install it and the
required dependencies:
-::
+.. code:: console
conda install -c conda-forge pot
@@ -188,15 +194,17 @@ below
- **ot.dr** (Wasserstein dimensionality reduction) depends on autograd
and pymanopt that can be installed with:
- ::
+.. code:: shell
- pip install pymanopt autograd
+ pip install pymanopt autograd
- **ot.gpu** (GPU accelerated OT) depends on cupy that have to be
installed following instructions on `this
page <https://docs-cupy.chainer.org/en/stable/install.html>`__.
-
-obviously you need CUDA installed and a compatible GPU.
+ Obviously you will need CUDA installed and a compatible GPU. Note
+ that this module is deprecated since version 0.8 and will be deleted
+ in the future. GPU is now handled automatically through the backends
+ and several solver already can run on GPU using the Pytorch backend.
Examples
--------
@@ -206,36 +214,36 @@ Short examples
- Import the toolbox
- .. code:: python
+.. code:: python
- import ot
+ import ot
- Compute Wasserstein distances
- .. code:: python
+.. code:: python
- # a,b are 1D histograms (sum to 1 and positive)
- # M is the ground cost matrix
- Wd=ot.emd2(a,b,M) # exact linear program
- Wd_reg=ot.sinkhorn2(a,b,M,reg) # entropic regularized OT
- # if b is a matrix compute all distances to a and return a vector
+ # a and b are 1D histograms (sum to 1 and positive)
+ # M is the ground cost matrix
+ Wd = ot.emd2(a, b, M) # exact linear program
+ Wd_reg = ot.sinkhorn2(a, b, M, reg) # entropic regularized OT
+ # if b is a matrix compute all distances to a and return a vector
- Compute OT matrix
- .. code:: python
+.. code:: python
- # a,b are 1D histograms (sum to 1 and positive)
- # M is the ground cost matrix
- T=ot.emd(a,b,M) # exact linear program
- T_reg=ot.sinkhorn(a,b,M,reg) # entropic regularized OT
+ # a and b are 1D histograms (sum to 1 and positive)
+ # M is the ground cost matrix
+ T = ot.emd(a, b, M) # exact linear program
+ T_reg = ot.sinkhorn(a, b, M, reg) # entropic regularized OT
- Compute Wasserstein barycenter
- .. code:: python
+.. code:: python
- # A is a n*d matrix containing d 1D histograms
- # M is the ground cost matrix
- ba=ot.barycenter(A,M,reg) # reg is regularization parameter
+ # A is a n*d matrix containing d 1D histograms
+ # M is the ground cost matrix
+ ba = ot.barycenter(A, M, reg) # reg is regularization parameter
Examples and Notebooks
~~~~~~~~~~~~~~~~~~~~~~