summaryrefslogtreecommitdiff
path: root/test/test_backend.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/test_backend.py')
-rw-r--r--test/test_backend.py66
1 files changed, 62 insertions, 4 deletions
diff --git a/test/test_backend.py b/test/test_backend.py
index 027c4cd..311c075 100644
--- a/test/test_backend.py
+++ b/test/test_backend.py
@@ -218,6 +218,8 @@ def test_empty_backend():
with pytest.raises(NotImplementedError):
nx.argmax(M)
with pytest.raises(NotImplementedError):
+ nx.argmin(M)
+ with pytest.raises(NotImplementedError):
nx.mean(M)
with pytest.raises(NotImplementedError):
nx.std(M)
@@ -264,12 +266,27 @@ def test_empty_backend():
nx.device_type(M)
with pytest.raises(NotImplementedError):
nx._bench(lambda x: x, M, n_runs=1)
+ with pytest.raises(NotImplementedError):
+ nx.solve(M, v)
+ with pytest.raises(NotImplementedError):
+ nx.trace(M)
+ with pytest.raises(NotImplementedError):
+ nx.inv(M)
+ with pytest.raises(NotImplementedError):
+ nx.sqrtm(M)
+ with pytest.raises(NotImplementedError):
+ nx.isfinite(M)
+ with pytest.raises(NotImplementedError):
+ nx.array_equal(M, M)
+ with pytest.raises(NotImplementedError):
+ nx.is_floating_point(M)
def test_func_backends(nx):
rnd = np.random.RandomState(0)
M = rnd.randn(10, 3)
+ SquareM = rnd.randn(10, 10)
v = rnd.randn(3)
val = np.array([1.0])
@@ -288,6 +305,7 @@ def test_func_backends(nx):
lst_name = []
Mb = nx.from_numpy(M)
+ SquareMb = nx.from_numpy(SquareM)
vb = nx.from_numpy(v)
val = nx.from_numpy(val)
@@ -467,6 +485,10 @@ def test_func_backends(nx):
lst_b.append(nx.to_numpy(A))
lst_name.append('argmax')
+ A = nx.argmin(Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('argmin')
+
A = nx.mean(Mb)
lst_b.append(nx.to_numpy(A))
lst_name.append('mean')
@@ -529,7 +551,11 @@ def test_func_backends(nx):
A = nx.where(Mb >= nx.stack([nx.linspace(0, 1, 10)] * 3, axis=1), Mb, 0.0)
lst_b.append(nx.to_numpy(A))
- lst_name.append('where')
+ lst_name.append('where (cond, x, y)')
+
+ A = nx.where(nx.from_numpy(np.array([True, False])))
+ lst_b.append(nx.to_numpy(nx.stack(A)))
+ lst_name.append('where (cond)')
A = nx.copy(Mb)
lst_b.append(nx.to_numpy(A))
@@ -550,15 +576,47 @@ def test_func_backends(nx):
nx._bench(lambda x: x, M, n_runs=1)
+ A = nx.solve(SquareMb, Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('solve')
+
+ A = nx.trace(SquareMb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('trace')
+
+ A = nx.inv(SquareMb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('matrix inverse')
+
+ A = nx.sqrtm(SquareMb.T @ SquareMb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append("matrix square root")
+
+ A = nx.concatenate([vb, nx.from_numpy(np.array([np.inf, np.nan]))], axis=0)
+ A = nx.isfinite(A)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append("isfinite")
+
+ assert not nx.array_equal(Mb, vb), "array_equal (shape)"
+ assert nx.array_equal(Mb, Mb), "array_equal (elements) - expected true"
+ assert not nx.array_equal(
+ Mb, Mb + nx.eye(*list(Mb.shape))
+ ), "array_equal (elements) - expected false"
+
+ assert nx.is_floating_point(Mb), "is_floating_point - expected true"
+ assert not nx.is_floating_point(
+ nx.from_numpy(np.array([0, 1, 2], dtype=int))
+ ), "is_floating_point - expected false"
+
lst_tot.append(lst_b)
lst_np = lst_tot[0]
lst_b = lst_tot[1]
for a1, a2, name in zip(lst_np, lst_b, lst_name):
- if not np.allclose(a1, a2):
- print('Assert fail on: ', name)
- assert np.allclose(a1, a2, atol=1e-7)
+ np.testing.assert_allclose(
+ a2, a1, atol=1e-7, err_msg=f'ASSERT FAILED ON: {name}'
+ )
def test_random_backends(nx):