summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMathieuCarriere <mathieu.carriere3@gmail.com>2019-11-04 18:18:10 -0500
committerMathieuCarriere <mathieu.carriere3@gmail.com>2019-11-04 18:18:10 -0500
commit38135d576d92cea7b4e355e2e70a4fe322d4b6e7 (patch)
treeea01aac00feda5edc9efb2443f7b82eef38b1de7
parent77e7ee6e197aa8f0cf0fc0065c8d12e7c543e21f (diff)
fixes to Marc's comments
-rwxr-xr-xsrc/python/example/ex_diagrams.py2
-rw-r--r--src/python/gudhi/sktda/kernel_methods.py2
-rw-r--r--src/python/gudhi/sktda/metrics.py2
-rw-r--r--src/python/gudhi/sktda/vector_methods.py12
4 files changed, 9 insertions, 9 deletions
diff --git a/src/python/example/ex_diagrams.py b/src/python/example/ex_diagrams.py
index f12304bd..a6a36b7c 100755
--- a/src/python/example/ex_diagrams.py
+++ b/src/python/example/ex_diagrams.py
@@ -47,7 +47,7 @@ plt.plot(bc[0])
plt.title("Betti Curve")
plt.show()
-CP = ComplexPolynomial(threshold=-1, F="T")
+CP = ComplexPolynomial(threshold=-1, polynomial_type="T")
cp = CP.fit_transform(diags)
print("Complex polynomial is " + str(cp[0,:]))
diff --git a/src/python/gudhi/sktda/kernel_methods.py b/src/python/gudhi/sktda/kernel_methods.py
index 20cda49b..b8d4ab3a 100644
--- a/src/python/gudhi/sktda/kernel_methods.py
+++ b/src/python/gudhi/sktda/kernel_methods.py
@@ -22,7 +22,7 @@ class SlicedWassersteinKernel(BaseEstimator, TransformerMixin):
Attributes:
bandwidth (double): bandwidth of the Gaussian kernel applied to the sliced Wasserstein distance (default 1.).
- num_directions (int): number of lines evenly sampled on [-pi/2,pi/2] in order to approximate and speed up the kernel computation (default 10). If -1, the exact kernel is computed.
+ num_directions (int): number of lines evenly sampled from [-pi/2,pi/2] in order to approximate and speed up the kernel computation (default 10). If -1, the exact kernel is computed.
"""
self.bandwidth = bandwidth
self.sw_ = SlicedWassersteinDistance(num_directions=num_directions)
diff --git a/src/python/gudhi/sktda/metrics.py b/src/python/gudhi/sktda/metrics.py
index 816441b6..e85fd14d 100644
--- a/src/python/gudhi/sktda/metrics.py
+++ b/src/python/gudhi/sktda/metrics.py
@@ -26,7 +26,7 @@ class SlicedWassersteinDistance(BaseEstimator, TransformerMixin):
Constructor for the SlicedWassersteinDistance class.
Attributes:
- num_directions (int): number of lines to sample uniformly from [-pi,pi] in order to approximate and speed up the distance computation (default 10).
+ num_directions (int): number of lines evenly sampled from [-pi/2,pi/2] in order to approximate and speed up the distance computation (default 10).
"""
self.num_directions = num_directions
thetas = np.linspace(-np.pi/2, np.pi/2, num=self.num_directions+1)[np.newaxis,:-1]
diff --git a/src/python/gudhi/sktda/vector_methods.py b/src/python/gudhi/sktda/vector_methods.py
index 1f304eaf..42ded45f 100644
--- a/src/python/gudhi/sktda/vector_methods.py
+++ b/src/python/gudhi/sktda/vector_methods.py
@@ -424,15 +424,15 @@ class ComplexPolynomial(BaseEstimator, TransformerMixin):
"""
This is a class for computing complex polynomials from a list of persistence diagrams. The persistence diagram points are seen as the roots of some complex polynomial, whose coefficients are returned in a complex vector. See https://link.springer.com/chapter/10.1007%2F978-3-319-23231-7_27 for more details.
"""
- def __init__(self, F="R", threshold=10):
+ def __init__(self, polynomial_type="R", threshold=10):
"""
Constructor for the ComplexPolynomial class.
Attributes:
- F (char): either "R", "S" or "T" (default "R"). Type of complex polynomial that is going to be computed (explained in https://link.springer.com/chapter/10.1007%2F978-3-319-23231-7_27).
+ polynomial_type (char): either "R", "S" or "T" (default "R"). Type of complex polynomial that is going to be computed (explained in https://link.springer.com/chapter/10.1007%2F978-3-319-23231-7_27).
threshold (int): number of coefficients (default 10). This is the dimension of the complex vector of coefficients, i.e. the number of coefficients corresponding to the largest degree terms of the polynomial. If -1, this threshold is computed from the list of persistence diagrams by considering the one with the largest number of points and using the dimension of its corresponding complex vector of coefficients as threshold.
"""
- self.threshold, self.F = threshold, F
+ self.threshold, self.polynomial_type = threshold, polynomial_type
def fit(self, X, y=None):
"""
@@ -462,13 +462,13 @@ class ComplexPolynomial(BaseEstimator, TransformerMixin):
Xfit = np.zeros([len(X), thresh]) + 1j * np.zeros([len(X), thresh])
for d in range(len(X)):
D, N = X[d], X[d].shape[0]
- if self.F == "R":
+ if self.polynomial_type == "R":
roots = D[:,0] + 1j * D[:,1]
- elif self.F == "S":
+ elif self.polynomial_type == "S":
alpha = np.linalg.norm(D, axis=1)
alpha = np.where(alpha==0, np.ones(N), alpha)
roots = np.multiply( np.multiply( (D[:,0]+1j*D[:,1]), (D[:,1]-D[:,0]) ), 1./(np.sqrt(2)*alpha) )
- elif self.F == "T":
+ elif self.polynomial_type == "T":
alpha = np.linalg.norm(D, axis=1)
roots = np.multiply( (D[:,1]-D[:,0])/2, np.cos(alpha) - np.sin(alpha) + 1j * (np.cos(alpha) + np.sin(alpha)) )
coeff = [0] * (N+1)