summaryrefslogtreecommitdiff
path: root/ot/lp/__init__.py
diff options
context:
space:
mode:
authorRĂ©mi Flamary <remi.flamary@gmail.com>2022-02-02 11:53:12 +0100
committerGitHub <noreply@github.com>2022-02-02 11:53:12 +0100
commita5e0f0d40d5046a6639924347ef97e2ac80ad0c9 (patch)
treedcd35e851ec2cc3f52eedbfa58fb6970664135c9 /ot/lp/__init__.py
parent71a57c68ea9eb2bc948c4dd1cce9928f34bf20e8 (diff)
[MRG] Add weak OT solver (#341)
* add info in release file * update tests * pep8 * add weak OT example * update plot in doc * correction ewample with empirical sinkhorn * better thumbnail * comment from review * update documenation
Diffstat (limited to 'ot/lp/__init__.py')
-rw-r--r--ot/lp/__init__.py9
1 files changed, 7 insertions, 2 deletions
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py
index 2ff7c1f..d9b6fa9 100644
--- a/ot/lp/__init__.py
+++ b/ot/lp/__init__.py
@@ -26,6 +26,8 @@ from ..utils import dist, list_to_array
from ..utils import parmap
from ..backend import get_backend
+
+
__all__ = ['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx', ' emd_1d_sorted',
'emd_1d', 'emd2_1d', 'wasserstein_1d']
@@ -220,7 +222,8 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1):
format
.. note:: This function is backend-compatible and will work on arrays
- from all compatible backends.
+ from all compatible backends. But the algorithm uses the C++ CPU backend
+ which can lead to copy overhead on GPU arrays.
Uses the algorithm proposed in :ref:`[1] <references-emd>`.
@@ -358,7 +361,8 @@ def emd2(a, b, M, processes=1,
- :math:`\mathbf{a}` and :math:`\mathbf{b}` are the sample weights
.. note:: This function is backend-compatible and will work on arrays
- from all compatible backends.
+ from all compatible backends. But the algorithm uses the C++ CPU backend
+ which can lead to copy overhead on GPU arrays.
Uses the algorithm proposed in :ref:`[1] <references-emd2>`.
@@ -622,3 +626,4 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None
return X, log_dict
else:
return X
+