summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2022-04-04 10:23:04 +0200
committerGitHub <noreply@github.com>2022-04-04 10:23:04 +0200
commit0afd84d744a472903d427e3c7ae32e55fdd7b9a7 (patch)
treef9e2ef9b2155fae13591a01bb66c2ea67ce80d18
parent82452e0f5f6dae05c7a1cc384e7a1fb62ae7e0d5 (diff)
[WIP] Add backend dual loss and plan computation for stochastic optimization or regularized OT (#360)
* add losses and plan computations and exmaple for dual oiptimization * pep8 * add nice exmaple * update awesome example stochasti dual * add all tests * pep8 + speedup exmaple * add release info
-rw-r--r--RELEASES.md2
-rw-r--r--examples/backends/plot_dual_ot_pytorch.py168
-rw-r--r--examples/backends/plot_stoch_continuous_ot_pytorch.py189
-rw-r--r--ot/stochastic.py242
-rw-r--r--test/test_stochastic.py115
5 files changed, 713 insertions, 3 deletions
diff --git a/RELEASES.md b/RELEASES.md
index c2bd0d1..45336f7 100644
--- a/RELEASES.md
+++ b/RELEASES.md
@@ -5,6 +5,8 @@
#### New features
+- Add stochastic loss and OT plan computation for regularized OT and
+ backend examples(PR #360).
- Implementation of factored OT with emd and sinkhorn (PR #358).
- A brand new logo for POT (PR #357)
- Better list of related examples in quick start guide with `minigallery` (PR #334).
diff --git a/examples/backends/plot_dual_ot_pytorch.py b/examples/backends/plot_dual_ot_pytorch.py
new file mode 100644
index 0000000..d3f7a66
--- /dev/null
+++ b/examples/backends/plot_dual_ot_pytorch.py
@@ -0,0 +1,168 @@
+# -*- coding: utf-8 -*-
+r"""
+======================================================================
+Dual OT solvers for entropic and quadratic regularized OT with Pytorch
+======================================================================
+
+
+"""
+
+# Author: Remi Flamary <remi.flamary@polytechnique.edu>
+#
+# License: MIT License
+
+# sphinx_gallery_thumbnail_number = 3
+
+import numpy as np
+import matplotlib.pyplot as pl
+import torch
+import ot
+import ot.plot
+
+# %%
+# Data generation
+# ---------------
+
+torch.manual_seed(1)
+
+n_source_samples = 100
+n_target_samples = 100
+theta = 2 * np.pi / 20
+noise_level = 0.1
+
+Xs, ys = ot.datasets.make_data_classif(
+ 'gaussrot', n_source_samples, nz=noise_level)
+Xt, yt = ot.datasets.make_data_classif(
+ 'gaussrot', n_target_samples, theta=theta, nz=noise_level)
+
+# one of the target mode changes its variance (no linear mapping)
+Xt[yt == 2] *= 3
+Xt = Xt + 4
+
+
+# %%
+# Plot data
+# ---------
+
+pl.figure(1, (10, 5))
+pl.clf()
+pl.scatter(Xs[:, 0], Xs[:, 1], marker='+', label='Source samples')
+pl.scatter(Xt[:, 0], Xt[:, 1], marker='o', label='Target samples')
+pl.legend(loc=0)
+pl.title('Source and target distributions')
+
+# %%
+# Convert data to torch tensors
+# -----------------------------
+
+xs = torch.tensor(Xs)
+xt = torch.tensor(Xt)
+
+# %%
+# Estimating dual variables for entropic OT
+# -----------------------------------------
+
+u = torch.randn(n_source_samples, requires_grad=True)
+v = torch.randn(n_source_samples, requires_grad=True)
+
+reg = 0.5
+
+optimizer = torch.optim.Adam([u, v], lr=1)
+
+# number of iteration
+n_iter = 200
+
+
+losses = []
+
+for i in range(n_iter):
+
+ # generate noise samples
+
+ # minus because we maximize te dual loss
+ loss = -ot.stochastic.loss_dual_entropic(u, v, xs, xt, reg=reg)
+ losses.append(float(loss.detach()))
+
+ if i % 10 == 0:
+ print("Iter: {:3d}, loss={}".format(i, losses[-1]))
+
+ loss.backward()
+ optimizer.step()
+ optimizer.zero_grad()
+
+
+pl.figure(2)
+pl.plot(losses)
+pl.grid()
+pl.title('Dual objective (negative)')
+pl.xlabel("Iterations")
+
+Ge = ot.stochastic.plan_dual_entropic(u, v, xs, xt, reg=reg)
+
+# %%
+# Plot teh estimated entropic OT plan
+# -----------------------------------
+
+pl.figure(3, (10, 5))
+pl.clf()
+ot.plot.plot2D_samples_mat(Xs, Xt, Ge.detach().numpy(), alpha=0.1)
+pl.scatter(Xs[:, 0], Xs[:, 1], marker='+', label='Source samples', zorder=2)
+pl.scatter(Xt[:, 0], Xt[:, 1], marker='o', label='Target samples', zorder=2)
+pl.legend(loc=0)
+pl.title('Source and target distributions')
+
+
+# %%
+# Estimating dual variables for quadratic OT
+# -----------------------------------------
+
+u = torch.randn(n_source_samples, requires_grad=True)
+v = torch.randn(n_source_samples, requires_grad=True)
+
+reg = 0.01
+
+optimizer = torch.optim.Adam([u, v], lr=1)
+
+# number of iteration
+n_iter = 200
+
+
+losses = []
+
+
+for i in range(n_iter):
+
+ # generate noise samples
+
+ # minus because we maximize te dual loss
+ loss = -ot.stochastic.loss_dual_quadratic(u, v, xs, xt, reg=reg)
+ losses.append(float(loss.detach()))
+
+ if i % 10 == 0:
+ print("Iter: {:3d}, loss={}".format(i, losses[-1]))
+
+ loss.backward()
+ optimizer.step()
+ optimizer.zero_grad()
+
+
+pl.figure(4)
+pl.plot(losses)
+pl.grid()
+pl.title('Dual objective (negative)')
+pl.xlabel("Iterations")
+
+Gq = ot.stochastic.plan_dual_quadratic(u, v, xs, xt, reg=reg)
+
+
+# %%
+# Plot the estimated quadratic OT plan
+# -----------------------------------
+
+pl.figure(5, (10, 5))
+pl.clf()
+ot.plot.plot2D_samples_mat(Xs, Xt, Gq.detach().numpy(), alpha=0.1)
+pl.scatter(Xs[:, 0], Xs[:, 1], marker='+', label='Source samples', zorder=2)
+pl.scatter(Xt[:, 0], Xt[:, 1], marker='o', label='Target samples', zorder=2)
+pl.legend(loc=0)
+pl.title('OT plan with quadratic regularization')
diff --git a/examples/backends/plot_stoch_continuous_ot_pytorch.py b/examples/backends/plot_stoch_continuous_ot_pytorch.py
new file mode 100644
index 0000000..6d9b916
--- /dev/null
+++ b/examples/backends/plot_stoch_continuous_ot_pytorch.py
@@ -0,0 +1,189 @@
+# -*- coding: utf-8 -*-
+r"""
+======================================================================
+Continuous OT plan estimation with Pytorch
+======================================================================
+
+
+"""
+
+# Author: Remi Flamary <remi.flamary@polytechnique.edu>
+#
+# License: MIT License
+
+# sphinx_gallery_thumbnail_number = 3
+
+import numpy as np
+import matplotlib.pyplot as pl
+import torch
+from torch import nn
+import ot
+import ot.plot
+
+# %%
+# Data generation
+# ---------------
+
+torch.manual_seed(42)
+np.random.seed(42)
+
+n_source_samples = 10000
+n_target_samples = 10000
+theta = 2 * np.pi / 20
+noise_level = 0.1
+
+Xs = np.random.randn(n_source_samples, 2) * 0.5
+Xt = np.random.randn(n_target_samples, 2) * 2
+
+# one of the target mode changes its variance (no linear mapping)
+Xt = Xt + 4
+
+
+# %%
+# Plot data
+# ---------
+nvisu = 300
+pl.figure(1, (5, 5))
+pl.clf()
+pl.scatter(Xs[:nvisu, 0], Xs[:nvisu, 1], marker='+', label='Source samples', alpha=0.5)
+pl.scatter(Xt[:nvisu, 0], Xt[:nvisu, 1], marker='o', label='Target samples', alpha=0.5)
+pl.legend(loc=0)
+ax_bounds = pl.axis()
+pl.title('Source and target distributions')
+
+# %%
+# Convert data to torch tensors
+# -----------------------------
+
+xs = torch.tensor(Xs)
+xt = torch.tensor(Xt)
+
+# %%
+# Estimating deep dual variables for entropic OT
+# ----------------------------------------------
+
+torch.manual_seed(42)
+
+# define the MLP model
+
+
+class Potential(torch.nn.Module):
+ def __init__(self):
+ super(Potential, self).__init__()
+ self.fc1 = nn.Linear(2, 200)
+ self.fc2 = nn.Linear(200, 1)
+ self.relu = torch.nn.ReLU() # instead of Heaviside step fn
+
+ def forward(self, x):
+ output = self.fc1(x)
+ output = self.relu(output) # instead of Heaviside step fn
+ output = self.fc2(output)
+ return output.ravel()
+
+
+u = Potential().double()
+v = Potential().double()
+
+reg = 1
+
+optimizer = torch.optim.Adam(list(u.parameters()) + list(v.parameters()), lr=.005)
+
+# number of iteration
+n_iter = 1000
+n_batch = 500
+
+
+losses = []
+
+for i in range(n_iter):
+
+ # generate noise samples
+
+ iperms = torch.randint(0, n_source_samples, (n_batch,))
+ ipermt = torch.randint(0, n_target_samples, (n_batch,))
+
+ xsi = xs[iperms]
+ xti = xt[ipermt]
+
+ # minus because we maximize te dual loss
+ loss = -ot.stochastic.loss_dual_entropic(u(xsi), v(xti), xsi, xti, reg=reg)
+ losses.append(float(loss.detach()))
+
+ if i % 10 == 0:
+ print("Iter: {:3d}, loss={}".format(i, losses[-1]))
+
+ loss.backward()
+ optimizer.step()
+ optimizer.zero_grad()
+
+
+pl.figure(2)
+pl.plot(losses)
+pl.grid()
+pl.title('Dual objective (negative)')
+pl.xlabel("Iterations")
+
+
+# %%
+# Plot the density on arget for a given source sample
+# ---------------------------------------------------
+
+
+nv = 100
+xl = np.linspace(ax_bounds[0], ax_bounds[1], nv)
+yl = np.linspace(ax_bounds[2], ax_bounds[3], nv)
+
+XX, YY = np.meshgrid(xl, yl)
+
+xg = np.concatenate((XX.ravel()[:, None], YY.ravel()[:, None]), axis=1)
+
+wxg = np.exp(-((xg[:, 0] - 4)**2 + (xg[:, 1] - 4)**2) / (2 * 2))
+wxg = wxg / np.sum(wxg)
+
+xg = torch.tensor(xg)
+wxg = torch.tensor(wxg)
+
+
+pl.figure(4, (12, 4))
+pl.clf()
+pl.subplot(1, 3, 1)
+
+iv = 2
+Gg = ot.stochastic.plan_dual_entropic(u(xs[iv:iv + 1, :]), v(xg), xs[iv:iv + 1, :], xg, reg=reg, wt=wxg)
+Gg = Gg.reshape((nv, nv)).detach().numpy()
+
+pl.scatter(Xs[:nvisu, 0], Xs[:nvisu, 1], marker='+', zorder=2, alpha=0.05)
+pl.scatter(Xt[:nvisu, 0], Xt[:nvisu, 1], marker='o', zorder=2, alpha=0.05)
+pl.scatter(Xs[iv:iv + 1, 0], Xs[iv:iv + 1, 1], s=100, marker='+', label='Source sample', zorder=2, alpha=1, color='C0')
+pl.pcolormesh(XX, YY, Gg, cmap='Greens', label='Density of transported sourec sample')
+pl.legend(loc=0)
+ax_bounds = pl.axis()
+pl.title('Density of transported source sample')
+
+pl.subplot(1, 3, 2)
+
+iv = 3
+Gg = ot.stochastic.plan_dual_entropic(u(xs[iv:iv + 1, :]), v(xg), xs[iv:iv + 1, :], xg, reg=reg, wt=wxg)
+Gg = Gg.reshape((nv, nv)).detach().numpy()
+
+pl.scatter(Xs[:nvisu, 0], Xs[:nvisu, 1], marker='+', zorder=2, alpha=0.05)
+pl.scatter(Xt[:nvisu, 0], Xt[:nvisu, 1], marker='o', zorder=2, alpha=0.05)
+pl.scatter(Xs[iv:iv + 1, 0], Xs[iv:iv + 1, 1], s=100, marker='+', label='Source sample', zorder=2, alpha=1, color='C0')
+pl.pcolormesh(XX, YY, Gg, cmap='Greens', label='Density of transported sourec sample')
+pl.legend(loc=0)
+ax_bounds = pl.axis()
+pl.title('Density of transported source sample')
+
+pl.subplot(1, 3, 3)
+
+iv = 6
+Gg = ot.stochastic.plan_dual_entropic(u(xs[iv:iv + 1, :]), v(xg), xs[iv:iv + 1, :], xg, reg=reg, wt=wxg)
+Gg = Gg.reshape((nv, nv)).detach().numpy()
+
+pl.scatter(Xs[:nvisu, 0], Xs[:nvisu, 1], marker='+', zorder=2, alpha=0.05)
+pl.scatter(Xt[:nvisu, 0], Xt[:nvisu, 1], marker='o', zorder=2, alpha=0.05)
+pl.scatter(Xs[iv:iv + 1, 0], Xs[iv:iv + 1, 1], s=100, marker='+', label='Source sample', zorder=2, alpha=1, color='C0')
+pl.pcolormesh(XX, YY, Gg, cmap='Greens', label='Density of transported sourec sample')
+pl.legend(loc=0)
+ax_bounds = pl.axis()
+pl.title('Density of transported source sample')
diff --git a/ot/stochastic.py b/ot/stochastic.py
index 693675f..61be9bb 100644
--- a/ot/stochastic.py
+++ b/ot/stochastic.py
@@ -4,12 +4,14 @@ Stochastic solvers for regularized OT.
"""
-# Author: Kilian Fatras <kilian.fatras@gmail.com>
+# Authors: Kilian Fatras <kilian.fatras@gmail.com>
+# Rémi Flamary <remi.flamary@polytechnique.edu>
#
# License: MIT License
import numpy as np
-
+from .utils import dist
+from .backend import get_backend
##############################################################################
# Optimization toolbox for SEMI - DUAL problems
@@ -747,3 +749,239 @@ def solve_dual_entropic(a, b, M, reg, batch_size, numItermax=10000, lr=1,
return pi, log
else:
return pi
+
+
+################################################################################
+# Losses for stochastic optimization
+################################################################################
+
+def loss_dual_entropic(u, v, xs, xt, reg=1, ws=None, wt=None, metric='sqeuclidean'):
+ r"""
+ Compute the dual loss of the entropic OT as in equation (6)-(7) of [19]
+
+ This loss is backend compatible and can be used for stochastic optimization
+ of the dual potentials. It can be used on the full dataset (beware of
+ memory) or on minibatches.
+
+
+ Parameters
+ ----------
+ u : array-like, shape (ns,)
+ Source dual potential
+ v : array-like, shape (nt,)
+ Target dual potential
+ xs : array-like, shape (ns,d)
+ Source samples
+ xt : array-like, shape (ns,d)
+ Target samples
+ reg : float
+ Regularization term > 0 (default=1)
+ ws : array-like, shape (ns,), optional
+ Source sample weights (default unif)
+ wt : array-like, shape (ns,), optional
+ Target sample weights (default unif)
+ metric : string, callable
+ Ground metric for OT (default quadratic). Can be given as a callable
+ function taking (xs,xt) as parameters.
+
+ Returns
+ -------
+ dual_loss : array-like
+ Dual loss (to maximize)
+
+
+ References
+ ----------
+ .. [19] Seguy, V., Bhushan Damodaran, B., Flamary, R., Courty, N., Rolet, A.& Blondel, M. Large-scale Optimal Transport and Mapping Estimation. International Conference on Learning Representation (2018)
+ """
+
+ nx = get_backend(u, v, xs, xt)
+
+ if ws is None:
+ ws = nx.ones(xs.shape[0], type_as=xs) / xs.shape[0]
+
+ if wt is None:
+ wt = nx.ones(xt.shape[0], type_as=xt) / xt.shape[0]
+
+ if callable(metric):
+ M = metric(xs, xt)
+ else:
+ M = dist(xs, xt, metric=metric)
+
+ F = -reg * nx.exp((u[:, None] + v[None, :] - M) / reg)
+
+ return nx.sum(u * ws) + nx.sum(v * wt) + nx.sum(ws[:, None] * F * wt[None, :])
+
+
+def plan_dual_entropic(u, v, xs, xt, reg=1, ws=None, wt=None, metric='sqeuclidean'):
+ r"""
+ Compute the primal OT plan the entropic OT as in equation (8) of [19]
+
+ This loss is backend compatible and can be used for stochastic optimization
+ of the dual potentials. It can be used on the full dataset (beware of
+ memory) or on minibatches.
+
+
+ Parameters
+ ----------
+ u : array-like, shape (ns,)
+ Source dual potential
+ v : array-like, shape (nt,)
+ Target dual potential
+ xs : array-like, shape (ns,d)
+ Source samples
+ xt : array-like, shape (ns,d)
+ Target samples
+ reg : float
+ Regularization term > 0 (default=1)
+ ws : array-like, shape (ns,), optional
+ Source sample weights (default unif)
+ wt : array-like, shape (ns,), optional
+ Target sample weights (default unif)
+ metric : string, callable
+ Ground metric for OT (default quadratic). Can be given as a callable
+ function taking (xs,xt) as parameters.
+
+ Returns
+ -------
+ G : array-like
+ Primal OT plan
+
+
+ References
+ ----------
+ .. [19] Seguy, V., Bhushan Damodaran, B., Flamary, R., Courty, N., Rolet, A.& Blondel, M. Large-scale Optimal Transport and Mapping Estimation. International Conference on Learning Representation (2018)
+ """
+
+ nx = get_backend(u, v, xs, xt)
+
+ if ws is None:
+ ws = nx.ones(xs.shape[0], type_as=xs) / xs.shape[0]
+
+ if wt is None:
+ wt = nx.ones(xt.shape[0], type_as=xt) / xt.shape[0]
+
+ if callable(metric):
+ M = metric(xs, xt)
+ else:
+ M = dist(xs, xt, metric=metric)
+
+ H = nx.exp((u[:, None] + v[None, :] - M) / reg)
+
+ return ws[:, None] * H * wt[None, :]
+
+
+def loss_dual_quadratic(u, v, xs, xt, reg=1, ws=None, wt=None, metric='sqeuclidean'):
+ r"""
+ Compute the dual loss of the quadratic regularized OT as in equation (6)-(7) of [19]
+
+ This loss is backend compatible and can be used for stochastic optimization
+ of the dual potentials. It can be used on the full dataset (beware of
+ memory) or on minibatches.
+
+
+ Parameters
+ ----------
+ u : array-like, shape (ns,)
+ Source dual potential
+ v : array-like, shape (nt,)
+ Target dual potential
+ xs : array-like, shape (ns,d)
+ Source samples
+ xt : array-like, shape (ns,d)
+ Target samples
+ reg : float
+ Regularization term > 0 (default=1)
+ ws : array-like, shape (ns,), optional
+ Source sample weights (default unif)
+ wt : array-like, shape (ns,), optional
+ Target sample weights (default unif)
+ metric : string, callable
+ Ground metric for OT (default quadratic). Can be given as a callable
+ function taking (xs,xt) as parameters.
+
+ Returns
+ -------
+ dual_loss : array-like
+ Dual loss (to maximize)
+
+
+ References
+ ----------
+ .. [19] Seguy, V., Bhushan Damodaran, B., Flamary, R., Courty, N., Rolet, A.& Blondel, M. Large-scale Optimal Transport and Mapping Estimation. International Conference on Learning Representation (2018)
+ """
+
+ nx = get_backend(u, v, xs, xt)
+
+ if ws is None:
+ ws = nx.ones(xs.shape[0], type_as=xs) / xs.shape[0]
+
+ if wt is None:
+ wt = nx.ones(xt.shape[0], type_as=xt) / xt.shape[0]
+
+ if callable(metric):
+ M = metric(xs, xt)
+ else:
+ M = dist(xs, xt, metric=metric)
+
+ F = -1.0 / (4 * reg) * nx.maximum(u[:, None] + v[None, :] - M, 0.0)**2
+
+ return nx.sum(u * ws) + nx.sum(v * wt) + nx.sum(ws[:, None] * F * wt[None, :])
+
+
+def plan_dual_quadratic(u, v, xs, xt, reg=1, ws=None, wt=None, metric='sqeuclidean'):
+ r"""
+ Compute the primal OT plan the quadratic regularized OT as in equation (8) of [19]
+
+ This loss is backend compatible and can be used for stochastic optimization
+ of the dual potentials. It can be used on the full dataset (beware of
+ memory) or on minibatches.
+
+
+ Parameters
+ ----------
+ u : array-like, shape (ns,)
+ Source dual potential
+ v : array-like, shape (nt,)
+ Target dual potential
+ xs : array-like, shape (ns,d)
+ Source samples
+ xt : array-like, shape (ns,d)
+ Target samples
+ reg : float
+ Regularization term > 0 (default=1)
+ ws : array-like, shape (ns,), optional
+ Source sample weights (default unif)
+ wt : array-like, shape (ns,), optional
+ Target sample weights (default unif)
+ metric : string, callable
+ Ground metric for OT (default quadratic). Can be given as a callable
+ function taking (xs,xt) as parameters.
+
+ Returns
+ -------
+ G : array-like
+ Primal OT plan
+
+
+ References
+ ----------
+ .. [19] Seguy, V., Bhushan Damodaran, B., Flamary, R., Courty, N., Rolet, A.& Blondel, M. Large-scale Optimal Transport and Mapping Estimation. International Conference on Learning Representation (2018)
+ """
+
+ nx = get_backend(u, v, xs, xt)
+
+ if ws is None:
+ ws = nx.ones(xs.shape[0], type_as=xs) / xs.shape[0]
+
+ if wt is None:
+ wt = nx.ones(xt.shape[0], type_as=xt) / xt.shape[0]
+
+ if callable(metric):
+ M = metric(xs, xt)
+ else:
+ M = dist(xs, xt, metric=metric)
+
+ H = 1.0 / (2 * reg) * nx.maximum(u[:, None] + v[None, :] - M, 0.0)
+
+ return ws[:, None] * H * wt[None, :]
diff --git a/test/test_stochastic.py b/test/test_stochastic.py
index 736df32..2b5c0fb 100644
--- a/test/test_stochastic.py
+++ b/test/test_stochastic.py
@@ -8,7 +8,8 @@ for descrete and semicontinous measures from the POT library.
"""
-# Author: Kilian Fatras <kilian.fatras@gmail.com>
+# Authors: Kilian Fatras <kilian.fatras@gmail.com>
+# Rémi Flamary <remi.flamary@polytechnique.edu>
#
# License: MIT License
@@ -213,3 +214,115 @@ def test_dual_sgd_sinkhorn():
G_sgd.sum(0), G_sinkhorn.sum(0), atol=1e-03)
np.testing.assert_allclose(
G_sgd, G_sinkhorn, atol=1e-03) # cf convergence sgd
+
+
+def test_loss_dual_entropic(nx):
+
+ nx.seed(0)
+
+ xs = nx.randn(50, 2)
+ xt = nx.randn(40, 2) + 2
+
+ ws = nx.rand(50)
+ ws = ws / nx.sum(ws)
+
+ wt = nx.rand(40)
+ wt = wt / nx.sum(wt)
+
+ u = nx.randn(50)
+ v = nx.randn(40)
+
+ def metric(x, y):
+ return -nx.dot(x, y.T)
+
+ ot.stochastic.loss_dual_entropic(u, v, xs, xt)
+
+ ot.stochastic.loss_dual_entropic(u, v, xs, xt, ws=ws, wt=wt, metric=metric)
+
+
+def test_plan_dual_entropic(nx):
+
+ nx.seed(0)
+
+ xs = nx.randn(50, 2)
+ xt = nx.randn(40, 2) + 2
+
+ ws = nx.rand(50)
+ ws = ws / nx.sum(ws)
+
+ wt = nx.rand(40)
+ wt = wt / nx.sum(wt)
+
+ u = nx.randn(50)
+ v = nx.randn(40)
+
+ def metric(x, y):
+ return -nx.dot(x, y.T)
+
+ G1 = ot.stochastic.plan_dual_entropic(u, v, xs, xt)
+
+ assert np.all(nx.to_numpy(G1) >= 0)
+ assert G1.shape[0] == 50
+ assert G1.shape[1] == 40
+
+ G2 = ot.stochastic.plan_dual_entropic(u, v, xs, xt, ws=ws, wt=wt, metric=metric)
+
+ assert np.all(nx.to_numpy(G2) >= 0)
+ assert G2.shape[0] == 50
+ assert G2.shape[1] == 40
+
+
+def test_loss_dual_quadratic(nx):
+
+ nx.seed(0)
+
+ xs = nx.randn(50, 2)
+ xt = nx.randn(40, 2) + 2
+
+ ws = nx.rand(50)
+ ws = ws / nx.sum(ws)
+
+ wt = nx.rand(40)
+ wt = wt / nx.sum(wt)
+
+ u = nx.randn(50)
+ v = nx.randn(40)
+
+ def metric(x, y):
+ return -nx.dot(x, y.T)
+
+ ot.stochastic.loss_dual_quadratic(u, v, xs, xt)
+
+ ot.stochastic.loss_dual_quadratic(u, v, xs, xt, ws=ws, wt=wt, metric=metric)
+
+
+def test_plan_dual_quadratic(nx):
+
+ nx.seed(0)
+
+ xs = nx.randn(50, 2)
+ xt = nx.randn(40, 2) + 2
+
+ ws = nx.rand(50)
+ ws = ws / nx.sum(ws)
+
+ wt = nx.rand(40)
+ wt = wt / nx.sum(wt)
+
+ u = nx.randn(50)
+ v = nx.randn(40)
+
+ def metric(x, y):
+ return -nx.dot(x, y.T)
+
+ G1 = ot.stochastic.plan_dual_quadratic(u, v, xs, xt)
+
+ assert np.all(nx.to_numpy(G1) >= 0)
+ assert G1.shape[0] == 50
+ assert G1.shape[1] == 40
+
+ G2 = ot.stochastic.plan_dual_quadratic(u, v, xs, xt, ws=ws, wt=wt, metric=metric)
+
+ assert np.all(nx.to_numpy(G2) >= 0)
+ assert G2.shape[0] == 50
+ assert G2.shape[1] == 40