diff options
author | Rémi Flamary <remi.flamary@gmail.com> | 2017-07-24 15:28:43 +0200 |
---|---|---|
committer | Rémi Flamary <remi.flamary@gmail.com> | 2017-07-24 15:28:43 +0200 |
commit | 64cf2fc4f9a9331d510afd93e9bd3b8963ff879e (patch) | |
tree | cf62fc6bc040fee9da3c61978fa356e888385a26 /test | |
parent | 83ecc6df836d1a6b05bd641dfef465cc02b25b8f (diff) |
tets barycenter
Diffstat (limited to 'test')
-rw-r--r-- | test/test_bregman.py | 29 |
1 files changed, 29 insertions, 0 deletions
diff --git a/test/test_bregman.py b/test/test_bregman.py index 78666c7..2dd3498 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -72,3 +72,32 @@ def test_sinkhorn_variants(): assert np.allclose(G0, Gs, atol=1e-05) assert np.allclose(G0, Ges, atol=1e-05) assert np.allclose(G0, Gerr) + + +def test_bary(): + + n = 100 # nb bins + + # bin positions + x = np.arange(n, dtype=np.float64) + + # Gaussian distributions + a1 = ot.datasets.get_1D_gauss(n, m=30, s=10) # m= mean, s= std + a2 = ot.datasets.get_1D_gauss(n, m=40, s=10) + + # creating matrix A containing all distributions + A = np.vstack((a1, a2)).T + n_distributions = A.shape[1] + + # loss matrix + normalization + M = ot.utils.dist0(n) + M /= M.max() + + alpha = 0.5 # 0<=alpha<=1 + weights = np.array([1 - alpha, alpha]) + + # wasserstein + reg = 1e-3 + bary_wass = ot.bregman.barycenter(A, M, reg, weights) + + assert np.allclose(1, np.sum(bary_wass)) |