summaryrefslogtreecommitdiff
path: root/src/python/gudhi/sktda/vector_methods.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/python/gudhi/sktda/vector_methods.py')
-rw-r--r--src/python/gudhi/sktda/vector_methods.py12
1 files changed, 6 insertions, 6 deletions
diff --git a/src/python/gudhi/sktda/vector_methods.py b/src/python/gudhi/sktda/vector_methods.py
index 1f304eaf..42ded45f 100644
--- a/src/python/gudhi/sktda/vector_methods.py
+++ b/src/python/gudhi/sktda/vector_methods.py
@@ -424,15 +424,15 @@ class ComplexPolynomial(BaseEstimator, TransformerMixin):
"""
This is a class for computing complex polynomials from a list of persistence diagrams. The persistence diagram points are seen as the roots of some complex polynomial, whose coefficients are returned in a complex vector. See https://link.springer.com/chapter/10.1007%2F978-3-319-23231-7_27 for more details.
"""
- def __init__(self, F="R", threshold=10):
+ def __init__(self, polynomial_type="R", threshold=10):
"""
Constructor for the ComplexPolynomial class.
Attributes:
- F (char): either "R", "S" or "T" (default "R"). Type of complex polynomial that is going to be computed (explained in https://link.springer.com/chapter/10.1007%2F978-3-319-23231-7_27).
+ polynomial_type (char): either "R", "S" or "T" (default "R"). Type of complex polynomial that is going to be computed (explained in https://link.springer.com/chapter/10.1007%2F978-3-319-23231-7_27).
threshold (int): number of coefficients (default 10). This is the dimension of the complex vector of coefficients, i.e. the number of coefficients corresponding to the largest degree terms of the polynomial. If -1, this threshold is computed from the list of persistence diagrams by considering the one with the largest number of points and using the dimension of its corresponding complex vector of coefficients as threshold.
"""
- self.threshold, self.F = threshold, F
+ self.threshold, self.polynomial_type = threshold, polynomial_type
def fit(self, X, y=None):
"""
@@ -462,13 +462,13 @@ class ComplexPolynomial(BaseEstimator, TransformerMixin):
Xfit = np.zeros([len(X), thresh]) + 1j * np.zeros([len(X), thresh])
for d in range(len(X)):
D, N = X[d], X[d].shape[0]
- if self.F == "R":
+ if self.polynomial_type == "R":
roots = D[:,0] + 1j * D[:,1]
- elif self.F == "S":
+ elif self.polynomial_type == "S":
alpha = np.linalg.norm(D, axis=1)
alpha = np.where(alpha==0, np.ones(N), alpha)
roots = np.multiply( np.multiply( (D[:,0]+1j*D[:,1]), (D[:,1]-D[:,0]) ), 1./(np.sqrt(2)*alpha) )
- elif self.F == "T":
+ elif self.polynomial_type == "T":
alpha = np.linalg.norm(D, axis=1)
roots = np.multiply( (D[:,1]-D[:,0])/2, np.cos(alpha) - np.sin(alpha) + 1j * (np.cos(alpha) + np.sin(alpha)) )
coeff = [0] * (N+1)