summaryrefslogtreecommitdiff
path: root/matching/include/matching_distance.h
diff options
context:
space:
mode:
Diffstat (limited to 'matching/include/matching_distance.h')
-rw-r--r--matching/include/matching_distance.h203
1 files changed, 157 insertions, 46 deletions
diff --git a/matching/include/matching_distance.h b/matching/include/matching_distance.h
index bb10203..5be34c7 100644
--- a/matching/include/matching_distance.h
+++ b/matching/include/matching_distance.h
@@ -4,9 +4,10 @@
#include <limits>
#include <utility>
#include <ostream>
+#include <chrono>
+#include <tuple>
+#include <algorithm>
-#include "spdlog/spdlog.h"
-#include "spdlog/fmt/ostr.h"
#include "common_defs.h"
#include "cell_with_value.h"
@@ -17,12 +18,15 @@
#include "bifiltration.h"
#include "bottleneck.h"
-namespace spd = spdlog;
-
namespace md {
- using HeatMap = std::map<DualPoint, Real>;
- using HeatMaps = std::map<int, HeatMap>;
+#ifdef MD_PRINT_HEAT_MAP
+ template<class Real>
+ using HeatMap = std::map<DualPoint<Real>, Real>;
+
+ template<class Real>
+ using HeatMaps = std::map<int, HeatMap<Real>>;
+#endif
enum class BoundStrategy {
bruteforce,
@@ -39,18 +43,107 @@ namespace md {
upper_bound
};
- std::ostream& operator<<(std::ostream& os, const BoundStrategy& s);
-
- std::ostream& operator<<(std::ostream& os, const TraverseStrategy& s);
-
- std::istream& operator>>(std::istream& is, BoundStrategy& s);
-
- std::istream& operator>>(std::istream& is, TraverseStrategy& s);
-
- BoundStrategy bs_from_string(std::string s);
-
- TraverseStrategy ts_from_string(std::string s);
-
+ inline std::ostream& operator<<(std::ostream& os, const BoundStrategy& s)
+ {
+ switch(s) {
+ case BoundStrategy::bruteforce :
+ os << "bruteforce";
+ break;
+ case BoundStrategy::local_dual_bound :
+ os << "local_grob";
+ break;
+ case BoundStrategy::local_combined :
+ os << "local_combined";
+ break;
+ case BoundStrategy::local_dual_bound_refined :
+ os << "local_refined";
+ break;
+ case BoundStrategy::local_dual_bound_for_each_point :
+ os << "local_for_each_point";
+ break;
+ default:
+ os << "FORGOTTEN BOUND STRATEGY";
+ }
+ return os;
+ }
+
+ inline std::ostream& operator<<(std::ostream& os, const TraverseStrategy& s)
+ {
+ switch(s) {
+ case TraverseStrategy::depth_first :
+ os << "DFS";
+ break;
+ case TraverseStrategy::breadth_first :
+ os << "BFS";
+ break;
+ case TraverseStrategy::breadth_first_value :
+ os << "BFS-VAL";
+ break;
+ case TraverseStrategy::upper_bound :
+ os << "UB";
+ break;
+ default:
+ os << "FORGOTTEN TRAVERSE STRATEGY";
+ }
+ return os;
+ }
+
+ inline std::istream& operator>>(std::istream& is, TraverseStrategy& s)
+ {
+ std::string ss;
+ is >> ss;
+ if (ss == "DFS") {
+ s = TraverseStrategy::depth_first;
+ } else if (ss == "BFS") {
+ s = TraverseStrategy::breadth_first;
+ } else if (ss == "BFS-VAL") {
+ s = TraverseStrategy::breadth_first_value;
+ } else if (ss == "UB") {
+ s = TraverseStrategy::upper_bound;
+ } else {
+ throw std::runtime_error("UNKNOWN TRAVERSE STRATEGY");
+ }
+ return is;
+ }
+
+
+ inline std::istream& operator>>(std::istream& is, BoundStrategy& s)
+ {
+ std::string ss;
+ is >> ss;
+ if (ss == "bruteforce") {
+ s = BoundStrategy::bruteforce;
+ } else if (ss == "local_grob") {
+ s = BoundStrategy::local_dual_bound;
+ } else if (ss == "local_combined") {
+ s = BoundStrategy::local_combined;
+ } else if (ss == "local_refined") {
+ s = BoundStrategy::local_dual_bound_refined;
+ } else if (ss == "local_for_each_point") {
+ s = BoundStrategy::local_dual_bound_for_each_point;
+ } else {
+ throw std::runtime_error("UNKNOWN BOUND STRATEGY");
+ }
+ return is;
+ }
+
+ inline BoundStrategy bs_from_string(std::string s)
+ {
+ std::stringstream ss(s);
+ BoundStrategy result;
+ ss >> result;
+ return result;
+ }
+
+ inline TraverseStrategy ts_from_string(std::string s)
+ {
+ std::stringstream ss(s);
+ TraverseStrategy result;
+ ss >> result;
+ return result;
+ }
+
+ template<class Real>
struct CalculationParams {
static constexpr int ALL_DIMENSIONS = -1;
@@ -75,22 +168,22 @@ namespace md {
// print statistics on each quad-tree level
bool print_stats { false };
-#ifdef PRINT_HEAT_MAP
+#ifdef MD_PRINT_HEAT_MAP
HeatMaps heat_maps;
#endif
};
- template<class DiagramProvider>
+ template<class Real_, class DiagramProvider>
class DistanceCalculator {
- using DualBox = md::DualBox;
- using CellValueVector = std::vector<CellWithValue>;
+ using Real = Real_;
+ using CellValueVector = std::vector<CellWithValue<Real>>;
public:
DistanceCalculator(const DiagramProvider& a,
const DiagramProvider& b,
- CalculationParams& params);
+ CalculationParams<Real>& params);
Real distance();
@@ -100,7 +193,7 @@ namespace md {
DiagramProvider module_a_;
DiagramProvider module_b_;
- CalculationParams& params_;
+ CalculationParams<Real>& params_;
int n_hera_calls_;
std::map<int, int> n_hera_calls_per_level_;
@@ -112,65 +205,83 @@ namespace md {
CellValueVector get_initial_dual_grid(Real& lower_bound);
+#ifdef MD_PRINT_HEAT_MAP
void heatmap_in_dimension(int dim, int depth);
+#endif
Real get_max_x(int module) const;
Real get_max_y(int module) const;
- void set_cell_central_value(CellWithValue& dual_cell);
+ void set_cell_central_value(CellWithValue<Real>& dual_cell);
Real get_distance();
Real get_distance_pq();
- // temporary, to try priority queue
- Real get_max_possible_value(const CellWithValue* first_cell_ptr, int n_cells);
+ Real get_max_possible_value(const CellWithValue<Real>* first_cell_ptr, int n_cells);
- Real get_upper_bound(const CellWithValue& dual_cell, Real good_enough_upper_bound) const;
+ Real get_upper_bound(const CellWithValue<Real>& dual_cell, Real good_enough_upper_bound) const;
- Real get_single_dgm_bound(const CellWithValue& dual_cell, ValuePoint vp, int module,
+ Real get_single_dgm_bound(const CellWithValue<Real>& dual_cell, ValuePoint vp, int module,
Real good_enough_value) const;
// this bound depends only on dual box
- Real get_local_dual_bound(int module, const DualBox& dual_box) const;
+ Real get_local_dual_bound(int module, const DualBox<Real>& dual_box) const;
- Real get_local_dual_bound(const DualBox& dual_box) const;
+ Real get_local_dual_bound(const DualBox<Real>& dual_box) const;
// this bound depends only on dual box, is more accurate
- Real get_local_refined_bound(int module, const md::DualBox& dual_box) const;
+ Real get_local_refined_bound(int module, const DualBox<Real>& dual_box) const;
- Real get_local_refined_bound(const md::DualBox& dual_box) const;
+ Real get_local_refined_bound(const DualBox<Real>& dual_box) const;
Real get_good_enough_upper_bound(Real lower_bound) const;
- Real
- get_max_displacement_single_point(const CellWithValue& dual_cell, ValuePoint value_point, const Point& p) const;
+ Real get_max_displacement_single_point(const CellWithValue<Real>& dual_cell, ValuePoint value_point,
+ const Point<Real>& p) const;
- void check_upper_bound(const CellWithValue& dual_cell) const;
+ void check_upper_bound(const CellWithValue<Real>& dual_cell) const;
- Real distance_on_line(DualPoint line);
- Real distance_on_line_const(DualPoint line) const;
+ Real distance_on_line(DualPoint<Real> line);
+ Real distance_on_line_const(DualPoint<Real> line) const;
Real current_error(Real lower_bound, Real upper_bound);
};
- Real matching_distance(const Bifiltration& bif_a, const Bifiltration& bif_b, CalculationParams& params);
+ template<class Real>
+ Real matching_distance(const Bifiltration<Real>& bif_a, const Bifiltration<Real>& bif_b,
+ CalculationParams<Real>& params);
- Real matching_distance(const ModulePresentation& mod_a, const ModulePresentation& mod_b, CalculationParams& params);
+ template<class Real>
+ Real matching_distance(const ModulePresentation<Real>& mod_a, const ModulePresentation<Real>& mod_b,
+ CalculationParams<Real>& params);
// for upper bound experiment
struct UbExperimentRecord {
- Real error;
- Real lower_bound;
- Real upper_bound;
- CellWithValue cell;
+ double error;
+ double lower_bound;
+ double upper_bound;
+ CellWithValue<double> cell;
long long int time;
long long int n_hera_calls;
};
- std::ostream& operator<<(std::ostream& os, const UbExperimentRecord& r);
+ inline std::ostream& operator<<(std::ostream& os, const UbExperimentRecord& r)
+ {
+ os << r.time << "\t" << r.n_hera_calls << "\t" << r.error << "\t" << r.lower_bound << "\t" << r.upper_bound;
+ return os;
+ }
+
+
+ template<class K, class V>
+ void print_map(const std::map<K, V>& dic)
+ {
+ for(const auto kv : dic) {
+ fmt::print("{} -> {}\n", kv.first, kv.second);
+ }
+ }
-}
+} // namespace md
#include "matching_distance.hpp"