summaryrefslogtreecommitdiff
path: root/docs/source/quickstart.rst
diff options
context:
space:
mode:
authorRĂ©mi Flamary <remi.flamary@gmail.com>2021-06-01 10:10:54 +0200
committerGitHub <noreply@github.com>2021-06-01 10:10:54 +0200
commit184f8f4f7ac78f1dd7f653496d2753211a4e3426 (patch)
tree483a7274c91030fd644de49b03a5fad04af9deba /docs/source/quickstart.rst
parent1f16614954e2522fbdb1598c5b1f5c3630c68472 (diff)
[MRG] POT numpy/torch/jax backends (#249)
* add numpy and torch backends * stat sets on functions * proper import * install recent torch on windows * install recent torch on windows * now testing all functions in backedn * add jax backedn * clenaup windowds * proper convert for jax backedn * pep8 * try again windows tests * test jax conversion * try proper widows tests * emd fuction ses backedn * better test partial OT * proper tests to_numpy and teplate Backend * pep8 * pep8 x2 * feaking sinkhorn works with torch * sinkhorn2 compatible * working ot.emd2 * important detach * it should work * jax autodiff emd * pep8 * no tast same for jax * new independat tests per backedn * freaking pep8 * add tests for gradients * deprecate ot.gpu * worging dist function * working dist * dist done in backedn * not in * remove indexing * change accuacy for jax * first pull backend * projection simplex * projection simplex * projection simplex * projection simplex no ci * projection simplex no ci * projection simplex no ci * pep8 * add backedn discusion to quickstart guide * projection simplex no ci * projection simplex no ci * projection simplex no ci * pep8 + better doc * proper links * corect doctest * big debug documentation * doctest again * doctest again bis * doctest again ter (last one or i kill myself) * backend test + doc proj simplex * correction test_utils * correction test_utils * correction cumsum * correction flip * correction flip v2 * more debug * more debug * more debug + pep8 * pep8 * argh * proj_simplex * backedn works for sort * proj simplex * jax sucks * update doc * Update test/test_utils.py Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update docs/source/quickstart.rst Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update docs/source/quickstart.rst Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update docs/source/quickstart.rst Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update docs/source/readme.rst Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update test/test_utils.py Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update ot/utils.py Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update docs/source/readme.rst Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update ot/lp/__init__.py Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * begin comment alex * comment alex part 2 * optimize test gromov * proj_simplex on vectors * add awesome gradient decsnt example on the weights * pep98 of course * proof read example by alex * pep8 again * encoding oos in translation * correct legend Co-authored-by: Nicolas Courty <ncourty@irisa.fr> Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
Diffstat (limited to 'docs/source/quickstart.rst')
-rw-r--r--docs/source/quickstart.rst68
1 files changed, 67 insertions, 1 deletions
diff --git a/docs/source/quickstart.rst b/docs/source/quickstart.rst
index cf5d6aa..fd046a1 100644
--- a/docs/source/quickstart.rst
+++ b/docs/source/quickstart.rst
@@ -15,6 +15,12 @@ are also available as notebooks on the POT Github.
in ML applications we refer the reader to the following `OTML tutorial
<https://remi.flamary.com/cours/tuto_otml.html>`_.
+.. note::
+
+ Since version 0.8, POT provides a backend to automatically solve some OT
+ problems independently from the toolbox used by the user (numpy/torch/jax).
+ We provide a discussion about which functions are compatible in section
+ `Backend section <#solving-ot-with-multiple-backends>`_ .
Why Optimal Transport ?
@@ -158,7 +164,6 @@ Wasserstein but has better computational and
`statistical properties <https://arxiv.org/pdf/1910.04091.pdf>`_.
-
Optimal transport and Wasserstein distance
------------------------------------------
@@ -922,6 +927,13 @@ The implementations of FGW and FGW barycenter is provided in functions
GPU acceleration
^^^^^^^^^^^^^^^^
+.. warning::
+
+ The :any:`ot.gpu` has been deprecated since the release 0.8 of POT and
+ should not be used. The GPU implementation (in Pytorch for instance) can be
+ used with the novel backends using the compatible functions from POT.
+
+
We provide several implementation of our OT solvers in :any:`ot.gpu`. Those
implementations use the :code:`cupy` toolbox that obviously need to be installed.
@@ -950,6 +962,60 @@ explicitly.
use it you have to specifically import it with :code:`import ot.gpu` .
+Solving OT with Multiple backends
+---------------------------------
+
+.. _backends_section:
+
+Since version 0.8, POT provides a backend that allows to code solvers
+independently from the type of the input arrays. The idea is to provide the user
+with a package that works seamlessly and returns a solution for instance as a
+Pytorch tensors when the function has Pytorch tensors as input.
+
+
+How it works
+^^^^^^^^^^^^
+
+The aim of the backend is to use the same function independently of the type of
+the input arrays.
+
+For instance when executing the following code
+
+.. code:: python
+
+ # a and b are 1D histograms (sum to 1 and positive)
+ # M is the ground cost matrix
+ T = ot.emd(a, b, M) # exact linear program
+ w = ot.emd2(a, b, M) # Wasserstein computation
+
+the functions :any:`ot.emd` and :any:`ot.emd2` can take inputs of the type
+:any:`numpy.array`, :any:`torch.tensor` or :any:`jax.numpy.array`. The output of
+the function will be the same type as the inputs and on the same device. When
+possible all computations are done on the same device and also when possible the
+output will be differentiable with respect to the input of the function.
+
+
+
+List of compatible Backends
+^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+- `Numpy <https://numpy.org/>`_ (all functions and solvers)
+- `Pytorch <https://pytorch.org/>`_ (all outputs differentiable w.r.t. inputs)
+- `Jax <https://github.com/google/jax>`_ (Some functions are differentiable some require a wrapper)
+
+List of compatible functions
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+This list will get longer for new releases and will hopefully disappear when POT
+become fully implemented with the backend.
+
+- :any:`ot.emd`
+- :any:`ot.emd2`
+- :any:`ot.sinkhorn`
+- :any:`ot.sinkhorn2`
+- :any:`ot.dist`
+
+
FAQ
---