diff options
Diffstat (limited to 'examples/plot_compute_emd.py')
-rw-r--r-- | examples/plot_compute_emd.py | 26 |
1 files changed, 22 insertions, 4 deletions
diff --git a/examples/plot_compute_emd.py b/examples/plot_compute_emd.py index 893eecf..704da0e 100644 --- a/examples/plot_compute_emd.py +++ b/examples/plot_compute_emd.py @@ -1,8 +1,8 @@ # -*- coding: utf-8 -*- """ -==================== -1D optimal transport -==================== +================= +Plot multiple EMD +================= """ @@ -16,6 +16,10 @@ import ot from ot.datasets import get_1D_gauss as gauss +############################################################################## +# Generate data +############################################################################## + #%% parameters n = 100 # nb bins @@ -40,6 +44,11 @@ M = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)), 'euclidean') M /= M.max() M2 = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)), 'sqeuclidean') M2 /= M2.max() + +############################################################################## +# Plot data +############################################################################## + #%% plot the distributions pl.figure(1) @@ -51,10 +60,15 @@ pl.plot(x, B, label='Target distributions') pl.title('Target distributions') pl.tight_layout() + +############################################################################## +# Compute EMD for the different losses +############################################################################## + #%% Compute and plot distributions and loss matrix d_emd = ot.emd2(a, B, M) # direct computation of EMD -d_emd2 = ot.emd2(a, B, M2) # direct computation of EMD with loss M3 +d_emd2 = ot.emd2(a, B, M2) # direct computation of EMD with loss M2 pl.figure(2) @@ -63,6 +77,10 @@ pl.plot(d_emd2, label='Squared Euclidean EMD') pl.title('EMD distances') pl.legend() +############################################################################## +# Compute Sinkhorn for the different losses +############################################################################## + #%% reg = 1e-2 d_sinkhorn = ot.sinkhorn2(a, B, M, reg) |