summaryrefslogtreecommitdiff
path: root/examples/plot_barycenter_lp_vs_entropic.py
diff options
context:
space:
mode:
Diffstat (limited to 'examples/plot_barycenter_lp_vs_entropic.py')
-rw-r--r--examples/plot_barycenter_lp_vs_entropic.py29
1 files changed, 9 insertions, 20 deletions
diff --git a/examples/plot_barycenter_lp_vs_entropic.py b/examples/plot_barycenter_lp_vs_entropic.py
index 6936bbb..b82765e 100644
--- a/examples/plot_barycenter_lp_vs_entropic.py
+++ b/examples/plot_barycenter_lp_vs_entropic.py
@@ -15,8 +15,6 @@ Wasserstein problems. SIAM Journal on Imaging Sciences, 9(1), 320-343.
Iterative Bregman projections for regularized transportation problems
SIAM Journal on Scientific Computing, 37(2), A1111-A1138.
-
-
"""
# Author: Remi Flamary <remi.flamary@unice.fr>
@@ -32,8 +30,8 @@ from matplotlib.collections import PolyCollection # noqa
#import ot.lp.cvx as cvx
-#
-# Generate data
+##############################################################################
+# Gaussian Data
# -------------
#%% parameters
@@ -47,8 +45,8 @@ x = np.arange(n, dtype=np.float64)
# Gaussian distributions
# Gaussian distributions
-a1 = ot.datasets.get_1D_gauss(n, m=20, s=5) # m= mean, s= std
-a2 = ot.datasets.get_1D_gauss(n, m=60, s=8)
+a1 = ot.datasets.make_1D_gauss(n, m=20, s=5) # m= mean, s= std
+a2 = ot.datasets.make_1D_gauss(n, m=60, s=8)
# creating matrix A containing all distributions
A = np.vstack((a1, a2)).T
@@ -58,9 +56,6 @@ n_distributions = A.shape[1]
M = ot.utils.dist0(n)
M /= M.max()
-#
-# Plot data
-# ---------
#%% plot the distributions
@@ -70,10 +65,6 @@ for i in range(n_distributions):
pl.title('Distributions')
pl.tight_layout()
-#
-# Barycenter computation
-# ----------------------
-
#%% barycenter computation
alpha = 0.5 # 0<=alpha<=1
@@ -110,6 +101,10 @@ pl.tight_layout()
problems.append([A, [bary_l2, bary_wass, bary_wass2]])
+##############################################################################
+# Dirac Data
+# ----------
+
#%% parameters
a1 = 1.0 * (x > 10) * (x < 50)
@@ -135,9 +130,6 @@ for i in range(n_distributions):
pl.title('Distributions')
pl.tight_layout()
-#
-# Barycenter computation
-# ----------------------
#%% barycenter computation
@@ -207,9 +199,6 @@ for i in range(n_distributions):
pl.title('Distributions')
pl.tight_layout()
-#
-# Barycenter computation
-# ----------------------
#%% barycenter computation
@@ -249,7 +238,7 @@ pl.title('Barycenters')
pl.tight_layout()
-#
+##############################################################################
# Final figure
# ------------
#