summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMarc Glisse <marc.glisse@inria.fr>2020-06-02 21:07:29 +0200
committerMarc Glisse <marc.glisse@inria.fr>2020-06-02 21:07:29 +0200
commit69852030a6d1b68f3283b5727c6b944a9c7f5e73 (patch)
tree8efd8f97f82f22991d09321f9fec3a4a8e9aa1bf
parent7706056bb9c0396188201570f9399e636df63df7 (diff)
Some test
-rw-r--r--src/python/gudhi/representations/kernel_methods.py2
-rw-r--r--src/python/gudhi/representations/metrics.py2
-rwxr-xr-xsrc/python/test/test_representations.py33
3 files changed, 34 insertions, 3 deletions
diff --git a/src/python/gudhi/representations/kernel_methods.py b/src/python/gudhi/representations/kernel_methods.py
index 6e4f0619..23fd23c7 100644
--- a/src/python/gudhi/representations/kernel_methods.py
+++ b/src/python/gudhi/representations/kernel_methods.py
@@ -75,7 +75,7 @@ def pairwise_persistence_diagram_kernels(X, Y=None, kernel="sliced_wasserstein",
numpy array of shape (nxm): kernel matrix.
"""
XX = np.reshape(np.arange(len(X)), [-1,1])
- YY = None if Y is None else np.reshape(np.arange(len(Y)), [-1,1])
+ YY = None if Y is None or Y is X else np.reshape(np.arange(len(Y)), [-1,1])
if kernel == "sliced_wasserstein":
return np.exp(-pairwise_persistence_diagram_distances(X, Y, metric="sliced_wasserstein", num_directions=kwargs["num_directions"], n_jobs=n_jobs) / kwargs["bandwidth"])
elif kernel == "persistence_fisher":
diff --git a/src/python/gudhi/representations/metrics.py b/src/python/gudhi/representations/metrics.py
index 84907160..cf2e0879 100644
--- a/src/python/gudhi/representations/metrics.py
+++ b/src/python/gudhi/representations/metrics.py
@@ -164,7 +164,7 @@ def pairwise_persistence_diagram_distances(X, Y=None, metric="bottleneck", n_job
numpy array of shape (nxm): distance matrix
"""
XX = np.reshape(np.arange(len(X)), [-1,1])
- YY = None if Y is None else np.reshape(np.arange(len(Y)), [-1,1])
+ YY = None if Y is None or Y is X else np.reshape(np.arange(len(Y)), [-1,1])
if metric == "bottleneck":
try:
from .. import bottleneck_distance
diff --git a/src/python/test/test_representations.py b/src/python/test/test_representations.py
index dba7f952..589cee00 100755
--- a/src/python/test/test_representations.py
+++ b/src/python/test/test_representations.py
@@ -1,12 +1,43 @@
import os
import sys
import matplotlib.pyplot as plt
+import numpy as np
+import pytest
+
def test_representations_examples():
# Disable graphics for testing purposes
- plt.show = lambda:None
+ plt.show = lambda: None
here = os.path.dirname(os.path.realpath(__file__))
sys.path.append(here + "/../example")
import diagram_vectorizations_distances_kernels
return None
+
+
+from gudhi.representations.metrics import *
+from gudhi.representations.kernel_methods import *
+
+
+def _n_diags(n):
+ l = []
+ for _ in range(n):
+ a = np.random.rand(50, 2)
+ a[:, 1] += a[:, 0] # So that y >= x
+ l.append(a)
+ return l
+
+
+def test_multiple():
+ l1 = _n_diags(9)
+ l2 = _n_diags(11)
+ l1b = l1.copy()
+ d1 = pairwise_persistence_diagram_distances(l1, e=0.00001, n_jobs=4)
+ d2 = BottleneckDistance(epsilon=0.00001).fit_transform(l1)
+ d3 = pairwise_persistence_diagram_distances(l1, l1b, e=0.00001, n_jobs=4)
+ assert d1 == pytest.approx(d2)
+ assert d3 == pytest.approx(d2, abs=1e-5) # Because of 0 entries (on the diagonal)
+ d1 = pairwise_persistence_diagram_distances(l1, l2, metric="wasserstein", order=2, internal_p=2)
+ d2 = WassersteinDistance(order=2, internal_p=2, n_jobs=4).fit(l2).transform(l1)
+ print(d1.shape, d2.shape)
+ assert d1 == pytest.approx(d2, rel=.02)