diff options
Diffstat (limited to 'docs/source/quickstart.rst')
-rw-r--r-- | docs/source/quickstart.rst | 68 |
1 files changed, 67 insertions, 1 deletions
diff --git a/docs/source/quickstart.rst b/docs/source/quickstart.rst index cf5d6aa..fd046a1 100644 --- a/docs/source/quickstart.rst +++ b/docs/source/quickstart.rst @@ -15,6 +15,12 @@ are also available as notebooks on the POT Github. in ML applications we refer the reader to the following `OTML tutorial <https://remi.flamary.com/cours/tuto_otml.html>`_. +.. note:: + + Since version 0.8, POT provides a backend to automatically solve some OT + problems independently from the toolbox used by the user (numpy/torch/jax). + We provide a discussion about which functions are compatible in section + `Backend section <#solving-ot-with-multiple-backends>`_ . Why Optimal Transport ? @@ -158,7 +164,6 @@ Wasserstein but has better computational and `statistical properties <https://arxiv.org/pdf/1910.04091.pdf>`_. - Optimal transport and Wasserstein distance ------------------------------------------ @@ -922,6 +927,13 @@ The implementations of FGW and FGW barycenter is provided in functions GPU acceleration ^^^^^^^^^^^^^^^^ +.. warning:: + + The :any:`ot.gpu` has been deprecated since the release 0.8 of POT and + should not be used. The GPU implementation (in Pytorch for instance) can be + used with the novel backends using the compatible functions from POT. + + We provide several implementation of our OT solvers in :any:`ot.gpu`. Those implementations use the :code:`cupy` toolbox that obviously need to be installed. @@ -950,6 +962,60 @@ explicitly. use it you have to specifically import it with :code:`import ot.gpu` . +Solving OT with Multiple backends +--------------------------------- + +.. _backends_section: + +Since version 0.8, POT provides a backend that allows to code solvers +independently from the type of the input arrays. The idea is to provide the user +with a package that works seamlessly and returns a solution for instance as a +Pytorch tensors when the function has Pytorch tensors as input. + + +How it works +^^^^^^^^^^^^ + +The aim of the backend is to use the same function independently of the type of +the input arrays. + +For instance when executing the following code + +.. code:: python + + # a and b are 1D histograms (sum to 1 and positive) + # M is the ground cost matrix + T = ot.emd(a, b, M) # exact linear program + w = ot.emd2(a, b, M) # Wasserstein computation + +the functions :any:`ot.emd` and :any:`ot.emd2` can take inputs of the type +:any:`numpy.array`, :any:`torch.tensor` or :any:`jax.numpy.array`. The output of +the function will be the same type as the inputs and on the same device. When +possible all computations are done on the same device and also when possible the +output will be differentiable with respect to the input of the function. + + + +List of compatible Backends +^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +- `Numpy <https://numpy.org/>`_ (all functions and solvers) +- `Pytorch <https://pytorch.org/>`_ (all outputs differentiable w.r.t. inputs) +- `Jax <https://github.com/google/jax>`_ (Some functions are differentiable some require a wrapper) + +List of compatible functions +^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +This list will get longer for new releases and will hopefully disappear when POT +become fully implemented with the backend. + +- :any:`ot.emd` +- :any:`ot.emd2` +- :any:`ot.sinkhorn` +- :any:`ot.sinkhorn2` +- :any:`ot.dist` + + FAQ --- |