summaryrefslogtreecommitdiff
path: root/docs/source/auto_examples/plot_compute_emd.rst
diff options
context:
space:
mode:
Diffstat (limited to 'docs/source/auto_examples/plot_compute_emd.rst')
-rw-r--r--docs/source/auto_examples/plot_compute_emd.rst153
1 files changed, 109 insertions, 44 deletions
diff --git a/docs/source/auto_examples/plot_compute_emd.rst b/docs/source/auto_examples/plot_compute_emd.rst
index 4c7445b..cdbc620 100644
--- a/docs/source/auto_examples/plot_compute_emd.rst
+++ b/docs/source/auto_examples/plot_compute_emd.rst
@@ -3,101 +3,166 @@
.. _sphx_glr_auto_examples_plot_compute_emd.py:
-====================
-1D optimal transport
-====================
+=================
+Plot multiple EMD
+=================
-@author: rflamary
+Shows how to compute multiple EMD and Sinkhorn with two differnt
+ground metrics and plot their values for diffeent distributions.
-.. rst-class:: sphx-glr-horizontal
+.. code-block:: python
- *
- .. image:: /auto_examples/images/sphx_glr_plot_compute_emd_001.png
- :scale: 47
+ # Author: Remi Flamary <remi.flamary@unice.fr>
+ #
+ # License: MIT License
- *
+ import numpy as np
+ import matplotlib.pylab as pl
+ import ot
+ from ot.datasets import get_1D_gauss as gauss
- .. image:: /auto_examples/images/sphx_glr_plot_compute_emd_002.png
- :scale: 47
-.. code-block:: python
- import numpy as np
- import matplotlib.pylab as pl
- import ot
- from ot.datasets import get_1D_gauss as gauss
+Generate data
+-------------
+
+
+
+.. code-block:: python
#%% parameters
- n=100 # nb bins
- n_target=50 # nb target distributions
+ n = 100 # nb bins
+ n_target = 50 # nb target distributions
# bin positions
- x=np.arange(n,dtype=np.float64)
+ x = np.arange(n, dtype=np.float64)
- lst_m=np.linspace(20,90,n_target)
+ lst_m = np.linspace(20, 90, n_target)
# Gaussian distributions
- a=gauss(n,m=20,s=5) # m= mean, s= std
+ a = gauss(n, m=20, s=5) # m= mean, s= std
- B=np.zeros((n,n_target))
+ B = np.zeros((n, n_target))
- for i,m in enumerate(lst_m):
- B[:,i]=gauss(n,m=m,s=5)
+ for i, m in enumerate(lst_m):
+ B[:, i] = gauss(n, m=m, s=5)
# loss matrix and normalization
- M=ot.dist(x.reshape((n,1)),x.reshape((n,1)),'euclidean')
- M/=M.max()
- M2=ot.dist(x.reshape((n,1)),x.reshape((n,1)),'sqeuclidean')
- M2/=M2.max()
+ M = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)), 'euclidean')
+ M /= M.max()
+ M2 = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)), 'sqeuclidean')
+ M2 /= M2.max()
+
+
+
+
+
+
+
+Plot data
+---------
+
+
+
+.. code-block:: python
+
+
#%% plot the distributions
pl.figure(1)
- pl.subplot(2,1,1)
- pl.plot(x,a,'b',label='Source distribution')
+ pl.subplot(2, 1, 1)
+ pl.plot(x, a, 'b', label='Source distribution')
pl.title('Source distribution')
- pl.subplot(2,1,2)
- pl.plot(x,B,label='Target distributions')
+ pl.subplot(2, 1, 2)
+ pl.plot(x, B, label='Target distributions')
pl.title('Target distributions')
+ pl.tight_layout()
+
+
+
+
+
+.. image:: /auto_examples/images/sphx_glr_plot_compute_emd_001.png
+ :align: center
+
+
+
+
+Compute EMD for the different losses
+------------------------------------
+
+
+
+.. code-block:: python
+
#%% Compute and plot distributions and loss matrix
- d_emd=ot.emd2(a,B,M) # direct computation of EMD
- d_emd2=ot.emd2(a,B,M2) # direct computation of EMD with loss M3
+ d_emd = ot.emd2(a, B, M) # direct computation of EMD
+ d_emd2 = ot.emd2(a, B, M2) # direct computation of EMD with loss M2
pl.figure(2)
- pl.plot(d_emd,label='Euclidean EMD')
- pl.plot(d_emd2,label='Squared Euclidean EMD')
+ pl.plot(d_emd, label='Euclidean EMD')
+ pl.plot(d_emd2, label='Squared Euclidean EMD')
pl.title('EMD distances')
pl.legend()
+
+
+
+.. image:: /auto_examples/images/sphx_glr_plot_compute_emd_003.png
+ :align: center
+
+
+
+
+Compute Sinkhorn for the different losses
+-----------------------------------------
+
+
+
+.. code-block:: python
+
+
#%%
- reg=1e-2
- d_sinkhorn=ot.sinkhorn(a,B,M,reg)
- d_sinkhorn2=ot.sinkhorn(a,B,M2,reg)
+ reg = 1e-2
+ d_sinkhorn = ot.sinkhorn2(a, B, M, reg)
+ d_sinkhorn2 = ot.sinkhorn2(a, B, M2, reg)
pl.figure(2)
pl.clf()
- pl.plot(d_emd,label='Euclidean EMD')
- pl.plot(d_emd2,label='Squared Euclidean EMD')
- pl.plot(d_sinkhorn,'+',label='Euclidean Sinkhorn')
- pl.plot(d_sinkhorn2,'+',label='Squared Euclidean Sinkhorn')
+ pl.plot(d_emd, label='Euclidean EMD')
+ pl.plot(d_emd2, label='Squared Euclidean EMD')
+ pl.plot(d_sinkhorn, '+', label='Euclidean Sinkhorn')
+ pl.plot(d_sinkhorn2, '+', label='Squared Euclidean Sinkhorn')
pl.title('EMD distances')
pl.legend()
-**Total running time of the script:** ( 0 minutes 0.521 seconds)
+
+ pl.show()
+
+
+
+.. image:: /auto_examples/images/sphx_glr_plot_compute_emd_004.png
+ :align: center
+
+
+
+
+**Total running time of the script:** ( 0 minutes 0.697 seconds)