From 767171593f2a98a26b9a39bf110a45085e3b982e Mon Sep 17 00:00:00 2001 From: Nathan Cassereau <84033440+ncassereau-idris@users.noreply.github.com> Date: Thu, 24 Mar 2022 10:53:47 +0100 Subject: [MRG] Domain adaptation and unbalanced solvers with backend support (#343) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * First draft * Add matrix inverse and square root to backend * Eigen decomposition for older versions of pytorch (1.8.1 and older) * Corrected eigen decomposition for pytorch 1.8.1 and older * Spectral theorem is a thing * Optimization * small optimization * More functions converted * pep8 * remove a warning and prepare torch meshgrid for future torch release (which will change default indexing) * dots and pep8 * Meshgrid corrected for older version and prepared for future versions changes * New backend functions * Base transport * LinearTransport * All transport classes + pep8 * PR added to release file * Jcpot barycenter test * unbalanced with backend * pep8 * bug solve * test of domain adaptation with backends * solve bug for tic toc & macos * solving scipy deprecation warning * solving scipy deprecation warning attempt2 * solving scipy deprecation warning attempt3 * A warning is triggered when a float->int conversion is detected * bug solve * docs * release file updated * Better handling of float->int conversion in EMD * Corrected test for is_floating_point * docs * release file updated * cupy does not allow implicit cast * fromnumpy * added test * test da tf jax * test unbalanced with no provided histogram * using type_as argument in unif function correctly * pep8 * transport plan cast in emd changed behaviour, now trying to cast as histogram's dtype, defaulting to cost matrix Co-authored-by: RĂ©mi Flamary --- test/test_backend.py | 66 ++++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 62 insertions(+), 4 deletions(-) (limited to 'test/test_backend.py') diff --git a/test/test_backend.py b/test/test_backend.py index 027c4cd..311c075 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -217,6 +217,8 @@ def test_empty_backend(): nx.zero_pad(M, v) with pytest.raises(NotImplementedError): nx.argmax(M) + with pytest.raises(NotImplementedError): + nx.argmin(M) with pytest.raises(NotImplementedError): nx.mean(M) with pytest.raises(NotImplementedError): @@ -264,12 +266,27 @@ def test_empty_backend(): nx.device_type(M) with pytest.raises(NotImplementedError): nx._bench(lambda x: x, M, n_runs=1) + with pytest.raises(NotImplementedError): + nx.solve(M, v) + with pytest.raises(NotImplementedError): + nx.trace(M) + with pytest.raises(NotImplementedError): + nx.inv(M) + with pytest.raises(NotImplementedError): + nx.sqrtm(M) + with pytest.raises(NotImplementedError): + nx.isfinite(M) + with pytest.raises(NotImplementedError): + nx.array_equal(M, M) + with pytest.raises(NotImplementedError): + nx.is_floating_point(M) def test_func_backends(nx): rnd = np.random.RandomState(0) M = rnd.randn(10, 3) + SquareM = rnd.randn(10, 10) v = rnd.randn(3) val = np.array([1.0]) @@ -288,6 +305,7 @@ def test_func_backends(nx): lst_name = [] Mb = nx.from_numpy(M) + SquareMb = nx.from_numpy(SquareM) vb = nx.from_numpy(v) val = nx.from_numpy(val) @@ -467,6 +485,10 @@ def test_func_backends(nx): lst_b.append(nx.to_numpy(A)) lst_name.append('argmax') + A = nx.argmin(Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('argmin') + A = nx.mean(Mb) lst_b.append(nx.to_numpy(A)) lst_name.append('mean') @@ -529,7 +551,11 @@ def test_func_backends(nx): A = nx.where(Mb >= nx.stack([nx.linspace(0, 1, 10)] * 3, axis=1), Mb, 0.0) lst_b.append(nx.to_numpy(A)) - lst_name.append('where') + lst_name.append('where (cond, x, y)') + + A = nx.where(nx.from_numpy(np.array([True, False]))) + lst_b.append(nx.to_numpy(nx.stack(A))) + lst_name.append('where (cond)') A = nx.copy(Mb) lst_b.append(nx.to_numpy(A)) @@ -550,15 +576,47 @@ def test_func_backends(nx): nx._bench(lambda x: x, M, n_runs=1) + A = nx.solve(SquareMb, Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('solve') + + A = nx.trace(SquareMb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('trace') + + A = nx.inv(SquareMb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('matrix inverse') + + A = nx.sqrtm(SquareMb.T @ SquareMb) + lst_b.append(nx.to_numpy(A)) + lst_name.append("matrix square root") + + A = nx.concatenate([vb, nx.from_numpy(np.array([np.inf, np.nan]))], axis=0) + A = nx.isfinite(A) + lst_b.append(nx.to_numpy(A)) + lst_name.append("isfinite") + + assert not nx.array_equal(Mb, vb), "array_equal (shape)" + assert nx.array_equal(Mb, Mb), "array_equal (elements) - expected true" + assert not nx.array_equal( + Mb, Mb + nx.eye(*list(Mb.shape)) + ), "array_equal (elements) - expected false" + + assert nx.is_floating_point(Mb), "is_floating_point - expected true" + assert not nx.is_floating_point( + nx.from_numpy(np.array([0, 1, 2], dtype=int)) + ), "is_floating_point - expected false" + lst_tot.append(lst_b) lst_np = lst_tot[0] lst_b = lst_tot[1] for a1, a2, name in zip(lst_np, lst_b, lst_name): - if not np.allclose(a1, a2): - print('Assert fail on: ', name) - assert np.allclose(a1, a2, atol=1e-7) + np.testing.assert_allclose( + a2, a1, atol=1e-7, err_msg=f'ASSERT FAILED ON: {name}' + ) def test_random_backends(nx): -- cgit v1.2.3