From 76450dddf8dd62b9714b72e99ae075516246d433 Mon Sep 17 00:00:00 2001 From: ncassereau-idris <84033440+ncassereau-idris@users.noreply.github.com> Date: Mon, 25 Oct 2021 17:35:36 +0200 Subject: [MRG] Backend for optim (#282) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Backend for optim * Bug solve * Doc update * backend tests now with fixture * Unused imports removed * Docs * Docs * Docs * Outer product backend docs * Prettier docs * Pep8 * Mistakes corrected Co-authored-by: RĂ©mi Flamary --- test/test_backend.py | 22 +++++++++------ test/test_optim.py | 78 +++++++++++++++++++++++++++++++++++++++++----------- test/test_ot.py | 6 +--- test/test_utils.py | 7 ----- 4 files changed, 77 insertions(+), 36 deletions(-) (limited to 'test') diff --git a/test/test_backend.py b/test/test_backend.py index 859da5a..5853282 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -17,9 +17,6 @@ from numpy.testing import assert_array_almost_equal_nulp from ot.backend import get_backend, get_backend_list, to_numpy -backend_list = get_backend_list() - - def test_get_backend_list(): lst = get_backend_list() @@ -28,7 +25,6 @@ def test_get_backend_list(): assert isinstance(lst[0], ot.backend.NumpyBackend) -@pytest.mark.parametrize('nx', backend_list) def test_to_numpy(nx): v = nx.zeros(10) @@ -92,7 +88,6 @@ def test_get_backend(): get_backend(A, B2) -@pytest.mark.parametrize('nx', backend_list) def test_convert_between_backends(nx): A = np.zeros((3, 2)) @@ -180,6 +175,8 @@ def test_empty_backend(): nx.searchsorted(v, v) with pytest.raises(NotImplementedError): nx.flip(M) + with pytest.raises(NotImplementedError): + nx.outer(v, v) with pytest.raises(NotImplementedError): nx.clip(M, -1, 1) with pytest.raises(NotImplementedError): @@ -208,10 +205,11 @@ def test_empty_backend(): nx.logsumexp(M) with pytest.raises(NotImplementedError): nx.stack([M, M]) + with pytest.raises(NotImplementedError): + nx.reshape(M, (5, 3, 2)) -@pytest.mark.parametrize('backend', backend_list) -def test_func_backends(backend): +def test_func_backends(nx): rnd = np.random.RandomState(0) M = rnd.randn(10, 3) @@ -220,7 +218,7 @@ def test_func_backends(backend): lst_tot = [] - for nx in [ot.backend.NumpyBackend(), backend]: + for nx in [ot.backend.NumpyBackend(), nx]: print('Backend: ', nx.__name__) @@ -371,6 +369,10 @@ def test_func_backends(backend): lst_b.append(nx.to_numpy(A)) lst_name.append('flip') + A = nx.outer(vb, vb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('outer') + A = nx.clip(vb, 0, 1) lst_b.append(nx.to_numpy(A)) lst_name.append('clip') @@ -432,6 +434,10 @@ def test_func_backends(backend): lst_b.append(nx.to_numpy(A)) lst_name.append('stack') + A = nx.reshape(Mb, (5, 3, 2)) + lst_b.append(nx.to_numpy(A)) + lst_name.append('reshape') + lst_tot.append(lst_b) lst_np = lst_tot[0] diff --git a/test/test_optim.py b/test/test_optim.py index 94995d5..4efd9b1 100644 --- a/test/test_optim.py +++ b/test/test_optim.py @@ -8,7 +8,7 @@ import numpy as np import ot -def test_conditional_gradient(): +def test_conditional_gradient(nx): n_bins = 100 # nb bins np.random.seed(0) @@ -29,15 +29,25 @@ def test_conditional_gradient(): def df(G): return G + def fb(G): + return 0.5 * nx.sum(G ** 2) + + ab = nx.from_numpy(a) + bb = nx.from_numpy(b) + Mb = nx.from_numpy(M, type_as=ab) + reg = 1e-1 G, log = ot.optim.cg(a, b, M, reg, f, df, verbose=True, log=True) + Gb, log = ot.optim.cg(ab, bb, Mb, reg, fb, df, verbose=True, log=True) + Gb = nx.to_numpy(Gb) - np.testing.assert_allclose(a, G.sum(1)) - np.testing.assert_allclose(b, G.sum(0)) + np.testing.assert_allclose(Gb, G) + np.testing.assert_allclose(a, Gb.sum(1)) + np.testing.assert_allclose(b, Gb.sum(0)) -def test_conditional_gradient_itermax(): +def test_conditional_gradient_itermax(nx): n = 100 # nb samples mu_s = np.array([0, 0]) @@ -61,16 +71,27 @@ def test_conditional_gradient_itermax(): def df(G): return G + def fb(G): + return 0.5 * nx.sum(G ** 2) + + ab = nx.from_numpy(a) + bb = nx.from_numpy(b) + Mb = nx.from_numpy(M, type_as=ab) + reg = 1e-1 G, log = ot.optim.cg(a, b, M, reg, f, df, numItermaxEmd=10000, verbose=True, log=True) + Gb, log = ot.optim.cg(ab, bb, Mb, reg, fb, df, numItermaxEmd=10000, + verbose=True, log=True) + Gb = nx.to_numpy(Gb) - np.testing.assert_allclose(a, G.sum(1)) - np.testing.assert_allclose(b, G.sum(0)) + np.testing.assert_allclose(Gb, G) + np.testing.assert_allclose(a, Gb.sum(1)) + np.testing.assert_allclose(b, Gb.sum(0)) -def test_generalized_conditional_gradient(): +def test_generalized_conditional_gradient(nx): n_bins = 100 # nb bins np.random.seed(0) @@ -91,13 +112,23 @@ def test_generalized_conditional_gradient(): def df(G): return G + def fb(G): + return 0.5 * nx.sum(G ** 2) + reg1 = 1e-3 reg2 = 1e-1 + ab = nx.from_numpy(a) + bb = nx.from_numpy(b) + Mb = nx.from_numpy(M, type_as=ab) + G, log = ot.optim.gcg(a, b, M, reg1, reg2, f, df, verbose=True, log=True) + Gb, log = ot.optim.gcg(ab, bb, Mb, reg1, reg2, fb, df, verbose=True, log=True) + Gb = nx.to_numpy(Gb) - np.testing.assert_allclose(a, G.sum(1), atol=1e-05) - np.testing.assert_allclose(b, G.sum(0), atol=1e-05) + np.testing.assert_allclose(Gb, G) + np.testing.assert_allclose(a, Gb.sum(1), atol=1e-05) + np.testing.assert_allclose(b, Gb.sum(0), atol=1e-05) def test_solve_1d_linesearch_quad_funct(): @@ -106,24 +137,31 @@ def test_solve_1d_linesearch_quad_funct(): np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad(-1, 0.5, 0), 1) -def test_line_search_armijo(): +def test_line_search_armijo(nx): xk = np.array([[0.25, 0.25], [0.25, 0.25]]) pk = np.array([[-0.25, 0.25], [0.25, -0.25]]) gfk = np.array([[23.04273441, 23.0449082], [23.04273441, 23.0449082]]) old_fval = -123 # Should not throw an exception and return None for alpha - alpha, _, _ = ot.optim.line_search_armijo(lambda x: 1, xk, pk, gfk, old_fval) + alpha, a, b = ot.optim.line_search_armijo( + lambda x: 1, nx.from_numpy(xk), nx.from_numpy(pk), nx.from_numpy(gfk), old_fval + ) + alpha_np, anp, bnp = ot.optim.line_search_armijo( + lambda x: 1, xk, pk, gfk, old_fval + ) + assert a == anp + assert b == bnp assert alpha is None # check line search armijo def f(x): - return np.sum((x - 5.0) ** 2) + return nx.sum((x - 5.0) ** 2) def grad(x): return 2 * (x - 5.0) - xk = np.array([[[-5.0, -5.0]]]) - pk = np.array([[[100.0, 100.0]]]) + xk = nx.from_numpy(np.array([[[-5.0, -5.0]]])) + pk = nx.from_numpy(np.array([[[100.0, 100.0]]])) gfk = grad(xk) old_fval = f(xk) @@ -132,10 +170,18 @@ def test_line_search_armijo(): np.testing.assert_allclose(alpha, 0.1) # check the case where the direction is not far enough - pk = np.array([[[3.0, 3.0]]]) + pk = nx.from_numpy(np.array([[[3.0, 3.0]]])) alpha, _, _ = ot.optim.line_search_armijo(f, xk, pk, gfk, old_fval, alpha0=1.0) np.testing.assert_allclose(alpha, 1.0) - # check the case where the checking the wrong direction + # check the case where checking the wrong direction alpha, _, _ = ot.optim.line_search_armijo(f, xk, -pk, gfk, old_fval) assert alpha <= 0 + + # check the case where the point is not a vector + xk = nx.from_numpy(np.array(-5.0)) + pk = nx.from_numpy(np.array(100.0)) + gfk = grad(xk) + old_fval = f(xk) + alpha, _, _ = ot.optim.line_search_armijo(f, xk, pk, gfk, old_fval) + np.testing.assert_allclose(alpha, 0.1) diff --git a/test/test_ot.py b/test/test_ot.py index 3e953dc..4dfc510 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -12,9 +12,7 @@ from scipy.stats import wasserstein_distance import ot from ot.datasets import make_1D_gauss as gauss -from ot.backend import get_backend_list, torch - -backend_list = get_backend_list() +from ot.backend import torch def test_emd_dimension_and_mass_mismatch(): @@ -37,7 +35,6 @@ def test_emd_dimension_and_mass_mismatch(): np.testing.assert_raises(AssertionError, ot.emd, a, b, M) -@pytest.mark.parametrize('nx', backend_list) def test_emd_backends(nx): n_samples = 100 n_features = 2 @@ -59,7 +56,6 @@ def test_emd_backends(nx): np.allclose(G, nx.to_numpy(Gb)) -@pytest.mark.parametrize('nx', backend_list) def test_emd2_backends(nx): n_samples = 100 n_features = 2 diff --git a/test/test_utils.py b/test/test_utils.py index 76b1faa..60ad5d3 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -4,17 +4,11 @@ # # License: MIT License -import pytest import ot import numpy as np import sys -from ot.backend import get_backend_list -backend_list = get_backend_list() - - -@pytest.mark.parametrize('nx', backend_list) def test_proj_simplex(nx): n = 10 rng = np.random.RandomState(0) @@ -119,7 +113,6 @@ def test_dist(): np.testing.assert_allclose(D, D3, atol=1e-14) -@ pytest.mark.parametrize('nx', backend_list) def test_dist_backends(nx): n = 100 -- cgit v1.2.3