summaryrefslogtreecommitdiff
path: root/test/test_backend.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/test_backend.py')
-rw-r--r--test/test_backend.py15
1 files changed, 15 insertions, 0 deletions
diff --git a/test/test_backend.py b/test/test_backend.py
index 5351e52..fedc62f 100644
--- a/test/test_backend.py
+++ b/test/test_backend.py
@@ -298,6 +298,8 @@ def test_empty_backend():
nx.transpose(M)
with pytest.raises(NotImplementedError):
nx.detach(M)
+ with pytest.raises(NotImplementedError):
+ nx.matmul(M, M.T)
def test_func_backends(nx):
@@ -308,6 +310,9 @@ def test_func_backends(nx):
v = rnd.randn(3)
val = np.array([1.0])
+ M1 = rnd.randn(1, 2, 10, 10)
+ M2 = rnd.randn(3, 1, 10, 10)
+
# Sparse tensors test
sp_row = np.array([0, 3, 1, 0, 3])
sp_col = np.array([0, 3, 1, 2, 2])
@@ -326,6 +331,9 @@ def test_func_backends(nx):
SquareMb = nx.from_numpy(SquareM)
vb = nx.from_numpy(v)
+ M1b = nx.from_numpy(M1)
+ M2b = nx.from_numpy(M2)
+
val = nx.from_numpy(val)
sp_rowb = nx.from_numpy(sp_row)
@@ -661,6 +669,13 @@ def test_func_backends(nx):
lst_b.append(nx.to_numpy(B))
lst_name.append("detach B")
+ A = nx.matmul(Mb, Mb.T)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append("matmul")
+ A = nx.matmul(M1b, M2b)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append("matmul broadcast")
+
assert not nx.array_equal(Mb, vb), "array_equal (shape)"
assert nx.array_equal(Mb, Mb), "array_equal (elements) - expected true"
assert not nx.array_equal(