From a54775103541ea37f54269de1ba1e1396a6d7b30 Mon Sep 17 00:00:00 2001 From: Rémi Flamary Date: Fri, 24 Apr 2020 17:32:57 +0200 Subject: exmaples in sections --- examples/unbalanced-partial/README.txt | 3 + examples/unbalanced-partial/plot_UOT_1D.py | 76 ++++++++++ .../unbalanced-partial/plot_UOT_barycenter_1D.py | 166 +++++++++++++++++++++ .../plot_partial_wass_and_gromov.py | 165 ++++++++++++++++++++ 4 files changed, 410 insertions(+) create mode 100644 examples/unbalanced-partial/README.txt create mode 100644 examples/unbalanced-partial/plot_UOT_1D.py create mode 100644 examples/unbalanced-partial/plot_UOT_barycenter_1D.py create mode 100755 examples/unbalanced-partial/plot_partial_wass_and_gromov.py (limited to 'examples/unbalanced-partial') diff --git a/examples/unbalanced-partial/README.txt b/examples/unbalanced-partial/README.txt new file mode 100644 index 0000000..2f404f0 --- /dev/null +++ b/examples/unbalanced-partial/README.txt @@ -0,0 +1,3 @@ + +Unbalanced and Partial OT +------------------------- \ No newline at end of file diff --git a/examples/unbalanced-partial/plot_UOT_1D.py b/examples/unbalanced-partial/plot_UOT_1D.py new file mode 100644 index 0000000..2ea8b05 --- /dev/null +++ b/examples/unbalanced-partial/plot_UOT_1D.py @@ -0,0 +1,76 @@ +# -*- coding: utf-8 -*- +""" +=============================== +1D Unbalanced optimal transport +=============================== + +This example illustrates the computation of Unbalanced Optimal transport +using a Kullback-Leibler relaxation. +""" + +# Author: Hicham Janati +# +# License: MIT License + +import numpy as np +import matplotlib.pylab as pl +import ot +import ot.plot +from ot.datasets import make_1D_gauss as gauss + +############################################################################## +# Generate data +# ------------- + + +#%% parameters + +n = 100 # nb bins + +# bin positions +x = np.arange(n, dtype=np.float64) + +# Gaussian distributions +a = gauss(n, m=20, s=5) # m= mean, s= std +b = gauss(n, m=60, s=10) + +# make distributions unbalanced +b *= 5. + +# loss matrix +M = ot.dist(x.reshape((n, 1)), x.reshape((n, 1))) +M /= M.max() + + +############################################################################## +# Plot distributions and loss matrix +# ---------------------------------- + +#%% plot the distributions + +pl.figure(1, figsize=(6.4, 3)) +pl.plot(x, a, 'b', label='Source distribution') +pl.plot(x, b, 'r', label='Target distribution') +pl.legend() + +# plot distributions and loss matrix + +pl.figure(2, figsize=(5, 5)) +ot.plot.plot1D_mat(a, b, M, 'Cost matrix M') + + +############################################################################## +# Solve Unbalanced Sinkhorn +# -------------- + + +# Sinkhorn + +epsilon = 0.1 # entropy parameter +alpha = 1. # Unbalanced KL relaxation parameter +Gs = ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, alpha, verbose=True) + +pl.figure(4, figsize=(5, 5)) +ot.plot.plot1D_mat(a, b, Gs, 'UOT matrix Sinkhorn') + +pl.show() diff --git a/examples/unbalanced-partial/plot_UOT_barycenter_1D.py b/examples/unbalanced-partial/plot_UOT_barycenter_1D.py new file mode 100644 index 0000000..931798b --- /dev/null +++ b/examples/unbalanced-partial/plot_UOT_barycenter_1D.py @@ -0,0 +1,166 @@ +# -*- coding: utf-8 -*- +""" +=========================================================== +1D Wasserstein barycenter demo for Unbalanced distributions +=========================================================== + +This example illustrates the computation of regularized Wassersyein Barycenter +as proposed in [10] for Unbalanced inputs. + + +[10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. + +""" + +# Author: Hicham Janati +# +# License: MIT License + +# sphinx_gallery_thumbnail_number = 2 + +import numpy as np +import matplotlib.pylab as pl +import ot +# necessary for 3d plot even if not used +from mpl_toolkits.mplot3d import Axes3D # noqa +from matplotlib.collections import PolyCollection + +############################################################################## +# Generate data +# ------------- + +# parameters + +n = 100 # nb bins + +# bin positions +x = np.arange(n, dtype=np.float64) + +# Gaussian distributions +a1 = ot.datasets.make_1D_gauss(n, m=20, s=5) # m= mean, s= std +a2 = ot.datasets.make_1D_gauss(n, m=60, s=8) + +# make unbalanced dists +a2 *= 3. + +# creating matrix A containing all distributions +A = np.vstack((a1, a2)).T +n_distributions = A.shape[1] + +# loss matrix + normalization +M = ot.utils.dist0(n) +M /= M.max() + +############################################################################## +# Plot data +# --------- + +# plot the distributions + +pl.figure(1, figsize=(6.4, 3)) +for i in range(n_distributions): + pl.plot(x, A[:, i]) +pl.title('Distributions') +pl.tight_layout() + +############################################################################## +# Barycenter computation +# ---------------------- + +# non weighted barycenter computation + +weight = 0.5 # 0<=weight<=1 +weights = np.array([1 - weight, weight]) + +# l2bary +bary_l2 = A.dot(weights) + +# wasserstein +reg = 1e-3 +alpha = 1. + +bary_wass = ot.unbalanced.barycenter_unbalanced(A, M, reg, alpha, weights=weights) + +pl.figure(2) +pl.clf() +pl.subplot(2, 1, 1) +for i in range(n_distributions): + pl.plot(x, A[:, i]) +pl.title('Distributions') + +pl.subplot(2, 1, 2) +pl.plot(x, bary_l2, 'r', label='l2') +pl.plot(x, bary_wass, 'g', label='Wasserstein') +pl.legend() +pl.title('Barycenters') +pl.tight_layout() + +############################################################################## +# Barycentric interpolation +# ------------------------- + +# barycenter interpolation + +n_weight = 11 +weight_list = np.linspace(0, 1, n_weight) + + +B_l2 = np.zeros((n, n_weight)) + +B_wass = np.copy(B_l2) + +for i in range(0, n_weight): + weight = weight_list[i] + weights = np.array([1 - weight, weight]) + B_l2[:, i] = A.dot(weights) + B_wass[:, i] = ot.unbalanced.barycenter_unbalanced(A, M, reg, alpha, weights=weights) + + +# plot interpolation + +pl.figure(3) + +cmap = pl.cm.get_cmap('viridis') +verts = [] +zs = weight_list +for i, z in enumerate(zs): + ys = B_l2[:, i] + verts.append(list(zip(x, ys))) + +ax = pl.gcf().gca(projection='3d') + +poly = PolyCollection(verts, facecolors=[cmap(a) for a in weight_list]) +poly.set_alpha(0.7) +ax.add_collection3d(poly, zs=zs, zdir='y') +ax.set_xlabel('x') +ax.set_xlim3d(0, n) +ax.set_ylabel(r'$\alpha$') +ax.set_ylim3d(0, 1) +ax.set_zlabel('') +ax.set_zlim3d(0, B_l2.max() * 1.01) +pl.title('Barycenter interpolation with l2') +pl.tight_layout() + +pl.figure(4) +cmap = pl.cm.get_cmap('viridis') +verts = [] +zs = weight_list +for i, z in enumerate(zs): + ys = B_wass[:, i] + verts.append(list(zip(x, ys))) + +ax = pl.gcf().gca(projection='3d') + +poly = PolyCollection(verts, facecolors=[cmap(a) for a in weight_list]) +poly.set_alpha(0.7) +ax.add_collection3d(poly, zs=zs, zdir='y') +ax.set_xlabel('x') +ax.set_xlim3d(0, n) +ax.set_ylabel(r'$\alpha$') +ax.set_ylim3d(0, 1) +ax.set_zlabel('') +ax.set_zlim3d(0, B_l2.max() * 1.01) +pl.title('Barycenter interpolation with Wasserstein') +pl.tight_layout() + +pl.show() diff --git a/examples/unbalanced-partial/plot_partial_wass_and_gromov.py b/examples/unbalanced-partial/plot_partial_wass_and_gromov.py new file mode 100755 index 0000000..0c5cbf9 --- /dev/null +++ b/examples/unbalanced-partial/plot_partial_wass_and_gromov.py @@ -0,0 +1,165 @@ +# -*- coding: utf-8 -*- +""" +================================================== +Partial Wasserstein and Gromov-Wasserstein example +================================================== + +This example is designed to show how to use the Partial (Gromov-)Wassertsein +distance computation in POT. +""" + +# Author: Laetitia Chapel +# License: MIT License + +# sphinx_gallery_thumbnail_number = 2 + +# necessary for 3d plot even if not used +from mpl_toolkits.mplot3d import Axes3D # noqa +import scipy as sp +import numpy as np +import matplotlib.pylab as pl +import ot + + +############################################################################# +# +# Sample two 2D Gaussian distributions and plot them +# -------------------------------------------------- +# +# For demonstration purpose, we sample two Gaussian distributions in 2-d +# spaces and add some random noise. + + +n_samples = 20 # nb samples (gaussian) +n_noise = 20 # nb of samples (noise) + +mu = np.array([0, 0]) +cov = np.array([[1, 0], [0, 2]]) + +xs = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov) +xs = np.append(xs, (np.random.rand(n_noise, 2) + 1) * 4).reshape((-1, 2)) +xt = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov) +xt = np.append(xt, (np.random.rand(n_noise, 2) + 1) * -3).reshape((-1, 2)) + +M = sp.spatial.distance.cdist(xs, xt) + +fig = pl.figure() +ax1 = fig.add_subplot(131) +ax1.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples') +ax2 = fig.add_subplot(132) +ax2.scatter(xt[:, 0], xt[:, 1], color='r') +ax3 = fig.add_subplot(133) +ax3.imshow(M) +pl.show() + +############################################################################# +# +# Compute partial Wasserstein plans and distance +# ---------------------------------------------- + +p = ot.unif(n_samples + n_noise) +q = ot.unif(n_samples + n_noise) + +w0, log0 = ot.partial.partial_wasserstein(p, q, M, m=0.5, log=True) +w, log = ot.partial.entropic_partial_wasserstein(p, q, M, reg=0.1, m=0.5, + log=True) + +print('Partial Wasserstein distance (m = 0.5): ' + str(log0['partial_w_dist'])) +print('Entropic partial Wasserstein distance (m = 0.5): ' + + str(log['partial_w_dist'])) + +pl.figure(1, (10, 5)) +pl.subplot(1, 2, 1) +pl.imshow(w0, cmap='jet') +pl.title('Partial Wasserstein') +pl.subplot(1, 2, 2) +pl.imshow(w, cmap='jet') +pl.title('Entropic partial Wasserstein') +pl.show() + + +############################################################################# +# +# Sample one 2D and 3D Gaussian distributions and plot them +# --------------------------------------------------------- +# +# The Gromov-Wasserstein distance allows to compute distances with samples that +# do not belong to the same metric space. For demonstration purpose, we sample +# two Gaussian distributions in 2- and 3-dimensional spaces. + +n_samples = 20 # nb samples +n_noise = 10 # nb of samples (noise) + +p = ot.unif(n_samples + n_noise) +q = ot.unif(n_samples + n_noise) + +mu_s = np.array([0, 0]) +cov_s = np.array([[1, 0], [0, 1]]) + +mu_t = np.array([0, 0, 0]) +cov_t = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) + + +xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s) +xs = np.concatenate((xs, ((np.random.rand(n_noise, 2) + 1) * 4)), axis=0) +P = sp.linalg.sqrtm(cov_t) +xt = np.random.randn(n_samples, 3).dot(P) + mu_t +xt = np.concatenate((xt, ((np.random.rand(n_noise, 3) + 1) * 10)), axis=0) + +fig = pl.figure() +ax1 = fig.add_subplot(121) +ax1.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples') +ax2 = fig.add_subplot(122, projection='3d') +ax2.scatter(xt[:, 0], xt[:, 1], xt[:, 2], color='r') +pl.show() + + +############################################################################# +# +# Compute partial Gromov-Wasserstein plans and distance +# ----------------------------------------------------- + +C1 = sp.spatial.distance.cdist(xs, xs) +C2 = sp.spatial.distance.cdist(xt, xt) + +# transport 100% of the mass +print('-----m = 1') +m = 1 +res0, log0 = ot.partial.partial_gromov_wasserstein(C1, C2, p, q, m=m, log=True) +res, log = ot.partial.entropic_partial_gromov_wasserstein(C1, C2, p, q, 10, + m=m, log=True) + +print('Wasserstein distance (m = 1): ' + str(log0['partial_gw_dist'])) +print('Entropic Wasserstein distance (m = 1): ' + str(log['partial_gw_dist'])) + +pl.figure(1, (10, 5)) +pl.title("mass to be transported m = 1") +pl.subplot(1, 2, 1) +pl.imshow(res0, cmap='jet') +pl.title('Wasserstein') +pl.subplot(1, 2, 2) +pl.imshow(res, cmap='jet') +pl.title('Entropic Wasserstein') +pl.show() + +# transport 2/3 of the mass +print('-----m = 2/3') +m = 2 / 3 +res0, log0 = ot.partial.partial_gromov_wasserstein(C1, C2, p, q, m=m, log=True) +res, log = ot.partial.entropic_partial_gromov_wasserstein(C1, C2, p, q, 10, + m=m, log=True) + +print('Partial Wasserstein distance (m = 2/3): ' + + str(log0['partial_gw_dist'])) +print('Entropic partial Wasserstein distance (m = 2/3): ' + + str(log['partial_gw_dist'])) + +pl.figure(1, (10, 5)) +pl.title("mass to be transported m = 2/3") +pl.subplot(1, 2, 1) +pl.imshow(res0, cmap='jet') +pl.title('Partial Wasserstein') +pl.subplot(1, 2, 2) +pl.imshow(res, cmap='jet') +pl.title('Entropic partial Wasserstein') +pl.show() -- cgit v1.2.3