summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorArnur Nigmetov <nigmetov@tugraz.at>2020-03-04 00:33:51 +0100
committerArnur Nigmetov <nigmetov@tugraz.at>2020-03-04 00:33:51 +0100
commit3809e4071827a5959f27e472514eaed08ba6d15e (patch)
tree113fd1c373e4a04b568f8daf25324efafff6b107
parentd56c07b093bfea1690a81ebbef41b8bb9c7c2464 (diff)
Make matching distance header-only.
-rw-r--r--matching/CMakeLists.txt66
-rw-r--r--matching/include/bifiltration.h72
-rw-r--r--matching/include/bifiltration.hpp (renamed from matching/src/bifiltration.cpp)134
-rw-r--r--matching/include/box.h77
-rw-r--r--matching/include/box.hpp52
-rw-r--r--matching/include/cell_with_value.h47
-rw-r--r--matching/include/cell_with_value.hpp (renamed from matching/src/cell_with_value.cpp)89
-rw-r--r--matching/include/common_defs.h4
-rw-r--r--matching/include/common_util.h113
-rw-r--r--matching/include/common_util.hpp96
-rw-r--r--matching/include/dual_box.h78
-rw-r--r--matching/include/dual_box.hpp (renamed from matching/src/dual_box.cpp)102
-rw-r--r--matching/include/dual_point.h27
-rw-r--r--matching/include/dual_point.hpp (renamed from matching/src/dual_point.cpp)77
-rw-r--r--matching/include/matching_distance.h203
-rw-r--r--matching/include/matching_distance.hpp326
-rw-r--r--matching/include/persistence_module.h34
-rw-r--r--matching/include/persistence_module.hpp (renamed from matching/src/persistence_module.cpp)52
-rw-r--r--matching/include/simplex.h70
-rw-r--r--matching/include/simplex.hpp79
-rw-r--r--matching/src/box.cpp61
-rw-r--r--matching/src/common_util.cpp243
-rw-r--r--matching/src/main.cpp29
-rw-r--r--matching/src/matching_distance.cpp150
-rw-r--r--matching/src/simplex.cpp121
-rw-r--r--matching/src/test_generator.cpp19
-rw-r--r--matching/src/tests/test_common.cpp66
-rw-r--r--matching/src/tests/test_matching_distance.cpp22
28 files changed, 1110 insertions, 1399 deletions
diff --git a/matching/CMakeLists.txt b/matching/CMakeLists.txt
index 9384328..121e25c 100644
--- a/matching/CMakeLists.txt
+++ b/matching/CMakeLists.txt
@@ -29,84 +29,34 @@ if (NOT WIN32)
endif (NOT WIN32)
file(GLOB BT_HEADERS ${CMAKE_CURRENT_SOURCE_DIR}/../bottleneck/include/*.h ${CMAKE_CURRENT_SOURCE_DIR}/../bottleneck/include/*.hpp)
-file(GLOB MD_HEADERS ${CMAKE_CURRENT_SOURCE_DIR}/include/*.h)
+file(GLOB MD_HEADERS ${CMAKE_CURRENT_SOURCE_DIR}/include/*.h ${CMAKE_CURRENT_SOURCE_DIR}/include/*.hpp)
file(GLOB SRC_FILES ${CMAKE_CURRENT_SOURCE_DIR}/src/*.cpp)
file(GLOB SRC_TEST_FILES ${CMAKE_CURRENT_SOURCE_DIR}/src/tests/*.cpp)
find_package(Threads)
-set(libraries ${libraries} "stdc++fs" ${CMAKE_THREAD_LIBS_INIT} ${OpenMP_CXX_LIBRARIES})
+set(libraries ${libraries} "stdc++fs" ${CMAKE_THREAD_LIBS_INIT})
find_package(OpenMP)
if (OPENMP_FOUND)
+set(libraries ${libraries} ${OpenMP_CXX_LIBRARIES})
set (CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}")
set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}")
set (CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} ${OpenMP_EXE_LINKER_FLAGS}")
endif()
-#add_executable(matching_distance ${SRC_FILES} ${BT_HEADERS} ${MD_HEADERS})
-add_executable(matching_distance "src/main.cpp"
- "src/box.cpp"
- "src/common_util.cpp"
- "src/persistence_module.cpp"
- "src/simplex.cpp"
- "src/bifiltration.cpp"
- "src/matching_distance.cpp"
- "src/dual_box.cpp"
- "src/dual_point.cpp"
- "include/box.h"
- "include/common_util.h"
- "include/persistence_module.h"
- "include/simplex.h"
- "include/bifiltration.h"
- "include/matching_distance.h"
- "include/dual_box.h"
- "include/dual_point.h"
- ${BT_HEADERS} include/cell_with_value.h src/cell_with_value.cpp)
+add_executable(matching_distance "src/main.cpp" ${MD_HEADERS} ${BT_HEADERS} )
target_link_libraries(matching_distance PUBLIC ${libraries})
-#add_executable(matching_distance_test ${SRC_TEST_FILES} ${BT_HEADERS} ${MD_HEADERS})
-add_executable(matching_distance_test ${SRC_TEST_FILES}
- "src/box.cpp"
- "src/common_util.cpp"
- "src/persistence_module.cpp"
- "src/simplex.cpp"
- "src/bifiltration.cpp"
- "src/matching_distance.cpp"
- "src/dual_box.cpp"
- "src/dual_point.cpp"
- "include/box.h"
- "include/common_util.h"
- "include/persistence_module.h"
- "include/simplex.h"
- "include/bifiltration.h"
- "include/matching_distance.h"
- "include/dual_box.h"
- "include/dual_point.h"
- ${BT_HEADERS} src/tests/test_common.cpp src/common_util.cpp src/tests/test_matching_distance.cpp src/cell_with_value.cpp)
+add_executable(matching_distance_test ${SRC_TEST_FILES} ${BT_HEADERS} ${MD_HEADERS})
target_link_libraries(matching_distance_test PUBLIC ${libraries})
add_executable(test_generator "src/test_generator.cpp"
- "src/box.cpp"
- "src/common_util.cpp"
- "src/persistence_module.cpp"
- "src/simplex.cpp"
- "src/bifiltration.cpp"
- "src/matching_distance.cpp"
- "src/dual_box.cpp"
- "src/dual_point.cpp"
- "include/box.h"
- "include/common_util.h"
- "include/persistence_module.h"
- "include/simplex.h"
- "include/bifiltration.h"
- "include/matching_distance.h"
- "include/dual_box.h"
- "include/dual_point.h"
- ${BT_HEADERS} src/cell_with_value.cpp)
+ ${MD_HEADERS}
+ ${BT_HEADERS})
target_link_libraries(test_generator PUBLIC ${libraries})
-#add_executable(matching_distance "src/main.cpp" "src/box.cpp" "src/common_util.cpp" "src/line.cpp" "src/persistence_module.cpp" ${BT_HEADERS} ${MD_HEADERS})
+#add_executable(matching_distance "include/main.cpp" "src/box.cpp" "src/common_util.cpp" "src/line.cpp" "src/persistence_module.hpp" ${BT_HEADERS} ${MD_HEADERS})
diff --git a/matching/include/bifiltration.h b/matching/include/bifiltration.h
index f505ed9..4dd8662 100644
--- a/matching/include/bifiltration.h
+++ b/matching/include/bifiltration.h
@@ -3,19 +3,30 @@
#include <string>
#include <ostream>
+#include <iostream>
+#include <fstream>
+#include <sstream>
+#include <cassert>
#include "common_util.h"
#include "box.h"
#include "simplex.h"
#include "dual_point.h"
+#include "phat/boundary_matrix.h"
+#include "phat/compute_persistence_pairs.h"
+
+#include "spdlog/spdlog.h"
+#include "spdlog/fmt/fmt.h"
+#include "spdlog/fmt/ostr.h"
+
+#include "common_util.h"
namespace md {
+ template<class Real>
class Bifiltration {
public:
- using Diagram = std::vector<std::pair<Real, Real>>;
- using Box = md::Box;
- using SimplexVector = std::vector<Simplex>;
+ using SimplexVector = std::vector<Simplex<Real>>;
Bifiltration() = default;
@@ -36,7 +47,7 @@ namespace md {
init();
}
- Diagram weighted_slice_diagram(const DualPoint& line, int dim) const;
+ Diagram<Real> weighted_slice_diagram(const DualPoint<Real>& line, int dim) const;
SimplexVector simplices() const { return simplices_; }
@@ -48,14 +59,12 @@ namespace md {
Real minimal_coordinate() const;
// return box that contains positions of all simplices
- Box bounding_box() const;
+ Box<Real> bounding_box() const;
void sanity_check() const;
int maximal_dim() const { return maximal_dim_; }
- friend std::ostream& operator<<(std::ostream& os, const Bifiltration& bif);
-
Real max_x() const;
Real max_y() const;
@@ -64,7 +73,7 @@ namespace md {
Real min_y() const;
- void add_simplex(Index _id, Point birth, int _dim, const Column& _bdry);
+ void add_simplex(Index _id, Point<Real> birth, int _dim, const Column& _bdry);
void save(const std::string& filename, BifiltrationFormat format = BifiltrationFormat::rivet); // save to file
@@ -72,11 +81,8 @@ namespace md {
private:
SimplexVector simplices_;
- // axes names, for rivet bifiltration format only
- std::string parameter_1_name_ {"axis_1"};
- std::string parameter_2_name_ {"axis_2"};
- Box bounding_box_;
+ Box<Real> bounding_box_;
int maximal_dim_ {-1};
void init();
@@ -97,13 +103,15 @@ namespace md {
};
- std::ostream& operator<<(std::ostream& os, const Bifiltration& bif);
+ template<class Real>
+ std::ostream& operator<<(std::ostream& os, const Bifiltration<Real>& bif);
+ template<class Real>
class BifiltrationProxy {
public:
- BifiltrationProxy(const Bifiltration& bif, int dim = 0);
+ BifiltrationProxy(const Bifiltration<Real>& bif, int dim = 0);
// return critical values of simplices that are important for current dimension (dim and dim+1)
- PointVec positions() const;
+ PointVec<Real> positions() const;
// set current dimension
int set_dim(int new_dim);
@@ -111,46 +119,22 @@ namespace md {
int maximal_dim() const;
void translate(Real a);
Real minimal_coordinate() const;
- Box bounding_box() const;
+ Box<Real> bounding_box() const;
Real max_x() const;
Real max_y() const;
Real min_x() const;
Real min_y() const;
- Diagram weighted_slice_diagram(const DualPoint& slice) const;
+ Diagram<Real> weighted_slice_diagram(const DualPoint<Real>& slice) const;
private:
int dim_ { 0 };
- mutable PointVec cached_positions_;
- Bifiltration bif_;
+ mutable PointVec<Real> cached_positions_;
+ Bifiltration<Real> bif_;
void cache_positions() const;
};
}
-
+#include "bifiltration.hpp"
#endif //MATCHING_DISTANCE_BIFILTRATION_H
-
-//// The value type of OutputIterator is Simplex_in_2D_filtration
-//template<typename OutputIterator>
-//void read_input(std::string filename, OutputIterator out)
-//{
-// std::ifstream ifstr;
-// ifstr.open(filename.c_str());
-// long n;
-// ifstr >> n; // number of simplices is the first number in file
-//
-// Index k; // used in loop
-// for (int i = 0; i < n; i++) {
-// Simplex_in_2D_filtration next;
-// next.index = i;
-// ifstr >> next.dim >> next.pos.x >> next.pos.y;
-// if (next.dim > 0) {
-// for (int j = 0; j <= next.dim; j++) {
-// ifstr >> k;
-// next.bd.push_back(k);
-// }
-// }
-// *out++ = next;
-// }
-//}
diff --git a/matching/src/bifiltration.cpp b/matching/include/bifiltration.hpp
index 44b12cf..9e2a82e 100644
--- a/matching/src/bifiltration.cpp
+++ b/matching/include/bifiltration.hpp
@@ -1,35 +1,20 @@
-#include <iostream>
-#include <fstream>
-#include <sstream>
-#include <cassert>
-
-#include<phat/boundary_matrix.h>
-#include<phat/compute_persistence_pairs.h>
-
-#include "spdlog/spdlog.h"
-#include "spdlog/fmt/fmt.h"
-#include "spdlog/fmt/ostr.h"
-
-#include "common_util.h"
-#include "bifiltration.h"
-
-namespace spd = spdlog;
-
namespace md {
- void Bifiltration::init()
+ template<class Real>
+ void Bifiltration<Real>::init()
{
- Point lower_left = max_point();
- Point upper_right = min_point();
+ auto lower_left = max_point<Real>();
+ auto upper_right = min_point<Real>();
for(const auto& simplex : simplices_) {
- lower_left = greatest_lower_bound(lower_left, simplex.position());
- upper_right = least_upper_bound(upper_right, simplex.position());
+ lower_left = greatest_lower_bound<>(lower_left, simplex.position());
+ upper_right = least_upper_bound<>(upper_right, simplex.position());
maximal_dim_ = std::max(maximal_dim_, simplex.dim());
}
- bounding_box_ = Box(lower_left, upper_right);
+ bounding_box_ = Box<Real>(lower_left, upper_right);
}
- Bifiltration::Bifiltration(const std::string& fname )
+ template<class Real>
+ Bifiltration<Real>::Bifiltration(const std::string& fname)
{
std::ifstream ifstr {fname.c_str()};
if (!ifstr.good()) {
@@ -69,12 +54,13 @@ namespace md {
init();
}
- void Bifiltration::rivet_format_reader(std::ifstream& ifstr)
+ template<class Real>
+ void Bifiltration<Real>::rivet_format_reader(std::ifstream& ifstr)
{
std::string s;
- // read axes names
- std::getline(ifstr, parameter_1_name_);
- std::getline(ifstr, parameter_2_name_);
+ // read axes names, ignore them
+ std::getline(ifstr, s);
+ std::getline(ifstr, s);
Index index = 0;
while(std::getline(ifstr, s)) {
@@ -84,7 +70,8 @@ namespace md {
}
}
- void Bifiltration::phat_like_format_reader(std::ifstream& ifstr)
+ template<class Real>
+ void Bifiltration<Real>::phat_like_format_reader(std::ifstream& ifstr)
{
spd::debug("Enter phat_like_format_reader");
// read stream line by line; do not use >> operator
@@ -105,7 +92,8 @@ namespace md {
spd::debug("Read {} simplices from file", n_simplices);
}
- void Bifiltration::scale(Real lambda)
+ template<class Real>
+ void Bifiltration<Real>::scale(Real lambda)
{
for(auto& s : simplices_) {
s.scale(lambda);
@@ -113,10 +101,11 @@ namespace md {
init();
}
- void Bifiltration::sanity_check() const
+ template<class Real>
+ void Bifiltration<Real>::sanity_check() const
{
#ifdef DEBUG
- spd::debug("Enter Bifiltration::sanity_check");
+ spd::debug("Enter Bifiltration<Real>::sanity_check");
// check that boundary has correct number of simplices,
// each bounding simplex has correct dim
// and appears in the filtration before the simplex it bounds
@@ -129,16 +118,17 @@ namespace md {
assert(bdry_simplex.position().is_less(s.position(), false));
}
}
- spd::debug("Exit Bifiltration::sanity_check");
+ spd::debug("Exit Bifiltration<Real>::sanity_check");
#endif
}
- Diagram Bifiltration::weighted_slice_diagram(const DualPoint& line, int dim) const
+ template<class Real>
+ Diagram<Real> Bifiltration<Real>::weighted_slice_diagram(const DualPoint<Real>& line, int dim) const
{
- DiagramKeeper dgm;
+ DiagramKeeper<Real> dgm;
// make a copy for now; I want slice_diagram to be const
- std::vector<Simplex> simplices(simplices_);
+ std::vector<Simplex<Real>> simplices(simplices_);
// std::vector<Simplex> simplices;
// simplices.reserve(simplices_.size() / 2);
@@ -156,7 +146,7 @@ namespace md {
}
std::sort(simplices.begin(), simplices.end(),
- [](const Simplex& a, const Simplex& b) { return a.value() < b.value(); });
+ [](const Simplex<Real>& a, const Simplex<Real>& b) { return a.value() < b.value(); });
std::map<Index, Index> index_map;
for(Index i = 0; i < (int) simplices.size(); i++) {
index_map[simplices[i].id()] = i;
@@ -202,17 +192,20 @@ namespace md {
return dgm.get_diagram(dim);
}
- Box Bifiltration::bounding_box() const
+ template<class Real>
+ Box<Real> Bifiltration<Real>::bounding_box() const
{
return bounding_box_;
}
- Real Bifiltration::minimal_coordinate() const
+ template<class Real>
+ Real Bifiltration<Real>::minimal_coordinate() const
{
return std::min(bounding_box_.lower_left().x, bounding_box_.lower_left().y);
}
- void Bifiltration::translate(Real a)
+ template<class Real>
+ void Bifiltration<Real>::translate(Real a)
{
bounding_box_.translate(a);
for(auto& simplex : simplices_) {
@@ -220,7 +213,8 @@ namespace md {
}
}
- Real Bifiltration::max_x() const
+ template<class Real>
+ Real Bifiltration<Real>::max_x() const
{
if (simplices_.empty())
return 1;
@@ -230,7 +224,8 @@ namespace md {
return me->position().x;
}
- Real Bifiltration::max_y() const
+ template<class Real>
+ Real Bifiltration<Real>::max_y() const
{
if (simplices_.empty())
return 1;
@@ -240,7 +235,8 @@ namespace md {
return me->position().y;
}
- Real Bifiltration::min_x() const
+ template<class Real>
+ Real Bifiltration<Real>::min_x() const
{
if (simplices_.empty())
return 0;
@@ -250,7 +246,8 @@ namespace md {
return me->position().x;
}
- Real Bifiltration::min_y() const
+ template<class Real>
+ Real Bifiltration<Real>::min_y() const
{
if (simplices_.empty())
return 0;
@@ -260,12 +257,14 @@ namespace md {
return me->position().y;
}
- void Bifiltration::add_simplex(md::Index _id, md::Point birth, int _dim, const md::Column& _bdry)
+ template<class Real>
+ void Bifiltration<Real>::add_simplex(Index _id, Point<Real> birth, int _dim, const Column& _bdry)
{
simplices_.emplace_back(_id, birth, _dim, _bdry);
}
- void Bifiltration::save(const std::string& filename, md::BifiltrationFormat format)
+ template<class Real>
+ void Bifiltration<Real>::save(const std::string& filename, md::BifiltrationFormat format)
{
switch(format) {
case BifiltrationFormat::rivet:
@@ -292,7 +291,8 @@ namespace md {
}
}
- void Bifiltration::postprocess_rivet_format()
+ template<class Real>
+ void Bifiltration<Real>::postprocess_rivet_format()
{
std::map<Column, Index> facets_to_ids;
@@ -324,16 +324,19 @@ namespace md {
} // loop over simplices
}
- std::ostream& operator<<(std::ostream& os, const Bifiltration& bif)
+ template<class Real>
+ std::ostream& operator<<(std::ostream& os, const Bifiltration<Real>& bif)
{
- os << "Bifiltration, axes = " << bif.parameter_1_name_ << ", " << bif.parameter_2_name_ << std::endl;
+ os << "Bifiltration [" << std::endl;
for(const auto& s : bif.simplices()) {
os << s << std::endl;
}
+ os << "]" << std::endl;
return os;
}
- BifiltrationProxy::BifiltrationProxy(const md::Bifiltration& bif, int dim)
+ template<class Real>
+ BifiltrationProxy<Real>::BifiltrationProxy(const Bifiltration<Real>& bif, int dim)
:
dim_(dim),
bif_(bif)
@@ -341,7 +344,8 @@ namespace md {
cache_positions();
}
- void BifiltrationProxy::cache_positions() const
+ template<class Real>
+ void BifiltrationProxy<Real>::cache_positions() const
{
cached_positions_.clear();
for(const auto& simplex : bif_.simplices()) {
@@ -350,7 +354,9 @@ namespace md {
}
}
- PointVec BifiltrationProxy::positions() const
+ template<class Real>
+ PointVec<Real>
+ BifiltrationProxy<Real>::positions() const
{
if (cached_positions_.empty()) {
cache_positions();
@@ -359,46 +365,54 @@ namespace md {
}
// translate all points by vector (a,a)
- void BifiltrationProxy::translate(Real a)
+ template<class Real>
+ void BifiltrationProxy<Real>::translate(Real a)
{
bif_.translate(a);
}
// return minimal value of x- and y-coordinates
// among all simplices
- Real BifiltrationProxy::minimal_coordinate() const
+ template<class Real>
+ Real BifiltrationProxy<Real>::minimal_coordinate() const
{
return bif_.minimal_coordinate();
}
// return box that contains positions of all simplices
- Box BifiltrationProxy::bounding_box() const
+ template<class Real>
+ Box<Real> BifiltrationProxy<Real>::bounding_box() const
{
return bif_.bounding_box();
}
- Real BifiltrationProxy::max_x() const
+ template<class Real>
+ Real BifiltrationProxy<Real>::max_x() const
{
return bif_.max_x();
}
- Real BifiltrationProxy::max_y() const
+ template<class Real>
+ Real BifiltrationProxy<Real>::max_y() const
{
return bif_.max_y();
}
- Real BifiltrationProxy::min_x() const
+ template<class Real>
+ Real BifiltrationProxy<Real>::min_x() const
{
return bif_.min_x();
}
- Real BifiltrationProxy::min_y() const
+ template<class Real>
+ Real BifiltrationProxy<Real>::min_y() const
{
return bif_.min_y();
}
- Diagram BifiltrationProxy::weighted_slice_diagram(const DualPoint& slice) const
+ template<class Real>
+ Diagram<Real> BifiltrationProxy<Real>::weighted_slice_diagram(const DualPoint<Real>& slice) const
{
return bif_.weighted_slice_diagram(slice, dim_);
}
diff --git a/matching/include/box.h b/matching/include/box.h
index 2990fba..4243667 100644
--- a/matching/include/box.h
+++ b/matching/include/box.h
@@ -8,20 +8,23 @@
namespace md {
+ template<class Real_>
struct Box {
+ public:
+ using Real = Real_;
private:
- Point ll;
- Point ur;
+ Point<Real> ll;
+ Point<Real> ur;
public:
- Box(Point ll = Point(), Point ur = Point())
+ Box(Point<Real> ll = Point<Real>(), Point<Real> ur = Point<Real>())
:ll(ll), ur(ur)
{
}
- Box(Point center, Real width, Real height) :
- ll(Point(center.x - 0.5 * width, center.y - 0.5 * height)),
- ur(Point(center.x + 0.5 * width, center.y + 0.5 * height))
+ Box(Point<Real> center, Real width, Real height) :
+ ll(Point<Real>(center.x - 0.5 * width, center.y - 0.5 * height)),
+ ur(Point<Real>(center.x + 0.5 * width, center.y + 0.5 * height))
{
}
@@ -30,11 +33,9 @@ namespace md {
inline double height() const { return ur.y - ll.y; }
- inline Point lower_left() const { return ll; }
- inline Point upper_right() const { return ur; }
- inline Point center() const { return Point((ll.x + ur.x) / 2, (ll.y + ur.y) / 2); }
-
-// bool inside(Point& p) const { return ll.x <= p.x && ll.y <= p.y && ur.x >= p.x && ur.y >= p.y; }
+ inline Point<Real> lower_left() const { return ll; }
+ inline Point<Real> upper_right() const { return ur; }
+ inline Point<Real> center() const { return Point<Real>((ll.x + ur.x) / 2, (ll.y + ur.y) / 2); }
inline bool operator==(const Box& p)
{
@@ -43,58 +44,16 @@ namespace md {
std::vector<Box> refine() const;
- std::vector<Point> corners() const;
+ std::vector<Point<Real>> corners() const;
void translate(Real a);
-
- // return minimal and maximal value of func
- // on the corners of the box
- template<typename F>
- std::pair<Real, Real> min_max_on_corners(const F& func) const;
-
- friend std::ostream& operator<<(std::ostream& os, const Box& box);
};
- std::ostream& operator<<(std::ostream& os, const Box& box);
-// template<typename InputIterator>
-// Box compute_bounding_box(InputIterator simplices_begin, InputIterator simplices_end)
-// {
-// if (simplices_begin == simplices_end) {
-// return Box();
-// }
-// Box bb;
-// bb.ll = bb.ur = simplices_begin->pos;
-// for (InputIterator it = simplices_begin; it != simplices_end; it++) {
-// Point& pos = it->pos;
-// if (pos.x < bb.ll.x) {
-// bb.ll.x = pos.x;
-// }
-// if (pos.y < bb.ll.y) {
-// bb.ll.y = pos.y;
-// }
-// if (pos.x > bb.ur.x) {
-// bb.ur.x = pos.x;
-// }
-// if (pos.y > bb.ur.y) {
-// bb.ur.y = pos.y;
-// }
-// }
-// return bb;
-// }
-
- Box get_enclosing_box(const Box& box_a, const Box& box_b);
-
- template<typename F>
- std::pair<Real, Real> Box::min_max_on_corners(const F& func) const
- {
- std::pair<Real, Real> min_max { std::numeric_limits<Real>::max(), -std::numeric_limits<Real>::max() };
- for(Point p : corners()) {
- Real value = func(p);
- min_max.first = std::min(min_max.first, value);
- min_max.second = std::max(min_max.second, value);
- }
- return min_max;
- };
+ template<class Real>
+ std::ostream& operator<<(std::ostream& os, const Box<Real>& box);
+
} // namespace md
+#include "box.hpp"
+
#endif //MATCHING_DISTANCE_BOX_H
diff --git a/matching/include/box.hpp b/matching/include/box.hpp
new file mode 100644
index 0000000..f551d84
--- /dev/null
+++ b/matching/include/box.hpp
@@ -0,0 +1,52 @@
+namespace md {
+
+ template<class Real>
+ std::ostream& operator<<(std::ostream& os, const Box<Real>& box)
+ {
+ os << "Box(lower_left = " << box.lower_left() << ", upper_right = " << box.upper_right() << ")";
+ return os;
+ }
+
+ template<class Real>
+ void Box<Real>::translate(Real a)
+ {
+ ll.x += a;
+ ll.y += a;
+ ur.x += a;
+ ur.y += a;
+ }
+
+ template<class Real>
+ std::vector<Box<Real>> Box<Real>::refine() const
+ {
+ std::vector<Box<Real>> result;
+
+// 1 | 2
+// 0 | 3
+
+ Point<Real> new_ll = lower_left();
+ Point<Real> new_ur = center();
+ result.emplace_back(new_ll, new_ur);
+
+ new_ll.y = center().y;
+ new_ur.y = ur.y;
+ result.emplace_back(new_ll, new_ur);
+
+ new_ll = center();
+ new_ur = upper_right();
+ result.emplace_back(new_ll, new_ur);
+
+ new_ll.y = ll.y;
+ new_ur.y = center().y;
+ result.emplace_back(new_ll, new_ur);
+
+ return result;
+ }
+
+ template<class Real>
+ std::vector<Point<Real>> Box<Real>::corners() const
+ {
+ return {ll, Point<Real>(ll.x, ur.y), ur, Point<Real>(ur.x, ll.y)};
+ };
+
+}
diff --git a/matching/include/cell_with_value.h b/matching/include/cell_with_value.h
index 25644d1..3548a11 100644
--- a/matching/include/cell_with_value.h
+++ b/matching/include/cell_with_value.h
@@ -1,7 +1,3 @@
-//
-// Created by narn on 16.07.19.
-//
-
#ifndef MATCHING_DISTANCE_CELL_WITH_VALUE_H
#define MATCHING_DISTANCE_CELL_WITH_VALUE_H
@@ -21,7 +17,29 @@ namespace md {
upper_right
};
- std::ostream& operator<<(std::ostream& os, const ValuePoint& vp);
+ inline std::ostream& operator<<(std::ostream& os, const ValuePoint& vp)
+ {
+ switch(vp) {
+ case ValuePoint::upper_left :
+ os << "upper_left";
+ break;
+ case ValuePoint::upper_right :
+ os << "upper_right";
+ break;
+ case ValuePoint::lower_left :
+ os << "lower_left";
+ break;
+ case ValuePoint::lower_right :
+ os << "lower_right";
+ break;
+ case ValuePoint::center:
+ os << "center";
+ break;
+ default:
+ os << "FORGOTTEN ValuePoint";
+ }
+ return os;
+ }
const std::vector<ValuePoint> k_all_vps = {ValuePoint::center, ValuePoint::lower_left, ValuePoint::upper_left,
ValuePoint::upper_right, ValuePoint::lower_right};
@@ -31,8 +49,10 @@ namespace md {
// represents a cell in the dual space with the value
// of the weighted bottleneck distance
+ template<class Real_>
class CellWithValue {
public:
+ using Real = Real_;
CellWithValue() = default;
@@ -44,18 +64,18 @@ namespace md {
CellWithValue& operator=(CellWithValue&& other) = default;
- CellWithValue(const DualBox& b, int level)
+ CellWithValue(const DualBox<Real>& b, int level)
:dual_box_(b), level_(level) { }
- DualBox dual_box() const { return dual_box_; }
+ DualBox<Real> dual_box() const { return dual_box_; }
- DualPoint center() const { return dual_box_.center(); }
+ DualPoint<Real> center() const { return dual_box_.center(); }
Real value_at(ValuePoint vp) const;
bool has_value_at(ValuePoint vp) const;
- DualPoint value_point(ValuePoint vp) const;
+ DualPoint<Real> value_point(ValuePoint vp) const;
int level() const { return level_; }
@@ -73,8 +93,6 @@ namespace md {
std::vector<CellWithValue> get_refined_cells() const;
- friend std::ostream& operator<<(std::ostream&, const CellWithValue&);
-
void set_max_possible_value(Real new_upper_bound);
int num_values() const;
@@ -100,7 +118,7 @@ namespace md {
bool has_upper_right_value() const { return upper_right_value_ >= 0; }
- DualBox dual_box_;
+ DualBox<Real> dual_box_;
Real central_value_ {-1.0};
Real lower_left_value_ {-1.0};
Real lower_right_value_ {-1.0};
@@ -114,7 +132,10 @@ namespace md {
bool has_max_possible_value_ {false};
};
- std::ostream& operator<<(std::ostream& os, const CellWithValue& cell);
+ template<class Real>
+ std::ostream& operator<<(std::ostream& os, const CellWithValue<Real>& cell);
} // namespace md
+#include "cell_with_value.hpp"
+
#endif //MATCHING_DISTANCE_CELL_WITH_VALUE_H
diff --git a/matching/src/cell_with_value.cpp b/matching/include/cell_with_value.hpp
index d8fd7d4..88b2569 100644
--- a/matching/src/cell_with_value.cpp
+++ b/matching/include/cell_with_value.hpp
@@ -1,17 +1,11 @@
-#include <spdlog/spdlog.h>
-#include <spdlog/fmt/ostr.h>
-
-namespace spd = spdlog;
-
-#include "cell_with_value.h"
-
namespace md {
#ifdef MD_DEBUG
- long long int CellWithValue::max_id = 0;
+ long long int CellWithValue<Real>::max_id = 0;
#endif
- Real CellWithValue::value_at(ValuePoint vp) const
+ template<class Real>
+ Real CellWithValue<Real>::value_at(ValuePoint vp) const
{
switch(vp) {
case ValuePoint::upper_left :
@@ -29,7 +23,8 @@ namespace md {
return 1.0 / 0.0;
}
- bool CellWithValue::has_value_at(ValuePoint vp) const
+ template<class Real>
+ bool CellWithValue<Real>::has_value_at(ValuePoint vp) const
{
switch(vp) {
case ValuePoint::upper_left :
@@ -45,9 +40,10 @@ namespace md {
}
// to shut up compiler warning
return 1.0 / 0.0;
- }
+ }
- DualPoint CellWithValue::value_point(md::ValuePoint vp) const
+ template<class Real>
+ DualPoint<Real> CellWithValue<Real>::value_point(md::ValuePoint vp) const
{
switch(vp) {
case ValuePoint::upper_left :
@@ -62,27 +58,31 @@ namespace md {
return dual_box().center();
}
// to shut up compiler warning
- return DualPoint();
- }
+ return DualPoint<Real>();
+ }
- bool CellWithValue::has_corner_value() const
+ template<class Real>
+ bool CellWithValue<Real>::has_corner_value() const
{
return has_lower_left_value() or has_lower_right_value() or has_upper_left_value()
or has_upper_right_value();
}
- Real CellWithValue::stored_upper_bound() const
+ template<class Real>
+ Real CellWithValue<Real>::stored_upper_bound() const
{
assert(has_max_possible_value_);
return max_possible_value_;
}
- Real CellWithValue::max_corner_value() const
+ template<class Real>
+ Real CellWithValue<Real>::max_corner_value() const
{
return std::max({lower_left_value_, lower_right_value_, upper_left_value_, upper_right_value_});
}
- Real CellWithValue::min_value() const
+ template<class Real>
+ Real CellWithValue<Real>::min_value() const
{
Real result = std::numeric_limits<Real>::max();
for(auto vp : k_all_vps) {
@@ -94,13 +94,14 @@ namespace md {
return result;
}
- std::vector<CellWithValue> CellWithValue::get_refined_cells() const
+ template<class Real>
+ std::vector<CellWithValue<Real>> CellWithValue<Real>::get_refined_cells() const
{
- std::vector<CellWithValue> result;
+ std::vector<CellWithValue<Real>> result;
result.reserve(4);
for(const auto& refined_box : dual_box_.refine()) {
- CellWithValue refined_cell(refined_box, level() + 1);
+ CellWithValue<Real> refined_cell(refined_box, level() + 1);
#ifdef MD_DEBUG
refined_cell.parent_ids = parent_ids;
@@ -142,10 +143,11 @@ namespace md {
return result;
}
- void CellWithValue::set_value_at(md::ValuePoint vp, md::Real new_value)
+ template<class Real>
+ void CellWithValue<Real>::set_value_at(ValuePoint vp, Real new_value)
{
if (has_value_at(vp))
- spd::error("CellWithValue: trying to re-assign value!, this = {}, vp = {}", *this, vp);
+ spd::error("CellWithValue<Real>: trying to re-assign value!, this = {}, vp = {}", *this, vp);
switch(vp) {
case ValuePoint::upper_left :
@@ -164,11 +166,10 @@ namespace md {
central_value_ = new_value;
break;
}
-
-
}
-
- int CellWithValue::num_values() const
+
+ template<class Real>
+ int CellWithValue<Real>::num_values() const
{
int result = 0;
for(ValuePoint vp : k_all_vps) {
@@ -177,8 +178,9 @@ namespace md {
return result;
}
-
- void CellWithValue::set_max_possible_value(Real new_upper_bound)
+
+ template<class Real>
+ void CellWithValue<Real>::set_max_possible_value(Real new_upper_bound)
{
assert(new_upper_bound >= central_value_);
assert(new_upper_bound >= lower_left_value_);
@@ -189,33 +191,9 @@ namespace md {
max_possible_value_ = new_upper_bound;
}
- std::ostream& operator<<(std::ostream& os, const ValuePoint& vp)
- {
- switch(vp) {
- case ValuePoint::upper_left :
- os << "upper_left";
- break;
- case ValuePoint::upper_right :
- os << "upper_right";
- break;
- case ValuePoint::lower_left :
- os << "lower_left";
- break;
- case ValuePoint::lower_right :
- os << "lower_right";
- break;
- case ValuePoint::center:
- os << "center";
- break;
- default:
- os << "FORGOTTEN ValuePoint";
- }
- return os;
- }
-
-
- std::ostream& operator<<(std::ostream& os, const CellWithValue& cell)
+ template<class Real>
+ std::ostream& operator<<(std::ostream& os, const CellWithValue<Real>& cell)
{
os << "CellWithValue(box = " << cell.dual_box() << ", ";
@@ -244,4 +222,3 @@ namespace md {
}
} // namespace md
-
diff --git a/matching/include/common_defs.h b/matching/include/common_defs.h
index 3f3d937..8d01325 100644
--- a/matching/include/common_defs.h
+++ b/matching/include/common_defs.h
@@ -1,8 +1,8 @@
#ifndef MATCHING_DISTANCE_DEF_DEBUG_H
#define MATCHING_DISTANCE_DEF_DEBUG_H
-//#define EXPERIMENTAL_TIMING
-//#define PRINT_HEAT_MAP
+//#define MD_EXPERIMENTAL_TIMING
+//#define MD_PRINT_HEAT_MAP
//#define MD_DEBUG
//#define MD_DO_CHECKS
//#define MD_DO_FULL_CHECK
diff --git a/matching/include/common_util.h b/matching/include/common_util.h
index 2d8dcb0..778536f 100644
--- a/matching/include/common_util.h
+++ b/matching/include/common_util.h
@@ -11,22 +11,24 @@
#include <map>
#include <functional>
+#include <spdlog/spdlog.h>
+#include <spdlog/fmt/ostr.h>
+
+namespace spd = spdlog;
+
#include "common_defs.h"
#include "phat/helpers/misc.h"
-
namespace md {
-
- using Real = double;
- using RealVec = std::vector<Real>;
using Index = phat::index;
using IndexVec = std::vector<Index>;
- static constexpr Real pi = M_PI;
+ //static constexpr Real pi = M_PI;
using Column = std::vector<Index>;
+ template<class Real>
struct Point {
Real x;
Real y;
@@ -71,59 +73,56 @@ namespace md {
};
- using PointVec = std::vector<Point>;
-
- Point operator+(const Point& u, const Point& v);
+ template<class Real>
+ using PointVec = std::vector<Point<Real>>;
- Point operator-(const Point& u, const Point& v);
+ template<class Real>
+ Point<Real> operator+(const Point<Real>& u, const Point<Real>& v);
- Point least_upper_bound(const Point& u, const Point& v);
+ template<class Real>
+ Point<Real> operator-(const Point<Real>& u, const Point<Real>& v);
- Point greatest_lower_bound(const Point& u, const Point& v);
- Point max_point();
+ template<class Real>
+ Point<Real> least_upper_bound(const Point<Real>& u, const Point<Real>& v);
- Point min_point();
+ template<class Real>
+ Point<Real> greatest_lower_bound(const Point<Real>& u, const Point<Real>& v);
- std::ostream& operator<<(std::ostream& ostr, const Point& vec);
+ template<class Real>
+ Point<Real> max_point();
- Real L_infty(const Point& v);
+ template<class Real>
+ Point<Real> min_point();
- Real l_2_norm(const Point& v);
+ template<class Real>
+ std::ostream& operator<<(std::ostream& ostr, const Point<Real>& vec);
- Real l_2_dist(const Point& x, const Point& y);
+ template<class Real>
+ using DiagramPoint = std::pair<Real, Real>;
- Real l_infty_dist(const Point& x, const Point& y);
+ template<class Real>
+ using Diagram = std::vector<DiagramPoint<Real>>;
- using Interval = std::pair<Real, Real>;
-
- // return minimal interval that contains both a and b
- inline Interval minimal_covering_interval(Interval a, Interval b)
- {
- return {std::min(a.first, b.first), std::max(a.second, b.second)};
- }
// to keep diagrams in all dimensions
// TODO: store in Hera format?
+ template<class Real>
class DiagramKeeper {
public:
- using DiagramPoint = std::pair<Real, Real>;
- using Diagram = std::vector<DiagramPoint>;
DiagramKeeper() { };
void add_point(int dim, Real birth, Real death);
- Diagram get_diagram(int dim) const;
+ Diagram<Real> get_diagram(int dim) const;
void clear() { data_.clear(); }
private:
- std::map<int, Diagram> data_;
+ std::map<int, Diagram<Real>> data_;
};
- using Diagram = std::vector<std::pair<Real, Real>>;
-
template<typename C>
std::string container_to_string(const C& cont)
{
@@ -140,42 +139,18 @@ namespace md {
return ss.str();
}
- int gcd(int a, int b);
-
- struct Rational {
- int numerator {0};
- int denominator {1};
- Rational() = default;
- Rational(int n, int d) : numerator(n / gcd(n, d)), denominator(d / gcd(n, d)) {}
- Rational(std::pair<int, int> p) : Rational(p.first, p.second) {}
- Rational(int n) : numerator(n), denominator(1) {}
- Real to_real() const { return (Real)numerator / (Real)denominator; }
- void reduce();
- Rational& operator+=(const Rational& rhs);
- Rational& operator-=(const Rational& rhs);
- Rational& operator*=(const Rational& rhs);
- Rational& operator/=(const Rational& rhs);
- };
-
- using namespace std::rel_ops;
-
- bool operator==(const Rational& a, const Rational& b);
- bool operator<(const Rational& a, const Rational& b);
- std::ostream& operator<<(std::ostream& os, const Rational& a);
-
- // arithmetic
- Rational operator+(Rational a, const Rational& b);
- Rational operator-(Rational a, const Rational& b);
- Rational operator*(Rational a, const Rational& b);
- Rational operator/(Rational a, const Rational& b);
-
- Rational reduce(Rational frac);
-
- Rational midpoint(Rational a, Rational b);
-
// return true, if s is empty or starts with # (commented out line)
// whitespaces in the beginning of s are ignored
- bool ignore_line(const std::string& s);
+ inline bool ignore_line(const std::string& s)
+ {
+ for(auto c : s) {
+ if (isspace(c))
+ continue;
+ return (c == '#');
+ }
+ return true;
+ }
+
// split string by delimeter
template<typename Out>
@@ -195,10 +170,10 @@ namespace md {
}
namespace std {
- template<>
- struct hash<md::Point>
+ template<class Real>
+ struct hash<md::Point<Real>>
{
- std::size_t operator()(const md::Point& p) const
+ std::size_t operator()(const md::Point<Real>& p) const
{
auto hx = std::hash<decltype(p.x)>()(p.x);
auto hy = std::hash<decltype(p.y)>()(p.y);
@@ -207,5 +182,7 @@ namespace std {
};
};
+#include "common_util.hpp"
+
#endif //MATCHING_DISTANCE_COMMON_UTIL_H
diff --git a/matching/include/common_util.hpp b/matching/include/common_util.hpp
new file mode 100644
index 0000000..76d97af
--- /dev/null
+++ b/matching/include/common_util.hpp
@@ -0,0 +1,96 @@
+#include <vector>
+#include <utility>
+#include <cmath>
+#include <ostream>
+#include <limits>
+#include <algorithm>
+
+#include <common_util.h>
+
+#include "spdlog/spdlog.h"
+#include "spdlog/fmt/ostr.h"
+
+namespace md {
+
+ template<class Real>
+ Point<Real> operator+(const Point<Real>& u, const Point<Real>& v)
+ {
+ return Point<Real>(u.x + v.x, u.y + v.y);
+ }
+
+ template<class Real>
+ Point<Real> operator-(const Point<Real>& u, const Point<Real>& v)
+ {
+ return Point<Real>(u.x - v.x, u.y - v.y);
+ }
+
+ template<class Real>
+ Point<Real> least_upper_bound(const Point<Real>& u, const Point<Real>& v)
+ {
+ return Point<Real>(std::max(u.x, v.x), std::max(u.y, v.y));
+ }
+
+ template<class Real>
+ Point<Real> greatest_lower_bound(const Point<Real>& u, const Point<Real>& v)
+ {
+ return Point<Real>(std::min(u.x, v.x), std::min(u.y, v.y));
+ }
+
+ template<class Real>
+ Point<Real> max_point()
+ {
+ return Point<Real>(std::numeric_limits<Real>::max(), std::numeric_limits<Real>::min());
+ }
+
+ template<class Real>
+ Point<Real> min_point()
+ {
+ return Point<Real>(-std::numeric_limits<Real>::max(), -std::numeric_limits<Real>::min());
+ }
+
+ template<class Real>
+ std::ostream& operator<<(std::ostream& ostr, const Point<Real>& vec)
+ {
+ ostr << "(" << vec.x << ", " << vec.y << ")";
+ return ostr;
+ }
+
+ template<class Real>
+ Real l_infty_norm(const Point<Real>& v)
+ {
+ return std::max(std::abs(v.x), std::abs(v.y));
+ }
+
+ template<class Real>
+ Real l_2_norm(const Point<Real>& v)
+ {
+ return v.norm();
+ }
+
+ template<class Real>
+ Real l_2_dist(const Point<Real>& x, const Point<Real>& y)
+ {
+ return l_2_norm(x - y);
+ }
+
+ template<class Real>
+ Real l_infty_dist(const Point<Real>& x, const Point<Real>& y)
+ {
+ return l_infty_norm(x - y);
+ }
+
+ template<class Real>
+ void DiagramKeeper<Real>::add_point(int dim, Real birth, Real death)
+ {
+ data_[dim].emplace_back(birth, death);
+ }
+
+ template<class Real>
+ Diagram<Real> DiagramKeeper<Real>::get_diagram(int dim) const
+ {
+ if (data_.count(dim) == 1)
+ return data_.at(dim);
+ else
+ return Diagram<Real>();
+ }
+}
diff --git a/matching/include/dual_box.h b/matching/include/dual_box.h
index ce0384d..0e4f4d5 100644
--- a/matching/include/dual_box.h
+++ b/matching/include/dual_box.h
@@ -4,16 +4,23 @@
#include <ostream>
#include <limits>
#include <vector>
+#include <random>
+
+#include "spdlog/spdlog.h"
+#include "spdlog/fmt/ostr.h"
+
#include "common_util.h"
#include "dual_point.h"
namespace md {
+
+ template<class Real>
class DualBox {
public:
- DualBox(DualPoint ll, DualPoint ur);
+ DualBox(DualPoint<Real> ll, DualPoint<Real> ur);
DualBox() = default;
DualBox(const DualBox&) = default;
@@ -23,12 +30,12 @@ namespace md {
DualBox& operator=(DualBox&& other) = default;
- DualPoint center() const { return midpoint(lower_left_, upper_right_); }
- DualPoint lower_left() const { return lower_left_; }
- DualPoint upper_right() const { return upper_right_; }
+ DualPoint<Real> center() const { return midpoint(lower_left_, upper_right_); }
+ DualPoint<Real> lower_left() const { return lower_left_; }
+ DualPoint<Real> upper_right() const { return upper_right_; }
- DualPoint lower_right() const;
- DualPoint upper_left() const;
+ DualPoint<Real> lower_right() const;
+ DualPoint<Real> upper_left() const;
AxisType axis_type() const { return lower_left_.axis_type(); }
AngleType angle_type() const { return lower_left_.angle_type(); }
@@ -42,66 +49,35 @@ namespace md {
bool is_flat() const { return upper_right_.is_flat(); }
bool is_steep() const { return lower_left_.is_steep(); }
- // return minimal and maximal value of func
- // on the corners of the box
- template<typename F>
- std::pair<Real, Real> min_max_on_corners(const F& func) const;
-
- template<typename F>
- Real max_abs_value(const F& func) const;
-
-
std::vector<DualBox> refine() const;
- std::vector<DualPoint> corners() const;
- std::vector<DualPoint> critical_points(const Point& p) const;
+ std::vector<DualPoint<Real>> corners() const;
+ std::vector<DualPoint<Real>> critical_points(const Point<Real>& p) const;
// sample n points from the box uniformly; for tests
- std::vector<DualPoint> random_points(int n) const;
+ std::vector<DualPoint<Real>> random_points(int n) const;
// return 2 dual points at the boundary
// where push changes from horizontal to vertical
- std::vector<DualPoint> push_change_points(const Point& p) const;
-
- friend std::ostream& operator<<(std::ostream& os, const DualBox& db);
+ std::vector<DualPoint<Real>> push_change_points(const Point<Real>& p) const;
// check that a has same sign, angles are all flat or all steep
bool sanity_check() const;
- bool contains(const DualPoint& dp) const;
+ bool contains(const DualPoint<Real>& dp) const;
bool operator==(const DualBox& other) const;
private:
- DualPoint lower_left_;
- DualPoint upper_right_;
+ DualPoint<Real> lower_left_;
+ DualPoint<Real> upper_right_;
};
- std::ostream& operator<<(std::ostream& os, const DualBox& db);
-
- template<typename F>
- std::pair<Real, Real> DualBox::min_max_on_corners(const F& func) const
+ template<class Real>
+ std::ostream& operator<<(std::ostream& os, const DualBox<Real>& db)
{
- std::pair<Real, Real> min_max { std::numeric_limits<Real>::max(), -std::numeric_limits<Real>::max() };
- for(auto p : corners()) {
- Real value = func(p);
- min_max.first = std::min(min_max.first, value);
- min_max.second = std::max(min_max.second, value);
- }
- return min_max;
- };
-
-
- template<typename F>
- Real DualBox::max_abs_value(const F& func) const
- {
- Real result = 0;
- for(auto p_1 : corners()) {
- for(auto p_2 : corners()) {
- Real value = fabs(func(p_1, p_2));
- result = std::max(value, result);
- }
- }
- return result;
- };
-
+ os << "DualBox(" << db.lower_left() << ", " << db.upper_right() << ")";
+ return os;
+ }
}
+#include "dual_box.hpp"
+
#endif //MATCHING_DISTANCE_DUAL_BOX_H
diff --git a/matching/src/dual_box.cpp b/matching/include/dual_box.hpp
index ff4d30c..85f7f27 100644
--- a/matching/src/dual_box.cpp
+++ b/matching/include/dual_box.hpp
@@ -1,36 +1,24 @@
-#include <random>
-
-#include "spdlog/spdlog.h"
-#include "spdlog/fmt/ostr.h"
-
-namespace spd = spdlog;
-
-#include "dual_box.h"
-
namespace md {
- std::ostream& operator<<(std::ostream& os, const DualBox& db)
- {
- os << "DualBox(" << db.lower_left_ << ", " << db.upper_right_ << ")";
- return os;
- }
-
- DualBox::DualBox(DualPoint ll, DualPoint ur)
+ template<class Real>
+ DualBox<Real>::DualBox(DualPoint<Real> ll, DualPoint<Real> ur)
:lower_left_(ll), upper_right_(ur)
{
}
- std::vector<DualPoint> DualBox::corners() const
+ template<class Real>
+ std::vector<DualPoint<Real>> DualBox<Real>::corners() const
{
return {lower_left_,
- DualPoint(axis_type(), angle_type(), lower_left_.lambda(), upper_right_.mu()),
+ DualPoint<Real>(axis_type(), angle_type(), lower_left_.lambda(), upper_right_.mu()),
upper_right_,
- DualPoint(axis_type(), angle_type(), upper_right_.lambda(), lower_left_.mu())};
+ DualPoint<Real>(axis_type(), angle_type(), upper_right_.lambda(), lower_left_.mu())};
}
- std::vector<DualPoint> DualBox::push_change_points(const Point& p) const
+ template<class Real>
+ std::vector<DualPoint<Real>> DualBox<Real>::push_change_points(const Point<Real>& p) const
{
- std::vector<DualPoint> result;
+ std::vector<DualPoint<Real>> result;
result.reserve(2);
bool is_y_type = lower_left_.is_y_type();
@@ -38,13 +26,13 @@ namespace md {
auto mu_from_lambda = [p, is_y_type, is_flat](Real lambda) {
bool is_x_type = not is_y_type, is_steep = not is_flat;
- if (is_y_type and is_flat) {
+ if (is_y_type && is_flat) {
return p.y - lambda * p.x;
- } else if (is_y_type and is_steep) {
+ } else if (is_y_type && is_steep) {
return p.y - p.x / lambda;
- } else if (is_x_type and is_flat) {
+ } else if (is_x_type && is_flat) {
return p.x - p.y / lambda;
- } else if (is_x_type and is_steep) {
+ } else if (is_x_type && is_steep) {
return p.x - lambda * p.y;
}
// to shut up compiler warning
@@ -53,13 +41,13 @@ namespace md {
auto lambda_from_mu = [p, is_y_type, is_flat](Real mu) {
bool is_x_type = not is_y_type, is_steep = not is_flat;
- if (is_y_type and is_flat) {
+ if (is_y_type && is_flat) {
return (p.y - mu) / p.x;
- } else if (is_y_type and is_steep) {
+ } else if (is_y_type && is_steep) {
return p.x / (p.y - mu);
- } else if (is_x_type and is_flat) {
+ } else if (is_x_type && is_flat) {
return p.y / (p.x - mu);
- } else if (is_x_type and is_steep) {
+ } else if (is_x_type && is_steep) {
return (p.x - mu) / p.y;
}
// to shut up compiler warning
@@ -67,7 +55,7 @@ namespace md {
};
// all inequalities below are strict: equality means it is a corner
- // and critical_points() returns corners anyway
+ // && critical_points() returns corners anyway
Real mu_intersect_min = mu_from_lambda(lambda_min());
@@ -99,22 +87,24 @@ namespace md {
return result;
}
- std::vector<DualPoint> DualBox::critical_points(const Point& /*p*/) const
+ template<class Real>
+ std::vector<DualPoint<Real>> DualBox<Real>::critical_points(const Point<Real>& /*p*/) const
{
// maximal difference is attained at corners
return corners();
-// std::vector<DualPoint> result;
+// std::vector<DualPoint<Real>> result;
// result.reserve(6);
// for(auto dp : corners()) result.push_back(dp);
// for(auto dp : push_change_points(p)) result.push_back(dp);
// return result;
}
- std::vector<DualPoint> DualBox::random_points(int n) const
+ template<class Real>
+ std::vector<DualPoint<Real>> DualBox<Real>::random_points(int n) const
{
assert(n >= 0);
std::mt19937_64 gen(1);
- std::vector<DualPoint> result;
+ std::vector<DualPoint<Real>> result;
result.reserve(n);
std::uniform_real_distribution<Real> mu_distr(mu_min(), mu_max());
std::uniform_real_distribution<Real> lambda_distr(lambda_min(), lambda_max());
@@ -124,7 +114,8 @@ namespace md {
return result;
}
- bool DualBox::sanity_check() const
+ template<class Real>
+ bool DualBox<Real>::sanity_check() const
{
lower_left_.sanity_check();
upper_right_.sanity_check();
@@ -144,51 +135,56 @@ namespace md {
return true;
}
- std::vector<DualBox> DualBox::refine() const
+ template<class Real>
+ std::vector<DualBox<Real>> DualBox<Real>::refine() const
{
- std::vector<DualBox> result;
+ std::vector<DualBox<Real>> result;
result.reserve(4);
Real lambda_middle = (lower_left().lambda() + upper_right().lambda()) / 2.0;
Real mu_middle = (lower_left().mu() + upper_right().mu()) / 2.0;
- DualPoint refinement_center(axis_type(), angle_type(), lambda_middle, mu_middle);
+ DualPoint<Real> refinement_center(axis_type(), angle_type(), lambda_middle, mu_middle);
result.emplace_back(lower_left_, refinement_center);
- result.emplace_back(DualPoint(axis_type(), angle_type(), lambda_middle, mu_min()),
- DualPoint(axis_type(), angle_type(), lambda_max(), mu_middle));
+ result.emplace_back(DualPoint<Real>(axis_type(), angle_type(), lambda_middle, mu_min()),
+ DualPoint<Real>(axis_type(), angle_type(), lambda_max(), mu_middle));
result.emplace_back(refinement_center, upper_right_);
- result.emplace_back(DualPoint(axis_type(), angle_type(), lambda_min(), mu_middle),
- DualPoint(axis_type(), angle_type(), lambda_middle, mu_max()));
+ result.emplace_back(DualPoint<Real>(axis_type(), angle_type(), lambda_min(), mu_middle),
+ DualPoint<Real>(axis_type(), angle_type(), lambda_middle, mu_max()));
return result;
}
- bool DualBox::operator==(const DualBox& other) const
+ template<class Real>
+ bool DualBox<Real>::operator==(const DualBox& other) const
{
- return lower_left() == other.lower_left() and
+ return lower_left() == other.lower_left() &&
upper_right() == other.upper_right();
}
- bool DualBox::contains(const DualPoint& dp) const
+ template<class Real>
+ bool DualBox<Real>::contains(const DualPoint<Real>& dp) const
{
- return dp.angle_type() == angle_type() and dp.axis_type() == axis_type() and
- mu_max() >= dp.mu() and
- mu_min() <= dp.mu() and
- lambda_min() <= dp.lambda() and
+ return dp.angle_type() == angle_type() && dp.axis_type() == axis_type() &&
+ mu_max() >= dp.mu() &&
+ mu_min() <= dp.mu() &&
+ lambda_min() <= dp.lambda() &&
lambda_max() >= dp.lambda();
}
- DualPoint DualBox::lower_right() const
+ template<class Real>
+ DualPoint<Real> DualBox<Real>::lower_right() const
{
- return DualPoint(lower_left_.axis_type(), lower_left_.angle_type(), lambda_max(), mu_min());
+ return DualPoint<Real>(lower_left_.axis_type(), lower_left_.angle_type(), lambda_max(), mu_min());
}
- DualPoint DualBox::upper_left() const
+ template<class Real>
+ DualPoint<Real> DualBox<Real>::upper_left() const
{
- return DualPoint(lower_left_.axis_type(), lower_left_.angle_type(), lambda_min(), mu_max());
+ return DualPoint<Real>(lower_left_.axis_type(), lower_left_.angle_type(), lambda_min(), mu_max());
}
}
diff --git a/matching/include/dual_point.h b/matching/include/dual_point.h
index db32f1a..8438860 100644
--- a/matching/include/dual_point.h
+++ b/matching/include/dual_point.h
@@ -1,12 +1,9 @@
-//
-// Created by narn on 12.02.19.
-//
-
#ifndef MATCHING_DISTANCE_DUAL_POINT_H
#define MATCHING_DISTANCE_DUAL_POINT_H
#include <vector>
#include <ostream>
+#include <tuple>
#include "common_util.h"
#include "box.h"
@@ -25,9 +22,10 @@ namespace md {
// so, e.g., line y = x has 4 different non-equal representation.
// we are unlikely to ever need this, because 4 cases are
// always treated separately.
+ template<class Real_>
class DualPoint {
public:
- using Real = md::Real;
+ using Real = Real_;
DualPoint() = default;
@@ -56,7 +54,6 @@ namespace md {
bool is_y_type() const { return axis_type_ == AxisType::y_type; }
- friend std::ostream& operator<<(std::ostream& os, const DualPoint& dp);
bool operator<(const DualPoint& rhs) const;
AxisType axis_type() const { return axis_type_; }
@@ -66,16 +63,16 @@ namespace md {
// return true otherwise
bool sanity_check() const;
- Real weighted_push(Point p) const;
- Point push(Point p) const;
+ Real weighted_push(Point<Real> p) const;
+ Point<Real> push(Point<Real> p) const;
bool is_horizontal() const;
bool is_vertical() const;
- bool goes_below(Point p) const;
- bool goes_above(Point p) const;
+ bool goes_below(Point<Real> p) const;
+ bool goes_above(Point<Real> p) const;
- bool contains(Point p) const;
+ bool contains(Point<Real> p) const;
Real x_slope() const;
Real y_slope() const;
@@ -98,9 +95,13 @@ namespace md {
Real mu_ {-1.0};
};
- std::ostream& operator<<(std::ostream& os, const DualPoint& dp);
+ template<class Real>
+ std::ostream& operator<<(std::ostream& os, const DualPoint<Real>& dp);
- DualPoint midpoint(DualPoint x, DualPoint y);
+ template<class Real>
+ DualPoint<Real> midpoint(DualPoint<Real> x, DualPoint<Real> y);
};
+#include "dual_point.hpp"
+
#endif //MATCHING_DISTANCE_DUAL_POINT_H
diff --git a/matching/src/dual_point.cpp b/matching/include/dual_point.hpp
index 1c00b58..04e25f2 100644
--- a/matching/src/dual_point.cpp
+++ b/matching/include/dual_point.hpp
@@ -1,10 +1,6 @@
-#include <tuple>
-
-#include "dual_point.h"
-
namespace md {
- std::ostream& operator<<(std::ostream& os, const AxisType& at)
+ inline std::ostream& operator<<(std::ostream& os, const AxisType& at)
{
if (at == AxisType::x_type)
os << "x-type";
@@ -13,7 +9,7 @@ namespace md {
return os;
}
- std::ostream& operator<<(std::ostream& os, const AngleType& at)
+ inline std::ostream& operator<<(std::ostream& os, const AngleType& at)
{
if (at == AngleType::flat)
os << "flat";
@@ -22,7 +18,8 @@ namespace md {
return os;
}
- std::ostream& operator<<(std::ostream& os, const DualPoint& dp)
+ template<class Real>
+ std::ostream& operator<<(std::ostream& os, const DualPoint<Real>& dp)
{
os << "Line(" << dp.axis_type() << ", ";
os << dp.angle_type() << ", ";
@@ -37,13 +34,15 @@ namespace md {
return os;
}
- bool DualPoint::operator<(const DualPoint& rhs) const
+ template<class Real>
+ bool DualPoint<Real>::operator<(const DualPoint<Real>& rhs) const
{
return std::tie(axis_type_, angle_type_, lambda_, mu_)
< std::tie(rhs.axis_type_, rhs.angle_type_, rhs.lambda_, rhs.mu_);
}
- DualPoint::DualPoint(AxisType axis_type, AngleType angle_type, Real lambda, Real mu)
+ template<class Real>
+ DualPoint<Real>::DualPoint(AxisType axis_type, AngleType angle_type, Real lambda, Real mu)
:
axis_type_(axis_type),
angle_type_(angle_type),
@@ -53,7 +52,8 @@ namespace md {
assert(sanity_check());
}
- bool DualPoint::sanity_check() const
+ template<class Real>
+ bool DualPoint<Real>::sanity_check() const
{
if (lambda_ < 0.0)
throw std::runtime_error("Invalid line, negative lambda");
@@ -64,7 +64,8 @@ namespace md {
return true;
}
- Real DualPoint::gamma() const
+ template<class Real>
+ Real DualPoint<Real>::gamma() const
{
if (is_steep())
return atan(Real(1.0) / lambda_);
@@ -72,17 +73,19 @@ namespace md {
return atan(lambda_);
}
- DualPoint midpoint(DualPoint x, DualPoint y)
+ template<class Real>
+ DualPoint<Real> midpoint(DualPoint<Real> x, DualPoint<Real> y)
{
assert(x.angle_type() == y.angle_type() and x.axis_type() == y.axis_type());
Real lambda_mid = (x.lambda() + y.lambda()) / 2;
Real mu_mid = (x.mu() + y.mu()) / 2;
- return DualPoint(x.axis_type(), x.angle_type(), lambda_mid, mu_mid);
+ return DualPoint<Real>(x.axis_type(), x.angle_type(), lambda_mid, mu_mid);
}
// return k in the line equation y = kx + b
- Real DualPoint::y_slope() const
+ template<class Real>
+ Real DualPoint<Real>::y_slope() const
{
if (is_flat())
return lambda();
@@ -91,7 +94,8 @@ namespace md {
}
// return k in the line equation x = ky + b
- Real DualPoint::x_slope() const
+ template<class Real>
+ Real DualPoint<Real>::x_slope() const
{
if (is_flat())
return Real(1.0) / lambda();
@@ -100,7 +104,8 @@ namespace md {
}
// return b in the line equation y = kx + b
- Real DualPoint::y_intercept() const
+ template<class Real>
+ Real DualPoint<Real>::y_intercept() const
{
if (is_y_type()) {
return mu();
@@ -112,7 +117,8 @@ namespace md {
}
// return k in the line equation x = ky + b
- Real DualPoint::x_intercept() const
+ template<class Real>
+ Real DualPoint<Real>::x_intercept() const
{
if (is_x_type()) {
return mu();
@@ -123,7 +129,8 @@ namespace md {
}
}
- Real DualPoint::x_from_y(Real y) const
+ template<class Real>
+ Real DualPoint<Real>::x_from_y(Real y) const
{
if (is_horizontal())
throw std::runtime_error("x_from_y called on horizontal line");
@@ -131,7 +138,8 @@ namespace md {
return x_slope() * y + x_intercept();
}
- Real DualPoint::y_from_x(Real x) const
+ template<class Real>
+ Real DualPoint<Real>::y_from_x(Real x) const
{
if (is_vertical())
throw std::runtime_error("x_from_y called on horizontal line");
@@ -139,17 +147,20 @@ namespace md {
return y_slope() * x + y_intercept();
}
- bool DualPoint::is_horizontal() const
+ template<class Real>
+ bool DualPoint<Real>::is_horizontal() const
{
return is_flat() and lambda() == 0;
}
- bool DualPoint::is_vertical() const
+ template<class Real>
+ bool DualPoint<Real>::is_vertical() const
{
return is_steep() and lambda() == 0;
}
-
- bool DualPoint::contains(Point p) const
+
+ template<class Real>
+ bool DualPoint<Real>::contains(Point<Real> p) const
{
if (is_vertical())
return p.x == x_from_y(p.y);
@@ -157,7 +168,8 @@ namespace md {
return p.y == y_from_x(p.x);
}
- bool DualPoint::goes_below(Point p) const
+ template<class Real>
+ bool DualPoint<Real>::goes_below(Point<Real> p) const
{
if (is_vertical())
return p.x <= x_from_y(p.y);
@@ -165,7 +177,8 @@ namespace md {
return p.y >= y_from_x(p.x);
}
- bool DualPoint::goes_above(Point p) const
+ template<class Real>
+ bool DualPoint<Real>::goes_above(Point<Real> p) const
{
if (is_vertical())
return p.x >= x_from_y(p.y);
@@ -173,9 +186,10 @@ namespace md {
return p.y <= y_from_x(p.x);
}
- Point DualPoint::push(Point p) const
+ template<class Real>
+ Point<Real> DualPoint<Real>::push(Point<Real> p) const
{
- Point result;
+ Point<Real> result;
// if line is below p, we push horizontally
bool horizontal_push = goes_below(p);
if (is_x_type()) {
@@ -225,7 +239,8 @@ namespace md {
return result;
}
- Real DualPoint::weighted_push(Point p) const
+ template<class Real>
+ Real DualPoint<Real>::weighted_push(Point<Real> p) const
{
// if line is below p, we push horizontally
bool horizontal_push = goes_below(p);
@@ -267,7 +282,8 @@ namespace md {
}
}
- bool DualPoint::operator==(const DualPoint& other) const
+ template<class Real>
+ bool DualPoint<Real>::operator==(const DualPoint<Real>& other) const
{
return axis_type() == other.axis_type() and
angle_type() == other.angle_type() and
@@ -275,7 +291,8 @@ namespace md {
lambda() == other.lambda();
}
- Real DualPoint::weight() const
+ template<class Real>
+ Real DualPoint<Real>::weight() const
{
return lambda_ / sqrt(1 + lambda_ * lambda_);
}
diff --git a/matching/include/matching_distance.h b/matching/include/matching_distance.h
index bb10203..5be34c7 100644
--- a/matching/include/matching_distance.h
+++ b/matching/include/matching_distance.h
@@ -4,9 +4,10 @@
#include <limits>
#include <utility>
#include <ostream>
+#include <chrono>
+#include <tuple>
+#include <algorithm>
-#include "spdlog/spdlog.h"
-#include "spdlog/fmt/ostr.h"
#include "common_defs.h"
#include "cell_with_value.h"
@@ -17,12 +18,15 @@
#include "bifiltration.h"
#include "bottleneck.h"
-namespace spd = spdlog;
-
namespace md {
- using HeatMap = std::map<DualPoint, Real>;
- using HeatMaps = std::map<int, HeatMap>;
+#ifdef MD_PRINT_HEAT_MAP
+ template<class Real>
+ using HeatMap = std::map<DualPoint<Real>, Real>;
+
+ template<class Real>
+ using HeatMaps = std::map<int, HeatMap<Real>>;
+#endif
enum class BoundStrategy {
bruteforce,
@@ -39,18 +43,107 @@ namespace md {
upper_bound
};
- std::ostream& operator<<(std::ostream& os, const BoundStrategy& s);
-
- std::ostream& operator<<(std::ostream& os, const TraverseStrategy& s);
-
- std::istream& operator>>(std::istream& is, BoundStrategy& s);
-
- std::istream& operator>>(std::istream& is, TraverseStrategy& s);
-
- BoundStrategy bs_from_string(std::string s);
-
- TraverseStrategy ts_from_string(std::string s);
-
+ inline std::ostream& operator<<(std::ostream& os, const BoundStrategy& s)
+ {
+ switch(s) {
+ case BoundStrategy::bruteforce :
+ os << "bruteforce";
+ break;
+ case BoundStrategy::local_dual_bound :
+ os << "local_grob";
+ break;
+ case BoundStrategy::local_combined :
+ os << "local_combined";
+ break;
+ case BoundStrategy::local_dual_bound_refined :
+ os << "local_refined";
+ break;
+ case BoundStrategy::local_dual_bound_for_each_point :
+ os << "local_for_each_point";
+ break;
+ default:
+ os << "FORGOTTEN BOUND STRATEGY";
+ }
+ return os;
+ }
+
+ inline std::ostream& operator<<(std::ostream& os, const TraverseStrategy& s)
+ {
+ switch(s) {
+ case TraverseStrategy::depth_first :
+ os << "DFS";
+ break;
+ case TraverseStrategy::breadth_first :
+ os << "BFS";
+ break;
+ case TraverseStrategy::breadth_first_value :
+ os << "BFS-VAL";
+ break;
+ case TraverseStrategy::upper_bound :
+ os << "UB";
+ break;
+ default:
+ os << "FORGOTTEN TRAVERSE STRATEGY";
+ }
+ return os;
+ }
+
+ inline std::istream& operator>>(std::istream& is, TraverseStrategy& s)
+ {
+ std::string ss;
+ is >> ss;
+ if (ss == "DFS") {
+ s = TraverseStrategy::depth_first;
+ } else if (ss == "BFS") {
+ s = TraverseStrategy::breadth_first;
+ } else if (ss == "BFS-VAL") {
+ s = TraverseStrategy::breadth_first_value;
+ } else if (ss == "UB") {
+ s = TraverseStrategy::upper_bound;
+ } else {
+ throw std::runtime_error("UNKNOWN TRAVERSE STRATEGY");
+ }
+ return is;
+ }
+
+
+ inline std::istream& operator>>(std::istream& is, BoundStrategy& s)
+ {
+ std::string ss;
+ is >> ss;
+ if (ss == "bruteforce") {
+ s = BoundStrategy::bruteforce;
+ } else if (ss == "local_grob") {
+ s = BoundStrategy::local_dual_bound;
+ } else if (ss == "local_combined") {
+ s = BoundStrategy::local_combined;
+ } else if (ss == "local_refined") {
+ s = BoundStrategy::local_dual_bound_refined;
+ } else if (ss == "local_for_each_point") {
+ s = BoundStrategy::local_dual_bound_for_each_point;
+ } else {
+ throw std::runtime_error("UNKNOWN BOUND STRATEGY");
+ }
+ return is;
+ }
+
+ inline BoundStrategy bs_from_string(std::string s)
+ {
+ std::stringstream ss(s);
+ BoundStrategy result;
+ ss >> result;
+ return result;
+ }
+
+ inline TraverseStrategy ts_from_string(std::string s)
+ {
+ std::stringstream ss(s);
+ TraverseStrategy result;
+ ss >> result;
+ return result;
+ }
+
+ template<class Real>
struct CalculationParams {
static constexpr int ALL_DIMENSIONS = -1;
@@ -75,22 +168,22 @@ namespace md {
// print statistics on each quad-tree level
bool print_stats { false };
-#ifdef PRINT_HEAT_MAP
+#ifdef MD_PRINT_HEAT_MAP
HeatMaps heat_maps;
#endif
};
- template<class DiagramProvider>
+ template<class Real_, class DiagramProvider>
class DistanceCalculator {
- using DualBox = md::DualBox;
- using CellValueVector = std::vector<CellWithValue>;
+ using Real = Real_;
+ using CellValueVector = std::vector<CellWithValue<Real>>;
public:
DistanceCalculator(const DiagramProvider& a,
const DiagramProvider& b,
- CalculationParams& params);
+ CalculationParams<Real>& params);
Real distance();
@@ -100,7 +193,7 @@ namespace md {
DiagramProvider module_a_;
DiagramProvider module_b_;
- CalculationParams& params_;
+ CalculationParams<Real>& params_;
int n_hera_calls_;
std::map<int, int> n_hera_calls_per_level_;
@@ -112,65 +205,83 @@ namespace md {
CellValueVector get_initial_dual_grid(Real& lower_bound);
+#ifdef MD_PRINT_HEAT_MAP
void heatmap_in_dimension(int dim, int depth);
+#endif
Real get_max_x(int module) const;
Real get_max_y(int module) const;
- void set_cell_central_value(CellWithValue& dual_cell);
+ void set_cell_central_value(CellWithValue<Real>& dual_cell);
Real get_distance();
Real get_distance_pq();
- // temporary, to try priority queue
- Real get_max_possible_value(const CellWithValue* first_cell_ptr, int n_cells);
+ Real get_max_possible_value(const CellWithValue<Real>* first_cell_ptr, int n_cells);
- Real get_upper_bound(const CellWithValue& dual_cell, Real good_enough_upper_bound) const;
+ Real get_upper_bound(const CellWithValue<Real>& dual_cell, Real good_enough_upper_bound) const;
- Real get_single_dgm_bound(const CellWithValue& dual_cell, ValuePoint vp, int module,
+ Real get_single_dgm_bound(const CellWithValue<Real>& dual_cell, ValuePoint vp, int module,
Real good_enough_value) const;
// this bound depends only on dual box
- Real get_local_dual_bound(int module, const DualBox& dual_box) const;
+ Real get_local_dual_bound(int module, const DualBox<Real>& dual_box) const;
- Real get_local_dual_bound(const DualBox& dual_box) const;
+ Real get_local_dual_bound(const DualBox<Real>& dual_box) const;
// this bound depends only on dual box, is more accurate
- Real get_local_refined_bound(int module, const md::DualBox& dual_box) const;
+ Real get_local_refined_bound(int module, const DualBox<Real>& dual_box) const;
- Real get_local_refined_bound(const md::DualBox& dual_box) const;
+ Real get_local_refined_bound(const DualBox<Real>& dual_box) const;
Real get_good_enough_upper_bound(Real lower_bound) const;
- Real
- get_max_displacement_single_point(const CellWithValue& dual_cell, ValuePoint value_point, const Point& p) const;
+ Real get_max_displacement_single_point(const CellWithValue<Real>& dual_cell, ValuePoint value_point,
+ const Point<Real>& p) const;
- void check_upper_bound(const CellWithValue& dual_cell) const;
+ void check_upper_bound(const CellWithValue<Real>& dual_cell) const;
- Real distance_on_line(DualPoint line);
- Real distance_on_line_const(DualPoint line) const;
+ Real distance_on_line(DualPoint<Real> line);
+ Real distance_on_line_const(DualPoint<Real> line) const;
Real current_error(Real lower_bound, Real upper_bound);
};
- Real matching_distance(const Bifiltration& bif_a, const Bifiltration& bif_b, CalculationParams& params);
+ template<class Real>
+ Real matching_distance(const Bifiltration<Real>& bif_a, const Bifiltration<Real>& bif_b,
+ CalculationParams<Real>& params);
- Real matching_distance(const ModulePresentation& mod_a, const ModulePresentation& mod_b, CalculationParams& params);
+ template<class Real>
+ Real matching_distance(const ModulePresentation<Real>& mod_a, const ModulePresentation<Real>& mod_b,
+ CalculationParams<Real>& params);
// for upper bound experiment
struct UbExperimentRecord {
- Real error;
- Real lower_bound;
- Real upper_bound;
- CellWithValue cell;
+ double error;
+ double lower_bound;
+ double upper_bound;
+ CellWithValue<double> cell;
long long int time;
long long int n_hera_calls;
};
- std::ostream& operator<<(std::ostream& os, const UbExperimentRecord& r);
+ inline std::ostream& operator<<(std::ostream& os, const UbExperimentRecord& r)
+ {
+ os << r.time << "\t" << r.n_hera_calls << "\t" << r.error << "\t" << r.lower_bound << "\t" << r.upper_bound;
+ return os;
+ }
+
+
+ template<class K, class V>
+ void print_map(const std::map<K, V>& dic)
+ {
+ for(const auto kv : dic) {
+ fmt::print("{} -> {}\n", kv.first, kv.second);
+ }
+ }
-}
+} // namespace md
#include "matching_distance.hpp"
diff --git a/matching/include/matching_distance.hpp b/matching/include/matching_distance.hpp
index d2d2fbc..48c8464 100644
--- a/matching/include/matching_distance.hpp
+++ b/matching/include/matching_distance.hpp
@@ -1,34 +1,26 @@
namespace md {
- template<class K, class V>
- void print_map(const std::map<K, V>& dic)
- {
- for(const auto kv : dic) {
- fmt::print("{} -> {}\n", kv.first, kv.second);
- }
- }
-
- template<class T>
- void DistanceCalculator<T>::check_upper_bound(const CellWithValue& dual_cell) const
+ template<class R, class T>
+ void DistanceCalculator<R, T>::check_upper_bound(const CellWithValue<R>& dual_cell) const
{
spd::debug("Enter check_get_max_delta_on_cell");
const int n_samples_lambda = 100;
const int n_samples_mu = 100;
- DualBox db = dual_cell.dual_box();
- Real min_lambda = db.lambda_min();
- Real max_lambda = db.lambda_max();
- Real min_mu = db.mu_min();
- Real max_mu = db.mu_max();
-
- Real h_lambda = (max_lambda - min_lambda) / n_samples_lambda;
- Real h_mu = (max_mu - min_mu) / n_samples_mu;
+ DualBox<R> db = dual_cell.dual_box();
+ R min_lambda = db.lambda_min();
+ R max_lambda = db.lambda_max();
+ R min_mu = db.mu_min();
+ R max_mu = db.mu_max();
+
+ R h_lambda = (max_lambda - min_lambda) / n_samples_lambda;
+ R h_mu = (max_mu - min_mu) / n_samples_mu;
for(int i = 1; i < n_samples_lambda; ++i) {
for(int j = 1; j < n_samples_mu; ++j) {
- Real lambda = min_lambda + i * h_lambda;
- Real mu = min_mu + j * h_mu;
- DualPoint l(db.axis_type(), db.angle_type(), lambda, mu);
- Real other_result = distance_on_line_const(l);
- Real diff = fabs(dual_cell.stored_upper_bound() - other_result);
+ R lambda = min_lambda + i * h_lambda;
+ R mu = min_mu + j * h_mu;
+ DualPoint<R> l(db.axis_type(), db.angle_type(), lambda, mu);
+ R other_result = distance_on_line_const(l);
+ R diff = fabs(dual_cell.stored_upper_bound() - other_result);
if (other_result > dual_cell.stored_upper_bound()) {
spd::error(
"in check_upper_bound, upper_bound = {}, other_result = {}, diff = {}\ndual_cell = {}",
@@ -42,10 +34,10 @@ namespace md {
// for all lines l, l' inside dual box,
// find the upper bound on the difference of weighted pushes of p
- template<class T>
- Real
- DistanceCalculator<T>::get_max_displacement_single_point(const CellWithValue& dual_cell, ValuePoint vp,
- const Point& p) const
+ template<class R, class T>
+ R
+ DistanceCalculator<R, T>::get_max_displacement_single_point(const CellWithValue<R>& dual_cell, ValuePoint vp,
+ const Point<R>& p) const
{
assert(p.x >= 0 && p.y >= 0);
@@ -53,15 +45,15 @@ namespace md {
std::vector<long long int> debug_ids = {3, 13, 54, 218, 350, 382, 484, 795, 2040, 8415, 44076};
bool debug = false; // std::find(debug_ids.begin(), debug_ids.end(), dual_cell.id) != debug_ids.end();
#endif
- DualPoint line = dual_cell.value_point(vp);
- const Real base_value = line.weighted_push(p);
+ DualPoint<R> line = dual_cell.value_point(vp);
+ const R base_value = line.weighted_push(p);
spd::debug("Enter get_max_displacement_single_point, p = {},\ndual_cell = {},\nline = {}, base_value = {}\n", p,
dual_cell, line, base_value);
- Real result = 0.0;
- for(DualPoint dp : dual_cell.dual_box().critical_points(p)) {
- Real dp_value = dp.weighted_push(p);
+ R result = 0.0;
+ for(DualPoint<R> dp : dual_cell.dual_box().critical_points(p)) {
+ R dp_value = dp.weighted_push(p);
spd::debug(
"In get_max_displacement_single_point, p = {}, critical dp = {},\ndp_value = {}, diff = {},\ndual_cell = {}\n",
p, dp, dp_value, fabs(base_value - dp_value), dual_cell);
@@ -69,15 +61,15 @@ namespace md {
}
#ifdef MD_DO_FULL_CHECK
- DualBox db = dual_cell.dual_box();
- std::uniform_real_distribution<Real> dlambda(db.lambda_min(), db.lambda_max());
- std::uniform_real_distribution<Real> dmu(db.mu_min(), db.mu_max());
+ auto db = dual_cell.dual_box();
+ std::uniform_real_distribution<R> dlambda(db.lambda_min(), db.lambda_max());
+ std::uniform_real_distribution<R> dmu(db.mu_min(), db.mu_max());
std::mt19937 gen(1);
for(int i = 0; i < 1000; ++i) {
- Real lambda = dlambda(gen);
- Real mu = dmu(gen);
- DualPoint dp_random { db.axis_type(), db.angle_type(), lambda, mu };
- Real dp_value = dp_random.weighted_push(p);
+ R lambda = dlambda(gen);
+ R mu = dmu(gen);
+ DualPoint<R> dp_random { db.axis_type(), db.angle_type(), lambda, mu };
+ R dp_value = dp_random.weighted_push(p);
if (fabs(base_value - dp_value) > result) {
spd::error("in get_max_displacement_single_point, p = {}, vp = {}\ndb = {}\nresult = {}, base_value = {}, dp_value = {}, dp_random = {}",
p, vp, db, result, base_value, dp_value, dp_random);
@@ -89,12 +81,12 @@ namespace md {
return result;
}
- template<class T>
- typename DistanceCalculator<T>::CellValueVector DistanceCalculator<T>::get_initial_dual_grid(Real& lower_bound)
+ template<class R, class T>
+ typename DistanceCalculator<R, T>::CellValueVector DistanceCalculator<R, T>::get_initial_dual_grid(R& lower_bound)
{
CellValueVector result = get_refined_grid(params_.initialization_depth, false, true);
- lower_bound = -1.0;
+ lower_bound = -1;
for(const auto& dc : result) {
lower_bound = std::max(lower_bound, dc.max_corner_value());
}
@@ -102,8 +94,8 @@ namespace md {
assert(lower_bound >= 0);
for(auto& dual_cell : result) {
- Real good_enough_ub = get_good_enough_upper_bound(lower_bound);
- Real max_value_on_cell = get_upper_bound(dual_cell, good_enough_ub);
+ R good_enough_ub = get_good_enough_upper_bound(lower_bound);
+ R max_value_on_cell = get_upper_bound(dual_cell, good_enough_ub);
dual_cell.set_max_possible_value(max_value_on_cell);
#ifdef MD_DO_FULL_CHECK
@@ -116,39 +108,39 @@ namespace md {
return result;
}
- template<class T>
- typename DistanceCalculator<T>::CellValueVector
- DistanceCalculator<T>::get_refined_grid(int init_depth, bool calculate_on_intermediate, bool calculate_on_last)
+ template<class R, class T>
+ typename DistanceCalculator<R, T>::CellValueVector
+ DistanceCalculator<R, T>::get_refined_grid(int init_depth, bool calculate_on_intermediate, bool calculate_on_last)
{
- const Real y_max = std::max(module_a_.max_y(), module_b_.max_y());
- const Real x_max = std::max(module_a_.max_x(), module_b_.max_x());
+ const R y_max = std::max(module_a_.max_y(), module_b_.max_y());
+ const R x_max = std::max(module_a_.max_x(), module_b_.max_x());
- const Real lambda_min = 0;
- const Real lambda_max = 1;
+ const R lambda_min = 0;
+ const R lambda_max = 1;
- const Real mu_min = 0;
+ const R mu_min = 0;
- DualBox x_flat(DualPoint(AxisType::x_type, AngleType::flat, lambda_min, mu_min),
- DualPoint(AxisType::x_type, AngleType::flat, lambda_max, x_max));
+ DualBox<R> x_flat(DualPoint<R>(AxisType::x_type, AngleType::flat, lambda_min, mu_min),
+ DualPoint<R>(AxisType::x_type, AngleType::flat, lambda_max, x_max));
- DualBox x_steep(DualPoint(AxisType::x_type, AngleType::steep, lambda_min, mu_min),
- DualPoint(AxisType::x_type, AngleType::steep, lambda_max, x_max));
+ DualBox<R> x_steep(DualPoint<R>(AxisType::x_type, AngleType::steep, lambda_min, mu_min),
+ DualPoint<R>(AxisType::x_type, AngleType::steep, lambda_max, x_max));
- DualBox y_flat(DualPoint(AxisType::y_type, AngleType::flat, lambda_min, mu_min),
- DualPoint(AxisType::y_type, AngleType::flat, lambda_max, y_max));
+ DualBox<R> y_flat(DualPoint<R>(AxisType::y_type, AngleType::flat, lambda_min, mu_min),
+ DualPoint<R>(AxisType::y_type, AngleType::flat, lambda_max, y_max));
- DualBox y_steep(DualPoint(AxisType::y_type, AngleType::steep, lambda_min, mu_min),
- DualPoint(AxisType::y_type, AngleType::steep, lambda_max, y_max));
+ DualBox<R> y_steep(DualPoint<R>(AxisType::y_type, AngleType::steep, lambda_min, mu_min),
+ DualPoint<R>(AxisType::y_type, AngleType::steep, lambda_max, y_max));
- CellWithValue x_flat_cell(x_flat, 0);
- CellWithValue x_steep_cell(x_steep, 0);
- CellWithValue y_flat_cell(y_flat, 0);
- CellWithValue y_steep_cell(y_steep, 0);
+ CellWithValue<R> x_flat_cell(x_flat, 0);
+ CellWithValue<R> x_steep_cell(x_steep, 0);
+ CellWithValue<R> y_flat_cell(y_flat, 0);
+ CellWithValue<R> y_steep_cell(y_steep, 0);
if (init_depth == 0) {
- DualPoint diagonal_x_flat(AxisType::x_type, AngleType::flat, 1, 0);
+ DualPoint<R> diagonal_x_flat(AxisType::x_type, AngleType::flat, 1, 0);
- Real diagonal_value = distance_on_line(diagonal_x_flat);
+ R diagonal_value = distance_on_line(diagonal_x_flat);
n_hera_calls_per_level_[0]++;
x_flat_cell.set_value_at(ValuePoint::lower_right, diagonal_value);
@@ -162,7 +154,7 @@ namespace md {
x_steep_cell.id = 2;
y_flat_cell.id = 3;
y_steep_cell.id = 4;
- CellWithValue::max_id = 4;
+ CellWithValue<R>::max_id = 4;
#endif
CellValueVector result {x_flat_cell, x_steep_cell, y_flat_cell, y_steep_cell};
@@ -189,10 +181,10 @@ namespace md {
return result;
}
- template<class T>
- DistanceCalculator<T>::DistanceCalculator(const T& a,
+ template<class R, class T>
+ DistanceCalculator<R, T>::DistanceCalculator(const T& a,
const T& b,
- CalculationParams& params)
+ CalculationParams<R>& params)
:
module_a_(a),
module_b_(b),
@@ -213,33 +205,33 @@ namespace md {
module_a_.max_x(), module_a_.max_y(), module_b_.max_x(), module_b_.max_y());
}
- template<class T>
- Real DistanceCalculator<T>::get_max_x(int module) const
+ template<class R, class T>
+ R DistanceCalculator<R, T>::get_max_x(int module) const
{
return (module == 0) ? module_a_.max_x() : module_b_.max_x();
}
- template<class T>
- Real DistanceCalculator<T>::get_max_y(int module) const
+ template<class R, class T>
+ R DistanceCalculator<R, T>::get_max_y(int module) const
{
return (module == 0) ? module_a_.max_y() : module_b_.max_y();
}
- template<class T>
- Real
- DistanceCalculator<T>::get_local_refined_bound(const md::DualBox& dual_box) const
+ template<class R, class T>
+ R
+ DistanceCalculator<R, T>::get_local_refined_bound(const DualBox<R>& dual_box) const
{
return get_local_refined_bound(0, dual_box) + get_local_refined_bound(1, dual_box);
}
- template<class T>
- Real
- DistanceCalculator<T>::get_local_refined_bound(int module, const md::DualBox& dual_box) const
+ template<class R, class T>
+ R
+ DistanceCalculator<R, T>::get_local_refined_bound(int module, const DualBox<R>& dual_box) const
{
spd::debug("Enter get_local_refined_bound, dual_box = {}", dual_box);
- Real d_lambda = dual_box.lambda_max() - dual_box.lambda_min();
- Real d_mu = dual_box.mu_max() - dual_box.mu_min();
- Real result;
+ R d_lambda = dual_box.lambda_max() - dual_box.lambda_min();
+ R d_mu = dual_box.mu_max() - dual_box.mu_min();
+ R result;
if (dual_box.axis_type() == AxisType::x_type) {
if (dual_box.is_flat()) {
result = dual_box.lambda_max() * d_mu + (get_max_x(module) - dual_box.mu_min()) * d_lambda;
@@ -258,11 +250,11 @@ namespace md {
return result;
}
- template<class T>
- Real DistanceCalculator<T>::get_local_dual_bound(int module, const md::DualBox& dual_box) const
+ template<class R, class T>
+ R DistanceCalculator<R, T>::get_local_dual_bound(int module, const DualBox<R>& dual_box) const
{
- Real dlambda = dual_box.lambda_max() - dual_box.lambda_min();
- Real dmu = dual_box.mu_max() - dual_box.mu_min();
+ R dlambda = dual_box.lambda_max() - dual_box.lambda_min();
+ R dmu = dual_box.mu_max() - dual_box.mu_min();
if (dual_box.is_flat()) {
return get_max_x(module) * dlambda + dmu;
@@ -271,20 +263,20 @@ namespace md {
}
}
- template<class T>
- Real DistanceCalculator<T>::get_local_dual_bound(const md::DualBox& dual_box) const
+ template<class R, class T>
+ R DistanceCalculator<R, T>::get_local_dual_bound(const DualBox<R>& dual_box) const
{
return get_local_dual_bound(0, dual_box) + get_local_dual_bound(1, dual_box);
}
- template<class T>
- Real DistanceCalculator<T>::get_upper_bound(const CellWithValue& dual_cell, Real good_enough_ub) const
+ template<class R, class T>
+ R DistanceCalculator<R, T>::get_upper_bound(const CellWithValue<R>& dual_cell, R good_enough_ub) const
{
assert(good_enough_ub >= 0);
switch(params_.bound_strategy) {
case BoundStrategy::bruteforce:
- return std::numeric_limits<Real>::max();
+ return std::numeric_limits<R>::max();
case BoundStrategy::local_dual_bound:
return dual_cell.min_value() + get_local_dual_bound(dual_cell.dual_box());
@@ -293,7 +285,7 @@ namespace md {
return dual_cell.min_value() + get_local_refined_bound(dual_cell.dual_box());
case BoundStrategy::local_combined: {
- Real cheap_upper_bound = dual_cell.min_value() + get_local_refined_bound(dual_cell.dual_box());
+ R cheap_upper_bound = dual_cell.min_value() + get_local_refined_bound(dual_cell.dual_box());
if (cheap_upper_bound < good_enough_ub) {
return cheap_upper_bound;
} else {
@@ -302,14 +294,14 @@ namespace md {
}
case BoundStrategy::local_dual_bound_for_each_point: {
- Real result = std::numeric_limits<Real>::max();
+ R result = std::numeric_limits<R>::max();
for(ValuePoint vp : k_corner_vps) {
if (not dual_cell.has_value_at(vp)) {
continue;
}
- Real base_value = dual_cell.value_at(vp);
- Real bound_dgm_a = get_single_dgm_bound(dual_cell, vp, 0, good_enough_ub);
+ R base_value = dual_cell.value_at(vp);
+ R bound_dgm_a = get_single_dgm_bound(dual_cell, vp, 0, good_enough_ub);
if (params_.stop_asap and bound_dgm_a + base_value >= good_enough_ub) {
// we want to return a valid upper bound, not just something that will prevent discarding the cell
@@ -318,8 +310,8 @@ namespace md {
return dual_cell.min_value() + get_local_refined_bound(dual_cell.dual_box());
}
- Real bound_dgm_b = get_single_dgm_bound(dual_cell, vp, 1,
- std::max(Real(0), good_enough_ub - bound_dgm_a));
+ R bound_dgm_b = get_single_dgm_bound(dual_cell, vp, 1,
+ std::max(R(0), good_enough_ub - bound_dgm_a));
result = std::min(result, base_value + bound_dgm_a + bound_dgm_b);
@@ -336,19 +328,19 @@ namespace md {
}
}
// to suppress compiler warning
- return std::numeric_limits<Real>::max();
+ return std::numeric_limits<R>::max();
}
// find maximal displacement of weighted points of m for all lines in dual_box
- template<class T>
- Real
- DistanceCalculator<T>::get_single_dgm_bound(const CellWithValue& dual_cell,
+ template<class R, class T>
+ R
+ DistanceCalculator<R, T>::get_single_dgm_bound(const CellWithValue<R>& dual_cell,
ValuePoint vp,
int module,
- [[maybe_unused]] Real good_enough_value) const
+ R good_enough_value) const
{
- Real result = 0;
- Point max_point;
+ R result = 0;
+ Point<R> max_point;
spd::debug(
"Enter get_single_dgm_bound, module = {}, dual_cell = {}, vp = {}, good_enough_value = {}, stop_asap = {}\n",
@@ -358,7 +350,7 @@ namespace md {
for(const auto& position : m.positions()) {
spd::debug("in get_single_dgm_bound, simplex = {}\n", position);
- Real x = get_max_displacement_single_point(dual_cell, vp, position);
+ R x = get_max_displacement_single_point(dual_cell, vp, position);
spd::debug("In get_single_dgm_bound, point = {}, displacement = {}", position, x);
@@ -385,30 +377,30 @@ namespace md {
return result;
}
- template<class T>
- Real DistanceCalculator<T>::distance()
+ template<class R, class T>
+ R DistanceCalculator<R, T>::distance()
{
return get_distance_pq();
}
// calculate weighted bottleneneck distance between slices on line
// increments hera calls counter
- template<class T>
- Real DistanceCalculator<T>::distance_on_line(DualPoint line)
+ template<class R, class T>
+ R DistanceCalculator<R, T>::distance_on_line(DualPoint<R> line)
{
++n_hera_calls_;
- Real result = distance_on_line_const(line);
+ R result = distance_on_line_const(line);
return result;
}
- template<class T>
- Real DistanceCalculator<T>::distance_on_line_const(DualPoint line) const
+ template<class R, class T>
+ R DistanceCalculator<R, T>::distance_on_line_const(DualPoint<R> line) const
{
// TODO: think about this - how to call Hera
auto dgm_a = module_a_.weighted_slice_diagram(line);
auto dgm_b = module_b_.weighted_slice_diagram(line);
- Real result;
- if (params_.hera_epsilon > static_cast<Real>(0)) {
+ R result;
+ if (params_.hera_epsilon > static_cast<R>(0)) {
result = hera::bottleneckDistApprox(dgm_a, dgm_b, params_.hera_epsilon) / ( params_.hera_epsilon + 1);
} else {
result = hera::bottleneckDistExact(dgm_a, dgm_b);
@@ -423,10 +415,10 @@ namespace md {
return result;
}
- template<class T>
- Real DistanceCalculator<T>::get_good_enough_upper_bound(Real lower_bound) const
+ template<class R, class T>
+ R DistanceCalculator<R, T>::get_good_enough_upper_bound(R lower_bound) const
{
- Real result;
+ R result;
// in upper_bound strategy we only prune cells if they cannot improve the lower bound,
// otherwise the experiment is supposed to run indefinitely
if (params_.traverse_strategy == TraverseStrategy::upper_bound) {
@@ -440,14 +432,14 @@ namespace md {
// helper function
// calculate weighted bt distance on cell center,
// assign distance value to cell, keep it in heat_map, and return
- template<class T>
- void DistanceCalculator<T>::set_cell_central_value(CellWithValue& dual_cell)
+ template<class R, class T>
+ void DistanceCalculator<R, T>::set_cell_central_value(CellWithValue<R>& dual_cell)
{
- DualPoint central_line {dual_cell.center()};
+ DualPoint<R> central_line {dual_cell.center()};
spd::debug("In set_cell_central_value, processing dual cell = {}, line = {}", dual_cell.dual_box(),
central_line);
- Real new_value = distance_on_line(central_line);
+ R new_value = distance_on_line(central_line);
n_hera_calls_per_level_[dual_cell.level() + 1]++;
dual_cell.set_value_at(ValuePoint::center, new_value);
params_.actual_max_depth = std::max(params_.actual_max_depth, dual_cell.level() + 1);
@@ -472,10 +464,10 @@ namespace md {
// assumes that the underlying container is vector!
// cell_ptr: pointer to the first element in queue
// n_cells: queue size
- template<class T>
- Real DistanceCalculator<T>::get_max_possible_value(const CellWithValue* cell_ptr, int n_cells)
+ template<class R, class T>
+ R DistanceCalculator<R, T>::get_max_possible_value(const CellWithValue<R>* cell_ptr, int n_cells)
{
- Real result = (n_cells > 0) ? cell_ptr->stored_upper_bound() : 0;
+ R result = (n_cells > 0) ? cell_ptr->stored_upper_bound() : 0;
for(int i = 0; i < n_cells; ++i, ++cell_ptr) {
result = std::max(result, cell_ptr->stored_upper_bound());
}
@@ -485,11 +477,11 @@ namespace md {
// helper function:
// return current error from lower and upper bounds
// and save it in params_ (hence not const)
- template<class T>
- Real DistanceCalculator<T>::current_error(Real lower_bound, Real upper_bound)
+ template<class R, class T>
+ R DistanceCalculator<R, T>::current_error(R lower_bound, R upper_bound)
{
- Real current_error = (lower_bound > 0.0) ? (upper_bound - lower_bound) / lower_bound
- : std::numeric_limits<Real>::max();
+ R current_error = (lower_bound > 0.0) ? (upper_bound - lower_bound) / lower_bound
+ : std::numeric_limits<R>::max();
params_.actual_error = current_error;
@@ -505,8 +497,8 @@ namespace md {
// use priority queue to store dual cells
// comparison function depends on the strategies in params_
// ressets hera calls counter
- template<class T>
- Real DistanceCalculator<T>::get_distance_pq()
+ template<class R, class T>
+ R DistanceCalculator<R, T>::get_distance_pq()
{
std::map<int, long> n_cells_considered;
std::map<int, long> n_cells_pushed_into_queue;
@@ -527,26 +519,26 @@ namespace md {
// if cell is too deep and is not pushed into queue,
// we still need to take its max value into account;
// the max over such cells is stored in max_result_on_too_fine_cells
- Real upper_bound_on_deep_cells = -1;
+ R upper_bound_on_deep_cells = -1;
spd::debug("Started iterations in dual space, delta = {}, bound_strategy = {}", params_.delta,
params_.bound_strategy);
// user-defined less lambda function
// to regulate priority queue depending on strategy
- auto dual_cell_less = [this](const CellWithValue& a, const CellWithValue& b) {
+ auto dual_cell_less = [this](const CellWithValue<R>& a, const CellWithValue<R>& b) {
int a_level = a.level();
int b_level = b.level();
- Real a_value = a.max_corner_value();
- Real b_value = b.max_corner_value();
- Real a_ub = a.stored_upper_bound();
- Real b_ub = b.stored_upper_bound();
+ R a_value = a.max_corner_value();
+ R b_value = b.max_corner_value();
+ R a_ub = a.stored_upper_bound();
+ R b_ub = b.stored_upper_bound();
if (this->params_.traverse_strategy == TraverseStrategy::upper_bound and
(not a.has_max_possible_value() or not b.has_max_possible_value())) {
throw std::runtime_error("no upper bound on cell");
}
- DualPoint a_lower_left = a.dual_box().lower_left();
- DualPoint b_lower_left = b.dual_box().lower_left();
+ DualPoint<R> a_lower_left = a.dual_box().lower_left();
+ DualPoint<R> b_lower_left = b.dual_box().lower_left();
switch(this->params_.traverse_strategy) {
// in both breadth_first searches we want coarser cells
@@ -569,24 +561,24 @@ namespace md {
}
};
- std::priority_queue<CellWithValue, CellValueVector, decltype(dual_cell_less)> dual_cells_queue(
+ std::priority_queue<CellWithValue<R>, CellValueVector, decltype(dual_cell_less)> dual_cells_queue(
dual_cell_less);
// weighted bt distance on the center of current cell
- Real lower_bound = std::numeric_limits<Real>::min();
+ R lower_bound = std::numeric_limits<R>::min();
// init pq and lower bound
for(auto& init_cell : get_initial_dual_grid(lower_bound)) {
dual_cells_queue.push(init_cell);
}
- Real upper_bound = get_max_possible_value(&dual_cells_queue.top(), dual_cells_queue.size());
+ R upper_bound = get_max_possible_value(&dual_cells_queue.top(), dual_cells_queue.size());
std::vector<UbExperimentRecord> ub_experiment_results;
while(not dual_cells_queue.empty()) {
- CellWithValue dual_cell = dual_cells_queue.top();
+ CellWithValue<R> dual_cell = dual_cells_queue.top();
dual_cells_queue.pop();
assert(dual_cell.has_corner_value()
and dual_cell.has_max_possible_value()
@@ -620,7 +612,7 @@ namespace md {
// until now, dual_cell knows its value in one of its corners
// new_value will be the weighted distance at its center
set_cell_central_value(dual_cell);
- Real new_value = dual_cell.value_at(ValuePoint::center);
+ R new_value = dual_cell.value_at(ValuePoint::center);
lower_bound = std::max(new_value, lower_bound);
spd::debug("Processed cell = {}, weighted value = {}, lower_bound = {}", dual_cell, new_value, lower_bound);
@@ -638,11 +630,11 @@ namespace md {
throw std::runtime_error("no value on cell");
// if delta is smaller than good_enough_value, it allows to prune cell
- Real good_enough_ub = get_good_enough_upper_bound(lower_bound);
+ R good_enough_ub = get_good_enough_upper_bound(lower_bound);
// upper bound of the parent holds for refined_cell
// and can sometimes be smaller!
- Real upper_bound_on_refined_cell = std::min(dual_cell.stored_upper_bound(),
+ R upper_bound_on_refined_cell = std::min(dual_cell.stored_upper_bound(),
get_upper_bound(refined_cell, good_enough_ub));
spd::debug("upper_bound_on_refined_cell = {}, dual_cell.stored_upper_bound = {}, get_upper_bound = {}",
@@ -774,10 +766,46 @@ namespace md {
return lower_bound;
}
- template<class T>
- int DistanceCalculator<T>::get_hera_calls_number() const
+ template<class R, class T>
+ int DistanceCalculator<R, T>::get_hera_calls_number() const
{
return n_hera_calls_;
}
-} \ No newline at end of file
+ template<class R>
+ R matching_distance(const Bifiltration<R>& bif_a, const Bifiltration<R>& bif_b,
+ CalculationParams<R>& params)
+ {
+ R result;
+ // compute distance only in one dimension
+ if (params.dim != CalculationParams<R>::ALL_DIMENSIONS) {
+ BifiltrationProxy<R> bifp_a(bif_a, params.dim);
+ BifiltrationProxy<R> bifp_b(bif_b, params.dim);
+ DistanceCalculator<R, BifiltrationProxy<R>> runner(bifp_a, bifp_b, params);
+ result = runner.distance();
+ params.n_hera_calls = runner.get_hera_calls_number();
+ } else {
+ // compute distance in all dimensions, return maximal
+ result = -1;
+ for(int dim = 0; dim < std::max(bif_a.maximal_dim(), bif_b.maximal_dim()); ++dim) {
+ BifiltrationProxy<R> bifp_a(bif_a, params.dim);
+ BifiltrationProxy<R> bifp_b(bif_a, params.dim);
+ DistanceCalculator<R, BifiltrationProxy<R>> runner(bifp_a, bifp_b, params);
+ result = std::max(result, runner.distance());
+ params.n_hera_calls += runner.get_hera_calls_number();
+ }
+ }
+ return result;
+ }
+
+
+ template<class R>
+ R matching_distance(const ModulePresentation<R>& mod_a, const ModulePresentation<R>& mod_b,
+ CalculationParams<R>& params)
+ {
+ DistanceCalculator<R, ModulePresentation<R>> runner(mod_a, mod_b, params);
+ R result = runner.distance();
+ params.n_hera_calls = runner.get_hera_calls_number();
+ return result;
+ }
+} // namespace md
diff --git a/matching/include/persistence_module.h b/matching/include/persistence_module.h
index a1fc67e..e99771f 100644
--- a/matching/include/persistence_module.h
+++ b/matching/include/persistence_module.h
@@ -5,6 +5,12 @@
#include <vector>
#include <utility>
#include <string>
+#include <numeric>
+#include <algorithm>
+#include <unordered_set>
+
+#include "phat/boundary_matrix.h"
+#include "phat/compute_persistence_pairs.h"
#include "common_util.h"
#include "dual_point.h"
@@ -28,17 +34,20 @@ namespace md {
*/
+ template<class Real>
class ModulePresentation {
public:
+ using RealVec = std::vector<Real>;
+
enum Format { rivet_firep };
struct Relation {
- Point position_;
+ Point<Real> position_;
IndexVec components_;
Relation() {}
- Relation(const Point& _pos, const IndexVec& _components);
+ Relation(const Point<Real>& _pos, const IndexVec& _components);
Real get_x() const { return position_.x; }
Real get_y() const { return position_.y; }
@@ -48,9 +57,9 @@ namespace md {
ModulePresentation() {}
- ModulePresentation(const PointVec& _generators, const RelVec& _relations);
+ ModulePresentation(const PointVec<Real>& _generators, const RelVec& _relations);
- Diagram weighted_slice_diagram(const DualPoint& line) const;
+ Diagram<Real> weighted_slice_diagram(const DualPoint<Real>& line) const;
// translate all points by vector (a,a)
void translate(Real a);
@@ -59,9 +68,7 @@ namespace md {
Real minimal_coordinate() const { return std::min(min_x(), min_y()); }
// return box that contains all positions of all simplices
- Box bounding_box() const;
-
- friend std::ostream& operator<<(std::ostream& os, const ModulePresentation& mp);
+ Box<Real> bounding_box() const;
Real max_x() const { return max_x_; }
@@ -71,26 +78,27 @@ namespace md {
Real min_y() const { return min_y_; }
- PointVec positions() const;
+ PointVec<Real> positions() const;
private:
- PointVec generators_;
+ PointVec<Real> generators_;
std::vector<Relation> relations_;
- PointVec positions_;
+ PointVec<Real> positions_;
Real max_x_ { std::numeric_limits<Real>::max() };
Real max_y_ { std::numeric_limits<Real>::max() };
Real min_x_ { -std::numeric_limits<Real>::max() };
Real min_y_ { -std::numeric_limits<Real>::max() };
- Box bounding_box_;
+ Box<Real> bounding_box_;
void init_boundaries();
- void project_generators(const DualPoint& slice, IndexVec& sorted_indices, RealVec& projections) const;
- void project_relations(const DualPoint& slice, IndexVec& sorted_indices, RealVec& projections) const;
+ void project_generators(const DualPoint<Real>& slice, IndexVec& sorted_indices, RealVec& projections) const;
+ void project_relations(const DualPoint<Real>& slice, IndexVec& sorted_indices, RealVec& projections) const;
};
} // namespace md
+#include "persistence_module.hpp"
#endif //MATCHING_DISTANCE_PERSISTENCE_MODULE_H
diff --git a/matching/src/persistence_module.cpp b/matching/include/persistence_module.hpp
index efb20ef..6e49b2e 100644
--- a/matching/src/persistence_module.cpp
+++ b/matching/include/persistence_module.hpp
@@ -1,12 +1,3 @@
-#include <numeric>
-#include <algorithm>
-#include <unordered_set>
-
-#include <phat/boundary_matrix.h>
-#include <phat/compute_persistence_pairs.h>
-
-#include "persistence_module.h"
-
namespace md {
/**
@@ -17,7 +8,7 @@ namespace md {
* 2) a_1,...,a_n is a permutation of 1,..,n
*/
- template<typename T>
+ template<class T>
IndexVec get_sorted_indices(const std::vector<T>& values)
{
IndexVec result(values.size());
@@ -28,18 +19,20 @@ namespace md {
}
// helper function to initialize const member positions_ in ModulePresentation
- PointVec
- concat_gen_and_rel_positions(const PointVec& generators, const ModulePresentation::RelVec& relations)
+ template<class Real>
+ PointVec<Real> concat_gen_and_rel_positions(const PointVec<Real>& generators,
+ const typename ModulePresentation<Real>::RelVec& relations)
{
- std::unordered_set<Point> ps(generators.begin(), generators.end());
+ std::unordered_set<Point<Real>> ps(generators.begin(), generators.end());
for(const auto& rel : relations) {
ps.insert(rel.position_);
}
- return PointVec(ps.begin(), ps.end());
+ return PointVec<Real>(ps.begin(), ps.end());
}
- void ModulePresentation::init_boundaries()
+ template<class Real>
+ void ModulePresentation<Real>::init_boundaries()
{
max_x_ = std::numeric_limits<Real>::max();
max_y_ = std::numeric_limits<Real>::max();
@@ -53,18 +46,20 @@ namespace md {
max_y_ = std::max(gen.y, max_y_);
}
- bounding_box_ = Box(Point(min_x_, min_y_), Point(max_x_, max_y_));
+ bounding_box_ = Box<Real>(Point<Real>(min_x_, min_y_), Point<Real>(max_x_, max_y_));
}
- ModulePresentation::ModulePresentation(const PointVec& _generators, const RelVec& _relations) :
+ template<class Real>
+ ModulePresentation<Real>::ModulePresentation(const PointVec<Real>& _generators, const RelVec& _relations) :
generators_(_generators),
relations_(_relations)
{
init_boundaries();
}
- void ModulePresentation::translate(md::Real a)
+ template<class Real>
+ void ModulePresentation<Real>::translate(Real a)
{
for(auto& g : generators_) {
g.translate(a);
@@ -86,8 +81,9 @@ namespace md {
* @param projections sorted weighted pushes of generators
*/
- void
- ModulePresentation::project_generators(const DualPoint& slice, IndexVec& sorted_indices, RealVec& projections) const
+ template<class Real>
+ void ModulePresentation<Real>::project_generators(const DualPoint<Real>& slice,
+ IndexVec& sorted_indices, RealVec& projections) const
{
size_t num_gens = generators_.size();
@@ -104,7 +100,8 @@ namespace md {
}
}
- void ModulePresentation::project_relations(const DualPoint& slice, IndexVec& sorted_rel_indices,
+ template<class Real>
+ void ModulePresentation<Real>::project_relations(const DualPoint<Real>& slice, IndexVec& sorted_rel_indices,
RealVec& projections) const
{
size_t num_rels = relations_.size();
@@ -122,7 +119,8 @@ namespace md {
}
}
- Diagram ModulePresentation::weighted_slice_diagram(const DualPoint& slice) const
+ template<class Real>
+ Diagram<Real> ModulePresentation<Real>::weighted_slice_diagram(const DualPoint<Real>& slice) const
{
IndexVec sorted_gen_indices, sorted_rel_indices;
RealVec gen_projections, rel_projections;
@@ -147,7 +145,7 @@ namespace md {
phat::persistence_pairs phat_persistence_pairs;
phat::compute_persistence_pairs<phat::twist_reduction>(phat_persistence_pairs, phat_matrix);
- Diagram dgm;
+ Diagram<Real> dgm;
constexpr Real real_inf = std::numeric_limits<Real>::infinity();
@@ -164,14 +162,16 @@ namespace md {
return dgm;
}
- PointVec ModulePresentation::positions() const
+ template<class Real>
+ PointVec<Real> ModulePresentation<Real>::positions() const
{
return positions_;
}
- Box ModulePresentation::bounding_box() const
+ template<class Real>
+ Box<Real> ModulePresentation<Real>::bounding_box() const
{
return bounding_box_;
}
-}
+} // namespace md
diff --git a/matching/include/simplex.h b/matching/include/simplex.h
index e9d0e30..75bbcae 100644
--- a/matching/include/simplex.h
+++ b/matching/include/simplex.h
@@ -9,6 +9,7 @@
namespace md {
+ template<class Real>
class Bifiltration;
enum class BifiltrationFormat {
@@ -38,11 +39,21 @@ namespace md {
int dim() const { return vertices_.size() - 1; }
- void push_back(int v);
+ void push_back(int v)
+ {
+ vertices_.push_back(v);
+ std::sort(vertices_.begin(), vertices_.end());
+ }
AbstractSimplex() { }
- AbstractSimplex(std::vector<int> vertices, bool sort = true);
+ AbstractSimplex(std::vector<int> vertices, bool sort = true)
+ :vertices_(vertices)
+ {
+ if (sort)
+ std::sort(vertices_.begin(), vertices_.end());
+ }
+
template<class Iter>
AbstractSimplex(Iter beg_iter, Iter end_iter, bool sort = true)
@@ -53,22 +64,51 @@ namespace md {
std::sort(vertices_.begin(), end());
}
- std::vector<AbstractSimplex> facets() const;
+ std::vector<AbstractSimplex> facets() const
+ {
+ std::vector<AbstractSimplex> result;
+ for (int i = 0; i < static_cast<int>(vertices_.size()); ++i) {
+ std::vector<int> facet_vertices;
+ facet_vertices.reserve(dim());
+ for (int j = 0; j < static_cast<int>(vertices_.size()); ++j) {
+ if (j != i)
+ facet_vertices.push_back(vertices_[j]);
+ }
+ if (!facet_vertices.empty()) {
+ result.emplace_back(facet_vertices, false);
+ }
+ }
+ return result;
+ }
friend std::ostream& operator<<(std::ostream& os, const AbstractSimplex& s);
-
// compare by vertices_ only
friend bool operator==(const AbstractSimplex& s1, const AbstractSimplex& s2);
friend bool operator<(const AbstractSimplex&, const AbstractSimplex&);
};
- std::ostream& operator<<(std::ostream& os, const AbstractSimplex& s);
+ inline std::ostream& operator<<(std::ostream& os, const AbstractSimplex& s)
+ {
+ os << "AbstractSimplex(id = " << s.id << ", vertices_ = " << container_to_string(s.vertices_) << ")";
+ return os;
+ }
+
+ inline bool operator<(const AbstractSimplex& a, const AbstractSimplex& b)
+ {
+ return a.vertices_ < b.vertices_;
+ }
+
+ inline bool operator==(const AbstractSimplex& s1, const AbstractSimplex& s2)
+ {
+ return s1.vertices_ == s2.vertices_;
+ }
+ template<class Real>
class Simplex {
private:
Index id_;
- Point pos_;
+ Point<Real> pos_;
int dim_;
// in our format we use facet indices,
// this is the fastest representation for homology
@@ -77,11 +117,11 @@ namespace md {
// conversion routines are in Bifiltration
Column facet_indices_;
Column vertices_;
- Real v {0.0}; // used when constructed a filtration for a slice
+ Real v {0}; // used when constructed a filtration for a slice
public:
Simplex(Index _id, std::string s, BifiltrationFormat input_format);
- Simplex(Index _id, Point birth, int _dim, const Column& _bdry);
+ Simplex(Index _id, Point<Real> birth, int _dim, const Column& _bdry);
void init_rivet(std::string s);
@@ -96,9 +136,9 @@ namespace md {
Real value() const { return v; }
// assumes 1-criticality
- Point position() const { return pos_; }
+ Point<Real> position() const { return pos_; }
- void set_position(const Point& new_pos) { pos_ = new_pos; }
+ void set_position(const Point<Real>& new_pos) { pos_ = new_pos; }
void scale(Real lambda)
{
@@ -110,12 +150,14 @@ namespace md {
void set_value(Real new_val) { v = new_val; }
- friend std::ostream& operator<<(std::ostream& os, const Simplex& s);
-
- friend Bifiltration;
+ friend Bifiltration<Real>;
};
- std::ostream& operator<<(std::ostream& os, const Simplex& s);
+ template<class Real>
+ std::ostream& operator<<(std::ostream& os, const Simplex<Real>& s);
}
+
+#include "simplex.hpp"
+
#endif //MATCHING_DISTANCE_SIMPLEX_H
diff --git a/matching/include/simplex.hpp b/matching/include/simplex.hpp
new file mode 100644
index 0000000..ce0e30f
--- /dev/null
+++ b/matching/include/simplex.hpp
@@ -0,0 +1,79 @@
+namespace md {
+
+ template<class Real>
+ Simplex<Real>::Simplex(Index id, Point<Real> birth, int dim, const Column& bdry)
+ :
+ id_(id),
+ pos_(birth),
+ dim_(dim),
+ facet_indices_(bdry) { }
+
+ template<class Real>
+ void Simplex<Real>::translate(Real a)
+ {
+ pos_.translate(a);
+ }
+
+ template<class Real>
+ void Simplex<Real>::init_rivet(std::string s)
+ {
+ auto delim_pos = s.find_first_of(";");
+ assert(delim_pos > 0);
+ std::string vertices_str = s.substr(0, delim_pos);
+ std::string pos_str = s.substr(delim_pos + 1);
+ assert(not vertices_str.empty() and not pos_str.empty());
+ // get vertices
+ std::stringstream vertices_ss(vertices_str);
+ int dim = 0;
+ int vertex;
+ while (vertices_ss >> vertex) {
+ dim++;
+ vertices_.push_back(vertex);
+ }
+ //
+ std::sort(vertices_.begin(), vertices_.end());
+ assert(dim > 0);
+
+ std::stringstream pos_ss(pos_str);
+ // TODO: get rid of 1-criticaltiy assumption
+ pos_ss >> pos_.x >> pos_.y;
+ }
+
+ template<class Real>
+ void Simplex<Real>::init_phat_like(std::string s)
+ {
+ facet_indices_.clear();
+ std::stringstream ss(s);
+ ss >> dim_ >> pos_.x >> pos_.y;
+ if (dim_ > 0) {
+ facet_indices_.reserve(dim_ + 1);
+ for (int j = 0; j <= dim_; j++) {
+ Index k;
+ ss >> k;
+ facet_indices_.push_back(k);
+ }
+ }
+ }
+
+ template<class Real>
+ Simplex<Real>::Simplex(Index _id, std::string s, BifiltrationFormat input_format)
+ :id_(_id)
+ {
+ switch (input_format) {
+ case BifiltrationFormat::phat_like :
+ init_phat_like(s);
+ break;
+ case BifiltrationFormat::rivet :
+ init_rivet(s);
+ break;
+ }
+ }
+
+ template<class Real>
+ std::ostream& operator<<(std::ostream& os, const Simplex<Real>& x)
+ {
+ os << "Simplex<Real>(id = " << x.id() << ", dim = " << x.dim();
+ os << ", boundary = " << container_to_string(x.boundary()) << ", pos = " << x.position() << ")";
+ return os;
+ }
+}
diff --git a/matching/src/box.cpp b/matching/src/box.cpp
deleted file mode 100644
index c128698..0000000
--- a/matching/src/box.cpp
+++ /dev/null
@@ -1,61 +0,0 @@
-
-#include "box.h"
-
-namespace md {
-
- std::ostream& operator<<(std::ostream& os, const Box& box)
- {
- os << "Box(lower_left = " << box.lower_left() << ", upper_right = " << box.upper_right() << ")";
- return os;
- }
-
- Box get_enclosing_box(const Box& box_a, const Box& box_b)
- {
- Point lower_left(std::min(box_a.lower_left().x, box_b.lower_left().x),
- std::min(box_a.lower_left().y, box_b.lower_left().y));
- Point upper_right(std::max(box_a.upper_right().x, box_b.upper_right().x),
- std::max(box_a.upper_right().y, box_b.upper_right().y));
- return Box(lower_left, upper_right);
- }
-
- void Box::translate(md::Real a)
- {
- ll.x += a;
- ll.y += a;
- ur.x += a;
- ur.y += a;
- }
-
- std::vector<Box> Box::refine() const
- {
- std::vector<Box> result;
-
-// 1 | 2
-// 0 | 3
-
- Point new_ll = lower_left();
- Point new_ur = center();
- result.emplace_back(new_ll, new_ur);
-
- new_ll.y = center().y;
- new_ur.y = ur.y;
- result.emplace_back(new_ll, new_ur);
-
- new_ll = center();
- new_ur = upper_right();
- result.emplace_back(new_ll, new_ur);
-
- new_ll.y = ll.y;
- new_ur.y = center().y;
- result.emplace_back(new_ll, new_ur);
-
- return result;
- }
-
- std::vector<Point> Box::corners() const
- {
- return {ll, Point(ll.x, ur.y), ur, Point(ur.x, ll.y)};
- };
-
-
-}
diff --git a/matching/src/common_util.cpp b/matching/src/common_util.cpp
deleted file mode 100644
index 96c3388..0000000
--- a/matching/src/common_util.cpp
+++ /dev/null
@@ -1,243 +0,0 @@
-#include <vector>
-#include <utility>
-#include <cmath>
-#include <ostream>
-#include <limits>
-#include <algorithm>
-
-#include <common_util.h>
-
-#include "spdlog/spdlog.h"
-#include "spdlog/fmt/ostr.h"
-
-namespace md {
-
-
- int gcd(int a, int b)
- {
- assert(a != 0 or b != 0);
- // make b <= a
- std::tie(b, a) = std::minmax({ abs(a), abs(b) });
- if (b == 0)
- return a;
- while((a = a % b)) {
- std::swap(a, b);
- }
- return b;
- }
-
- int signum(int a)
- {
- if (a < 0)
- return -1;
- else if (a > 0)
- return 1;
- else
- return 0;
- }
-
- Rational reduce(Rational frac)
- {
- int d = gcd(frac.numerator, frac.denominator);
- frac.numerator /= d;
- frac.denominator /= d;
- return frac;
- }
-
- void Rational::reduce() { *this = md::reduce(*this); }
-
-
- Rational& Rational::operator*=(const md::Rational& rhs)
- {
- numerator *= rhs.numerator;
- denominator *= rhs.denominator;
- reduce();
- return *this;
- }
-
- Rational& Rational::operator/=(const md::Rational& rhs)
- {
- numerator *= rhs.denominator;
- denominator *= rhs.numerator;
- reduce();
- return *this;
- }
-
- Rational& Rational::operator+=(const md::Rational& rhs)
- {
- numerator = numerator * rhs.denominator + denominator * rhs.numerator;
- denominator *= rhs.denominator;
- reduce();
- return *this;
- }
-
- Rational& Rational::operator-=(const md::Rational& rhs)
- {
- numerator = numerator * rhs.denominator - denominator * rhs.numerator;
- denominator *= rhs.denominator;
- reduce();
- return *this;
- }
-
-
- Rational midpoint(Rational a, Rational b)
- {
- return reduce({a.numerator * b.denominator + a.denominator * b.numerator, 2 * a.denominator * b.denominator });
- }
-
- Rational operator+(Rational a, const Rational& b)
- {
- a += b;
- return a;
- }
-
- Rational operator-(Rational a, const Rational& b)
- {
- a -= b;
- return a;
- }
-
- Rational operator*(Rational a, const Rational& b)
- {
- a *= b;
- return a;
- }
-
- Rational operator/(Rational a, const Rational& b)
- {
- a /= b;
- return a;
- }
-
- bool is_less(Rational a, Rational b)
- {
- // compute a - b = a_1 / a_2 - b_1 / b_2
- long numer = a.numerator * b.denominator - a.denominator * b.numerator;
- long denom = a.denominator * b.denominator;
- assert(denom != 0);
- return signum(numer) * signum(denom) < 0;
- }
-
- bool operator==(const Rational& a, const Rational& b)
- {
- return std::tie(a.numerator, a.denominator) == std::tie(b.numerator, b.denominator);
- }
-
- bool operator<(const Rational& a, const Rational& b)
- {
- // do not remove signum - overflow
- long numer = a.numerator * b.denominator - a.denominator * b.numerator;
- long denom = a.denominator * b.denominator;
- assert(denom != 0);
-// spdlog::debug("a = {}, b = {}, numer = {}, denom = {}, result = {}", a, b, numer, denom, signum(numer) * signum(denom) <= 0);
- return signum(numer) * signum(denom) < 0;
- }
-
- bool is_leq(Rational a, Rational b)
- {
- // compute a - b = a_1 / a_2 - b_1 / b_2
- long numer = a.numerator * b.denominator - a.denominator * b.numerator;
- long denom = a.denominator * b.denominator;
- assert(denom != 0);
- return signum(numer) * signum(denom) <= 0;
- }
-
- bool is_greater(Rational a, Rational b)
- {
- return not is_leq(a, b);
- }
-
- bool is_geq(Rational a, Rational b)
- {
- return not is_less(a, b);
- }
-
- Point operator+(const Point& u, const Point& v)
- {
- return Point(u.x + v.x, u.y + v.y);
- }
-
- Point operator-(const Point& u, const Point& v)
- {
- return Point(u.x - v.x, u.y - v.y);
- }
-
- Point least_upper_bound(const Point& u, const Point& v)
- {
- return Point(std::max(u.x, v.x), std::max(u.y, v.y));
- }
-
- Point greatest_lower_bound(const Point& u, const Point& v)
- {
- return Point(std::min(u.x, v.x), std::min(u.y, v.y));
- }
-
- Point max_point()
- {
- return Point(std::numeric_limits<Real>::max(), std::numeric_limits<Real>::min());
- }
-
- Point min_point()
- {
- return Point(-std::numeric_limits<Real>::max(), -std::numeric_limits<Real>::min());
- }
-
- std::ostream& operator<<(std::ostream& ostr, const Point& vec)
- {
- ostr << "(" << vec.x << ", " << vec.y << ")";
- return ostr;
- }
-
- Real l_infty_norm(const Point& v)
- {
- return std::max(std::abs(v.x), std::abs(v.y));
- }
-
- Real l_2_norm(const Point& v)
- {
- return v.norm();
- }
-
- Real l_2_dist(const Point& x, const Point& y)
- {
- return l_2_norm(x - y);
- }
-
- Real l_infty_dist(const Point& x, const Point& y)
- {
- return l_infty_norm(x - y);
- }
-
- void DiagramKeeper::add_point(int dim, md::Real birth, md::Real death)
- {
- data_[dim].emplace_back(birth, death);
- }
-
- DiagramKeeper::Diagram DiagramKeeper::get_diagram(int dim) const
- {
- if (data_.count(dim) == 1)
- return data_.at(dim);
- else
- return DiagramKeeper::Diagram();
- }
-
- // return true, if line starts with #
- // or contains only spaces
- bool ignore_line(const std::string& s)
- {
- for(auto c : s) {
- if (isspace(c))
- continue;
- return (c == '#');
- }
- return true;
- }
-
-
-
- std::ostream& operator<<(std::ostream& os, const Rational& a)
- {
- os << a.numerator << " / " << a.denominator;
- return os;
- }
-}
diff --git a/matching/src/main.cpp b/matching/src/main.cpp
index f1472be..2093457 100644
--- a/matching/src/main.cpp
+++ b/matching/src/main.cpp
@@ -18,12 +18,20 @@
#include "box.h"
#include "matching_distance.h"
+using Real = double;
+
using namespace md;
namespace fs = std::experimental::filesystem;
+void force_instantiation()
+{
+ DualBox<Real> db;
+ std::cout << db;
+}
+
#ifdef PRINT_HEAT_MAP
-void print_heat_map(const md::HeatMaps& hms, std::string fname, const CalculationParams& params)
+void print_heat_map(const md::HeatMaps<Real>& hms, std::string fname, const CalculationParams<Real>& params)
{
spd::debug("Entered print_heat_map");
std::set<Real> mu_vals, lambda_vals;
@@ -143,7 +151,7 @@ int main(int argc, char** argv)
bool help = false;
bool no_stop_asap = false;
- CalculationParams params;
+ CalculationParams<Real> params;
#ifdef PRINT_HEAT_MAP
bool heatmap_only = false;
@@ -178,8 +186,8 @@ int main(int argc, char** argv)
auto bounds_list = split_by_delim(bounds_list_str, ',');
auto traverse_list = split_by_delim(traverse_list_str, ',');
- Bifiltration bif_a(fname_a);
- Bifiltration bif_b(fname_b);
+ Bifiltration<Real> bif_a(fname_a);
+ Bifiltration<Real> bif_b(fname_b);
bif_a.sanity_check();
bif_b.sanity_check();
@@ -207,11 +215,11 @@ int main(int argc, char** argv)
}
struct ExperimentResult {
- CalculationParams params {CalculationParams()};
+ CalculationParams<Real> params {CalculationParams()};
int n_hera_calls {0};
double total_milliseconds_elapsed {0};
- double distance {0};
- double actual_error {std::numeric_limits<double>::max()};
+ Real distance {0};
+ Real actual_error {std::numeric_limits<double>::max()};
int actual_max_depth {0};
int x_wins {0};
@@ -250,7 +258,7 @@ int main(int argc, char** argv)
ExperimentResult() { }
- ExperimentResult(CalculationParams p, int nhc, double tme, double d)
+ ExperimentResult(CalculationParams<Real> p, int nhc, double tme, double d)
:
params(p), n_hera_calls(nhc), total_milliseconds_elapsed(tme), distance(d) { }
};
@@ -267,7 +275,7 @@ int main(int argc, char** argv)
std::map<std::tuple<BoundStrategy, TraverseStrategy>, ExperimentResult> results;
for(BoundStrategy bound_strategy : bound_strategies) {
for(TraverseStrategy traverse_strategy : traverse_strategies) {
- CalculationParams params_experiment;
+ CalculationParams<Real> params_experiment;
params_experiment.bound_strategy = bound_strategy;
params_experiment.traverse_strategy = traverse_strategy;
params_experiment.max_depth = params.max_depth;
@@ -366,8 +374,9 @@ int main(int argc, char** argv)
spd::debug("Will use {} bound, {} traverse strategy", params.bound_strategy, params.traverse_strategy);
- Real dist = matching_distance(bif_a, bif_b, params);
+ Real dist = matching_distance<Real>(bif_a, bif_b, params);
std::cout << dist << std::endl;
#endif
+ force_instantiation();
return 0;
}
diff --git a/matching/src/matching_distance.cpp b/matching/src/matching_distance.cpp
deleted file mode 100644
index e53233f..0000000
--- a/matching/src/matching_distance.cpp
+++ /dev/null
@@ -1,150 +0,0 @@
-#include <chrono>
-#include <tuple>
-#include <algorithm>
-
-#include "common_defs.h"
-
-#include "matching_distance.h"
-
-namespace md {
-
- Real matching_distance(const Bifiltration& bif_a, const Bifiltration& bif_b,
- CalculationParams& params)
- {
- Real result;
- // compute distance only in one dimension
- if (params.dim != CalculationParams::ALL_DIMENSIONS) {
- BifiltrationProxy bifp_a(bif_a, params.dim);
- BifiltrationProxy bifp_b(bif_b, params.dim);
- DistanceCalculator<BifiltrationProxy> runner(bifp_a, bifp_b, params);
- result = runner.distance();
- params.n_hera_calls = runner.get_hera_calls_number();
- } else {
- // compute distance in all dimensions, return maximal
- result = -1;
- for(int dim = 0; dim < std::max(bif_a.maximal_dim(), bif_b.maximal_dim()); ++dim) {
- BifiltrationProxy bifp_a(bif_a, params.dim);
- BifiltrationProxy bifp_b(bif_a, params.dim);
- DistanceCalculator<BifiltrationProxy> runner(bifp_a, bifp_b, params);
- result = std::max(result, runner.distance());
- params.n_hera_calls += runner.get_hera_calls_number();
- }
- }
- return result;
- }
-
-
- Real matching_distance(const ModulePresentation& mod_a, const ModulePresentation& mod_b,
- CalculationParams& params)
- {
- DistanceCalculator<ModulePresentation> runner(mod_a, mod_b, params);
- Real result = runner.distance();
- params.n_hera_calls = runner.get_hera_calls_number();
- return result;
- }
-
- std::istream& operator>>(std::istream& is, BoundStrategy& s)
- {
- std::string ss;
- is >> ss;
- if (ss == "bruteforce") {
- s = BoundStrategy::bruteforce;
- } else if (ss == "local_grob") {
- s = BoundStrategy::local_dual_bound;
- } else if (ss == "local_combined") {
- s = BoundStrategy::local_combined;
- } else if (ss == "local_refined") {
- s = BoundStrategy::local_dual_bound_refined;
- } else if (ss == "local_for_each_point") {
- s = BoundStrategy::local_dual_bound_for_each_point;
- } else {
- throw std::runtime_error("UNKNOWN BOUND STRATEGY");
- }
- return is;
- }
-
- BoundStrategy bs_from_string(std::string s)
- {
- std::stringstream ss(s);
- BoundStrategy result;
- ss >> result;
- return result;
- }
-
- TraverseStrategy ts_from_string(std::string s)
- {
- std::stringstream ss(s);
- TraverseStrategy result;
- ss >> result;
- return result;
- }
-
- std::istream& operator>>(std::istream& is, TraverseStrategy& s)
- {
- std::string ss;
- is >> ss;
- if (ss == "DFS") {
- s = TraverseStrategy::depth_first;
- } else if (ss == "BFS") {
- s = TraverseStrategy::breadth_first;
- } else if (ss == "BFS-VAL") {
- s = TraverseStrategy::breadth_first_value;
- } else if (ss == "UB") {
- s = TraverseStrategy::upper_bound;
- } else {
- throw std::runtime_error("UNKNOWN TRAVERSE STRATEGY");
- }
- return is;
- }
-
- std::ostream& operator<<(std::ostream& os, const UbExperimentRecord& r)
- {
- os << r.time << "\t" << r.n_hera_calls << "\t" << r.error << "\t" << r.lower_bound << "\t" << r.upper_bound;
- return os;
- }
-
- std::ostream& operator<<(std::ostream& os, const BoundStrategy& s)
- {
- switch(s) {
- case BoundStrategy::bruteforce :
- os << "bruteforce";
- break;
- case BoundStrategy::local_dual_bound :
- os << "local_grob";
- break;
- case BoundStrategy::local_combined :
- os << "local_combined";
- break;
- case BoundStrategy::local_dual_bound_refined :
- os << "local_refined";
- break;
- case BoundStrategy::local_dual_bound_for_each_point :
- os << "local_for_each_point";
- break;
- default:
- os << "FORGOTTEN BOUND STRATEGY";
- }
- return os;
- }
-
- std::ostream& operator<<(std::ostream& os, const TraverseStrategy& s)
- {
- switch(s) {
- case TraverseStrategy::depth_first :
- os << "DFS";
- break;
- case TraverseStrategy::breadth_first :
- os << "BFS";
- break;
- case TraverseStrategy::breadth_first_value :
- os << "BFS-VAL";
- break;
- case TraverseStrategy::upper_bound :
- os << "UB";
- break;
- default:
- os << "FORGOTTEN TRAVERSE STRATEGY";
- }
- return os;
- }
-}
diff --git a/matching/src/simplex.cpp b/matching/src/simplex.cpp
deleted file mode 100644
index 6b53680..0000000
--- a/matching/src/simplex.cpp
+++ /dev/null
@@ -1,121 +0,0 @@
-#include "simplex.h"
-
-namespace md {
-
- std::ostream& operator<<(std::ostream& os, const AbstractSimplex& s)
- {
- os << "AbstractSimplex(id = " << s.id << ", vertices_ = " << container_to_string(s.vertices_) << ")";
- return os;
- }
-
- bool operator<(const AbstractSimplex& a, const AbstractSimplex& b)
- {
- return a.vertices_ < b.vertices_;
- }
-
- bool operator==(const AbstractSimplex& s1, const AbstractSimplex& s2)
- {
- return s1.vertices_ == s2.vertices_;
- }
-
- void AbstractSimplex::push_back(int v)
- {
- vertices_.push_back(v);
- std::sort(vertices_.begin(), vertices_.end());
- }
-
- AbstractSimplex::AbstractSimplex(std::vector<int> vertices, bool sort)
- :vertices_(vertices)
- {
- if (sort)
- std::sort(vertices_.begin(), vertices_.end());
- }
-
- std::vector<AbstractSimplex> AbstractSimplex::facets() const
- {
- std::vector<AbstractSimplex> result;
- for (int i = 0; i < static_cast<int>(vertices_.size()); ++i) {
- std::vector<int> facet_vertices;
- facet_vertices.reserve(dim());
- for (int j = 0; j < static_cast<int>(vertices_.size()); ++j) {
- if (j != i)
- facet_vertices.push_back(vertices_[j]);
- }
- if (!facet_vertices.empty()) {
- result.emplace_back(facet_vertices, false);
- }
- }
- return result;
- }
-
- Simplex::Simplex(md::Index id, md::Point birth, int dim, const md::Column& bdry)
- :
- id_(id),
- pos_(birth),
- dim_(dim),
- facet_indices_(bdry) { }
-
- void Simplex::translate(Real a)
- {
- pos_.translate(a);
- }
-
- void Simplex::init_rivet(std::string s)
- {
- auto delim_pos = s.find_first_of(";");
- assert(delim_pos > 0);
- std::string vertices_str = s.substr(0, delim_pos);
- std::string pos_str = s.substr(delim_pos + 1);
- assert(not vertices_str.empty() and not pos_str.empty());
- // get vertices
- std::stringstream vertices_ss(vertices_str);
- int dim = 0;
- int vertex;
- while (vertices_ss >> vertex) {
- dim++;
- vertices_.push_back(vertex);
- }
- //
- std::sort(vertices_.begin(), vertices_.end());
- assert(dim > 0);
-
- std::stringstream pos_ss(pos_str);
- // TODO: get rid of 1-criticaltiy assumption
- pos_ss >> pos_.x >> pos_.y;
- }
-
- void Simplex::init_phat_like(std::string s)
- {
- facet_indices_.clear();
- std::stringstream ss(s);
- ss >> dim_ >> pos_.x >> pos_.y;
- if (dim_ > 0) {
- facet_indices_.reserve(dim_ + 1);
- for (int j = 0; j <= dim_; j++) {
- Index k;
- ss >> k;
- facet_indices_.push_back(k);
- }
- }
- }
-
- Simplex::Simplex(Index _id, std::string s, BifiltrationFormat input_format)
- :id_(_id)
- {
- switch (input_format) {
- case BifiltrationFormat::phat_like :
- init_phat_like(s);
- break;
- case BifiltrationFormat::rivet :
- init_rivet(s);
- break;
- }
- }
-
- std::ostream& operator<<(std::ostream& os, const Simplex& x)
- {
- os << "Simplex(id = " << x.id() << ", dim = " << x.dim();
- os << ", boundary = " << container_to_string(x.boundary()) << ", pos = " << x.position() << ")";
- return os;
- }
-}
diff --git a/matching/src/test_generator.cpp b/matching/src/test_generator.cpp
index e8f128f..a2f0625 100644
--- a/matching/src/test_generator.cpp
+++ b/matching/src/test_generator.cpp
@@ -11,9 +11,12 @@
#include "common_util.h"
#include "bifiltration.h"
+using Real = double;
using Index = md::Index;
-using Point = md::Point;
+using Point = md::Point<Real>;
+using Bifiltration = md::Bifiltration<Real>;
using Column = md::Column;
+using Simplex = md::Simplex<Real>;
int g_max_coord = 100;
@@ -100,7 +103,7 @@ void generate_positions(const ASimplex& s, ASimplexToBirthMap& simplex_to_birth,
}
}
-md::Bifiltration get_random_bifiltration(int n_vertices, int max_dim, int n_top_simplices)
+Bifiltration get_random_bifiltration(int n_vertices, int max_dim, int n_top_simplices)
{
ASimplexToBirthMap simplex_to_birth;
@@ -122,13 +125,13 @@ md::Bifiltration get_random_bifiltration(int n_vertices, int max_dim, int n_top_
add_if_top(candidate_simplex, top_simplices);
}
- Point upper_bound{static_cast<md::Real>(g_max_coord), static_cast<md::Real>(g_max_coord)};
+ Point upper_bound{static_cast<Real>(g_max_coord), static_cast<Real>(g_max_coord)};
for(const auto& top_simplex : top_simplices) {
generate_positions(top_simplex, simplex_to_birth, upper_bound);
}
std::vector<std::pair<ASimplex, Point>> simplex_birth_pairs{simplex_to_birth.begin(), simplex_to_birth.end()};
- std::vector<md::Column> boundaries{simplex_to_birth.size(), md::Column()};
+ std::vector<Column> boundaries{simplex_to_birth.size(), Column()};
// assign ids and save boundaries
int id = 0;
@@ -138,7 +141,7 @@ md::Bifiltration get_random_bifiltration(int n_vertices, int max_dim, int n_top_
ASimplex& simplex = simplex_birth_pairs[i].first;
if (simplex.dim() == dim) {
simplex.id = id++;
- md::Column bdry;
+ Column bdry;
for(auto& facet : simplex.facets()) {
auto facet_iter = std::find_if(simplex_birth_pairs.begin(), simplex_birth_pairs.end(),
[facet](const std::pair<ASimplex, Point>& sbp) { return facet == sbp.first; });
@@ -153,7 +156,7 @@ md::Bifiltration get_random_bifiltration(int n_vertices, int max_dim, int n_top_
}
// create vector of Simplex-es
- std::vector<md::Simplex> simplices;
+ std::vector<Simplex> simplices;
for(int i = 0; i < (int) simplex_birth_pairs.size(); ++i) {
int id = simplex_birth_pairs[i].first.id;
int dim = simplex_birth_pairs[i].first.dim();
@@ -164,13 +167,13 @@ md::Bifiltration get_random_bifiltration(int n_vertices, int max_dim, int n_top_
// sort by id
std::sort(simplices.begin(), simplices.end(),
- [](const md::Simplex& s1, const md::Simplex& s2) { return s1.id() < s2.id(); });
+ [](const Simplex& s1, const Simplex& s2) { return s1.id() < s2.id(); });
for(int i = 0; i < (int)simplices.size(); ++i) {
assert(simplices[i].id() == i);
assert(i == 0 || simplices[i].dim() >= simplices[i - 1].dim());
}
- return md::Bifiltration(simplices.begin(), simplices.end());
+ return Bifiltration(simplices.begin(), simplices.end());
}
int main(int argc, char** argv)
diff --git a/matching/src/tests/test_common.cpp b/matching/src/tests/test_common.cpp
index c55577e..9079a56 100644
--- a/matching/src/tests/test_common.cpp
+++ b/matching/src/tests/test_common.cpp
@@ -8,56 +8,24 @@
#include "simplex.h"
#include "matching_distance.h"
-using namespace md;
+//using namespace md;
+using Real = double;
+using Point = md::Point<Real>;
+using Bifiltration = md::Bifiltration<Real>;
+using BifiltrationProxy = md::BifiltrationProxy<Real>;
+using CalculationParams = md::CalculationParams<Real>;
+using CellWithValue = md::CellWithValue<Real>;
+using DualPoint = md::DualPoint<Real>;
+using DualBox = md::DualBox<Real>;
+using Simplex = md::Simplex<Real>;
+using AbstractSimplex = md::AbstractSimplex;
+using BoundStrategy = md::BoundStrategy;
+using TraverseStrategy = md::TraverseStrategy;
+using AxisType = md::AxisType;
+using AngleType = md::AngleType;
+using ValuePoint = md::ValuePoint;
+using Column = md::Column;
-TEST_CASE("Rational", "[common_utils][rational]")
-{
- // gcd
- REQUIRE(gcd(10, 5) == 5);
- REQUIRE(gcd(5, 10) == 5);
- REQUIRE(gcd(5, 7) == 1);
- REQUIRE(gcd(7, 5) == 1);
- REQUIRE(gcd(13, 0) == 13);
- REQUIRE(gcd(0, 13) == 13);
- REQUIRE(gcd(16, 24) == 8);
- REQUIRE(gcd(24, 16) == 8);
- REQUIRE(gcd(16, 32) == 16);
- REQUIRE(gcd(32, 16) == 16);
-
-
- // reduce
- REQUIRE(reduce({2, 1}) == std::make_pair(2, 1));
- REQUIRE(reduce({1, 2}) == std::make_pair(1, 2));
- REQUIRE(reduce({2, 2}) == std::make_pair(1, 1));
- REQUIRE(reduce({0, 2}) == std::make_pair(0, 1));
- REQUIRE(reduce({0, 20}) == std::make_pair(0, 1));
- REQUIRE(reduce({35, 49}) == std::make_pair(5, 7));
- REQUIRE(reduce({35, 25}) == std::make_pair(7, 5));
-
- // midpoint
- REQUIRE(midpoint(Rational {0, 1}, Rational {1, 2}) == std::make_pair(1, 4));
- REQUIRE(midpoint(Rational {1, 4}, Rational {1, 2}) == std::make_pair(3, 8));
- REQUIRE(midpoint(Rational {1, 2}, Rational {1, 2}) == std::make_pair(1, 2));
- REQUIRE(midpoint(Rational {1, 2}, Rational {1, 1}) == std::make_pair(3, 4));
- REQUIRE(midpoint(Rational {3, 7}, Rational {5, 14}) == std::make_pair(11, 28));
-
-
- // arithmetic
-
- REQUIRE(Rational(1, 2) + Rational(3, 5) == Rational(11, 10));
- REQUIRE(Rational(2, 5) - Rational(3, 10) == Rational(1, 10));
- REQUIRE(Rational(2, 3) * Rational(4, 7) == Rational(8, 21));
- REQUIRE(Rational(2, 3) * Rational(3, 2) == Rational(1));
- REQUIRE(Rational(2, 3) / Rational(3, 2) == Rational(4, 9));
- REQUIRE(Rational(1, 2) * Rational(3, 5) == Rational(3, 10));
-
- // comparison
- REQUIRE(Rational(100000, 2000000) < Rational(100001, 2000000));
- REQUIRE(!(Rational(100001, 2000000) < Rational(100000, 2000000)));
- REQUIRE(!(Rational(100000, 2000000) < Rational(100000, 2000000)));
- REQUIRE(Rational(-100000, 2000000) < Rational(100001, 2000000));
- REQUIRE(Rational(-100001, 2000000) < Rational(100000, 2000000));
-};
TEST_CASE("AbstractSimplex", "[abstract_simplex]")
{
diff --git a/matching/src/tests/test_matching_distance.cpp b/matching/src/tests/test_matching_distance.cpp
index df9345e..82da530 100644
--- a/matching/src/tests/test_matching_distance.cpp
+++ b/matching/src/tests/test_matching_distance.cpp
@@ -11,7 +11,25 @@
#include "simplex.h"
#include "matching_distance.h"
-using namespace md;
+using Real = double;
+using Point = md::Point<Real>;
+using Bifiltration = md::Bifiltration<Real>;
+using BifiltrationProxy = md::BifiltrationProxy<Real>;
+using CalculationParams = md::CalculationParams<Real>;
+using CellWithValue = md::CellWithValue<Real>;
+using DualPoint = md::DualPoint<Real>;
+using DualBox = md::DualBox<Real>;
+using Simplex = md::Simplex<Real>;
+using AbstractSimplex = md::AbstractSimplex;
+using BoundStrategy = md::BoundStrategy;
+using TraverseStrategy = md::TraverseStrategy;
+using AxisType = md::AxisType;
+using AngleType = md::AngleType;
+using ValuePoint = md::ValuePoint;
+using Column = md::Column;
+
+using md::k_corner_vps;
+
namespace spd = spdlog;
TEST_CASE("Different bounds", "[bounds]")
@@ -40,7 +58,7 @@ TEST_CASE("Different bounds", "[bounds]")
BifiltrationProxy bifp_a(bif_a, params.dim);
BifiltrationProxy bifp_b(bif_b, params.dim);
- DistanceCalculator<BifiltrationProxy> calc(bifp_a, bifp_b, params);
+ md::DistanceCalculator<Real, BifiltrationProxy> calc(bifp_a, bifp_b, params);
// REQUIRE(calc.max_x_ == Approx(max_x));
// REQUIRE(calc.max_y_ == Approx(max_y));