diff options
Diffstat (limited to 'docs/source/auto_examples/plot_compute_emd.py')
-rw-r--r-- | docs/source/auto_examples/plot_compute_emd.py | 30 |
1 files changed, 26 insertions, 4 deletions
diff --git a/docs/source/auto_examples/plot_compute_emd.py b/docs/source/auto_examples/plot_compute_emd.py index 893eecf..b688f93 100644 --- a/docs/source/auto_examples/plot_compute_emd.py +++ b/docs/source/auto_examples/plot_compute_emd.py @@ -1,8 +1,12 @@ # -*- coding: utf-8 -*- """ -==================== -1D optimal transport -==================== +================= +Plot multiple EMD +================= + +Shows how to compute multiple EMD and Sinkhorn with two differnt +ground metrics and plot their values for diffeent distributions. + """ @@ -16,6 +20,10 @@ import ot from ot.datasets import get_1D_gauss as gauss +############################################################################## +# Generate data +############################################################################## + #%% parameters n = 100 # nb bins @@ -40,6 +48,11 @@ M = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)), 'euclidean') M /= M.max() M2 = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)), 'sqeuclidean') M2 /= M2.max() + +############################################################################## +# Plot data +############################################################################## + #%% plot the distributions pl.figure(1) @@ -51,10 +64,15 @@ pl.plot(x, B, label='Target distributions') pl.title('Target distributions') pl.tight_layout() + +############################################################################## +# Compute EMD for the different losses +############################################################################## + #%% Compute and plot distributions and loss matrix d_emd = ot.emd2(a, B, M) # direct computation of EMD -d_emd2 = ot.emd2(a, B, M2) # direct computation of EMD with loss M3 +d_emd2 = ot.emd2(a, B, M2) # direct computation of EMD with loss M2 pl.figure(2) @@ -63,6 +81,10 @@ pl.plot(d_emd2, label='Squared Euclidean EMD') pl.title('EMD distances') pl.legend() +############################################################################## +# Compute Sinkhorn for the different losses +############################################################################## + #%% reg = 1e-2 d_sinkhorn = ot.sinkhorn2(a, B, M, reg) |