summaryrefslogtreecommitdiff
path: root/src/python/test
diff options
context:
space:
mode:
authorMarc Glisse <marc.glisse@inria.fr>2020-05-26 12:06:40 +0200
committerMarc Glisse <marc.glisse@inria.fr>2020-05-26 12:06:40 +0200
commit59763a7a4a7be162aee9552e5db4e86fa2225cb6 (patch)
treec865dc0ca9e97ba28dac58a78e3612621de1040f /src/python/test
parent5a78dc70afe172e8f38bff09639be21fc92b1fb4 (diff)
test
Diffstat (limited to 'src/python/test')
-rwxr-xr-xsrc/python/test/test_tomato.py5
1 files changed, 4 insertions, 1 deletions
diff --git a/src/python/test/test_tomato.py b/src/python/test/test_tomato.py
index a4cab654..ecab03c4 100755
--- a/src/python/test/test_tomato.py
+++ b/src/python/test/test_tomato.py
@@ -17,7 +17,7 @@ import matplotlib.pyplot as plt
plt.show = lambda: None
-def test_tomato_something():
+def test_tomato_1():
a = [(1, 2), (1.1, 1.9), (0.9, 1.8), (10, 0), (10.1, 0.05), (10.2, -0.1), (5.4, 0)]
t = Tomato(metric="euclidean", n_clusters=2, k=4, n_jobs=-1, eps=0.05)
assert np.array_equal(t.fit_predict(a), [1, 1, 1, 0, 0, 0, 0]) # or with swapped 0 and 1
@@ -31,6 +31,9 @@ def test_tomato_something():
assert t.n_clusters_ == 1
assert (t.labels_ == 0).all()
+ t = Tomato(graph_type="radius", r=0.1, metric="cosine", k=3)
+ assert np.array_equal(t.fit_predict(a), [1, 1, 1, 0, 0, 0, 0]) # or with swapped 0 and 1
+
t = Tomato(metric="euclidean", graph_type="radius", r=4.7, k=4)
t.fit(a)
assert t.max_weight_per_cc_.size == 2