From 15f222eecf3b427c59f09ec3bec17983377d96a2 Mon Sep 17 00:00:00 2001 From: Marc Glisse Date: Wed, 25 Dec 2019 22:27:52 +0100 Subject: Copy hera headers in user_version --- src/cmake/modules/GUDHI_user_version_target.cmake | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) (limited to 'src/cmake/modules/GUDHI_user_version_target.cmake') diff --git a/src/cmake/modules/GUDHI_user_version_target.cmake b/src/cmake/modules/GUDHI_user_version_target.cmake index 4fa74330..2527dee9 100644 --- a/src/cmake/modules/GUDHI_user_version_target.cmake +++ b/src/cmake/modules/GUDHI_user_version_target.cmake @@ -56,6 +56,9 @@ add_custom_command(TARGET user_version PRE_BUILD COMMAND ${CMAKE_COMMAND} -E add_custom_command(TARGET user_version PRE_BUILD COMMAND ${CMAKE_COMMAND} -E copy_directory ${CMAKE_SOURCE_DIR}/src/GudhUI ${GUDHI_USER_VERSION_DIR}/GudhUI) +add_custom_command(TARGET user_version PRE_BUILD COMMAND ${CMAKE_COMMAND} -E + copy_directory ${CMAKE_SOURCE_DIR}/ext/hera/geom_matching/wasserstein/include ${GUDHI_USER_VERSION_DIR}/hera/wasserstein) + set(GUDHI_DIRECTORIES "doc;example;concept;utilities") set(GUDHI_INCLUDE_DIRECTORIES "include/gudhi") @@ -95,4 +98,4 @@ foreach(GUDHI_MODULE ${GUDHI_MODULES_FULL_LIST}) endforeach() endforeach(GUDHI_INCLUDE_DIRECTORY ${GUDHI_INCLUDE_DIRECTORIES}) -endforeach(GUDHI_MODULE ${GUDHI_MODULES_FULL_LIST}) \ No newline at end of file +endforeach(GUDHI_MODULE ${GUDHI_MODULES_FULL_LIST}) -- cgit v1.2.3 From c2e22942c35e894d5c1ddc429eb32687c61538c8 Mon Sep 17 00:00:00 2001 From: Marc Glisse Date: Thu, 26 Dec 2019 10:22:47 +0100 Subject: Basic binding for wasserstein_distance --- src/cmake/modules/GUDHI_user_version_target.cmake | 2 +- src/python/gudhi/hera.cc | 48 +++++++++++++++++++++++ src/python/setup.py.in | 26 +++++++++++- 3 files changed, 73 insertions(+), 3 deletions(-) create mode 100644 src/python/gudhi/hera.cc (limited to 'src/cmake/modules/GUDHI_user_version_target.cmake') diff --git a/src/cmake/modules/GUDHI_user_version_target.cmake b/src/cmake/modules/GUDHI_user_version_target.cmake index 2527dee9..9a05386f 100644 --- a/src/cmake/modules/GUDHI_user_version_target.cmake +++ b/src/cmake/modules/GUDHI_user_version_target.cmake @@ -57,7 +57,7 @@ add_custom_command(TARGET user_version PRE_BUILD COMMAND ${CMAKE_COMMAND} -E copy_directory ${CMAKE_SOURCE_DIR}/src/GudhUI ${GUDHI_USER_VERSION_DIR}/GudhUI) add_custom_command(TARGET user_version PRE_BUILD COMMAND ${CMAKE_COMMAND} -E - copy_directory ${CMAKE_SOURCE_DIR}/ext/hera/geom_matching/wasserstein/include ${GUDHI_USER_VERSION_DIR}/hera/wasserstein) + copy_directory ${CMAKE_SOURCE_DIR}/ext/hera/geom_matching/wasserstein/include ${GUDHI_USER_VERSION_DIR}/ext/hera/geom_matching/wasserstein/include) set(GUDHI_DIRECTORIES "doc;example;concept;utilities") diff --git a/src/python/gudhi/hera.cc b/src/python/gudhi/hera.cc new file mode 100644 index 00000000..7cef9425 --- /dev/null +++ b/src/python/gudhi/hera.cc @@ -0,0 +1,48 @@ +#include +#include + +#include + +#include + +#include + +namespace py = pybind11; +typedef py::array_t Dgm; + +namespace hera { +template <> struct DiagramTraits{ + //using Container = void; + using PointType = std::array; + using RealType = double; + + static RealType get_x(const PointType& p) { return std::get<0>(p); } + static RealType get_y(const PointType& p) { return std::get<1>(p); } +}; +} + +double wasserstein_distance( + Dgm d1, + Dgm d2) +{ + py::buffer_info buf1 = d1.request(); + py::buffer_info buf2 = d2.request(); + if(buf1.ndim!=2 || buf1.shape[1]!=2) + throw std::runtime_error("Diagram 1 must be an array of size n x 2"); + if(buf2.ndim!=2 || buf2.shape[1]!=2) + throw std::runtime_error("Diagram 1 must be an array of size n x 2"); + typedef hera::DiagramTraits::PointType Point; + auto p1 = (Point*)buf1.ptr; + auto p2 = (Point*)buf2.ptr; + auto diag1 = boost::make_iterator_range(p1, p1+buf1.shape[0]); + auto diag2 = boost::make_iterator_range(p2, p2+buf2.shape[0]); + + hera::AuctionParams params; + return hera::wasserstein_dist(diag1, diag2, params); +} + +PYBIND11_MODULE(hera, m) { + m.def("wasserstein_distance", &wasserstein_distance, R"pbdoc( + Compute the Wasserstein distance between two diagrams + )pbdoc"); +} diff --git a/src/python/setup.py.in b/src/python/setup.py.in index 3f1d4424..f7ffd146 100644 --- a/src/python/setup.py.in +++ b/src/python/setup.py.in @@ -26,6 +26,19 @@ library_dirs=[@GUDHI_PYTHON_LIBRARY_DIRS@] include_dirs = [numpy_get_include(), '@CMAKE_CURRENT_SOURCE_DIR@/gudhi/', @GUDHI_PYTHON_INCLUDE_DIRS@] runtime_library_dirs=[@GUDHI_PYTHON_RUNTIME_LIBRARY_DIRS@] +class get_pybind_include(object): + """Helper class to determine the pybind11 include path + The purpose of this class is to postpone importing pybind11 + until it is actually installed, so that the ``get_include()`` + method can be invoked. """ + + def __init__(self, user=False): + self.user = user + + def __str__(self): + import pybind11 + return pybind11.get_include(self.user) + # Create ext_modules list from module list ext_modules = [] for module in modules: @@ -39,6 +52,15 @@ for module in modules: library_dirs=library_dirs, include_dirs=include_dirs, runtime_library_dirs=runtime_library_dirs,)) +ext_modules.append(Extension( + 'gudhi.hera', + sources = [source_dir + 'hera.cc'], + language = 'c++', + extra_compile_args=extra_compile_args + ['-fvisibility=hidden'], # FIXME + include_dirs = include_dirs + + ['@CMAKE_SOURCE_DIR@/ext/hera/geom_matching/wasserstein/include', + get_pybind_include(False), get_pybind_include(True)] + )) setup( name = 'gudhi', @@ -48,6 +70,6 @@ setup( version='@GUDHI_VERSION@', url='http://gudhi.gforge.inria.fr/', ext_modules = cythonize(ext_modules), - install_requires = ['cython','numpy >= 1.9',], - setup_requires = ['numpy >= 1.9',], + install_requires = ['cython','numpy >= 1.9','pybind11',], + setup_requires = ['numpy >= 1.9','pybind11',], ) -- cgit v1.2.3