diff options
author | Rémi Flamary <remi.flamary@gmail.com> | 2017-07-24 15:45:09 +0200 |
---|---|---|
committer | Rémi Flamary <remi.flamary@gmail.com> | 2017-07-24 15:45:09 +0200 |
commit | bd705ed847dd7e43082e9d2771a59e539d6b7440 (patch) | |
tree | 6aa66499a61a487eebe59b3c75cdbec97e7019b1 /test | |
parent | 33f3d309209baa8c5e127d02f00aae0660ed7bfb (diff) |
add test yunmlix and bary
Diffstat (limited to 'test')
-rw-r--r-- | test/test_bregman.py | 34 | ||||
-rw-r--r-- | test/test_gpu.py | 2 | ||||
-rw-r--r-- | test/test_ot.py | 2 |
3 files changed, 36 insertions, 2 deletions
diff --git a/test/test_bregman.py b/test/test_bregman.py index b204fe4..025568c 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -97,3 +97,37 @@ def test_bary(): bary_wass = ot.bregman.barycenter(A, M, reg, weights) assert np.allclose(1, np.sum(bary_wass)) + + ot.bregman.barycenter(A, M, reg, log=True, verbose=True) + + +def test_unmix(): + + n = 50 # nb bins + + # Gaussian distributions + a1 = ot.datasets.get_1D_gauss(n, m=20, s=10) # m= mean, s= std + a2 = ot.datasets.get_1D_gauss(n, m=40, s=10) + + a = ot.datasets.get_1D_gauss(n, m=30, s=10) + + # creating matrix A containing all distributions + D = np.vstack((a1, a2)).T + + # loss matrix + normalization + M = ot.utils.dist0(n) + M /= M.max() + + M0 = ot.utils.dist0(2) + M0 /= M0.max() + h0 = ot.unif(2) + + # wasserstein + reg = 1e-3 + um = ot.bregman.unmix(a, D, M, M0, h0, reg, 1, alpha=0.01,) + + assert np.allclose(1, np.sum(um), rtol=1e-03, atol=1e-03) + assert np.allclose([0.5, 0.5], um, rtol=1e-03, atol=1e-03) + + ot.bregman.unmix(a, D, M, M0, h0, reg, + 1, alpha=0.01, log=True, verbose=True) diff --git a/test/test_gpu.py b/test/test_gpu.py index 9cc39d7..24797f2 100644 --- a/test/test_gpu.py +++ b/test/test_gpu.py @@ -48,7 +48,7 @@ def test_gpu_sinkhorn_lpl1(): print("min:{:.3E}, max:{:.3E}, mean:{:.3E}, std:{:.3E}" .format(np.min(r), np.max(r), np.mean(r), np.std(r))) - for n in [50, 100, 500, 1000]: + for n in [50, 100, 500]: print(n) a = np.random.rand(n // 4, 100) labels_a = np.random.randint(10, size=(n // 4)) diff --git a/test/test_ot.py b/test/test_ot.py index 3897397..5bf65c6 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -76,7 +76,7 @@ def test_emd2_multi(): # Gaussian distributions a = gauss(n, m=20, s=5) # m= mean, s= std - ls = np.arange(20, 1000, 10) + ls = np.arange(20, 1000, 20) nb = len(ls) b = np.zeros((n, nb)) for i in range(nb): |