summaryrefslogtreecommitdiff
path: root/examples/plot_OT_L1_vs_L2.py
diff options
context:
space:
mode:
authorAlexandre Gramfort <alexandre.gramfort@m4x.org>2017-07-12 22:51:17 +0200
committerAlexandre Gramfort <alexandre.gramfort@m4x.org>2017-07-20 14:05:12 +0200
commit25ef32ff892fba105a4a116a804b1e4f08ae57cd (patch)
treecc58adb198141841e2e5dab9c4f33253556f8f65 /examples/plot_OT_L1_vs_L2.py
parentd6091dae858a82f69e6843859f164269fa338c6b (diff)
more
Diffstat (limited to 'examples/plot_OT_L1_vs_L2.py')
-rw-r--r--examples/plot_OT_L1_vs_L2.py80
1 files changed, 40 insertions, 40 deletions
diff --git a/examples/plot_OT_L1_vs_L2.py b/examples/plot_OT_L1_vs_L2.py
index e11d6ad..86d902b 100644
--- a/examples/plot_OT_L1_vs_L2.py
+++ b/examples/plot_OT_L1_vs_L2.py
@@ -12,7 +12,7 @@ https://arxiv.org/pdf/1706.07650.pdf
"""
import numpy as np
-import matplotlib.pylab as plt
+import matplotlib.pylab as pl
import ot
#%% parameters and data generation
@@ -55,58 +55,58 @@ for data in range(2):
#%% plot samples
- plt.figure(1 + 3 * data, figsize=(7, 3))
- plt.clf()
- plt.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
- plt.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
- plt.axis('equal')
- plt.title('Source and traget distributions')
+ pl.figure(1 + 3 * data, figsize=(7, 3))
+ pl.clf()
+ pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+ pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
+ pl.axis('equal')
+ pl.title('Source and traget distributions')
- plt.figure(2 + 3 * data, figsize=(7, 3))
+ pl.figure(2 + 3 * data, figsize=(7, 3))
- plt.subplot(1, 3, 1)
- plt.imshow(M1, interpolation='nearest')
- plt.title('Euclidean cost')
+ pl.subplot(1, 3, 1)
+ pl.imshow(M1, interpolation='nearest')
+ pl.title('Euclidean cost')
- plt.subplot(1, 3, 2)
- plt.imshow(M2, interpolation='nearest')
- plt.title('Squared Euclidean cost')
+ pl.subplot(1, 3, 2)
+ pl.imshow(M2, interpolation='nearest')
+ pl.title('Squared Euclidean cost')
- plt.subplot(1, 3, 3)
- plt.imshow(Mp, interpolation='nearest')
- plt.title('Sqrt Euclidean cost')
- plt.tight_layout()
+ pl.subplot(1, 3, 3)
+ pl.imshow(Mp, interpolation='nearest')
+ pl.title('Sqrt Euclidean cost')
+ pl.tight_layout()
#%% EMD
G1 = ot.emd(a, b, M1)
G2 = ot.emd(a, b, M2)
Gp = ot.emd(a, b, Mp)
- plt.figure(3 + 3 * data, figsize=(7, 3))
+ pl.figure(3 + 3 * data, figsize=(7, 3))
- plt.subplot(1, 3, 1)
+ pl.subplot(1, 3, 1)
ot.plot.plot2D_samples_mat(xs, xt, G1, c=[.5, .5, 1])
- plt.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
- plt.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
- plt.axis('equal')
- # plt.legend(loc=0)
- plt.title('OT Euclidean')
+ pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+ pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
+ pl.axis('equal')
+ # pl.legend(loc=0)
+ pl.title('OT Euclidean')
- plt.subplot(1, 3, 2)
+ pl.subplot(1, 3, 2)
ot.plot.plot2D_samples_mat(xs, xt, G2, c=[.5, .5, 1])
- plt.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
- plt.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
- plt.axis('equal')
- # plt.legend(loc=0)
- plt.title('OT squared Euclidean')
+ pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+ pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
+ pl.axis('equal')
+ # pl.legend(loc=0)
+ pl.title('OT squared Euclidean')
- plt.subplot(1, 3, 3)
+ pl.subplot(1, 3, 3)
ot.plot.plot2D_samples_mat(xs, xt, Gp, c=[.5, .5, 1])
- plt.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
- plt.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
- plt.axis('equal')
- # plt.legend(loc=0)
- plt.title('OT sqrt Euclidean')
- plt.tight_layout()
-
-plt.show()
+ pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+ pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
+ pl.axis('equal')
+ # pl.legend(loc=0)
+ pl.title('OT sqrt Euclidean')
+ pl.tight_layout()
+
+pl.show()