diff options
author | Arnur Nigmetov <nigmetov@tugraz.at> | 2020-03-04 00:33:51 +0100 |
---|---|---|
committer | Arnur Nigmetov <nigmetov@tugraz.at> | 2020-03-04 00:33:51 +0100 |
commit | 3809e4071827a5959f27e472514eaed08ba6d15e (patch) | |
tree | 113fd1c373e4a04b568f8daf25324efafff6b107 /matching/include/matching_distance.h | |
parent | d56c07b093bfea1690a81ebbef41b8bb9c7c2464 (diff) |
Make matching distance header-only.
Diffstat (limited to 'matching/include/matching_distance.h')
-rw-r--r-- | matching/include/matching_distance.h | 203 |
1 files changed, 157 insertions, 46 deletions
diff --git a/matching/include/matching_distance.h b/matching/include/matching_distance.h index bb10203..5be34c7 100644 --- a/matching/include/matching_distance.h +++ b/matching/include/matching_distance.h @@ -4,9 +4,10 @@ #include <limits> #include <utility> #include <ostream> +#include <chrono> +#include <tuple> +#include <algorithm> -#include "spdlog/spdlog.h" -#include "spdlog/fmt/ostr.h" #include "common_defs.h" #include "cell_with_value.h" @@ -17,12 +18,15 @@ #include "bifiltration.h" #include "bottleneck.h" -namespace spd = spdlog; - namespace md { - using HeatMap = std::map<DualPoint, Real>; - using HeatMaps = std::map<int, HeatMap>; +#ifdef MD_PRINT_HEAT_MAP + template<class Real> + using HeatMap = std::map<DualPoint<Real>, Real>; + + template<class Real> + using HeatMaps = std::map<int, HeatMap<Real>>; +#endif enum class BoundStrategy { bruteforce, @@ -39,18 +43,107 @@ namespace md { upper_bound }; - std::ostream& operator<<(std::ostream& os, const BoundStrategy& s); - - std::ostream& operator<<(std::ostream& os, const TraverseStrategy& s); - - std::istream& operator>>(std::istream& is, BoundStrategy& s); - - std::istream& operator>>(std::istream& is, TraverseStrategy& s); - - BoundStrategy bs_from_string(std::string s); - - TraverseStrategy ts_from_string(std::string s); - + inline std::ostream& operator<<(std::ostream& os, const BoundStrategy& s) + { + switch(s) { + case BoundStrategy::bruteforce : + os << "bruteforce"; + break; + case BoundStrategy::local_dual_bound : + os << "local_grob"; + break; + case BoundStrategy::local_combined : + os << "local_combined"; + break; + case BoundStrategy::local_dual_bound_refined : + os << "local_refined"; + break; + case BoundStrategy::local_dual_bound_for_each_point : + os << "local_for_each_point"; + break; + default: + os << "FORGOTTEN BOUND STRATEGY"; + } + return os; + } + + inline std::ostream& operator<<(std::ostream& os, const TraverseStrategy& s) + { + switch(s) { + case TraverseStrategy::depth_first : + os << "DFS"; + break; + case TraverseStrategy::breadth_first : + os << "BFS"; + break; + case TraverseStrategy::breadth_first_value : + os << "BFS-VAL"; + break; + case TraverseStrategy::upper_bound : + os << "UB"; + break; + default: + os << "FORGOTTEN TRAVERSE STRATEGY"; + } + return os; + } + + inline std::istream& operator>>(std::istream& is, TraverseStrategy& s) + { + std::string ss; + is >> ss; + if (ss == "DFS") { + s = TraverseStrategy::depth_first; + } else if (ss == "BFS") { + s = TraverseStrategy::breadth_first; + } else if (ss == "BFS-VAL") { + s = TraverseStrategy::breadth_first_value; + } else if (ss == "UB") { + s = TraverseStrategy::upper_bound; + } else { + throw std::runtime_error("UNKNOWN TRAVERSE STRATEGY"); + } + return is; + } + + + inline std::istream& operator>>(std::istream& is, BoundStrategy& s) + { + std::string ss; + is >> ss; + if (ss == "bruteforce") { + s = BoundStrategy::bruteforce; + } else if (ss == "local_grob") { + s = BoundStrategy::local_dual_bound; + } else if (ss == "local_combined") { + s = BoundStrategy::local_combined; + } else if (ss == "local_refined") { + s = BoundStrategy::local_dual_bound_refined; + } else if (ss == "local_for_each_point") { + s = BoundStrategy::local_dual_bound_for_each_point; + } else { + throw std::runtime_error("UNKNOWN BOUND STRATEGY"); + } + return is; + } + + inline BoundStrategy bs_from_string(std::string s) + { + std::stringstream ss(s); + BoundStrategy result; + ss >> result; + return result; + } + + inline TraverseStrategy ts_from_string(std::string s) + { + std::stringstream ss(s); + TraverseStrategy result; + ss >> result; + return result; + } + + template<class Real> struct CalculationParams { static constexpr int ALL_DIMENSIONS = -1; @@ -75,22 +168,22 @@ namespace md { // print statistics on each quad-tree level bool print_stats { false }; -#ifdef PRINT_HEAT_MAP +#ifdef MD_PRINT_HEAT_MAP HeatMaps heat_maps; #endif }; - template<class DiagramProvider> + template<class Real_, class DiagramProvider> class DistanceCalculator { - using DualBox = md::DualBox; - using CellValueVector = std::vector<CellWithValue>; + using Real = Real_; + using CellValueVector = std::vector<CellWithValue<Real>>; public: DistanceCalculator(const DiagramProvider& a, const DiagramProvider& b, - CalculationParams& params); + CalculationParams<Real>& params); Real distance(); @@ -100,7 +193,7 @@ namespace md { DiagramProvider module_a_; DiagramProvider module_b_; - CalculationParams& params_; + CalculationParams<Real>& params_; int n_hera_calls_; std::map<int, int> n_hera_calls_per_level_; @@ -112,65 +205,83 @@ namespace md { CellValueVector get_initial_dual_grid(Real& lower_bound); +#ifdef MD_PRINT_HEAT_MAP void heatmap_in_dimension(int dim, int depth); +#endif Real get_max_x(int module) const; Real get_max_y(int module) const; - void set_cell_central_value(CellWithValue& dual_cell); + void set_cell_central_value(CellWithValue<Real>& dual_cell); Real get_distance(); Real get_distance_pq(); - // temporary, to try priority queue - Real get_max_possible_value(const CellWithValue* first_cell_ptr, int n_cells); + Real get_max_possible_value(const CellWithValue<Real>* first_cell_ptr, int n_cells); - Real get_upper_bound(const CellWithValue& dual_cell, Real good_enough_upper_bound) const; + Real get_upper_bound(const CellWithValue<Real>& dual_cell, Real good_enough_upper_bound) const; - Real get_single_dgm_bound(const CellWithValue& dual_cell, ValuePoint vp, int module, + Real get_single_dgm_bound(const CellWithValue<Real>& dual_cell, ValuePoint vp, int module, Real good_enough_value) const; // this bound depends only on dual box - Real get_local_dual_bound(int module, const DualBox& dual_box) const; + Real get_local_dual_bound(int module, const DualBox<Real>& dual_box) const; - Real get_local_dual_bound(const DualBox& dual_box) const; + Real get_local_dual_bound(const DualBox<Real>& dual_box) const; // this bound depends only on dual box, is more accurate - Real get_local_refined_bound(int module, const md::DualBox& dual_box) const; + Real get_local_refined_bound(int module, const DualBox<Real>& dual_box) const; - Real get_local_refined_bound(const md::DualBox& dual_box) const; + Real get_local_refined_bound(const DualBox<Real>& dual_box) const; Real get_good_enough_upper_bound(Real lower_bound) const; - Real - get_max_displacement_single_point(const CellWithValue& dual_cell, ValuePoint value_point, const Point& p) const; + Real get_max_displacement_single_point(const CellWithValue<Real>& dual_cell, ValuePoint value_point, + const Point<Real>& p) const; - void check_upper_bound(const CellWithValue& dual_cell) const; + void check_upper_bound(const CellWithValue<Real>& dual_cell) const; - Real distance_on_line(DualPoint line); - Real distance_on_line_const(DualPoint line) const; + Real distance_on_line(DualPoint<Real> line); + Real distance_on_line_const(DualPoint<Real> line) const; Real current_error(Real lower_bound, Real upper_bound); }; - Real matching_distance(const Bifiltration& bif_a, const Bifiltration& bif_b, CalculationParams& params); + template<class Real> + Real matching_distance(const Bifiltration<Real>& bif_a, const Bifiltration<Real>& bif_b, + CalculationParams<Real>& params); - Real matching_distance(const ModulePresentation& mod_a, const ModulePresentation& mod_b, CalculationParams& params); + template<class Real> + Real matching_distance(const ModulePresentation<Real>& mod_a, const ModulePresentation<Real>& mod_b, + CalculationParams<Real>& params); // for upper bound experiment struct UbExperimentRecord { - Real error; - Real lower_bound; - Real upper_bound; - CellWithValue cell; + double error; + double lower_bound; + double upper_bound; + CellWithValue<double> cell; long long int time; long long int n_hera_calls; }; - std::ostream& operator<<(std::ostream& os, const UbExperimentRecord& r); + inline std::ostream& operator<<(std::ostream& os, const UbExperimentRecord& r) + { + os << r.time << "\t" << r.n_hera_calls << "\t" << r.error << "\t" << r.lower_bound << "\t" << r.upper_bound; + return os; + } + + + template<class K, class V> + void print_map(const std::map<K, V>& dic) + { + for(const auto kv : dic) { + fmt::print("{} -> {}\n", kv.first, kv.second); + } + } -} +} // namespace md #include "matching_distance.hpp" |