summaryrefslogtreecommitdiff
path: root/test/test_bregman.py
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2017-07-26 11:34:11 +0200
committerRémi Flamary <remi.flamary@gmail.com>2017-07-26 11:34:11 +0200
commit68d74902bcd3d988fff8cb7713314063f04c0089 (patch)
tree55f7fa521a252e9b67eecd2d2708003daae7e6c4 /test/test_bregman.py
parent46f297f678de0051dc6d5067291d1e1046b4705e (diff)
numpy assert + n_bins
Diffstat (limited to 'test/test_bregman.py')
-rw-r--r--test/test_bregman.py63
1 files changed, 31 insertions, 32 deletions
diff --git a/test/test_bregman.py b/test/test_bregman.py
index aaa2efc..1638ef6 100644
--- a/test/test_bregman.py
+++ b/test/test_bregman.py
@@ -3,15 +3,12 @@ import numpy as np
import ot
-# import pytest
-
-
def test_sinkhorn():
# test sinkhorn
n = 100
- np.random.seed(0)
+ rng = np.random.RandomState(0)
- x = np.random.randn(n, 2)
+ x = rng.randn(n, 2)
u = ot.utils.unif(n)
M = ot.dist(x, x)
@@ -19,45 +16,47 @@ def test_sinkhorn():
G = ot.sinkhorn(u, u, M, 1, stopThr=1e-10)
# check constratints
- assert np.allclose(u, G.sum(1), atol=1e-05) # cf convergence sinkhorn
- assert np.allclose(u, G.sum(0), atol=1e-05) # cf convergence sinkhorn
+ np.testing.assert_allclose(
+ u, G.sum(1), atol=1e-05) # cf convergence sinkhorn
+ np.testing.assert_allclose(
+ u, G.sum(0), atol=1e-05) # cf convergence sinkhorn
def test_sinkhorn_empty():
# test sinkhorn
n = 100
- np.random.seed(0)
+ rng = np.random.RandomState(0)
- x = np.random.randn(n, 2)
+ x = rng.randn(n, 2)
u = ot.utils.unif(n)
M = ot.dist(x, x)
G, log = ot.sinkhorn([], [], M, 1, stopThr=1e-10, verbose=True, log=True)
# check constratints
- assert np.allclose(u, G.sum(1), atol=1e-05) # cf convergence sinkhorn
- assert np.allclose(u, G.sum(0), atol=1e-05) # cf convergence sinkhorn
+ np.testing.assert_allclose(u, G.sum(1), atol=1e-05)
+ np.testing.assert_allclose(u, G.sum(0), atol=1e-05)
G, log = ot.sinkhorn([], [], M, 1, stopThr=1e-10,
method='sinkhorn_stabilized', verbose=True, log=True)
# check constratints
- assert np.allclose(u, G.sum(1), atol=1e-05) # cf convergence sinkhorn
- assert np.allclose(u, G.sum(0), atol=1e-05) # cf convergence sinkhorn
+ np.testing.assert_allclose(u, G.sum(1), atol=1e-05)
+ np.testing.assert_allclose(u, G.sum(0), atol=1e-05)
G, log = ot.sinkhorn(
[], [], M, 1, stopThr=1e-10, method='sinkhorn_epsilon_scaling',
verbose=True, log=True)
# check constratints
- assert np.allclose(u, G.sum(1), atol=1e-05) # cf convergence sinkhorn
- assert np.allclose(u, G.sum(0), atol=1e-05) # cf convergence sinkhorn
+ np.testing.assert_allclose(u, G.sum(1), atol=1e-05)
+ np.testing.assert_allclose(u, G.sum(0), atol=1e-05)
def test_sinkhorn_variants():
# test sinkhorn
n = 100
- np.random.seed(0)
+ rng = np.random.RandomState(0)
- x = np.random.randn(n, 2)
+ x = rng.randn(n, 2)
u = ot.utils.unif(n)
M = ot.dist(x, x)
@@ -69,24 +68,24 @@ def test_sinkhorn_variants():
Gerr = ot.sinkhorn(u, u, M, 1, method='do_not_exists', stopThr=1e-10)
# check values
- assert np.allclose(G0, Gs, atol=1e-05)
- assert np.allclose(G0, Ges, atol=1e-05)
- assert np.allclose(G0, Gerr)
+ np.testing.assert_allclose(G0, Gs, atol=1e-05)
+ np.testing.assert_allclose(G0, Ges, atol=1e-05)
+ np.testing.assert_allclose(G0, Gerr)
def test_bary():
- n = 100 # nb bins
+ n_bins = 100 # nb bins
# Gaussian distributions
- a1 = ot.datasets.get_1D_gauss(n, m=30, s=10) # m= mean, s= std
- a2 = ot.datasets.get_1D_gauss(n, m=40, s=10)
+ a1 = ot.datasets.get_1D_gauss(n_bins, m=30, s=10) # m= mean, s= std
+ a2 = ot.datasets.get_1D_gauss(n_bins, m=40, s=10)
# creating matrix A containing all distributions
A = np.vstack((a1, a2)).T
# loss matrix + normalization
- M = ot.utils.dist0(n)
+ M = ot.utils.dist0(n_bins)
M /= M.max()
alpha = 0.5 # 0<=alpha<=1
@@ -96,26 +95,26 @@ def test_bary():
reg = 1e-3
bary_wass = ot.bregman.barycenter(A, M, reg, weights)
- assert np.allclose(1, np.sum(bary_wass))
+ np.testing.assert_allclose(1, np.sum(bary_wass))
ot.bregman.barycenter(A, M, reg, log=True, verbose=True)
def test_unmix():
- n = 50 # nb bins
+ n_bins = 50 # nb bins
# Gaussian distributions
- a1 = ot.datasets.get_1D_gauss(n, m=20, s=10) # m= mean, s= std
- a2 = ot.datasets.get_1D_gauss(n, m=40, s=10)
+ a1 = ot.datasets.get_1D_gauss(n_bins, m=20, s=10) # m= mean, s= std
+ a2 = ot.datasets.get_1D_gauss(n_bins, m=40, s=10)
- a = ot.datasets.get_1D_gauss(n, m=30, s=10)
+ a = ot.datasets.get_1D_gauss(n_bins, m=30, s=10)
# creating matrix A containing all distributions
D = np.vstack((a1, a2)).T
# loss matrix + normalization
- M = ot.utils.dist0(n)
+ M = ot.utils.dist0(n_bins)
M /= M.max()
M0 = ot.utils.dist0(2)
@@ -126,8 +125,8 @@ def test_unmix():
reg = 1e-3
um = ot.bregman.unmix(a, D, M, M0, h0, reg, 1, alpha=0.01,)
- assert np.allclose(1, np.sum(um), rtol=1e-03, atol=1e-03)
- assert np.allclose([0.5, 0.5], um, rtol=1e-03, atol=1e-03)
+ np.testing.assert_allclose(1, np.sum(um), rtol=1e-03, atol=1e-03)
+ np.testing.assert_allclose([0.5, 0.5], um, rtol=1e-03, atol=1e-03)
ot.bregman.unmix(a, D, M, M0, h0, reg,
1, alpha=0.01, log=True, verbose=True)