From 68d74902bcd3d988fff8cb7713314063f04c0089 Mon Sep 17 00:00:00 2001 From: RĂ©mi Flamary Date: Wed, 26 Jul 2017 11:34:11 +0200 Subject: numpy assert + n_bins --- test/test_bregman.py | 63 ++++++++++++++++++++++++++-------------------------- 1 file changed, 31 insertions(+), 32 deletions(-) (limited to 'test') diff --git a/test/test_bregman.py b/test/test_bregman.py index aaa2efc..1638ef6 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -3,15 +3,12 @@ import numpy as np import ot -# import pytest - - def test_sinkhorn(): # test sinkhorn n = 100 - np.random.seed(0) + rng = np.random.RandomState(0) - x = np.random.randn(n, 2) + x = rng.randn(n, 2) u = ot.utils.unif(n) M = ot.dist(x, x) @@ -19,45 +16,47 @@ def test_sinkhorn(): G = ot.sinkhorn(u, u, M, 1, stopThr=1e-10) # check constratints - assert np.allclose(u, G.sum(1), atol=1e-05) # cf convergence sinkhorn - assert np.allclose(u, G.sum(0), atol=1e-05) # cf convergence sinkhorn + np.testing.assert_allclose( + u, G.sum(1), atol=1e-05) # cf convergence sinkhorn + np.testing.assert_allclose( + u, G.sum(0), atol=1e-05) # cf convergence sinkhorn def test_sinkhorn_empty(): # test sinkhorn n = 100 - np.random.seed(0) + rng = np.random.RandomState(0) - x = np.random.randn(n, 2) + x = rng.randn(n, 2) u = ot.utils.unif(n) M = ot.dist(x, x) G, log = ot.sinkhorn([], [], M, 1, stopThr=1e-10, verbose=True, log=True) # check constratints - assert np.allclose(u, G.sum(1), atol=1e-05) # cf convergence sinkhorn - assert np.allclose(u, G.sum(0), atol=1e-05) # cf convergence sinkhorn + np.testing.assert_allclose(u, G.sum(1), atol=1e-05) + np.testing.assert_allclose(u, G.sum(0), atol=1e-05) G, log = ot.sinkhorn([], [], M, 1, stopThr=1e-10, method='sinkhorn_stabilized', verbose=True, log=True) # check constratints - assert np.allclose(u, G.sum(1), atol=1e-05) # cf convergence sinkhorn - assert np.allclose(u, G.sum(0), atol=1e-05) # cf convergence sinkhorn + np.testing.assert_allclose(u, G.sum(1), atol=1e-05) + np.testing.assert_allclose(u, G.sum(0), atol=1e-05) G, log = ot.sinkhorn( [], [], M, 1, stopThr=1e-10, method='sinkhorn_epsilon_scaling', verbose=True, log=True) # check constratints - assert np.allclose(u, G.sum(1), atol=1e-05) # cf convergence sinkhorn - assert np.allclose(u, G.sum(0), atol=1e-05) # cf convergence sinkhorn + np.testing.assert_allclose(u, G.sum(1), atol=1e-05) + np.testing.assert_allclose(u, G.sum(0), atol=1e-05) def test_sinkhorn_variants(): # test sinkhorn n = 100 - np.random.seed(0) + rng = np.random.RandomState(0) - x = np.random.randn(n, 2) + x = rng.randn(n, 2) u = ot.utils.unif(n) M = ot.dist(x, x) @@ -69,24 +68,24 @@ def test_sinkhorn_variants(): Gerr = ot.sinkhorn(u, u, M, 1, method='do_not_exists', stopThr=1e-10) # check values - assert np.allclose(G0, Gs, atol=1e-05) - assert np.allclose(G0, Ges, atol=1e-05) - assert np.allclose(G0, Gerr) + np.testing.assert_allclose(G0, Gs, atol=1e-05) + np.testing.assert_allclose(G0, Ges, atol=1e-05) + np.testing.assert_allclose(G0, Gerr) def test_bary(): - n = 100 # nb bins + n_bins = 100 # nb bins # Gaussian distributions - a1 = ot.datasets.get_1D_gauss(n, m=30, s=10) # m= mean, s= std - a2 = ot.datasets.get_1D_gauss(n, m=40, s=10) + a1 = ot.datasets.get_1D_gauss(n_bins, m=30, s=10) # m= mean, s= std + a2 = ot.datasets.get_1D_gauss(n_bins, m=40, s=10) # creating matrix A containing all distributions A = np.vstack((a1, a2)).T # loss matrix + normalization - M = ot.utils.dist0(n) + M = ot.utils.dist0(n_bins) M /= M.max() alpha = 0.5 # 0<=alpha<=1 @@ -96,26 +95,26 @@ def test_bary(): reg = 1e-3 bary_wass = ot.bregman.barycenter(A, M, reg, weights) - assert np.allclose(1, np.sum(bary_wass)) + np.testing.assert_allclose(1, np.sum(bary_wass)) ot.bregman.barycenter(A, M, reg, log=True, verbose=True) def test_unmix(): - n = 50 # nb bins + n_bins = 50 # nb bins # Gaussian distributions - a1 = ot.datasets.get_1D_gauss(n, m=20, s=10) # m= mean, s= std - a2 = ot.datasets.get_1D_gauss(n, m=40, s=10) + a1 = ot.datasets.get_1D_gauss(n_bins, m=20, s=10) # m= mean, s= std + a2 = ot.datasets.get_1D_gauss(n_bins, m=40, s=10) - a = ot.datasets.get_1D_gauss(n, m=30, s=10) + a = ot.datasets.get_1D_gauss(n_bins, m=30, s=10) # creating matrix A containing all distributions D = np.vstack((a1, a2)).T # loss matrix + normalization - M = ot.utils.dist0(n) + M = ot.utils.dist0(n_bins) M /= M.max() M0 = ot.utils.dist0(2) @@ -126,8 +125,8 @@ def test_unmix(): reg = 1e-3 um = ot.bregman.unmix(a, D, M, M0, h0, reg, 1, alpha=0.01,) - assert np.allclose(1, np.sum(um), rtol=1e-03, atol=1e-03) - assert np.allclose([0.5, 0.5], um, rtol=1e-03, atol=1e-03) + np.testing.assert_allclose(1, np.sum(um), rtol=1e-03, atol=1e-03) + np.testing.assert_allclose([0.5, 0.5], um, rtol=1e-03, atol=1e-03) ot.bregman.unmix(a, D, M, M0, h0, reg, 1, alpha=0.01, log=True, verbose=True) -- cgit v1.2.3