summaryrefslogtreecommitdiff
path: root/src/python/gudhi/representations
diff options
context:
space:
mode:
authorMarc Glisse <marc.glisse@inria.fr>2020-06-02 21:07:29 +0200
committerMarc Glisse <marc.glisse@inria.fr>2020-06-02 21:07:29 +0200
commit69852030a6d1b68f3283b5727c6b944a9c7f5e73 (patch)
tree8efd8f97f82f22991d09321f9fec3a4a8e9aa1bf /src/python/gudhi/representations
parent7706056bb9c0396188201570f9399e636df63df7 (diff)
Some test
Diffstat (limited to 'src/python/gudhi/representations')
-rw-r--r--src/python/gudhi/representations/kernel_methods.py2
-rw-r--r--src/python/gudhi/representations/metrics.py2
2 files changed, 2 insertions, 2 deletions
diff --git a/src/python/gudhi/representations/kernel_methods.py b/src/python/gudhi/representations/kernel_methods.py
index 6e4f0619..23fd23c7 100644
--- a/src/python/gudhi/representations/kernel_methods.py
+++ b/src/python/gudhi/representations/kernel_methods.py
@@ -75,7 +75,7 @@ def pairwise_persistence_diagram_kernels(X, Y=None, kernel="sliced_wasserstein",
numpy array of shape (nxm): kernel matrix.
"""
XX = np.reshape(np.arange(len(X)), [-1,1])
- YY = None if Y is None else np.reshape(np.arange(len(Y)), [-1,1])
+ YY = None if Y is None or Y is X else np.reshape(np.arange(len(Y)), [-1,1])
if kernel == "sliced_wasserstein":
return np.exp(-pairwise_persistence_diagram_distances(X, Y, metric="sliced_wasserstein", num_directions=kwargs["num_directions"], n_jobs=n_jobs) / kwargs["bandwidth"])
elif kernel == "persistence_fisher":
diff --git a/src/python/gudhi/representations/metrics.py b/src/python/gudhi/representations/metrics.py
index 84907160..cf2e0879 100644
--- a/src/python/gudhi/representations/metrics.py
+++ b/src/python/gudhi/representations/metrics.py
@@ -164,7 +164,7 @@ def pairwise_persistence_diagram_distances(X, Y=None, metric="bottleneck", n_job
numpy array of shape (nxm): distance matrix
"""
XX = np.reshape(np.arange(len(X)), [-1,1])
- YY = None if Y is None else np.reshape(np.arange(len(Y)), [-1,1])
+ YY = None if Y is None or Y is X else np.reshape(np.arange(len(Y)), [-1,1])
if metric == "bottleneck":
try:
from .. import bottleneck_distance