summaryrefslogtreecommitdiff
path: root/examples/plot_compute_emd.py
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2017-08-31 09:28:37 +0200
committerRémi Flamary <remi.flamary@gmail.com>2017-08-31 09:28:37 +0200
commit212f3889b1114026765cda0134e02766daa82af2 (patch)
treef9ea2d2566d1544b3409152f8ebbc8ca706c96e2 /examples/plot_compute_emd.py
parentec67362de5ec785e3871eac75a8aa477857092c4 (diff)
update tests
Diffstat (limited to 'examples/plot_compute_emd.py')
-rw-r--r--examples/plot_compute_emd.py26
1 files changed, 22 insertions, 4 deletions
diff --git a/examples/plot_compute_emd.py b/examples/plot_compute_emd.py
index 893eecf..704da0e 100644
--- a/examples/plot_compute_emd.py
+++ b/examples/plot_compute_emd.py
@@ -1,8 +1,8 @@
# -*- coding: utf-8 -*-
"""
-====================
-1D optimal transport
-====================
+=================
+Plot multiple EMD
+=================
"""
@@ -16,6 +16,10 @@ import ot
from ot.datasets import get_1D_gauss as gauss
+##############################################################################
+# Generate data
+##############################################################################
+
#%% parameters
n = 100 # nb bins
@@ -40,6 +44,11 @@ M = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)), 'euclidean')
M /= M.max()
M2 = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)), 'sqeuclidean')
M2 /= M2.max()
+
+##############################################################################
+# Plot data
+##############################################################################
+
#%% plot the distributions
pl.figure(1)
@@ -51,10 +60,15 @@ pl.plot(x, B, label='Target distributions')
pl.title('Target distributions')
pl.tight_layout()
+
+##############################################################################
+# Compute EMD for the different losses
+##############################################################################
+
#%% Compute and plot distributions and loss matrix
d_emd = ot.emd2(a, B, M) # direct computation of EMD
-d_emd2 = ot.emd2(a, B, M2) # direct computation of EMD with loss M3
+d_emd2 = ot.emd2(a, B, M2) # direct computation of EMD with loss M2
pl.figure(2)
@@ -63,6 +77,10 @@ pl.plot(d_emd2, label='Squared Euclidean EMD')
pl.title('EMD distances')
pl.legend()
+##############################################################################
+# Compute Sinkhorn for the different losses
+##############################################################################
+
#%%
reg = 1e-2
d_sinkhorn = ot.sinkhorn2(a, B, M, reg)