diff options
Diffstat (limited to 'examples/plot_compute_emd.py')
-rw-r--r-- | examples/plot_compute_emd.py | 102 |
1 files changed, 102 insertions, 0 deletions
diff --git a/examples/plot_compute_emd.py b/examples/plot_compute_emd.py new file mode 100644 index 0000000..7ed2b01 --- /dev/null +++ b/examples/plot_compute_emd.py @@ -0,0 +1,102 @@ +# -*- coding: utf-8 -*- +""" +================= +Plot multiple EMD +================= + +Shows how to compute multiple EMD and Sinkhorn with two differnt +ground metrics and plot their values for diffeent distributions. + + +""" + +# Author: Remi Flamary <remi.flamary@unice.fr> +# +# License: MIT License + +import numpy as np +import matplotlib.pylab as pl +import ot +from ot.datasets import make_1D_gauss as gauss + + +############################################################################## +# Generate data +# ------------- + +#%% parameters + +n = 100 # nb bins +n_target = 50 # nb target distributions + + +# bin positions +x = np.arange(n, dtype=np.float64) + +lst_m = np.linspace(20, 90, n_target) + +# Gaussian distributions +a = gauss(n, m=20, s=5) # m= mean, s= std + +B = np.zeros((n, n_target)) + +for i, m in enumerate(lst_m): + B[:, i] = gauss(n, m=m, s=5) + +# loss matrix and normalization +M = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)), 'euclidean') +M /= M.max() +M2 = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)), 'sqeuclidean') +M2 /= M2.max() + +############################################################################## +# Plot data +# --------- + +#%% plot the distributions + +pl.figure(1) +pl.subplot(2, 1, 1) +pl.plot(x, a, 'b', label='Source distribution') +pl.title('Source distribution') +pl.subplot(2, 1, 2) +pl.plot(x, B, label='Target distributions') +pl.title('Target distributions') +pl.tight_layout() + + +############################################################################## +# Compute EMD for the different losses +# ------------------------------------ + +#%% Compute and plot distributions and loss matrix + +d_emd = ot.emd2(a, B, M) # direct computation of EMD +d_emd2 = ot.emd2(a, B, M2) # direct computation of EMD with loss M2 + + +pl.figure(2) +pl.plot(d_emd, label='Euclidean EMD') +pl.plot(d_emd2, label='Squared Euclidean EMD') +pl.title('EMD distances') +pl.legend() + +############################################################################## +# Compute Sinkhorn for the different losses +# ----------------------------------------- + +#%% +reg = 1e-2 +d_sinkhorn = ot.sinkhorn2(a, B, M, reg) +d_sinkhorn2 = ot.sinkhorn2(a, B, M2, reg) + +pl.figure(2) +pl.clf() +pl.plot(d_emd, label='Euclidean EMD') +pl.plot(d_emd2, label='Squared Euclidean EMD') +pl.plot(d_sinkhorn, '+', label='Euclidean Sinkhorn') +pl.plot(d_sinkhorn2, '+', label='Squared Euclidean Sinkhorn') +pl.title('EMD distances') +pl.legend() + +pl.show() |