summaryrefslogtreecommitdiff
path: root/ot/gromov.py
diff options
context:
space:
mode:
authorRĂ©mi Flamary <remi.flamary@gmail.com>2022-02-02 11:53:12 +0100
committerGitHub <noreply@github.com>2022-02-02 11:53:12 +0100
commita5e0f0d40d5046a6639924347ef97e2ac80ad0c9 (patch)
treedcd35e851ec2cc3f52eedbfa58fb6970664135c9 /ot/gromov.py
parent71a57c68ea9eb2bc948c4dd1cce9928f34bf20e8 (diff)
[MRG] Add weak OT solver (#341)
* add info in release file * update tests * pep8 * add weak OT example * update plot in doc * correction ewample with empirical sinkhorn * better thumbnail * comment from review * update documenation
Diffstat (limited to 'ot/gromov.py')
-rw-r--r--ot/gromov.py16
1 files changed, 16 insertions, 0 deletions
diff --git a/ot/gromov.py b/ot/gromov.py
index 6544260..b7e7949 100644
--- a/ot/gromov.py
+++ b/ot/gromov.py
@@ -338,6 +338,10 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=F
- :math:`\mathbf{q}`: distribution in the target space
- `L`: loss function to account for the misfit between the similarity matrices
+ .. note:: This function is backend-compatible and will work on arrays
+ from all compatible backends. But the algorithm uses the C++ CPU backend
+ which can lead to copy overhead on GPU arrays.
+
Parameters
----------
C1 : array-like, shape (ns, ns)
@@ -436,6 +440,10 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=
Note that when using backends, this loss function is differentiable wrt the
marices and weights for quadratic loss using the gradients from [38]_.
+ .. note:: This function is backend-compatible and will work on arrays
+ from all compatible backends. But the algorithm uses the C++ CPU backend
+ which can lead to copy overhead on GPU arrays.
+
Parameters
----------
C1 : array-like, shape (ns, ns)
@@ -545,6 +553,10 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5,
- :math:`\mathbf{p}` and :math:`\mathbf{q}` are source and target weights (sum to 1)
- `L` is a loss function to account for the misfit between the similarity matrices
+ .. note:: This function is backend-compatible and will work on arrays
+ from all compatible backends. But the algorithm uses the C++ CPU backend
+ which can lead to copy overhead on GPU arrays.
+
The algorithm used for solving the problem is conditional gradient as discussed in :ref:`[24] <references-fused-gromov-wasserstein>`
Parameters
@@ -645,6 +657,10 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5
The algorithm used for solving the problem is conditional gradient as
discussed in :ref:`[24] <references-fused-gromov-wasserstein2>`
+ .. note:: This function is backend-compatible and will work on arrays
+ from all compatible backends. But the algorithm uses the C++ CPU backend
+ which can lead to copy overhead on GPU arrays.
+
Note that when using backends, this loss function is differentiable wrt the
marices and weights for quadratic loss using the gradients from [38]_.