summaryrefslogtreecommitdiff
path: root/matching/include/matching_distance.h
diff options
context:
space:
mode:
Diffstat (limited to 'matching/include/matching_distance.h')
-rw-r--r--matching/include/matching_distance.h168
1 files changed, 168 insertions, 0 deletions
diff --git a/matching/include/matching_distance.h b/matching/include/matching_distance.h
new file mode 100644
index 0000000..603b535
--- /dev/null
+++ b/matching/include/matching_distance.h
@@ -0,0 +1,168 @@
+#pragma once
+
+#include <vector>
+#include <limits>
+#include <utility>
+#include <ostream>
+
+#include "spdlog/spdlog.h"
+
+#include "common_defs.h"
+#include "cell_with_value.h"
+#include "box.h"
+#include "dual_point.h"
+#include "dual_box.h"
+#include "persistence_module.h"
+#include "bifiltration.h"
+#include "bottleneck.h"
+
+namespace spd = spdlog;
+
+namespace md {
+
+ using HeatMap = std::map<DualPoint, Real>;
+ using HeatMaps = std::map<int, HeatMap>;
+
+ enum class BoundStrategy {
+ bruteforce,
+ local_dual_bound,
+ local_dual_bound_refined,
+ local_dual_bound_for_each_point,
+ local_combined
+ };
+
+ enum class TraverseStrategy {
+ depth_first,
+ breadth_first,
+ breadth_first_value,
+ upper_bound
+ };
+
+ std::ostream& operator<<(std::ostream& os, const BoundStrategy& s);
+
+ std::ostream& operator<<(std::ostream& os, const TraverseStrategy& s);
+
+ std::istream& operator>>(std::istream& is, BoundStrategy& s);
+
+ std::istream& operator>>(std::istream& is, TraverseStrategy& s);
+
+ BoundStrategy bs_from_string(std::string s);
+
+ TraverseStrategy ts_from_string(std::string s);
+
+ struct CalculationParams {
+ static constexpr int ALL_DIMENSIONS = -1;
+
+ Real hera_epsilon {0.01}; // relative error in hera call
+ Real delta {0.1}; // relative error for matching distance
+ int max_depth {6}; // maximal number of refinenemnts
+ int initialization_depth {3};
+ int dim {0}; // in which dim to calculate the distance; use ALL_DIMENSIONS to get max over all dims
+ BoundStrategy bound_strategy {BoundStrategy::bruteforce};
+ TraverseStrategy traverse_strategy {TraverseStrategy::breadth_first};
+ bool tolerate_max_iter_exceeded {true};
+ Real actual_error {std::numeric_limits<Real>::max()};
+ int actual_max_depth {0};
+ int n_hera_calls {0}; // for experiments only; is set in matching_distance function, input value is ignored
+
+ // stop looping over points immediately, if current point's displacement is too large
+ // to prune the cell
+ // if true, cells are pruned immediately, and bounds may be unreliable
+ // (we just return something large enough to prune the cell)
+ bool stop_asap { true };
+
+#ifdef PRINT_HEAT_MAP
+ HeatMaps heat_maps;
+#endif
+ };
+
+ class DistanceCalculator {
+
+ using DiagramProvider = md::Bifiltration;
+ using DualBox = md::DualBox;
+
+ using CellValueVector = std::vector<CellWithValue>;
+
+ public:
+ DistanceCalculator(const DiagramProvider& a,
+ const DiagramProvider& b,
+ CalculationParams& params);
+
+ Real distance();
+
+ int get_hera_calls_number() const
+ {
+ int result;
+ for(auto c : n_hera_calls_) result += c.second;
+ return result;
+ }
+
+ int get_hera_calls_number(int dim) const;
+
+ void clear_cache();
+
+// for tests - make everything p
+// private:
+ DiagramProvider module_a_;
+ DiagramProvider module_b_;
+
+ CalculationParams& params_;
+
+ int maximal_dim_ {0};
+
+ std::map<int, int> n_hera_calls_;
+ std::map<int, int> n_hera_calls_per_level_;
+ std::vector<Real> distances_; // indexed by dim
+
+ // if calculate_on_intermediate, then weighted distance
+ // will be calculated on centers of each grid in between
+ CellValueVector get_refined_grid(int init_depth, bool calculate_on_intermediate, bool calculate_on_last = true);
+
+ CellValueVector get_initial_dual_grid(Real& lower_bound);
+
+ void heatmap_in_dimension(int dim, int depth);
+
+ Real get_max_x(int module) const;
+
+ Real get_max_y(int module) const;
+
+ void set_cell_central_value(CellWithValue& dual_cell, int dim);
+
+ Real distance_in_dimension(int dim);
+
+ Real distance_in_dimension_pq(int dim);
+
+ // temporary, to try priority queue
+ Real get_max_possible_value(const CellWithValue* first_cell_ptr, int n_cells);
+
+ Real get_upper_bound(const CellWithValue& dual_cell, int dim, Real good_enough_upper_bound) const;
+
+ Real get_single_dgm_bound(const CellWithValue& dual_cell, ValuePoint vp, int module, int dim,
+ Real good_enough_value) const;
+
+ // this bound depends only on dual box
+ Real get_local_dual_bound(int module, const DualBox& dual_box) const;
+
+ Real get_local_dual_bound(const DualBox& dual_box) const;
+
+ // this bound depends only on dual box, is more accurate
+ Real get_local_refined_bound(int module, const md::DualBox& dual_box) const;
+
+ Real get_local_refined_bound(const md::DualBox& dual_box) const;
+
+ Real get_good_enough_upper_bound(Real lower_bound) const;
+
+ Real
+ get_max_displacement_single_point(const CellWithValue& dual_cell, ValuePoint value_point, const Point& p) const;
+
+ void check_upper_bound(const CellWithValue& dual_cell, int dim) const;
+
+ Real distance_on_line(int dim, DualPoint line);
+ Real distance_on_line_const(int dim, DualPoint line) const;
+
+ Real current_error(Real lower_bound, Real upper_bound);
+ };
+
+ Real matching_distance(const Bifiltration& bif_a, const Bifiltration& bif_b, CalculationParams& params);
+
+}