summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorvivienseguy <vivienseguy@gmail.com>2018-07-05 18:26:55 +0900
committervivienseguy <vivienseguy@gmail.com>2018-07-05 18:26:55 +0900
commit98ce4ccd3536d95c609ee1c5b737ced85d68f786 (patch)
treed5716a2854e15c9acbb1aa40d09549fa87f7a51e
parent3f23fa1a950ffde4a4224a6343a504a0c5b7851b (diff)
free support barycenter
-rw-r--r--examples/plot_free_support_barycenter.py9
-rw-r--r--ot/lp/cvx.py2
2 files changed, 5 insertions, 6 deletions
diff --git a/examples/plot_free_support_barycenter.py b/examples/plot_free_support_barycenter.py
index a733745..274cf76 100644
--- a/examples/plot_free_support_barycenter.py
+++ b/examples/plot_free_support_barycenter.py
@@ -17,7 +17,6 @@ import numpy as np
import matplotlib.pylab as pl
import ot.plot
-
##############################################################################
# Generate data
# -------------
@@ -29,10 +28,10 @@ measures_weights = []
for i in range(N):
- n = np.rand.int(low=1, high=20) # nb samples
+ n = np.random.randint(low=1, high=20) # nb samples
mu = np.random.normal(0., 1., (d,))
- cov = np.random.normal(0., 1., (d,d))
+ cov = np.random.uniform(0., 1., (d,d))
xs = ot.datasets.make_2D_samples_gauss(n, mu, cov)
b = np.random.uniform(0., 1., n)
@@ -49,7 +48,7 @@ b_init = np.ones((k,)) / k
##############################################################################
# Compute free support barycenter
# -------------
-X = ot.lp.barycenter(measures_locations, measures_weights, X_init, b_init)
+X = ot.lp.cvx.free_support_barycenter(measures_locations, measures_weights, X_init, b_init)
##############################################################################
@@ -60,7 +59,7 @@ X = ot.lp.barycenter(measures_locations, measures_weights, X_init, b_init)
pl.figure(1)
for (xs, b) in zip(measures_locations, measures_weights):
- pl.scatter(xs[:, 0], xs[:, 1], s=b, c=np.tile(np.rand(0. ,255., size=(3,)), (1,b.size(0))) , label='Data measures')
+ pl.scatter(xs[:, 0], xs[:, 1], s=b, c=np.tile(np.random.uniform(0. ,255., size=(3,)), (1,b.size(0))) , label='Data measures')
pl.scatter(xs[:, 0], xs[:, 1], s=b, c='black' , label='2-Wasserstein barycenter')
pl.legend(loc=0)
pl.title('Data measures and their barycenter')
diff --git a/ot/lp/cvx.py b/ot/lp/cvx.py
index 91a5922..b74960f 100644
--- a/ot/lp/cvx.py
+++ b/ot/lp/cvx.py
@@ -217,7 +217,7 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b_init
M_i = ot.dist(X, measure_locations_i)
T_i = ot.emd(b_init, measure_weights_i, M_i)
- T_sum += np.reshape(1. / b_init, (-1, 1)) * np.matmul(T_i, measure_locations_i)
+ T_sum = T_sum + weight_i*np.reshape(1. / b_init, (-1, 1)) * np.matmul(T_i, measure_locations_i)
displacement_square_norm = np.sum(np.square(X-T_sum))
X = T_sum