diff options
Diffstat (limited to 'src/python')
-rw-r--r-- | src/python/gudhi/point_cloud/knn.py | 2 | ||||
-rw-r--r-- | src/python/gudhi/simplex_tree.pyx | 2 | ||||
-rw-r--r-- | src/python/setup.py.in | 5 | ||||
-rwxr-xr-x | src/python/test/test_tomato.py | 2 |
4 files changed, 5 insertions, 6 deletions
diff --git a/src/python/gudhi/point_cloud/knn.py b/src/python/gudhi/point_cloud/knn.py index 994be3b6..829bf1bf 100644 --- a/src/python/gudhi/point_cloud/knn.py +++ b/src/python/gudhi/point_cloud/knn.py @@ -111,7 +111,7 @@ class KNearestNeighbors: nargs = { k: v for k, v in self.params.items() if k in {"p", "n_jobs", "metric_params", "algorithm", "leaf_size"} } - self.nn = NearestNeighbors(self.k, metric=self.metric, **nargs) + self.nn = NearestNeighbors(n_neighbors=self.k, metric=self.metric, **nargs) self.nn.fit(X) if self.params["implementation"] == "hnsw": diff --git a/src/python/gudhi/simplex_tree.pyx b/src/python/gudhi/simplex_tree.pyx index d7991417..be08a3a1 100644 --- a/src/python/gudhi/simplex_tree.pyx +++ b/src/python/gudhi/simplex_tree.pyx @@ -11,7 +11,7 @@ from cython.operator import dereference, preincrement from libc.stdint cimport intptr_t import numpy from numpy import array as np_array -cimport simplex_tree +cimport gudhi.simplex_tree __author__ = "Vincent Rouvreau" __copyright__ = "Copyright (C) 2016 Inria" diff --git a/src/python/setup.py.in b/src/python/setup.py.in index 98d058fc..65f5446e 100644 --- a/src/python/setup.py.in +++ b/src/python/setup.py.in @@ -41,10 +41,9 @@ for module in cython_modules: libraries=libraries, library_dirs=library_dirs, include_dirs=include_dirs, - runtime_library_dirs=runtime_library_dirs, - cython_directives = {'language_level': str(sys.version_info[0])},)) + runtime_library_dirs=runtime_library_dirs,)) -ext_modules = cythonize(ext_modules) +ext_modules = cythonize(ext_modules, compiler_directives={'language_level': str(sys.version_info[0])}) for module in pybind11_modules: my_include_dirs = include_dirs + [pybind11.get_include(False), pybind11.get_include(True)] diff --git a/src/python/test/test_tomato.py b/src/python/test/test_tomato.py index ecab03c4..c571f799 100755 --- a/src/python/test/test_tomato.py +++ b/src/python/test/test_tomato.py @@ -37,7 +37,7 @@ def test_tomato_1(): t = Tomato(metric="euclidean", graph_type="radius", r=4.7, k=4) t.fit(a) assert t.max_weight_per_cc_.size == 2 - assert np.array_equal(t.neighbors_, [[0, 1, 2], [0, 1, 2], [0, 1, 2], [3, 4, 5, 6], [3, 4, 5], [3, 4, 5], [3, 6]]) + assert t.neighbors_ == [[0, 1, 2], [0, 1, 2], [0, 1, 2], [3, 4, 5, 6], [3, 4, 5], [3, 4, 5], [3, 6]] t.plot_diagram() t = Tomato(graph_type="radius", r=4.7, k=4, symmetrize_graph=True) |