summaryrefslogtreecommitdiff
path: root/examples/plot_OT_L1_vs_L2.py
diff options
context:
space:
mode:
Diffstat (limited to 'examples/plot_OT_L1_vs_L2.py')
-rw-r--r--examples/plot_OT_L1_vs_L2.py34
1 files changed, 17 insertions, 17 deletions
diff --git a/examples/plot_OT_L1_vs_L2.py b/examples/plot_OT_L1_vs_L2.py
index 60353ab..cce51f8 100644
--- a/examples/plot_OT_L1_vs_L2.py
+++ b/examples/plot_OT_L1_vs_L2.py
@@ -1,10 +1,10 @@
# -*- coding: utf-8 -*-
"""
-==========================================
-2D Optimal transport for different metrics
-==========================================
+================================================
+Optimal Transport with different gournd metrics
+================================================
-2D OT on empirical distributio with different gound metric.
+2D OT on empirical distributio with different ground metric.
Stole the figure idea from Fig. 1 and 2 in
https://arxiv.org/pdf/1706.07650.pdf
@@ -23,7 +23,7 @@ import matplotlib.pylab as pl
import ot
import ot.plot
-##############################################################################
+# %%
# Dataset 1 : uniform sampling
# ----------------------------
@@ -46,7 +46,7 @@ M2 = ot.dist(xs, xt, metric='sqeuclidean')
M2 /= M2.max()
# loss matrix
-Mp = np.sqrt(ot.dist(xs, xt, metric='euclidean'))
+Mp = ot.dist(xs, xt, metric='cityblock')
Mp /= Mp.max()
# Data
@@ -71,7 +71,7 @@ pl.title('Squared Euclidean cost')
pl.subplot(1, 3, 3)
pl.imshow(Mp, interpolation='nearest')
-pl.title('Sqrt Euclidean cost')
+pl.title('L1 (cityblock cost')
pl.tight_layout()
##############################################################################
@@ -109,22 +109,22 @@ pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
pl.axis('equal')
# pl.legend(loc=0)
-pl.title('OT sqrt Euclidean')
+pl.title('OT L1 (cityblock)')
pl.tight_layout()
pl.show()
-##############################################################################
+# %%
# Dataset 2 : Partial circle
# --------------------------
-n = 50 # nb samples
+n = 20 # nb samples
xtot = np.zeros((n + 1, 2))
xtot[:, 0] = np.cos(
- (np.arange(n + 1) + 1.0) * 0.9 / (n + 2) * 2 * np.pi)
+ (np.arange(n + 1) + 1.0) * 0.8 / (n + 2) * 2 * np.pi)
xtot[:, 1] = np.sin(
- (np.arange(n + 1) + 1.0) * 0.9 / (n + 2) * 2 * np.pi)
+ (np.arange(n + 1) + 1.0) * 0.8 / (n + 2) * 2 * np.pi)
xs = xtot[:n, :]
xt = xtot[1:, :]
@@ -140,7 +140,7 @@ M2 = ot.dist(xs, xt, metric='sqeuclidean')
M2 /= M2.max()
# loss matrix
-Mp = np.sqrt(ot.dist(xs, xt, metric='euclidean'))
+Mp = ot.dist(xs, xt, metric='cityblock')
Mp /= Mp.max()
@@ -150,7 +150,7 @@ pl.clf()
pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
pl.axis('equal')
-pl.title('Source and traget distributions')
+pl.title('Source and target distributions')
# Cost matrices
@@ -166,13 +166,13 @@ pl.title('Squared Euclidean cost')
pl.subplot(1, 3, 3)
pl.imshow(Mp, interpolation='nearest')
-pl.title('Sqrt Euclidean cost')
+pl.title('L1 (cityblock) cost')
pl.tight_layout()
##############################################################################
# Dataset 2 : Plot OT Matrices
# -----------------------------
-
+#
#%% EMD
G1 = ot.emd(a, b, M1)
@@ -204,7 +204,7 @@ pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
pl.axis('equal')
# pl.legend(loc=0)
-pl.title('OT sqrt Euclidean')
+pl.title('OT L1 (cityblock)')
pl.tight_layout()
pl.show()