summaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
authorVivien Seguy <vivienseguy@h40.57.229.10.1016746.vlan.kuins.net>2018-07-09 16:49:21 +0900
committerVivien Seguy <vivienseguy@h40.57.229.10.1016746.vlan.kuins.net>2018-07-09 16:49:21 +0900
commit46712790c1276f1ecb3496362a8117e153782ede (patch)
tree15c66c04b24396078b60b8f0a89b7a9cc6476b50 /examples
parent67ddb92e28d6bb44eb65686419e255c2ce3311eb (diff)
add test free support barycenter algorithm + cleaning
Diffstat (limited to 'examples')
-rw-r--r--examples/plot_free_support_barycenter.py29
1 files changed, 14 insertions, 15 deletions
diff --git a/examples/plot_free_support_barycenter.py b/examples/plot_free_support_barycenter.py
index 5b08507..b2e62c8 100644
--- a/examples/plot_free_support_barycenter.py
+++ b/examples/plot_free_support_barycenter.py
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
"""
====================================================
-2D Wasserstein barycenters of distributions
+2D free support Wasserstein barycenters of distributions
====================================================
Illustration of 2D Wasserstein barycenters if discributions that are weighted
@@ -15,7 +15,8 @@ sum of diracs.
import numpy as np
import matplotlib.pylab as pl
-import ot.plot
+import ot
+
##############################################################################
# Generate data
@@ -28,16 +29,16 @@ measures_weights = []
for i in range(N):
- n = np.random.randint(low=1, high=20) # nb samples
+ n_i = np.random.randint(low=1, high=20) # nb samples
- mu = np.random.normal(0., 4., (d,))
+ mu_i = np.random.normal(0., 4., (d,)) # Gaussian mean
- A = np.random.rand(d, d)
- cov = np.dot(A, A.transpose())
+ A_i = np.random.rand(d, d)
+ cov_i = np.dot(A_i, A_i.transpose()) # Gaussian covariance matrix
- x_i = ot.datasets.make_2D_samples_gauss(n, mu, cov)
- b_i = np.random.uniform(0., 1., (n,))
- b_i = b_i / np.sum(b_i)
+ x_i = ot.datasets.make_2D_samples_gauss(n_i, mu_i, cov_i) # Dirac locations
+ b_i = np.random.uniform(0., 1., (n_i,))
+ b_i = b_i / np.sum(b_i) # Dirac weights
measures_locations.append(x_i)
measures_weights.append(b_i)
@@ -47,19 +48,17 @@ for i in range(N):
# Compute free support barycenter
# -------------
-k = 10
-X_init = np.random.normal(0., 1., (k, d))
-b = np.ones((k,)) / k
+k = 10 # number of Diracs of the barycenter
+X_init = np.random.normal(0., 1., (k, d)) # initial Dirac locations
+b = np.ones((k,)) / k # weights of the barycenter (it will not be optimized, only the locations are optimized)
-X = ot.lp.cvx.free_support_barycenter(measures_locations, measures_weights, X_init, b)
+X = ot.lp.free_support_barycenter(measures_locations, measures_weights, X_init, b)
##############################################################################
# Plot data
# ---------
-#%% plot samples
-
pl.figure(1)
for (x_i, b_i) in zip(measures_locations, measures_weights):
color = np.random.randint(low=1, high=10 * N)