summaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2022-04-05 11:57:10 +0200
committerGitHub <noreply@github.com>2022-04-05 11:57:10 +0200
commitad02112d4288f3efdd5bc6fc6e45444313bba871 (patch)
treef6cd539450c2ed36cf5d7014debfd82e8b9fddfb /examples
parent0afd84d744a472903d427e3c7ae32e55fdd7b9a7 (diff)
[MRG] Update examples in the doc (#359)
* add transparent color logo * add transparent color logo * move screenkhorn * move stochastic and install ffmpeg on circleci * try something * add sudo * install ffmpeg before python * cleanup examples * test svg scrapper * add animation for reg path * better example OT sivergence * update ttles and add plots * update free support * proper figure indexes * have less frame sin animation * update readme and release file * add tests for python 3.10
Diffstat (limited to 'examples')
-rw-r--r--examples/backends/plot_sliced_wass_grad_flow_pytorch.py2
-rw-r--r--examples/backends/plot_wass1d_torch.py8
-rw-r--r--examples/barycenters/plot_free_support_barycenter.py55
-rw-r--r--examples/others/plot_logo.py8
-rw-r--r--examples/others/plot_screenkhorn_1D.py (renamed from examples/plot_screenkhorn_1D.py)6
-rw-r--r--examples/others/plot_stochastic.py (renamed from examples/plot_stochastic.py)0
-rw-r--r--examples/plot_OT_1D.py12
-rw-r--r--examples/plot_OT_1D_smooth.py6
-rw-r--r--examples/plot_OT_2D_samples.py2
-rw-r--r--examples/plot_OT_L1_vs_L2.py32
-rw-r--r--examples/plot_compute_emd.py72
-rw-r--r--examples/plot_optim_OTreg.py38
-rw-r--r--examples/sliced-wasserstein/README.txt2
-rw-r--r--examples/sliced-wasserstein/plot_variance.py8
-rw-r--r--examples/unbalanced-partial/plot_UOT_1D.py17
-rw-r--r--examples/unbalanced-partial/plot_regpath.py88
16 files changed, 260 insertions, 96 deletions
diff --git a/examples/backends/plot_sliced_wass_grad_flow_pytorch.py b/examples/backends/plot_sliced_wass_grad_flow_pytorch.py
index 05b9952..cf5d64d 100644
--- a/examples/backends/plot_sliced_wass_grad_flow_pytorch.py
+++ b/examples/backends/plot_sliced_wass_grad_flow_pytorch.py
@@ -27,6 +27,8 @@ Machine Learning (pp. 4104-4113). PMLR.
#
# License: MIT License
+# sphinx_gallery_thumbnail_number = 4
+
# %%
# Loading the data
diff --git a/examples/backends/plot_wass1d_torch.py b/examples/backends/plot_wass1d_torch.py
index 0abdd6d..cd8e2fd 100644
--- a/examples/backends/plot_wass1d_torch.py
+++ b/examples/backends/plot_wass1d_torch.py
@@ -1,9 +1,9 @@
r"""
-=================================
-Wasserstein 1D with PyTorch
-=================================
+=================================================
+Wasserstein 1D (flow and barycenter) with PyTorch
+=================================================
-In this small example, we consider the following minization problem:
+In this small example, we consider the following minimization problem:
.. math::
\mu^* = \min_\mu W(\mu,\nu)
diff --git a/examples/barycenters/plot_free_support_barycenter.py b/examples/barycenters/plot_free_support_barycenter.py
index 2d68a39..226dfeb 100644
--- a/examples/barycenters/plot_free_support_barycenter.py
+++ b/examples/barycenters/plot_free_support_barycenter.py
@@ -9,61 +9,62 @@ sum of diracs.
"""
-# Author: Vivien Seguy <vivien.seguy@iip.ist.i.kyoto-u.ac.jp>
+# Authors: Vivien Seguy <vivien.seguy@iip.ist.i.kyoto-u.ac.jp>
+# Rémi Flamary <remi.flamary@polytechnique.edu>
#
# License: MIT License
+# sphinx_gallery_thumbnail_number = 2
+
import numpy as np
import matplotlib.pylab as pl
import ot
-##############################################################################
+# %%
# Generate data
# -------------
-N = 3
+N = 2
d = 2
-measures_locations = []
-measures_weights = []
-
-for i in range(N):
- n_i = np.random.randint(low=1, high=20) # nb samples
+I1 = pl.imread('../../data/redcross.png').astype(np.float64)[::4, ::4, 2]
+I2 = pl.imread('../../data/duck.png').astype(np.float64)[::4, ::4, 2]
- mu_i = np.random.normal(0., 4., (d,)) # Gaussian mean
+sz = I2.shape[0]
+XX, YY = np.meshgrid(np.arange(sz), np.arange(sz))
- A_i = np.random.rand(d, d)
- cov_i = np.dot(A_i, A_i.transpose()) # Gaussian covariance matrix
+x1 = np.stack((XX[I1 == 0], YY[I1 == 0]), 1) * 1.0
+x2 = np.stack((XX[I2 == 0] + 80, -YY[I2 == 0] + 32), 1) * 1.0
+x3 = np.stack((XX[I2 == 0], -YY[I2 == 0] + 32), 1) * 1.0
- x_i = ot.datasets.make_2D_samples_gauss(n_i, mu_i, cov_i) # Dirac locations
- b_i = np.random.uniform(0., 1., (n_i,))
- b_i = b_i / np.sum(b_i) # Dirac weights
+measures_locations = [x1, x2]
+measures_weights = [ot.unif(x1.shape[0]), ot.unif(x2.shape[0])]
- measures_locations.append(x_i)
- measures_weights.append(b_i)
+pl.figure(1, (12, 4))
+pl.scatter(x1[:, 0], x1[:, 1], alpha=0.5)
+pl.scatter(x2[:, 0], x2[:, 1], alpha=0.5)
+pl.title('Distributions')
-##############################################################################
+# %%
# Compute free support barycenter
# -------------------------------
-k = 10 # number of Diracs of the barycenter
+k = 200 # number of Diracs of the barycenter
X_init = np.random.normal(0., 1., (k, d)) # initial Dirac locations
b = np.ones((k,)) / k # weights of the barycenter (it will not be optimized, only the locations are optimized)
X = ot.lp.free_support_barycenter(measures_locations, measures_weights, X_init, b)
-
-##############################################################################
-# Plot data
+# %%
+# Plot the barycenter
# ---------
-pl.figure(1)
-for (x_i, b_i) in zip(measures_locations, measures_weights):
- color = np.random.randint(low=1, high=10 * N)
- pl.scatter(x_i[:, 0], x_i[:, 1], s=b_i * 1000, label='input measure')
-pl.scatter(X[:, 0], X[:, 1], s=b * 1000, c='black', marker='^', label='2-Wasserstein barycenter')
+pl.figure(2, (8, 3))
+pl.scatter(x1[:, 0], x1[:, 1], alpha=0.5)
+pl.scatter(x2[:, 0], x2[:, 1], alpha=0.5)
+pl.scatter(X[:, 0], X[:, 1], s=b * 1000, marker='s', label='2-Wasserstein barycenter')
pl.title('Data measures and their barycenter')
-pl.legend(loc=0)
+pl.legend(loc="lower right")
pl.show()
diff --git a/examples/others/plot_logo.py b/examples/others/plot_logo.py
index afddcad..9414371 100644
--- a/examples/others/plot_logo.py
+++ b/examples/others/plot_logo.py
@@ -7,8 +7,8 @@ Logo of the POT toolbox
In this example we plot the logo of the POT toolbox.
-A specificity of this logo is that it is done 100% in Python and generated using
-matplotlib using the EMD solver from POT.
+This logo is that it is done 100% in Python and generated using
+matplotlib and ploting teh solution of the EMD solver from POT.
"""
@@ -86,8 +86,8 @@ pl.axis('equal')
pl.axis('off')
# Save logo file
-# pl.savefig('logo.svg', dpi=150, bbox_inches='tight')
-# pl.savefig('logo.png', dpi=150, bbox_inches='tight')
+# pl.savefig('logo.svg', dpi=150, transparent=True, bbox_inches='tight')
+# pl.savefig('logo.png', dpi=150, transparent=True, bbox_inches='tight')
# %%
# Plot the logo (dark background)
diff --git a/examples/plot_screenkhorn_1D.py b/examples/others/plot_screenkhorn_1D.py
index 785642a..2023649 100644
--- a/examples/plot_screenkhorn_1D.py
+++ b/examples/others/plot_screenkhorn_1D.py
@@ -1,8 +1,8 @@
# -*- coding: utf-8 -*-
"""
-===============================
-1D Screened optimal transport
-===============================
+========================================
+Screened optimal transport (Screenkhorn)
+========================================
This example illustrates the computation of Screenkhorn [26].
diff --git a/examples/plot_stochastic.py b/examples/others/plot_stochastic.py
index 3a1ef31..3a1ef31 100644
--- a/examples/plot_stochastic.py
+++ b/examples/others/plot_stochastic.py
diff --git a/examples/plot_OT_1D.py b/examples/plot_OT_1D.py
index 15ead96..62f0b7d 100644
--- a/examples/plot_OT_1D.py
+++ b/examples/plot_OT_1D.py
@@ -1,8 +1,8 @@
# -*- coding: utf-8 -*-
"""
-====================
-1D optimal transport
-====================
+======================================
+Optimal Transport for 1D distributions
+======================================
This example illustrates the computation of EMD and Sinkhorn transport plans
and their visualization.
@@ -64,7 +64,11 @@ ot.plot.plot1D_mat(a, b, M, 'Cost matrix M')
#%% EMD
-G0 = ot.emd(a, b, M)
+# use fast 1D solver
+G0 = ot.emd_1d(x, x, a, b)
+
+# Equivalent to
+# G0 = ot.emd(a, b, M)
pl.figure(3, figsize=(5, 5))
ot.plot.plot1D_mat(a, b, G0, 'OT matrix G0')
diff --git a/examples/plot_OT_1D_smooth.py b/examples/plot_OT_1D_smooth.py
index b07f99f..5415e4f 100644
--- a/examples/plot_OT_1D_smooth.py
+++ b/examples/plot_OT_1D_smooth.py
@@ -1,8 +1,8 @@
# -*- coding: utf-8 -*-
"""
-===========================
-1D smooth optimal transport
-===========================
+================================
+Smooth optimal transport example
+================================
This example illustrates the computation of EMD, Sinkhorn and smooth OT plans
and their visualization.
diff --git a/examples/plot_OT_2D_samples.py b/examples/plot_OT_2D_samples.py
index c3a7cd8..1d82fb8 100644
--- a/examples/plot_OT_2D_samples.py
+++ b/examples/plot_OT_2D_samples.py
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
"""
====================================================
-2D Optimal transport between empirical distributions
+Optimal Transport between 2D empirical distributions
====================================================
Illustration of 2D optimal transport between discributions that are weighted
diff --git a/examples/plot_OT_L1_vs_L2.py b/examples/plot_OT_L1_vs_L2.py
index cb94574..cce51f8 100644
--- a/examples/plot_OT_L1_vs_L2.py
+++ b/examples/plot_OT_L1_vs_L2.py
@@ -1,10 +1,10 @@
# -*- coding: utf-8 -*-
"""
-==========================================
-2D Optimal transport for different metrics
-==========================================
+================================================
+Optimal Transport with different gournd metrics
+================================================
-2D OT on empirical distributio with different gound metric.
+2D OT on empirical distributio with different ground metric.
Stole the figure idea from Fig. 1 and 2 in
https://arxiv.org/pdf/1706.07650.pdf
@@ -23,7 +23,7 @@ import matplotlib.pylab as pl
import ot
import ot.plot
-##############################################################################
+# %%
# Dataset 1 : uniform sampling
# ----------------------------
@@ -46,7 +46,7 @@ M2 = ot.dist(xs, xt, metric='sqeuclidean')
M2 /= M2.max()
# loss matrix
-Mp = np.sqrt(ot.dist(xs, xt, metric='euclidean'))
+Mp = ot.dist(xs, xt, metric='cityblock')
Mp /= Mp.max()
# Data
@@ -71,7 +71,7 @@ pl.title('Squared Euclidean cost')
pl.subplot(1, 3, 3)
pl.imshow(Mp, interpolation='nearest')
-pl.title('Sqrt Euclidean cost')
+pl.title('L1 (cityblock cost')
pl.tight_layout()
##############################################################################
@@ -109,22 +109,22 @@ pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
pl.axis('equal')
# pl.legend(loc=0)
-pl.title('OT sqrt Euclidean')
+pl.title('OT L1 (cityblock)')
pl.tight_layout()
pl.show()
-##############################################################################
+# %%
# Dataset 2 : Partial circle
# --------------------------
-n = 50 # nb samples
+n = 20 # nb samples
xtot = np.zeros((n + 1, 2))
xtot[:, 0] = np.cos(
- (np.arange(n + 1) + 1.0) * 0.9 / (n + 2) * 2 * np.pi)
+ (np.arange(n + 1) + 1.0) * 0.8 / (n + 2) * 2 * np.pi)
xtot[:, 1] = np.sin(
- (np.arange(n + 1) + 1.0) * 0.9 / (n + 2) * 2 * np.pi)
+ (np.arange(n + 1) + 1.0) * 0.8 / (n + 2) * 2 * np.pi)
xs = xtot[:n, :]
xt = xtot[1:, :]
@@ -140,7 +140,7 @@ M2 = ot.dist(xs, xt, metric='sqeuclidean')
M2 /= M2.max()
# loss matrix
-Mp = np.sqrt(ot.dist(xs, xt, metric='euclidean'))
+Mp = ot.dist(xs, xt, metric='cityblock')
Mp /= Mp.max()
@@ -166,13 +166,13 @@ pl.title('Squared Euclidean cost')
pl.subplot(1, 3, 3)
pl.imshow(Mp, interpolation='nearest')
-pl.title('Sqrt Euclidean cost')
+pl.title('L1 (cityblock) cost')
pl.tight_layout()
##############################################################################
# Dataset 2 : Plot OT Matrices
# -----------------------------
-
+#
#%% EMD
G1 = ot.emd(a, b, M1)
@@ -204,7 +204,7 @@ pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
pl.axis('equal')
# pl.legend(loc=0)
-pl.title('OT sqrt Euclidean')
+pl.title('OT L1 (cityblock)')
pl.tight_layout()
pl.show()
diff --git a/examples/plot_compute_emd.py b/examples/plot_compute_emd.py
index 527a847..36cc7da 100644
--- a/examples/plot_compute_emd.py
+++ b/examples/plot_compute_emd.py
@@ -1,10 +1,10 @@
# -*- coding: utf-8 -*-
"""
-=================
-Plot multiple EMD
-=================
+==================
+OT distances in 1D
+==================
-Shows how to compute multiple EMD and Sinkhorn with two different
+Shows how to compute multiple Wassersein and Sinkhorn with two different
ground metrics and plot their values for different distributions.
@@ -14,7 +14,7 @@ ground metrics and plot their values for different distributions.
#
# License: MIT License
-# sphinx_gallery_thumbnail_number = 3
+# sphinx_gallery_thumbnail_number = 2
import numpy as np
import matplotlib.pylab as pl
@@ -29,7 +29,7 @@ from ot.datasets import make_1D_gauss as gauss
#%% parameters
n = 100 # nb bins
-n_target = 50 # nb target distributions
+n_target = 20 # nb target distributions
# bin positions
@@ -47,9 +47,9 @@ for i, m in enumerate(lst_m):
# loss matrix and normalization
M = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)), 'euclidean')
-M /= M.max()
+M /= M.max() * 0.1
M2 = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)), 'sqeuclidean')
-M2 /= M2.max()
+M2 /= M2.max() * 0.1
##############################################################################
# Plot data
@@ -59,10 +59,12 @@ M2 /= M2.max()
pl.figure(1)
pl.subplot(2, 1, 1)
-pl.plot(x, a, 'b', label='Source distribution')
+pl.plot(x, a, 'r', label='Source distribution')
pl.title('Source distribution')
pl.subplot(2, 1, 2)
-pl.plot(x, B, label='Target distributions')
+for i in range(n_target):
+ pl.plot(x, B[:, i], 'b', alpha=i / n_target)
+pl.plot(x, B[:, -1], 'b', label='Target distributions')
pl.title('Target distributions')
pl.tight_layout()
@@ -73,14 +75,27 @@ pl.tight_layout()
#%% Compute and plot distributions and loss matrix
-d_emd = ot.emd2(a, B, M) # direct computation of EMD
-d_emd2 = ot.emd2(a, B, M2) # direct computation of EMD with loss M2
-
+d_emd = ot.emd2(a, B, M) # direct computation of OT loss
+d_emd2 = ot.emd2(a, B, M2) # direct computation of OT loss with metrixc M2
+d_tv = [np.sum(abs(a - B[:, i])) for i in range(n_target)]
pl.figure(2)
-pl.plot(d_emd, label='Euclidean EMD')
-pl.plot(d_emd2, label='Squared Euclidean EMD')
-pl.title('EMD distances')
+pl.subplot(2, 1, 1)
+pl.plot(x, a, 'r', label='Source distribution')
+pl.title('Distributions')
+for i in range(n_target):
+ pl.plot(x, B[:, i], 'b', alpha=i / n_target)
+pl.plot(x, B[:, -1], 'b', label='Target distributions')
+pl.ylim((-.01, 0.13))
+pl.xticks(())
+pl.legend()
+pl.subplot(2, 1, 2)
+pl.plot(d_emd, label='Euclidean OT')
+pl.plot(d_emd2, label='Squared Euclidean OT')
+pl.plot(d_tv, label='Total Variation (TV)')
+#pl.xlim((-7,23))
+pl.xlabel('Displacement')
+pl.title('Divergences')
pl.legend()
##############################################################################
@@ -88,17 +103,30 @@ pl.legend()
# -----------------------------------------
#%%
-reg = 1e-2
+reg = 1e-1
d_sinkhorn = ot.sinkhorn2(a, B, M, reg)
d_sinkhorn2 = ot.sinkhorn2(a, B, M2, reg)
-pl.figure(2)
+pl.figure(3)
pl.clf()
-pl.plot(d_emd, label='Euclidean EMD')
-pl.plot(d_emd2, label='Squared Euclidean EMD')
+
+pl.subplot(2, 1, 1)
+pl.plot(x, a, 'r', label='Source distribution')
+pl.title('Distributions')
+for i in range(n_target):
+ pl.plot(x, B[:, i], 'b', alpha=i / n_target)
+pl.plot(x, B[:, -1], 'b', label='Target distributions')
+pl.ylim((-.01, 0.13))
+pl.xticks(())
+pl.legend()
+pl.subplot(2, 1, 2)
+pl.plot(d_emd, label='Euclidean OT')
+pl.plot(d_emd2, label='Squared Euclidean OT')
pl.plot(d_sinkhorn, '+', label='Euclidean Sinkhorn')
pl.plot(d_sinkhorn2, '+', label='Squared Euclidean Sinkhorn')
-pl.title('EMD distances')
+pl.plot(d_tv, label='Total Variation (TV)')
+#pl.xlim((-7,23))
+pl.xlabel('Displacement')
+pl.title('Divergences')
pl.legend()
-
pl.show()
diff --git a/examples/plot_optim_OTreg.py b/examples/plot_optim_OTreg.py
index 5eb15bd..7b021d2 100644
--- a/examples/plot_optim_OTreg.py
+++ b/examples/plot_optim_OTreg.py
@@ -24,7 +24,7 @@ arXiv preprint arXiv:1510.06567.
"""
-# sphinx_gallery_thumbnail_number = 4
+# sphinx_gallery_thumbnail_number = 5
import numpy as np
import matplotlib.pylab as pl
@@ -58,7 +58,7 @@ M /= M.max()
G0 = ot.emd(a, b, M)
-pl.figure(3, figsize=(5, 5))
+pl.figure(1, figsize=(5, 5))
ot.plot.plot1D_mat(a, b, G0, 'OT matrix G0')
##############################################################################
@@ -80,7 +80,7 @@ reg = 1e-1
Gl2 = ot.optim.cg(a, b, M, reg, f, df, verbose=True)
-pl.figure(3)
+pl.figure(2)
ot.plot.plot1D_mat(a, b, Gl2, 'OT matrix Frob. reg')
##############################################################################
@@ -102,7 +102,7 @@ reg = 1e-3
Ge = ot.optim.cg(a, b, M, reg, f, df, verbose=True)
-pl.figure(4, figsize=(5, 5))
+pl.figure(3, figsize=(5, 5))
ot.plot.plot1D_mat(a, b, Ge, 'OT matrix Entrop. reg')
##############################################################################
@@ -125,6 +125,34 @@ reg2 = 1e-1
Gel2 = ot.optim.gcg(a, b, M, reg1, reg2, f, df, verbose=True)
-pl.figure(5, figsize=(5, 5))
+pl.figure(4, figsize=(5, 5))
ot.plot.plot1D_mat(a, b, Gel2, 'OT entropic + matrix Frob. reg')
pl.show()
+
+
+# %%
+# Comparison of the OT matrices
+
+nvisu = 40
+
+pl.figure(5, figsize=(10, 4))
+
+pl.subplot(2, 2, 1)
+pl.imshow(G0[:nvisu, :])
+pl.axis('off')
+pl.title('Exact OT')
+
+pl.subplot(2, 2, 2)
+pl.imshow(Gl2[:nvisu, :])
+pl.axis('off')
+pl.title('Frobenius reg.')
+
+pl.subplot(2, 2, 3)
+pl.imshow(Ge[:nvisu, :])
+pl.axis('off')
+pl.title('Entropic reg.')
+
+pl.subplot(2, 2, 4)
+pl.imshow(Gel2[:nvisu, :])
+pl.axis('off')
+pl.title('Entropic + Frobenius reg.')
diff --git a/examples/sliced-wasserstein/README.txt b/examples/sliced-wasserstein/README.txt
index a575345..73e6122 100644
--- a/examples/sliced-wasserstein/README.txt
+++ b/examples/sliced-wasserstein/README.txt
@@ -1,4 +1,4 @@
Sliced Wasserstein Distance
---------------------------- \ No newline at end of file
+---------------------------
diff --git a/examples/sliced-wasserstein/plot_variance.py b/examples/sliced-wasserstein/plot_variance.py
index 7d73907..f12b522 100644
--- a/examples/sliced-wasserstein/plot_variance.py
+++ b/examples/sliced-wasserstein/plot_variance.py
@@ -1,8 +1,8 @@
# -*- coding: utf-8 -*-
"""
-==============================
-2D Sliced Wasserstein Distance
-==============================
+===============================================
+Sliced Wasserstein Distance on 2D distributions
+===============================================
This example illustrates the computation of the sliced Wasserstein Distance as
proposed in [31].
@@ -16,6 +16,8 @@ measures." Journal of Mathematical Imaging and Vision 51.1 (2015): 22-45
#
# License: MIT License
+# sphinx_gallery_thumbnail_number = 2
+
import matplotlib.pylab as pl
import numpy as np
diff --git a/examples/unbalanced-partial/plot_UOT_1D.py b/examples/unbalanced-partial/plot_UOT_1D.py
index 183849c..06dd02d 100644
--- a/examples/unbalanced-partial/plot_UOT_1D.py
+++ b/examples/unbalanced-partial/plot_UOT_1D.py
@@ -12,6 +12,8 @@ using a Kullback-Leibler relaxation.
#
# License: MIT License
+# sphinx_gallery_thumbnail_number = 4
+
import numpy as np
import matplotlib.pylab as pl
import ot
@@ -69,7 +71,20 @@ epsilon = 0.1 # entropy parameter
alpha = 1. # Unbalanced KL relaxation parameter
Gs = ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, alpha, verbose=True)
-pl.figure(4, figsize=(5, 5))
+pl.figure(3, figsize=(5, 5))
ot.plot.plot1D_mat(a, b, Gs, 'UOT matrix Sinkhorn')
pl.show()
+
+
+# %%
+# plot the transported mass
+# -------------------------
+
+pl.figure(4, figsize=(6.4, 3))
+pl.plot(x, a, 'b', label='Source distribution')
+pl.plot(x, b, 'r', label='Target distribution')
+pl.fill(x, Gs.sum(1), 'b', alpha=0.5, label='Transported source')
+pl.fill(x, Gs.sum(0), 'r', alpha=0.5, label='Transported target')
+pl.legend(loc='upper right')
+pl.title('Distributions and transported mass for UOT')
diff --git a/examples/unbalanced-partial/plot_regpath.py b/examples/unbalanced-partial/plot_regpath.py
index 4a51c2d..782e8c2 100644
--- a/examples/unbalanced-partial/plot_regpath.py
+++ b/examples/unbalanced-partial/plot_regpath.py
@@ -15,11 +15,12 @@ penalized linear regression.
# Author: Haoran Wu <haoran.wu@univ-ubs.fr>
# License: MIT License
+# sphinx_gallery_thumbnail_number = 2
import numpy as np
import matplotlib.pylab as pl
import ot
-
+import matplotlib.animation as animation
##############################################################################
# Generate data
# -------------
@@ -72,6 +73,9 @@ t2, t_list2, g_list2 = ot.regpath.regularization_path(a, b, M, reg=final_gamma,
##############################################################################
# Plot the regularization path
# ----------------
+#
+# The OT plan is ploted as a function of $\gamma$ that is the inverse of the
+# weight on the marginal relaxations.
#%% fully relaxed l2-penalized UOT
@@ -103,13 +107,53 @@ for p in range(4):
pl.show()
+# %%
+# Animation of the regpath for UOT l2
+# ------------------------
+
+nv = 100
+g_list_v = np.logspace(-.5, -2.5, nv)
+
+pl.figure(3)
+
+
+def _update_plot(iv):
+ pl.clf()
+ tp = ot.regpath.compute_transport_plan(g_list_v[iv], g_list,
+ t_list)
+ P = tp.reshape((n, n))
+ if P.sum() > 0:
+ P = P / P.max()
+ for i in range(n):
+ for j in range(n):
+ if P[i, j] > 0:
+ pl.plot([xs[i, 0], xt[j, 0]], [xs[i, 1], xt[j, 1]], color='C2',
+ alpha=P[i, j] * 0.5)
+ pl.scatter(xs[:, 0], xs[:, 1], c='C0', alpha=0.2)
+ pl.scatter(xt[:, 0], xt[:, 1], c='C1', alpha=0.2)
+ pl.scatter(xs[:, 0], xs[:, 1], c='C0', s=P.sum(1).ravel() * (1 + p) * 4,
+ label='Re-weighted source', alpha=1)
+ pl.scatter(xt[:, 0], xt[:, 1], c='C1', s=P.sum(0).ravel() * (1 + p) * 4,
+ label='Re-weighted target', alpha=1)
+ pl.plot([], [], color='C2', alpha=0.8, label='OT plan')
+ pl.title(r'$\ell_2$ UOT $\gamma$={:1.3f}'.format(g_list_v[iv]),
+ fontsize=11)
+ return 1
+
+
+i = 0
+_update_plot(i)
+
+ani = animation.FuncAnimation(pl.gcf(), _update_plot, nv, interval=50, repeat_delay=2000)
+
+
##############################################################################
# Plot the semi-relaxed regularization path
# -------------------
#%% semi-relaxed l2-penalized UOT
-pl.figure(3)
+pl.figure(4)
selected_gamma = [10, 1, 1e-1, 1e-2]
for p in range(4):
tp = ot.regpath.compute_transport_plan(selected_gamma[p], g_list2,
@@ -133,3 +177,43 @@ for p in range(4):
if p < 2:
pl.xticks(())
pl.show()
+
+
+# %%
+# Animation of the regpath for semi-relaxed UOT l2
+# ------------------------
+
+nv = 100
+g_list_v = np.logspace(2.5, -2, nv)
+
+pl.figure(5)
+
+
+def _update_plot(iv):
+ pl.clf()
+ tp = ot.regpath.compute_transport_plan(g_list_v[iv], g_list2,
+ t_list2)
+ P = tp.reshape((n, n))
+ if P.sum() > 0:
+ P = P / P.max()
+ for i in range(n):
+ for j in range(n):
+ if P[i, j] > 0:
+ pl.plot([xs[i, 0], xt[j, 0]], [xs[i, 1], xt[j, 1]], color='C2',
+ alpha=P[i, j] * 0.5)
+ pl.scatter(xs[:, 0], xs[:, 1], c='C0', alpha=0.2)
+ pl.scatter(xt[:, 0], xt[:, 1], c='C1', alpha=0.2)
+ pl.scatter(xs[:, 0], xs[:, 1], c='C0', s=P.sum(1).ravel() * (1 + p) * 4,
+ label='Re-weighted source', alpha=1)
+ pl.scatter(xt[:, 0], xt[:, 1], c='C1', s=P.sum(0).ravel() * (1 + p) * 4,
+ label='Re-weighted target', alpha=1)
+ pl.plot([], [], color='C2', alpha=0.8, label='OT plan')
+ pl.title(r'Semi-relaxed $\ell_2$ UOT $\gamma$={:1.3f}'.format(g_list_v[iv]),
+ fontsize=11)
+ return 1
+
+
+i = 0
+_update_plot(i)
+
+ani = animation.FuncAnimation(pl.gcf(), _update_plot, nv, interval=50, repeat_delay=2000)