summaryrefslogtreecommitdiff
path: root/src/python/test
diff options
context:
space:
mode:
authorMathieuCarriere <mathieu.carriere3@gmail.com>2021-11-12 09:46:22 +0100
committerMathieuCarriere <mathieu.carriere3@gmail.com>2021-11-12 09:46:22 +0100
commit6ae793a8cad4503d1795e227d40d85d43954d1dd (patch)
treee46c9cc11628456739008f281a860a4df1d20775 /src/python/test
parent3f1a6e659611dce2913fddc93b01480f05fb7983 (diff)
removed unraveling in cubical
Diffstat (limited to 'src/python/test')
-rw-r--r--src/python/test/test_diff.py13
1 files changed, 12 insertions, 1 deletions
diff --git a/src/python/test/test_diff.py b/src/python/test/test_diff.py
index f49eff7b..e0c99d07 100644
--- a/src/python/test/test_diff.py
+++ b/src/python/test/test_diff.py
@@ -15,7 +15,6 @@ def test_rips_diff():
grads = tape.gradient(loss, [X])
assert np.abs(grads[0].numpy()-np.array([[-.5,-.5],[.5,.5]])).sum() <= 1e-6
-
def test_cubical_diff():
Xinit = np.array([[0.,2.,2.],[2.,2.,2.],[2.,2.,1.]], dtype=np.float32)
@@ -28,6 +27,18 @@ def test_cubical_diff():
grads = tape.gradient(loss, [X])
assert np.abs(grads[0].numpy()-np.array([[0.,0.,0.],[0.,.5,0.],[0.,0.,-.5]])).sum() <= 1e-6
+def test_nonsquare_cubical_diff():
+
+ Xinit = np.array([[-1.,1.,0.],[1.,1.,1.]], dtype=np.float32)
+ X = tf.Variable(initial_value=Xinit, trainable=True)
+ cl = CubicalLayer(dimensions=[0])
+
+ with tf.GradientTape() as tape:
+ dgm = cl.call(X)[0]
+ loss = tf.math.reduce_sum(tf.square(.5*(dgm[:,1]-dgm[:,0])))
+ grads = tape.gradient(loss, [X])
+ assert np.abs(grads[0].numpy()-np.array([[0.,0.5,-0.5],[0.,0.,0.]])).sum() <= 1e-6
+
def test_st_diff():
st = gd.SimplexTree()