summaryrefslogtreecommitdiff
path: root/examples/plot_compute_emd.py
diff options
context:
space:
mode:
Diffstat (limited to 'examples/plot_compute_emd.py')
-rw-r--r--examples/plot_compute_emd.py72
1 files changed, 50 insertions, 22 deletions
diff --git a/examples/plot_compute_emd.py b/examples/plot_compute_emd.py
index 527a847..36cc7da 100644
--- a/examples/plot_compute_emd.py
+++ b/examples/plot_compute_emd.py
@@ -1,10 +1,10 @@
# -*- coding: utf-8 -*-
"""
-=================
-Plot multiple EMD
-=================
+==================
+OT distances in 1D
+==================
-Shows how to compute multiple EMD and Sinkhorn with two different
+Shows how to compute multiple Wassersein and Sinkhorn with two different
ground metrics and plot their values for different distributions.
@@ -14,7 +14,7 @@ ground metrics and plot their values for different distributions.
#
# License: MIT License
-# sphinx_gallery_thumbnail_number = 3
+# sphinx_gallery_thumbnail_number = 2
import numpy as np
import matplotlib.pylab as pl
@@ -29,7 +29,7 @@ from ot.datasets import make_1D_gauss as gauss
#%% parameters
n = 100 # nb bins
-n_target = 50 # nb target distributions
+n_target = 20 # nb target distributions
# bin positions
@@ -47,9 +47,9 @@ for i, m in enumerate(lst_m):
# loss matrix and normalization
M = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)), 'euclidean')
-M /= M.max()
+M /= M.max() * 0.1
M2 = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)), 'sqeuclidean')
-M2 /= M2.max()
+M2 /= M2.max() * 0.1
##############################################################################
# Plot data
@@ -59,10 +59,12 @@ M2 /= M2.max()
pl.figure(1)
pl.subplot(2, 1, 1)
-pl.plot(x, a, 'b', label='Source distribution')
+pl.plot(x, a, 'r', label='Source distribution')
pl.title('Source distribution')
pl.subplot(2, 1, 2)
-pl.plot(x, B, label='Target distributions')
+for i in range(n_target):
+ pl.plot(x, B[:, i], 'b', alpha=i / n_target)
+pl.plot(x, B[:, -1], 'b', label='Target distributions')
pl.title('Target distributions')
pl.tight_layout()
@@ -73,14 +75,27 @@ pl.tight_layout()
#%% Compute and plot distributions and loss matrix
-d_emd = ot.emd2(a, B, M) # direct computation of EMD
-d_emd2 = ot.emd2(a, B, M2) # direct computation of EMD with loss M2
-
+d_emd = ot.emd2(a, B, M) # direct computation of OT loss
+d_emd2 = ot.emd2(a, B, M2) # direct computation of OT loss with metrixc M2
+d_tv = [np.sum(abs(a - B[:, i])) for i in range(n_target)]
pl.figure(2)
-pl.plot(d_emd, label='Euclidean EMD')
-pl.plot(d_emd2, label='Squared Euclidean EMD')
-pl.title('EMD distances')
+pl.subplot(2, 1, 1)
+pl.plot(x, a, 'r', label='Source distribution')
+pl.title('Distributions')
+for i in range(n_target):
+ pl.plot(x, B[:, i], 'b', alpha=i / n_target)
+pl.plot(x, B[:, -1], 'b', label='Target distributions')
+pl.ylim((-.01, 0.13))
+pl.xticks(())
+pl.legend()
+pl.subplot(2, 1, 2)
+pl.plot(d_emd, label='Euclidean OT')
+pl.plot(d_emd2, label='Squared Euclidean OT')
+pl.plot(d_tv, label='Total Variation (TV)')
+#pl.xlim((-7,23))
+pl.xlabel('Displacement')
+pl.title('Divergences')
pl.legend()
##############################################################################
@@ -88,17 +103,30 @@ pl.legend()
# -----------------------------------------
#%%
-reg = 1e-2
+reg = 1e-1
d_sinkhorn = ot.sinkhorn2(a, B, M, reg)
d_sinkhorn2 = ot.sinkhorn2(a, B, M2, reg)
-pl.figure(2)
+pl.figure(3)
pl.clf()
-pl.plot(d_emd, label='Euclidean EMD')
-pl.plot(d_emd2, label='Squared Euclidean EMD')
+
+pl.subplot(2, 1, 1)
+pl.plot(x, a, 'r', label='Source distribution')
+pl.title('Distributions')
+for i in range(n_target):
+ pl.plot(x, B[:, i], 'b', alpha=i / n_target)
+pl.plot(x, B[:, -1], 'b', label='Target distributions')
+pl.ylim((-.01, 0.13))
+pl.xticks(())
+pl.legend()
+pl.subplot(2, 1, 2)
+pl.plot(d_emd, label='Euclidean OT')
+pl.plot(d_emd2, label='Squared Euclidean OT')
pl.plot(d_sinkhorn, '+', label='Euclidean Sinkhorn')
pl.plot(d_sinkhorn2, '+', label='Squared Euclidean Sinkhorn')
-pl.title('EMD distances')
+pl.plot(d_tv, label='Total Variation (TV)')
+#pl.xlim((-7,23))
+pl.xlabel('Displacement')
+pl.title('Divergences')
pl.legend()
-
pl.show()