summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMarc Glisse <marc.glisse@inria.fr>2019-12-26 10:22:47 +0100
committerMarc Glisse <marc.glisse@inria.fr>2019-12-26 10:22:47 +0100
commitc2e22942c35e894d5c1ddc429eb32687c61538c8 (patch)
treeaaf991edcadf848a78122fcaa8c94fbf6d54933e
parent15f222eecf3b427c59f09ec3bec17983377d96a2 (diff)
Basic binding for wasserstein_distance
-rw-r--r--src/cmake/modules/GUDHI_user_version_target.cmake2
-rw-r--r--src/python/gudhi/hera.cc48
-rw-r--r--src/python/setup.py.in26
3 files changed, 73 insertions, 3 deletions
diff --git a/src/cmake/modules/GUDHI_user_version_target.cmake b/src/cmake/modules/GUDHI_user_version_target.cmake
index 2527dee9..9a05386f 100644
--- a/src/cmake/modules/GUDHI_user_version_target.cmake
+++ b/src/cmake/modules/GUDHI_user_version_target.cmake
@@ -57,7 +57,7 @@ add_custom_command(TARGET user_version PRE_BUILD COMMAND ${CMAKE_COMMAND} -E
copy_directory ${CMAKE_SOURCE_DIR}/src/GudhUI ${GUDHI_USER_VERSION_DIR}/GudhUI)
add_custom_command(TARGET user_version PRE_BUILD COMMAND ${CMAKE_COMMAND} -E
- copy_directory ${CMAKE_SOURCE_DIR}/ext/hera/geom_matching/wasserstein/include ${GUDHI_USER_VERSION_DIR}/hera/wasserstein)
+ copy_directory ${CMAKE_SOURCE_DIR}/ext/hera/geom_matching/wasserstein/include ${GUDHI_USER_VERSION_DIR}/ext/hera/geom_matching/wasserstein/include)
set(GUDHI_DIRECTORIES "doc;example;concept;utilities")
diff --git a/src/python/gudhi/hera.cc b/src/python/gudhi/hera.cc
new file mode 100644
index 00000000..7cef9425
--- /dev/null
+++ b/src/python/gudhi/hera.cc
@@ -0,0 +1,48 @@
+#include <pybind11/pybind11.h>
+#include <pybind11/numpy.h>
+
+#include <boost/range/iterator_range.hpp>
+
+#include <wasserstein.h>
+
+#include <array>
+
+namespace py = pybind11;
+typedef py::array_t<double, py::array::c_style | py::array::forcecast> Dgm;
+
+namespace hera {
+template <> struct DiagramTraits<Dgm>{
+ //using Container = void;
+ using PointType = std::array<double,2>;
+ using RealType = double;
+
+ static RealType get_x(const PointType& p) { return std::get<0>(p); }
+ static RealType get_y(const PointType& p) { return std::get<1>(p); }
+};
+}
+
+double wasserstein_distance(
+ Dgm d1,
+ Dgm d2)
+{
+ py::buffer_info buf1 = d1.request();
+ py::buffer_info buf2 = d2.request();
+ if(buf1.ndim!=2 || buf1.shape[1]!=2)
+ throw std::runtime_error("Diagram 1 must be an array of size n x 2");
+ if(buf2.ndim!=2 || buf2.shape[1]!=2)
+ throw std::runtime_error("Diagram 1 must be an array of size n x 2");
+ typedef hera::DiagramTraits<Dgm>::PointType Point;
+ auto p1 = (Point*)buf1.ptr;
+ auto p2 = (Point*)buf2.ptr;
+ auto diag1 = boost::make_iterator_range(p1, p1+buf1.shape[0]);
+ auto diag2 = boost::make_iterator_range(p2, p2+buf2.shape[0]);
+
+ hera::AuctionParams<double> params;
+ return hera::wasserstein_dist(diag1, diag2, params);
+}
+
+PYBIND11_MODULE(hera, m) {
+ m.def("wasserstein_distance", &wasserstein_distance, R"pbdoc(
+ Compute the Wasserstein distance between two diagrams
+ )pbdoc");
+}
diff --git a/src/python/setup.py.in b/src/python/setup.py.in
index 3f1d4424..f7ffd146 100644
--- a/src/python/setup.py.in
+++ b/src/python/setup.py.in
@@ -26,6 +26,19 @@ library_dirs=[@GUDHI_PYTHON_LIBRARY_DIRS@]
include_dirs = [numpy_get_include(), '@CMAKE_CURRENT_SOURCE_DIR@/gudhi/', @GUDHI_PYTHON_INCLUDE_DIRS@]
runtime_library_dirs=[@GUDHI_PYTHON_RUNTIME_LIBRARY_DIRS@]
+class get_pybind_include(object):
+ """Helper class to determine the pybind11 include path
+ The purpose of this class is to postpone importing pybind11
+ until it is actually installed, so that the ``get_include()``
+ method can be invoked. """
+
+ def __init__(self, user=False):
+ self.user = user
+
+ def __str__(self):
+ import pybind11
+ return pybind11.get_include(self.user)
+
# Create ext_modules list from module list
ext_modules = []
for module in modules:
@@ -39,6 +52,15 @@ for module in modules:
library_dirs=library_dirs,
include_dirs=include_dirs,
runtime_library_dirs=runtime_library_dirs,))
+ext_modules.append(Extension(
+ 'gudhi.hera',
+ sources = [source_dir + 'hera.cc'],
+ language = 'c++',
+ extra_compile_args=extra_compile_args + ['-fvisibility=hidden'], # FIXME
+ include_dirs = include_dirs +
+ ['@CMAKE_SOURCE_DIR@/ext/hera/geom_matching/wasserstein/include',
+ get_pybind_include(False), get_pybind_include(True)]
+ ))
setup(
name = 'gudhi',
@@ -48,6 +70,6 @@ setup(
version='@GUDHI_VERSION@',
url='http://gudhi.gforge.inria.fr/',
ext_modules = cythonize(ext_modules),
- install_requires = ['cython','numpy >= 1.9',],
- setup_requires = ['numpy >= 1.9',],
+ install_requires = ['cython','numpy >= 1.9','pybind11',],
+ setup_requires = ['numpy >= 1.9','pybind11',],
)