diff options
author | Cédric Vincent-Cuaz <cedvincentcuaz@gmail.com> | 2023-03-16 08:05:54 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-03-16 08:05:54 +0100 |
commit | 583501652517c4f1dbd8572e9f942551a9e54a1f (patch) | |
tree | fadb96f888924b2d1bef01b78486e97a88ebcd42 | |
parent | 8f56effe7320991ebdc6457a2cf1d3b6648a09d1 (diff) |
[MRG] fix bugs of gw_entropic and armijo to run on gpu (#446)
* maj gw/ srgw/ generic cg solver
* correct pep8 on current state
* fix bug previous tests
* fix pep8
* fix bug srGW constC in loss and gradient
* fix doc html
* fix doc html
* start updating test_optim.py
* update tests gromov and optim - plus fix gromov dependencies
* add symmetry feature to entropic gw
* add symmetry feature to entropic gw
* add exemple for sr(F)GW matchings
* small stuff
* remove (reg,M) from line-search/ complete srgw tests with backend
* remove backend repetitions / rename fG to costG/ fix innerlog to True
* fix pep8
* take comments into account / new nx parameters still to test
* factor (f)gw2 + test new backend parameters in ot.gromov + harmonize stopping criterions
* split gromov.py in ot/gromov/ + update test_gromov with helper_backend functions
* manual documentaion gromov
* remove circular autosummary
* trying stuff
* debug documentation
* alphabetic ordering of module
* merge into branch
* add note in entropic gw solvers
* fix exemples/gromov doc
* add fixed issue to releases.md
* fix bugs of gw_entropic and armijo to run on gpu
* add pr to releases.md
* fix pep8
* fix call to backend in line_search_armijo
* correct docstring generic_conditional_gradient
---------
Co-authored-by: Rémi Flamary <remi.flamary@gmail.com>
-rw-r--r-- | RELEASES.md | 3 | ||||
-rw-r--r-- | ot/gromov/_bregman.py | 5 | ||||
-rw-r--r-- | ot/optim.py | 38 | ||||
-rw-r--r-- | test/test_gromov.py | 16 | ||||
-rw-r--r-- | test/test_optim.py | 55 |
5 files changed, 100 insertions, 17 deletions
diff --git a/RELEASES.md b/RELEASES.md index da4d7bb..b6e12d9 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -47,6 +47,7 @@ PR #413) that explicitly specified `stopThr=1e-9` (Issue #421, PR #422). - Fixed a bug breaking an example where we would try to make an array of arrays of different shapes (Issue #424, PR #425) - Fixed an issue with the documentation gallery section (PR #444) +- Fixed issues with cuda variables for `line_search_armijo` and `entropic_gromov_wasserstein` (Issue #445, #PR 446) ## 0.8.2 @@ -571,4 +572,4 @@ It provides the following solvers: * Optimal transport for domain adaptation with group lasso regularization * Conditional gradient and Generalized conditional gradient for regularized OT. -Some demonstrations (both in Python and Jupyter Notebook format) are available in the examples folder. +Some demonstrations (both in Python and Jupyter Notebook format) are available in the examples folder.
\ No newline at end of file diff --git a/ot/gromov/_bregman.py b/ot/gromov/_bregman.py index 5b2f959..b0cccfb 100644 --- a/ot/gromov/_bregman.py +++ b/ot/gromov/_bregman.py @@ -11,9 +11,6 @@ Bregman projections solvers for entropic Gromov-Wasserstein # # License: MIT License -import numpy as np - - from ..bregman import sinkhorn from ..utils import dist, list_to_array, check_random_state from ..backend import get_backend @@ -109,7 +106,7 @@ def entropic_gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, symmetric=None, T = G0 constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun, nx) if symmetric is None: - symmetric = np.allclose(C1, C1.T, atol=1e-10) and np.allclose(C2, C2.T, atol=1e-10) + symmetric = nx.allclose(C1, C1.T, atol=1e-10) and nx.allclose(C2, C2.T, atol=1e-10) if not symmetric: constCt, hC1t, hC2t = init_matrix(C1.T, C2.T, p, q, loss_fun, nx) cpt = 0 diff --git a/ot/optim.py b/ot/optim.py index 201f898..58e5596 100644 --- a/ot/optim.py +++ b/ot/optim.py @@ -35,6 +35,9 @@ def line_search_armijo( Find an approximate minimum of :math:`f(x_k + \alpha \cdot p_k)` that satisfies the armijo conditions. + .. note:: If the loss function f returns a float (resp. a 1d array) then + the returned alpha and fa are float (resp. 1d arrays). + Parameters ---------- f : callable @@ -45,7 +48,7 @@ def line_search_armijo( descent direction gfk : array-like gradient of `f` at :math:`x_k` - old_fval : float + old_fval : float or 1d array loss value at :math:`x_k` args : tuple, optional arguments given to `f` @@ -61,42 +64,59 @@ def line_search_armijo( If let to its default value None, a backend test will be conducted. Returns ------- - alpha : float + alpha : float or 1d array step that satisfy armijo conditions fc : int nb of function call - fa : float + fa : float or 1d array loss value at step alpha """ if nx is None: xk, pk, gfk = list_to_array(xk, pk, gfk) - nx = get_backend(xk, pk) + xk0, pk0 = xk, pk + nx = get_backend(xk0, pk0) + else: + xk0, pk0 = xk, pk if len(xk.shape) == 0: xk = nx.reshape(xk, (-1,)) + xk = nx.to_numpy(xk) + pk = nx.to_numpy(pk) + gfk = nx.to_numpy(gfk) + fc = [0] def phi(alpha1): + # The callable function operates on nx backend fc[0] += 1 - return f(xk + alpha1 * pk, *args) + alpha10 = nx.from_numpy(alpha1) + fval = f(xk0 + alpha10 * pk0, *args) + if type(fval) is float: + # prevent bug from nx.to_numpy that can look for .cpu or .gpu + return fval + else: + return nx.to_numpy(fval) if old_fval is None: phi0 = phi(0.) - else: + elif type(old_fval) is float: + # prevent bug from nx.to_numpy that can look for .cpu or .gpu phi0 = old_fval + else: + phi0 = nx.to_numpy(old_fval) - derphi0 = nx.sum(pk * gfk) # Quickfix for matrices + derphi0 = np.sum(pk * gfk) # Quickfix for matrices alpha, phi1 = scalar_search_armijo( phi, phi0, derphi0, c1=c1, alpha0=alpha0) if alpha is None: - return 0., fc[0], phi0 + return 0., fc[0], nx.from_numpy(phi0, type_as=xk0) else: if alpha_min is not None or alpha_max is not None: alpha = np.clip(alpha, alpha_min, alpha_max) - return float(alpha), fc[0], phi1 + return nx.from_numpy(alpha, type_as=xk0), fc[0], nx.from_numpy(phi1, type_as=xk0) def generic_conditional_gradient(a, b, M, f, df, reg1, reg2, lp_solver, line_search, G0=None, diff --git a/test/test_gromov.py b/test/test_gromov.py index cfccce7..80b6df4 100644 --- a/test/test_gromov.py +++ b/test/test_gromov.py @@ -214,6 +214,7 @@ def test_gromov2_gradients(): C11 = torch.tensor(C1, requires_grad=True, device=device)
C12 = torch.tensor(C2, requires_grad=True, device=device)
+ # Test with exact line-search
val = ot.gromov_wasserstein2(C11, C12, p1, q1)
val.backward()
@@ -224,6 +225,21 @@ def test_gromov2_gradients(): assert C11.shape == C11.grad.shape
assert C12.shape == C12.grad.shape
+ # Test with armijo line-search
+ q1.grad = None
+ p1.grad = None
+ C11.grad = None
+ C12.grad = None
+ val = ot.gromov_wasserstein2(C11, C12, p1, q1, armijo=True)
+
+ val.backward()
+
+ assert val.device == p1.device
+ assert q1.shape == q1.grad.shape
+ assert p1.shape == p1.grad.shape
+ assert C11.shape == C11.grad.shape
+ assert C12.shape == C12.grad.shape
+
def test_gw_helper_backend(nx):
n_samples = 20 # nb samples
diff --git a/test/test_optim.py b/test/test_optim.py index 129fe22..a43e704 100644 --- a/test/test_optim.py +++ b/test/test_optim.py @@ -135,16 +135,18 @@ def test_line_search_armijo(nx): xk = np.array([[0.25, 0.25], [0.25, 0.25]]) pk = np.array([[-0.25, 0.25], [0.25, -0.25]]) gfk = np.array([[23.04273441, 23.0449082], [23.04273441, 23.0449082]]) - old_fval = -123 + old_fval = -123. xkb, pkb, gfkb = nx.from_numpy(xk, pk, gfk) + def f(x): + return 1. # Should not throw an exception and return 0. for alpha alpha, a, b = ot.optim.line_search_armijo( - lambda x: 1, xkb, pkb, gfkb, old_fval + f, xkb, pkb, gfkb, old_fval ) alpha_np, anp, bnp = ot.optim.line_search_armijo( - lambda x: 1, xk, pk, gfk, old_fval + f, xk, pk, gfk, old_fval ) assert a == anp assert b == bnp @@ -182,3 +184,50 @@ def test_line_search_armijo(nx): old_fval = f(xk) alpha, _, _ = ot.optim.line_search_armijo(f, xk, pk, gfk, old_fval) np.testing.assert_allclose(alpha, 0.1) + + +def test_line_search_armijo_dtype_device(nx): + for tp in nx.__type_list__: + def f(x): + return nx.sum((x - 5.0) ** 2) + + def grad(x): + return 2 * (x - 5.0) + + xk = np.array([[[-5.0, -5.0]]]) + pk = np.array([[[100.0, 100.0]]]) + xkb, pkb = nx.from_numpy(xk, pk, type_as=tp) + gfkb = grad(xkb) + old_fval = f(xkb) + + # chech the case where the optimum is on the direction + alpha, _, fval = ot.optim.line_search_armijo(f, xkb, pkb, gfkb, old_fval) + alpha = nx.to_numpy(alpha) + np.testing.assert_allclose(alpha, 0.1) + nx.assert_same_dtype_device(old_fval, fval) + + # check the case where the direction is not far enough + pk = np.array([[[3.0, 3.0]]]) + pkb = nx.from_numpy(pk, type_as=tp) + alpha, _, fval = ot.optim.line_search_armijo(f, xkb, pkb, gfkb, old_fval, alpha0=1.0) + alpha = nx.to_numpy(alpha) + np.testing.assert_allclose(alpha, 1.0) + nx.assert_same_dtype_device(old_fval, fval) + + # check the case where checking the wrong direction + alpha, _, fval = ot.optim.line_search_armijo(f, xkb, -pkb, gfkb, old_fval) + alpha = nx.to_numpy(alpha) + + assert alpha <= 0 + nx.assert_same_dtype_device(old_fval, fval) + + # check the case where the point is not a vector + xkb = nx.from_numpy(np.array(-5.0), type_as=tp) + pkb = nx.from_numpy(np.array(100), type_as=tp) + gfkb = grad(xkb) + old_fval = f(xkb) + alpha, _, fval = ot.optim.line_search_armijo(f, xkb, pkb, gfkb, old_fval) + alpha = nx.to_numpy(alpha) + + np.testing.assert_allclose(alpha, 0.1) + nx.assert_same_dtype_device(old_fval, fval) |