diff options
author | Marc Glisse <marc.glisse@inria.fr> | 2020-06-02 21:07:29 +0200 |
---|---|---|
committer | Marc Glisse <marc.glisse@inria.fr> | 2020-06-02 21:07:29 +0200 |
commit | 69852030a6d1b68f3283b5727c6b944a9c7f5e73 (patch) | |
tree | 8efd8f97f82f22991d09321f9fec3a4a8e9aa1bf /src/python/gudhi/representations | |
parent | 7706056bb9c0396188201570f9399e636df63df7 (diff) |
Some test
Diffstat (limited to 'src/python/gudhi/representations')
-rw-r--r-- | src/python/gudhi/representations/kernel_methods.py | 2 | ||||
-rw-r--r-- | src/python/gudhi/representations/metrics.py | 2 |
2 files changed, 2 insertions, 2 deletions
diff --git a/src/python/gudhi/representations/kernel_methods.py b/src/python/gudhi/representations/kernel_methods.py index 6e4f0619..23fd23c7 100644 --- a/src/python/gudhi/representations/kernel_methods.py +++ b/src/python/gudhi/representations/kernel_methods.py @@ -75,7 +75,7 @@ def pairwise_persistence_diagram_kernels(X, Y=None, kernel="sliced_wasserstein", numpy array of shape (nxm): kernel matrix. """ XX = np.reshape(np.arange(len(X)), [-1,1]) - YY = None if Y is None else np.reshape(np.arange(len(Y)), [-1,1]) + YY = None if Y is None or Y is X else np.reshape(np.arange(len(Y)), [-1,1]) if kernel == "sliced_wasserstein": return np.exp(-pairwise_persistence_diagram_distances(X, Y, metric="sliced_wasserstein", num_directions=kwargs["num_directions"], n_jobs=n_jobs) / kwargs["bandwidth"]) elif kernel == "persistence_fisher": diff --git a/src/python/gudhi/representations/metrics.py b/src/python/gudhi/representations/metrics.py index 84907160..cf2e0879 100644 --- a/src/python/gudhi/representations/metrics.py +++ b/src/python/gudhi/representations/metrics.py @@ -164,7 +164,7 @@ def pairwise_persistence_diagram_distances(X, Y=None, metric="bottleneck", n_job numpy array of shape (nxm): distance matrix """ XX = np.reshape(np.arange(len(X)), [-1,1]) - YY = None if Y is None else np.reshape(np.arange(len(Y)), [-1,1]) + YY = None if Y is None or Y is X else np.reshape(np.arange(len(Y)), [-1,1]) if metric == "bottleneck": try: from .. import bottleneck_distance |