From b2f91f24796a996a82db41e91f56ba6a51989159 Mon Sep 17 00:00:00 2001 From: RĂ©mi Flamary Date: Mon, 24 Jul 2017 14:26:25 +0200 Subject: full coveragre utils --- test/test_utils.py | 46 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) (limited to 'test/test_utils.py') diff --git a/test/test_utils.py b/test/test_utils.py index 3219fce..e85e5b7 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -71,6 +71,52 @@ def test_dist(): D[i, j] = np.sum(np.square(x[i, :] - x[j, :])) D2 = ot.dist(x, x) + D3 = ot.dist(x) # dist shoul return squared euclidean assert np.allclose(D, D2) + assert np.allclose(D, D3) + + +def test_dist0(): + + n = 100 + M = ot.utils.dist0(n, method='lin_square') + + # dist0 default to linear sampling with quadratic loss + assert np.allclose(M[0, -1], (n - 1) * (n - 1)) + + +def test_dots(): + + n1, n2, n3, n4 = 100, 50, 200, 100 + + A = np.random.randn(n1, n2) + B = np.random.randn(n2, n3) + C = np.random.randn(n3, n4) + + X1 = ot.utils.dots(A, B, C) + + X2 = A.dot(B.dot(C)) + + assert np.allclose(X1, X2) + + +def test_clean_zeros(): + + n = 100 + nz = 50 + nz2 = 20 + u1 = ot.unif(n) + u1[:nz] = 0 + u1 = u1 / u1.sum() + u2 = ot.unif(n) + u2[:nz2] = 0 + u2 = u2 / u2.sum() + + M = ot.utils.dist0(n) + + a, b, M2 = ot.utils.clean_zeros(u1, u2, M) + + assert len(a) == n - nz + assert len(b) == n - nz2 -- cgit v1.2.3