From 98ce4ccd3536d95c609ee1c5b737ced85d68f786 Mon Sep 17 00:00:00 2001 From: vivienseguy Date: Thu, 5 Jul 2018 18:26:55 +0900 Subject: free support barycenter --- examples/plot_free_support_barycenter.py | 9 ++++----- ot/lp/cvx.py | 2 +- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/examples/plot_free_support_barycenter.py b/examples/plot_free_support_barycenter.py index a733745..274cf76 100644 --- a/examples/plot_free_support_barycenter.py +++ b/examples/plot_free_support_barycenter.py @@ -17,7 +17,6 @@ import numpy as np import matplotlib.pylab as pl import ot.plot - ############################################################################## # Generate data # ------------- @@ -29,10 +28,10 @@ measures_weights = [] for i in range(N): - n = np.rand.int(low=1, high=20) # nb samples + n = np.random.randint(low=1, high=20) # nb samples mu = np.random.normal(0., 1., (d,)) - cov = np.random.normal(0., 1., (d,d)) + cov = np.random.uniform(0., 1., (d,d)) xs = ot.datasets.make_2D_samples_gauss(n, mu, cov) b = np.random.uniform(0., 1., n) @@ -49,7 +48,7 @@ b_init = np.ones((k,)) / k ############################################################################## # Compute free support barycenter # ------------- -X = ot.lp.barycenter(measures_locations, measures_weights, X_init, b_init) +X = ot.lp.cvx.free_support_barycenter(measures_locations, measures_weights, X_init, b_init) ############################################################################## @@ -60,7 +59,7 @@ X = ot.lp.barycenter(measures_locations, measures_weights, X_init, b_init) pl.figure(1) for (xs, b) in zip(measures_locations, measures_weights): - pl.scatter(xs[:, 0], xs[:, 1], s=b, c=np.tile(np.rand(0. ,255., size=(3,)), (1,b.size(0))) , label='Data measures') + pl.scatter(xs[:, 0], xs[:, 1], s=b, c=np.tile(np.random.uniform(0. ,255., size=(3,)), (1,b.size(0))) , label='Data measures') pl.scatter(xs[:, 0], xs[:, 1], s=b, c='black' , label='2-Wasserstein barycenter') pl.legend(loc=0) pl.title('Data measures and their barycenter') diff --git a/ot/lp/cvx.py b/ot/lp/cvx.py index 91a5922..b74960f 100644 --- a/ot/lp/cvx.py +++ b/ot/lp/cvx.py @@ -217,7 +217,7 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b_init M_i = ot.dist(X, measure_locations_i) T_i = ot.emd(b_init, measure_weights_i, M_i) - T_sum += np.reshape(1. / b_init, (-1, 1)) * np.matmul(T_i, measure_locations_i) + T_sum = T_sum + weight_i*np.reshape(1. / b_init, (-1, 1)) * np.matmul(T_i, measure_locations_i) displacement_square_norm = np.sum(np.square(X-T_sum)) X = T_sum -- cgit v1.2.3