summaryrefslogtreecommitdiff
path: root/src/python/gudhi/tensorflow/perslay.py
diff options
context:
space:
mode:
authorMathieuCarriere <mathieu.carriere3@gmail.com>2022-11-05 22:33:21 +0100
committerMathieuCarriere <mathieu.carriere3@gmail.com>2022-11-05 22:33:21 +0100
commit85a93e6432771b7439ea7e2403dc702a66481033 (patch)
tree722e7ba778e4ab26af962d9d0f3b5452e46b6f24 /src/python/gudhi/tensorflow/perslay.py
parent077476721702b54d18a48a072b4f75de20046d8a (diff)
added a few comments in the doc
Diffstat (limited to 'src/python/gudhi/tensorflow/perslay.py')
-rw-r--r--src/python/gudhi/tensorflow/perslay.py18
1 files changed, 9 insertions, 9 deletions
diff --git a/src/python/gudhi/tensorflow/perslay.py b/src/python/gudhi/tensorflow/perslay.py
index 69acc529..9976c5f3 100644
--- a/src/python/gudhi/tensorflow/perslay.py
+++ b/src/python/gudhi/tensorflow/perslay.py
@@ -8,7 +8,7 @@
# - YYYY/MM Author: Description of the modification
import tensorflow as tf
-import numpy as np
+import math
class GridPerslayWeight(tf.keras.layers.Layer):
"""
@@ -156,7 +156,7 @@ class GaussianPerslayPhi(tf.keras.layers.Layer):
for _ in range(2):
diagrams_d = tf.expand_dims(diagrams_d,-1)
dists = tf.math.square(diagrams_d-mu) / (2*tf.math.square(self.variance))
- gauss = tf.math.exp(tf.math.reduce_sum(-dists, axis=2)) / (2*np.pi*tf.math.square(self.variance))
+ gauss = tf.math.exp(tf.math.reduce_sum(-dists, axis=2)) / (2*math.pi*tf.math.square(self.variance))
output = tf.expand_dims(gauss,-1)
output_shape = M[0].shape + tuple([1])
return output, output_shape
@@ -191,7 +191,7 @@ class TentPerslayPhi(tf.keras.layers.Layer):
"""
samples_d = tf.expand_dims(tf.expand_dims(self.samples,0),0)
xs, ys = diagrams[:,:,0:1], diagrams[:,:,1:2]
- output = tf.math.maximum(.5*(ys-xs) - tf.math.abs(samples_d-.5*(ys+xs)), np.array([0.]))
+ output = tf.math.maximum(.5*(ys-xs) - tf.math.abs(samples_d-.5*(ys+xs)), tf.constant([0.]))
output_shape = self.samples.shape
return output, output_shape
@@ -238,17 +238,17 @@ class Perslay(tf.keras.layers.Layer):
def __init__(self, weight, phi, perm_op, rho, **kwargs):
"""
Constructor for the Perslay class.
-
+
Parameters:
- weight (function): weight function for the persistence diagram points. Can be either :class:`~gudhi.tensorflow.GridPerslayWeight`, :class:`~gudhi.tensorflow.GaussianMixturePerslayWeight`, :class:`~gudhi.tensorflow.PowerPerslayWeight`, or a custom function.
- phi (function): transformation function for the persistence diagram points. Can be either :class:`~gudhi.tensorflow.GaussianPerslayPhi`, :class:`~gudhi.tensorflow.TentPerslayPhi`, :class:`~gudhi.tensorflow.FlatPerslayPhi`, or a custom function.
- perm_op (function): permutation invariant function, such as `tf.math.reduce_sum`, `tf.math.reduce_mean`, `tf.math.reduce_max`, `tf.math.reduce_min`, or a custom function. If perm_op is the string "topk" (where k is a number), this function will be computed as `tf.math.top_k` with parameter `int(k)`.
+ weight (function): weight function for the persistence diagram points. Can be either :class:`~gudhi.tensorflow.perslay.GridPerslayWeight`, :class:`~gudhi.tensorflow.perslay.GaussianMixturePerslayWeight`, :class:`~gudhi.tensorflow.perslay.PowerPerslayWeight`, or a custom TensorFlow function that takes persistence diagrams as argument (represented as an (n x None x 2) ragged tensor, where n is the number of diagrams).
+ phi (function): transformation function for the persistence diagram points. Can be either :class:`~gudhi.tensorflow.perslay.GaussianPerslayPhi`, :class:`~gudhi.tensorflow.perslay.TentPerslayPhi`, :class:`~gudhi.tensorflow.perslay.FlatPerslayPhi`, or a custom TensorFlow class (that can have trainable parameters) with a method `call` that takes persistence diagrams as argument (represented as an (n x None x 2) ragged tensor, where n is the number of diagrams).
+ perm_op (function): permutation invariant function, such as `tf.math.reduce_sum`, `tf.math.reduce_mean`, `tf.math.reduce_max`, `tf.math.reduce_min`, or a custom TensorFlow function that takes two arguments: a tensor and an axis on which to apply the permutation invariant operation. If perm_op is the string "topk" (where k is a number), this function will be computed as `tf.math.top_k` with parameter `int(k)`.
rho (function): postprocessing function that is applied after the permutation invariant operation. Can be any TensorFlow layer.
"""
super().__init__(dynamic=True, **kwargs)
self.weight = weight
self.phi = phi
- self.pop = perm_op
+ self.perm_op = perm_op
self.rho = rho
def build(self, input_shape):
@@ -270,7 +270,7 @@ class Perslay(tf.keras.layers.Layer):
weight = tf.expand_dims(weight, -1)
vector = tf.math.multiply(vector, weight)
- permop = self.pop
+ permop = self.perm_op
if type(permop) == str and permop[:3] == 'top':
k = int(permop[3:])
vector = vector.to_tensor(default_value=-1e10)