diff options
Diffstat (limited to 'examples/unbalanced-partial')
-rw-r--r-- | examples/unbalanced-partial/README.txt | 3 | ||||
-rw-r--r-- | examples/unbalanced-partial/plot_UOT_1D.py | 76 | ||||
-rw-r--r-- | examples/unbalanced-partial/plot_UOT_barycenter_1D.py | 166 | ||||
-rwxr-xr-x | examples/unbalanced-partial/plot_partial_wass_and_gromov.py | 165 |
4 files changed, 410 insertions, 0 deletions
diff --git a/examples/unbalanced-partial/README.txt b/examples/unbalanced-partial/README.txt new file mode 100644 index 0000000..2f404f0 --- /dev/null +++ b/examples/unbalanced-partial/README.txt @@ -0,0 +1,3 @@ + +Unbalanced and Partial OT +-------------------------
\ No newline at end of file diff --git a/examples/unbalanced-partial/plot_UOT_1D.py b/examples/unbalanced-partial/plot_UOT_1D.py new file mode 100644 index 0000000..2ea8b05 --- /dev/null +++ b/examples/unbalanced-partial/plot_UOT_1D.py @@ -0,0 +1,76 @@ +# -*- coding: utf-8 -*- +""" +=============================== +1D Unbalanced optimal transport +=============================== + +This example illustrates the computation of Unbalanced Optimal transport +using a Kullback-Leibler relaxation. +""" + +# Author: Hicham Janati <hicham.janati@inria.fr> +# +# License: MIT License + +import numpy as np +import matplotlib.pylab as pl +import ot +import ot.plot +from ot.datasets import make_1D_gauss as gauss + +############################################################################## +# Generate data +# ------------- + + +#%% parameters + +n = 100 # nb bins + +# bin positions +x = np.arange(n, dtype=np.float64) + +# Gaussian distributions +a = gauss(n, m=20, s=5) # m= mean, s= std +b = gauss(n, m=60, s=10) + +# make distributions unbalanced +b *= 5. + +# loss matrix +M = ot.dist(x.reshape((n, 1)), x.reshape((n, 1))) +M /= M.max() + + +############################################################################## +# Plot distributions and loss matrix +# ---------------------------------- + +#%% plot the distributions + +pl.figure(1, figsize=(6.4, 3)) +pl.plot(x, a, 'b', label='Source distribution') +pl.plot(x, b, 'r', label='Target distribution') +pl.legend() + +# plot distributions and loss matrix + +pl.figure(2, figsize=(5, 5)) +ot.plot.plot1D_mat(a, b, M, 'Cost matrix M') + + +############################################################################## +# Solve Unbalanced Sinkhorn +# -------------- + + +# Sinkhorn + +epsilon = 0.1 # entropy parameter +alpha = 1. # Unbalanced KL relaxation parameter +Gs = ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, alpha, verbose=True) + +pl.figure(4, figsize=(5, 5)) +ot.plot.plot1D_mat(a, b, Gs, 'UOT matrix Sinkhorn') + +pl.show() diff --git a/examples/unbalanced-partial/plot_UOT_barycenter_1D.py b/examples/unbalanced-partial/plot_UOT_barycenter_1D.py new file mode 100644 index 0000000..931798b --- /dev/null +++ b/examples/unbalanced-partial/plot_UOT_barycenter_1D.py @@ -0,0 +1,166 @@ +# -*- coding: utf-8 -*- +""" +=========================================================== +1D Wasserstein barycenter demo for Unbalanced distributions +=========================================================== + +This example illustrates the computation of regularized Wassersyein Barycenter +as proposed in [10] for Unbalanced inputs. + + +[10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. + +""" + +# Author: Hicham Janati <hicham.janati@inria.fr> +# +# License: MIT License + +# sphinx_gallery_thumbnail_number = 2 + +import numpy as np +import matplotlib.pylab as pl +import ot +# necessary for 3d plot even if not used +from mpl_toolkits.mplot3d import Axes3D # noqa +from matplotlib.collections import PolyCollection + +############################################################################## +# Generate data +# ------------- + +# parameters + +n = 100 # nb bins + +# bin positions +x = np.arange(n, dtype=np.float64) + +# Gaussian distributions +a1 = ot.datasets.make_1D_gauss(n, m=20, s=5) # m= mean, s= std +a2 = ot.datasets.make_1D_gauss(n, m=60, s=8) + +# make unbalanced dists +a2 *= 3. + +# creating matrix A containing all distributions +A = np.vstack((a1, a2)).T +n_distributions = A.shape[1] + +# loss matrix + normalization +M = ot.utils.dist0(n) +M /= M.max() + +############################################################################## +# Plot data +# --------- + +# plot the distributions + +pl.figure(1, figsize=(6.4, 3)) +for i in range(n_distributions): + pl.plot(x, A[:, i]) +pl.title('Distributions') +pl.tight_layout() + +############################################################################## +# Barycenter computation +# ---------------------- + +# non weighted barycenter computation + +weight = 0.5 # 0<=weight<=1 +weights = np.array([1 - weight, weight]) + +# l2bary +bary_l2 = A.dot(weights) + +# wasserstein +reg = 1e-3 +alpha = 1. + +bary_wass = ot.unbalanced.barycenter_unbalanced(A, M, reg, alpha, weights=weights) + +pl.figure(2) +pl.clf() +pl.subplot(2, 1, 1) +for i in range(n_distributions): + pl.plot(x, A[:, i]) +pl.title('Distributions') + +pl.subplot(2, 1, 2) +pl.plot(x, bary_l2, 'r', label='l2') +pl.plot(x, bary_wass, 'g', label='Wasserstein') +pl.legend() +pl.title('Barycenters') +pl.tight_layout() + +############################################################################## +# Barycentric interpolation +# ------------------------- + +# barycenter interpolation + +n_weight = 11 +weight_list = np.linspace(0, 1, n_weight) + + +B_l2 = np.zeros((n, n_weight)) + +B_wass = np.copy(B_l2) + +for i in range(0, n_weight): + weight = weight_list[i] + weights = np.array([1 - weight, weight]) + B_l2[:, i] = A.dot(weights) + B_wass[:, i] = ot.unbalanced.barycenter_unbalanced(A, M, reg, alpha, weights=weights) + + +# plot interpolation + +pl.figure(3) + +cmap = pl.cm.get_cmap('viridis') +verts = [] +zs = weight_list +for i, z in enumerate(zs): + ys = B_l2[:, i] + verts.append(list(zip(x, ys))) + +ax = pl.gcf().gca(projection='3d') + +poly = PolyCollection(verts, facecolors=[cmap(a) for a in weight_list]) +poly.set_alpha(0.7) +ax.add_collection3d(poly, zs=zs, zdir='y') +ax.set_xlabel('x') +ax.set_xlim3d(0, n) +ax.set_ylabel(r'$\alpha$') +ax.set_ylim3d(0, 1) +ax.set_zlabel('') +ax.set_zlim3d(0, B_l2.max() * 1.01) +pl.title('Barycenter interpolation with l2') +pl.tight_layout() + +pl.figure(4) +cmap = pl.cm.get_cmap('viridis') +verts = [] +zs = weight_list +for i, z in enumerate(zs): + ys = B_wass[:, i] + verts.append(list(zip(x, ys))) + +ax = pl.gcf().gca(projection='3d') + +poly = PolyCollection(verts, facecolors=[cmap(a) for a in weight_list]) +poly.set_alpha(0.7) +ax.add_collection3d(poly, zs=zs, zdir='y') +ax.set_xlabel('x') +ax.set_xlim3d(0, n) +ax.set_ylabel(r'$\alpha$') +ax.set_ylim3d(0, 1) +ax.set_zlabel('') +ax.set_zlim3d(0, B_l2.max() * 1.01) +pl.title('Barycenter interpolation with Wasserstein') +pl.tight_layout() + +pl.show() diff --git a/examples/unbalanced-partial/plot_partial_wass_and_gromov.py b/examples/unbalanced-partial/plot_partial_wass_and_gromov.py new file mode 100755 index 0000000..0c5cbf9 --- /dev/null +++ b/examples/unbalanced-partial/plot_partial_wass_and_gromov.py @@ -0,0 +1,165 @@ +# -*- coding: utf-8 -*-
+"""
+==================================================
+Partial Wasserstein and Gromov-Wasserstein example
+==================================================
+
+This example is designed to show how to use the Partial (Gromov-)Wassertsein
+distance computation in POT.
+"""
+
+# Author: Laetitia Chapel <laetitia.chapel@irisa.fr>
+# License: MIT License
+
+# sphinx_gallery_thumbnail_number = 2
+
+# necessary for 3d plot even if not used
+from mpl_toolkits.mplot3d import Axes3D # noqa
+import scipy as sp
+import numpy as np
+import matplotlib.pylab as pl
+import ot
+
+
+#############################################################################
+#
+# Sample two 2D Gaussian distributions and plot them
+# --------------------------------------------------
+#
+# For demonstration purpose, we sample two Gaussian distributions in 2-d
+# spaces and add some random noise.
+
+
+n_samples = 20 # nb samples (gaussian)
+n_noise = 20 # nb of samples (noise)
+
+mu = np.array([0, 0])
+cov = np.array([[1, 0], [0, 2]])
+
+xs = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov)
+xs = np.append(xs, (np.random.rand(n_noise, 2) + 1) * 4).reshape((-1, 2))
+xt = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov)
+xt = np.append(xt, (np.random.rand(n_noise, 2) + 1) * -3).reshape((-1, 2))
+
+M = sp.spatial.distance.cdist(xs, xt)
+
+fig = pl.figure()
+ax1 = fig.add_subplot(131)
+ax1.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+ax2 = fig.add_subplot(132)
+ax2.scatter(xt[:, 0], xt[:, 1], color='r')
+ax3 = fig.add_subplot(133)
+ax3.imshow(M)
+pl.show()
+
+#############################################################################
+#
+# Compute partial Wasserstein plans and distance
+# ----------------------------------------------
+
+p = ot.unif(n_samples + n_noise)
+q = ot.unif(n_samples + n_noise)
+
+w0, log0 = ot.partial.partial_wasserstein(p, q, M, m=0.5, log=True)
+w, log = ot.partial.entropic_partial_wasserstein(p, q, M, reg=0.1, m=0.5,
+ log=True)
+
+print('Partial Wasserstein distance (m = 0.5): ' + str(log0['partial_w_dist']))
+print('Entropic partial Wasserstein distance (m = 0.5): ' +
+ str(log['partial_w_dist']))
+
+pl.figure(1, (10, 5))
+pl.subplot(1, 2, 1)
+pl.imshow(w0, cmap='jet')
+pl.title('Partial Wasserstein')
+pl.subplot(1, 2, 2)
+pl.imshow(w, cmap='jet')
+pl.title('Entropic partial Wasserstein')
+pl.show()
+
+
+#############################################################################
+#
+# Sample one 2D and 3D Gaussian distributions and plot them
+# ---------------------------------------------------------
+#
+# The Gromov-Wasserstein distance allows to compute distances with samples that
+# do not belong to the same metric space. For demonstration purpose, we sample
+# two Gaussian distributions in 2- and 3-dimensional spaces.
+
+n_samples = 20 # nb samples
+n_noise = 10 # nb of samples (noise)
+
+p = ot.unif(n_samples + n_noise)
+q = ot.unif(n_samples + n_noise)
+
+mu_s = np.array([0, 0])
+cov_s = np.array([[1, 0], [0, 1]])
+
+mu_t = np.array([0, 0, 0])
+cov_t = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
+
+
+xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s)
+xs = np.concatenate((xs, ((np.random.rand(n_noise, 2) + 1) * 4)), axis=0)
+P = sp.linalg.sqrtm(cov_t)
+xt = np.random.randn(n_samples, 3).dot(P) + mu_t
+xt = np.concatenate((xt, ((np.random.rand(n_noise, 3) + 1) * 10)), axis=0)
+
+fig = pl.figure()
+ax1 = fig.add_subplot(121)
+ax1.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+ax2 = fig.add_subplot(122, projection='3d')
+ax2.scatter(xt[:, 0], xt[:, 1], xt[:, 2], color='r')
+pl.show()
+
+
+#############################################################################
+#
+# Compute partial Gromov-Wasserstein plans and distance
+# -----------------------------------------------------
+
+C1 = sp.spatial.distance.cdist(xs, xs)
+C2 = sp.spatial.distance.cdist(xt, xt)
+
+# transport 100% of the mass
+print('-----m = 1')
+m = 1
+res0, log0 = ot.partial.partial_gromov_wasserstein(C1, C2, p, q, m=m, log=True)
+res, log = ot.partial.entropic_partial_gromov_wasserstein(C1, C2, p, q, 10,
+ m=m, log=True)
+
+print('Wasserstein distance (m = 1): ' + str(log0['partial_gw_dist']))
+print('Entropic Wasserstein distance (m = 1): ' + str(log['partial_gw_dist']))
+
+pl.figure(1, (10, 5))
+pl.title("mass to be transported m = 1")
+pl.subplot(1, 2, 1)
+pl.imshow(res0, cmap='jet')
+pl.title('Wasserstein')
+pl.subplot(1, 2, 2)
+pl.imshow(res, cmap='jet')
+pl.title('Entropic Wasserstein')
+pl.show()
+
+# transport 2/3 of the mass
+print('-----m = 2/3')
+m = 2 / 3
+res0, log0 = ot.partial.partial_gromov_wasserstein(C1, C2, p, q, m=m, log=True)
+res, log = ot.partial.entropic_partial_gromov_wasserstein(C1, C2, p, q, 10,
+ m=m, log=True)
+
+print('Partial Wasserstein distance (m = 2/3): ' +
+ str(log0['partial_gw_dist']))
+print('Entropic partial Wasserstein distance (m = 2/3): ' +
+ str(log['partial_gw_dist']))
+
+pl.figure(1, (10, 5))
+pl.title("mass to be transported m = 2/3")
+pl.subplot(1, 2, 1)
+pl.imshow(res0, cmap='jet')
+pl.title('Partial Wasserstein')
+pl.subplot(1, 2, 2)
+pl.imshow(res, cmap='jet')
+pl.title('Entropic partial Wasserstein')
+pl.show()
|