summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorCédric Vincent-Cuaz <cedvincentcuaz@gmail.com>2023-03-16 08:05:54 +0100
committerGitHub <noreply@github.com>2023-03-16 08:05:54 +0100
commit583501652517c4f1dbd8572e9f942551a9e54a1f (patch)
treefadb96f888924b2d1bef01b78486e97a88ebcd42 /test
parent8f56effe7320991ebdc6457a2cf1d3b6648a09d1 (diff)
[MRG] fix bugs of gw_entropic and armijo to run on gpu (#446)
* maj gw/ srgw/ generic cg solver * correct pep8 on current state * fix bug previous tests * fix pep8 * fix bug srGW constC in loss and gradient * fix doc html * fix doc html * start updating test_optim.py * update tests gromov and optim - plus fix gromov dependencies * add symmetry feature to entropic gw * add symmetry feature to entropic gw * add exemple for sr(F)GW matchings * small stuff * remove (reg,M) from line-search/ complete srgw tests with backend * remove backend repetitions / rename fG to costG/ fix innerlog to True * fix pep8 * take comments into account / new nx parameters still to test * factor (f)gw2 + test new backend parameters in ot.gromov + harmonize stopping criterions * split gromov.py in ot/gromov/ + update test_gromov with helper_backend functions * manual documentaion gromov * remove circular autosummary * trying stuff * debug documentation * alphabetic ordering of module * merge into branch * add note in entropic gw solvers * fix exemples/gromov doc * add fixed issue to releases.md * fix bugs of gw_entropic and armijo to run on gpu * add pr to releases.md * fix pep8 * fix call to backend in line_search_armijo * correct docstring generic_conditional_gradient --------- Co-authored-by: Rémi Flamary <remi.flamary@gmail.com>
Diffstat (limited to 'test')
-rw-r--r--test/test_gromov.py16
-rw-r--r--test/test_optim.py55
2 files changed, 68 insertions, 3 deletions
diff --git a/test/test_gromov.py b/test/test_gromov.py
index cfccce7..80b6df4 100644
--- a/test/test_gromov.py
+++ b/test/test_gromov.py
@@ -214,6 +214,7 @@ def test_gromov2_gradients():
C11 = torch.tensor(C1, requires_grad=True, device=device)
C12 = torch.tensor(C2, requires_grad=True, device=device)
+ # Test with exact line-search
val = ot.gromov_wasserstein2(C11, C12, p1, q1)
val.backward()
@@ -224,6 +225,21 @@ def test_gromov2_gradients():
assert C11.shape == C11.grad.shape
assert C12.shape == C12.grad.shape
+ # Test with armijo line-search
+ q1.grad = None
+ p1.grad = None
+ C11.grad = None
+ C12.grad = None
+ val = ot.gromov_wasserstein2(C11, C12, p1, q1, armijo=True)
+
+ val.backward()
+
+ assert val.device == p1.device
+ assert q1.shape == q1.grad.shape
+ assert p1.shape == p1.grad.shape
+ assert C11.shape == C11.grad.shape
+ assert C12.shape == C12.grad.shape
+
def test_gw_helper_backend(nx):
n_samples = 20 # nb samples
diff --git a/test/test_optim.py b/test/test_optim.py
index 129fe22..a43e704 100644
--- a/test/test_optim.py
+++ b/test/test_optim.py
@@ -135,16 +135,18 @@ def test_line_search_armijo(nx):
xk = np.array([[0.25, 0.25], [0.25, 0.25]])
pk = np.array([[-0.25, 0.25], [0.25, -0.25]])
gfk = np.array([[23.04273441, 23.0449082], [23.04273441, 23.0449082]])
- old_fval = -123
+ old_fval = -123.
xkb, pkb, gfkb = nx.from_numpy(xk, pk, gfk)
+ def f(x):
+ return 1.
# Should not throw an exception and return 0. for alpha
alpha, a, b = ot.optim.line_search_armijo(
- lambda x: 1, xkb, pkb, gfkb, old_fval
+ f, xkb, pkb, gfkb, old_fval
)
alpha_np, anp, bnp = ot.optim.line_search_armijo(
- lambda x: 1, xk, pk, gfk, old_fval
+ f, xk, pk, gfk, old_fval
)
assert a == anp
assert b == bnp
@@ -182,3 +184,50 @@ def test_line_search_armijo(nx):
old_fval = f(xk)
alpha, _, _ = ot.optim.line_search_armijo(f, xk, pk, gfk, old_fval)
np.testing.assert_allclose(alpha, 0.1)
+
+
+def test_line_search_armijo_dtype_device(nx):
+ for tp in nx.__type_list__:
+ def f(x):
+ return nx.sum((x - 5.0) ** 2)
+
+ def grad(x):
+ return 2 * (x - 5.0)
+
+ xk = np.array([[[-5.0, -5.0]]])
+ pk = np.array([[[100.0, 100.0]]])
+ xkb, pkb = nx.from_numpy(xk, pk, type_as=tp)
+ gfkb = grad(xkb)
+ old_fval = f(xkb)
+
+ # chech the case where the optimum is on the direction
+ alpha, _, fval = ot.optim.line_search_armijo(f, xkb, pkb, gfkb, old_fval)
+ alpha = nx.to_numpy(alpha)
+ np.testing.assert_allclose(alpha, 0.1)
+ nx.assert_same_dtype_device(old_fval, fval)
+
+ # check the case where the direction is not far enough
+ pk = np.array([[[3.0, 3.0]]])
+ pkb = nx.from_numpy(pk, type_as=tp)
+ alpha, _, fval = ot.optim.line_search_armijo(f, xkb, pkb, gfkb, old_fval, alpha0=1.0)
+ alpha = nx.to_numpy(alpha)
+ np.testing.assert_allclose(alpha, 1.0)
+ nx.assert_same_dtype_device(old_fval, fval)
+
+ # check the case where checking the wrong direction
+ alpha, _, fval = ot.optim.line_search_armijo(f, xkb, -pkb, gfkb, old_fval)
+ alpha = nx.to_numpy(alpha)
+
+ assert alpha <= 0
+ nx.assert_same_dtype_device(old_fval, fval)
+
+ # check the case where the point is not a vector
+ xkb = nx.from_numpy(np.array(-5.0), type_as=tp)
+ pkb = nx.from_numpy(np.array(100), type_as=tp)
+ gfkb = grad(xkb)
+ old_fval = f(xkb)
+ alpha, _, fval = ot.optim.line_search_armijo(f, xkb, pkb, gfkb, old_fval)
+ alpha = nx.to_numpy(alpha)
+
+ np.testing.assert_allclose(alpha, 0.1)
+ nx.assert_same_dtype_device(old_fval, fval)