summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--README.md3
-rw-r--r--RELEASES.md8
-rw-r--r--docs/source/all.rst1
-rw-r--r--examples/others/plot_WeakOT_VS_OT.py98
-rw-r--r--examples/plot_OT_2D_samples.py5
-rw-r--r--ot/__init__.py5
-rw-r--r--ot/gromov.py16
-rw-r--r--ot/lp/__init__.py9
-rw-r--r--ot/lp/cvx.py1
-rw-r--r--ot/utils.py12
-rw-r--r--ot/weak.py124
-rw-r--r--test/test_bregman.py13
-rw-r--r--test/test_ot.py2
-rw-r--r--test/test_utils.py18
-rw-r--r--test/test_weak.py54
15 files changed, 343 insertions, 26 deletions
diff --git a/README.md b/README.md
index 17fbe81..a7627df 100644
--- a/README.md
+++ b/README.md
@@ -25,6 +25,7 @@ POT provides the following generic OT solvers (links to examples):
* Sinkhorn divergence [23] and entropic regularization OT from empirical data.
* Debiased Sinkhorn barycenters [Sinkhorn divergence barycenter](https://pythonot.github.io/auto_examples/barycenters/plot_debiased_barycenter.html) [37]
* [Smooth optimal transport solvers](https://pythonot.github.io/auto_examples/plot_OT_1D_smooth.html) (dual and semi-dual) for KL and squared L2 regularizations [17].
+* Weak OT solver between empirical distributions [39]
* Non regularized [Wasserstein barycenters [16] ](https://pythonot.github.io/auto_examples/barycenters/plot_barycenter_lp_vs_entropic.html)) with LP solver (only small scale).
* [Gromov-Wasserstein distances](https://pythonot.github.io/auto_examples/gromov/plot_gromov.html) and [GW barycenters](https://pythonot.github.io/auto_examples/gromov/plot_gromov_barycenter.html) (exact [13] and regularized [12]), differentiable using gradients from
* [Fused-Gromov-Wasserstein distances solver](https://pythonot.github.io/auto_examples/gromov/plot_fgw.html#sphx-glr-auto-examples-plot-fgw-py) and [FGW barycenters](https://pythonot.github.io/auto_examples/gromov/plot_barycenter_fgw.html) [24]
@@ -301,3 +302,5 @@ Conference on Machine Learning, PMLR 119:4692-4701, 2020
[38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, [Online Graph
Dictionary Learning](https://arxiv.org/pdf/2102.06555.pdf), International Conference on Machine Learning (ICML), 2021.
+
+[39] Gozlan, N., Roberto, C., Samson, P. M., & Tetali, P. (2017). [Kantorovich duality for general transport costs and applications](https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.712.1825&rep=rep1&type=pdf). Journal of Functional Analysis, 273(11), 3327-3405. \ No newline at end of file
diff --git a/RELEASES.md b/RELEASES.md
index 94c853b..4d05582 100644
--- a/RELEASES.md
+++ b/RELEASES.md
@@ -5,10 +5,12 @@
#### New features
-- Better list of related examples in quick start guide with `minigallery` (PR #334)
+- Better list of related examples in quick start guide with `minigallery` (PR #334).
- Add optional log-domain Sinkhorn implementation in WDA to support smaller values
- of the regularization parameter (PR #336)
-- Backend implementation for `ot.lp.free_support_barycenter` (PR #340)
+ of the regularization parameter (PR #336).
+- Backend implementation for `ot.lp.free_support_barycenter` (PR #340).
+- Add weak OT solver + example (PR #341).
+
#### Closed issues
diff --git a/docs/source/all.rst b/docs/source/all.rst
index 7f85a91..76d2ff5 100644
--- a/docs/source/all.rst
+++ b/docs/source/all.rst
@@ -28,6 +28,7 @@ API and modules
unbalanced
partial
sliced
+ weak
.. autosummary::
:toctree: ../modules/generated/
diff --git a/examples/others/plot_WeakOT_VS_OT.py b/examples/others/plot_WeakOT_VS_OT.py
new file mode 100644
index 0000000..a29c875
--- /dev/null
+++ b/examples/others/plot_WeakOT_VS_OT.py
@@ -0,0 +1,98 @@
+# -*- coding: utf-8 -*-
+"""
+====================================================
+Weak Optimal Transport VS exact Optimal Transport
+====================================================
+
+Illustration of 2D optimal transport between distributions that are weighted
+sum of diracs. The OT matrix is plotted with the samples.
+
+"""
+
+# Author: Remi Flamary <remi.flamary@polytechnique.edu>
+#
+# License: MIT License
+
+# sphinx_gallery_thumbnail_number = 4
+
+import numpy as np
+import matplotlib.pylab as pl
+import ot
+import ot.plot
+
+##############################################################################
+# Generate data an plot it
+# ------------------------
+
+#%% parameters and data generation
+
+n = 50 # nb samples
+
+mu_s = np.array([0, 0])
+cov_s = np.array([[1, 0], [0, 1]])
+
+mu_t = np.array([4, 4])
+cov_t = np.array([[1, -.8], [-.8, 1]])
+
+xs = ot.datasets.make_2D_samples_gauss(n, mu_s, cov_s)
+xt = ot.datasets.make_2D_samples_gauss(n, mu_t, cov_t)
+
+a, b = ot.unif(n), ot.unif(n) # uniform distribution on samples
+
+# loss matrix
+M = ot.dist(xs, xt)
+M /= M.max()
+
+#%% plot samples
+
+pl.figure(1)
+pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
+pl.legend(loc=0)
+pl.title('Source and target distributions')
+
+pl.figure(2)
+pl.imshow(M, interpolation='nearest')
+pl.title('Cost matrix M')
+
+
+##############################################################################
+# Compute Weak OT and exact OT solutions
+# --------------------------------------
+
+#%% EMD
+
+G0 = ot.emd(a, b, M)
+
+#%% Weak OT
+
+Gweak = ot.weak_optimal_transport(xs, xt, a, b)
+
+
+##############################################################################
+# Plot weak OT and exact OT solutions
+# --------------------------------------
+
+pl.figure(3, (8, 5))
+
+pl.subplot(1, 2, 1)
+pl.imshow(G0, interpolation='nearest')
+pl.title('OT matrix')
+
+pl.subplot(1, 2, 2)
+pl.imshow(Gweak, interpolation='nearest')
+pl.title('Weak OT matrix')
+
+pl.figure(4, (8, 5))
+
+pl.subplot(1, 2, 1)
+ot.plot.plot2D_samples_mat(xs, xt, G0, c=[.5, .5, 1])
+pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
+pl.title('OT matrix with samples')
+
+pl.subplot(1, 2, 2)
+ot.plot.plot2D_samples_mat(xs, xt, Gweak, c=[.5, .5, 1])
+pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
+pl.title('Weak OT matrix with samples')
diff --git a/examples/plot_OT_2D_samples.py b/examples/plot_OT_2D_samples.py
index af1bc12..c3a7cd8 100644
--- a/examples/plot_OT_2D_samples.py
+++ b/examples/plot_OT_2D_samples.py
@@ -42,7 +42,6 @@ a, b = np.ones((n,)) / n, np.ones((n,)) / n # uniform distribution on samples
# loss matrix
M = ot.dist(xs, xt)
-M /= M.max()
##############################################################################
# Plot data
@@ -87,7 +86,7 @@ pl.title('OT matrix with samples')
#%% sinkhorn
# reg term
-lambd = 1e-3
+lambd = 1e-1
Gs = ot.sinkhorn(a, b, M, lambd)
@@ -112,7 +111,7 @@ pl.show()
#%% sinkhorn
# reg term
-lambd = 1e-3
+lambd = 1e-1
Ges = ot.bregman.empirical_sinkhorn(xs, xt, lambd)
diff --git a/ot/__init__.py b/ot/__init__.py
index 1ea7403..7253318 100644
--- a/ot/__init__.py
+++ b/ot/__init__.py
@@ -36,6 +36,7 @@ from . import unbalanced
from . import partial
from . import backend
from . import regpath
+from . import weak
# OT functions
from .lp import emd, emd2, emd_1d, emd2_1d, wasserstein_1d
@@ -46,7 +47,7 @@ from .da import sinkhorn_lpl1_mm
from .sliced import sliced_wasserstein_distance, max_sliced_wasserstein_distance
from .gromov import (gromov_wasserstein, gromov_wasserstein2,
gromov_barycenters, fused_gromov_wasserstein, fused_gromov_wasserstein2)
-
+from .weak import weak_optimal_transport
# utils functions
from .utils import dist, unif, tic, toc, toq
@@ -59,5 +60,5 @@ __all__ = ['emd', 'emd2', 'emd_1d', 'sinkhorn', 'sinkhorn2', 'utils',
'sinkhorn_unbalanced', 'barycenter_unbalanced',
'sinkhorn_unbalanced2', 'sliced_wasserstein_distance',
'gromov_wasserstein', 'gromov_wasserstein2', 'gromov_barycenters', 'fused_gromov_wasserstein', 'fused_gromov_wasserstein2',
- 'max_sliced_wasserstein_distance',
+ 'max_sliced_wasserstein_distance', 'weak_optimal_transport',
'smooth', 'stochastic', 'unbalanced', 'partial', 'regpath']
diff --git a/ot/gromov.py b/ot/gromov.py
index 6544260..b7e7949 100644
--- a/ot/gromov.py
+++ b/ot/gromov.py
@@ -338,6 +338,10 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=F
- :math:`\mathbf{q}`: distribution in the target space
- `L`: loss function to account for the misfit between the similarity matrices
+ .. note:: This function is backend-compatible and will work on arrays
+ from all compatible backends. But the algorithm uses the C++ CPU backend
+ which can lead to copy overhead on GPU arrays.
+
Parameters
----------
C1 : array-like, shape (ns, ns)
@@ -436,6 +440,10 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=
Note that when using backends, this loss function is differentiable wrt the
marices and weights for quadratic loss using the gradients from [38]_.
+ .. note:: This function is backend-compatible and will work on arrays
+ from all compatible backends. But the algorithm uses the C++ CPU backend
+ which can lead to copy overhead on GPU arrays.
+
Parameters
----------
C1 : array-like, shape (ns, ns)
@@ -545,6 +553,10 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5,
- :math:`\mathbf{p}` and :math:`\mathbf{q}` are source and target weights (sum to 1)
- `L` is a loss function to account for the misfit between the similarity matrices
+ .. note:: This function is backend-compatible and will work on arrays
+ from all compatible backends. But the algorithm uses the C++ CPU backend
+ which can lead to copy overhead on GPU arrays.
+
The algorithm used for solving the problem is conditional gradient as discussed in :ref:`[24] <references-fused-gromov-wasserstein>`
Parameters
@@ -645,6 +657,10 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5
The algorithm used for solving the problem is conditional gradient as
discussed in :ref:`[24] <references-fused-gromov-wasserstein2>`
+ .. note:: This function is backend-compatible and will work on arrays
+ from all compatible backends. But the algorithm uses the C++ CPU backend
+ which can lead to copy overhead on GPU arrays.
+
Note that when using backends, this loss function is differentiable wrt the
marices and weights for quadratic loss using the gradients from [38]_.
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py
index 2ff7c1f..d9b6fa9 100644
--- a/ot/lp/__init__.py
+++ b/ot/lp/__init__.py
@@ -26,6 +26,8 @@ from ..utils import dist, list_to_array
from ..utils import parmap
from ..backend import get_backend
+
+
__all__ = ['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx', ' emd_1d_sorted',
'emd_1d', 'emd2_1d', 'wasserstein_1d']
@@ -220,7 +222,8 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1):
format
.. note:: This function is backend-compatible and will work on arrays
- from all compatible backends.
+ from all compatible backends. But the algorithm uses the C++ CPU backend
+ which can lead to copy overhead on GPU arrays.
Uses the algorithm proposed in :ref:`[1] <references-emd>`.
@@ -358,7 +361,8 @@ def emd2(a, b, M, processes=1,
- :math:`\mathbf{a}` and :math:`\mathbf{b}` are the sample weights
.. note:: This function is backend-compatible and will work on arrays
- from all compatible backends.
+ from all compatible backends. But the algorithm uses the C++ CPU backend
+ which can lead to copy overhead on GPU arrays.
Uses the algorithm proposed in :ref:`[1] <references-emd2>`.
@@ -622,3 +626,4 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None
return X, log_dict
else:
return X
+
diff --git a/ot/lp/cvx.py b/ot/lp/cvx.py
index 869d450..fbf3c0e 100644
--- a/ot/lp/cvx.py
+++ b/ot/lp/cvx.py
@@ -11,7 +11,6 @@ import numpy as np
import scipy as sp
import scipy.sparse as sps
-
try:
import cvxopt
from cvxopt import solvers, matrix, spmatrix
diff --git a/ot/utils.py b/ot/utils.py
index e6c93c8..725ca00 100644
--- a/ot/utils.py
+++ b/ot/utils.py
@@ -116,7 +116,7 @@ def proj_simplex(v, z=1):
return w
-def unif(n):
+def unif(n, type_as=None):
r"""
Return a uniform histogram of length `n` (simplex).
@@ -124,13 +124,19 @@ def unif(n):
----------
n : int
number of bins in the histogram
+ type_as : array_like
+ array of the same type of the expected output (numpy/pytorch/jax)
Returns
-------
- h : np.array (`n`,)
+ h : array_like (`n`,)
histogram of length `n` such that :math:`\forall i, \mathbf{h}_i = \frac{1}{n}`
"""
- return np.ones((n,)) / n
+ if type_as is None:
+ return np.ones((n,)) / n
+ else:
+ nx = get_backend(type_as)
+ return nx.ones((n,)) / n
def clean_zeros(a, b, M):
diff --git a/ot/weak.py b/ot/weak.py
new file mode 100644
index 0000000..f7d5b23
--- /dev/null
+++ b/ot/weak.py
@@ -0,0 +1,124 @@
+"""
+Weak optimal ransport solvers
+"""
+
+# Author: Remi Flamary <remi.flamary@polytehnique.edu>
+#
+# License: MIT License
+
+from .backend import get_backend
+from .optim import cg
+import numpy as np
+
+__all__ = ['weak_optimal_transport']
+
+
+def weak_optimal_transport(Xa, Xb, a=None, b=None, verbose=False, log=False, G0=None, **kwargs):
+ r"""Solves the weak optimal transport problem between two empirical distributions
+
+
+ .. math::
+ \gamma = \mathop{\arg \min}_\gamma \quad \|X_a-diag(1/a)\gammaX_b\|_F^2
+
+ s.t. \ \gamma \mathbf{1} = \mathbf{a}
+
+ \gamma^T \mathbf{1} = \mathbf{b}
+
+ \gamma \geq 0
+
+ where :
+
+ - :math:`X_a` :math:`X_b` are the sample matrices.
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are the sample weights
+
+
+ .. note:: This function is backend-compatible and will work on arrays
+ from all compatible backends. But the algorithm uses the C++ CPU backend
+ which can lead to copy overhead on GPU arrays.
+
+ Uses the conditional gradient algorithm to solve the problem proposed
+ in :ref:`[39] <references-weak>`.
+
+ Parameters
+ ----------
+ Xa : (ns,d) array-like, float
+ Source samples
+ Xb : (nt,d) array-like, float
+ Target samples
+ a : (ns,) array-like, float
+ Source histogram (uniform weight if empty list)
+ b : (nt,) array-like, float
+ Target histogram (uniform weight if empty list))
+ numItermax : int, optional
+ Max number of iterations
+ numItermaxEmd : int, optional
+ Max number of iterations for emd
+ stopThr : float, optional
+ Stop threshold on the relative variation (>0)
+ stopThr2 : float, optional
+ Stop threshold on the absolute variation (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+
+ Returns
+ -------
+ gamma: array-like, shape (ns, nt)
+ Optimal transportation matrix for the given
+ parameters
+ log: dict, optional
+ If input log is true, a dictionary containing the
+ cost and dual variables and exit status
+
+
+ .. _references-weak:
+ References
+ ----------
+ .. [39] Gozlan, N., Roberto, C., Samson, P. M., & Tetali, P. (2017).
+ Kantorovich duality for general transport costs and applications.
+ Journal of Functional Analysis, 273(11), 3327-3405.
+
+ See Also
+ --------
+ ot.bregman.sinkhorn : Entropic regularized OT
+ ot.optim.cg : General regularized OT
+ """
+
+ nx = get_backend(Xa, Xb)
+
+ Xa2 = nx.to_numpy(Xa)
+ Xb2 = nx.to_numpy(Xb)
+
+ if a is None:
+ a2 = np.ones((Xa.shape[0])) / Xa.shape[0]
+ else:
+ a2 = nx.to_numpy(a)
+ if b is None:
+ b2 = np.ones((Xb.shape[0])) / Xb.shape[0]
+ else:
+ b2 = nx.to_numpy(b)
+
+ # init uniform
+ if G0 is None:
+ T0 = a2[:, None] * b2[None, :]
+ else:
+ T0 = nx.to_numpy(G0)
+
+ # weak OT loss
+ def f(T):
+ return np.dot(a2, np.sum((Xa2 - np.dot(T, Xb2) / a2[:, None])**2, 1))
+
+ # weak OT gradient
+ def df(T):
+ return -2 * np.dot(Xa2 - np.dot(T, Xb2) / a2[:, None], Xb2.T)
+
+ # solve with conditional gradient and return solution
+ if log:
+ res, log = cg(a2, b2, 0, 1, f, df, T0, log=log, verbose=verbose, **kwargs)
+ log['u'] = nx.from_numpy(log['u'], type_as=Xa)
+ log['v'] = nx.from_numpy(log['v'], type_as=Xb)
+ return nx.from_numpy(res, type_as=Xa), log
+ else:
+ return nx.from_numpy(cg(a2, b2, 0, 1, f, df, T0, log=log, verbose=verbose, **kwargs), type_as=Xa)
diff --git a/test/test_bregman.py b/test/test_bregman.py
index 6e90aa4..1419f9b 100644
--- a/test/test_bregman.py
+++ b/test/test_bregman.py
@@ -60,7 +60,7 @@ def test_convergence_warning(method):
ot.sinkhorn2(a1, a2, M, 1, method=method, stopThr=0, numItermax=1)
-def test_not_impemented_method():
+def test_not_implemented_method():
# test sinkhorn
w = 10
n = w ** 2
@@ -635,7 +635,7 @@ def test_wasserstein_bary_2d(nx, method):
with pytest.raises(NotImplementedError):
ot.bregman.convolutional_barycenter2d(A_nx, reg, method=method)
else:
- bary_wass_np = ot.bregman.convolutional_barycenter2d(A, reg, method=method)
+ bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d(A, reg, method=method, verbose=True, log=True)
bary_wass = nx.to_numpy(ot.bregman.convolutional_barycenter2d(A_nx, reg, method=method))
np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3)
@@ -667,7 +667,7 @@ def test_wasserstein_bary_2d_debiased(nx, method):
with pytest.raises(NotImplementedError):
ot.bregman.convolutional_barycenter2d_debiased(A_nx, reg, method=method)
else:
- bary_wass_np = ot.bregman.convolutional_barycenter2d_debiased(A, reg, method=method)
+ bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d_debiased(A, reg, method=method, verbose=True, log=True)
bary_wass = nx.to_numpy(ot.bregman.convolutional_barycenter2d_debiased(A_nx, reg, method=method))
np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3)
@@ -940,14 +940,11 @@ def test_screenkhorn(nx):
bb = nx.from_numpy(b)
M_nx = nx.from_numpy(M, type_as=ab)
- # np sinkhorn
- G_sink_np = ot.sinkhorn(a, b, M, 1e-03)
# sinkhorn
- G_sink = nx.to_numpy(ot.sinkhorn(ab, bb, M_nx, 1e-03))
+ G_sink = nx.to_numpy(ot.sinkhorn(ab, bb, M_nx, 1e-1))
# screenkhorn
- G_screen = nx.to_numpy(ot.bregman.screenkhorn(ab, bb, M_nx, 1e-03, uniform=True, verbose=True))
+ G_screen = nx.to_numpy(ot.bregman.screenkhorn(ab, bb, M_nx, 1e-1, uniform=True, verbose=True))
# check marginals
- np.testing.assert_allclose(G_sink_np, G_sink)
np.testing.assert_allclose(G_sink.sum(0), G_screen.sum(0), atol=1e-02)
np.testing.assert_allclose(G_sink.sum(1), G_screen.sum(1), atol=1e-02)
diff --git a/test/test_ot.py b/test/test_ot.py
index e8e2d97..3e2d845 100644
--- a/test/test_ot.py
+++ b/test/test_ot.py
@@ -232,7 +232,7 @@ def test_emd2_multi():
# Gaussian distributions
a = gauss(n, m=20, s=5) # m= mean, s= std
- ls = np.arange(20, 500, 20)
+ ls = np.arange(20, 500, 100)
nb = len(ls)
b = np.zeros((n, nb))
for i in range(nb):
diff --git a/test/test_utils.py b/test/test_utils.py
index 8b23c22..5ad167b 100644
--- a/test/test_utils.py
+++ b/test/test_utils.py
@@ -62,12 +62,12 @@ def test_tic_toc():
import time
ot.tic()
- time.sleep(0.5)
+ time.sleep(0.1)
t = ot.toc()
t2 = ot.toq()
# test timing
- np.testing.assert_allclose(0.5, t, rtol=1e-1, atol=1e-1)
+ np.testing.assert_allclose(0.1, t, rtol=1e-1, atol=1e-1)
# test toc vs toq
np.testing.assert_allclose(t, t2, rtol=1e-1, atol=1e-1)
@@ -94,10 +94,22 @@ def test_unif():
np.testing.assert_allclose(1, np.sum(u))
-def test_dist():
+def test_unif_backend(nx):
n = 100
+ for tp in nx.__type_list__:
+ print(nx.dtype_device(tp))
+
+ u = ot.unif(n, type_as=tp)
+
+ np.testing.assert_allclose(1, np.sum(nx.to_numpy(u)), atol=1e-6)
+
+
+def test_dist():
+
+ n = 10
+
rng = np.random.RandomState(0)
x = rng.randn(n, 2)
diff --git a/test/test_weak.py b/test/test_weak.py
new file mode 100644
index 0000000..c4c3278
--- /dev/null
+++ b/test/test_weak.py
@@ -0,0 +1,54 @@
+"""Tests for main module ot.weak """
+
+# Author: Remi Flamary <remi.flamary@unice.fr>
+#
+# License: MIT License
+
+import ot
+import numpy as np
+
+
+def test_weak_ot():
+ # test weak ot solver and identity stationary point
+ n = 50
+ rng = np.random.RandomState(0)
+
+ xs = rng.randn(n, 2)
+ xt = rng.randn(n, 2)
+ u = ot.utils.unif(n)
+
+ G, log = ot.weak_optimal_transport(xs, xt, u, u, log=True)
+
+ # check constraints
+ np.testing.assert_allclose(u, G.sum(1))
+ np.testing.assert_allclose(u, G.sum(0))
+
+ # chaeck that identity is recovered
+ G = ot.weak_optimal_transport(xs, xs, G0=np.eye(n) / n)
+
+ # check G is identity
+ np.testing.assert_allclose(G, np.eye(n) / n)
+
+ # check constraints
+ np.testing.assert_allclose(u, G.sum(1))
+ np.testing.assert_allclose(u, G.sum(0))
+
+
+def test_weak_ot_bakends(nx):
+ # test weak ot solver for different backends
+ n = 50
+ rng = np.random.RandomState(0)
+
+ xs = rng.randn(n, 2)
+ xt = rng.randn(n, 2)
+ u = ot.utils.unif(n)
+
+ G = ot.weak_optimal_transport(xs, xt, u, u)
+
+ xs2 = nx.from_numpy(xs)
+ xt2 = nx.from_numpy(xt)
+ u2 = nx.from_numpy(u)
+
+ G2 = ot.weak_optimal_transport(xs2, xt2, u2, u2)
+
+ np.testing.assert_allclose(nx.to_numpy(G2), G)