From 64cf2fc4f9a9331d510afd93e9bd3b8963ff879e Mon Sep 17 00:00:00 2001 From: RĂ©mi Flamary Date: Mon, 24 Jul 2017 15:28:43 +0200 Subject: tets barycenter --- test/test_bregman.py | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) (limited to 'test') diff --git a/test/test_bregman.py b/test/test_bregman.py index 78666c7..2dd3498 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -72,3 +72,32 @@ def test_sinkhorn_variants(): assert np.allclose(G0, Gs, atol=1e-05) assert np.allclose(G0, Ges, atol=1e-05) assert np.allclose(G0, Gerr) + + +def test_bary(): + + n = 100 # nb bins + + # bin positions + x = np.arange(n, dtype=np.float64) + + # Gaussian distributions + a1 = ot.datasets.get_1D_gauss(n, m=30, s=10) # m= mean, s= std + a2 = ot.datasets.get_1D_gauss(n, m=40, s=10) + + # creating matrix A containing all distributions + A = np.vstack((a1, a2)).T + n_distributions = A.shape[1] + + # loss matrix + normalization + M = ot.utils.dist0(n) + M /= M.max() + + alpha = 0.5 # 0<=alpha<=1 + weights = np.array([1 - alpha, alpha]) + + # wasserstein + reg = 1e-3 + bary_wass = ot.bregman.barycenter(A, M, reg, weights) + + assert np.allclose(1, np.sum(bary_wass)) -- cgit v1.2.3