diff options
author | Clément Bonet <32179275+clbonet@users.noreply.github.com> | 2023-05-05 10:53:48 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-05-05 10:53:48 +0200 |
commit | 7e0ea27ad9cad31cfc2181430d837c0a77a61568 (patch) | |
tree | 0a41128a975500bfef52a4c21b5af634adecc71a /test/test_backend.py | |
parent | 83dc498b496087aea293df1445442d8728435211 (diff) |
[MRG] Fix bug SSW backend (#471)
* fix bug np vs torch matmul
* typo error
* einsum projections ssw
* Test broadcast matmul
* einsum projections ssw
* Test broadcast matmul
* projections SSW with einsum
* reduce number of samples in test wasserstein_circle_unif
* Update releases.md
---------
Co-authored-by: Rémi Flamary <remi.flamary@gmail.com>
Diffstat (limited to 'test/test_backend.py')
-rw-r--r-- | test/test_backend.py | 15 |
1 files changed, 15 insertions, 0 deletions
diff --git a/test/test_backend.py b/test/test_backend.py index 5351e52..fedc62f 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -298,6 +298,8 @@ def test_empty_backend(): nx.transpose(M) with pytest.raises(NotImplementedError): nx.detach(M) + with pytest.raises(NotImplementedError): + nx.matmul(M, M.T) def test_func_backends(nx): @@ -308,6 +310,9 @@ def test_func_backends(nx): v = rnd.randn(3) val = np.array([1.0]) + M1 = rnd.randn(1, 2, 10, 10) + M2 = rnd.randn(3, 1, 10, 10) + # Sparse tensors test sp_row = np.array([0, 3, 1, 0, 3]) sp_col = np.array([0, 3, 1, 2, 2]) @@ -326,6 +331,9 @@ def test_func_backends(nx): SquareMb = nx.from_numpy(SquareM) vb = nx.from_numpy(v) + M1b = nx.from_numpy(M1) + M2b = nx.from_numpy(M2) + val = nx.from_numpy(val) sp_rowb = nx.from_numpy(sp_row) @@ -661,6 +669,13 @@ def test_func_backends(nx): lst_b.append(nx.to_numpy(B)) lst_name.append("detach B") + A = nx.matmul(Mb, Mb.T) + lst_b.append(nx.to_numpy(A)) + lst_name.append("matmul") + A = nx.matmul(M1b, M2b) + lst_b.append(nx.to_numpy(A)) + lst_name.append("matmul broadcast") + assert not nx.array_equal(Mb, vb), "array_equal (shape)" assert nx.array_equal(Mb, Mb), "array_equal (elements) - expected true" assert not nx.array_equal( |