diff options
author | RĂ©mi Flamary <remi.flamary@gmail.com> | 2021-06-01 10:10:54 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-06-01 10:10:54 +0200 |
commit | 184f8f4f7ac78f1dd7f653496d2753211a4e3426 (patch) | |
tree | 483a7274c91030fd644de49b03a5fad04af9deba /docs/source/quickstart.rst | |
parent | 1f16614954e2522fbdb1598c5b1f5c3630c68472 (diff) |
[MRG] POT numpy/torch/jax backends (#249)
* add numpy and torch backends
* stat sets on functions
* proper import
* install recent torch on windows
* install recent torch on windows
* now testing all functions in backedn
* add jax backedn
* clenaup windowds
* proper convert for jax backedn
* pep8
* try again windows tests
* test jax conversion
* try proper widows tests
* emd fuction ses backedn
* better test partial OT
* proper tests to_numpy and teplate Backend
* pep8
* pep8 x2
* feaking sinkhorn works with torch
* sinkhorn2 compatible
* working ot.emd2
* important detach
* it should work
* jax autodiff emd
* pep8
* no tast same for jax
* new independat tests per backedn
* freaking pep8
* add tests for gradients
* deprecate ot.gpu
* worging dist function
* working dist
* dist done in backedn
* not in
* remove indexing
* change accuacy for jax
* first pull backend
* projection simplex
* projection simplex
* projection simplex
* projection simplex no ci
* projection simplex no ci
* projection simplex no ci
* pep8
* add backedn discusion to quickstart guide
* projection simplex no ci
* projection simplex no ci
* projection simplex no ci
* pep8 + better doc
* proper links
* corect doctest
* big debug documentation
* doctest again
* doctest again bis
* doctest again ter (last one or i kill myself)
* backend test + doc proj simplex
* correction test_utils
* correction test_utils
* correction cumsum
* correction flip
* correction flip v2
* more debug
* more debug
* more debug + pep8
* pep8
* argh
* proj_simplex
* backedn works for sort
* proj simplex
* jax sucks
* update doc
* Update test/test_utils.py
Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
* Update docs/source/quickstart.rst
Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
* Update docs/source/quickstart.rst
Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
* Update docs/source/quickstart.rst
Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
* Update docs/source/readme.rst
Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
* Update test/test_utils.py
Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
* Update ot/utils.py
Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
* Update docs/source/readme.rst
Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
* Update ot/lp/__init__.py
Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
* begin comment alex
* comment alex part 2
* optimize test gromov
* proj_simplex on vectors
* add awesome gradient decsnt example on the weights
* pep98 of course
* proof read example by alex
* pep8 again
* encoding oos in translation
* correct legend
Co-authored-by: Nicolas Courty <ncourty@irisa.fr>
Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
Diffstat (limited to 'docs/source/quickstart.rst')
-rw-r--r-- | docs/source/quickstart.rst | 68 |
1 files changed, 67 insertions, 1 deletions
diff --git a/docs/source/quickstart.rst b/docs/source/quickstart.rst index cf5d6aa..fd046a1 100644 --- a/docs/source/quickstart.rst +++ b/docs/source/quickstart.rst @@ -15,6 +15,12 @@ are also available as notebooks on the POT Github. in ML applications we refer the reader to the following `OTML tutorial <https://remi.flamary.com/cours/tuto_otml.html>`_. +.. note:: + + Since version 0.8, POT provides a backend to automatically solve some OT + problems independently from the toolbox used by the user (numpy/torch/jax). + We provide a discussion about which functions are compatible in section + `Backend section <#solving-ot-with-multiple-backends>`_ . Why Optimal Transport ? @@ -158,7 +164,6 @@ Wasserstein but has better computational and `statistical properties <https://arxiv.org/pdf/1910.04091.pdf>`_. - Optimal transport and Wasserstein distance ------------------------------------------ @@ -922,6 +927,13 @@ The implementations of FGW and FGW barycenter is provided in functions GPU acceleration ^^^^^^^^^^^^^^^^ +.. warning:: + + The :any:`ot.gpu` has been deprecated since the release 0.8 of POT and + should not be used. The GPU implementation (in Pytorch for instance) can be + used with the novel backends using the compatible functions from POT. + + We provide several implementation of our OT solvers in :any:`ot.gpu`. Those implementations use the :code:`cupy` toolbox that obviously need to be installed. @@ -950,6 +962,60 @@ explicitly. use it you have to specifically import it with :code:`import ot.gpu` . +Solving OT with Multiple backends +--------------------------------- + +.. _backends_section: + +Since version 0.8, POT provides a backend that allows to code solvers +independently from the type of the input arrays. The idea is to provide the user +with a package that works seamlessly and returns a solution for instance as a +Pytorch tensors when the function has Pytorch tensors as input. + + +How it works +^^^^^^^^^^^^ + +The aim of the backend is to use the same function independently of the type of +the input arrays. + +For instance when executing the following code + +.. code:: python + + # a and b are 1D histograms (sum to 1 and positive) + # M is the ground cost matrix + T = ot.emd(a, b, M) # exact linear program + w = ot.emd2(a, b, M) # Wasserstein computation + +the functions :any:`ot.emd` and :any:`ot.emd2` can take inputs of the type +:any:`numpy.array`, :any:`torch.tensor` or :any:`jax.numpy.array`. The output of +the function will be the same type as the inputs and on the same device. When +possible all computations are done on the same device and also when possible the +output will be differentiable with respect to the input of the function. + + + +List of compatible Backends +^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +- `Numpy <https://numpy.org/>`_ (all functions and solvers) +- `Pytorch <https://pytorch.org/>`_ (all outputs differentiable w.r.t. inputs) +- `Jax <https://github.com/google/jax>`_ (Some functions are differentiable some require a wrapper) + +List of compatible functions +^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +This list will get longer for new releases and will hopefully disappear when POT +become fully implemented with the backend. + +- :any:`ot.emd` +- :any:`ot.emd2` +- :any:`ot.sinkhorn` +- :any:`ot.sinkhorn2` +- :any:`ot.dist` + + FAQ --- |