summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2017-07-24 15:45:09 +0200
committerRémi Flamary <remi.flamary@gmail.com>2017-07-24 15:45:09 +0200
commitbd705ed847dd7e43082e9d2771a59e539d6b7440 (patch)
tree6aa66499a61a487eebe59b3c75cdbec97e7019b1 /test
parent33f3d309209baa8c5e127d02f00aae0660ed7bfb (diff)
add test yunmlix and bary
Diffstat (limited to 'test')
-rw-r--r--test/test_bregman.py34
-rw-r--r--test/test_gpu.py2
-rw-r--r--test/test_ot.py2
3 files changed, 36 insertions, 2 deletions
diff --git a/test/test_bregman.py b/test/test_bregman.py
index b204fe4..025568c 100644
--- a/test/test_bregman.py
+++ b/test/test_bregman.py
@@ -97,3 +97,37 @@ def test_bary():
bary_wass = ot.bregman.barycenter(A, M, reg, weights)
assert np.allclose(1, np.sum(bary_wass))
+
+ ot.bregman.barycenter(A, M, reg, log=True, verbose=True)
+
+
+def test_unmix():
+
+ n = 50 # nb bins
+
+ # Gaussian distributions
+ a1 = ot.datasets.get_1D_gauss(n, m=20, s=10) # m= mean, s= std
+ a2 = ot.datasets.get_1D_gauss(n, m=40, s=10)
+
+ a = ot.datasets.get_1D_gauss(n, m=30, s=10)
+
+ # creating matrix A containing all distributions
+ D = np.vstack((a1, a2)).T
+
+ # loss matrix + normalization
+ M = ot.utils.dist0(n)
+ M /= M.max()
+
+ M0 = ot.utils.dist0(2)
+ M0 /= M0.max()
+ h0 = ot.unif(2)
+
+ # wasserstein
+ reg = 1e-3
+ um = ot.bregman.unmix(a, D, M, M0, h0, reg, 1, alpha=0.01,)
+
+ assert np.allclose(1, np.sum(um), rtol=1e-03, atol=1e-03)
+ assert np.allclose([0.5, 0.5], um, rtol=1e-03, atol=1e-03)
+
+ ot.bregman.unmix(a, D, M, M0, h0, reg,
+ 1, alpha=0.01, log=True, verbose=True)
diff --git a/test/test_gpu.py b/test/test_gpu.py
index 9cc39d7..24797f2 100644
--- a/test/test_gpu.py
+++ b/test/test_gpu.py
@@ -48,7 +48,7 @@ def test_gpu_sinkhorn_lpl1():
print("min:{:.3E}, max:{:.3E}, mean:{:.3E}, std:{:.3E}"
.format(np.min(r), np.max(r), np.mean(r), np.std(r)))
- for n in [50, 100, 500, 1000]:
+ for n in [50, 100, 500]:
print(n)
a = np.random.rand(n // 4, 100)
labels_a = np.random.randint(10, size=(n // 4))
diff --git a/test/test_ot.py b/test/test_ot.py
index 3897397..5bf65c6 100644
--- a/test/test_ot.py
+++ b/test/test_ot.py
@@ -76,7 +76,7 @@ def test_emd2_multi():
# Gaussian distributions
a = gauss(n, m=20, s=5) # m= mean, s= std
- ls = np.arange(20, 1000, 10)
+ ls = np.arange(20, 1000, 20)
nb = len(ls)
b = np.zeros((n, nb))
for i in range(nb):