summaryrefslogtreecommitdiff
path: root/src/python
diff options
context:
space:
mode:
Diffstat (limited to 'src/python')
-rw-r--r--src/python/gudhi/hera/bottleneck.cc2
-rw-r--r--src/python/gudhi/hera/wasserstein.cc10
-rw-r--r--src/python/setup.py.in6
3 files changed, 11 insertions, 7 deletions
diff --git a/src/python/gudhi/hera/bottleneck.cc b/src/python/gudhi/hera/bottleneck.cc
index 0cb562ce..ec461f7c 100644
--- a/src/python/gudhi/hera/bottleneck.cc
+++ b/src/python/gudhi/hera/bottleneck.cc
@@ -16,7 +16,7 @@
using py::ssize_t;
#endif
-#include <bottleneck.h> // Hera
+#include <hera/bottleneck.h> // Hera
double bottleneck_distance(Dgm d1, Dgm d2, double delta)
{
diff --git a/src/python/gudhi/hera/wasserstein.cc b/src/python/gudhi/hera/wasserstein.cc
index fa0cf8aa..3516352e 100644
--- a/src/python/gudhi/hera/wasserstein.cc
+++ b/src/python/gudhi/hera/wasserstein.cc
@@ -8,10 +8,16 @@
* - YYYY/MM Author: Description of the modification
*/
-#include <wasserstein.h> // Hera
-
#include <pybind11_diagram_utils.h>
+#ifdef _MSC_VER
+// https://github.com/grey-narn/hera/issues/3
+// ssize_t is a non-standard type (well, posix)
+using py::ssize_t;
+#endif
+
+#include <hera/wasserstein.h> // Hera
+
double wasserstein_distance(
Dgm d1, Dgm d2,
double wasserstein_power, double internal_p,
diff --git a/src/python/setup.py.in b/src/python/setup.py.in
index 2c67c2c5..1ecbe985 100644
--- a/src/python/setup.py.in
+++ b/src/python/setup.py.in
@@ -48,10 +48,8 @@ ext_modules = cythonize(ext_modules, compiler_directives={'language_level': '3'}
for module in pybind11_modules:
my_include_dirs = include_dirs + [pybind11.get_include(False), pybind11.get_include(True)]
- if module == 'hera/wasserstein':
- my_include_dirs = ['@HERA_WASSERSTEIN_INCLUDE_DIR@'] + my_include_dirs
- elif module == 'hera/bottleneck':
- my_include_dirs = ['@HERA_BOTTLENECK_INCLUDE_DIR@'] + my_include_dirs
+ if module.startswith('hera/'):
+ my_include_dirs = ['@HERA_INCLUDE_DIR@'] + my_include_dirs
ext_modules.append(Extension(
'gudhi.' + module.replace('/', '.'),
sources = [source_dir + module + '.cc'],