diff options
Diffstat (limited to 'test/test_bregman.py')
-rw-r--r-- | test/test_bregman.py | 34 |
1 files changed, 34 insertions, 0 deletions
diff --git a/test/test_bregman.py b/test/test_bregman.py index b204fe4..025568c 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -97,3 +97,37 @@ def test_bary(): bary_wass = ot.bregman.barycenter(A, M, reg, weights) assert np.allclose(1, np.sum(bary_wass)) + + ot.bregman.barycenter(A, M, reg, log=True, verbose=True) + + +def test_unmix(): + + n = 50 # nb bins + + # Gaussian distributions + a1 = ot.datasets.get_1D_gauss(n, m=20, s=10) # m= mean, s= std + a2 = ot.datasets.get_1D_gauss(n, m=40, s=10) + + a = ot.datasets.get_1D_gauss(n, m=30, s=10) + + # creating matrix A containing all distributions + D = np.vstack((a1, a2)).T + + # loss matrix + normalization + M = ot.utils.dist0(n) + M /= M.max() + + M0 = ot.utils.dist0(2) + M0 /= M0.max() + h0 = ot.unif(2) + + # wasserstein + reg = 1e-3 + um = ot.bregman.unmix(a, D, M, M0, h0, reg, 1, alpha=0.01,) + + assert np.allclose(1, np.sum(um), rtol=1e-03, atol=1e-03) + assert np.allclose([0.5, 0.5], um, rtol=1e-03, atol=1e-03) + + ot.bregman.unmix(a, D, M, M0, h0, reg, + 1, alpha=0.01, log=True, verbose=True) |