summaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
authorVivien Seguy <vivienseguy@Viviens-MacBook-Pro.local>2018-07-06 01:58:47 +0900
committerVivien Seguy <vivienseguy@Viviens-MacBook-Pro.local>2018-07-06 01:58:47 +0900
commit2c7b98009f33e278a2e7e95a035c6a6231bec44e (patch)
tree9c5b5730ec159f9cbb832bc1bf9d221092cbb14b /examples
parente39f04a9465bd9f1447423eb2a592cc9356589a9 (diff)
add free support barycenter algorithm
Diffstat (limited to 'examples')
-rw-r--r--examples/plot_free_support_barycenter.py35
1 files changed, 18 insertions, 17 deletions
diff --git a/examples/plot_free_support_barycenter.py b/examples/plot_free_support_barycenter.py
index 61671cf..42e22fc 100644
--- a/examples/plot_free_support_barycenter.py
+++ b/examples/plot_free_support_barycenter.py
@@ -21,7 +21,7 @@ import ot.plot
# Generate data
# -------------
#%% parameters and data generation
-N = 6
+N = 3
d = 2
measures_locations = []
measures_weights = []
@@ -33,24 +33,25 @@ for i in range(N):
mu = np.random.normal(0., 4., (d,))
A = np.random.rand(d, d)
- cov = np.dot(A,A.transpose())
+ cov = np.dot(A, A.transpose())
- xs = ot.datasets.make_2D_samples_gauss(n, mu, cov)
- b = np.random.uniform(0., 1., (n,))
- b = b/np.sum(b)
+ x_i = ot.datasets.make_2D_samples_gauss(n, mu, cov)
+ b_i = np.random.uniform(0., 1., (n,))
+ b_i = b_i / np.sum(b_i)
- measures_locations.append(xs)
- measures_weights.append(b)
-
-k = 10
-X_init = np.random.normal(0., 1., (k,d))
-b_init = np.ones((k,)) / k
+ measures_locations.append(x_i)
+ measures_weights.append(b_i)
##############################################################################
# Compute free support barycenter
# -------------
-X = ot.lp.cvx.free_support_barycenter(measures_locations, measures_weights, X_init, b_init)
+
+k = 10
+X_init = np.random.normal(0., 1., (k, d))
+b = np.ones((k,)) / k
+
+X = ot.lp.cvx.free_support_barycenter(measures_locations, measures_weights, X_init, b)
##############################################################################
@@ -60,10 +61,10 @@ X = ot.lp.cvx.free_support_barycenter(measures_locations, measures_weights, X_in
#%% plot samples
pl.figure(1)
-for (xs, b) in zip(measures_locations, measures_weights):
- color = np.random.randint(low=1, high=10*N)
- pl.scatter(xs[:, 0], xs[:, 1], s=b*1000, label='input measure')
-pl.scatter(X[:, 0], X[:, 1], s=b_init*1000, c='black' , marker='^', label='2-Wasserstein barycenter')
+for (x_i, b_i) in zip(measures_locations, measures_weights):
+ color = np.random.randint(low=1, high=10 * N)
+ pl.scatter(x_i[:, 0], x_i[:, 1], s=b * 1000, label='input measure')
+pl.scatter(X[:, 0], X[:, 1], s=b * 1000, c='black', marker='^', label='2-Wasserstein barycenter')
pl.title('Data measures and their barycenter')
pl.legend(loc=0)
-pl.show() \ No newline at end of file
+pl.show()