diff options
Diffstat (limited to 'src/python/gudhi/tensorflow/perslay.py')
-rw-r--r-- | src/python/gudhi/tensorflow/perslay.py | 18 |
1 files changed, 9 insertions, 9 deletions
diff --git a/src/python/gudhi/tensorflow/perslay.py b/src/python/gudhi/tensorflow/perslay.py index 69acc529..9976c5f3 100644 --- a/src/python/gudhi/tensorflow/perslay.py +++ b/src/python/gudhi/tensorflow/perslay.py @@ -8,7 +8,7 @@ # - YYYY/MM Author: Description of the modification import tensorflow as tf -import numpy as np +import math class GridPerslayWeight(tf.keras.layers.Layer): """ @@ -156,7 +156,7 @@ class GaussianPerslayPhi(tf.keras.layers.Layer): for _ in range(2): diagrams_d = tf.expand_dims(diagrams_d,-1) dists = tf.math.square(diagrams_d-mu) / (2*tf.math.square(self.variance)) - gauss = tf.math.exp(tf.math.reduce_sum(-dists, axis=2)) / (2*np.pi*tf.math.square(self.variance)) + gauss = tf.math.exp(tf.math.reduce_sum(-dists, axis=2)) / (2*math.pi*tf.math.square(self.variance)) output = tf.expand_dims(gauss,-1) output_shape = M[0].shape + tuple([1]) return output, output_shape @@ -191,7 +191,7 @@ class TentPerslayPhi(tf.keras.layers.Layer): """ samples_d = tf.expand_dims(tf.expand_dims(self.samples,0),0) xs, ys = diagrams[:,:,0:1], diagrams[:,:,1:2] - output = tf.math.maximum(.5*(ys-xs) - tf.math.abs(samples_d-.5*(ys+xs)), np.array([0.])) + output = tf.math.maximum(.5*(ys-xs) - tf.math.abs(samples_d-.5*(ys+xs)), tf.constant([0.])) output_shape = self.samples.shape return output, output_shape @@ -238,17 +238,17 @@ class Perslay(tf.keras.layers.Layer): def __init__(self, weight, phi, perm_op, rho, **kwargs): """ Constructor for the Perslay class. - + Parameters: - weight (function): weight function for the persistence diagram points. Can be either :class:`~gudhi.tensorflow.GridPerslayWeight`, :class:`~gudhi.tensorflow.GaussianMixturePerslayWeight`, :class:`~gudhi.tensorflow.PowerPerslayWeight`, or a custom function. - phi (function): transformation function for the persistence diagram points. Can be either :class:`~gudhi.tensorflow.GaussianPerslayPhi`, :class:`~gudhi.tensorflow.TentPerslayPhi`, :class:`~gudhi.tensorflow.FlatPerslayPhi`, or a custom function. - perm_op (function): permutation invariant function, such as `tf.math.reduce_sum`, `tf.math.reduce_mean`, `tf.math.reduce_max`, `tf.math.reduce_min`, or a custom function. If perm_op is the string "topk" (where k is a number), this function will be computed as `tf.math.top_k` with parameter `int(k)`. + weight (function): weight function for the persistence diagram points. Can be either :class:`~gudhi.tensorflow.perslay.GridPerslayWeight`, :class:`~gudhi.tensorflow.perslay.GaussianMixturePerslayWeight`, :class:`~gudhi.tensorflow.perslay.PowerPerslayWeight`, or a custom TensorFlow function that takes persistence diagrams as argument (represented as an (n x None x 2) ragged tensor, where n is the number of diagrams). + phi (function): transformation function for the persistence diagram points. Can be either :class:`~gudhi.tensorflow.perslay.GaussianPerslayPhi`, :class:`~gudhi.tensorflow.perslay.TentPerslayPhi`, :class:`~gudhi.tensorflow.perslay.FlatPerslayPhi`, or a custom TensorFlow class (that can have trainable parameters) with a method `call` that takes persistence diagrams as argument (represented as an (n x None x 2) ragged tensor, where n is the number of diagrams). + perm_op (function): permutation invariant function, such as `tf.math.reduce_sum`, `tf.math.reduce_mean`, `tf.math.reduce_max`, `tf.math.reduce_min`, or a custom TensorFlow function that takes two arguments: a tensor and an axis on which to apply the permutation invariant operation. If perm_op is the string "topk" (where k is a number), this function will be computed as `tf.math.top_k` with parameter `int(k)`. rho (function): postprocessing function that is applied after the permutation invariant operation. Can be any TensorFlow layer. """ super().__init__(dynamic=True, **kwargs) self.weight = weight self.phi = phi - self.pop = perm_op + self.perm_op = perm_op self.rho = rho def build(self, input_shape): @@ -270,7 +270,7 @@ class Perslay(tf.keras.layers.Layer): weight = tf.expand_dims(weight, -1) vector = tf.math.multiply(vector, weight) - permop = self.pop + permop = self.perm_op if type(permop) == str and permop[:3] == 'top': k = int(permop[3:]) vector = vector.to_tensor(default_value=-1e10) |