From 25ef32ff892fba105a4a116a804b1e4f08ae57cd Mon Sep 17 00:00:00 2001 From: Alexandre Gramfort Date: Wed, 12 Jul 2017 22:51:17 +0200 Subject: more --- examples/plot_OT_L1_vs_L2.py | 80 ++++++++++++++++++++++---------------------- 1 file changed, 40 insertions(+), 40 deletions(-) (limited to 'examples/plot_OT_L1_vs_L2.py') diff --git a/examples/plot_OT_L1_vs_L2.py b/examples/plot_OT_L1_vs_L2.py index e11d6ad..86d902b 100644 --- a/examples/plot_OT_L1_vs_L2.py +++ b/examples/plot_OT_L1_vs_L2.py @@ -12,7 +12,7 @@ https://arxiv.org/pdf/1706.07650.pdf """ import numpy as np -import matplotlib.pylab as plt +import matplotlib.pylab as pl import ot #%% parameters and data generation @@ -55,58 +55,58 @@ for data in range(2): #%% plot samples - plt.figure(1 + 3 * data, figsize=(7, 3)) - plt.clf() - plt.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples') - plt.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples') - plt.axis('equal') - plt.title('Source and traget distributions') + pl.figure(1 + 3 * data, figsize=(7, 3)) + pl.clf() + pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples') + pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples') + pl.axis('equal') + pl.title('Source and traget distributions') - plt.figure(2 + 3 * data, figsize=(7, 3)) + pl.figure(2 + 3 * data, figsize=(7, 3)) - plt.subplot(1, 3, 1) - plt.imshow(M1, interpolation='nearest') - plt.title('Euclidean cost') + pl.subplot(1, 3, 1) + pl.imshow(M1, interpolation='nearest') + pl.title('Euclidean cost') - plt.subplot(1, 3, 2) - plt.imshow(M2, interpolation='nearest') - plt.title('Squared Euclidean cost') + pl.subplot(1, 3, 2) + pl.imshow(M2, interpolation='nearest') + pl.title('Squared Euclidean cost') - plt.subplot(1, 3, 3) - plt.imshow(Mp, interpolation='nearest') - plt.title('Sqrt Euclidean cost') - plt.tight_layout() + pl.subplot(1, 3, 3) + pl.imshow(Mp, interpolation='nearest') + pl.title('Sqrt Euclidean cost') + pl.tight_layout() #%% EMD G1 = ot.emd(a, b, M1) G2 = ot.emd(a, b, M2) Gp = ot.emd(a, b, Mp) - plt.figure(3 + 3 * data, figsize=(7, 3)) + pl.figure(3 + 3 * data, figsize=(7, 3)) - plt.subplot(1, 3, 1) + pl.subplot(1, 3, 1) ot.plot.plot2D_samples_mat(xs, xt, G1, c=[.5, .5, 1]) - plt.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples') - plt.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples') - plt.axis('equal') - # plt.legend(loc=0) - plt.title('OT Euclidean') + pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples') + pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples') + pl.axis('equal') + # pl.legend(loc=0) + pl.title('OT Euclidean') - plt.subplot(1, 3, 2) + pl.subplot(1, 3, 2) ot.plot.plot2D_samples_mat(xs, xt, G2, c=[.5, .5, 1]) - plt.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples') - plt.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples') - plt.axis('equal') - # plt.legend(loc=0) - plt.title('OT squared Euclidean') + pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples') + pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples') + pl.axis('equal') + # pl.legend(loc=0) + pl.title('OT squared Euclidean') - plt.subplot(1, 3, 3) + pl.subplot(1, 3, 3) ot.plot.plot2D_samples_mat(xs, xt, Gp, c=[.5, .5, 1]) - plt.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples') - plt.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples') - plt.axis('equal') - # plt.legend(loc=0) - plt.title('OT sqrt Euclidean') - plt.tight_layout() - -plt.show() + pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples') + pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples') + pl.axis('equal') + # pl.legend(loc=0) + pl.title('OT sqrt Euclidean') + pl.tight_layout() + +pl.show() -- cgit v1.2.3