summaryrefslogtreecommitdiff
path: root/matching/include/matching_distance.h
diff options
context:
space:
mode:
Diffstat (limited to 'matching/include/matching_distance.h')
-rw-r--r--matching/include/matching_distance.h290
1 files changed, 290 insertions, 0 deletions
diff --git a/matching/include/matching_distance.h b/matching/include/matching_distance.h
new file mode 100644
index 0000000..e1679dc
--- /dev/null
+++ b/matching/include/matching_distance.h
@@ -0,0 +1,290 @@
+#pragma once
+
+#include <vector>
+#include <limits>
+#include <utility>
+#include <ostream>
+#include <chrono>
+#include <tuple>
+#include <algorithm>
+
+
+#include "common_defs.h"
+#include "cell_with_value.h"
+#include "box.h"
+#include "dual_point.h"
+#include "dual_box.h"
+#include "persistence_module.h"
+#include "bifiltration.h"
+#include "bottleneck.h"
+
+namespace md {
+
+#ifdef MD_PRINT_HEAT_MAP
+ template<class Real>
+ using HeatMap = std::map<DualPoint<Real>, Real>;
+
+ template<class Real>
+ using HeatMaps = std::map<int, HeatMap<Real>>;
+#endif
+
+ enum class BoundStrategy {
+ bruteforce,
+ local_dual_bound,
+ local_dual_bound_refined,
+ local_dual_bound_for_each_point,
+ local_combined
+ };
+
+ enum class TraverseStrategy {
+ depth_first,
+ breadth_first,
+ breadth_first_value,
+ upper_bound
+ };
+
+ inline std::ostream& operator<<(std::ostream& os, const BoundStrategy& s)
+ {
+ switch(s) {
+ case BoundStrategy::bruteforce :
+ os << "bruteforce";
+ break;
+ case BoundStrategy::local_dual_bound :
+ os << "local_grob";
+ break;
+ case BoundStrategy::local_combined :
+ os << "local_combined";
+ break;
+ case BoundStrategy::local_dual_bound_refined :
+ os << "local_refined";
+ break;
+ case BoundStrategy::local_dual_bound_for_each_point :
+ os << "local_for_each_point";
+ break;
+ default:
+ os << "FORGOTTEN BOUND STRATEGY";
+ }
+ return os;
+ }
+
+ inline std::ostream& operator<<(std::ostream& os, const TraverseStrategy& s)
+ {
+ switch(s) {
+ case TraverseStrategy::depth_first :
+ os << "DFS";
+ break;
+ case TraverseStrategy::breadth_first :
+ os << "BFS";
+ break;
+ case TraverseStrategy::breadth_first_value :
+ os << "BFS-VAL";
+ break;
+ case TraverseStrategy::upper_bound :
+ os << "UB";
+ break;
+ default:
+ os << "FORGOTTEN TRAVERSE STRATEGY";
+ }
+ return os;
+ }
+
+ inline std::istream& operator>>(std::istream& is, TraverseStrategy& s)
+ {
+ std::string ss;
+ is >> ss;
+ if (ss == "DFS") {
+ s = TraverseStrategy::depth_first;
+ } else if (ss == "BFS") {
+ s = TraverseStrategy::breadth_first;
+ } else if (ss == "BFS-VAL") {
+ s = TraverseStrategy::breadth_first_value;
+ } else if (ss == "UB") {
+ s = TraverseStrategy::upper_bound;
+ } else {
+ throw std::runtime_error("UNKNOWN TRAVERSE STRATEGY");
+ }
+ return is;
+ }
+
+
+ inline std::istream& operator>>(std::istream& is, BoundStrategy& s)
+ {
+ std::string ss;
+ is >> ss;
+ if (ss == "bruteforce") {
+ s = BoundStrategy::bruteforce;
+ } else if (ss == "local_grob") {
+ s = BoundStrategy::local_dual_bound;
+ } else if (ss == "local_combined") {
+ s = BoundStrategy::local_combined;
+ } else if (ss == "local_refined") {
+ s = BoundStrategy::local_dual_bound_refined;
+ } else if (ss == "local_for_each_point") {
+ s = BoundStrategy::local_dual_bound_for_each_point;
+ } else {
+ throw std::runtime_error("UNKNOWN BOUND STRATEGY");
+ }
+ return is;
+ }
+
+ inline BoundStrategy bs_from_string(std::string s)
+ {
+ std::stringstream ss(s);
+ BoundStrategy result;
+ ss >> result;
+ return result;
+ }
+
+ inline TraverseStrategy ts_from_string(std::string s)
+ {
+ std::stringstream ss(s);
+ TraverseStrategy result;
+ ss >> result;
+ return result;
+ }
+
+ template<class Real>
+ struct CalculationParams {
+ static constexpr int ALL_DIMENSIONS = -1;
+
+ Real hera_epsilon {0.001}; // relative error in hera call
+ Real delta {0.1}; // relative error for matching distance
+ int max_depth {8}; // maximal number of refinenemnts
+ int initialization_depth {2};
+ int dim {0}; // in which dim to calculate the distance; use ALL_DIMENSIONS to get max over all dims
+ BoundStrategy bound_strategy {BoundStrategy::local_combined};
+ TraverseStrategy traverse_strategy {TraverseStrategy::breadth_first};
+ bool tolerate_max_iter_exceeded {false};
+ Real actual_error {std::numeric_limits<Real>::max()};
+ int actual_max_depth {0};
+ int n_hera_calls {0}; // for experiments only; is set in matching_distance function, input value is ignored
+
+ // stop looping over points immediately, if current point's displacement is too large
+ // to prune the cell
+ // if true, cells are pruned immediately, and bounds may increase
+ // (just return something large enough to not prune the cell)
+ bool stop_asap { true };
+
+ // print statistics on each quad-tree level
+ bool print_stats { false };
+
+#ifdef MD_PRINT_HEAT_MAP
+ HeatMaps<Real> heat_maps;
+#endif
+ };
+
+
+ template<class Real_, class DiagramProvider>
+ class DistanceCalculator {
+
+ using Real = Real_;
+ using CellValueVector = std::vector<CellWithValue<Real>>;
+
+ public:
+ DistanceCalculator(const DiagramProvider& a,
+ const DiagramProvider& b,
+ CalculationParams<Real>& params);
+
+ Real distance();
+
+ int get_hera_calls_number() const;
+
+#ifndef MD_TEST_CODE
+ private:
+#endif
+
+ DiagramProvider module_a_;
+ DiagramProvider module_b_;
+
+ CalculationParams<Real>& params_;
+
+ int n_hera_calls_;
+ std::map<int, int> n_hera_calls_per_level_;
+ Real distance_;
+
+ // if calculate_on_intermediate, then weighted distance
+ // will be calculated on centers of each grid in between
+ CellValueVector get_refined_grid(int init_depth, bool calculate_on_intermediate, bool calculate_on_last = true);
+
+ CellValueVector get_initial_dual_grid(Real& lower_bound);
+
+#ifdef MD_PRINT_HEAT_MAP
+ void heatmap_in_dimension(int dim, int depth);
+#endif
+
+ Real get_max_x(int module) const;
+
+ Real get_max_y(int module) const;
+
+ void set_cell_central_value(CellWithValue<Real>& dual_cell);
+
+ Real get_distance();
+
+ Real get_distance_pq();
+
+ Real get_max_possible_value(const CellWithValue<Real>* first_cell_ptr, int n_cells);
+
+ Real get_upper_bound(const CellWithValue<Real>& dual_cell, Real good_enough_upper_bound) const;
+
+ Real get_single_dgm_bound(const CellWithValue<Real>& dual_cell, ValuePoint vp, int module,
+ Real good_enough_value) const;
+
+ // this bound depends only on dual box
+ Real get_local_dual_bound(int module, const DualBox<Real>& dual_box) const;
+
+ Real get_local_dual_bound(const DualBox<Real>& dual_box) const;
+
+ // this bound depends only on dual box, is more accurate
+ Real get_local_refined_bound(int module, const DualBox<Real>& dual_box) const;
+
+ Real get_local_refined_bound(const DualBox<Real>& dual_box) const;
+
+ Real get_good_enough_upper_bound(Real lower_bound) const;
+
+ Real get_max_displacement_single_point(const CellWithValue<Real>& dual_cell, ValuePoint value_point,
+ const Point<Real>& p) const;
+
+ void check_upper_bound(const CellWithValue<Real>& dual_cell) const;
+
+ Real distance_on_line(DualPoint<Real> line);
+ Real distance_on_line_const(DualPoint<Real> line) const;
+
+ Real current_error(Real lower_bound, Real upper_bound);
+ };
+
+ template<class Real>
+ Real matching_distance(const Bifiltration<Real>& bif_a, const Bifiltration<Real>& bif_b,
+ CalculationParams<Real>& params);
+
+ template<class Real>
+ Real matching_distance(const ModulePresentation<Real>& mod_a, const ModulePresentation<Real>& mod_b,
+ CalculationParams<Real>& params);
+
+ // for upper bound experiment
+ struct UbExperimentRecord {
+ double error;
+ double lower_bound;
+ double upper_bound;
+ CellWithValue<double> cell;
+ long long int time;
+ long long int n_hera_calls;
+ };
+
+ inline std::ostream& operator<<(std::ostream& os, const UbExperimentRecord& r)
+ {
+ os << r.time << "\t" << r.n_hera_calls << "\t" << r.error << "\t" << r.lower_bound << "\t" << r.upper_bound;
+ return os;
+ }
+
+
+ template<class K, class V>
+ void print_map(const std::map<K, V>& dic)
+ {
+ for(const auto kv : dic) {
+ fmt::print("{} -> {}\n", kv.first, kv.second);
+ }
+ }
+
+} // namespace md
+
+#include "matching_distance.hpp"