summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2017-07-24 15:28:43 +0200
committerRémi Flamary <remi.flamary@gmail.com>2017-07-24 15:28:43 +0200
commit64cf2fc4f9a9331d510afd93e9bd3b8963ff879e (patch)
treecf62fc6bc040fee9da3c61978fa356e888385a26
parent83ecc6df836d1a6b05bd641dfef465cc02b25b8f (diff)
tets barycenter
-rw-r--r--test/test_bregman.py29
1 files changed, 29 insertions, 0 deletions
diff --git a/test/test_bregman.py b/test/test_bregman.py
index 78666c7..2dd3498 100644
--- a/test/test_bregman.py
+++ b/test/test_bregman.py
@@ -72,3 +72,32 @@ def test_sinkhorn_variants():
assert np.allclose(G0, Gs, atol=1e-05)
assert np.allclose(G0, Ges, atol=1e-05)
assert np.allclose(G0, Gerr)
+
+
+def test_bary():
+
+ n = 100 # nb bins
+
+ # bin positions
+ x = np.arange(n, dtype=np.float64)
+
+ # Gaussian distributions
+ a1 = ot.datasets.get_1D_gauss(n, m=30, s=10) # m= mean, s= std
+ a2 = ot.datasets.get_1D_gauss(n, m=40, s=10)
+
+ # creating matrix A containing all distributions
+ A = np.vstack((a1, a2)).T
+ n_distributions = A.shape[1]
+
+ # loss matrix + normalization
+ M = ot.utils.dist0(n)
+ M /= M.max()
+
+ alpha = 0.5 # 0<=alpha<=1
+ weights = np.array([1 - alpha, alpha])
+
+ # wasserstein
+ reg = 1e-3
+ bary_wass = ot.bregman.barycenter(A, M, reg, weights)
+
+ assert np.allclose(1, np.sum(bary_wass))