From e39f04a9465bd9f1447423eb2a592cc9356589a9 Mon Sep 17 00:00:00 2001 From: Vivien Seguy Date: Thu, 5 Jul 2018 19:01:10 +0900 Subject: add free support barycenter algorithm --- examples/plot_free_support_barycenter.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) (limited to 'examples') diff --git a/examples/plot_free_support_barycenter.py b/examples/plot_free_support_barycenter.py index 274cf76..61671cf 100644 --- a/examples/plot_free_support_barycenter.py +++ b/examples/plot_free_support_barycenter.py @@ -21,7 +21,7 @@ import ot.plot # Generate data # ------------- #%% parameters and data generation -N = 4 +N = 6 d = 2 measures_locations = [] measures_weights = [] @@ -30,11 +30,13 @@ for i in range(N): n = np.random.randint(low=1, high=20) # nb samples - mu = np.random.normal(0., 1., (d,)) - cov = np.random.uniform(0., 1., (d,d)) + mu = np.random.normal(0., 4., (d,)) + + A = np.random.rand(d, d) + cov = np.dot(A,A.transpose()) xs = ot.datasets.make_2D_samples_gauss(n, mu, cov) - b = np.random.uniform(0., 1., n) + b = np.random.uniform(0., 1., (n,)) b = b/np.sum(b) measures_locations.append(xs) @@ -59,7 +61,9 @@ X = ot.lp.cvx.free_support_barycenter(measures_locations, measures_weights, X_in pl.figure(1) for (xs, b) in zip(measures_locations, measures_weights): - pl.scatter(xs[:, 0], xs[:, 1], s=b, c=np.tile(np.random.uniform(0. ,255., size=(3,)), (1,b.size(0))) , label='Data measures') -pl.scatter(xs[:, 0], xs[:, 1], s=b, c='black' , label='2-Wasserstein barycenter') -pl.legend(loc=0) + color = np.random.randint(low=1, high=10*N) + pl.scatter(xs[:, 0], xs[:, 1], s=b*1000, label='input measure') +pl.scatter(X[:, 0], X[:, 1], s=b_init*1000, c='black' , marker='^', label='2-Wasserstein barycenter') pl.title('Data measures and their barycenter') +pl.legend(loc=0) +pl.show() \ No newline at end of file -- cgit v1.2.3