summaryrefslogtreecommitdiff
path: root/test/test_optim.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/test_optim.py')
-rw-r--r--test/test_optim.py78
1 files changed, 62 insertions, 16 deletions
diff --git a/test/test_optim.py b/test/test_optim.py
index 94995d5..4efd9b1 100644
--- a/test/test_optim.py
+++ b/test/test_optim.py
@@ -8,7 +8,7 @@ import numpy as np
import ot
-def test_conditional_gradient():
+def test_conditional_gradient(nx):
n_bins = 100 # nb bins
np.random.seed(0)
@@ -29,15 +29,25 @@ def test_conditional_gradient():
def df(G):
return G
+ def fb(G):
+ return 0.5 * nx.sum(G ** 2)
+
+ ab = nx.from_numpy(a)
+ bb = nx.from_numpy(b)
+ Mb = nx.from_numpy(M, type_as=ab)
+
reg = 1e-1
G, log = ot.optim.cg(a, b, M, reg, f, df, verbose=True, log=True)
+ Gb, log = ot.optim.cg(ab, bb, Mb, reg, fb, df, verbose=True, log=True)
+ Gb = nx.to_numpy(Gb)
- np.testing.assert_allclose(a, G.sum(1))
- np.testing.assert_allclose(b, G.sum(0))
+ np.testing.assert_allclose(Gb, G)
+ np.testing.assert_allclose(a, Gb.sum(1))
+ np.testing.assert_allclose(b, Gb.sum(0))
-def test_conditional_gradient_itermax():
+def test_conditional_gradient_itermax(nx):
n = 100 # nb samples
mu_s = np.array([0, 0])
@@ -61,16 +71,27 @@ def test_conditional_gradient_itermax():
def df(G):
return G
+ def fb(G):
+ return 0.5 * nx.sum(G ** 2)
+
+ ab = nx.from_numpy(a)
+ bb = nx.from_numpy(b)
+ Mb = nx.from_numpy(M, type_as=ab)
+
reg = 1e-1
G, log = ot.optim.cg(a, b, M, reg, f, df, numItermaxEmd=10000,
verbose=True, log=True)
+ Gb, log = ot.optim.cg(ab, bb, Mb, reg, fb, df, numItermaxEmd=10000,
+ verbose=True, log=True)
+ Gb = nx.to_numpy(Gb)
- np.testing.assert_allclose(a, G.sum(1))
- np.testing.assert_allclose(b, G.sum(0))
+ np.testing.assert_allclose(Gb, G)
+ np.testing.assert_allclose(a, Gb.sum(1))
+ np.testing.assert_allclose(b, Gb.sum(0))
-def test_generalized_conditional_gradient():
+def test_generalized_conditional_gradient(nx):
n_bins = 100 # nb bins
np.random.seed(0)
@@ -91,13 +112,23 @@ def test_generalized_conditional_gradient():
def df(G):
return G
+ def fb(G):
+ return 0.5 * nx.sum(G ** 2)
+
reg1 = 1e-3
reg2 = 1e-1
+ ab = nx.from_numpy(a)
+ bb = nx.from_numpy(b)
+ Mb = nx.from_numpy(M, type_as=ab)
+
G, log = ot.optim.gcg(a, b, M, reg1, reg2, f, df, verbose=True, log=True)
+ Gb, log = ot.optim.gcg(ab, bb, Mb, reg1, reg2, fb, df, verbose=True, log=True)
+ Gb = nx.to_numpy(Gb)
- np.testing.assert_allclose(a, G.sum(1), atol=1e-05)
- np.testing.assert_allclose(b, G.sum(0), atol=1e-05)
+ np.testing.assert_allclose(Gb, G)
+ np.testing.assert_allclose(a, Gb.sum(1), atol=1e-05)
+ np.testing.assert_allclose(b, Gb.sum(0), atol=1e-05)
def test_solve_1d_linesearch_quad_funct():
@@ -106,24 +137,31 @@ def test_solve_1d_linesearch_quad_funct():
np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad(-1, 0.5, 0), 1)
-def test_line_search_armijo():
+def test_line_search_armijo(nx):
xk = np.array([[0.25, 0.25], [0.25, 0.25]])
pk = np.array([[-0.25, 0.25], [0.25, -0.25]])
gfk = np.array([[23.04273441, 23.0449082], [23.04273441, 23.0449082]])
old_fval = -123
# Should not throw an exception and return None for alpha
- alpha, _, _ = ot.optim.line_search_armijo(lambda x: 1, xk, pk, gfk, old_fval)
+ alpha, a, b = ot.optim.line_search_armijo(
+ lambda x: 1, nx.from_numpy(xk), nx.from_numpy(pk), nx.from_numpy(gfk), old_fval
+ )
+ alpha_np, anp, bnp = ot.optim.line_search_armijo(
+ lambda x: 1, xk, pk, gfk, old_fval
+ )
+ assert a == anp
+ assert b == bnp
assert alpha is None
# check line search armijo
def f(x):
- return np.sum((x - 5.0) ** 2)
+ return nx.sum((x - 5.0) ** 2)
def grad(x):
return 2 * (x - 5.0)
- xk = np.array([[[-5.0, -5.0]]])
- pk = np.array([[[100.0, 100.0]]])
+ xk = nx.from_numpy(np.array([[[-5.0, -5.0]]]))
+ pk = nx.from_numpy(np.array([[[100.0, 100.0]]]))
gfk = grad(xk)
old_fval = f(xk)
@@ -132,10 +170,18 @@ def test_line_search_armijo():
np.testing.assert_allclose(alpha, 0.1)
# check the case where the direction is not far enough
- pk = np.array([[[3.0, 3.0]]])
+ pk = nx.from_numpy(np.array([[[3.0, 3.0]]]))
alpha, _, _ = ot.optim.line_search_armijo(f, xk, pk, gfk, old_fval, alpha0=1.0)
np.testing.assert_allclose(alpha, 1.0)
- # check the case where the checking the wrong direction
+ # check the case where checking the wrong direction
alpha, _, _ = ot.optim.line_search_armijo(f, xk, -pk, gfk, old_fval)
assert alpha <= 0
+
+ # check the case where the point is not a vector
+ xk = nx.from_numpy(np.array(-5.0))
+ pk = nx.from_numpy(np.array(100.0))
+ gfk = grad(xk)
+ old_fval = f(xk)
+ alpha, _, _ = ot.optim.line_search_armijo(f, xk, pk, gfk, old_fval)
+ np.testing.assert_allclose(alpha, 0.1)