summaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
authorGard Spreemann <gspr@nonempty.org>2021-11-09 17:05:13 +0100
committerGard Spreemann <gspr@nonempty.org>2021-11-09 17:05:13 +0100
commita9fdc844907decddf54bed3ebeea8d8b2cf0fc5c (patch)
tree449a03fce8fafb78b6badd12b6e633f1e5d73a64 /examples
parenta16b9471d7114ec08977479b7249efe747702b97 (diff)
parentf1628794d521a8dfa00af383b5e06cd6d34af619 (diff)
Merge tag '0.8.0' into dfsg/latest
Diffstat (limited to 'examples')
-rw-r--r--examples/README.txt2
-rw-r--r--examples/backends/README.txt4
-rw-r--r--examples/backends/plot_optim_gromov_pytorch.py260
-rw-r--r--examples/backends/plot_sliced_wass_grad_flow_pytorch.py185
-rw-r--r--examples/backends/plot_unmix_optim_torch.py161
-rw-r--r--examples/backends/plot_wass1d_torch.py152
-rw-r--r--examples/backends/plot_wass2_gan_torch.py227
-rw-r--r--examples/barycenters/plot_barycenter_1D.py63
-rw-r--r--examples/barycenters/plot_barycenter_lp_vs_entropic.py2
-rw-r--r--examples/barycenters/plot_convolutional_barycenter.py53
-rw-r--r--examples/barycenters/plot_debiased_barycenter.py131
-rw-r--r--examples/barycenters/plot_free_support_barycenter.py4
-rw-r--r--examples/domain-adaptation/plot_otda_color_images.py128
-rw-r--r--examples/domain-adaptation/plot_otda_jcpot.py4
-rw-r--r--examples/domain-adaptation/plot_otda_linear_mapping.py81
-rw-r--r--examples/domain-adaptation/plot_otda_mapping_colors_images.py128
-rw-r--r--examples/gromov/plot_barycenter_fgw.py2
-rw-r--r--examples/gromov/plot_fgw.py10
-rw-r--r--examples/gromov/plot_gromov.py34
-rwxr-xr-xexamples/gromov/plot_gromov_barycenter.py94
-rw-r--r--examples/plot_Intro_OT.py373
-rw-r--r--examples/plot_OT_1D_smooth.py2
-rw-r--r--examples/plot_OT_2D_samples.py2
-rw-r--r--examples/sliced-wasserstein/README.txt4
-rw-r--r--examples/sliced-wasserstein/plot_variance.py86
-rw-r--r--examples/unbalanced-partial/plot_UOT_1D.py3
-rwxr-xr-xexamples/unbalanced-partial/plot_partial_wass_and_gromov.py23
-rw-r--r--examples/unbalanced-partial/plot_regpath.py135
28 files changed, 2054 insertions, 299 deletions
diff --git a/examples/README.txt b/examples/README.txt
index 69a9f84..b48487f 100644
--- a/examples/README.txt
+++ b/examples/README.txt
@@ -1,7 +1,7 @@
Examples gallery
================
-This is a gallery of all the POT example files.
+This is a gallery of all the POT example files.
OT and regularized OT
diff --git a/examples/backends/README.txt b/examples/backends/README.txt
new file mode 100644
index 0000000..3ee0e27
--- /dev/null
+++ b/examples/backends/README.txt
@@ -0,0 +1,4 @@
+
+
+POT backend examples
+-------------------- \ No newline at end of file
diff --git a/examples/backends/plot_optim_gromov_pytorch.py b/examples/backends/plot_optim_gromov_pytorch.py
new file mode 100644
index 0000000..969707f
--- /dev/null
+++ b/examples/backends/plot_optim_gromov_pytorch.py
@@ -0,0 +1,260 @@
+r"""
+=================================
+Optimizing the Gromov-Wasserstein distance with PyTorch
+=================================
+
+In this exemple we use the pytorch backend to optimize the Gromov-Wasserstein
+(GW) loss between two graphs expressed as empirical distribution.
+
+In the first example we optimize the weights on the node of a simple template
+graph so that it minimizes the GW with a given Stochastic Block Model graph.
+We can see that this actually recovers the proportion of classes in the SBM
+and allows for an accurate clustering of the nodes using the GW optimal plan.
+
+In a second example we optimize simultaneously the weights and the sructure of
+the template graph which allows us to perform graph compression and to recover
+other properties of the SBM.
+
+The backend actually uses the gradients expressed in [38] to optimize the
+weights.
+
+[38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, Online Graph
+Dictionary Learning, International Conference on Machine Learning (ICML), 2021.
+
+"""
+# Author: Rémi Flamary <remi.flamary@polytechnique.edu>
+#
+# License: MIT License
+
+# sphinx_gallery_thumbnail_number = 3
+
+from sklearn.manifold import MDS
+import numpy as np
+import matplotlib.pylab as pl
+import torch
+
+import ot
+from ot.gromov import gromov_wasserstein2
+
+# %%
+# Graph generation
+# ---------------
+
+rng = np.random.RandomState(42)
+
+
+def get_sbm(n, nc, ratio, P):
+ nbpc = np.round(n * ratio).astype(int)
+ n = np.sum(nbpc)
+ C = np.zeros((n, n))
+ for c1 in range(nc):
+ for c2 in range(c1 + 1):
+ if c1 == c2:
+ for i in range(np.sum(nbpc[:c1]), np.sum(nbpc[:c1 + 1])):
+ for j in range(np.sum(nbpc[:c2]), i):
+ if rng.rand() <= P[c1, c2]:
+ C[i, j] = 1
+ else:
+ for i in range(np.sum(nbpc[:c1]), np.sum(nbpc[:c1 + 1])):
+ for j in range(np.sum(nbpc[:c2]), np.sum(nbpc[:c2 + 1])):
+ if rng.rand() <= P[c1, c2]:
+ C[i, j] = 1
+
+ return C + C.T
+
+
+n = 100
+nc = 3
+ratio = np.array([.5, .3, .2])
+P = np.array(0.6 * np.eye(3) + 0.05 * np.ones((3, 3)))
+C1 = get_sbm(n, nc, ratio, P)
+
+# get 2d position for nodes
+x1 = MDS(dissimilarity='precomputed', random_state=0).fit_transform(1 - C1)
+
+
+def plot_graph(x, C, color='C0', s=None):
+ for j in range(C.shape[0]):
+ for i in range(j):
+ if C[i, j] > 0:
+ pl.plot([x[i, 0], x[j, 0]], [x[i, 1], x[j, 1]], alpha=0.2, color='k')
+ pl.scatter(x[:, 0], x[:, 1], c=color, s=s, zorder=10, edgecolors='k', cmap='tab10', vmax=9)
+
+
+pl.figure(1, (10, 5))
+pl.clf()
+pl.subplot(1, 2, 1)
+plot_graph(x1, C1, color='C0')
+pl.title("SBM Graph")
+pl.axis("off")
+pl.subplot(1, 2, 2)
+pl.imshow(C1, interpolation='nearest')
+pl.title("Adjacency matrix")
+pl.axis("off")
+
+
+# %%
+# Optimizing GW w.r.t. the weights on a template structure
+# ------------------------------------------------
+# The adajacency matrix C1 is block diagonal with 3 blocks. We want to
+# optimize the weights of a simple template C0=eye(3) and see if we can
+# recover the proportion of classes from the SBM (up to a permutation).
+
+C0 = np.eye(3)
+
+
+def min_weight_gw(C1, C2, a2, nb_iter_max=100, lr=1e-2):
+ """ solve min_a GW(C1,C2,a, a2) by gradient descent"""
+
+ # use pyTorch for our data
+ C1_torch = torch.tensor(C1)
+ C2_torch = torch.tensor(C2)
+
+ a0 = rng.rand(C1.shape[0]) # random_init
+ a0 /= a0.sum() # on simplex
+ a1_torch = torch.tensor(a0).requires_grad_(True)
+ a2_torch = torch.tensor(a2)
+
+ loss_iter = []
+
+ for i in range(nb_iter_max):
+
+ loss = gromov_wasserstein2(C1_torch, C2_torch, a1_torch, a2_torch)
+
+ loss_iter.append(loss.clone().detach().cpu().numpy())
+ loss.backward()
+
+ #print("{:03d} | {}".format(i, loss_iter[-1]))
+
+ # performs a step of projected gradient descent
+ with torch.no_grad():
+ grad = a1_torch.grad
+ a1_torch -= grad * lr # step
+ a1_torch.grad.zero_()
+ a1_torch.data = ot.utils.proj_simplex(a1_torch)
+
+ a1 = a1_torch.clone().detach().cpu().numpy()
+
+ return a1, loss_iter
+
+
+a0_est, loss_iter0 = min_weight_gw(C0, C1, ot.unif(n), nb_iter_max=100, lr=1e-2)
+
+pl.figure(2)
+pl.plot(loss_iter0)
+pl.title("Loss along iterations")
+
+print("Estimated weights : ", a0_est)
+print("True proportions : ", ratio)
+
+
+# %%
+# It is clear that the optimization has converged and that we recover the
+# ratio of the different classes in the SBM graph up to a permutation.
+
+
+# %%
+# Community clustering with uniform and estimated weights
+# --------------------------------------------
+# The GW OT plan can be used to perform a clustering of the nodes of a graph
+# when computing the GW with a simple template like C0 by labeling nodes in
+# the original graph using by the index of the noe in the template receiving
+# the most mass.
+#
+# We show here the result of such a clustering when using uniform weights on
+# the template C0 and when using the optimal weights previously estimated.
+
+
+T_unif = ot.gromov_wasserstein(C1, C0, ot.unif(n), ot.unif(3))
+label_unif = T_unif.argmax(1)
+
+T_est = ot.gromov_wasserstein(C1, C0, ot.unif(n), a0_est)
+label_est = T_est.argmax(1)
+
+pl.figure(3, (10, 5))
+pl.clf()
+pl.subplot(1, 2, 1)
+plot_graph(x1, C1, color=label_unif)
+pl.title("Graph clustering unif. weights")
+pl.axis("off")
+pl.subplot(1, 2, 2)
+plot_graph(x1, C1, color=label_est)
+pl.title("Graph clustering est. weights")
+pl.axis("off")
+
+
+# %%
+# Graph compression with GW
+# -------------------------
+
+# Now we optimize both the weights and structure of a small graph that
+# minimize the GW distance wrt our data graph. This can be seen as graph
+# compression but can also recover important properties of an SBM such
+# as its class proportion but also its matrix of probability of links between
+# classes
+
+
+def graph_compession_gw(nb_nodes, C2, a2, nb_iter_max=100, lr=1e-2):
+ """ solve min_a GW(C1,C2,a, a2) by gradient descent"""
+
+ # use pyTorch for our data
+
+ C2_torch = torch.tensor(C2)
+ a2_torch = torch.tensor(a2)
+
+ a0 = rng.rand(nb_nodes) # random_init
+ a0 /= a0.sum() # on simplex
+ a1_torch = torch.tensor(a0).requires_grad_(True)
+ C0 = np.eye(nb_nodes)
+ C1_torch = torch.tensor(C0).requires_grad_(True)
+
+ loss_iter = []
+
+ for i in range(nb_iter_max):
+
+ loss = gromov_wasserstein2(C1_torch, C2_torch, a1_torch, a2_torch)
+
+ loss_iter.append(loss.clone().detach().cpu().numpy())
+ loss.backward()
+
+ #print("{:03d} | {}".format(i, loss_iter[-1]))
+
+ # performs a step of projected gradient descent
+ with torch.no_grad():
+ grad = a1_torch.grad
+ a1_torch -= grad * lr # step
+ a1_torch.grad.zero_()
+ a1_torch.data = ot.utils.proj_simplex(a1_torch)
+
+ grad = C1_torch.grad
+ C1_torch -= grad * lr # step
+ C1_torch.grad.zero_()
+ C1_torch.data = torch.clamp(C1_torch, 0, 1)
+
+ a1 = a1_torch.clone().detach().cpu().numpy()
+ C1 = C1_torch.clone().detach().cpu().numpy()
+
+ return a1, C1, loss_iter
+
+
+nb_nodes = 3
+a0_est2, C0_est2, loss_iter2 = graph_compession_gw(nb_nodes, C1, ot.unif(n),
+ nb_iter_max=100, lr=5e-2)
+
+pl.figure(4)
+pl.plot(loss_iter2)
+pl.title("Loss along iterations")
+
+
+print("Estimated weights : ", a0_est2)
+print("True proportions : ", ratio)
+
+pl.figure(6, (10, 3.5))
+pl.clf()
+pl.subplot(1, 2, 1)
+pl.imshow(P, vmin=0, vmax=1)
+pl.title('True SBM P matrix')
+pl.subplot(1, 2, 2)
+pl.imshow(C0_est2, vmin=0, vmax=1)
+pl.title('Estimated C0 matrix')
+pl.colorbar()
diff --git a/examples/backends/plot_sliced_wass_grad_flow_pytorch.py b/examples/backends/plot_sliced_wass_grad_flow_pytorch.py
new file mode 100644
index 0000000..05b9952
--- /dev/null
+++ b/examples/backends/plot_sliced_wass_grad_flow_pytorch.py
@@ -0,0 +1,185 @@
+r"""
+=================================
+Sliced Wasserstein barycenter and gradient flow with PyTorch
+=================================
+
+In this exemple we use the pytorch backend to optimize the sliced Wasserstein
+loss between two empirical distributions [31].
+
+In the first example one we perform a
+gradient flow on the support of a distribution that minimize the sliced
+Wassersein distance as poposed in [36].
+
+In the second exemple we optimize with a gradient descent the sliced
+Wasserstein barycenter between two distributions as in [31].
+
+[31] Bonneel, Nicolas, et al. "Sliced and radon wasserstein barycenters of
+measures." Journal of Mathematical Imaging and Vision 51.1 (2015): 22-45
+
+[36] Liutkus, A., Simsekli, U., Majewski, S., Durmus, A., & Stöter, F. R.
+(2019, May). Sliced-Wasserstein flows: Nonparametric generative modeling
+via optimal transport and diffusions. In International Conference on
+Machine Learning (pp. 4104-4113). PMLR.
+
+
+"""
+# Author: Rémi Flamary <remi.flamary@polytechnique.edu>
+#
+# License: MIT License
+
+
+# %%
+# Loading the data
+
+
+import numpy as np
+import matplotlib.pylab as pl
+import torch
+import ot
+import matplotlib.animation as animation
+
+I1 = pl.imread('../../data/redcross.png').astype(np.float64)[::4, ::4, 2]
+I2 = pl.imread('../../data/tooth.png').astype(np.float64)[::4, ::4, 2]
+
+sz = I2.shape[0]
+XX, YY = np.meshgrid(np.arange(sz), np.arange(sz))
+
+x1 = np.stack((XX[I1 == 0], YY[I1 == 0]), 1) * 1.0
+x2 = np.stack((XX[I2 == 0] + 60, -YY[I2 == 0] + 32), 1) * 1.0
+x3 = np.stack((XX[I2 == 0], -YY[I2 == 0] + 32), 1) * 1.0
+
+pl.figure(1, (8, 4))
+pl.scatter(x1[:, 0], x1[:, 1], alpha=0.5)
+pl.scatter(x2[:, 0], x2[:, 1], alpha=0.5)
+
+# %%
+# Sliced Wasserstein gradient flow with Pytorch
+# ---------------------------------------------
+
+
+device = "cuda" if torch.cuda.is_available() else "cpu"
+
+# use pyTorch for our data
+x1_torch = torch.tensor(x1).to(device=device).requires_grad_(True)
+x2_torch = torch.tensor(x2).to(device=device)
+
+
+lr = 1e3
+nb_iter_max = 100
+
+x_all = np.zeros((nb_iter_max, x1.shape[0], 2))
+
+loss_iter = []
+
+# generator for random permutations
+gen = torch.Generator()
+gen.manual_seed(42)
+
+for i in range(nb_iter_max):
+
+ loss = ot.sliced_wasserstein_distance(x1_torch, x2_torch, n_projections=20, seed=gen)
+
+ loss_iter.append(loss.clone().detach().cpu().numpy())
+ loss.backward()
+
+ # performs a step of projected gradient descent
+ with torch.no_grad():
+ grad = x1_torch.grad
+ x1_torch -= grad * lr / (1 + i / 5e1) # step
+ x1_torch.grad.zero_()
+ x_all[i, :, :] = x1_torch.clone().detach().cpu().numpy()
+
+xb = x1_torch.clone().detach().cpu().numpy()
+
+pl.figure(2, (8, 4))
+pl.scatter(x1[:, 0], x1[:, 1], alpha=0.5, label='$\mu^{(0)}$')
+pl.scatter(x2[:, 0], x2[:, 1], alpha=0.5, label=r'$\nu$')
+pl.scatter(xb[:, 0], xb[:, 1], alpha=0.5, label='$\mu^{(100)}$')
+pl.title('Sliced Wasserstein gradient flow')
+pl.legend()
+ax = pl.axis()
+
+# %%
+# Animate trajectories of the gradient flow along iteration
+# -------------------------------------------------------
+
+pl.figure(3, (8, 4))
+
+
+def _update_plot(i):
+ pl.clf()
+ pl.scatter(x1[:, 0], x1[:, 1], alpha=0.5, label='$\mu^{(0)}$')
+ pl.scatter(x2[:, 0], x2[:, 1], alpha=0.5, label=r'$\nu$')
+ pl.scatter(x_all[i, :, 0], x_all[i, :, 1], alpha=0.5, label='$\mu^{(100)}$')
+ pl.title('Sliced Wasserstein gradient flow Iter. {}'.format(i))
+ pl.axis(ax)
+ return 1
+
+
+ani = animation.FuncAnimation(pl.gcf(), _update_plot, nb_iter_max, interval=100, repeat_delay=2000)
+
+# %%
+# Compute the Sliced Wasserstein Barycenter
+#
+x1_torch = torch.tensor(x1).to(device=device)
+x3_torch = torch.tensor(x3).to(device=device)
+xbinit = np.random.randn(500, 2) * 10 + 16
+xbary_torch = torch.tensor(xbinit).to(device=device).requires_grad_(True)
+
+lr = 1e3
+nb_iter_max = 100
+
+x_all = np.zeros((nb_iter_max, xbary_torch.shape[0], 2))
+
+loss_iter = []
+
+# generator for random permutations
+gen = torch.Generator()
+gen.manual_seed(42)
+
+alpha = 0.5
+
+for i in range(nb_iter_max):
+
+ loss = alpha * ot.sliced_wasserstein_distance(xbary_torch, x3_torch, n_projections=50, seed=gen) \
+ + (1 - alpha) * ot.sliced_wasserstein_distance(xbary_torch, x1_torch, n_projections=50, seed=gen)
+
+ loss_iter.append(loss.clone().detach().cpu().numpy())
+ loss.backward()
+
+ # performs a step of projected gradient descent
+ with torch.no_grad():
+ grad = xbary_torch.grad
+ xbary_torch -= grad * lr # / (1 + i / 5e1) # step
+ xbary_torch.grad.zero_()
+ x_all[i, :, :] = xbary_torch.clone().detach().cpu().numpy()
+
+xb = xbary_torch.clone().detach().cpu().numpy()
+
+pl.figure(4, (8, 4))
+pl.scatter(x1[:, 0], x1[:, 1], alpha=0.5, label='$\mu$')
+pl.scatter(x2[:, 0], x2[:, 1], alpha=0.5, label=r'$\nu$')
+pl.scatter(xb[:, 0] + 30, xb[:, 1], alpha=0.5, label='Barycenter')
+pl.title('Sliced Wasserstein barycenter')
+pl.legend()
+ax = pl.axis()
+
+
+# %%
+# Animate trajectories of the barycenter along gradient descent
+# -------------------------------------------------------
+
+pl.figure(5, (8, 4))
+
+
+def _update_plot(i):
+ pl.clf()
+ pl.scatter(x1[:, 0], x1[:, 1], alpha=0.5, label='$\mu^{(0)}$')
+ pl.scatter(x2[:, 0], x2[:, 1], alpha=0.5, label=r'$\nu$')
+ pl.scatter(x_all[i, :, 0] + 30, x_all[i, :, 1], alpha=0.5, label='$\mu^{(100)}$')
+ pl.title('Sliced Wasserstein barycenter Iter. {}'.format(i))
+ pl.axis(ax)
+ return 1
+
+
+ani = animation.FuncAnimation(pl.gcf(), _update_plot, nb_iter_max, interval=100, repeat_delay=2000)
diff --git a/examples/backends/plot_unmix_optim_torch.py b/examples/backends/plot_unmix_optim_torch.py
new file mode 100644
index 0000000..9ae66e9
--- /dev/null
+++ b/examples/backends/plot_unmix_optim_torch.py
@@ -0,0 +1,161 @@
+# -*- coding: utf-8 -*-
+r"""
+=================================
+Wasserstein unmixing with PyTorch
+=================================
+
+In this example we estimate mixing parameters from distributions that minimize
+the Wasserstein distance. In other words we suppose that a target
+distribution :math:`\mu^t` can be expressed as a weighted sum of source
+distributions :math:`\mu^s_k` with the following model:
+
+.. math::
+ \mu^t = \sum_{k=1}^K w_k\mu^s_k
+
+where :math:`\mathbf{w}` is a vector of size :math:`K` and belongs in the
+distribution simplex :math:`\Delta_K`.
+
+In order to estimate this weight vector we propose to optimize the Wasserstein
+distance between the model and the observed :math:`\mu^t` with respect to
+the vector. This leads to the following optimization problem:
+
+.. math::
+ \min_{\mathbf{w}\in\Delta_K} \quad W \left(\mu^t,\sum_{k=1}^K w_k\mu^s_k\right)
+
+This minimization is done in this example with a simple projected gradient
+descent in PyTorch. We use the automatic backend of POT that allows us to
+compute the Wasserstein distance with :any:`ot.emd2` with
+differentiable losses.
+
+"""
+
+# Author: Remi Flamary <remi.flamary@polytechnique.edu>
+#
+# License: MIT License
+
+# sphinx_gallery_thumbnail_number = 2
+
+import numpy as np
+import matplotlib.pylab as pl
+import ot
+import torch
+
+
+##############################################################################
+# Generate data
+# -------------
+
+#%% Data
+
+nt = 100
+nt1 = 10 #
+
+ns1 = 50
+ns = 2 * ns1
+
+rng = np.random.RandomState(2)
+
+xt = rng.randn(nt, 2) * 0.2
+xt[:nt1, 0] += 1
+xt[nt1:, 1] += 1
+
+
+xs1 = rng.randn(ns1, 2) * 0.2
+xs1[:, 0] += 1
+xs2 = rng.randn(ns1, 2) * 0.2
+xs2[:, 1] += 1
+
+xs = np.concatenate((xs1, xs2))
+
+# Sample reweighting matrix H
+H = np.zeros((ns, 2))
+H[:ns1, 0] = 1 / ns1
+H[ns1:, 1] = 1 / ns1
+# each columns sums to 1 and has weights only for samples form the
+# corresponding source distribution
+
+M = ot.dist(xs, xt)
+
+##############################################################################
+# Plot data
+# ---------
+
+#%% plot the distributions
+
+pl.figure(1)
+pl.scatter(xt[:, 0], xt[:, 1], label='Target $\mu^t$', alpha=0.5)
+pl.scatter(xs1[:, 0], xs1[:, 1], label='Source $\mu^s_1$', alpha=0.5)
+pl.scatter(xs2[:, 0], xs2[:, 1], label='Source $\mu^s_2$', alpha=0.5)
+pl.title('Sources and Target distributions')
+pl.legend()
+
+
+##############################################################################
+# Optimization of the model wrt the Wasserstein distance
+# ------------------------------------------------------
+
+
+#%% Weights optimization with gradient descent
+
+# convert numpy arrays to torch tensors
+H2 = torch.tensor(H)
+M2 = torch.tensor(M)
+
+# weights for the source distributions
+w = torch.tensor(ot.unif(2), requires_grad=True)
+
+# uniform weights for target
+b = torch.tensor(ot.unif(nt))
+
+lr = 2e-3 # learning rate
+niter = 500 # number of iterations
+losses = [] # loss along the iterations
+
+# loss for the minimal Wasserstein estimator
+
+
+def get_loss(w):
+ a = torch.mv(H2, w) # distribution reweighting
+ return ot.emd2(a, b, M2) # squared Wasserstein 2
+
+
+for i in range(niter):
+
+ loss = get_loss(w)
+ losses.append(float(loss))
+
+ loss.backward()
+
+ with torch.no_grad():
+ w -= lr * w.grad # gradient step
+ w[:] = ot.utils.proj_simplex(w) # projection on the simplex
+
+ w.grad.zero_()
+
+
+##############################################################################
+# Estimated weights and convergence of the objective
+# ---------------------------------------------------
+
+we = w.detach().numpy()
+print('Estimated mixture:', we)
+
+pl.figure(2)
+pl.semilogy(losses)
+pl.grid()
+pl.title('Wasserstein distance')
+pl.xlabel("Iterations")
+
+##############################################################################
+# Ploting the reweighted source distribution
+# ------------------------------------------
+
+pl.figure(3)
+
+# compute source weights
+ws = H.dot(we)
+
+pl.scatter(xt[:, 0], xt[:, 1], label='Target $\mu^t$', alpha=0.5)
+pl.scatter(xs[:, 0], xs[:, 1], color='C3', s=ws * 20 * ns, label='Weighted sources $\sum_{k} w_k\mu^s_k$', alpha=0.5)
+pl.title('Target and reweighted source distributions')
+pl.legend()
diff --git a/examples/backends/plot_wass1d_torch.py b/examples/backends/plot_wass1d_torch.py
new file mode 100644
index 0000000..0abdd6d
--- /dev/null
+++ b/examples/backends/plot_wass1d_torch.py
@@ -0,0 +1,152 @@
+r"""
+=================================
+Wasserstein 1D with PyTorch
+=================================
+
+In this small example, we consider the following minization problem:
+
+.. math::
+ \mu^* = \min_\mu W(\mu,\nu)
+
+where :math:`\nu` is a reference 1D measure. The problem is handled
+by a projected gradient descent method, where the gradient is computed
+by pyTorch automatic differentiation. The projection on the simplex
+ensures that the iterate will remain on the probability simplex.
+
+This example illustrates both `wasserstein_1d` function and backend use within
+the POT framework.
+"""
+# Author: Nicolas Courty <ncourty@irisa.fr>
+# Rémi Flamary <remi.flamary@polytechnique.edu>
+#
+# License: MIT License
+
+import numpy as np
+import matplotlib.pylab as pl
+import matplotlib as mpl
+import torch
+
+from ot.lp import wasserstein_1d
+from ot.datasets import make_1D_gauss as gauss
+from ot.utils import proj_simplex
+
+red = np.array(mpl.colors.to_rgb('red'))
+blue = np.array(mpl.colors.to_rgb('blue'))
+
+
+n = 100 # nb bins
+
+# bin positions
+x = np.arange(n, dtype=np.float64)
+
+# Gaussian distributions
+a = gauss(n, m=20, s=5) # m= mean, s= std
+b = gauss(n, m=60, s=10)
+
+# enforce sum to one on the support
+a = a / a.sum()
+b = b / b.sum()
+
+
+device = "cuda" if torch.cuda.is_available() else "cpu"
+
+# use pyTorch for our data
+x_torch = torch.tensor(x).to(device=device)
+a_torch = torch.tensor(a).to(device=device).requires_grad_(True)
+b_torch = torch.tensor(b).to(device=device)
+
+lr = 1e-6
+nb_iter_max = 800
+
+loss_iter = []
+
+pl.figure(1, figsize=(8, 4))
+pl.plot(x, a, 'b', label='Source distribution')
+pl.plot(x, b, 'r', label='Target distribution')
+
+for i in range(nb_iter_max):
+ # Compute the Wasserstein 1D with torch backend
+ loss = wasserstein_1d(x_torch, x_torch, a_torch, b_torch, p=2)
+ # record the corresponding loss value
+ loss_iter.append(loss.clone().detach().cpu().numpy())
+ loss.backward()
+
+ # performs a step of projected gradient descent
+ with torch.no_grad():
+ grad = a_torch.grad
+ a_torch -= a_torch.grad * lr # step
+ a_torch.grad.zero_()
+ a_torch.data = proj_simplex(a_torch) # projection onto the simplex
+
+ # plot one curve every 10 iterations
+ if i % 10 == 0:
+ mix = float(i) / nb_iter_max
+ pl.plot(x, a_torch.clone().detach().cpu().numpy(), c=(1 - mix) * blue + mix * red)
+
+pl.legend()
+pl.title('Distribution along the iterations of the projected gradient descent')
+pl.show()
+
+pl.figure(2)
+pl.plot(range(nb_iter_max), loss_iter, lw=3)
+pl.title('Evolution of the loss along iterations', fontsize=16)
+pl.show()
+
+# %%
+# Wasserstein barycenter
+# ---------
+# In this example, we consider the following Wasserstein barycenter problem
+# $$ \\eta^* = \\min_\\eta\;\;\; (1-t)W(\\mu,\\eta) + tW(\\eta,\\nu)$$
+# where :math:`\\mu` and :math:`\\nu` are reference 1D measures, and :math:`t`
+# is a parameter :math:`\in [0,1]`. The problem is handled by a project gradient
+# descent method, where the gradient is computed by pyTorch automatic differentiation.
+# The projection on the simplex ensures that the iterate will remain on the
+# probability simplex.
+#
+# This example illustrates both `wasserstein_1d` function and backend use within the
+# POT framework.
+
+
+device = "cuda" if torch.cuda.is_available() else "cpu"
+
+# use pyTorch for our data
+x_torch = torch.tensor(x).to(device=device)
+a_torch = torch.tensor(a).to(device=device)
+b_torch = torch.tensor(b).to(device=device)
+bary_torch = torch.tensor((a + b).copy() / 2).to(device=device).requires_grad_(True)
+
+
+lr = 1e-6
+nb_iter_max = 2000
+
+loss_iter = []
+
+# instant of the interpolation
+t = 0.5
+
+for i in range(nb_iter_max):
+ # Compute the Wasserstein 1D with torch backend
+ loss = (1 - t) * wasserstein_1d(x_torch, x_torch, a_torch.detach(), bary_torch, p=2) + t * wasserstein_1d(x_torch, x_torch, b_torch, bary_torch, p=2)
+ # record the corresponding loss value
+ loss_iter.append(loss.clone().detach().cpu().numpy())
+ loss.backward()
+
+ # performs a step of projected gradient descent
+ with torch.no_grad():
+ grad = bary_torch.grad
+ bary_torch -= bary_torch.grad * lr # step
+ bary_torch.grad.zero_()
+ bary_torch.data = proj_simplex(bary_torch) # projection onto the simplex
+
+pl.figure(3, figsize=(8, 4))
+pl.plot(x, a, 'b', label='Source distribution')
+pl.plot(x, b, 'r', label='Target distribution')
+pl.plot(x, bary_torch.clone().detach().cpu().numpy(), c='green', label='W barycenter')
+pl.legend()
+pl.title('Wasserstein barycenter computed by gradient descent')
+pl.show()
+
+pl.figure(4)
+pl.plot(range(nb_iter_max), loss_iter, lw=3)
+pl.title('Evolution of the loss along iterations', fontsize=16)
+pl.show()
diff --git a/examples/backends/plot_wass2_gan_torch.py b/examples/backends/plot_wass2_gan_torch.py
new file mode 100644
index 0000000..ca5b3c9
--- /dev/null
+++ b/examples/backends/plot_wass2_gan_torch.py
@@ -0,0 +1,227 @@
+# -*- coding: utf-8 -*-
+r"""
+========================================
+Wasserstein 2 Minibatch GAN with PyTorch
+========================================
+
+In this example we train a Wasserstein GAN using Wasserstein 2 on minibatches
+as a distribution fitting term.
+
+We want to train a generator :math:`G_\theta` that generates realistic
+data from random noise drawn form a Gaussian :math:`\mu_n` distribution so
+that the data is indistinguishable from true data in the data distribution
+:math:`\mu_d`. To this end Wasserstein GAN [Arjovsky2017] aim at optimizing
+the parameters :math:`\theta` of the generator with the following
+optimization problem:
+
+.. math::
+ \min_{\theta} W(\mu_d,G_\theta\#\mu_n)
+
+
+In practice we do not have access to the full distribution :math:`\mu_d` but
+samples and we cannot compute the Wasserstein distance for lare dataset.
+[Arjovsky2017] proposed to approximate the dual potential of Wasserstein 1
+with a neural network recovering an optimization problem similar to GAN.
+In this example
+we will optimize the expectation of the Wasserstein distance over minibatches
+at each iterations as proposed in [Genevay2018]. Optimizing the Minibatches
+of the Wasserstein distance has been studied in[Fatras2019].
+
+[Arjovsky2017] Arjovsky, M., Chintala, S., & Bottou, L. (2017, July).
+Wasserstein generative adversarial networks. In International conference
+on machine learning (pp. 214-223). PMLR.
+
+[Genevay2018] Genevay, Aude, Gabriel Peyré, and Marco Cuturi. "Learning generative models
+with sinkhorn divergences." International Conference on Artificial Intelligence
+and Statistics. PMLR, 2018.
+
+[Fatras2019] Fatras, K., Zine, Y., Flamary, R., Gribonval, R., & Courty, N.
+(2020, June). Learning with minibatch Wasserstein: asymptotic and gradient
+properties. In the 23nd International Conference on Artificial Intelligence
+and Statistics (Vol. 108).
+
+"""
+
+# Author: Remi Flamary <remi.flamary@polytechnique.edu>
+#
+# License: MIT License
+
+# sphinx_gallery_thumbnail_number = 3
+
+import numpy as np
+import matplotlib.pyplot as pl
+import matplotlib.animation as animation
+import torch
+from torch import nn
+import ot
+
+
+# %%
+# Data generation
+# ---------------
+
+torch.manual_seed(1)
+sigma = 0.1
+n_dims = 2
+n_features = 2
+
+
+def get_data(n_samples):
+ c = torch.rand(size=(n_samples, 1))
+ angle = c * 2 * np.pi
+ x = torch.cat((torch.cos(angle), torch.sin(angle)), 1)
+ x += torch.randn(n_samples, 2) * sigma
+ return x
+
+
+# %%
+# Plot data
+# ---------
+
+# plot the distributions
+x = get_data(500)
+pl.figure(1)
+pl.scatter(x[:, 0], x[:, 1], label='Data samples from $\mu_d$', alpha=0.5)
+pl.title('Data distribution')
+pl.legend()
+
+
+# %%
+# Generator Model
+# ---------------
+
+# define the MLP model
+class Generator(torch.nn.Module):
+ def __init__(self):
+ super(Generator, self).__init__()
+ self.fc1 = nn.Linear(n_features, 200)
+ self.fc2 = nn.Linear(200, 500)
+ self.fc3 = nn.Linear(500, n_dims)
+ self.relu = torch.nn.ReLU() # instead of Heaviside step fn
+
+ def forward(self, x):
+ output = self.fc1(x)
+ output = self.relu(output) # instead of Heaviside step fn
+ output = self.fc2(output)
+ output = self.relu(output)
+ output = self.fc3(output)
+ return output
+
+# %%
+# Training the model
+# ------------------
+
+
+G = Generator()
+optimizer = torch.optim.RMSprop(G.parameters(), lr=0.00019, eps=1e-5)
+
+# number of iteration and size of the batches
+n_iter = 200 # set to 200 for doc build but 1000 is better ;)
+size_batch = 500
+
+# generate statis samples to see their trajectory along training
+n_visu = 100
+xnvisu = torch.randn(n_visu, n_features)
+xvisu = torch.zeros(n_iter, n_visu, n_dims)
+
+ab = torch.ones(size_batch) / size_batch
+losses = []
+
+
+for i in range(n_iter):
+
+ # generate noise samples
+ xn = torch.randn(size_batch, n_features)
+
+ # generate data samples
+ xd = get_data(size_batch)
+
+ # generate sample along iterations
+ xvisu[i, :, :] = G(xnvisu).detach()
+
+ # generate smaples and compte distance matrix
+ xg = G(xn)
+ M = ot.dist(xg, xd)
+
+ loss = ot.emd2(ab, ab, M)
+ losses.append(float(loss.detach()))
+
+ if i % 10 == 0:
+ print("Iter: {:3d}, loss={}".format(i, losses[-1]))
+
+ loss.backward()
+ optimizer.step()
+
+ del M
+
+pl.figure(2)
+pl.semilogy(losses)
+pl.grid()
+pl.title('Wasserstein distance')
+pl.xlabel("Iterations")
+
+
+# %%
+# Plot trajectories of generated samples along iterations
+# -------------------------------------------------------
+
+
+pl.figure(3, (10, 10))
+
+ivisu = [0, 10, 25, 50, 75, 125, 15, 175, 199]
+
+for i in range(9):
+ pl.subplot(3, 3, i + 1)
+ pl.scatter(xd[:, 0], xd[:, 1], label='Data samples from $\mu_d$', alpha=0.1)
+ pl.scatter(xvisu[ivisu[i], :, 0], xvisu[ivisu[i], :, 1], label='Data samples from $G\#\mu_n$', alpha=0.5)
+ pl.xticks(())
+ pl.yticks(())
+ pl.title('Iter. {}'.format(ivisu[i]))
+ if i == 0:
+ pl.legend()
+
+# %%
+# Animate trajectories of generated samples along iteration
+# -------------------------------------------------------
+
+pl.figure(4, (8, 8))
+
+
+def _update_plot(i):
+ pl.clf()
+ pl.scatter(xd[:, 0], xd[:, 1], label='Data samples from $\mu_d$', alpha=0.1)
+ pl.scatter(xvisu[i, :, 0], xvisu[i, :, 1], label='Data samples from $G\#\mu_n$', alpha=0.5)
+ pl.xticks(())
+ pl.yticks(())
+ pl.xlim((-1.5, 1.5))
+ pl.ylim((-1.5, 1.5))
+ pl.title('Iter. {}'.format(i))
+ return 1
+
+
+i = 0
+pl.scatter(xd[:, 0], xd[:, 1], label='Data samples from $\mu_d$', alpha=0.1)
+pl.scatter(xvisu[i, :, 0], xvisu[i, :, 1], label='Data samples from $G\#\mu_n$', alpha=0.5)
+pl.xticks(())
+pl.yticks(())
+pl.xlim((-1.5, 1.5))
+pl.ylim((-1.5, 1.5))
+pl.title('Iter. {}'.format(ivisu[i]))
+
+
+ani = animation.FuncAnimation(pl.gcf(), _update_plot, n_iter, interval=100, repeat_delay=2000)
+
+# %%
+# Generate and visualize data
+# ---------------------------
+
+size_batch = 500
+xd = get_data(size_batch)
+xn = torch.randn(size_batch, 2)
+x = G(xn).detach().numpy()
+
+pl.figure(5)
+pl.scatter(xd[:, 0], xd[:, 1], label='Data samples from $\mu_d$', alpha=0.5)
+pl.scatter(x[:, 0], x[:, 1], label='Data samples from $G\#\mu_n$', alpha=0.5)
+pl.title('Sources and Target distributions')
+pl.legend()
diff --git a/examples/barycenters/plot_barycenter_1D.py b/examples/barycenters/plot_barycenter_1D.py
index 63dc460..2373e99 100644
--- a/examples/barycenters/plot_barycenter_1D.py
+++ b/examples/barycenters/plot_barycenter_1D.py
@@ -18,10 +18,10 @@ SIAM Journal on Scientific Computing, 37(2), A1111-A1138.
#
# License: MIT License
-# sphinx_gallery_thumbnail_number = 4
+# sphinx_gallery_thumbnail_number = 1
import numpy as np
-import matplotlib.pylab as pl
+import matplotlib.pyplot as plt
import ot
# necessary for 3d plot even if not used
from mpl_toolkits.mplot3d import Axes3D # noqa
@@ -51,18 +51,6 @@ M = ot.utils.dist0(n)
M /= M.max()
##############################################################################
-# Plot data
-# ---------
-
-#%% plot the distributions
-
-pl.figure(1, figsize=(6.4, 3))
-for i in range(n_distributions):
- pl.plot(x, A[:, i])
-pl.title('Distributions')
-pl.tight_layout()
-
-##############################################################################
# Barycenter computation
# ----------------------
@@ -78,24 +66,20 @@ bary_l2 = A.dot(weights)
reg = 1e-3
bary_wass = ot.bregman.barycenter(A, M, reg, weights)
-pl.figure(2)
-pl.clf()
-pl.subplot(2, 1, 1)
-for i in range(n_distributions):
- pl.plot(x, A[:, i])
-pl.title('Distributions')
+f, (ax1, ax2) = plt.subplots(2, 1, tight_layout=True, num=1)
+ax1.plot(x, A, color="black")
+ax1.set_title('Distributions')
-pl.subplot(2, 1, 2)
-pl.plot(x, bary_l2, 'r', label='l2')
-pl.plot(x, bary_wass, 'g', label='Wasserstein')
-pl.legend()
-pl.title('Barycenters')
-pl.tight_layout()
+ax2.plot(x, bary_l2, 'r', label='l2')
+ax2.plot(x, bary_wass, 'g', label='Wasserstein')
+ax2.set_title('Barycenters')
+
+plt.legend()
+plt.show()
##############################################################################
# Barycentric interpolation
# -------------------------
-
#%% barycenter interpolation
n_alpha = 11
@@ -106,24 +90,23 @@ B_l2 = np.zeros((n, n_alpha))
B_wass = np.copy(B_l2)
-for i in range(0, n_alpha):
+for i in range(n_alpha):
alpha = alpha_list[i]
weights = np.array([1 - alpha, alpha])
B_l2[:, i] = A.dot(weights)
B_wass[:, i] = ot.bregman.barycenter(A, M, reg, weights)
#%% plot interpolation
+plt.figure(2)
-pl.figure(3)
-
-cmap = pl.cm.get_cmap('viridis')
+cmap = plt.cm.get_cmap('viridis')
verts = []
zs = alpha_list
for i, z in enumerate(zs):
ys = B_l2[:, i]
verts.append(list(zip(x, ys)))
-ax = pl.gcf().gca(projection='3d')
+ax = plt.gcf().gca(projection='3d')
poly = PolyCollection(verts, facecolors=[cmap(a) for a in alpha_list])
poly.set_alpha(0.7)
@@ -134,18 +117,18 @@ ax.set_ylabel('$\\alpha$')
ax.set_ylim3d(0, 1)
ax.set_zlabel('')
ax.set_zlim3d(0, B_l2.max() * 1.01)
-pl.title('Barycenter interpolation with l2')
-pl.tight_layout()
+plt.title('Barycenter interpolation with l2')
+plt.tight_layout()
-pl.figure(4)
-cmap = pl.cm.get_cmap('viridis')
+plt.figure(3)
+cmap = plt.cm.get_cmap('viridis')
verts = []
zs = alpha_list
for i, z in enumerate(zs):
ys = B_wass[:, i]
verts.append(list(zip(x, ys)))
-ax = pl.gcf().gca(projection='3d')
+ax = plt.gcf().gca(projection='3d')
poly = PolyCollection(verts, facecolors=[cmap(a) for a in alpha_list])
poly.set_alpha(0.7)
@@ -156,7 +139,7 @@ ax.set_ylabel('$\\alpha$')
ax.set_ylim3d(0, 1)
ax.set_zlabel('')
ax.set_zlim3d(0, B_l2.max() * 1.01)
-pl.title('Barycenter interpolation with Wasserstein')
-pl.tight_layout()
+plt.title('Barycenter interpolation with Wasserstein')
+plt.tight_layout()
-pl.show()
+plt.show()
diff --git a/examples/barycenters/plot_barycenter_lp_vs_entropic.py b/examples/barycenters/plot_barycenter_lp_vs_entropic.py
index 57a6bac..6502f16 100644
--- a/examples/barycenters/plot_barycenter_lp_vs_entropic.py
+++ b/examples/barycenters/plot_barycenter_lp_vs_entropic.py
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
"""
=================================================================================
-1D Wasserstein barycenter comparison between exact LP and entropic regularization
+1D Wasserstein barycenter: exact LP vs entropic regularization
=================================================================================
This example illustrates the computation of regularized Wasserstein Barycenter
diff --git a/examples/barycenters/plot_convolutional_barycenter.py b/examples/barycenters/plot_convolutional_barycenter.py
index cbcd4a1..3721f31 100644
--- a/examples/barycenters/plot_convolutional_barycenter.py
+++ b/examples/barycenters/plot_convolutional_barycenter.py
@@ -6,17 +6,18 @@
Convolutional Wasserstein Barycenter example
============================================
-This example is designed to illustrate how the Convolutional Wasserstein Barycenter
-function of POT works.
+This example is designed to illustrate how the Convolutional Wasserstein
+Barycenter function of POT works.
"""
# Author: Nicolas Courty <ncourty@irisa.fr>
#
# License: MIT License
-
+import os
+from pathlib import Path
import numpy as np
-import pylab as pl
+import matplotlib.pyplot as plt
import ot
##############################################################################
@@ -25,22 +26,19 @@ import ot
#
# The four distributions are constructed from 4 simple images
+this_file = os.path.realpath('__file__')
+data_path = os.path.join(Path(this_file).parent.parent.parent, 'data')
-f1 = 1 - pl.imread('../../data/redcross.png')[:, :, 2]
-f2 = 1 - pl.imread('../../data/duck.png')[:, :, 2]
-f3 = 1 - pl.imread('../../data/heart.png')[:, :, 2]
-f4 = 1 - pl.imread('../../data/tooth.png')[:, :, 2]
+f1 = 1 - plt.imread(os.path.join(data_path, 'redcross.png'))[:, :, 2]
+f2 = 1 - plt.imread(os.path.join(data_path, 'tooth.png'))[:, :, 2]
+f3 = 1 - plt.imread(os.path.join(data_path, 'heart.png'))[:, :, 2]
+f4 = 1 - plt.imread(os.path.join(data_path, 'duck.png'))[:, :, 2]
-A = []
f1 = f1 / np.sum(f1)
f2 = f2 / np.sum(f2)
f3 = f3 / np.sum(f3)
f4 = f4 / np.sum(f4)
-A.append(f1)
-A.append(f2)
-A.append(f3)
-A.append(f4)
-A = np.array(A)
+A = np.array([f1, f2, f3, f4])
nb_images = 5
@@ -57,14 +55,13 @@ v4 = np.array((0, 0, 0, 1))
# ----------------------------------------
#
-pl.figure(figsize=(10, 10))
-pl.title('Convolutional Wasserstein Barycenters in POT')
+fig, axes = plt.subplots(nb_images, nb_images, figsize=(7, 7))
+plt.suptitle('Convolutional Wasserstein Barycenters in POT')
cm = 'Blues'
# regularization parameter
reg = 0.004
for i in range(nb_images):
for j in range(nb_images):
- pl.subplot(nb_images, nb_images, i * nb_images + j + 1)
tx = float(i) / (nb_images - 1)
ty = float(j) / (nb_images - 1)
@@ -74,19 +71,19 @@ for i in range(nb_images):
weights = (1 - ty) * tmp1 + ty * tmp2
if i == 0 and j == 0:
- pl.imshow(f1, cmap=cm)
- pl.axis('off')
+ axes[i, j].imshow(f1, cmap=cm)
elif i == 0 and j == (nb_images - 1):
- pl.imshow(f3, cmap=cm)
- pl.axis('off')
+ axes[i, j].imshow(f3, cmap=cm)
elif i == (nb_images - 1) and j == 0:
- pl.imshow(f2, cmap=cm)
- pl.axis('off')
+ axes[i, j].imshow(f2, cmap=cm)
elif i == (nb_images - 1) and j == (nb_images - 1):
- pl.imshow(f4, cmap=cm)
- pl.axis('off')
+ axes[i, j].imshow(f4, cmap=cm)
else:
# call to barycenter computation
- pl.imshow(ot.bregman.convolutional_barycenter2d(A, reg, weights), cmap=cm)
- pl.axis('off')
-pl.show()
+ axes[i, j].imshow(
+ ot.bregman.convolutional_barycenter2d(A, reg, weights),
+ cmap=cm
+ )
+ axes[i, j].axis('off')
+plt.tight_layout()
+plt.show()
diff --git a/examples/barycenters/plot_debiased_barycenter.py b/examples/barycenters/plot_debiased_barycenter.py
new file mode 100644
index 0000000..2a603dd
--- /dev/null
+++ b/examples/barycenters/plot_debiased_barycenter.py
@@ -0,0 +1,131 @@
+# -*- coding: utf-8 -*-
+"""
+=================================
+Debiased Sinkhorn barycenter demo
+=================================
+
+This example illustrates the computation of the debiased Sinkhorn barycenter
+as proposed in [37]_.
+
+
+.. [37] Janati, H., Cuturi, M., Gramfort, A. Proceedings of the 37th
+ International Conference on Machine Learning, PMLR 119:4692-4701, 2020
+"""
+
+# Author: Hicham Janati <hicham.janati100@gmail.com>
+#
+# License: MIT License
+# sphinx_gallery_thumbnail_number = 3
+
+import os
+from pathlib import Path
+
+import numpy as np
+import matplotlib.pyplot as plt
+
+import ot
+from ot.bregman import (barycenter, barycenter_debiased,
+ convolutional_barycenter2d,
+ convolutional_barycenter2d_debiased)
+
+##############################################################################
+# Debiased barycenter of 1D Gaussians
+# ------------------------------------
+
+#%% parameters
+
+n = 100 # nb bins
+
+# bin positions
+x = np.arange(n, dtype=np.float64)
+
+# Gaussian distributions
+a1 = ot.datasets.make_1D_gauss(n, m=20, s=5) # m= mean, s= std
+a2 = ot.datasets.make_1D_gauss(n, m=60, s=8)
+
+# creating matrix A containing all distributions
+A = np.vstack((a1, a2)).T
+n_distributions = A.shape[1]
+
+# loss matrix + normalization
+M = ot.utils.dist0(n)
+M /= M.max()
+
+#%% barycenter computation
+
+alpha = 0.2 # 0<=alpha<=1
+weights = np.array([1 - alpha, alpha])
+
+epsilons = [5e-3, 1e-2, 5e-2]
+
+
+bars = [barycenter(A, M, reg, weights) for reg in epsilons]
+bars_debiased = [barycenter_debiased(A, M, reg, weights) for reg in epsilons]
+labels = ["Sinkhorn barycenter", "Debiased barycenter"]
+colors = ["indianred", "gold"]
+
+f, axes = plt.subplots(1, len(epsilons), tight_layout=True, sharey=True,
+ figsize=(12, 4), num=1)
+for ax, eps, bar, bar_debiased in zip(axes, epsilons, bars, bars_debiased):
+ ax.plot(A[:, 0], color="k", ls="--", label="Input data", alpha=0.3)
+ ax.plot(A[:, 1], color="k", ls="--", alpha=0.3)
+ for data, label, color in zip([bar, bar_debiased], labels, colors):
+ ax.plot(data, color=color, label=label, lw=2)
+ ax.set_title(r"$\varepsilon = %.3f$" % eps)
+plt.legend()
+plt.show()
+
+
+##############################################################################
+# Debiased barycenter of 2D images
+# ---------------------------------
+this_file = os.path.realpath('__file__')
+data_path = os.path.join(Path(this_file).parent.parent.parent, 'data')
+f1 = 1 - plt.imread(os.path.join(data_path, 'heart.png'))[:, :, 2]
+f2 = 1 - plt.imread(os.path.join(data_path, 'duck.png'))[:, :, 2]
+
+A = np.asarray([f1, f2]) + 1e-2
+A /= A.sum(axis=(1, 2))[:, None, None]
+
+##############################################################################
+# Display the input images
+
+fig, axes = plt.subplots(1, 2, figsize=(7, 4), num=2)
+for ax, img in zip(axes, A):
+ ax.imshow(img, cmap="Greys")
+ ax.axis("off")
+fig.tight_layout()
+plt.show()
+
+
+##############################################################################
+# Barycenter computation and visualization
+# ----------------------------------------
+#
+
+bars_sinkhorn, bars_debiased = [], []
+epsilons = [5e-3, 7e-3, 1e-2]
+for eps in epsilons:
+ bar = convolutional_barycenter2d(A, eps)
+ bar_debiased, log = convolutional_barycenter2d_debiased(A, eps, log=True)
+ bars_sinkhorn.append(bar)
+ bars_debiased.append(bar_debiased)
+
+titles = ["Sinkhorn", "Debiased"]
+all_bars = [bars_sinkhorn, bars_debiased]
+fig, axes = plt.subplots(2, 3, figsize=(8, 6), num=3)
+for jj, (method, ax_row, bars) in enumerate(zip(titles, axes, all_bars)):
+ for ii, (ax, img, eps) in enumerate(zip(ax_row, bars, epsilons)):
+ ax.imshow(img, cmap="Greys")
+ if jj == 0:
+ ax.set_title(r"$\varepsilon = %.3f$" % eps, fontsize=13)
+ ax.set_xticks([])
+ ax.set_yticks([])
+ ax.spines['top'].set_visible(False)
+ ax.spines['right'].set_visible(False)
+ ax.spines['bottom'].set_visible(False)
+ ax.spines['left'].set_visible(False)
+ if ii == 0:
+ ax.set_ylabel(method, fontsize=15)
+fig.tight_layout()
+plt.show()
diff --git a/examples/barycenters/plot_free_support_barycenter.py b/examples/barycenters/plot_free_support_barycenter.py
index 27ddc8e..2d68a39 100644
--- a/examples/barycenters/plot_free_support_barycenter.py
+++ b/examples/barycenters/plot_free_support_barycenter.py
@@ -1,8 +1,8 @@
# -*- coding: utf-8 -*-
"""
-====================================================
+========================================================
2D free support Wasserstein barycenters of distributions
-====================================================
+========================================================
Illustration of 2D Wasserstein barycenters if distributions are weighted
sum of diracs.
diff --git a/examples/domain-adaptation/plot_otda_color_images.py b/examples/domain-adaptation/plot_otda_color_images.py
index 929365e..06dc8ab 100644
--- a/examples/domain-adaptation/plot_otda_color_images.py
+++ b/examples/domain-adaptation/plot_otda_color_images.py
@@ -19,17 +19,20 @@ SIAM Journal on Imaging Sciences, 7(3), 1853-1882.
# sphinx_gallery_thumbnail_number = 2
+import os
+from pathlib import Path
+
import numpy as np
-import matplotlib.pylab as pl
+from matplotlib import pyplot as plt
import ot
-r = np.random.RandomState(42)
+rng = np.random.RandomState(42)
-def im2mat(I):
+def im2mat(img):
"""Converts an image to matrix (one pixel per line)"""
- return I.reshape((I.shape[0] * I.shape[1], I.shape[2]))
+ return img.reshape((img.shape[0] * img.shape[1], img.shape[2]))
def mat2im(X, shape):
@@ -37,8 +40,8 @@ def mat2im(X, shape):
return X.reshape(shape)
-def minmax(I):
- return np.clip(I, 0, 1)
+def minmax(img):
+ return np.clip(img, 0, 1)
##############################################################################
@@ -46,16 +49,19 @@ def minmax(I):
# -------------
# Loading images
-I1 = pl.imread('../../data/ocean_day.jpg').astype(np.float64) / 256
-I2 = pl.imread('../../data/ocean_sunset.jpg').astype(np.float64) / 256
+this_file = os.path.realpath('__file__')
+data_path = os.path.join(Path(this_file).parent.parent.parent, 'data')
+
+I1 = plt.imread(os.path.join(data_path, 'ocean_day.jpg')).astype(np.float64) / 256
+I2 = plt.imread(os.path.join(data_path, 'ocean_sunset.jpg')).astype(np.float64) / 256
X1 = im2mat(I1)
X2 = im2mat(I2)
# training samples
-nb = 1000
-idx1 = r.randint(X1.shape[0], size=(nb,))
-idx2 = r.randint(X2.shape[0], size=(nb,))
+nb = 500
+idx1 = rng.randint(X1.shape[0], size=(nb,))
+idx2 = rng.randint(X2.shape[0], size=(nb,))
Xs = X1[idx1, :]
Xt = X2[idx2, :]
@@ -65,39 +71,39 @@ Xt = X2[idx2, :]
# Plot original image
# -------------------
-pl.figure(1, figsize=(6.4, 3))
+plt.figure(1, figsize=(6.4, 3))
-pl.subplot(1, 2, 1)
-pl.imshow(I1)
-pl.axis('off')
-pl.title('Image 1')
+plt.subplot(1, 2, 1)
+plt.imshow(I1)
+plt.axis('off')
+plt.title('Image 1')
-pl.subplot(1, 2, 2)
-pl.imshow(I2)
-pl.axis('off')
-pl.title('Image 2')
+plt.subplot(1, 2, 2)
+plt.imshow(I2)
+plt.axis('off')
+plt.title('Image 2')
##############################################################################
# Scatter plot of colors
# ----------------------
-pl.figure(2, figsize=(6.4, 3))
+plt.figure(2, figsize=(6.4, 3))
-pl.subplot(1, 2, 1)
-pl.scatter(Xs[:, 0], Xs[:, 2], c=Xs)
-pl.axis([0, 1, 0, 1])
-pl.xlabel('Red')
-pl.ylabel('Blue')
-pl.title('Image 1')
+plt.subplot(1, 2, 1)
+plt.scatter(Xs[:, 0], Xs[:, 2], c=Xs)
+plt.axis([0, 1, 0, 1])
+plt.xlabel('Red')
+plt.ylabel('Blue')
+plt.title('Image 1')
-pl.subplot(1, 2, 2)
-pl.scatter(Xt[:, 0], Xt[:, 2], c=Xt)
-pl.axis([0, 1, 0, 1])
-pl.xlabel('Red')
-pl.ylabel('Blue')
-pl.title('Image 2')
-pl.tight_layout()
+plt.subplot(1, 2, 2)
+plt.scatter(Xt[:, 0], Xt[:, 2], c=Xt)
+plt.axis([0, 1, 0, 1])
+plt.xlabel('Red')
+plt.ylabel('Blue')
+plt.title('Image 2')
+plt.tight_layout()
##############################################################################
@@ -130,37 +136,37 @@ I2te = minmax(mat2im(transp_Xt_sinkhorn, I2.shape))
# Plot new images
# ---------------
-pl.figure(3, figsize=(8, 4))
+plt.figure(3, figsize=(8, 4))
-pl.subplot(2, 3, 1)
-pl.imshow(I1)
-pl.axis('off')
-pl.title('Image 1')
+plt.subplot(2, 3, 1)
+plt.imshow(I1)
+plt.axis('off')
+plt.title('Image 1')
-pl.subplot(2, 3, 2)
-pl.imshow(I1t)
-pl.axis('off')
-pl.title('Image 1 Adapt')
+plt.subplot(2, 3, 2)
+plt.imshow(I1t)
+plt.axis('off')
+plt.title('Image 1 Adapt')
-pl.subplot(2, 3, 3)
-pl.imshow(I1te)
-pl.axis('off')
-pl.title('Image 1 Adapt (reg)')
+plt.subplot(2, 3, 3)
+plt.imshow(I1te)
+plt.axis('off')
+plt.title('Image 1 Adapt (reg)')
-pl.subplot(2, 3, 4)
-pl.imshow(I2)
-pl.axis('off')
-pl.title('Image 2')
+plt.subplot(2, 3, 4)
+plt.imshow(I2)
+plt.axis('off')
+plt.title('Image 2')
-pl.subplot(2, 3, 5)
-pl.imshow(I2t)
-pl.axis('off')
-pl.title('Image 2 Adapt')
+plt.subplot(2, 3, 5)
+plt.imshow(I2t)
+plt.axis('off')
+plt.title('Image 2 Adapt')
-pl.subplot(2, 3, 6)
-pl.imshow(I2te)
-pl.axis('off')
-pl.title('Image 2 Adapt (reg)')
-pl.tight_layout()
+plt.subplot(2, 3, 6)
+plt.imshow(I2te)
+plt.axis('off')
+plt.title('Image 2 Adapt (reg)')
+plt.tight_layout()
-pl.show()
+plt.show()
diff --git a/examples/domain-adaptation/plot_otda_jcpot.py b/examples/domain-adaptation/plot_otda_jcpot.py
index c495690..0d974f4 100644
--- a/examples/domain-adaptation/plot_otda_jcpot.py
+++ b/examples/domain-adaptation/plot_otda_jcpot.py
@@ -1,8 +1,8 @@
# -*- coding: utf-8 -*-
"""
-========================
+================================
OT for multi-source target shift
-========================
+================================
This example introduces a target shift problem with two 2D source and 1 target domain.
diff --git a/examples/domain-adaptation/plot_otda_linear_mapping.py b/examples/domain-adaptation/plot_otda_linear_mapping.py
index dbf16b8..a44096a 100644
--- a/examples/domain-adaptation/plot_otda_linear_mapping.py
+++ b/examples/domain-adaptation/plot_otda_linear_mapping.py
@@ -13,9 +13,11 @@ Linear OT mapping estimation
# License: MIT License
# sphinx_gallery_thumbnail_number = 2
+import os
+from pathlib import Path
import numpy as np
-import pylab as pl
+from matplotlib import pyplot as plt
import ot
##############################################################################
@@ -26,17 +28,19 @@ n = 1000
d = 2
sigma = .1
+rng = np.random.RandomState(42)
+
# source samples
-angles = np.random.rand(n, 1) * 2 * np.pi
+angles = rng.rand(n, 1) * 2 * np.pi
xs = np.concatenate((np.sin(angles), np.cos(angles)),
- axis=1) + sigma * np.random.randn(n, 2)
+ axis=1) + sigma * rng.randn(n, 2)
xs[:n // 2, 1] += 2
# target samples
-anglet = np.random.rand(n, 1) * 2 * np.pi
+anglet = rng.rand(n, 1) * 2 * np.pi
xt = np.concatenate((np.sin(anglet), np.cos(anglet)),
- axis=1) + sigma * np.random.randn(n, 2)
+ axis=1) + sigma * rng.randn(n, 2)
xt[:n // 2, 1] += 2
@@ -48,9 +52,9 @@ xt = xt.dot(A) + b
# Plot data
# ---------
-pl.figure(1, (5, 5))
-pl.plot(xs[:, 0], xs[:, 1], '+')
-pl.plot(xt[:, 0], xt[:, 1], 'o')
+plt.figure(1, (5, 5))
+plt.plot(xs[:, 0], xs[:, 1], '+')
+plt.plot(xt[:, 0], xt[:, 1], 'o')
##############################################################################
@@ -66,22 +70,22 @@ xst = xs.dot(Ae) + be
# Plot transported samples
# ------------------------
-pl.figure(1, (5, 5))
-pl.clf()
-pl.plot(xs[:, 0], xs[:, 1], '+')
-pl.plot(xt[:, 0], xt[:, 1], 'o')
-pl.plot(xst[:, 0], xst[:, 1], '+')
+plt.figure(1, (5, 5))
+plt.clf()
+plt.plot(xs[:, 0], xs[:, 1], '+')
+plt.plot(xt[:, 0], xt[:, 1], 'o')
+plt.plot(xst[:, 0], xst[:, 1], '+')
-pl.show()
+plt.show()
##############################################################################
# Load image data
# ---------------
-def im2mat(I):
+def im2mat(img):
"""Converts and image to matrix (one pixel per line)"""
- return I.reshape((I.shape[0] * I.shape[1], I.shape[2]))
+ return img.reshape((img.shape[0] * img.shape[1], img.shape[2]))
def mat2im(X, shape):
@@ -89,13 +93,16 @@ def mat2im(X, shape):
return X.reshape(shape)
-def minmax(I):
- return np.clip(I, 0, 1)
+def minmax(img):
+ return np.clip(img, 0, 1)
# Loading images
-I1 = pl.imread('../../data/ocean_day.jpg').astype(np.float64) / 256
-I2 = pl.imread('../../data/ocean_sunset.jpg').astype(np.float64) / 256
+this_file = os.path.realpath('__file__')
+data_path = os.path.join(Path(this_file).parent.parent.parent, 'data')
+
+I1 = plt.imread(os.path.join(data_path, 'ocean_day.jpg')).astype(np.float64) / 256
+I2 = plt.imread(os.path.join(data_path, 'ocean_sunset.jpg')).astype(np.float64) / 256
X1 = im2mat(I1)
@@ -123,24 +130,24 @@ I2t = minmax(mat2im(xts, I2.shape))
# Plot transformed images
# -----------------------
-pl.figure(2, figsize=(10, 7))
+plt.figure(2, figsize=(10, 7))
-pl.subplot(2, 2, 1)
-pl.imshow(I1)
-pl.axis('off')
-pl.title('Im. 1')
+plt.subplot(2, 2, 1)
+plt.imshow(I1)
+plt.axis('off')
+plt.title('Im. 1')
-pl.subplot(2, 2, 2)
-pl.imshow(I2)
-pl.axis('off')
-pl.title('Im. 2')
+plt.subplot(2, 2, 2)
+plt.imshow(I2)
+plt.axis('off')
+plt.title('Im. 2')
-pl.subplot(2, 2, 3)
-pl.imshow(I1t)
-pl.axis('off')
-pl.title('Mapping Im. 1')
+plt.subplot(2, 2, 3)
+plt.imshow(I1t)
+plt.axis('off')
+plt.title('Mapping Im. 1')
-pl.subplot(2, 2, 4)
-pl.imshow(I2t)
-pl.axis('off')
-pl.title('Inverse mapping Im. 2')
+plt.subplot(2, 2, 4)
+plt.imshow(I2t)
+plt.axis('off')
+plt.title('Inverse mapping Im. 2')
diff --git a/examples/domain-adaptation/plot_otda_mapping_colors_images.py b/examples/domain-adaptation/plot_otda_mapping_colors_images.py
index ee5c8b0..dbece70 100644
--- a/examples/domain-adaptation/plot_otda_mapping_colors_images.py
+++ b/examples/domain-adaptation/plot_otda_mapping_colors_images.py
@@ -21,17 +21,19 @@ discrete optimal transport", Neural Information Processing Systems (NIPS), 2016.
# License: MIT License
# sphinx_gallery_thumbnail_number = 3
+import os
+from pathlib import Path
import numpy as np
-import matplotlib.pylab as pl
+from matplotlib import pyplot as plt
import ot
-r = np.random.RandomState(42)
+rng = np.random.RandomState(42)
-def im2mat(I):
+def im2mat(img):
"""Converts and image to matrix (one pixel per line)"""
- return I.reshape((I.shape[0] * I.shape[1], I.shape[2]))
+ return img.reshape((img.shape[0] * img.shape[1], img.shape[2]))
def mat2im(X, shape):
@@ -39,8 +41,8 @@ def mat2im(X, shape):
return X.reshape(shape)
-def minmax(I):
- return np.clip(I, 0, 1)
+def minmax(img):
+ return np.clip(img, 0, 1)
##############################################################################
@@ -48,17 +50,19 @@ def minmax(I):
# -------------
# Loading images
-I1 = pl.imread('../../data/ocean_day.jpg').astype(np.float64) / 256
-I2 = pl.imread('../../data/ocean_sunset.jpg').astype(np.float64) / 256
+this_file = os.path.realpath('__file__')
+data_path = os.path.join(Path(this_file).parent.parent.parent, 'data')
+I1 = plt.imread(os.path.join(data_path, 'ocean_day.jpg')).astype(np.float64) / 256
+I2 = plt.imread(os.path.join(data_path, 'ocean_sunset.jpg')).astype(np.float64) / 256
X1 = im2mat(I1)
X2 = im2mat(I2)
# training samples
-nb = 1000
-idx1 = r.randint(X1.shape[0], size=(nb,))
-idx2 = r.randint(X2.shape[0], size=(nb,))
+nb = 500
+idx1 = rng.randint(X1.shape[0], size=(nb,))
+idx2 = rng.randint(X2.shape[0], size=(nb,))
Xs = X1[idx1, :]
Xt = X2[idx2, :]
@@ -99,76 +103,76 @@ Image_mapping_gaussian = minmax(mat2im(X1tn, I1.shape))
# Plot original images
# --------------------
-pl.figure(1, figsize=(6.4, 3))
-pl.subplot(1, 2, 1)
-pl.imshow(I1)
-pl.axis('off')
-pl.title('Image 1')
+plt.figure(1, figsize=(6.4, 3))
+plt.subplot(1, 2, 1)
+plt.imshow(I1)
+plt.axis('off')
+plt.title('Image 1')
-pl.subplot(1, 2, 2)
-pl.imshow(I2)
-pl.axis('off')
-pl.title('Image 2')
-pl.tight_layout()
+plt.subplot(1, 2, 2)
+plt.imshow(I2)
+plt.axis('off')
+plt.title('Image 2')
+plt.tight_layout()
##############################################################################
# Plot pixel values distribution
# ------------------------------
-pl.figure(2, figsize=(6.4, 5))
+plt.figure(2, figsize=(6.4, 5))
-pl.subplot(1, 2, 1)
-pl.scatter(Xs[:, 0], Xs[:, 2], c=Xs)
-pl.axis([0, 1, 0, 1])
-pl.xlabel('Red')
-pl.ylabel('Blue')
-pl.title('Image 1')
+plt.subplot(1, 2, 1)
+plt.scatter(Xs[:, 0], Xs[:, 2], c=Xs)
+plt.axis([0, 1, 0, 1])
+plt.xlabel('Red')
+plt.ylabel('Blue')
+plt.title('Image 1')
-pl.subplot(1, 2, 2)
-pl.scatter(Xt[:, 0], Xt[:, 2], c=Xt)
-pl.axis([0, 1, 0, 1])
-pl.xlabel('Red')
-pl.ylabel('Blue')
-pl.title('Image 2')
-pl.tight_layout()
+plt.subplot(1, 2, 2)
+plt.scatter(Xt[:, 0], Xt[:, 2], c=Xt)
+plt.axis([0, 1, 0, 1])
+plt.xlabel('Red')
+plt.ylabel('Blue')
+plt.title('Image 2')
+plt.tight_layout()
##############################################################################
# Plot transformed images
# -----------------------
-pl.figure(2, figsize=(10, 5))
+plt.figure(2, figsize=(10, 5))
-pl.subplot(2, 3, 1)
-pl.imshow(I1)
-pl.axis('off')
-pl.title('Im. 1')
+plt.subplot(2, 3, 1)
+plt.imshow(I1)
+plt.axis('off')
+plt.title('Im. 1')
-pl.subplot(2, 3, 4)
-pl.imshow(I2)
-pl.axis('off')
-pl.title('Im. 2')
+plt.subplot(2, 3, 4)
+plt.imshow(I2)
+plt.axis('off')
+plt.title('Im. 2')
-pl.subplot(2, 3, 2)
-pl.imshow(Image_emd)
-pl.axis('off')
-pl.title('EmdTransport')
+plt.subplot(2, 3, 2)
+plt.imshow(Image_emd)
+plt.axis('off')
+plt.title('EmdTransport')
-pl.subplot(2, 3, 5)
-pl.imshow(Image_sinkhorn)
-pl.axis('off')
-pl.title('SinkhornTransport')
+plt.subplot(2, 3, 5)
+plt.imshow(Image_sinkhorn)
+plt.axis('off')
+plt.title('SinkhornTransport')
-pl.subplot(2, 3, 3)
-pl.imshow(Image_mapping_linear)
-pl.axis('off')
-pl.title('MappingTransport (linear)')
+plt.subplot(2, 3, 3)
+plt.imshow(Image_mapping_linear)
+plt.axis('off')
+plt.title('MappingTransport (linear)')
-pl.subplot(2, 3, 6)
-pl.imshow(Image_mapping_gaussian)
-pl.axis('off')
-pl.title('MappingTransport (gaussian)')
-pl.tight_layout()
+plt.subplot(2, 3, 6)
+plt.imshow(Image_mapping_gaussian)
+plt.axis('off')
+plt.title('MappingTransport (gaussian)')
+plt.tight_layout()
-pl.show()
+plt.show()
diff --git a/examples/gromov/plot_barycenter_fgw.py b/examples/gromov/plot_barycenter_fgw.py
index 3f81765..556e08f 100644
--- a/examples/gromov/plot_barycenter_fgw.py
+++ b/examples/gromov/plot_barycenter_fgw.py
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
"""
=================================
-Plot graphs' barycenter using FGW
+Plot graphs barycenter using FGW
=================================
This example illustrates the computation barycenter of labeled graphs using
diff --git a/examples/gromov/plot_fgw.py b/examples/gromov/plot_fgw.py
index 97fe619..5475fb3 100644
--- a/examples/gromov/plot_fgw.py
+++ b/examples/gromov/plot_fgw.py
@@ -26,7 +26,7 @@ from ot.gromov import gromov_wasserstein, fused_gromov_wasserstein
##############################################################################
# Generate data
-# ---------
+# -------------
#%% parameters
# We create two 1D random measures
@@ -76,7 +76,7 @@ pl.show()
##############################################################################
# Create structure matrices and across-feature distance matrix
-# ---------
+# ------------------------------------------------------------
#%% Structure matrices and across-features distance matrix
C1 = ot.dist(xs)
@@ -88,7 +88,7 @@ Got = ot.emd([], [], M)
##############################################################################
# Plot matrices
-# ---------
+# -------------
#%%
cmap = 'Reds'
@@ -131,7 +131,7 @@ pl.show()
##############################################################################
# Compute FGW/GW
-# ---------
+# --------------
#%% Computing FGW and GW
alpha = 1e-3
@@ -145,7 +145,7 @@ Gg, log = gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', verbose=True,
##############################################################################
# Visualize transport matrices
-# ---------
+# ----------------------------
#%% visu OT matrix
cmap = 'Blues'
diff --git a/examples/gromov/plot_gromov.py b/examples/gromov/plot_gromov.py
index deb2f86..5a362cf 100644
--- a/examples/gromov/plot_gromov.py
+++ b/examples/gromov/plot_gromov.py
@@ -104,3 +104,37 @@ pl.imshow(gw, cmap='jet')
pl.title('Entropic Gromov Wasserstein')
pl.show()
+
+#############################################################################
+#
+# Compute GW with a scalable stochastic method with any loss function
+# ----------------------------------------------------------------------
+
+
+def loss(x, y):
+ return np.abs(x - y)
+
+
+pgw, plog = ot.gromov.pointwise_gromov_wasserstein(C1, C2, p, q, loss, max_iter=100,
+ log=True)
+
+sgw, slog = ot.gromov.sampled_gromov_wasserstein(C1, C2, p, q, loss, epsilon=0.1, max_iter=100,
+ log=True)
+
+print('Pointwise Gromov-Wasserstein distance estimated: ' + str(plog['gw_dist_estimated']))
+print('Variance estimated: ' + str(plog['gw_dist_std']))
+print('Sampled Gromov-Wasserstein distance: ' + str(slog['gw_dist_estimated']))
+print('Variance estimated: ' + str(slog['gw_dist_std']))
+
+
+pl.figure(1, (10, 5))
+
+pl.subplot(1, 2, 1)
+pl.imshow(pgw.toarray(), cmap='jet')
+pl.title('Pointwise Gromov Wasserstein')
+
+pl.subplot(1, 2, 2)
+pl.imshow(sgw, cmap='jet')
+pl.title('Sampled Gromov Wasserstein')
+
+pl.show()
diff --git a/examples/gromov/plot_gromov_barycenter.py b/examples/gromov/plot_gromov_barycenter.py
index f6f031a..7fe081f 100755
--- a/examples/gromov/plot_gromov_barycenter.py
+++ b/examples/gromov/plot_gromov_barycenter.py
@@ -13,11 +13,13 @@ computation in POT.
#
# License: MIT License
+import os
+from pathlib import Path
import numpy as np
import scipy as sp
-import matplotlib.pylab as pl
+from matplotlib import pyplot as plt
from sklearn import manifold
from sklearn.decomposition import PCA
@@ -84,22 +86,24 @@ def smacof_mds(C, dim, max_iter=3000, eps=1e-9):
# The four distributions are constructed from 4 simple images
-def im2mat(I):
+def im2mat(img):
"""Converts and image to matrix (one pixel per line)"""
- return I.reshape((I.shape[0] * I.shape[1], I.shape[2]))
+ return img.reshape((img.shape[0] * img.shape[1], img.shape[2]))
-square = pl.imread('../../data/square.png').astype(np.float64)[:, :, 2]
-cross = pl.imread('../../data/cross.png').astype(np.float64)[:, :, 2]
-triangle = pl.imread('../../data/triangle.png').astype(np.float64)[:, :, 2]
-star = pl.imread('../../data/star.png').astype(np.float64)[:, :, 2]
+this_file = os.path.realpath('__file__')
+data_path = os.path.join(Path(this_file).parent.parent.parent, 'data')
+
+square = plt.imread(os.path.join(data_path, 'square.png')).astype(np.float64)[:, :, 2]
+cross = plt.imread(os.path.join(data_path, 'cross.png')).astype(np.float64)[:, :, 2]
+triangle = plt.imread(os.path.join(data_path, 'triangle.png')).astype(np.float64)[:, :, 2]
+star = plt.imread(os.path.join(data_path, 'star.png')).astype(np.float64)[:, :, 2]
shapes = [square, cross, triangle, star]
S = 4
xs = [[] for i in range(S)]
-
for nb in range(4):
for i in range(8):
for j in range(8):
@@ -184,64 +188,64 @@ npost23 = [smacof_mds(Ct23[s], 2) for s in range(2)]
npost23 = [clf.fit_transform(npost23[s]) for s in range(2)]
-fig = pl.figure(figsize=(10, 10))
+fig = plt.figure(figsize=(10, 10))
-ax1 = pl.subplot2grid((4, 4), (0, 0))
-pl.xlim((-1, 1))
-pl.ylim((-1, 1))
+ax1 = plt.subplot2grid((4, 4), (0, 0))
+plt.xlim((-1, 1))
+plt.ylim((-1, 1))
ax1.scatter(npos[0][:, 0], npos[0][:, 1], color='r')
-ax2 = pl.subplot2grid((4, 4), (0, 1))
-pl.xlim((-1, 1))
-pl.ylim((-1, 1))
+ax2 = plt.subplot2grid((4, 4), (0, 1))
+plt.xlim((-1, 1))
+plt.ylim((-1, 1))
ax2.scatter(npost01[1][:, 0], npost01[1][:, 1], color='b')
-ax3 = pl.subplot2grid((4, 4), (0, 2))
-pl.xlim((-1, 1))
-pl.ylim((-1, 1))
+ax3 = plt.subplot2grid((4, 4), (0, 2))
+plt.xlim((-1, 1))
+plt.ylim((-1, 1))
ax3.scatter(npost01[0][:, 0], npost01[0][:, 1], color='b')
-ax4 = pl.subplot2grid((4, 4), (0, 3))
-pl.xlim((-1, 1))
-pl.ylim((-1, 1))
+ax4 = plt.subplot2grid((4, 4), (0, 3))
+plt.xlim((-1, 1))
+plt.ylim((-1, 1))
ax4.scatter(npos[1][:, 0], npos[1][:, 1], color='r')
-ax5 = pl.subplot2grid((4, 4), (1, 0))
-pl.xlim((-1, 1))
-pl.ylim((-1, 1))
+ax5 = plt.subplot2grid((4, 4), (1, 0))
+plt.xlim((-1, 1))
+plt.ylim((-1, 1))
ax5.scatter(npost02[1][:, 0], npost02[1][:, 1], color='b')
-ax6 = pl.subplot2grid((4, 4), (1, 3))
-pl.xlim((-1, 1))
-pl.ylim((-1, 1))
+ax6 = plt.subplot2grid((4, 4), (1, 3))
+plt.xlim((-1, 1))
+plt.ylim((-1, 1))
ax6.scatter(npost13[1][:, 0], npost13[1][:, 1], color='b')
-ax7 = pl.subplot2grid((4, 4), (2, 0))
-pl.xlim((-1, 1))
-pl.ylim((-1, 1))
+ax7 = plt.subplot2grid((4, 4), (2, 0))
+plt.xlim((-1, 1))
+plt.ylim((-1, 1))
ax7.scatter(npost02[0][:, 0], npost02[0][:, 1], color='b')
-ax8 = pl.subplot2grid((4, 4), (2, 3))
-pl.xlim((-1, 1))
-pl.ylim((-1, 1))
+ax8 = plt.subplot2grid((4, 4), (2, 3))
+plt.xlim((-1, 1))
+plt.ylim((-1, 1))
ax8.scatter(npost13[0][:, 0], npost13[0][:, 1], color='b')
-ax9 = pl.subplot2grid((4, 4), (3, 0))
-pl.xlim((-1, 1))
-pl.ylim((-1, 1))
+ax9 = plt.subplot2grid((4, 4), (3, 0))
+plt.xlim((-1, 1))
+plt.ylim((-1, 1))
ax9.scatter(npos[2][:, 0], npos[2][:, 1], color='r')
-ax10 = pl.subplot2grid((4, 4), (3, 1))
-pl.xlim((-1, 1))
-pl.ylim((-1, 1))
+ax10 = plt.subplot2grid((4, 4), (3, 1))
+plt.xlim((-1, 1))
+plt.ylim((-1, 1))
ax10.scatter(npost23[1][:, 0], npost23[1][:, 1], color='b')
-ax11 = pl.subplot2grid((4, 4), (3, 2))
-pl.xlim((-1, 1))
-pl.ylim((-1, 1))
+ax11 = plt.subplot2grid((4, 4), (3, 2))
+plt.xlim((-1, 1))
+plt.ylim((-1, 1))
ax11.scatter(npost23[0][:, 0], npost23[0][:, 1], color='b')
-ax12 = pl.subplot2grid((4, 4), (3, 3))
-pl.xlim((-1, 1))
-pl.ylim((-1, 1))
+ax12 = plt.subplot2grid((4, 4), (3, 3))
+plt.xlim((-1, 1))
+plt.ylim((-1, 1))
ax12.scatter(npos[3][:, 0], npos[3][:, 1], color='r')
diff --git a/examples/plot_Intro_OT.py b/examples/plot_Intro_OT.py
new file mode 100644
index 0000000..2e2c6fd
--- /dev/null
+++ b/examples/plot_Intro_OT.py
@@ -0,0 +1,373 @@
+# coding: utf-8
+"""
+=============================================
+Introduction to Optimal Transport with Python
+=============================================
+
+This example gives an introduction on how to use Optimal Transport in Python.
+
+"""
+
+# Author: Remi Flamary, Nicolas Courty, Aurelie Boisbunon
+#
+# License: MIT License
+# sphinx_gallery_thumbnail_number = 1
+
+##############################################################################
+# POT Python Optimal Transport Toolbox
+# ------------------------------------
+#
+# POT installation
+# ```````````````````
+#
+# * Install with pip::
+#
+# pip install pot
+# * Install with conda::
+#
+# conda install -c conda-forge pot
+#
+# Import the toolbox
+# ```````````````````
+#
+
+import numpy as np # always need it
+import pylab as pl # do the plots
+
+import ot # ot
+
+import time
+
+##############################################################################
+# Getting help
+# `````````````
+#
+# Online documentation : `<https://pythonot.github.io/all.html>`_
+#
+# Or inline help:
+#
+
+help(ot.dist)
+
+
+##############################################################################
+# First OT Problem
+# ----------------
+#
+# We will solve the Bakery/Cafés problem of transporting croissants from a
+# number of Bakeries to Cafés in a City (in this case Manhattan). We did a
+# quick google map search in Manhattan for bakeries and Cafés:
+#
+# .. image:: images/bak.png
+# :align: center
+# :alt: bakery-cafe-manhattan
+# :width: 600px
+# :height: 280px
+#
+# We extracted from this search their positions and generated fictional
+# production and sale number (that both sum to the same value).
+#
+# We have acess to the position of Bakeries ``bakery_pos`` and their
+# respective production ``bakery_prod`` which describe the source
+# distribution. The Cafés where the croissants are sold are defined also by
+# their position ``cafe_pos`` and ``cafe_prod``, and describe the target
+# distribution. For fun we also provide a
+# map ``Imap`` that will illustrate the position of these shops in the city.
+#
+#
+# Now we load the data
+#
+#
+
+data = np.load('../data/manhattan.npz')
+
+bakery_pos = data['bakery_pos']
+bakery_prod = data['bakery_prod']
+cafe_pos = data['cafe_pos']
+cafe_prod = data['cafe_prod']
+Imap = data['Imap']
+
+print('Bakery production: {}'.format(bakery_prod))
+print('Cafe sale: {}'.format(cafe_prod))
+print('Total croissants : {}'.format(cafe_prod.sum()))
+
+
+##############################################################################
+# Plotting bakeries in the city
+# -----------------------------
+#
+# Next we plot the position of the bakeries and cafés on the map. The size of
+# the circle is proportional to their production.
+#
+
+pl.figure(1, (7, 6))
+pl.clf()
+pl.imshow(Imap, interpolation='bilinear') # plot the map
+pl.scatter(bakery_pos[:, 0], bakery_pos[:, 1], s=bakery_prod, c='r', ec='k', label='Bakeries')
+pl.scatter(cafe_pos[:, 0], cafe_pos[:, 1], s=cafe_prod, c='b', ec='k', label='Cafés')
+pl.legend()
+pl.title('Manhattan Bakeries and Cafés')
+
+
+##############################################################################
+# Cost matrix
+# -----------
+#
+#
+# We can now compute the cost matrix between the bakeries and the cafés, which
+# will be the transport cost matrix. This can be done using the
+# `ot.dist <https://pythonot.github.io/all.html#ot.dist>`_ function that
+# defaults to squared Euclidean distance but can return other things such as
+# cityblock (or Manhattan distance).
+#
+
+C = ot.dist(bakery_pos, cafe_pos)
+
+labels = [str(i) for i in range(len(bakery_prod))]
+f = pl.figure(2, (14, 7))
+pl.clf()
+pl.subplot(121)
+pl.imshow(Imap, interpolation='bilinear') # plot the map
+for i in range(len(cafe_pos)):
+ pl.text(cafe_pos[i, 0], cafe_pos[i, 1], labels[i], color='b',
+ fontsize=14, fontweight='bold', ha='center', va='center')
+for i in range(len(bakery_pos)):
+ pl.text(bakery_pos[i, 0], bakery_pos[i, 1], labels[i], color='r',
+ fontsize=14, fontweight='bold', ha='center', va='center')
+pl.title('Manhattan Bakeries and Cafés')
+
+ax = pl.subplot(122)
+im = pl.imshow(C, cmap="coolwarm")
+pl.title('Cost matrix')
+cbar = pl.colorbar(im, ax=ax, shrink=0.5, use_gridspec=True)
+cbar.ax.set_ylabel("cost", rotation=-90, va="bottom")
+
+pl.xlabel('Cafés')
+pl.ylabel('Bakeries')
+pl.tight_layout()
+
+
+##############################################################################
+# The red cells in the matrix image show the bakeries and cafés that are
+# further away, and thus more costly to transport from one to the other, while
+# the blue ones show those that are very close to each other, with respect to
+# the squared Euclidean distance.
+
+
+##############################################################################
+# Solving the OT problem with `ot.emd <https://pythonot.github.io/all.html#ot.emd>`_
+# -----------------------------------------------------------------------------------
+
+start = time.time()
+ot_emd = ot.emd(bakery_prod, cafe_prod, C)
+time_emd = time.time() - start
+
+##############################################################################
+# The function returns the transport matrix, which we can then visualize (next section).
+
+##############################################################################
+# Transportation plan vizualization
+# `````````````````````````````````
+#
+# A good vizualization of the OT matrix in the 2D plane is to denote the
+# transportation of mass between a Bakery and a Café by a line. This can easily
+# be done with a double ``for`` loop.
+#
+# In order to make it more interpretable one can also use the ``alpha``
+# parameter of plot and set it to ``alpha=G[i,j]/G.max()``.
+
+# Plot the matrix and the map
+f = pl.figure(3, (14, 7))
+pl.clf()
+pl.subplot(121)
+pl.imshow(Imap, interpolation='bilinear') # plot the map
+for i in range(len(bakery_pos)):
+ for j in range(len(cafe_pos)):
+ pl.plot([bakery_pos[i, 0], cafe_pos[j, 0]], [bakery_pos[i, 1], cafe_pos[j, 1]],
+ '-k', lw=3. * ot_emd[i, j] / ot_emd.max())
+for i in range(len(cafe_pos)):
+ pl.text(cafe_pos[i, 0], cafe_pos[i, 1], labels[i], color='b', fontsize=14,
+ fontweight='bold', ha='center', va='center')
+for i in range(len(bakery_pos)):
+ pl.text(bakery_pos[i, 0], bakery_pos[i, 1], labels[i], color='r', fontsize=14,
+ fontweight='bold', ha='center', va='center')
+pl.title('Manhattan Bakeries and Cafés')
+
+ax = pl.subplot(122)
+im = pl.imshow(ot_emd)
+for i in range(len(bakery_prod)):
+ for j in range(len(cafe_prod)):
+ text = ax.text(j, i, '{0:g}'.format(ot_emd[i, j]),
+ ha="center", va="center", color="w")
+pl.title('Transport matrix')
+
+pl.xlabel('Cafés')
+pl.ylabel('Bakeries')
+pl.tight_layout()
+
+##############################################################################
+# The transport matrix gives the number of croissants that can be transported
+# from each bakery to each café. We can see that the bakeries only need to
+# transport croissants to one or two cafés, the transport matrix being very
+# sparse.
+
+##############################################################################
+# OT loss and dual variables
+# --------------------------
+#
+# The resulting wasserstein loss loss is of the form:
+#
+# .. math::
+# W=\sum_{i,j}\gamma_{i,j}C_{i,j}
+#
+# where :math:`\gamma` is the optimal transport matrix.
+#
+
+W = np.sum(ot_emd * C)
+print('Wasserstein loss (EMD) = {0:.2f}'.format(W))
+
+##############################################################################
+# Regularized OT with Sinkhorn
+# ----------------------------
+#
+# The Sinkhorn algorithm is very simple to code. You can implement it directly
+# using the following pseudo-code
+#
+# .. image:: images/sinkhorn.png
+# :align: center
+# :alt: Sinkhorn algorithm
+# :width: 440px
+# :height: 240px
+#
+# In this algorithm, :math:`\oslash` corresponds to the element-wise division.
+#
+# An alternative is to use the POT toolbox with
+# `ot.sinkhorn <https://pythonot.github.io/all.html#ot.sinkhorn>`_
+#
+# Be careful of numerical problems. A good pre-processing for Sinkhorn is to
+# divide the cost matrix ``C`` by its maximum value.
+
+##############################################################################
+# Algorithm
+# `````````
+
+# Compute Sinkhorn transport matrix from algorithm
+reg = 0.1
+K = np.exp(-C / C.max() / reg)
+nit = 100
+u = np.ones((len(bakery_prod), ))
+for i in range(1, nit):
+ v = cafe_prod / np.dot(K.T, u)
+ u = bakery_prod / (np.dot(K, v))
+ot_sink_algo = np.atleast_2d(u).T * (K * v.T) # Equivalent to np.dot(np.diag(u), np.dot(K, np.diag(v)))
+
+# Compute Sinkhorn transport matrix with POT
+ot_sinkhorn = ot.sinkhorn(bakery_prod, cafe_prod, reg=reg, M=C / C.max())
+
+# Difference between the 2
+print('Difference between algo and ot.sinkhorn = {0:.2g}'.format(np.sum(np.power(ot_sink_algo - ot_sinkhorn, 2))))
+
+##############################################################################
+# Plot the matrix and the map
+# ```````````````````````````
+
+print('Min. of Sinkhorn\'s transport matrix = {0:.2g}'.format(np.min(ot_sinkhorn)))
+
+f = pl.figure(4, (13, 6))
+pl.clf()
+pl.subplot(121)
+pl.imshow(Imap, interpolation='bilinear') # plot the map
+for i in range(len(bakery_pos)):
+ for j in range(len(cafe_pos)):
+ pl.plot([bakery_pos[i, 0], cafe_pos[j, 0]],
+ [bakery_pos[i, 1], cafe_pos[j, 1]],
+ '-k', lw=3. * ot_sinkhorn[i, j] / ot_sinkhorn.max())
+for i in range(len(cafe_pos)):
+ pl.text(cafe_pos[i, 0], cafe_pos[i, 1], labels[i], color='b',
+ fontsize=14, fontweight='bold', ha='center', va='center')
+for i in range(len(bakery_pos)):
+ pl.text(bakery_pos[i, 0], bakery_pos[i, 1], labels[i], color='r',
+ fontsize=14, fontweight='bold', ha='center', va='center')
+pl.title('Manhattan Bakeries and Cafés')
+
+ax = pl.subplot(122)
+im = pl.imshow(ot_sinkhorn)
+for i in range(len(bakery_prod)):
+ for j in range(len(cafe_prod)):
+ text = ax.text(j, i, np.round(ot_sinkhorn[i, j], 1),
+ ha="center", va="center", color="w")
+pl.title('Transport matrix')
+
+pl.xlabel('Cafés')
+pl.ylabel('Bakeries')
+pl.tight_layout()
+
+
+##############################################################################
+# We notice right away that the matrix is not sparse at all with Sinkhorn,
+# each bakery delivering croissants to all 5 cafés with that solution. Also,
+# this solution gives a transport with fractions, which does not make sense
+# in the case of croissants. This was not the case with EMD.
+
+##############################################################################
+# Varying the regularization parameter in Sinkhorn
+# ````````````````````````````````````````````````
+#
+
+reg_parameter = np.logspace(-3, 0, 20)
+W_sinkhorn_reg = np.zeros((len(reg_parameter), ))
+time_sinkhorn_reg = np.zeros((len(reg_parameter), ))
+
+f = pl.figure(5, (14, 5))
+pl.clf()
+max_ot = 100 # plot matrices with the same colorbar
+for k in range(len(reg_parameter)):
+ start = time.time()
+ ot_sinkhorn = ot.sinkhorn(bakery_prod, cafe_prod, reg=reg_parameter[k], M=C / C.max())
+ time_sinkhorn_reg[k] = time.time() - start
+
+ if k % 4 == 0 and k > 0: # we only plot a few
+ ax = pl.subplot(1, 5, k / 4)
+ im = pl.imshow(ot_sinkhorn, vmin=0, vmax=max_ot)
+ pl.title('reg={0:.2g}'.format(reg_parameter[k]))
+ pl.xlabel('Cafés')
+ pl.ylabel('Bakeries')
+
+ # Compute the Wasserstein loss for Sinkhorn, and compare with EMD
+ W_sinkhorn_reg[k] = np.sum(ot_sinkhorn * C)
+pl.tight_layout()
+
+
+##############################################################################
+# This series of graph shows that the solution of Sinkhorn starts with something
+# very similar to EMD (although not sparse) for very small values of the
+# regularization parameter, and tends to a more uniform solution as the
+# regularization parameter increases.
+#
+
+##############################################################################
+# Wasserstein loss and computational time
+# ```````````````````````````````````````
+#
+
+# Plot the matrix and the map
+f = pl.figure(6, (4, 4))
+pl.clf()
+pl.title("Comparison between Sinkhorn and EMD")
+
+pl.plot(reg_parameter, W_sinkhorn_reg, 'o', label="Sinkhorn")
+XLim = pl.xlim()
+pl.plot(XLim, [W, W], '--k', label="EMD")
+pl.legend()
+pl.xlabel("reg")
+pl.ylabel("Wasserstein loss")
+
+##############################################################################
+# In this last graph, we show the impact of the regularization parameter on
+# the Wasserstein loss. We can see that higher
+# values of ``reg`` leads to a much higher Wasserstein loss.
+#
+# The Wasserstein loss of EMD is displayed for
+# comparison. The Wasserstein loss of Sinkhorn can be a little lower than that
+# of EMD for low values of ``reg``, but it quickly gets much higher.
+#
diff --git a/examples/plot_OT_1D_smooth.py b/examples/plot_OT_1D_smooth.py
index 75cd295..b07f99f 100644
--- a/examples/plot_OT_1D_smooth.py
+++ b/examples/plot_OT_1D_smooth.py
@@ -87,7 +87,7 @@ pl.show()
##############################################################################
# Solve Smooth OT
-# --------------
+# ---------------
#%% Smooth OT with KL regularization
diff --git a/examples/plot_OT_2D_samples.py b/examples/plot_OT_2D_samples.py
index 1544e82..af1bc12 100644
--- a/examples/plot_OT_2D_samples.py
+++ b/examples/plot_OT_2D_samples.py
@@ -107,7 +107,7 @@ pl.show()
##############################################################################
# Emprirical Sinkhorn
-# ----------------
+# -------------------
#%% sinkhorn
diff --git a/examples/sliced-wasserstein/README.txt b/examples/sliced-wasserstein/README.txt
new file mode 100644
index 0000000..a575345
--- /dev/null
+++ b/examples/sliced-wasserstein/README.txt
@@ -0,0 +1,4 @@
+
+
+Sliced Wasserstein Distance
+--------------------------- \ No newline at end of file
diff --git a/examples/sliced-wasserstein/plot_variance.py b/examples/sliced-wasserstein/plot_variance.py
new file mode 100644
index 0000000..7d73907
--- /dev/null
+++ b/examples/sliced-wasserstein/plot_variance.py
@@ -0,0 +1,86 @@
+# -*- coding: utf-8 -*-
+"""
+==============================
+2D Sliced Wasserstein Distance
+==============================
+
+This example illustrates the computation of the sliced Wasserstein Distance as
+proposed in [31].
+
+[31] Bonneel, Nicolas, et al. "Sliced and radon wasserstein barycenters of
+measures." Journal of Mathematical Imaging and Vision 51.1 (2015): 22-45
+
+"""
+
+# Author: Adrien Corenflos <adrien.corenflos@aalto.fi>
+#
+# License: MIT License
+
+import matplotlib.pylab as pl
+import numpy as np
+
+import ot
+
+##############################################################################
+# Generate data
+# -------------
+
+# %% parameters and data generation
+
+n = 500 # nb samples
+
+mu_s = np.array([0, 0])
+cov_s = np.array([[1, 0], [0, 1]])
+
+mu_t = np.array([4, 4])
+cov_t = np.array([[1, -.8], [-.8, 1]])
+
+xs = ot.datasets.make_2D_samples_gauss(n, mu_s, cov_s)
+xt = ot.datasets.make_2D_samples_gauss(n, mu_t, cov_t)
+
+a, b = np.ones((n,)) / n, np.ones((n,)) / n # uniform distribution on samples
+
+##############################################################################
+# Plot data
+# ---------
+
+# %% plot samples
+
+pl.figure(1)
+pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
+pl.legend(loc=0)
+pl.title('Source and target distributions')
+
+###############################################################################
+# Sliced Wasserstein distance for different seeds and number of projections
+# -------------------------------------------------------------------------
+
+n_seed = 50
+n_projections_arr = np.logspace(0, 3, 25, dtype=int)
+res = np.empty((n_seed, 25))
+
+# %% Compute statistics
+for seed in range(n_seed):
+ for i, n_projections in enumerate(n_projections_arr):
+ res[seed, i] = ot.sliced_wasserstein_distance(xs, xt, a, b, n_projections, seed=seed)
+
+res_mean = np.mean(res, axis=0)
+res_std = np.std(res, axis=0)
+
+###############################################################################
+# Plot Sliced Wasserstein Distance
+# --------------------------------
+
+pl.figure(2)
+pl.plot(n_projections_arr, res_mean, label="SWD")
+pl.fill_between(n_projections_arr, res_mean - 2 * res_std, res_mean + 2 * res_std, alpha=0.5)
+
+pl.legend()
+pl.xscale('log')
+
+pl.xlabel("Number of projections")
+pl.ylabel("Distance")
+pl.title('Sliced Wasserstein Distance with 95% confidence inverval')
+
+pl.show()
diff --git a/examples/unbalanced-partial/plot_UOT_1D.py b/examples/unbalanced-partial/plot_UOT_1D.py
index 2ea8b05..183849c 100644
--- a/examples/unbalanced-partial/plot_UOT_1D.py
+++ b/examples/unbalanced-partial/plot_UOT_1D.py
@@ -61,8 +61,7 @@ ot.plot.plot1D_mat(a, b, M, 'Cost matrix M')
##############################################################################
# Solve Unbalanced Sinkhorn
-# --------------
-
+# -------------------------
# Sinkhorn
diff --git a/examples/unbalanced-partial/plot_partial_wass_and_gromov.py b/examples/unbalanced-partial/plot_partial_wass_and_gromov.py
index 0c5cbf9..ac4194c 100755
--- a/examples/unbalanced-partial/plot_partial_wass_and_gromov.py
+++ b/examples/unbalanced-partial/plot_partial_wass_and_gromov.py
@@ -4,7 +4,7 @@
Partial Wasserstein and Gromov-Wasserstein example
==================================================
-This example is designed to show how to use the Partial (Gromov-)Wassertsein
+This example is designed to show how to use the Partial (Gromov-)Wasserstein
distance computation in POT.
"""
@@ -123,11 +123,12 @@ C1 = sp.spatial.distance.cdist(xs, xs)
C2 = sp.spatial.distance.cdist(xt, xt)
# transport 100% of the mass
-print('-----m = 1')
+print('------m = 1')
m = 1
res0, log0 = ot.partial.partial_gromov_wasserstein(C1, C2, p, q, m=m, log=True)
res, log = ot.partial.entropic_partial_gromov_wasserstein(C1, C2, p, q, 10,
- m=m, log=True)
+ m=m, log=True,
+ verbose=True)
print('Wasserstein distance (m = 1): ' + str(log0['partial_gw_dist']))
print('Entropic Wasserstein distance (m = 1): ' + str(log['partial_gw_dist']))
@@ -136,18 +137,20 @@ pl.figure(1, (10, 5))
pl.title("mass to be transported m = 1")
pl.subplot(1, 2, 1)
pl.imshow(res0, cmap='jet')
-pl.title('Wasserstein')
+pl.title('Gromov-Wasserstein')
pl.subplot(1, 2, 2)
pl.imshow(res, cmap='jet')
-pl.title('Entropic Wasserstein')
+pl.title('Entropic Gromov-Wasserstein')
pl.show()
# transport 2/3 of the mass
-print('-----m = 2/3')
+print('------m = 2/3')
m = 2 / 3
-res0, log0 = ot.partial.partial_gromov_wasserstein(C1, C2, p, q, m=m, log=True)
+res0, log0 = ot.partial.partial_gromov_wasserstein(C1, C2, p, q, m=m, log=True,
+ verbose=True)
res, log = ot.partial.entropic_partial_gromov_wasserstein(C1, C2, p, q, 10,
- m=m, log=True)
+ m=m, log=True,
+ verbose=True)
print('Partial Wasserstein distance (m = 2/3): ' +
str(log0['partial_gw_dist']))
@@ -158,8 +161,8 @@ pl.figure(1, (10, 5))
pl.title("mass to be transported m = 2/3")
pl.subplot(1, 2, 1)
pl.imshow(res0, cmap='jet')
-pl.title('Partial Wasserstein')
+pl.title('Partial Gromov-Wasserstein')
pl.subplot(1, 2, 2)
pl.imshow(res, cmap='jet')
-pl.title('Entropic partial Wasserstein')
+pl.title('Entropic partial Gromov-Wasserstein')
pl.show()
diff --git a/examples/unbalanced-partial/plot_regpath.py b/examples/unbalanced-partial/plot_regpath.py
new file mode 100644
index 0000000..4a51c2d
--- /dev/null
+++ b/examples/unbalanced-partial/plot_regpath.py
@@ -0,0 +1,135 @@
+# -*- coding: utf-8 -*-
+"""
+================================================================
+Regularization path of l2-penalized unbalanced optimal transport
+================================================================
+This example illustrate the regularization path for 2D unbalanced
+optimal transport. We present here both the fully relaxed case
+and the semi-relaxed case.
+
+[Chapel et al., 2021] Chapel, L., Flamary, R., Wu, H., Févotte, C.,
+and Gasso, G. (2021). Unbalanced optimal transport through non-negative
+penalized linear regression.
+"""
+
+# Author: Haoran Wu <haoran.wu@univ-ubs.fr>
+# License: MIT License
+
+
+import numpy as np
+import matplotlib.pylab as pl
+import ot
+
+##############################################################################
+# Generate data
+# -------------
+
+#%% parameters and data generation
+
+n = 50 # nb samples
+
+mu_s = np.array([-1, -1])
+cov_s = np.array([[1, 0], [0, 1]])
+
+mu_t = np.array([4, 4])
+cov_t = np.array([[1, -.8], [-.8, 1]])
+
+np.random.seed(0)
+xs = ot.datasets.make_2D_samples_gauss(n, mu_s, cov_s)
+xt = ot.datasets.make_2D_samples_gauss(n, mu_t, cov_t)
+
+a, b = np.ones((n,)) / n, np.ones((n,)) / n # uniform distribution on samples
+
+# loss matrix
+M = ot.dist(xs, xt)
+M /= M.max()
+
+##############################################################################
+# Plot data
+# ---------
+
+#%% plot 2 distribution samples
+
+pl.figure(1)
+pl.scatter(xs[:, 0], xs[:, 1], c='C0', label='Source')
+pl.scatter(xt[:, 0], xt[:, 1], c='C1', label='Target')
+pl.legend(loc=2)
+pl.title('Source and target distributions')
+pl.show()
+
+##############################################################################
+# Compute semi-relaxed and fully relaxed regularization paths
+# -----------
+
+#%%
+final_gamma = 1e-8
+t, t_list, g_list = ot.regpath.regularization_path(a, b, M, reg=final_gamma,
+ semi_relaxed=False)
+t2, t_list2, g_list2 = ot.regpath.regularization_path(a, b, M, reg=final_gamma,
+ semi_relaxed=True)
+
+
+##############################################################################
+# Plot the regularization path
+# ----------------
+
+#%% fully relaxed l2-penalized UOT
+
+pl.figure(2)
+selected_gamma = [2e-1, 1e-1, 5e-2, 1e-3]
+for p in range(4):
+ tp = ot.regpath.compute_transport_plan(selected_gamma[p], g_list,
+ t_list)
+ P = tp.reshape((n, n))
+ pl.subplot(2, 2, p + 1)
+ if P.sum() > 0:
+ P = P / P.max()
+ for i in range(n):
+ for j in range(n):
+ if P[i, j] > 0:
+ pl.plot([xs[i, 0], xt[j, 0]], [xs[i, 1], xt[j, 1]], color='C2',
+ alpha=P[i, j] * 0.3)
+ pl.scatter(xs[:, 0], xs[:, 1], c='C0', alpha=0.2)
+ pl.scatter(xt[:, 0], xt[:, 1], c='C1', alpha=0.2)
+ pl.scatter(xs[:, 0], xs[:, 1], c='C0', s=P.sum(1).ravel() * (1 + p) * 2,
+ label='Re-weighted source', alpha=1)
+ pl.scatter(xt[:, 0], xt[:, 1], c='C1', s=P.sum(0).ravel() * (1 + p) * 2,
+ label='Re-weighted target', alpha=1)
+ pl.plot([], [], color='C2', alpha=0.8, label='OT plan')
+ pl.title(r'$\ell_2$ UOT $\gamma$={}'.format(selected_gamma[p]),
+ fontsize=11)
+ if p < 2:
+ pl.xticks(())
+pl.show()
+
+
+##############################################################################
+# Plot the semi-relaxed regularization path
+# -------------------
+
+#%% semi-relaxed l2-penalized UOT
+
+pl.figure(3)
+selected_gamma = [10, 1, 1e-1, 1e-2]
+for p in range(4):
+ tp = ot.regpath.compute_transport_plan(selected_gamma[p], g_list2,
+ t_list2)
+ P = tp.reshape((n, n))
+ pl.subplot(2, 2, p + 1)
+ if P.sum() > 0:
+ P = P / P.max()
+ for i in range(n):
+ for j in range(n):
+ if P[i, j] > 0:
+ pl.plot([xs[i, 0], xt[j, 0]], [xs[i, 1], xt[j, 1]], color='C2',
+ alpha=P[i, j] * 0.3)
+ pl.scatter(xs[:, 0], xs[:, 1], c='C0', alpha=0.2)
+ pl.scatter(xt[:, 0], xt[:, 1], c='C1', alpha=1, label='Target marginal')
+ pl.scatter(xs[:, 0], xs[:, 1], c='C0', s=P.sum(1).ravel() * 2 * (1 + p),
+ label='Source marginal', alpha=1)
+ pl.plot([], [], color='C2', alpha=0.8, label='OT plan')
+ pl.title(r'Semi-relaxed $l_2$ UOT $\gamma$={}'.format(selected_gamma[p]),
+ fontsize=11)
+ if p < 2:
+ pl.xticks(())
+pl.show()