diff options
Diffstat (limited to 'examples/others')
-rw-r--r-- | examples/others/plot_WeakOT_VS_OT.py | 98 | ||||
-rw-r--r-- | examples/others/plot_factored_coupling.py | 86 | ||||
-rw-r--r-- | examples/others/plot_logo.py | 112 | ||||
-rw-r--r-- | examples/others/plot_screenkhorn_1D.py | 71 | ||||
-rw-r--r-- | examples/others/plot_stochastic.py | 189 |
5 files changed, 556 insertions, 0 deletions
diff --git a/examples/others/plot_WeakOT_VS_OT.py b/examples/others/plot_WeakOT_VS_OT.py new file mode 100644 index 0000000..a29c875 --- /dev/null +++ b/examples/others/plot_WeakOT_VS_OT.py @@ -0,0 +1,98 @@ +# -*- coding: utf-8 -*- +""" +==================================================== +Weak Optimal Transport VS exact Optimal Transport +==================================================== + +Illustration of 2D optimal transport between distributions that are weighted +sum of diracs. The OT matrix is plotted with the samples. + +""" + +# Author: Remi Flamary <remi.flamary@polytechnique.edu> +# +# License: MIT License + +# sphinx_gallery_thumbnail_number = 4 + +import numpy as np +import matplotlib.pylab as pl +import ot +import ot.plot + +############################################################################## +# Generate data an plot it +# ------------------------ + +#%% parameters and data generation + +n = 50 # nb samples + +mu_s = np.array([0, 0]) +cov_s = np.array([[1, 0], [0, 1]]) + +mu_t = np.array([4, 4]) +cov_t = np.array([[1, -.8], [-.8, 1]]) + +xs = ot.datasets.make_2D_samples_gauss(n, mu_s, cov_s) +xt = ot.datasets.make_2D_samples_gauss(n, mu_t, cov_t) + +a, b = ot.unif(n), ot.unif(n) # uniform distribution on samples + +# loss matrix +M = ot.dist(xs, xt) +M /= M.max() + +#%% plot samples + +pl.figure(1) +pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples') +pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples') +pl.legend(loc=0) +pl.title('Source and target distributions') + +pl.figure(2) +pl.imshow(M, interpolation='nearest') +pl.title('Cost matrix M') + + +############################################################################## +# Compute Weak OT and exact OT solutions +# -------------------------------------- + +#%% EMD + +G0 = ot.emd(a, b, M) + +#%% Weak OT + +Gweak = ot.weak_optimal_transport(xs, xt, a, b) + + +############################################################################## +# Plot weak OT and exact OT solutions +# -------------------------------------- + +pl.figure(3, (8, 5)) + +pl.subplot(1, 2, 1) +pl.imshow(G0, interpolation='nearest') +pl.title('OT matrix') + +pl.subplot(1, 2, 2) +pl.imshow(Gweak, interpolation='nearest') +pl.title('Weak OT matrix') + +pl.figure(4, (8, 5)) + +pl.subplot(1, 2, 1) +ot.plot.plot2D_samples_mat(xs, xt, G0, c=[.5, .5, 1]) +pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples') +pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples') +pl.title('OT matrix with samples') + +pl.subplot(1, 2, 2) +ot.plot.plot2D_samples_mat(xs, xt, Gweak, c=[.5, .5, 1]) +pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples') +pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples') +pl.title('Weak OT matrix with samples') diff --git a/examples/others/plot_factored_coupling.py b/examples/others/plot_factored_coupling.py new file mode 100644 index 0000000..b5b1c9f --- /dev/null +++ b/examples/others/plot_factored_coupling.py @@ -0,0 +1,86 @@ +# -*- coding: utf-8 -*- +""" +========================================== +Optimal transport with factored couplings +========================================== + +Illustration of the factored coupling OT between 2D empirical distributions + +""" + +# Author: Remi Flamary <remi.flamary@polytechnique.edu> +# +# License: MIT License + +# sphinx_gallery_thumbnail_number = 2 + +import numpy as np +import matplotlib.pylab as pl +import ot +import ot.plot + +# %% +# Generate data an plot it +# ------------------------ + +# parameters and data generation + +np.random.seed(42) + +n = 100 # nb samples + +xs = np.random.rand(n, 2) - .5 + +xs = xs + np.sign(xs) + +xt = np.random.rand(n, 2) - .5 + +a, b = ot.unif(n), ot.unif(n) # uniform distribution on samples + +#%% plot samples + +pl.figure(1) +pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples') +pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples') +pl.legend(loc=0) +pl.title('Source and target distributions') + + +# %% +# Compute Factore OT and exact OT solutions +# -------------------------------------- + +#%% EMD +M = ot.dist(xs, xt) +G0 = ot.emd(a, b, M) + +#%% factored OT OT + +Ga, Gb, xb = ot.factored_optimal_transport(xs, xt, a, b, r=4) + + +# %% +# Plot factored OT and exact OT solutions +# -------------------------------------- + +pl.figure(2, (14, 4)) + +pl.subplot(1, 3, 1) +ot.plot.plot2D_samples_mat(xs, xt, G0, c=[.2, .2, .2], alpha=0.1) +pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples') +pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples') +pl.title('Exact OT with samples') + +pl.subplot(1, 3, 2) +ot.plot.plot2D_samples_mat(xs, xb, Ga, c=[.6, .6, .9], alpha=0.5) +ot.plot.plot2D_samples_mat(xb, xt, Gb, c=[.9, .6, .6], alpha=0.5) +pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples') +pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples') +pl.plot(xb[:, 0], xb[:, 1], 'og', label='Template samples') +pl.title('Factored OT with template samples') + +pl.subplot(1, 3, 3) +ot.plot.plot2D_samples_mat(xs, xt, Ga.dot(Gb), c=[.2, .2, .2], alpha=0.1) +pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples') +pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples') +pl.title('Factored OT low rank OT plan') diff --git a/examples/others/plot_logo.py b/examples/others/plot_logo.py new file mode 100644 index 0000000..bb4f640 --- /dev/null +++ b/examples/others/plot_logo.py @@ -0,0 +1,112 @@ + +# -*- coding: utf-8 -*- +r""" +======================= +Logo of the POT toolbox +======================= + +In this example we plot the logo of the POT toolbox. + +This logo is that it is done 100% in Python and generated using +matplotlib and ploting teh solution of the EMD solver from POT. + +""" + +# Author: Remi Flamary <remi.flamary@polytechnique.edu> +# +# License: MIT License + +# sphinx_gallery_thumbnail_number = 1 + +# %% Load modules +import numpy as np +import matplotlib.pyplot as pl +import ot + +# %% +# Data for logo +# ------------- + + +# Letter P +p1 = np.array([[0, 6.], [0, 5], [0, 4], [0, 3], [0, 2], [0, 1], ]) +p2 = np.array([[1.5, 6], [2, 4], [2, 5], [1.5, 3], [0.5, 2], [.5, 1], ]) + +# Letter O +o1 = np.array([[0, 6.], [-1, 5], [-1.5, 4], [-1.5, 3], [-1, 2], [0, 1], ]) +o2 = np.array([[1, 6.], [2, 5], [2.5, 4], [2.5, 3], [2, 2], [1, 1], ]) + +# Scaling and translation for letter O +o1[:, 0] += 6.4 +o2[:, 0] += 6.4 +o1[:, 0] *= 0.6 +o2[:, 0] *= 0.6 + +# Letter T +t1 = np.array([[-1, 6.], [-1, 5], [0, 4], [0, 3], [0, 2], [0, 1], ]) +t2 = np.array([[1.5, 6.], [1.5, 5], [0.5, 4], [0.5, 3], [0.5, 2], [0.5, 1], ]) + +# Translating the T +t1[:, 0] += 7.1 +t2[:, 0] += 7.1 + +# Concatenate all letters +x1 = np.concatenate((p1, o1, t1), axis=0) +x2 = np.concatenate((p2, o2, t2), axis=0) + +# Horizontal and vertical scaling +sx = 1.0 +sy = .5 +x1[:, 0] *= sx +x1[:, 1] *= sy +x2[:, 0] *= sx +x2[:, 1] *= sy + +# %% +# Plot the logo (clear background) +# -------------------------------- + +# Solve OT problem between the points +M = ot.dist(x1, x2, metric='euclidean') +T = ot.emd([], [], M) + +pl.figure(1, (3.5, 1.1)) +pl.clf() +# plot the OT plan +for i in range(M.shape[0]): + for j in range(M.shape[1]): + if T[i, j] > 1e-8: + pl.plot([x1[i, 0], x2[j, 0]], [x1[i, 1], x2[j, 1]], color='k', alpha=0.6, linewidth=3, zorder=1) +# plot the samples +pl.plot(x1[:, 0], x1[:, 1], 'o', markerfacecolor='C3', markeredgecolor='k') +pl.plot(x2[:, 0], x2[:, 1], 'o', markerfacecolor='b', markeredgecolor='k') + + +pl.axis('equal') +pl.axis('off') + +# Save logo file +# pl.savefig('logo.svg', dpi=150, transparent=True, bbox_inches='tight') +# pl.savefig('logo.png', dpi=150, transparent=True, bbox_inches='tight') + +# %% +# Plot the logo (dark background) +# -------------------------------- + +pl.figure(2, (3.5, 1.1), facecolor='darkgray') +pl.clf() +# plot the OT plan +for i in range(M.shape[0]): + for j in range(M.shape[1]): + if T[i, j] > 1e-8: + pl.plot([x1[i, 0], x2[j, 0]], [x1[i, 1], x2[j, 1]], color='w', alpha=0.8, linewidth=3, zorder=1) +# plot the samples +pl.plot(x1[:, 0], x1[:, 1], 'o', markerfacecolor='w', markeredgecolor='w') +pl.plot(x2[:, 0], x2[:, 1], 'o', markerfacecolor='w', markeredgecolor='w') + +pl.axis('equal') +pl.axis('off') + +# Save logo file +# pl.savefig('logo_dark.svg', dpi=150, transparent=True, bbox_inches='tight') +# pl.savefig('logo_dark.png', dpi=150, transparent=True, bbox_inches='tight') diff --git a/examples/others/plot_screenkhorn_1D.py b/examples/others/plot_screenkhorn_1D.py new file mode 100644 index 0000000..2023649 --- /dev/null +++ b/examples/others/plot_screenkhorn_1D.py @@ -0,0 +1,71 @@ +# -*- coding: utf-8 -*- +""" +======================================== +Screened optimal transport (Screenkhorn) +======================================== + +This example illustrates the computation of Screenkhorn [26]. + +[26] Alaya M. Z., Bérar M., Gasso G., Rakotomamonjy A. (2019). +Screening Sinkhorn Algorithm for Regularized Optimal Transport, +Advances in Neural Information Processing Systems 33 (NeurIPS). +""" + +# Author: Mokhtar Z. Alaya <mokhtarzahdi.alaya@gmail.com> +# +# License: MIT License + +import numpy as np +import matplotlib.pylab as pl +import ot.plot +from ot.datasets import make_1D_gauss as gauss +from ot.bregman import screenkhorn + +############################################################################## +# Generate data +# ------------- + +#%% parameters + +n = 100 # nb bins + +# bin positions +x = np.arange(n, dtype=np.float64) + +# Gaussian distributions +a = gauss(n, m=20, s=5) # m= mean, s= std +b = gauss(n, m=60, s=10) + +# loss matrix +M = ot.dist(x.reshape((n, 1)), x.reshape((n, 1))) +M /= M.max() + +############################################################################## +# Plot distributions and loss matrix +# ---------------------------------- + +#%% plot the distributions + +pl.figure(1, figsize=(6.4, 3)) +pl.plot(x, a, 'b', label='Source distribution') +pl.plot(x, b, 'r', label='Target distribution') +pl.legend() + +# plot distributions and loss matrix + +pl.figure(2, figsize=(5, 5)) +ot.plot.plot1D_mat(a, b, M, 'Cost matrix M') + +############################################################################## +# Solve Screenkhorn +# ----------------------- + +# Screenkhorn +lambd = 2e-03 # entropy parameter +ns_budget = 30 # budget number of points to be keeped in the source distribution +nt_budget = 30 # budget number of points to be keeped in the target distribution + +G_screen = screenkhorn(a, b, M, lambd, ns_budget, nt_budget, uniform=False, restricted=True, verbose=True) +pl.figure(4, figsize=(5, 5)) +ot.plot.plot1D_mat(a, b, G_screen, 'OT matrix Screenkhorn') +pl.show() diff --git a/examples/others/plot_stochastic.py b/examples/others/plot_stochastic.py new file mode 100644 index 0000000..3a1ef31 --- /dev/null +++ b/examples/others/plot_stochastic.py @@ -0,0 +1,189 @@ +""" +=================== +Stochastic examples +=================== + +This example is designed to show how to use the stochatic optimization +algorithms for discrete and semi-continuous measures from the POT library. + +[18] Genevay, A., Cuturi, M., Peyré, G. & Bach, F. +Stochastic Optimization for Large-scale Optimal Transport. +Advances in Neural Information Processing Systems (2016). + +[19] Seguy, V., Bhushan Damodaran, B., Flamary, R., Courty, N., Rolet, A. & +Blondel, M. Large-scale Optimal Transport and Mapping Estimation. +International Conference on Learning Representation (2018) + +""" + +# Author: Kilian Fatras <kilian.fatras@gmail.com> +# +# License: MIT License + +import matplotlib.pylab as pl +import numpy as np +import ot +import ot.plot + + +############################################################################# +# Compute the Transportation Matrix for the Semi-Dual Problem +# ----------------------------------------------------------- +# +# Discrete case +# ````````````` +# +# Sample two discrete measures for the discrete case and compute their cost +# matrix c. + +n_source = 7 +n_target = 4 +reg = 1 +numItermax = 1000 + +a = ot.utils.unif(n_source) +b = ot.utils.unif(n_target) + +rng = np.random.RandomState(0) +X_source = rng.randn(n_source, 2) +Y_target = rng.randn(n_target, 2) +M = ot.dist(X_source, Y_target) + +############################################################################# +# Call the "SAG" method to find the transportation matrix in the discrete case + +method = "SAG" +sag_pi = ot.stochastic.solve_semi_dual_entropic(a, b, M, reg, method, + numItermax) +print(sag_pi) + +############################################################################# +# Semi-Continuous Case +# ```````````````````` +# +# Sample one general measure a, one discrete measures b for the semicontinous +# case, the points where source and target measures are defined and compute the +# cost matrix. + +n_source = 7 +n_target = 4 +reg = 1 +numItermax = 1000 +log = True + +a = ot.utils.unif(n_source) +b = ot.utils.unif(n_target) + +rng = np.random.RandomState(0) +X_source = rng.randn(n_source, 2) +Y_target = rng.randn(n_target, 2) +M = ot.dist(X_source, Y_target) + +############################################################################# +# Call the "ASGD" method to find the transportation matrix in the semicontinous +# case. + +method = "ASGD" +asgd_pi, log_asgd = ot.stochastic.solve_semi_dual_entropic(a, b, M, reg, method, + numItermax, log=log) +print(log_asgd['alpha'], log_asgd['beta']) +print(asgd_pi) + +############################################################################# +# Compare the results with the Sinkhorn algorithm + +sinkhorn_pi = ot.sinkhorn(a, b, M, reg) +print(sinkhorn_pi) + + +############################################################################## +# Plot Transportation Matrices +# ```````````````````````````` +# +# For SAG + +pl.figure(4, figsize=(5, 5)) +ot.plot.plot1D_mat(a, b, sag_pi, 'semi-dual : OT matrix SAG') +pl.show() + + +############################################################################## +# For ASGD + +pl.figure(4, figsize=(5, 5)) +ot.plot.plot1D_mat(a, b, asgd_pi, 'semi-dual : OT matrix ASGD') +pl.show() + + +############################################################################## +# For Sinkhorn + +pl.figure(4, figsize=(5, 5)) +ot.plot.plot1D_mat(a, b, sinkhorn_pi, 'OT matrix Sinkhorn') +pl.show() + + +############################################################################# +# Compute the Transportation Matrix for the Dual Problem +# ------------------------------------------------------ +# +# Semi-continuous case +# ```````````````````` +# +# Sample one general measure a, one discrete measures b for the semi-continuous +# case and compute the cost matrix c. + +n_source = 7 +n_target = 4 +reg = 1 +numItermax = 100000 +lr = 0.1 +batch_size = 3 +log = True + +a = ot.utils.unif(n_source) +b = ot.utils.unif(n_target) + +rng = np.random.RandomState(0) +X_source = rng.randn(n_source, 2) +Y_target = rng.randn(n_target, 2) +M = ot.dist(X_source, Y_target) + +############################################################################# +# +# Call the "SGD" dual method to find the transportation matrix in the +# semi-continuous case + +sgd_dual_pi, log_sgd = ot.stochastic.solve_dual_entropic(a, b, M, reg, + batch_size, numItermax, + lr, log=log) +print(log_sgd['alpha'], log_sgd['beta']) +print(sgd_dual_pi) + +############################################################################# +# +# Compare the results with the Sinkhorn algorithm +# ``````````````````````````````````````````````` +# +# Call the Sinkhorn algorithm from POT + +sinkhorn_pi = ot.sinkhorn(a, b, M, reg) +print(sinkhorn_pi) + +############################################################################## +# Plot Transportation Matrices +# ```````````````````````````` +# +# For SGD + +pl.figure(4, figsize=(5, 5)) +ot.plot.plot1D_mat(a, b, sgd_dual_pi, 'dual : OT matrix SGD') +pl.show() + + +############################################################################## +# For Sinkhorn + +pl.figure(4, figsize=(5, 5)) +ot.plot.plot1D_mat(a, b, sinkhorn_pi, 'OT matrix Sinkhorn') +pl.show() |