summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorncassereau-idris <84033440+ncassereau-idris@users.noreply.github.com>2021-10-25 17:35:36 +0200
committerGitHub <noreply@github.com>2021-10-25 17:35:36 +0200
commit76450dddf8dd62b9714b72e99ae075516246d433 (patch)
tree67de8de1c185cc8e7fc33a1fc0613015824d1fbb /test
parent7a65086dd340265d0223eb8ffb5c9a5152a82dff (diff)
[MRG] Backend for optim (#282)
* Backend for optim * Bug solve * Doc update * backend tests now with fixture * Unused imports removed * Docs * Docs * Docs * Outer product backend docs * Prettier docs * Pep8 * Mistakes corrected Co-authored-by: RĂ©mi Flamary <remi.flamary@gmail.com>
Diffstat (limited to 'test')
-rw-r--r--test/test_backend.py22
-rw-r--r--test/test_optim.py78
-rw-r--r--test/test_ot.py6
-rw-r--r--test/test_utils.py7
4 files changed, 77 insertions, 36 deletions
diff --git a/test/test_backend.py b/test/test_backend.py
index 859da5a..5853282 100644
--- a/test/test_backend.py
+++ b/test/test_backend.py
@@ -17,9 +17,6 @@ from numpy.testing import assert_array_almost_equal_nulp
from ot.backend import get_backend, get_backend_list, to_numpy
-backend_list = get_backend_list()
-
-
def test_get_backend_list():
lst = get_backend_list()
@@ -28,7 +25,6 @@ def test_get_backend_list():
assert isinstance(lst[0], ot.backend.NumpyBackend)
-@pytest.mark.parametrize('nx', backend_list)
def test_to_numpy(nx):
v = nx.zeros(10)
@@ -92,7 +88,6 @@ def test_get_backend():
get_backend(A, B2)
-@pytest.mark.parametrize('nx', backend_list)
def test_convert_between_backends(nx):
A = np.zeros((3, 2))
@@ -181,6 +176,8 @@ def test_empty_backend():
with pytest.raises(NotImplementedError):
nx.flip(M)
with pytest.raises(NotImplementedError):
+ nx.outer(v, v)
+ with pytest.raises(NotImplementedError):
nx.clip(M, -1, 1)
with pytest.raises(NotImplementedError):
nx.repeat(M, 0, 1)
@@ -208,10 +205,11 @@ def test_empty_backend():
nx.logsumexp(M)
with pytest.raises(NotImplementedError):
nx.stack([M, M])
+ with pytest.raises(NotImplementedError):
+ nx.reshape(M, (5, 3, 2))
-@pytest.mark.parametrize('backend', backend_list)
-def test_func_backends(backend):
+def test_func_backends(nx):
rnd = np.random.RandomState(0)
M = rnd.randn(10, 3)
@@ -220,7 +218,7 @@ def test_func_backends(backend):
lst_tot = []
- for nx in [ot.backend.NumpyBackend(), backend]:
+ for nx in [ot.backend.NumpyBackend(), nx]:
print('Backend: ', nx.__name__)
@@ -371,6 +369,10 @@ def test_func_backends(backend):
lst_b.append(nx.to_numpy(A))
lst_name.append('flip')
+ A = nx.outer(vb, vb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('outer')
+
A = nx.clip(vb, 0, 1)
lst_b.append(nx.to_numpy(A))
lst_name.append('clip')
@@ -432,6 +434,10 @@ def test_func_backends(backend):
lst_b.append(nx.to_numpy(A))
lst_name.append('stack')
+ A = nx.reshape(Mb, (5, 3, 2))
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('reshape')
+
lst_tot.append(lst_b)
lst_np = lst_tot[0]
diff --git a/test/test_optim.py b/test/test_optim.py
index 94995d5..4efd9b1 100644
--- a/test/test_optim.py
+++ b/test/test_optim.py
@@ -8,7 +8,7 @@ import numpy as np
import ot
-def test_conditional_gradient():
+def test_conditional_gradient(nx):
n_bins = 100 # nb bins
np.random.seed(0)
@@ -29,15 +29,25 @@ def test_conditional_gradient():
def df(G):
return G
+ def fb(G):
+ return 0.5 * nx.sum(G ** 2)
+
+ ab = nx.from_numpy(a)
+ bb = nx.from_numpy(b)
+ Mb = nx.from_numpy(M, type_as=ab)
+
reg = 1e-1
G, log = ot.optim.cg(a, b, M, reg, f, df, verbose=True, log=True)
+ Gb, log = ot.optim.cg(ab, bb, Mb, reg, fb, df, verbose=True, log=True)
+ Gb = nx.to_numpy(Gb)
- np.testing.assert_allclose(a, G.sum(1))
- np.testing.assert_allclose(b, G.sum(0))
+ np.testing.assert_allclose(Gb, G)
+ np.testing.assert_allclose(a, Gb.sum(1))
+ np.testing.assert_allclose(b, Gb.sum(0))
-def test_conditional_gradient_itermax():
+def test_conditional_gradient_itermax(nx):
n = 100 # nb samples
mu_s = np.array([0, 0])
@@ -61,16 +71,27 @@ def test_conditional_gradient_itermax():
def df(G):
return G
+ def fb(G):
+ return 0.5 * nx.sum(G ** 2)
+
+ ab = nx.from_numpy(a)
+ bb = nx.from_numpy(b)
+ Mb = nx.from_numpy(M, type_as=ab)
+
reg = 1e-1
G, log = ot.optim.cg(a, b, M, reg, f, df, numItermaxEmd=10000,
verbose=True, log=True)
+ Gb, log = ot.optim.cg(ab, bb, Mb, reg, fb, df, numItermaxEmd=10000,
+ verbose=True, log=True)
+ Gb = nx.to_numpy(Gb)
- np.testing.assert_allclose(a, G.sum(1))
- np.testing.assert_allclose(b, G.sum(0))
+ np.testing.assert_allclose(Gb, G)
+ np.testing.assert_allclose(a, Gb.sum(1))
+ np.testing.assert_allclose(b, Gb.sum(0))
-def test_generalized_conditional_gradient():
+def test_generalized_conditional_gradient(nx):
n_bins = 100 # nb bins
np.random.seed(0)
@@ -91,13 +112,23 @@ def test_generalized_conditional_gradient():
def df(G):
return G
+ def fb(G):
+ return 0.5 * nx.sum(G ** 2)
+
reg1 = 1e-3
reg2 = 1e-1
+ ab = nx.from_numpy(a)
+ bb = nx.from_numpy(b)
+ Mb = nx.from_numpy(M, type_as=ab)
+
G, log = ot.optim.gcg(a, b, M, reg1, reg2, f, df, verbose=True, log=True)
+ Gb, log = ot.optim.gcg(ab, bb, Mb, reg1, reg2, fb, df, verbose=True, log=True)
+ Gb = nx.to_numpy(Gb)
- np.testing.assert_allclose(a, G.sum(1), atol=1e-05)
- np.testing.assert_allclose(b, G.sum(0), atol=1e-05)
+ np.testing.assert_allclose(Gb, G)
+ np.testing.assert_allclose(a, Gb.sum(1), atol=1e-05)
+ np.testing.assert_allclose(b, Gb.sum(0), atol=1e-05)
def test_solve_1d_linesearch_quad_funct():
@@ -106,24 +137,31 @@ def test_solve_1d_linesearch_quad_funct():
np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad(-1, 0.5, 0), 1)
-def test_line_search_armijo():
+def test_line_search_armijo(nx):
xk = np.array([[0.25, 0.25], [0.25, 0.25]])
pk = np.array([[-0.25, 0.25], [0.25, -0.25]])
gfk = np.array([[23.04273441, 23.0449082], [23.04273441, 23.0449082]])
old_fval = -123
# Should not throw an exception and return None for alpha
- alpha, _, _ = ot.optim.line_search_armijo(lambda x: 1, xk, pk, gfk, old_fval)
+ alpha, a, b = ot.optim.line_search_armijo(
+ lambda x: 1, nx.from_numpy(xk), nx.from_numpy(pk), nx.from_numpy(gfk), old_fval
+ )
+ alpha_np, anp, bnp = ot.optim.line_search_armijo(
+ lambda x: 1, xk, pk, gfk, old_fval
+ )
+ assert a == anp
+ assert b == bnp
assert alpha is None
# check line search armijo
def f(x):
- return np.sum((x - 5.0) ** 2)
+ return nx.sum((x - 5.0) ** 2)
def grad(x):
return 2 * (x - 5.0)
- xk = np.array([[[-5.0, -5.0]]])
- pk = np.array([[[100.0, 100.0]]])
+ xk = nx.from_numpy(np.array([[[-5.0, -5.0]]]))
+ pk = nx.from_numpy(np.array([[[100.0, 100.0]]]))
gfk = grad(xk)
old_fval = f(xk)
@@ -132,10 +170,18 @@ def test_line_search_armijo():
np.testing.assert_allclose(alpha, 0.1)
# check the case where the direction is not far enough
- pk = np.array([[[3.0, 3.0]]])
+ pk = nx.from_numpy(np.array([[[3.0, 3.0]]]))
alpha, _, _ = ot.optim.line_search_armijo(f, xk, pk, gfk, old_fval, alpha0=1.0)
np.testing.assert_allclose(alpha, 1.0)
- # check the case where the checking the wrong direction
+ # check the case where checking the wrong direction
alpha, _, _ = ot.optim.line_search_armijo(f, xk, -pk, gfk, old_fval)
assert alpha <= 0
+
+ # check the case where the point is not a vector
+ xk = nx.from_numpy(np.array(-5.0))
+ pk = nx.from_numpy(np.array(100.0))
+ gfk = grad(xk)
+ old_fval = f(xk)
+ alpha, _, _ = ot.optim.line_search_armijo(f, xk, pk, gfk, old_fval)
+ np.testing.assert_allclose(alpha, 0.1)
diff --git a/test/test_ot.py b/test/test_ot.py
index 3e953dc..4dfc510 100644
--- a/test/test_ot.py
+++ b/test/test_ot.py
@@ -12,9 +12,7 @@ from scipy.stats import wasserstein_distance
import ot
from ot.datasets import make_1D_gauss as gauss
-from ot.backend import get_backend_list, torch
-
-backend_list = get_backend_list()
+from ot.backend import torch
def test_emd_dimension_and_mass_mismatch():
@@ -37,7 +35,6 @@ def test_emd_dimension_and_mass_mismatch():
np.testing.assert_raises(AssertionError, ot.emd, a, b, M)
-@pytest.mark.parametrize('nx', backend_list)
def test_emd_backends(nx):
n_samples = 100
n_features = 2
@@ -59,7 +56,6 @@ def test_emd_backends(nx):
np.allclose(G, nx.to_numpy(Gb))
-@pytest.mark.parametrize('nx', backend_list)
def test_emd2_backends(nx):
n_samples = 100
n_features = 2
diff --git a/test/test_utils.py b/test/test_utils.py
index 76b1faa..60ad5d3 100644
--- a/test/test_utils.py
+++ b/test/test_utils.py
@@ -4,17 +4,11 @@
#
# License: MIT License
-import pytest
import ot
import numpy as np
import sys
-from ot.backend import get_backend_list
-backend_list = get_backend_list()
-
-
-@pytest.mark.parametrize('nx', backend_list)
def test_proj_simplex(nx):
n = 10
rng = np.random.RandomState(0)
@@ -119,7 +113,6 @@ def test_dist():
np.testing.assert_allclose(D, D3, atol=1e-14)
-@ pytest.mark.parametrize('nx', backend_list)
def test_dist_backends(nx):
n = 100