summaryrefslogtreecommitdiff
path: root/examples/plot_OT_L1_vs_L2.py
diff options
context:
space:
mode:
Diffstat (limited to 'examples/plot_OT_L1_vs_L2.py')
-rw-r--r--examples/plot_OT_L1_vs_L2.py80
1 files changed, 40 insertions, 40 deletions
diff --git a/examples/plot_OT_L1_vs_L2.py b/examples/plot_OT_L1_vs_L2.py
index e11d6ad..86d902b 100644
--- a/examples/plot_OT_L1_vs_L2.py
+++ b/examples/plot_OT_L1_vs_L2.py
@@ -12,7 +12,7 @@ https://arxiv.org/pdf/1706.07650.pdf
"""
import numpy as np
-import matplotlib.pylab as plt
+import matplotlib.pylab as pl
import ot
#%% parameters and data generation
@@ -55,58 +55,58 @@ for data in range(2):
#%% plot samples
- plt.figure(1 + 3 * data, figsize=(7, 3))
- plt.clf()
- plt.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
- plt.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
- plt.axis('equal')
- plt.title('Source and traget distributions')
+ pl.figure(1 + 3 * data, figsize=(7, 3))
+ pl.clf()
+ pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+ pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
+ pl.axis('equal')
+ pl.title('Source and traget distributions')
- plt.figure(2 + 3 * data, figsize=(7, 3))
+ pl.figure(2 + 3 * data, figsize=(7, 3))
- plt.subplot(1, 3, 1)
- plt.imshow(M1, interpolation='nearest')
- plt.title('Euclidean cost')
+ pl.subplot(1, 3, 1)
+ pl.imshow(M1, interpolation='nearest')
+ pl.title('Euclidean cost')
- plt.subplot(1, 3, 2)
- plt.imshow(M2, interpolation='nearest')
- plt.title('Squared Euclidean cost')
+ pl.subplot(1, 3, 2)
+ pl.imshow(M2, interpolation='nearest')
+ pl.title('Squared Euclidean cost')
- plt.subplot(1, 3, 3)
- plt.imshow(Mp, interpolation='nearest')
- plt.title('Sqrt Euclidean cost')
- plt.tight_layout()
+ pl.subplot(1, 3, 3)
+ pl.imshow(Mp, interpolation='nearest')
+ pl.title('Sqrt Euclidean cost')
+ pl.tight_layout()
#%% EMD
G1 = ot.emd(a, b, M1)
G2 = ot.emd(a, b, M2)
Gp = ot.emd(a, b, Mp)
- plt.figure(3 + 3 * data, figsize=(7, 3))
+ pl.figure(3 + 3 * data, figsize=(7, 3))
- plt.subplot(1, 3, 1)
+ pl.subplot(1, 3, 1)
ot.plot.plot2D_samples_mat(xs, xt, G1, c=[.5, .5, 1])
- plt.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
- plt.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
- plt.axis('equal')
- # plt.legend(loc=0)
- plt.title('OT Euclidean')
+ pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+ pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
+ pl.axis('equal')
+ # pl.legend(loc=0)
+ pl.title('OT Euclidean')
- plt.subplot(1, 3, 2)
+ pl.subplot(1, 3, 2)
ot.plot.plot2D_samples_mat(xs, xt, G2, c=[.5, .5, 1])
- plt.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
- plt.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
- plt.axis('equal')
- # plt.legend(loc=0)
- plt.title('OT squared Euclidean')
+ pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+ pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
+ pl.axis('equal')
+ # pl.legend(loc=0)
+ pl.title('OT squared Euclidean')
- plt.subplot(1, 3, 3)
+ pl.subplot(1, 3, 3)
ot.plot.plot2D_samples_mat(xs, xt, Gp, c=[.5, .5, 1])
- plt.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
- plt.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
- plt.axis('equal')
- # plt.legend(loc=0)
- plt.title('OT sqrt Euclidean')
- plt.tight_layout()
-
-plt.show()
+ pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+ pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
+ pl.axis('equal')
+ # pl.legend(loc=0)
+ pl.title('OT sqrt Euclidean')
+ pl.tight_layout()
+
+pl.show()