summaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
authorGard Spreemann <gspr@nonempty.org>2022-04-27 11:49:23 +0200
committerGard Spreemann <gspr@nonempty.org>2022-04-27 11:49:23 +0200
commit35bd2c98b642df78638d7d733bc1a89d873db1de (patch)
tree6bc637624004713808d3097b95acdccbb9608e52 /examples
parentc4753bd3f74139af8380127b66b484bc09b50661 (diff)
parenteccb1386eea52b94b82456d126bd20cbe3198e05 (diff)
Merge tag '0.8.2' into dfsg/latest
Diffstat (limited to 'examples')
-rw-r--r--examples/backends/plot_dual_ot_pytorch.py168
-rw-r--r--examples/backends/plot_sliced_wass_grad_flow_pytorch.py2
-rw-r--r--examples/backends/plot_stoch_continuous_ot_pytorch.py189
-rw-r--r--examples/backends/plot_wass1d_torch.py8
-rw-r--r--examples/barycenters/plot_free_support_barycenter.py55
-rwxr-xr-xexamples/gromov/plot_gromov_wasserstein_dictionary_learning.py357
-rw-r--r--examples/others/plot_WeakOT_VS_OT.py98
-rw-r--r--examples/others/plot_factored_coupling.py86
-rw-r--r--examples/others/plot_logo.py112
-rw-r--r--examples/others/plot_screenkhorn_1D.py (renamed from examples/plot_screenkhorn_1D.py)6
-rw-r--r--examples/others/plot_stochastic.py (renamed from examples/plot_stochastic.py)0
-rw-r--r--examples/plot_Intro_OT.py4
-rw-r--r--examples/plot_OT_1D.py12
-rw-r--r--examples/plot_OT_1D_smooth.py6
-rw-r--r--examples/plot_OT_2D_samples.py7
-rw-r--r--examples/plot_OT_L1_vs_L2.py34
-rw-r--r--examples/plot_compute_emd.py72
-rw-r--r--examples/plot_optim_OTreg.py38
-rw-r--r--examples/sliced-wasserstein/README.txt2
-rw-r--r--examples/sliced-wasserstein/plot_variance.py8
-rw-r--r--examples/unbalanced-partial/plot_UOT_1D.py17
-rw-r--r--examples/unbalanced-partial/plot_regpath.py88
-rw-r--r--examples/unbalanced-partial/plot_unbalanced_OT.py116
23 files changed, 1387 insertions, 98 deletions
diff --git a/examples/backends/plot_dual_ot_pytorch.py b/examples/backends/plot_dual_ot_pytorch.py
new file mode 100644
index 0000000..d3f7a66
--- /dev/null
+++ b/examples/backends/plot_dual_ot_pytorch.py
@@ -0,0 +1,168 @@
+# -*- coding: utf-8 -*-
+r"""
+======================================================================
+Dual OT solvers for entropic and quadratic regularized OT with Pytorch
+======================================================================
+
+
+"""
+
+# Author: Remi Flamary <remi.flamary@polytechnique.edu>
+#
+# License: MIT License
+
+# sphinx_gallery_thumbnail_number = 3
+
+import numpy as np
+import matplotlib.pyplot as pl
+import torch
+import ot
+import ot.plot
+
+# %%
+# Data generation
+# ---------------
+
+torch.manual_seed(1)
+
+n_source_samples = 100
+n_target_samples = 100
+theta = 2 * np.pi / 20
+noise_level = 0.1
+
+Xs, ys = ot.datasets.make_data_classif(
+ 'gaussrot', n_source_samples, nz=noise_level)
+Xt, yt = ot.datasets.make_data_classif(
+ 'gaussrot', n_target_samples, theta=theta, nz=noise_level)
+
+# one of the target mode changes its variance (no linear mapping)
+Xt[yt == 2] *= 3
+Xt = Xt + 4
+
+
+# %%
+# Plot data
+# ---------
+
+pl.figure(1, (10, 5))
+pl.clf()
+pl.scatter(Xs[:, 0], Xs[:, 1], marker='+', label='Source samples')
+pl.scatter(Xt[:, 0], Xt[:, 1], marker='o', label='Target samples')
+pl.legend(loc=0)
+pl.title('Source and target distributions')
+
+# %%
+# Convert data to torch tensors
+# -----------------------------
+
+xs = torch.tensor(Xs)
+xt = torch.tensor(Xt)
+
+# %%
+# Estimating dual variables for entropic OT
+# -----------------------------------------
+
+u = torch.randn(n_source_samples, requires_grad=True)
+v = torch.randn(n_source_samples, requires_grad=True)
+
+reg = 0.5
+
+optimizer = torch.optim.Adam([u, v], lr=1)
+
+# number of iteration
+n_iter = 200
+
+
+losses = []
+
+for i in range(n_iter):
+
+ # generate noise samples
+
+ # minus because we maximize te dual loss
+ loss = -ot.stochastic.loss_dual_entropic(u, v, xs, xt, reg=reg)
+ losses.append(float(loss.detach()))
+
+ if i % 10 == 0:
+ print("Iter: {:3d}, loss={}".format(i, losses[-1]))
+
+ loss.backward()
+ optimizer.step()
+ optimizer.zero_grad()
+
+
+pl.figure(2)
+pl.plot(losses)
+pl.grid()
+pl.title('Dual objective (negative)')
+pl.xlabel("Iterations")
+
+Ge = ot.stochastic.plan_dual_entropic(u, v, xs, xt, reg=reg)
+
+# %%
+# Plot teh estimated entropic OT plan
+# -----------------------------------
+
+pl.figure(3, (10, 5))
+pl.clf()
+ot.plot.plot2D_samples_mat(Xs, Xt, Ge.detach().numpy(), alpha=0.1)
+pl.scatter(Xs[:, 0], Xs[:, 1], marker='+', label='Source samples', zorder=2)
+pl.scatter(Xt[:, 0], Xt[:, 1], marker='o', label='Target samples', zorder=2)
+pl.legend(loc=0)
+pl.title('Source and target distributions')
+
+
+# %%
+# Estimating dual variables for quadratic OT
+# -----------------------------------------
+
+u = torch.randn(n_source_samples, requires_grad=True)
+v = torch.randn(n_source_samples, requires_grad=True)
+
+reg = 0.01
+
+optimizer = torch.optim.Adam([u, v], lr=1)
+
+# number of iteration
+n_iter = 200
+
+
+losses = []
+
+
+for i in range(n_iter):
+
+ # generate noise samples
+
+ # minus because we maximize te dual loss
+ loss = -ot.stochastic.loss_dual_quadratic(u, v, xs, xt, reg=reg)
+ losses.append(float(loss.detach()))
+
+ if i % 10 == 0:
+ print("Iter: {:3d}, loss={}".format(i, losses[-1]))
+
+ loss.backward()
+ optimizer.step()
+ optimizer.zero_grad()
+
+
+pl.figure(4)
+pl.plot(losses)
+pl.grid()
+pl.title('Dual objective (negative)')
+pl.xlabel("Iterations")
+
+Gq = ot.stochastic.plan_dual_quadratic(u, v, xs, xt, reg=reg)
+
+
+# %%
+# Plot the estimated quadratic OT plan
+# -----------------------------------
+
+pl.figure(5, (10, 5))
+pl.clf()
+ot.plot.plot2D_samples_mat(Xs, Xt, Gq.detach().numpy(), alpha=0.1)
+pl.scatter(Xs[:, 0], Xs[:, 1], marker='+', label='Source samples', zorder=2)
+pl.scatter(Xt[:, 0], Xt[:, 1], marker='o', label='Target samples', zorder=2)
+pl.legend(loc=0)
+pl.title('OT plan with quadratic regularization')
diff --git a/examples/backends/plot_sliced_wass_grad_flow_pytorch.py b/examples/backends/plot_sliced_wass_grad_flow_pytorch.py
index 05b9952..cf5d64d 100644
--- a/examples/backends/plot_sliced_wass_grad_flow_pytorch.py
+++ b/examples/backends/plot_sliced_wass_grad_flow_pytorch.py
@@ -27,6 +27,8 @@ Machine Learning (pp. 4104-4113). PMLR.
#
# License: MIT License
+# sphinx_gallery_thumbnail_number = 4
+
# %%
# Loading the data
diff --git a/examples/backends/plot_stoch_continuous_ot_pytorch.py b/examples/backends/plot_stoch_continuous_ot_pytorch.py
new file mode 100644
index 0000000..6d9b916
--- /dev/null
+++ b/examples/backends/plot_stoch_continuous_ot_pytorch.py
@@ -0,0 +1,189 @@
+# -*- coding: utf-8 -*-
+r"""
+======================================================================
+Continuous OT plan estimation with Pytorch
+======================================================================
+
+
+"""
+
+# Author: Remi Flamary <remi.flamary@polytechnique.edu>
+#
+# License: MIT License
+
+# sphinx_gallery_thumbnail_number = 3
+
+import numpy as np
+import matplotlib.pyplot as pl
+import torch
+from torch import nn
+import ot
+import ot.plot
+
+# %%
+# Data generation
+# ---------------
+
+torch.manual_seed(42)
+np.random.seed(42)
+
+n_source_samples = 10000
+n_target_samples = 10000
+theta = 2 * np.pi / 20
+noise_level = 0.1
+
+Xs = np.random.randn(n_source_samples, 2) * 0.5
+Xt = np.random.randn(n_target_samples, 2) * 2
+
+# one of the target mode changes its variance (no linear mapping)
+Xt = Xt + 4
+
+
+# %%
+# Plot data
+# ---------
+nvisu = 300
+pl.figure(1, (5, 5))
+pl.clf()
+pl.scatter(Xs[:nvisu, 0], Xs[:nvisu, 1], marker='+', label='Source samples', alpha=0.5)
+pl.scatter(Xt[:nvisu, 0], Xt[:nvisu, 1], marker='o', label='Target samples', alpha=0.5)
+pl.legend(loc=0)
+ax_bounds = pl.axis()
+pl.title('Source and target distributions')
+
+# %%
+# Convert data to torch tensors
+# -----------------------------
+
+xs = torch.tensor(Xs)
+xt = torch.tensor(Xt)
+
+# %%
+# Estimating deep dual variables for entropic OT
+# ----------------------------------------------
+
+torch.manual_seed(42)
+
+# define the MLP model
+
+
+class Potential(torch.nn.Module):
+ def __init__(self):
+ super(Potential, self).__init__()
+ self.fc1 = nn.Linear(2, 200)
+ self.fc2 = nn.Linear(200, 1)
+ self.relu = torch.nn.ReLU() # instead of Heaviside step fn
+
+ def forward(self, x):
+ output = self.fc1(x)
+ output = self.relu(output) # instead of Heaviside step fn
+ output = self.fc2(output)
+ return output.ravel()
+
+
+u = Potential().double()
+v = Potential().double()
+
+reg = 1
+
+optimizer = torch.optim.Adam(list(u.parameters()) + list(v.parameters()), lr=.005)
+
+# number of iteration
+n_iter = 1000
+n_batch = 500
+
+
+losses = []
+
+for i in range(n_iter):
+
+ # generate noise samples
+
+ iperms = torch.randint(0, n_source_samples, (n_batch,))
+ ipermt = torch.randint(0, n_target_samples, (n_batch,))
+
+ xsi = xs[iperms]
+ xti = xt[ipermt]
+
+ # minus because we maximize te dual loss
+ loss = -ot.stochastic.loss_dual_entropic(u(xsi), v(xti), xsi, xti, reg=reg)
+ losses.append(float(loss.detach()))
+
+ if i % 10 == 0:
+ print("Iter: {:3d}, loss={}".format(i, losses[-1]))
+
+ loss.backward()
+ optimizer.step()
+ optimizer.zero_grad()
+
+
+pl.figure(2)
+pl.plot(losses)
+pl.grid()
+pl.title('Dual objective (negative)')
+pl.xlabel("Iterations")
+
+
+# %%
+# Plot the density on arget for a given source sample
+# ---------------------------------------------------
+
+
+nv = 100
+xl = np.linspace(ax_bounds[0], ax_bounds[1], nv)
+yl = np.linspace(ax_bounds[2], ax_bounds[3], nv)
+
+XX, YY = np.meshgrid(xl, yl)
+
+xg = np.concatenate((XX.ravel()[:, None], YY.ravel()[:, None]), axis=1)
+
+wxg = np.exp(-((xg[:, 0] - 4)**2 + (xg[:, 1] - 4)**2) / (2 * 2))
+wxg = wxg / np.sum(wxg)
+
+xg = torch.tensor(xg)
+wxg = torch.tensor(wxg)
+
+
+pl.figure(4, (12, 4))
+pl.clf()
+pl.subplot(1, 3, 1)
+
+iv = 2
+Gg = ot.stochastic.plan_dual_entropic(u(xs[iv:iv + 1, :]), v(xg), xs[iv:iv + 1, :], xg, reg=reg, wt=wxg)
+Gg = Gg.reshape((nv, nv)).detach().numpy()
+
+pl.scatter(Xs[:nvisu, 0], Xs[:nvisu, 1], marker='+', zorder=2, alpha=0.05)
+pl.scatter(Xt[:nvisu, 0], Xt[:nvisu, 1], marker='o', zorder=2, alpha=0.05)
+pl.scatter(Xs[iv:iv + 1, 0], Xs[iv:iv + 1, 1], s=100, marker='+', label='Source sample', zorder=2, alpha=1, color='C0')
+pl.pcolormesh(XX, YY, Gg, cmap='Greens', label='Density of transported sourec sample')
+pl.legend(loc=0)
+ax_bounds = pl.axis()
+pl.title('Density of transported source sample')
+
+pl.subplot(1, 3, 2)
+
+iv = 3
+Gg = ot.stochastic.plan_dual_entropic(u(xs[iv:iv + 1, :]), v(xg), xs[iv:iv + 1, :], xg, reg=reg, wt=wxg)
+Gg = Gg.reshape((nv, nv)).detach().numpy()
+
+pl.scatter(Xs[:nvisu, 0], Xs[:nvisu, 1], marker='+', zorder=2, alpha=0.05)
+pl.scatter(Xt[:nvisu, 0], Xt[:nvisu, 1], marker='o', zorder=2, alpha=0.05)
+pl.scatter(Xs[iv:iv + 1, 0], Xs[iv:iv + 1, 1], s=100, marker='+', label='Source sample', zorder=2, alpha=1, color='C0')
+pl.pcolormesh(XX, YY, Gg, cmap='Greens', label='Density of transported sourec sample')
+pl.legend(loc=0)
+ax_bounds = pl.axis()
+pl.title('Density of transported source sample')
+
+pl.subplot(1, 3, 3)
+
+iv = 6
+Gg = ot.stochastic.plan_dual_entropic(u(xs[iv:iv + 1, :]), v(xg), xs[iv:iv + 1, :], xg, reg=reg, wt=wxg)
+Gg = Gg.reshape((nv, nv)).detach().numpy()
+
+pl.scatter(Xs[:nvisu, 0], Xs[:nvisu, 1], marker='+', zorder=2, alpha=0.05)
+pl.scatter(Xt[:nvisu, 0], Xt[:nvisu, 1], marker='o', zorder=2, alpha=0.05)
+pl.scatter(Xs[iv:iv + 1, 0], Xs[iv:iv + 1, 1], s=100, marker='+', label='Source sample', zorder=2, alpha=1, color='C0')
+pl.pcolormesh(XX, YY, Gg, cmap='Greens', label='Density of transported sourec sample')
+pl.legend(loc=0)
+ax_bounds = pl.axis()
+pl.title('Density of transported source sample')
diff --git a/examples/backends/plot_wass1d_torch.py b/examples/backends/plot_wass1d_torch.py
index 0abdd6d..cd8e2fd 100644
--- a/examples/backends/plot_wass1d_torch.py
+++ b/examples/backends/plot_wass1d_torch.py
@@ -1,9 +1,9 @@
r"""
-=================================
-Wasserstein 1D with PyTorch
-=================================
+=================================================
+Wasserstein 1D (flow and barycenter) with PyTorch
+=================================================
-In this small example, we consider the following minization problem:
+In this small example, we consider the following minimization problem:
.. math::
\mu^* = \min_\mu W(\mu,\nu)
diff --git a/examples/barycenters/plot_free_support_barycenter.py b/examples/barycenters/plot_free_support_barycenter.py
index 2d68a39..226dfeb 100644
--- a/examples/barycenters/plot_free_support_barycenter.py
+++ b/examples/barycenters/plot_free_support_barycenter.py
@@ -9,61 +9,62 @@ sum of diracs.
"""
-# Author: Vivien Seguy <vivien.seguy@iip.ist.i.kyoto-u.ac.jp>
+# Authors: Vivien Seguy <vivien.seguy@iip.ist.i.kyoto-u.ac.jp>
+# Rémi Flamary <remi.flamary@polytechnique.edu>
#
# License: MIT License
+# sphinx_gallery_thumbnail_number = 2
+
import numpy as np
import matplotlib.pylab as pl
import ot
-##############################################################################
+# %%
# Generate data
# -------------
-N = 3
+N = 2
d = 2
-measures_locations = []
-measures_weights = []
-
-for i in range(N):
- n_i = np.random.randint(low=1, high=20) # nb samples
+I1 = pl.imread('../../data/redcross.png').astype(np.float64)[::4, ::4, 2]
+I2 = pl.imread('../../data/duck.png').astype(np.float64)[::4, ::4, 2]
- mu_i = np.random.normal(0., 4., (d,)) # Gaussian mean
+sz = I2.shape[0]
+XX, YY = np.meshgrid(np.arange(sz), np.arange(sz))
- A_i = np.random.rand(d, d)
- cov_i = np.dot(A_i, A_i.transpose()) # Gaussian covariance matrix
+x1 = np.stack((XX[I1 == 0], YY[I1 == 0]), 1) * 1.0
+x2 = np.stack((XX[I2 == 0] + 80, -YY[I2 == 0] + 32), 1) * 1.0
+x3 = np.stack((XX[I2 == 0], -YY[I2 == 0] + 32), 1) * 1.0
- x_i = ot.datasets.make_2D_samples_gauss(n_i, mu_i, cov_i) # Dirac locations
- b_i = np.random.uniform(0., 1., (n_i,))
- b_i = b_i / np.sum(b_i) # Dirac weights
+measures_locations = [x1, x2]
+measures_weights = [ot.unif(x1.shape[0]), ot.unif(x2.shape[0])]
- measures_locations.append(x_i)
- measures_weights.append(b_i)
+pl.figure(1, (12, 4))
+pl.scatter(x1[:, 0], x1[:, 1], alpha=0.5)
+pl.scatter(x2[:, 0], x2[:, 1], alpha=0.5)
+pl.title('Distributions')
-##############################################################################
+# %%
# Compute free support barycenter
# -------------------------------
-k = 10 # number of Diracs of the barycenter
+k = 200 # number of Diracs of the barycenter
X_init = np.random.normal(0., 1., (k, d)) # initial Dirac locations
b = np.ones((k,)) / k # weights of the barycenter (it will not be optimized, only the locations are optimized)
X = ot.lp.free_support_barycenter(measures_locations, measures_weights, X_init, b)
-
-##############################################################################
-# Plot data
+# %%
+# Plot the barycenter
# ---------
-pl.figure(1)
-for (x_i, b_i) in zip(measures_locations, measures_weights):
- color = np.random.randint(low=1, high=10 * N)
- pl.scatter(x_i[:, 0], x_i[:, 1], s=b_i * 1000, label='input measure')
-pl.scatter(X[:, 0], X[:, 1], s=b * 1000, c='black', marker='^', label='2-Wasserstein barycenter')
+pl.figure(2, (8, 3))
+pl.scatter(x1[:, 0], x1[:, 1], alpha=0.5)
+pl.scatter(x2[:, 0], x2[:, 1], alpha=0.5)
+pl.scatter(X[:, 0], X[:, 1], s=b * 1000, marker='s', label='2-Wasserstein barycenter')
pl.title('Data measures and their barycenter')
-pl.legend(loc=0)
+pl.legend(loc="lower right")
pl.show()
diff --git a/examples/gromov/plot_gromov_wasserstein_dictionary_learning.py b/examples/gromov/plot_gromov_wasserstein_dictionary_learning.py
new file mode 100755
index 0000000..1fdc3b9
--- /dev/null
+++ b/examples/gromov/plot_gromov_wasserstein_dictionary_learning.py
@@ -0,0 +1,357 @@
+# -*- coding: utf-8 -*-
+
+r"""
+=================================
+(Fused) Gromov-Wasserstein Linear Dictionary Learning
+=================================
+
+In this exemple, we illustrate how to learn a Gromov-Wasserstein dictionary on
+a dataset of structured data such as graphs, denoted
+:math:`\{ \mathbf{C_s} \}_{s \in [S]}` where every nodes have uniform weights.
+Given a dictionary :math:`\mathbf{C_{dict}}` composed of D structures of a fixed
+size nt, each graph :math:`(\mathbf{C_s}, \mathbf{p_s})`
+is modeled as a convex combination :math:`\mathbf{w_s} \in \Sigma_D` of these
+dictionary atoms as :math:`\sum_d w_{s,d} \mathbf{C_{dict}[d]}`.
+
+
+First, we consider a dataset composed of graphs generated by Stochastic Block models
+with variable sizes taken in :math:`\{30, ... , 50\}` and quantities of clusters
+varying in :math:`\{ 1, 2, 3\}`. We learn a dictionary of 3 atoms, by minimizing
+the Gromov-Wasserstein distance from all samples to its model in the dictionary
+with respect to the dictionary atoms.
+
+Second, we illustrate the extension of this dictionary learning framework to
+structured data endowed with node features by using the Fused Gromov-Wasserstein
+distance. Starting from the aforementioned dataset of unattributed graphs, we
+add discrete labels uniformly depending on the number of clusters. Then we learn
+and visualize attributed graph atoms where each sample is modeled as a joint convex
+combination between atom structures and features.
+
+
+[38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, Online Graph
+Dictionary Learning, International Conference on Machine Learning (ICML), 2021.
+
+"""
+# Author: Cédric Vincent-Cuaz <cedric.vincent-cuaz@inria.fr>
+#
+# License: MIT License
+
+# sphinx_gallery_thumbnail_number = 4
+
+import numpy as np
+import matplotlib.pylab as pl
+from sklearn.manifold import MDS
+from ot.gromov import gromov_wasserstein_linear_unmixing, gromov_wasserstein_dictionary_learning, fused_gromov_wasserstein_linear_unmixing, fused_gromov_wasserstein_dictionary_learning
+import ot
+import networkx
+from networkx.generators.community import stochastic_block_model as sbm
+# %%
+# =============================================================================
+# Generate a dataset composed of graphs following Stochastic Block models of 1, 2 and 3 clusters.
+# =============================================================================
+
+np.random.seed(42)
+
+N = 60 # number of graphs in the dataset
+# For every number of clusters, we generate SBM with fixed inter/intra-clusters probability.
+clusters = [1, 2, 3]
+Nc = N // len(clusters) # number of graphs by cluster
+nlabels = len(clusters)
+dataset = []
+labels = []
+
+p_inter = 0.1
+p_intra = 0.9
+for n_cluster in clusters:
+ for i in range(Nc):
+ n_nodes = int(np.random.uniform(low=30, high=50))
+
+ if n_cluster > 1:
+ P = p_inter * np.ones((n_cluster, n_cluster))
+ np.fill_diagonal(P, p_intra)
+ else:
+ P = p_intra * np.eye(1)
+ sizes = np.round(n_nodes * np.ones(n_cluster) / n_cluster).astype(np.int32)
+ G = sbm(sizes, P, seed=i, directed=False)
+ C = networkx.to_numpy_array(G)
+ dataset.append(C)
+ labels.append(n_cluster)
+
+
+# Visualize samples
+
+def plot_graph(x, C, binary=True, color='C0', s=None):
+ for j in range(C.shape[0]):
+ for i in range(j):
+ if binary:
+ if C[i, j] > 0:
+ pl.plot([x[i, 0], x[j, 0]], [x[i, 1], x[j, 1]], alpha=0.2, color='k')
+ else: # connection intensity proportional to C[i,j]
+ pl.plot([x[i, 0], x[j, 0]], [x[i, 1], x[j, 1]], alpha=C[i, j], color='k')
+
+ pl.scatter(x[:, 0], x[:, 1], c=color, s=s, zorder=10, edgecolors='k', cmap='tab10', vmax=9)
+
+
+pl.figure(1, (12, 8))
+pl.clf()
+for idx_c, c in enumerate(clusters):
+ C = dataset[(c - 1) * Nc] # sample with c clusters
+ # get 2d position for nodes
+ x = MDS(dissimilarity='precomputed', random_state=0).fit_transform(1 - C)
+ pl.subplot(2, nlabels, c)
+ pl.title('(graph) sample from label ' + str(c), fontsize=14)
+ plot_graph(x, C, binary=True, color='C0', s=50.)
+ pl.axis("off")
+ pl.subplot(2, nlabels, nlabels + c)
+ pl.title('(matrix) sample from label %s \n' % c, fontsize=14)
+ pl.imshow(C, interpolation='nearest')
+ pl.axis("off")
+pl.tight_layout()
+pl.show()
+
+# %%
+# =============================================================================
+# Estimate the gromov-wasserstein dictionary from the dataset
+# =============================================================================
+
+
+np.random.seed(0)
+ps = [ot.unif(C.shape[0]) for C in dataset]
+
+D = 3 # 3 atoms in the dictionary
+nt = 6 # of 6 nodes each
+
+q = ot.unif(nt)
+reg = 0. # regularization coefficient to promote sparsity of unmixings {w_s}
+
+Cdict_GW, log = gromov_wasserstein_dictionary_learning(
+ Cs=dataset, D=D, nt=nt, ps=ps, q=q, epochs=10, batch_size=16,
+ learning_rate=0.1, reg=reg, projection='nonnegative_symmetric',
+ tol_outer=10**(-5), tol_inner=10**(-5), max_iter_outer=30, max_iter_inner=300,
+ use_log=True, use_adam_optimizer=True, verbose=True
+)
+# visualize loss evolution over epochs
+pl.figure(2, (4, 3))
+pl.clf()
+pl.title('loss evolution by epoch', fontsize=14)
+pl.plot(log['loss_epochs'])
+pl.xlabel('epochs', fontsize=12)
+pl.ylabel('loss', fontsize=12)
+pl.tight_layout()
+pl.show()
+
+# %%
+# =============================================================================
+# Visualization of the estimated dictionary atoms
+# =============================================================================
+
+
+# Continuous connections between nodes of the atoms are colored in shades of grey (1: dark / 2: white)
+
+pl.figure(3, (12, 8))
+pl.clf()
+for idx_atom, atom in enumerate(Cdict_GW):
+ scaled_atom = (atom - atom.min()) / (atom.max() - atom.min())
+ x = MDS(dissimilarity='precomputed', random_state=0).fit_transform(1 - scaled_atom)
+ pl.subplot(2, D, idx_atom + 1)
+ pl.title('(graph) atom ' + str(idx_atom + 1), fontsize=14)
+ plot_graph(x, atom / atom.max(), binary=False, color='C0', s=100.)
+ pl.axis("off")
+ pl.subplot(2, D, D + idx_atom + 1)
+ pl.title('(matrix) atom %s \n' % (idx_atom + 1), fontsize=14)
+ pl.imshow(scaled_atom, interpolation='nearest')
+ pl.colorbar()
+ pl.axis("off")
+pl.tight_layout()
+pl.show()
+#%%
+# =============================================================================
+# Visualization of the embedding space
+# =============================================================================
+
+unmixings = []
+reconstruction_errors = []
+for C in dataset:
+ p = ot.unif(C.shape[0])
+ unmixing, Cembedded, OT, reconstruction_error = gromov_wasserstein_linear_unmixing(
+ C, Cdict_GW, p=p, q=q, reg=reg,
+ tol_outer=10**(-5), tol_inner=10**(-5),
+ max_iter_outer=30, max_iter_inner=300
+ )
+ unmixings.append(unmixing)
+ reconstruction_errors.append(reconstruction_error)
+unmixings = np.array(unmixings)
+print('cumulated reconstruction error:', np.array(reconstruction_errors).sum())
+
+
+# Compute the 2D representation of the unmixing living in the 2-simplex of probability
+unmixings2D = np.zeros(shape=(N, 2))
+for i, w in enumerate(unmixings):
+ unmixings2D[i, 0] = (2. * w[1] + w[2]) / 2.
+ unmixings2D[i, 1] = (np.sqrt(3.) * w[2]) / 2.
+x = [0., 0.]
+y = [1., 0.]
+z = [0.5, np.sqrt(3) / 2.]
+extremities = np.stack([x, y, z])
+
+pl.figure(4, (4, 4))
+pl.clf()
+pl.title('Embedding space', fontsize=14)
+for cluster in range(nlabels):
+ start, end = Nc * cluster, Nc * (cluster + 1)
+ if cluster == 0:
+ pl.scatter(unmixings2D[start:end, 0], unmixings2D[start:end, 1], c='C' + str(cluster), marker='o', s=40., label='1 cluster')
+ else:
+ pl.scatter(unmixings2D[start:end, 0], unmixings2D[start:end, 1], c='C' + str(cluster), marker='o', s=40., label='%s clusters' % (cluster + 1))
+pl.scatter(extremities[:, 0], extremities[:, 1], c='black', marker='x', s=80., label='atoms')
+pl.plot([x[0], y[0]], [x[1], y[1]], color='black', linewidth=2.)
+pl.plot([x[0], z[0]], [x[1], z[1]], color='black', linewidth=2.)
+pl.plot([y[0], z[0]], [y[1], z[1]], color='black', linewidth=2.)
+pl.axis('off')
+pl.legend(fontsize=11)
+pl.tight_layout()
+pl.show()
+# %%
+# =============================================================================
+# Endow the dataset with node features
+# =============================================================================
+
+# We follow this feature assignment on all nodes of a graph depending on its label/number of clusters
+# 1 cluster --> 0 as nodes feature
+# 2 clusters --> 1 as nodes feature
+# 3 clusters --> 2 as nodes feature
+# features are one-hot encoded following these assignments
+dataset_features = []
+for i in range(len(dataset)):
+ n = dataset[i].shape[0]
+ F = np.zeros((n, 3))
+ if i < Nc: # graph with 1 cluster
+ F[:, 0] = 1.
+ elif i < 2 * Nc: # graph with 2 clusters
+ F[:, 1] = 1.
+ else: # graph with 3 clusters
+ F[:, 2] = 1.
+ dataset_features.append(F)
+
+pl.figure(5, (12, 8))
+pl.clf()
+for idx_c, c in enumerate(clusters):
+ C = dataset[(c - 1) * Nc] # sample with c clusters
+ F = dataset_features[(c - 1) * Nc]
+ colors = ['C' + str(np.argmax(F[i])) for i in range(F.shape[0])]
+ # get 2d position for nodes
+ x = MDS(dissimilarity='precomputed', random_state=0).fit_transform(1 - C)
+ pl.subplot(2, nlabels, c)
+ pl.title('(graph) sample from label ' + str(c), fontsize=14)
+ plot_graph(x, C, binary=True, color=colors, s=50)
+ pl.axis("off")
+ pl.subplot(2, nlabels, nlabels + c)
+ pl.title('(matrix) sample from label %s \n' % c, fontsize=14)
+ pl.imshow(C, interpolation='nearest')
+ pl.axis("off")
+pl.tight_layout()
+pl.show()
+# %%
+# =============================================================================
+# Estimate a Fused Gromov-Wasserstein dictionary from the dataset of attributed graphs
+# =============================================================================
+np.random.seed(0)
+ps = [ot.unif(C.shape[0]) for C in dataset]
+D = 3 # 6 atoms instead of 3
+nt = 6
+q = ot.unif(nt)
+reg = 0.001
+alpha = 0.5 # trade-off parameter between structure and feature information of Fused Gromov-Wasserstein
+
+
+Cdict_FGW, Ydict_FGW, log = fused_gromov_wasserstein_dictionary_learning(
+ Cs=dataset, Ys=dataset_features, D=D, nt=nt, ps=ps, q=q, alpha=alpha,
+ epochs=10, batch_size=16, learning_rate_C=0.1, learning_rate_Y=0.1, reg=reg,
+ tol_outer=10**(-5), tol_inner=10**(-5), max_iter_outer=30, max_iter_inner=300,
+ projection='nonnegative_symmetric', use_log=True, use_adam_optimizer=True, verbose=True
+)
+# visualize loss evolution
+pl.figure(6, (4, 3))
+pl.clf()
+pl.title('loss evolution by epoch', fontsize=14)
+pl.plot(log['loss_epochs'])
+pl.xlabel('epochs', fontsize=12)
+pl.ylabel('loss', fontsize=12)
+pl.tight_layout()
+pl.show()
+
+# %%
+# =============================================================================
+# Visualization of the estimated dictionary atoms
+# =============================================================================
+
+pl.figure(7, (12, 8))
+pl.clf()
+max_features = Ydict_FGW.max()
+min_features = Ydict_FGW.min()
+
+for idx_atom, (Catom, Fatom) in enumerate(zip(Cdict_FGW, Ydict_FGW)):
+ scaled_atom = (Catom - Catom.min()) / (Catom.max() - Catom.min())
+ #scaled_F = 2 * (Fatom - min_features) / (max_features - min_features)
+ colors = ['C%s' % np.argmax(Fatom[i]) for i in range(Fatom.shape[0])]
+ x = MDS(dissimilarity='precomputed', random_state=0).fit_transform(1 - scaled_atom)
+ pl.subplot(2, D, idx_atom + 1)
+ pl.title('(attributed graph) atom ' + str(idx_atom + 1), fontsize=14)
+ plot_graph(x, Catom / Catom.max(), binary=False, color=colors, s=100)
+ pl.axis("off")
+ pl.subplot(2, D, D + idx_atom + 1)
+ pl.title('(matrix) atom %s \n' % (idx_atom + 1), fontsize=14)
+ pl.imshow(scaled_atom, interpolation='nearest')
+ pl.colorbar()
+ pl.axis("off")
+pl.tight_layout()
+pl.show()
+
+# %%
+# =============================================================================
+# Visualization of the embedding space
+# =============================================================================
+
+unmixings = []
+reconstruction_errors = []
+for i in range(len(dataset)):
+ C = dataset[i]
+ Y = dataset_features[i]
+ p = ot.unif(C.shape[0])
+ unmixing, Cembedded, Yembedded, OT, reconstruction_error = fused_gromov_wasserstein_linear_unmixing(
+ C, Y, Cdict_FGW, Ydict_FGW, p=p, q=q, alpha=alpha,
+ reg=reg, tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=30, max_iter_inner=300
+ )
+ unmixings.append(unmixing)
+ reconstruction_errors.append(reconstruction_error)
+unmixings = np.array(unmixings)
+print('cumulated reconstruction error:', np.array(reconstruction_errors).sum())
+
+# Visualize unmixings in the 2-simplex of probability
+unmixings2D = np.zeros(shape=(N, 2))
+for i, w in enumerate(unmixings):
+ unmixings2D[i, 0] = (2. * w[1] + w[2]) / 2.
+ unmixings2D[i, 1] = (np.sqrt(3.) * w[2]) / 2.
+x = [0., 0.]
+y = [1., 0.]
+z = [0.5, np.sqrt(3) / 2.]
+extremities = np.stack([x, y, z])
+
+pl.figure(8, (4, 4))
+pl.clf()
+pl.title('Embedding space', fontsize=14)
+for cluster in range(nlabels):
+ start, end = Nc * cluster, Nc * (cluster + 1)
+ if cluster == 0:
+ pl.scatter(unmixings2D[start:end, 0], unmixings2D[start:end, 1], c='C' + str(cluster), marker='o', s=40., label='1 cluster')
+ else:
+ pl.scatter(unmixings2D[start:end, 0], unmixings2D[start:end, 1], c='C' + str(cluster), marker='o', s=40., label='%s clusters' % (cluster + 1))
+
+pl.scatter(extremities[:, 0], extremities[:, 1], c='black', marker='x', s=80., label='atoms')
+pl.plot([x[0], y[0]], [x[1], y[1]], color='black', linewidth=2.)
+pl.plot([x[0], z[0]], [x[1], z[1]], color='black', linewidth=2.)
+pl.plot([y[0], z[0]], [y[1], z[1]], color='black', linewidth=2.)
+pl.axis('off')
+pl.legend(fontsize=11)
+pl.tight_layout()
+pl.show()
diff --git a/examples/others/plot_WeakOT_VS_OT.py b/examples/others/plot_WeakOT_VS_OT.py
new file mode 100644
index 0000000..a29c875
--- /dev/null
+++ b/examples/others/plot_WeakOT_VS_OT.py
@@ -0,0 +1,98 @@
+# -*- coding: utf-8 -*-
+"""
+====================================================
+Weak Optimal Transport VS exact Optimal Transport
+====================================================
+
+Illustration of 2D optimal transport between distributions that are weighted
+sum of diracs. The OT matrix is plotted with the samples.
+
+"""
+
+# Author: Remi Flamary <remi.flamary@polytechnique.edu>
+#
+# License: MIT License
+
+# sphinx_gallery_thumbnail_number = 4
+
+import numpy as np
+import matplotlib.pylab as pl
+import ot
+import ot.plot
+
+##############################################################################
+# Generate data an plot it
+# ------------------------
+
+#%% parameters and data generation
+
+n = 50 # nb samples
+
+mu_s = np.array([0, 0])
+cov_s = np.array([[1, 0], [0, 1]])
+
+mu_t = np.array([4, 4])
+cov_t = np.array([[1, -.8], [-.8, 1]])
+
+xs = ot.datasets.make_2D_samples_gauss(n, mu_s, cov_s)
+xt = ot.datasets.make_2D_samples_gauss(n, mu_t, cov_t)
+
+a, b = ot.unif(n), ot.unif(n) # uniform distribution on samples
+
+# loss matrix
+M = ot.dist(xs, xt)
+M /= M.max()
+
+#%% plot samples
+
+pl.figure(1)
+pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
+pl.legend(loc=0)
+pl.title('Source and target distributions')
+
+pl.figure(2)
+pl.imshow(M, interpolation='nearest')
+pl.title('Cost matrix M')
+
+
+##############################################################################
+# Compute Weak OT and exact OT solutions
+# --------------------------------------
+
+#%% EMD
+
+G0 = ot.emd(a, b, M)
+
+#%% Weak OT
+
+Gweak = ot.weak_optimal_transport(xs, xt, a, b)
+
+
+##############################################################################
+# Plot weak OT and exact OT solutions
+# --------------------------------------
+
+pl.figure(3, (8, 5))
+
+pl.subplot(1, 2, 1)
+pl.imshow(G0, interpolation='nearest')
+pl.title('OT matrix')
+
+pl.subplot(1, 2, 2)
+pl.imshow(Gweak, interpolation='nearest')
+pl.title('Weak OT matrix')
+
+pl.figure(4, (8, 5))
+
+pl.subplot(1, 2, 1)
+ot.plot.plot2D_samples_mat(xs, xt, G0, c=[.5, .5, 1])
+pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
+pl.title('OT matrix with samples')
+
+pl.subplot(1, 2, 2)
+ot.plot.plot2D_samples_mat(xs, xt, Gweak, c=[.5, .5, 1])
+pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
+pl.title('Weak OT matrix with samples')
diff --git a/examples/others/plot_factored_coupling.py b/examples/others/plot_factored_coupling.py
new file mode 100644
index 0000000..b5b1c9f
--- /dev/null
+++ b/examples/others/plot_factored_coupling.py
@@ -0,0 +1,86 @@
+# -*- coding: utf-8 -*-
+"""
+==========================================
+Optimal transport with factored couplings
+==========================================
+
+Illustration of the factored coupling OT between 2D empirical distributions
+
+"""
+
+# Author: Remi Flamary <remi.flamary@polytechnique.edu>
+#
+# License: MIT License
+
+# sphinx_gallery_thumbnail_number = 2
+
+import numpy as np
+import matplotlib.pylab as pl
+import ot
+import ot.plot
+
+# %%
+# Generate data an plot it
+# ------------------------
+
+# parameters and data generation
+
+np.random.seed(42)
+
+n = 100 # nb samples
+
+xs = np.random.rand(n, 2) - .5
+
+xs = xs + np.sign(xs)
+
+xt = np.random.rand(n, 2) - .5
+
+a, b = ot.unif(n), ot.unif(n) # uniform distribution on samples
+
+#%% plot samples
+
+pl.figure(1)
+pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
+pl.legend(loc=0)
+pl.title('Source and target distributions')
+
+
+# %%
+# Compute Factore OT and exact OT solutions
+# --------------------------------------
+
+#%% EMD
+M = ot.dist(xs, xt)
+G0 = ot.emd(a, b, M)
+
+#%% factored OT OT
+
+Ga, Gb, xb = ot.factored_optimal_transport(xs, xt, a, b, r=4)
+
+
+# %%
+# Plot factored OT and exact OT solutions
+# --------------------------------------
+
+pl.figure(2, (14, 4))
+
+pl.subplot(1, 3, 1)
+ot.plot.plot2D_samples_mat(xs, xt, G0, c=[.2, .2, .2], alpha=0.1)
+pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
+pl.title('Exact OT with samples')
+
+pl.subplot(1, 3, 2)
+ot.plot.plot2D_samples_mat(xs, xb, Ga, c=[.6, .6, .9], alpha=0.5)
+ot.plot.plot2D_samples_mat(xb, xt, Gb, c=[.9, .6, .6], alpha=0.5)
+pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
+pl.plot(xb[:, 0], xb[:, 1], 'og', label='Template samples')
+pl.title('Factored OT with template samples')
+
+pl.subplot(1, 3, 3)
+ot.plot.plot2D_samples_mat(xs, xt, Ga.dot(Gb), c=[.2, .2, .2], alpha=0.1)
+pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
+pl.title('Factored OT low rank OT plan')
diff --git a/examples/others/plot_logo.py b/examples/others/plot_logo.py
new file mode 100644
index 0000000..bb4f640
--- /dev/null
+++ b/examples/others/plot_logo.py
@@ -0,0 +1,112 @@
+
+# -*- coding: utf-8 -*-
+r"""
+=======================
+Logo of the POT toolbox
+=======================
+
+In this example we plot the logo of the POT toolbox.
+
+This logo is that it is done 100% in Python and generated using
+matplotlib and ploting teh solution of the EMD solver from POT.
+
+"""
+
+# Author: Remi Flamary <remi.flamary@polytechnique.edu>
+#
+# License: MIT License
+
+# sphinx_gallery_thumbnail_number = 1
+
+# %% Load modules
+import numpy as np
+import matplotlib.pyplot as pl
+import ot
+
+# %%
+# Data for logo
+# -------------
+
+
+# Letter P
+p1 = np.array([[0, 6.], [0, 5], [0, 4], [0, 3], [0, 2], [0, 1], ])
+p2 = np.array([[1.5, 6], [2, 4], [2, 5], [1.5, 3], [0.5, 2], [.5, 1], ])
+
+# Letter O
+o1 = np.array([[0, 6.], [-1, 5], [-1.5, 4], [-1.5, 3], [-1, 2], [0, 1], ])
+o2 = np.array([[1, 6.], [2, 5], [2.5, 4], [2.5, 3], [2, 2], [1, 1], ])
+
+# Scaling and translation for letter O
+o1[:, 0] += 6.4
+o2[:, 0] += 6.4
+o1[:, 0] *= 0.6
+o2[:, 0] *= 0.6
+
+# Letter T
+t1 = np.array([[-1, 6.], [-1, 5], [0, 4], [0, 3], [0, 2], [0, 1], ])
+t2 = np.array([[1.5, 6.], [1.5, 5], [0.5, 4], [0.5, 3], [0.5, 2], [0.5, 1], ])
+
+# Translating the T
+t1[:, 0] += 7.1
+t2[:, 0] += 7.1
+
+# Concatenate all letters
+x1 = np.concatenate((p1, o1, t1), axis=0)
+x2 = np.concatenate((p2, o2, t2), axis=0)
+
+# Horizontal and vertical scaling
+sx = 1.0
+sy = .5
+x1[:, 0] *= sx
+x1[:, 1] *= sy
+x2[:, 0] *= sx
+x2[:, 1] *= sy
+
+# %%
+# Plot the logo (clear background)
+# --------------------------------
+
+# Solve OT problem between the points
+M = ot.dist(x1, x2, metric='euclidean')
+T = ot.emd([], [], M)
+
+pl.figure(1, (3.5, 1.1))
+pl.clf()
+# plot the OT plan
+for i in range(M.shape[0]):
+ for j in range(M.shape[1]):
+ if T[i, j] > 1e-8:
+ pl.plot([x1[i, 0], x2[j, 0]], [x1[i, 1], x2[j, 1]], color='k', alpha=0.6, linewidth=3, zorder=1)
+# plot the samples
+pl.plot(x1[:, 0], x1[:, 1], 'o', markerfacecolor='C3', markeredgecolor='k')
+pl.plot(x2[:, 0], x2[:, 1], 'o', markerfacecolor='b', markeredgecolor='k')
+
+
+pl.axis('equal')
+pl.axis('off')
+
+# Save logo file
+# pl.savefig('logo.svg', dpi=150, transparent=True, bbox_inches='tight')
+# pl.savefig('logo.png', dpi=150, transparent=True, bbox_inches='tight')
+
+# %%
+# Plot the logo (dark background)
+# --------------------------------
+
+pl.figure(2, (3.5, 1.1), facecolor='darkgray')
+pl.clf()
+# plot the OT plan
+for i in range(M.shape[0]):
+ for j in range(M.shape[1]):
+ if T[i, j] > 1e-8:
+ pl.plot([x1[i, 0], x2[j, 0]], [x1[i, 1], x2[j, 1]], color='w', alpha=0.8, linewidth=3, zorder=1)
+# plot the samples
+pl.plot(x1[:, 0], x1[:, 1], 'o', markerfacecolor='w', markeredgecolor='w')
+pl.plot(x2[:, 0], x2[:, 1], 'o', markerfacecolor='w', markeredgecolor='w')
+
+pl.axis('equal')
+pl.axis('off')
+
+# Save logo file
+# pl.savefig('logo_dark.svg', dpi=150, transparent=True, bbox_inches='tight')
+# pl.savefig('logo_dark.png', dpi=150, transparent=True, bbox_inches='tight')
diff --git a/examples/plot_screenkhorn_1D.py b/examples/others/plot_screenkhorn_1D.py
index 785642a..2023649 100644
--- a/examples/plot_screenkhorn_1D.py
+++ b/examples/others/plot_screenkhorn_1D.py
@@ -1,8 +1,8 @@
# -*- coding: utf-8 -*-
"""
-===============================
-1D Screened optimal transport
-===============================
+========================================
+Screened optimal transport (Screenkhorn)
+========================================
This example illustrates the computation of Screenkhorn [26].
diff --git a/examples/plot_stochastic.py b/examples/others/plot_stochastic.py
index 3a1ef31..3a1ef31 100644
--- a/examples/plot_stochastic.py
+++ b/examples/others/plot_stochastic.py
diff --git a/examples/plot_Intro_OT.py b/examples/plot_Intro_OT.py
index f282950..219aa51 100644
--- a/examples/plot_Intro_OT.py
+++ b/examples/plot_Intro_OT.py
@@ -58,7 +58,7 @@ help(ot.dist)
# number of Bakeries to Cafés in a City (in this case Manhattan). We did a
# quick google map search in Manhattan for bakeries and Cafés:
#
-# .. image:: images/bak.png
+# .. image:: ../_static/images/bak.png
# :align: center
# :alt: bakery-cafe-manhattan
# :width: 600px
@@ -233,7 +233,7 @@ print('Wasserstein loss (EMD) = {0:.2f}'.format(W))
# The Sinkhorn algorithm is very simple to code. You can implement it directly
# using the following pseudo-code
#
-# .. image:: images/sinkhorn.png
+# .. image:: ../_static/images/sinkhorn.png
# :align: center
# :alt: Sinkhorn algorithm
# :width: 440px
diff --git a/examples/plot_OT_1D.py b/examples/plot_OT_1D.py
index 15ead96..62f0b7d 100644
--- a/examples/plot_OT_1D.py
+++ b/examples/plot_OT_1D.py
@@ -1,8 +1,8 @@
# -*- coding: utf-8 -*-
"""
-====================
-1D optimal transport
-====================
+======================================
+Optimal Transport for 1D distributions
+======================================
This example illustrates the computation of EMD and Sinkhorn transport plans
and their visualization.
@@ -64,7 +64,11 @@ ot.plot.plot1D_mat(a, b, M, 'Cost matrix M')
#%% EMD
-G0 = ot.emd(a, b, M)
+# use fast 1D solver
+G0 = ot.emd_1d(x, x, a, b)
+
+# Equivalent to
+# G0 = ot.emd(a, b, M)
pl.figure(3, figsize=(5, 5))
ot.plot.plot1D_mat(a, b, G0, 'OT matrix G0')
diff --git a/examples/plot_OT_1D_smooth.py b/examples/plot_OT_1D_smooth.py
index b07f99f..5415e4f 100644
--- a/examples/plot_OT_1D_smooth.py
+++ b/examples/plot_OT_1D_smooth.py
@@ -1,8 +1,8 @@
# -*- coding: utf-8 -*-
"""
-===========================
-1D smooth optimal transport
-===========================
+================================
+Smooth optimal transport example
+================================
This example illustrates the computation of EMD, Sinkhorn and smooth OT plans
and their visualization.
diff --git a/examples/plot_OT_2D_samples.py b/examples/plot_OT_2D_samples.py
index af1bc12..1d82fb8 100644
--- a/examples/plot_OT_2D_samples.py
+++ b/examples/plot_OT_2D_samples.py
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
"""
====================================================
-2D Optimal transport between empirical distributions
+Optimal Transport between 2D empirical distributions
====================================================
Illustration of 2D optimal transport between discributions that are weighted
@@ -42,7 +42,6 @@ a, b = np.ones((n,)) / n, np.ones((n,)) / n # uniform distribution on samples
# loss matrix
M = ot.dist(xs, xt)
-M /= M.max()
##############################################################################
# Plot data
@@ -87,7 +86,7 @@ pl.title('OT matrix with samples')
#%% sinkhorn
# reg term
-lambd = 1e-3
+lambd = 1e-1
Gs = ot.sinkhorn(a, b, M, lambd)
@@ -112,7 +111,7 @@ pl.show()
#%% sinkhorn
# reg term
-lambd = 1e-3
+lambd = 1e-1
Ges = ot.bregman.empirical_sinkhorn(xs, xt, lambd)
diff --git a/examples/plot_OT_L1_vs_L2.py b/examples/plot_OT_L1_vs_L2.py
index 60353ab..cce51f8 100644
--- a/examples/plot_OT_L1_vs_L2.py
+++ b/examples/plot_OT_L1_vs_L2.py
@@ -1,10 +1,10 @@
# -*- coding: utf-8 -*-
"""
-==========================================
-2D Optimal transport for different metrics
-==========================================
+================================================
+Optimal Transport with different gournd metrics
+================================================
-2D OT on empirical distributio with different gound metric.
+2D OT on empirical distributio with different ground metric.
Stole the figure idea from Fig. 1 and 2 in
https://arxiv.org/pdf/1706.07650.pdf
@@ -23,7 +23,7 @@ import matplotlib.pylab as pl
import ot
import ot.plot
-##############################################################################
+# %%
# Dataset 1 : uniform sampling
# ----------------------------
@@ -46,7 +46,7 @@ M2 = ot.dist(xs, xt, metric='sqeuclidean')
M2 /= M2.max()
# loss matrix
-Mp = np.sqrt(ot.dist(xs, xt, metric='euclidean'))
+Mp = ot.dist(xs, xt, metric='cityblock')
Mp /= Mp.max()
# Data
@@ -71,7 +71,7 @@ pl.title('Squared Euclidean cost')
pl.subplot(1, 3, 3)
pl.imshow(Mp, interpolation='nearest')
-pl.title('Sqrt Euclidean cost')
+pl.title('L1 (cityblock cost')
pl.tight_layout()
##############################################################################
@@ -109,22 +109,22 @@ pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
pl.axis('equal')
# pl.legend(loc=0)
-pl.title('OT sqrt Euclidean')
+pl.title('OT L1 (cityblock)')
pl.tight_layout()
pl.show()
-##############################################################################
+# %%
# Dataset 2 : Partial circle
# --------------------------
-n = 50 # nb samples
+n = 20 # nb samples
xtot = np.zeros((n + 1, 2))
xtot[:, 0] = np.cos(
- (np.arange(n + 1) + 1.0) * 0.9 / (n + 2) * 2 * np.pi)
+ (np.arange(n + 1) + 1.0) * 0.8 / (n + 2) * 2 * np.pi)
xtot[:, 1] = np.sin(
- (np.arange(n + 1) + 1.0) * 0.9 / (n + 2) * 2 * np.pi)
+ (np.arange(n + 1) + 1.0) * 0.8 / (n + 2) * 2 * np.pi)
xs = xtot[:n, :]
xt = xtot[1:, :]
@@ -140,7 +140,7 @@ M2 = ot.dist(xs, xt, metric='sqeuclidean')
M2 /= M2.max()
# loss matrix
-Mp = np.sqrt(ot.dist(xs, xt, metric='euclidean'))
+Mp = ot.dist(xs, xt, metric='cityblock')
Mp /= Mp.max()
@@ -150,7 +150,7 @@ pl.clf()
pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
pl.axis('equal')
-pl.title('Source and traget distributions')
+pl.title('Source and target distributions')
# Cost matrices
@@ -166,13 +166,13 @@ pl.title('Squared Euclidean cost')
pl.subplot(1, 3, 3)
pl.imshow(Mp, interpolation='nearest')
-pl.title('Sqrt Euclidean cost')
+pl.title('L1 (cityblock) cost')
pl.tight_layout()
##############################################################################
# Dataset 2 : Plot OT Matrices
# -----------------------------
-
+#
#%% EMD
G1 = ot.emd(a, b, M1)
@@ -204,7 +204,7 @@ pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
pl.axis('equal')
# pl.legend(loc=0)
-pl.title('OT sqrt Euclidean')
+pl.title('OT L1 (cityblock)')
pl.tight_layout()
pl.show()
diff --git a/examples/plot_compute_emd.py b/examples/plot_compute_emd.py
index 527a847..36cc7da 100644
--- a/examples/plot_compute_emd.py
+++ b/examples/plot_compute_emd.py
@@ -1,10 +1,10 @@
# -*- coding: utf-8 -*-
"""
-=================
-Plot multiple EMD
-=================
+==================
+OT distances in 1D
+==================
-Shows how to compute multiple EMD and Sinkhorn with two different
+Shows how to compute multiple Wassersein and Sinkhorn with two different
ground metrics and plot their values for different distributions.
@@ -14,7 +14,7 @@ ground metrics and plot their values for different distributions.
#
# License: MIT License
-# sphinx_gallery_thumbnail_number = 3
+# sphinx_gallery_thumbnail_number = 2
import numpy as np
import matplotlib.pylab as pl
@@ -29,7 +29,7 @@ from ot.datasets import make_1D_gauss as gauss
#%% parameters
n = 100 # nb bins
-n_target = 50 # nb target distributions
+n_target = 20 # nb target distributions
# bin positions
@@ -47,9 +47,9 @@ for i, m in enumerate(lst_m):
# loss matrix and normalization
M = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)), 'euclidean')
-M /= M.max()
+M /= M.max() * 0.1
M2 = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)), 'sqeuclidean')
-M2 /= M2.max()
+M2 /= M2.max() * 0.1
##############################################################################
# Plot data
@@ -59,10 +59,12 @@ M2 /= M2.max()
pl.figure(1)
pl.subplot(2, 1, 1)
-pl.plot(x, a, 'b', label='Source distribution')
+pl.plot(x, a, 'r', label='Source distribution')
pl.title('Source distribution')
pl.subplot(2, 1, 2)
-pl.plot(x, B, label='Target distributions')
+for i in range(n_target):
+ pl.plot(x, B[:, i], 'b', alpha=i / n_target)
+pl.plot(x, B[:, -1], 'b', label='Target distributions')
pl.title('Target distributions')
pl.tight_layout()
@@ -73,14 +75,27 @@ pl.tight_layout()
#%% Compute and plot distributions and loss matrix
-d_emd = ot.emd2(a, B, M) # direct computation of EMD
-d_emd2 = ot.emd2(a, B, M2) # direct computation of EMD with loss M2
-
+d_emd = ot.emd2(a, B, M) # direct computation of OT loss
+d_emd2 = ot.emd2(a, B, M2) # direct computation of OT loss with metrixc M2
+d_tv = [np.sum(abs(a - B[:, i])) for i in range(n_target)]
pl.figure(2)
-pl.plot(d_emd, label='Euclidean EMD')
-pl.plot(d_emd2, label='Squared Euclidean EMD')
-pl.title('EMD distances')
+pl.subplot(2, 1, 1)
+pl.plot(x, a, 'r', label='Source distribution')
+pl.title('Distributions')
+for i in range(n_target):
+ pl.plot(x, B[:, i], 'b', alpha=i / n_target)
+pl.plot(x, B[:, -1], 'b', label='Target distributions')
+pl.ylim((-.01, 0.13))
+pl.xticks(())
+pl.legend()
+pl.subplot(2, 1, 2)
+pl.plot(d_emd, label='Euclidean OT')
+pl.plot(d_emd2, label='Squared Euclidean OT')
+pl.plot(d_tv, label='Total Variation (TV)')
+#pl.xlim((-7,23))
+pl.xlabel('Displacement')
+pl.title('Divergences')
pl.legend()
##############################################################################
@@ -88,17 +103,30 @@ pl.legend()
# -----------------------------------------
#%%
-reg = 1e-2
+reg = 1e-1
d_sinkhorn = ot.sinkhorn2(a, B, M, reg)
d_sinkhorn2 = ot.sinkhorn2(a, B, M2, reg)
-pl.figure(2)
+pl.figure(3)
pl.clf()
-pl.plot(d_emd, label='Euclidean EMD')
-pl.plot(d_emd2, label='Squared Euclidean EMD')
+
+pl.subplot(2, 1, 1)
+pl.plot(x, a, 'r', label='Source distribution')
+pl.title('Distributions')
+for i in range(n_target):
+ pl.plot(x, B[:, i], 'b', alpha=i / n_target)
+pl.plot(x, B[:, -1], 'b', label='Target distributions')
+pl.ylim((-.01, 0.13))
+pl.xticks(())
+pl.legend()
+pl.subplot(2, 1, 2)
+pl.plot(d_emd, label='Euclidean OT')
+pl.plot(d_emd2, label='Squared Euclidean OT')
pl.plot(d_sinkhorn, '+', label='Euclidean Sinkhorn')
pl.plot(d_sinkhorn2, '+', label='Squared Euclidean Sinkhorn')
-pl.title('EMD distances')
+pl.plot(d_tv, label='Total Variation (TV)')
+#pl.xlim((-7,23))
+pl.xlabel('Displacement')
+pl.title('Divergences')
pl.legend()
-
pl.show()
diff --git a/examples/plot_optim_OTreg.py b/examples/plot_optim_OTreg.py
index 5eb15bd..7b021d2 100644
--- a/examples/plot_optim_OTreg.py
+++ b/examples/plot_optim_OTreg.py
@@ -24,7 +24,7 @@ arXiv preprint arXiv:1510.06567.
"""
-# sphinx_gallery_thumbnail_number = 4
+# sphinx_gallery_thumbnail_number = 5
import numpy as np
import matplotlib.pylab as pl
@@ -58,7 +58,7 @@ M /= M.max()
G0 = ot.emd(a, b, M)
-pl.figure(3, figsize=(5, 5))
+pl.figure(1, figsize=(5, 5))
ot.plot.plot1D_mat(a, b, G0, 'OT matrix G0')
##############################################################################
@@ -80,7 +80,7 @@ reg = 1e-1
Gl2 = ot.optim.cg(a, b, M, reg, f, df, verbose=True)
-pl.figure(3)
+pl.figure(2)
ot.plot.plot1D_mat(a, b, Gl2, 'OT matrix Frob. reg')
##############################################################################
@@ -102,7 +102,7 @@ reg = 1e-3
Ge = ot.optim.cg(a, b, M, reg, f, df, verbose=True)
-pl.figure(4, figsize=(5, 5))
+pl.figure(3, figsize=(5, 5))
ot.plot.plot1D_mat(a, b, Ge, 'OT matrix Entrop. reg')
##############################################################################
@@ -125,6 +125,34 @@ reg2 = 1e-1
Gel2 = ot.optim.gcg(a, b, M, reg1, reg2, f, df, verbose=True)
-pl.figure(5, figsize=(5, 5))
+pl.figure(4, figsize=(5, 5))
ot.plot.plot1D_mat(a, b, Gel2, 'OT entropic + matrix Frob. reg')
pl.show()
+
+
+# %%
+# Comparison of the OT matrices
+
+nvisu = 40
+
+pl.figure(5, figsize=(10, 4))
+
+pl.subplot(2, 2, 1)
+pl.imshow(G0[:nvisu, :])
+pl.axis('off')
+pl.title('Exact OT')
+
+pl.subplot(2, 2, 2)
+pl.imshow(Gl2[:nvisu, :])
+pl.axis('off')
+pl.title('Frobenius reg.')
+
+pl.subplot(2, 2, 3)
+pl.imshow(Ge[:nvisu, :])
+pl.axis('off')
+pl.title('Entropic reg.')
+
+pl.subplot(2, 2, 4)
+pl.imshow(Gel2[:nvisu, :])
+pl.axis('off')
+pl.title('Entropic + Frobenius reg.')
diff --git a/examples/sliced-wasserstein/README.txt b/examples/sliced-wasserstein/README.txt
index a575345..73e6122 100644
--- a/examples/sliced-wasserstein/README.txt
+++ b/examples/sliced-wasserstein/README.txt
@@ -1,4 +1,4 @@
Sliced Wasserstein Distance
---------------------------- \ No newline at end of file
+---------------------------
diff --git a/examples/sliced-wasserstein/plot_variance.py b/examples/sliced-wasserstein/plot_variance.py
index 7d73907..f12b522 100644
--- a/examples/sliced-wasserstein/plot_variance.py
+++ b/examples/sliced-wasserstein/plot_variance.py
@@ -1,8 +1,8 @@
# -*- coding: utf-8 -*-
"""
-==============================
-2D Sliced Wasserstein Distance
-==============================
+===============================================
+Sliced Wasserstein Distance on 2D distributions
+===============================================
This example illustrates the computation of the sliced Wasserstein Distance as
proposed in [31].
@@ -16,6 +16,8 @@ measures." Journal of Mathematical Imaging and Vision 51.1 (2015): 22-45
#
# License: MIT License
+# sphinx_gallery_thumbnail_number = 2
+
import matplotlib.pylab as pl
import numpy as np
diff --git a/examples/unbalanced-partial/plot_UOT_1D.py b/examples/unbalanced-partial/plot_UOT_1D.py
index 183849c..06dd02d 100644
--- a/examples/unbalanced-partial/plot_UOT_1D.py
+++ b/examples/unbalanced-partial/plot_UOT_1D.py
@@ -12,6 +12,8 @@ using a Kullback-Leibler relaxation.
#
# License: MIT License
+# sphinx_gallery_thumbnail_number = 4
+
import numpy as np
import matplotlib.pylab as pl
import ot
@@ -69,7 +71,20 @@ epsilon = 0.1 # entropy parameter
alpha = 1. # Unbalanced KL relaxation parameter
Gs = ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, alpha, verbose=True)
-pl.figure(4, figsize=(5, 5))
+pl.figure(3, figsize=(5, 5))
ot.plot.plot1D_mat(a, b, Gs, 'UOT matrix Sinkhorn')
pl.show()
+
+
+# %%
+# plot the transported mass
+# -------------------------
+
+pl.figure(4, figsize=(6.4, 3))
+pl.plot(x, a, 'b', label='Source distribution')
+pl.plot(x, b, 'r', label='Target distribution')
+pl.fill(x, Gs.sum(1), 'b', alpha=0.5, label='Transported source')
+pl.fill(x, Gs.sum(0), 'r', alpha=0.5, label='Transported target')
+pl.legend(loc='upper right')
+pl.title('Distributions and transported mass for UOT')
diff --git a/examples/unbalanced-partial/plot_regpath.py b/examples/unbalanced-partial/plot_regpath.py
index 4a51c2d..782e8c2 100644
--- a/examples/unbalanced-partial/plot_regpath.py
+++ b/examples/unbalanced-partial/plot_regpath.py
@@ -15,11 +15,12 @@ penalized linear regression.
# Author: Haoran Wu <haoran.wu@univ-ubs.fr>
# License: MIT License
+# sphinx_gallery_thumbnail_number = 2
import numpy as np
import matplotlib.pylab as pl
import ot
-
+import matplotlib.animation as animation
##############################################################################
# Generate data
# -------------
@@ -72,6 +73,9 @@ t2, t_list2, g_list2 = ot.regpath.regularization_path(a, b, M, reg=final_gamma,
##############################################################################
# Plot the regularization path
# ----------------
+#
+# The OT plan is ploted as a function of $\gamma$ that is the inverse of the
+# weight on the marginal relaxations.
#%% fully relaxed l2-penalized UOT
@@ -103,13 +107,53 @@ for p in range(4):
pl.show()
+# %%
+# Animation of the regpath for UOT l2
+# ------------------------
+
+nv = 100
+g_list_v = np.logspace(-.5, -2.5, nv)
+
+pl.figure(3)
+
+
+def _update_plot(iv):
+ pl.clf()
+ tp = ot.regpath.compute_transport_plan(g_list_v[iv], g_list,
+ t_list)
+ P = tp.reshape((n, n))
+ if P.sum() > 0:
+ P = P / P.max()
+ for i in range(n):
+ for j in range(n):
+ if P[i, j] > 0:
+ pl.plot([xs[i, 0], xt[j, 0]], [xs[i, 1], xt[j, 1]], color='C2',
+ alpha=P[i, j] * 0.5)
+ pl.scatter(xs[:, 0], xs[:, 1], c='C0', alpha=0.2)
+ pl.scatter(xt[:, 0], xt[:, 1], c='C1', alpha=0.2)
+ pl.scatter(xs[:, 0], xs[:, 1], c='C0', s=P.sum(1).ravel() * (1 + p) * 4,
+ label='Re-weighted source', alpha=1)
+ pl.scatter(xt[:, 0], xt[:, 1], c='C1', s=P.sum(0).ravel() * (1 + p) * 4,
+ label='Re-weighted target', alpha=1)
+ pl.plot([], [], color='C2', alpha=0.8, label='OT plan')
+ pl.title(r'$\ell_2$ UOT $\gamma$={:1.3f}'.format(g_list_v[iv]),
+ fontsize=11)
+ return 1
+
+
+i = 0
+_update_plot(i)
+
+ani = animation.FuncAnimation(pl.gcf(), _update_plot, nv, interval=50, repeat_delay=2000)
+
+
##############################################################################
# Plot the semi-relaxed regularization path
# -------------------
#%% semi-relaxed l2-penalized UOT
-pl.figure(3)
+pl.figure(4)
selected_gamma = [10, 1, 1e-1, 1e-2]
for p in range(4):
tp = ot.regpath.compute_transport_plan(selected_gamma[p], g_list2,
@@ -133,3 +177,43 @@ for p in range(4):
if p < 2:
pl.xticks(())
pl.show()
+
+
+# %%
+# Animation of the regpath for semi-relaxed UOT l2
+# ------------------------
+
+nv = 100
+g_list_v = np.logspace(2.5, -2, nv)
+
+pl.figure(5)
+
+
+def _update_plot(iv):
+ pl.clf()
+ tp = ot.regpath.compute_transport_plan(g_list_v[iv], g_list2,
+ t_list2)
+ P = tp.reshape((n, n))
+ if P.sum() > 0:
+ P = P / P.max()
+ for i in range(n):
+ for j in range(n):
+ if P[i, j] > 0:
+ pl.plot([xs[i, 0], xt[j, 0]], [xs[i, 1], xt[j, 1]], color='C2',
+ alpha=P[i, j] * 0.5)
+ pl.scatter(xs[:, 0], xs[:, 1], c='C0', alpha=0.2)
+ pl.scatter(xt[:, 0], xt[:, 1], c='C1', alpha=0.2)
+ pl.scatter(xs[:, 0], xs[:, 1], c='C0', s=P.sum(1).ravel() * (1 + p) * 4,
+ label='Re-weighted source', alpha=1)
+ pl.scatter(xt[:, 0], xt[:, 1], c='C1', s=P.sum(0).ravel() * (1 + p) * 4,
+ label='Re-weighted target', alpha=1)
+ pl.plot([], [], color='C2', alpha=0.8, label='OT plan')
+ pl.title(r'Semi-relaxed $\ell_2$ UOT $\gamma$={:1.3f}'.format(g_list_v[iv]),
+ fontsize=11)
+ return 1
+
+
+i = 0
+_update_plot(i)
+
+ani = animation.FuncAnimation(pl.gcf(), _update_plot, nv, interval=50, repeat_delay=2000)
diff --git a/examples/unbalanced-partial/plot_unbalanced_OT.py b/examples/unbalanced-partial/plot_unbalanced_OT.py
new file mode 100644
index 0000000..03487e7
--- /dev/null
+++ b/examples/unbalanced-partial/plot_unbalanced_OT.py
@@ -0,0 +1,116 @@
+# -*- coding: utf-8 -*-
+"""
+==============================================================
+2D examples of exact and entropic unbalanced optimal transport
+==============================================================
+This example is designed to show how to compute unbalanced and
+partial OT in POT.
+
+UOT aims at solving the following optimization problem:
+
+ .. math::
+ W = \min_{\gamma} <\gamma, \mathbf{M}>_F +
+ \mathrm{reg}\cdot\Omega(\gamma) +
+ \mathrm{reg_m} \cdot \mathrm{div}(\gamma \mathbf{1}, \mathbf{a}) +
+ \mathrm{reg_m} \cdot \mathrm{div}(\gamma^T \mathbf{1}, \mathbf{b})
+
+ s.t.
+ \gamma \geq 0
+
+where :math:`\mathrm{div}` is a divergence.
+When using the entropic UOT, :math:`\mathrm{reg}>0` and :math:`\mathrm{div}`
+should be the Kullback-Leibler divergence.
+When solving exact UOT, :math:`\mathrm{reg}=0` and :math:`\mathrm{div}`
+can be either the Kullback-Leibler or the quadratic divergence.
+Using :math:`\ell_1` norm gives the so-called partial OT.
+"""
+
+# Author: Laetitia Chapel <laetitia.chapel@univ-ubs.fr>
+# License: MIT License
+
+import numpy as np
+import matplotlib.pylab as pl
+import ot
+
+##############################################################################
+# Generate data
+# -------------
+
+# %% parameters and data generation
+
+n = 40 # nb samples
+
+mu_s = np.array([-1, -1])
+cov_s = np.array([[1, 0], [0, 1]])
+
+mu_t = np.array([4, 4])
+cov_t = np.array([[1, -.8], [-.8, 1]])
+
+np.random.seed(0)
+xs = ot.datasets.make_2D_samples_gauss(n, mu_s, cov_s)
+xt = ot.datasets.make_2D_samples_gauss(n, mu_t, cov_t)
+
+n_noise = 10
+
+xs = np.concatenate((xs, ((np.random.rand(n_noise, 2) - 4))), axis=0)
+xt = np.concatenate((xt, ((np.random.rand(n_noise, 2) + 6))), axis=0)
+
+n = n + n_noise
+
+a, b = np.ones((n,)) / n, np.ones((n,)) / n # uniform distribution on samples
+
+# loss matrix
+M = ot.dist(xs, xt)
+M /= M.max()
+
+
+##############################################################################
+# Compute entropic kl-regularized UOT, kl- and l2-regularized UOT
+# -----------
+
+reg = 0.005
+reg_m_kl = 0.05
+reg_m_l2 = 5
+mass = 0.7
+
+entropic_kl_uot = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg, reg_m_kl)
+kl_uot = ot.unbalanced.mm_unbalanced(a, b, M, reg_m_kl, div='kl')
+l2_uot = ot.unbalanced.mm_unbalanced(a, b, M, reg_m_l2, div='l2')
+partial_ot = ot.partial.partial_wasserstein(a, b, M, m=mass)
+
+##############################################################################
+# Plot the results
+# ----------------
+
+pl.figure(2)
+transp = [partial_ot, l2_uot, kl_uot, entropic_kl_uot]
+title = ["partial OT \n m=" + str(mass), "$\ell_2$-UOT \n $\mathrm{reg_m}$=" +
+ str(reg_m_l2), "kl-UOT \n $\mathrm{reg_m}$=" + str(reg_m_kl),
+ "entropic kl-UOT \n $\mathrm{reg_m}$=" + str(reg_m_kl)]
+
+for p in range(4):
+ pl.subplot(2, 4, p + 1)
+ P = transp[p]
+ if P.sum() > 0:
+ P = P / P.max()
+ for i in range(n):
+ for j in range(n):
+ if P[i, j] > 0:
+ pl.plot([xs[i, 0], xt[j, 0]], [xs[i, 1], xt[j, 1]], color='C2',
+ alpha=P[i, j] * 0.3)
+ pl.scatter(xs[:, 0], xs[:, 1], c='C0', alpha=0.2)
+ pl.scatter(xt[:, 0], xt[:, 1], c='C1', alpha=0.2)
+ pl.scatter(xs[:, 0], xs[:, 1], c='C0', s=P.sum(1).ravel() * (1 + p) * 2)
+ pl.scatter(xt[:, 0], xt[:, 1], c='C1', s=P.sum(0).ravel() * (1 + p) * 2)
+ pl.title(title[p])
+ pl.yticks(())
+ pl.xticks(())
+ if p < 1:
+ pl.ylabel("mappings")
+ pl.subplot(2, 4, p + 5)
+ pl.imshow(P, cmap='jet')
+ pl.yticks(())
+ pl.xticks(())
+ if p < 1:
+ pl.ylabel("transport plans")
+pl.show()