summaryrefslogtreecommitdiff
path: root/test/test_ot.py
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2017-07-24 11:15:33 +0200
committerRémi Flamary <remi.flamary@gmail.com>2017-07-24 11:15:33 +0200
commit5a6b5de9b2f28c93bef1a9db2e3b94693c05ff4f (patch)
tree1f7457a028ef71253be36c44fb87c2e4131e909a /test/test_ot.py
parent82da63f1020835a412f6174500099694a78ab6be (diff)
add proper testing
Diffstat (limited to 'test/test_ot.py')
-rw-r--r--test/test_ot.py55
1 files changed, 55 insertions, 0 deletions
diff --git a/test/test_ot.py b/test/test_ot.py
new file mode 100644
index 0000000..51ee510
--- /dev/null
+++ b/test/test_ot.py
@@ -0,0 +1,55 @@
+
+
+import ot
+import numpy as np
+
+#import pytest
+
+
+def test_doctest():
+
+ import doctest
+
+ # test lp solver
+ doctest.testmod(ot.lp, verbose=True)
+
+ # test bregman solver
+ doctest.testmod(ot.bregman, verbose=True)
+
+
+#@pytest.mark.skip(reason="Seems to be a conflict between pytest and multiprocessing")
+def test_emd_multi():
+
+ from ot.datasets import get_1D_gauss as gauss
+
+ n = 1000 # nb bins
+
+ # bin positions
+ x = np.arange(n, dtype=np.float64)
+
+ # Gaussian distributions
+ a = gauss(n, m=20, s=5) # m= mean, s= std
+
+ ls = np.arange(20, 1000, 10)
+ nb = len(ls)
+ b = np.zeros((n, nb))
+ for i in range(nb):
+ b[:, i] = gauss(n, m=ls[i], s=10)
+
+ # loss matrix
+ M = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)))
+ # M/=M.max()
+
+ print('Computing {} EMD '.format(nb))
+
+ # emd loss 1 proc
+ ot.tic()
+ emd1 = ot.emd2(a, b, M, 1)
+ ot.toc('1 proc : {} s')
+
+ # emd loss multipro proc
+ ot.tic()
+ emdn = ot.emd2(a, b, M)
+ ot.toc('multi proc : {} s')
+
+ assert np.allclose(emd1, emdn)