diff options
author | ROUVREAU Vincent <vincent.rouvreau@inria.fr> | 2020-08-18 10:55:42 +0200 |
---|---|---|
committer | ROUVREAU Vincent <vincent.rouvreau@inria.fr> | 2020-08-18 10:55:42 +0200 |
commit | a1cd7e9ead030654a1fdb6cfd50408103c458529 (patch) | |
tree | 9786156bfb00d5b4f85dda2458b087d60d1bc1a8 /src/python/test/test_representations.py | |
parent | 85eec1ba750d56b66e3739dc486c6205f49fb31e (diff) | |
parent | 4737aaeb36a4ff3b27d7bcbb374911197ed09e5a (diff) |
Merge master and resolve conflicts
Diffstat (limited to 'src/python/test/test_representations.py')
-rwxr-xr-x | src/python/test/test_representations.py | 17 |
1 files changed, 17 insertions, 0 deletions
diff --git a/src/python/test/test_representations.py b/src/python/test/test_representations.py index 589cee00..e5c211a0 100755 --- a/src/python/test/test_representations.py +++ b/src/python/test/test_representations.py @@ -4,6 +4,8 @@ import matplotlib.pyplot as plt import numpy as np import pytest +from sklearn.cluster import KMeans + def test_representations_examples(): # Disable graphics for testing purposes @@ -15,6 +17,7 @@ def test_representations_examples(): return None +from gudhi.representations.vector_methods import Atol from gudhi.representations.metrics import * from gudhi.representations.kernel_methods import * @@ -41,3 +44,17 @@ def test_multiple(): d2 = WassersteinDistance(order=2, internal_p=2, n_jobs=4).fit(l2).transform(l1) print(d1.shape, d2.shape) assert d1 == pytest.approx(d2, rel=.02) + + +def test_dummy_atol(): + a = np.array([[1, 2, 4], [1, 4, 0], [1, 0, 4]]) + b = np.array([[4, 2, 0], [4, 4, 0], [4, 0, 2]]) + c = np.array([[3, 2, -1], [1, 2, -1]]) + + for weighting_method in ["cloud", "iidproba"]: + for contrast in ["gaussian", "laplacian", "indicator"]: + atol_vectoriser = Atol(quantiser=KMeans(n_clusters=1, random_state=202006), weighting_method=weighting_method, contrast=contrast) + atol_vectoriser.fit([a, b, c]) + atol_vectoriser(a) + atol_vectoriser.transform(X=[a, b, c]) + |