From a5e0f0d40d5046a6639924347ef97e2ac80ad0c9 Mon Sep 17 00:00:00 2001 From: RĂ©mi Flamary Date: Wed, 2 Feb 2022 11:53:12 +0100 Subject: [MRG] Add weak OT solver (#341) * add info in release file * update tests * pep8 * add weak OT example * update plot in doc * correction ewample with empirical sinkhorn * better thumbnail * comment from review * update documenation --- ot/gromov.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) (limited to 'ot/gromov.py') diff --git a/ot/gromov.py b/ot/gromov.py index 6544260..b7e7949 100644 --- a/ot/gromov.py +++ b/ot/gromov.py @@ -338,6 +338,10 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=F - :math:`\mathbf{q}`: distribution in the target space - `L`: loss function to account for the misfit between the similarity matrices + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. But the algorithm uses the C++ CPU backend + which can lead to copy overhead on GPU arrays. + Parameters ---------- C1 : array-like, shape (ns, ns) @@ -436,6 +440,10 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', log=False, armijo= Note that when using backends, this loss function is differentiable wrt the marices and weights for quadratic loss using the gradients from [38]_. + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. But the algorithm uses the C++ CPU backend + which can lead to copy overhead on GPU arrays. + Parameters ---------- C1 : array-like, shape (ns, ns) @@ -545,6 +553,10 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, - :math:`\mathbf{p}` and :math:`\mathbf{q}` are source and target weights (sum to 1) - `L` is a loss function to account for the misfit between the similarity matrices + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. But the algorithm uses the C++ CPU backend + which can lead to copy overhead on GPU arrays. + The algorithm used for solving the problem is conditional gradient as discussed in :ref:`[24] ` Parameters @@ -645,6 +657,10 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5 The algorithm used for solving the problem is conditional gradient as discussed in :ref:`[24] ` + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. But the algorithm uses the C++ CPU backend + which can lead to copy overhead on GPU arrays. + Note that when using backends, this loss function is differentiable wrt the marices and weights for quadratic loss using the gradients from [38]_. -- cgit v1.2.3