diff options
Diffstat (limited to 'matching/include/matching_distance.hpp')
-rw-r--r-- | matching/include/matching_distance.hpp | 326 |
1 files changed, 177 insertions, 149 deletions
diff --git a/matching/include/matching_distance.hpp b/matching/include/matching_distance.hpp index d2d2fbc..48c8464 100644 --- a/matching/include/matching_distance.hpp +++ b/matching/include/matching_distance.hpp @@ -1,34 +1,26 @@ namespace md { - template<class K, class V> - void print_map(const std::map<K, V>& dic) - { - for(const auto kv : dic) { - fmt::print("{} -> {}\n", kv.first, kv.second); - } - } - - template<class T> - void DistanceCalculator<T>::check_upper_bound(const CellWithValue& dual_cell) const + template<class R, class T> + void DistanceCalculator<R, T>::check_upper_bound(const CellWithValue<R>& dual_cell) const { spd::debug("Enter check_get_max_delta_on_cell"); const int n_samples_lambda = 100; const int n_samples_mu = 100; - DualBox db = dual_cell.dual_box(); - Real min_lambda = db.lambda_min(); - Real max_lambda = db.lambda_max(); - Real min_mu = db.mu_min(); - Real max_mu = db.mu_max(); - - Real h_lambda = (max_lambda - min_lambda) / n_samples_lambda; - Real h_mu = (max_mu - min_mu) / n_samples_mu; + DualBox<R> db = dual_cell.dual_box(); + R min_lambda = db.lambda_min(); + R max_lambda = db.lambda_max(); + R min_mu = db.mu_min(); + R max_mu = db.mu_max(); + + R h_lambda = (max_lambda - min_lambda) / n_samples_lambda; + R h_mu = (max_mu - min_mu) / n_samples_mu; for(int i = 1; i < n_samples_lambda; ++i) { for(int j = 1; j < n_samples_mu; ++j) { - Real lambda = min_lambda + i * h_lambda; - Real mu = min_mu + j * h_mu; - DualPoint l(db.axis_type(), db.angle_type(), lambda, mu); - Real other_result = distance_on_line_const(l); - Real diff = fabs(dual_cell.stored_upper_bound() - other_result); + R lambda = min_lambda + i * h_lambda; + R mu = min_mu + j * h_mu; + DualPoint<R> l(db.axis_type(), db.angle_type(), lambda, mu); + R other_result = distance_on_line_const(l); + R diff = fabs(dual_cell.stored_upper_bound() - other_result); if (other_result > dual_cell.stored_upper_bound()) { spd::error( "in check_upper_bound, upper_bound = {}, other_result = {}, diff = {}\ndual_cell = {}", @@ -42,10 +34,10 @@ namespace md { // for all lines l, l' inside dual box, // find the upper bound on the difference of weighted pushes of p - template<class T> - Real - DistanceCalculator<T>::get_max_displacement_single_point(const CellWithValue& dual_cell, ValuePoint vp, - const Point& p) const + template<class R, class T> + R + DistanceCalculator<R, T>::get_max_displacement_single_point(const CellWithValue<R>& dual_cell, ValuePoint vp, + const Point<R>& p) const { assert(p.x >= 0 && p.y >= 0); @@ -53,15 +45,15 @@ namespace md { std::vector<long long int> debug_ids = {3, 13, 54, 218, 350, 382, 484, 795, 2040, 8415, 44076}; bool debug = false; // std::find(debug_ids.begin(), debug_ids.end(), dual_cell.id) != debug_ids.end(); #endif - DualPoint line = dual_cell.value_point(vp); - const Real base_value = line.weighted_push(p); + DualPoint<R> line = dual_cell.value_point(vp); + const R base_value = line.weighted_push(p); spd::debug("Enter get_max_displacement_single_point, p = {},\ndual_cell = {},\nline = {}, base_value = {}\n", p, dual_cell, line, base_value); - Real result = 0.0; - for(DualPoint dp : dual_cell.dual_box().critical_points(p)) { - Real dp_value = dp.weighted_push(p); + R result = 0.0; + for(DualPoint<R> dp : dual_cell.dual_box().critical_points(p)) { + R dp_value = dp.weighted_push(p); spd::debug( "In get_max_displacement_single_point, p = {}, critical dp = {},\ndp_value = {}, diff = {},\ndual_cell = {}\n", p, dp, dp_value, fabs(base_value - dp_value), dual_cell); @@ -69,15 +61,15 @@ namespace md { } #ifdef MD_DO_FULL_CHECK - DualBox db = dual_cell.dual_box(); - std::uniform_real_distribution<Real> dlambda(db.lambda_min(), db.lambda_max()); - std::uniform_real_distribution<Real> dmu(db.mu_min(), db.mu_max()); + auto db = dual_cell.dual_box(); + std::uniform_real_distribution<R> dlambda(db.lambda_min(), db.lambda_max()); + std::uniform_real_distribution<R> dmu(db.mu_min(), db.mu_max()); std::mt19937 gen(1); for(int i = 0; i < 1000; ++i) { - Real lambda = dlambda(gen); - Real mu = dmu(gen); - DualPoint dp_random { db.axis_type(), db.angle_type(), lambda, mu }; - Real dp_value = dp_random.weighted_push(p); + R lambda = dlambda(gen); + R mu = dmu(gen); + DualPoint<R> dp_random { db.axis_type(), db.angle_type(), lambda, mu }; + R dp_value = dp_random.weighted_push(p); if (fabs(base_value - dp_value) > result) { spd::error("in get_max_displacement_single_point, p = {}, vp = {}\ndb = {}\nresult = {}, base_value = {}, dp_value = {}, dp_random = {}", p, vp, db, result, base_value, dp_value, dp_random); @@ -89,12 +81,12 @@ namespace md { return result; } - template<class T> - typename DistanceCalculator<T>::CellValueVector DistanceCalculator<T>::get_initial_dual_grid(Real& lower_bound) + template<class R, class T> + typename DistanceCalculator<R, T>::CellValueVector DistanceCalculator<R, T>::get_initial_dual_grid(R& lower_bound) { CellValueVector result = get_refined_grid(params_.initialization_depth, false, true); - lower_bound = -1.0; + lower_bound = -1; for(const auto& dc : result) { lower_bound = std::max(lower_bound, dc.max_corner_value()); } @@ -102,8 +94,8 @@ namespace md { assert(lower_bound >= 0); for(auto& dual_cell : result) { - Real good_enough_ub = get_good_enough_upper_bound(lower_bound); - Real max_value_on_cell = get_upper_bound(dual_cell, good_enough_ub); + R good_enough_ub = get_good_enough_upper_bound(lower_bound); + R max_value_on_cell = get_upper_bound(dual_cell, good_enough_ub); dual_cell.set_max_possible_value(max_value_on_cell); #ifdef MD_DO_FULL_CHECK @@ -116,39 +108,39 @@ namespace md { return result; } - template<class T> - typename DistanceCalculator<T>::CellValueVector - DistanceCalculator<T>::get_refined_grid(int init_depth, bool calculate_on_intermediate, bool calculate_on_last) + template<class R, class T> + typename DistanceCalculator<R, T>::CellValueVector + DistanceCalculator<R, T>::get_refined_grid(int init_depth, bool calculate_on_intermediate, bool calculate_on_last) { - const Real y_max = std::max(module_a_.max_y(), module_b_.max_y()); - const Real x_max = std::max(module_a_.max_x(), module_b_.max_x()); + const R y_max = std::max(module_a_.max_y(), module_b_.max_y()); + const R x_max = std::max(module_a_.max_x(), module_b_.max_x()); - const Real lambda_min = 0; - const Real lambda_max = 1; + const R lambda_min = 0; + const R lambda_max = 1; - const Real mu_min = 0; + const R mu_min = 0; - DualBox x_flat(DualPoint(AxisType::x_type, AngleType::flat, lambda_min, mu_min), - DualPoint(AxisType::x_type, AngleType::flat, lambda_max, x_max)); + DualBox<R> x_flat(DualPoint<R>(AxisType::x_type, AngleType::flat, lambda_min, mu_min), + DualPoint<R>(AxisType::x_type, AngleType::flat, lambda_max, x_max)); - DualBox x_steep(DualPoint(AxisType::x_type, AngleType::steep, lambda_min, mu_min), - DualPoint(AxisType::x_type, AngleType::steep, lambda_max, x_max)); + DualBox<R> x_steep(DualPoint<R>(AxisType::x_type, AngleType::steep, lambda_min, mu_min), + DualPoint<R>(AxisType::x_type, AngleType::steep, lambda_max, x_max)); - DualBox y_flat(DualPoint(AxisType::y_type, AngleType::flat, lambda_min, mu_min), - DualPoint(AxisType::y_type, AngleType::flat, lambda_max, y_max)); + DualBox<R> y_flat(DualPoint<R>(AxisType::y_type, AngleType::flat, lambda_min, mu_min), + DualPoint<R>(AxisType::y_type, AngleType::flat, lambda_max, y_max)); - DualBox y_steep(DualPoint(AxisType::y_type, AngleType::steep, lambda_min, mu_min), - DualPoint(AxisType::y_type, AngleType::steep, lambda_max, y_max)); + DualBox<R> y_steep(DualPoint<R>(AxisType::y_type, AngleType::steep, lambda_min, mu_min), + DualPoint<R>(AxisType::y_type, AngleType::steep, lambda_max, y_max)); - CellWithValue x_flat_cell(x_flat, 0); - CellWithValue x_steep_cell(x_steep, 0); - CellWithValue y_flat_cell(y_flat, 0); - CellWithValue y_steep_cell(y_steep, 0); + CellWithValue<R> x_flat_cell(x_flat, 0); + CellWithValue<R> x_steep_cell(x_steep, 0); + CellWithValue<R> y_flat_cell(y_flat, 0); + CellWithValue<R> y_steep_cell(y_steep, 0); if (init_depth == 0) { - DualPoint diagonal_x_flat(AxisType::x_type, AngleType::flat, 1, 0); + DualPoint<R> diagonal_x_flat(AxisType::x_type, AngleType::flat, 1, 0); - Real diagonal_value = distance_on_line(diagonal_x_flat); + R diagonal_value = distance_on_line(diagonal_x_flat); n_hera_calls_per_level_[0]++; x_flat_cell.set_value_at(ValuePoint::lower_right, diagonal_value); @@ -162,7 +154,7 @@ namespace md { x_steep_cell.id = 2; y_flat_cell.id = 3; y_steep_cell.id = 4; - CellWithValue::max_id = 4; + CellWithValue<R>::max_id = 4; #endif CellValueVector result {x_flat_cell, x_steep_cell, y_flat_cell, y_steep_cell}; @@ -189,10 +181,10 @@ namespace md { return result; } - template<class T> - DistanceCalculator<T>::DistanceCalculator(const T& a, + template<class R, class T> + DistanceCalculator<R, T>::DistanceCalculator(const T& a, const T& b, - CalculationParams& params) + CalculationParams<R>& params) : module_a_(a), module_b_(b), @@ -213,33 +205,33 @@ namespace md { module_a_.max_x(), module_a_.max_y(), module_b_.max_x(), module_b_.max_y()); } - template<class T> - Real DistanceCalculator<T>::get_max_x(int module) const + template<class R, class T> + R DistanceCalculator<R, T>::get_max_x(int module) const { return (module == 0) ? module_a_.max_x() : module_b_.max_x(); } - template<class T> - Real DistanceCalculator<T>::get_max_y(int module) const + template<class R, class T> + R DistanceCalculator<R, T>::get_max_y(int module) const { return (module == 0) ? module_a_.max_y() : module_b_.max_y(); } - template<class T> - Real - DistanceCalculator<T>::get_local_refined_bound(const md::DualBox& dual_box) const + template<class R, class T> + R + DistanceCalculator<R, T>::get_local_refined_bound(const DualBox<R>& dual_box) const { return get_local_refined_bound(0, dual_box) + get_local_refined_bound(1, dual_box); } - template<class T> - Real - DistanceCalculator<T>::get_local_refined_bound(int module, const md::DualBox& dual_box) const + template<class R, class T> + R + DistanceCalculator<R, T>::get_local_refined_bound(int module, const DualBox<R>& dual_box) const { spd::debug("Enter get_local_refined_bound, dual_box = {}", dual_box); - Real d_lambda = dual_box.lambda_max() - dual_box.lambda_min(); - Real d_mu = dual_box.mu_max() - dual_box.mu_min(); - Real result; + R d_lambda = dual_box.lambda_max() - dual_box.lambda_min(); + R d_mu = dual_box.mu_max() - dual_box.mu_min(); + R result; if (dual_box.axis_type() == AxisType::x_type) { if (dual_box.is_flat()) { result = dual_box.lambda_max() * d_mu + (get_max_x(module) - dual_box.mu_min()) * d_lambda; @@ -258,11 +250,11 @@ namespace md { return result; } - template<class T> - Real DistanceCalculator<T>::get_local_dual_bound(int module, const md::DualBox& dual_box) const + template<class R, class T> + R DistanceCalculator<R, T>::get_local_dual_bound(int module, const DualBox<R>& dual_box) const { - Real dlambda = dual_box.lambda_max() - dual_box.lambda_min(); - Real dmu = dual_box.mu_max() - dual_box.mu_min(); + R dlambda = dual_box.lambda_max() - dual_box.lambda_min(); + R dmu = dual_box.mu_max() - dual_box.mu_min(); if (dual_box.is_flat()) { return get_max_x(module) * dlambda + dmu; @@ -271,20 +263,20 @@ namespace md { } } - template<class T> - Real DistanceCalculator<T>::get_local_dual_bound(const md::DualBox& dual_box) const + template<class R, class T> + R DistanceCalculator<R, T>::get_local_dual_bound(const DualBox<R>& dual_box) const { return get_local_dual_bound(0, dual_box) + get_local_dual_bound(1, dual_box); } - template<class T> - Real DistanceCalculator<T>::get_upper_bound(const CellWithValue& dual_cell, Real good_enough_ub) const + template<class R, class T> + R DistanceCalculator<R, T>::get_upper_bound(const CellWithValue<R>& dual_cell, R good_enough_ub) const { assert(good_enough_ub >= 0); switch(params_.bound_strategy) { case BoundStrategy::bruteforce: - return std::numeric_limits<Real>::max(); + return std::numeric_limits<R>::max(); case BoundStrategy::local_dual_bound: return dual_cell.min_value() + get_local_dual_bound(dual_cell.dual_box()); @@ -293,7 +285,7 @@ namespace md { return dual_cell.min_value() + get_local_refined_bound(dual_cell.dual_box()); case BoundStrategy::local_combined: { - Real cheap_upper_bound = dual_cell.min_value() + get_local_refined_bound(dual_cell.dual_box()); + R cheap_upper_bound = dual_cell.min_value() + get_local_refined_bound(dual_cell.dual_box()); if (cheap_upper_bound < good_enough_ub) { return cheap_upper_bound; } else { @@ -302,14 +294,14 @@ namespace md { } case BoundStrategy::local_dual_bound_for_each_point: { - Real result = std::numeric_limits<Real>::max(); + R result = std::numeric_limits<R>::max(); for(ValuePoint vp : k_corner_vps) { if (not dual_cell.has_value_at(vp)) { continue; } - Real base_value = dual_cell.value_at(vp); - Real bound_dgm_a = get_single_dgm_bound(dual_cell, vp, 0, good_enough_ub); + R base_value = dual_cell.value_at(vp); + R bound_dgm_a = get_single_dgm_bound(dual_cell, vp, 0, good_enough_ub); if (params_.stop_asap and bound_dgm_a + base_value >= good_enough_ub) { // we want to return a valid upper bound, not just something that will prevent discarding the cell @@ -318,8 +310,8 @@ namespace md { return dual_cell.min_value() + get_local_refined_bound(dual_cell.dual_box()); } - Real bound_dgm_b = get_single_dgm_bound(dual_cell, vp, 1, - std::max(Real(0), good_enough_ub - bound_dgm_a)); + R bound_dgm_b = get_single_dgm_bound(dual_cell, vp, 1, + std::max(R(0), good_enough_ub - bound_dgm_a)); result = std::min(result, base_value + bound_dgm_a + bound_dgm_b); @@ -336,19 +328,19 @@ namespace md { } } // to suppress compiler warning - return std::numeric_limits<Real>::max(); + return std::numeric_limits<R>::max(); } // find maximal displacement of weighted points of m for all lines in dual_box - template<class T> - Real - DistanceCalculator<T>::get_single_dgm_bound(const CellWithValue& dual_cell, + template<class R, class T> + R + DistanceCalculator<R, T>::get_single_dgm_bound(const CellWithValue<R>& dual_cell, ValuePoint vp, int module, - [[maybe_unused]] Real good_enough_value) const + R good_enough_value) const { - Real result = 0; - Point max_point; + R result = 0; + Point<R> max_point; spd::debug( "Enter get_single_dgm_bound, module = {}, dual_cell = {}, vp = {}, good_enough_value = {}, stop_asap = {}\n", @@ -358,7 +350,7 @@ namespace md { for(const auto& position : m.positions()) { spd::debug("in get_single_dgm_bound, simplex = {}\n", position); - Real x = get_max_displacement_single_point(dual_cell, vp, position); + R x = get_max_displacement_single_point(dual_cell, vp, position); spd::debug("In get_single_dgm_bound, point = {}, displacement = {}", position, x); @@ -385,30 +377,30 @@ namespace md { return result; } - template<class T> - Real DistanceCalculator<T>::distance() + template<class R, class T> + R DistanceCalculator<R, T>::distance() { return get_distance_pq(); } // calculate weighted bottleneneck distance between slices on line // increments hera calls counter - template<class T> - Real DistanceCalculator<T>::distance_on_line(DualPoint line) + template<class R, class T> + R DistanceCalculator<R, T>::distance_on_line(DualPoint<R> line) { ++n_hera_calls_; - Real result = distance_on_line_const(line); + R result = distance_on_line_const(line); return result; } - template<class T> - Real DistanceCalculator<T>::distance_on_line_const(DualPoint line) const + template<class R, class T> + R DistanceCalculator<R, T>::distance_on_line_const(DualPoint<R> line) const { // TODO: think about this - how to call Hera auto dgm_a = module_a_.weighted_slice_diagram(line); auto dgm_b = module_b_.weighted_slice_diagram(line); - Real result; - if (params_.hera_epsilon > static_cast<Real>(0)) { + R result; + if (params_.hera_epsilon > static_cast<R>(0)) { result = hera::bottleneckDistApprox(dgm_a, dgm_b, params_.hera_epsilon) / ( params_.hera_epsilon + 1); } else { result = hera::bottleneckDistExact(dgm_a, dgm_b); @@ -423,10 +415,10 @@ namespace md { return result; } - template<class T> - Real DistanceCalculator<T>::get_good_enough_upper_bound(Real lower_bound) const + template<class R, class T> + R DistanceCalculator<R, T>::get_good_enough_upper_bound(R lower_bound) const { - Real result; + R result; // in upper_bound strategy we only prune cells if they cannot improve the lower bound, // otherwise the experiment is supposed to run indefinitely if (params_.traverse_strategy == TraverseStrategy::upper_bound) { @@ -440,14 +432,14 @@ namespace md { // helper function // calculate weighted bt distance on cell center, // assign distance value to cell, keep it in heat_map, and return - template<class T> - void DistanceCalculator<T>::set_cell_central_value(CellWithValue& dual_cell) + template<class R, class T> + void DistanceCalculator<R, T>::set_cell_central_value(CellWithValue<R>& dual_cell) { - DualPoint central_line {dual_cell.center()}; + DualPoint<R> central_line {dual_cell.center()}; spd::debug("In set_cell_central_value, processing dual cell = {}, line = {}", dual_cell.dual_box(), central_line); - Real new_value = distance_on_line(central_line); + R new_value = distance_on_line(central_line); n_hera_calls_per_level_[dual_cell.level() + 1]++; dual_cell.set_value_at(ValuePoint::center, new_value); params_.actual_max_depth = std::max(params_.actual_max_depth, dual_cell.level() + 1); @@ -472,10 +464,10 @@ namespace md { // assumes that the underlying container is vector! // cell_ptr: pointer to the first element in queue // n_cells: queue size - template<class T> - Real DistanceCalculator<T>::get_max_possible_value(const CellWithValue* cell_ptr, int n_cells) + template<class R, class T> + R DistanceCalculator<R, T>::get_max_possible_value(const CellWithValue<R>* cell_ptr, int n_cells) { - Real result = (n_cells > 0) ? cell_ptr->stored_upper_bound() : 0; + R result = (n_cells > 0) ? cell_ptr->stored_upper_bound() : 0; for(int i = 0; i < n_cells; ++i, ++cell_ptr) { result = std::max(result, cell_ptr->stored_upper_bound()); } @@ -485,11 +477,11 @@ namespace md { // helper function: // return current error from lower and upper bounds // and save it in params_ (hence not const) - template<class T> - Real DistanceCalculator<T>::current_error(Real lower_bound, Real upper_bound) + template<class R, class T> + R DistanceCalculator<R, T>::current_error(R lower_bound, R upper_bound) { - Real current_error = (lower_bound > 0.0) ? (upper_bound - lower_bound) / lower_bound - : std::numeric_limits<Real>::max(); + R current_error = (lower_bound > 0.0) ? (upper_bound - lower_bound) / lower_bound + : std::numeric_limits<R>::max(); params_.actual_error = current_error; @@ -505,8 +497,8 @@ namespace md { // use priority queue to store dual cells // comparison function depends on the strategies in params_ // ressets hera calls counter - template<class T> - Real DistanceCalculator<T>::get_distance_pq() + template<class R, class T> + R DistanceCalculator<R, T>::get_distance_pq() { std::map<int, long> n_cells_considered; std::map<int, long> n_cells_pushed_into_queue; @@ -527,26 +519,26 @@ namespace md { // if cell is too deep and is not pushed into queue, // we still need to take its max value into account; // the max over such cells is stored in max_result_on_too_fine_cells - Real upper_bound_on_deep_cells = -1; + R upper_bound_on_deep_cells = -1; spd::debug("Started iterations in dual space, delta = {}, bound_strategy = {}", params_.delta, params_.bound_strategy); // user-defined less lambda function // to regulate priority queue depending on strategy - auto dual_cell_less = [this](const CellWithValue& a, const CellWithValue& b) { + auto dual_cell_less = [this](const CellWithValue<R>& a, const CellWithValue<R>& b) { int a_level = a.level(); int b_level = b.level(); - Real a_value = a.max_corner_value(); - Real b_value = b.max_corner_value(); - Real a_ub = a.stored_upper_bound(); - Real b_ub = b.stored_upper_bound(); + R a_value = a.max_corner_value(); + R b_value = b.max_corner_value(); + R a_ub = a.stored_upper_bound(); + R b_ub = b.stored_upper_bound(); if (this->params_.traverse_strategy == TraverseStrategy::upper_bound and (not a.has_max_possible_value() or not b.has_max_possible_value())) { throw std::runtime_error("no upper bound on cell"); } - DualPoint a_lower_left = a.dual_box().lower_left(); - DualPoint b_lower_left = b.dual_box().lower_left(); + DualPoint<R> a_lower_left = a.dual_box().lower_left(); + DualPoint<R> b_lower_left = b.dual_box().lower_left(); switch(this->params_.traverse_strategy) { // in both breadth_first searches we want coarser cells @@ -569,24 +561,24 @@ namespace md { } }; - std::priority_queue<CellWithValue, CellValueVector, decltype(dual_cell_less)> dual_cells_queue( + std::priority_queue<CellWithValue<R>, CellValueVector, decltype(dual_cell_less)> dual_cells_queue( dual_cell_less); // weighted bt distance on the center of current cell - Real lower_bound = std::numeric_limits<Real>::min(); + R lower_bound = std::numeric_limits<R>::min(); // init pq and lower bound for(auto& init_cell : get_initial_dual_grid(lower_bound)) { dual_cells_queue.push(init_cell); } - Real upper_bound = get_max_possible_value(&dual_cells_queue.top(), dual_cells_queue.size()); + R upper_bound = get_max_possible_value(&dual_cells_queue.top(), dual_cells_queue.size()); std::vector<UbExperimentRecord> ub_experiment_results; while(not dual_cells_queue.empty()) { - CellWithValue dual_cell = dual_cells_queue.top(); + CellWithValue<R> dual_cell = dual_cells_queue.top(); dual_cells_queue.pop(); assert(dual_cell.has_corner_value() and dual_cell.has_max_possible_value() @@ -620,7 +612,7 @@ namespace md { // until now, dual_cell knows its value in one of its corners // new_value will be the weighted distance at its center set_cell_central_value(dual_cell); - Real new_value = dual_cell.value_at(ValuePoint::center); + R new_value = dual_cell.value_at(ValuePoint::center); lower_bound = std::max(new_value, lower_bound); spd::debug("Processed cell = {}, weighted value = {}, lower_bound = {}", dual_cell, new_value, lower_bound); @@ -638,11 +630,11 @@ namespace md { throw std::runtime_error("no value on cell"); // if delta is smaller than good_enough_value, it allows to prune cell - Real good_enough_ub = get_good_enough_upper_bound(lower_bound); + R good_enough_ub = get_good_enough_upper_bound(lower_bound); // upper bound of the parent holds for refined_cell // and can sometimes be smaller! - Real upper_bound_on_refined_cell = std::min(dual_cell.stored_upper_bound(), + R upper_bound_on_refined_cell = std::min(dual_cell.stored_upper_bound(), get_upper_bound(refined_cell, good_enough_ub)); spd::debug("upper_bound_on_refined_cell = {}, dual_cell.stored_upper_bound = {}, get_upper_bound = {}", @@ -774,10 +766,46 @@ namespace md { return lower_bound; } - template<class T> - int DistanceCalculator<T>::get_hera_calls_number() const + template<class R, class T> + int DistanceCalculator<R, T>::get_hera_calls_number() const { return n_hera_calls_; } -}
\ No newline at end of file + template<class R> + R matching_distance(const Bifiltration<R>& bif_a, const Bifiltration<R>& bif_b, + CalculationParams<R>& params) + { + R result; + // compute distance only in one dimension + if (params.dim != CalculationParams<R>::ALL_DIMENSIONS) { + BifiltrationProxy<R> bifp_a(bif_a, params.dim); + BifiltrationProxy<R> bifp_b(bif_b, params.dim); + DistanceCalculator<R, BifiltrationProxy<R>> runner(bifp_a, bifp_b, params); + result = runner.distance(); + params.n_hera_calls = runner.get_hera_calls_number(); + } else { + // compute distance in all dimensions, return maximal + result = -1; + for(int dim = 0; dim < std::max(bif_a.maximal_dim(), bif_b.maximal_dim()); ++dim) { + BifiltrationProxy<R> bifp_a(bif_a, params.dim); + BifiltrationProxy<R> bifp_b(bif_a, params.dim); + DistanceCalculator<R, BifiltrationProxy<R>> runner(bifp_a, bifp_b, params); + result = std::max(result, runner.distance()); + params.n_hera_calls += runner.get_hera_calls_number(); + } + } + return result; + } + + + template<class R> + R matching_distance(const ModulePresentation<R>& mod_a, const ModulePresentation<R>& mod_b, + CalculationParams<R>& params) + { + DistanceCalculator<R, ModulePresentation<R>> runner(mod_a, mod_b, params); + R result = runner.distance(); + params.n_hera_calls = runner.get_hera_calls_number(); + return result; + } +} // namespace md |