summaryrefslogtreecommitdiff
path: root/src/python/test
diff options
context:
space:
mode:
authorMathieuCarriere <mathieu.carriere3@gmail.com>2022-04-16 11:21:09 +0200
committerMathieuCarriere <mathieu.carriere3@gmail.com>2022-04-16 11:21:09 +0200
commitcc723a7a3735a44491bd1085b6bb6c47272b73ed (patch)
treedfe9e0a100a2fd8a9f8b6505068c39fe52f44bd0 /src/python/test
parent27f8df308e3ed935e4ef9f62d23717efebdf36ae (diff)
fix test
Diffstat (limited to 'src/python/test')
-rw-r--r--src/python/test/test_diff.py4
1 files changed, 2 insertions, 2 deletions
diff --git a/src/python/test/test_diff.py b/src/python/test/test_diff.py
index bab0d10c..e0c99d07 100644
--- a/src/python/test/test_diff.py
+++ b/src/python/test/test_diff.py
@@ -22,7 +22,7 @@ def test_cubical_diff():
cl = CubicalLayer(dimensions=[0])
with tf.GradientTape() as tape:
- dgm = cl.call(X)[0][0]
+ dgm = cl.call(X)[0]
loss = tf.math.reduce_sum(tf.square(.5*(dgm[:,1]-dgm[:,0])))
grads = tape.gradient(loss, [X])
assert np.abs(grads[0].numpy()-np.array([[0.,0.,0.],[0.,.5,0.],[0.,0.,-.5]])).sum() <= 1e-6
@@ -34,7 +34,7 @@ def test_nonsquare_cubical_diff():
cl = CubicalLayer(dimensions=[0])
with tf.GradientTape() as tape:
- dgm = cl.call(X)[0][0]
+ dgm = cl.call(X)[0]
loss = tf.math.reduce_sum(tf.square(.5*(dgm[:,1]-dgm[:,0])))
grads = tape.gradient(loss, [X])
assert np.abs(grads[0].numpy()-np.array([[0.,0.5,-0.5],[0.,0.,0.]])).sum() <= 1e-6