summaryrefslogtreecommitdiff
path: root/test/test_utils.py
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2017-07-24 14:26:25 +0200
committerRémi Flamary <remi.flamary@gmail.com>2017-07-24 14:26:25 +0200
commitb2f91f24796a996a82db41e91f56ba6a51989159 (patch)
tree81f8ce82e2b7474378746fa591149352d89a7775 /test/test_utils.py
parent1cf304cee298e2752ce29c83e5201f593722c3af (diff)
full coveragre utils
Diffstat (limited to 'test/test_utils.py')
-rw-r--r--test/test_utils.py46
1 files changed, 46 insertions, 0 deletions
diff --git a/test/test_utils.py b/test/test_utils.py
index 3219fce..e85e5b7 100644
--- a/test/test_utils.py
+++ b/test/test_utils.py
@@ -71,6 +71,52 @@ def test_dist():
D[i, j] = np.sum(np.square(x[i, :] - x[j, :]))
D2 = ot.dist(x, x)
+ D3 = ot.dist(x)
# dist shoul return squared euclidean
assert np.allclose(D, D2)
+ assert np.allclose(D, D3)
+
+
+def test_dist0():
+
+ n = 100
+ M = ot.utils.dist0(n, method='lin_square')
+
+ # dist0 default to linear sampling with quadratic loss
+ assert np.allclose(M[0, -1], (n - 1) * (n - 1))
+
+
+def test_dots():
+
+ n1, n2, n3, n4 = 100, 50, 200, 100
+
+ A = np.random.randn(n1, n2)
+ B = np.random.randn(n2, n3)
+ C = np.random.randn(n3, n4)
+
+ X1 = ot.utils.dots(A, B, C)
+
+ X2 = A.dot(B.dot(C))
+
+ assert np.allclose(X1, X2)
+
+
+def test_clean_zeros():
+
+ n = 100
+ nz = 50
+ nz2 = 20
+ u1 = ot.unif(n)
+ u1[:nz] = 0
+ u1 = u1 / u1.sum()
+ u2 = ot.unif(n)
+ u2[:nz2] = 0
+ u2 = u2 / u2.sum()
+
+ M = ot.utils.dist0(n)
+
+ a, b, M2 = ot.utils.clean_zeros(u1, u2, M)
+
+ assert len(a) == n - nz
+ assert len(b) == n - nz2