summaryrefslogtreecommitdiff
path: root/matching/include/matching_distance.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'matching/include/matching_distance.hpp')
-rw-r--r--matching/include/matching_distance.hpp326
1 files changed, 177 insertions, 149 deletions
diff --git a/matching/include/matching_distance.hpp b/matching/include/matching_distance.hpp
index d2d2fbc..48c8464 100644
--- a/matching/include/matching_distance.hpp
+++ b/matching/include/matching_distance.hpp
@@ -1,34 +1,26 @@
namespace md {
- template<class K, class V>
- void print_map(const std::map<K, V>& dic)
- {
- for(const auto kv : dic) {
- fmt::print("{} -> {}\n", kv.first, kv.second);
- }
- }
-
- template<class T>
- void DistanceCalculator<T>::check_upper_bound(const CellWithValue& dual_cell) const
+ template<class R, class T>
+ void DistanceCalculator<R, T>::check_upper_bound(const CellWithValue<R>& dual_cell) const
{
spd::debug("Enter check_get_max_delta_on_cell");
const int n_samples_lambda = 100;
const int n_samples_mu = 100;
- DualBox db = dual_cell.dual_box();
- Real min_lambda = db.lambda_min();
- Real max_lambda = db.lambda_max();
- Real min_mu = db.mu_min();
- Real max_mu = db.mu_max();
-
- Real h_lambda = (max_lambda - min_lambda) / n_samples_lambda;
- Real h_mu = (max_mu - min_mu) / n_samples_mu;
+ DualBox<R> db = dual_cell.dual_box();
+ R min_lambda = db.lambda_min();
+ R max_lambda = db.lambda_max();
+ R min_mu = db.mu_min();
+ R max_mu = db.mu_max();
+
+ R h_lambda = (max_lambda - min_lambda) / n_samples_lambda;
+ R h_mu = (max_mu - min_mu) / n_samples_mu;
for(int i = 1; i < n_samples_lambda; ++i) {
for(int j = 1; j < n_samples_mu; ++j) {
- Real lambda = min_lambda + i * h_lambda;
- Real mu = min_mu + j * h_mu;
- DualPoint l(db.axis_type(), db.angle_type(), lambda, mu);
- Real other_result = distance_on_line_const(l);
- Real diff = fabs(dual_cell.stored_upper_bound() - other_result);
+ R lambda = min_lambda + i * h_lambda;
+ R mu = min_mu + j * h_mu;
+ DualPoint<R> l(db.axis_type(), db.angle_type(), lambda, mu);
+ R other_result = distance_on_line_const(l);
+ R diff = fabs(dual_cell.stored_upper_bound() - other_result);
if (other_result > dual_cell.stored_upper_bound()) {
spd::error(
"in check_upper_bound, upper_bound = {}, other_result = {}, diff = {}\ndual_cell = {}",
@@ -42,10 +34,10 @@ namespace md {
// for all lines l, l' inside dual box,
// find the upper bound on the difference of weighted pushes of p
- template<class T>
- Real
- DistanceCalculator<T>::get_max_displacement_single_point(const CellWithValue& dual_cell, ValuePoint vp,
- const Point& p) const
+ template<class R, class T>
+ R
+ DistanceCalculator<R, T>::get_max_displacement_single_point(const CellWithValue<R>& dual_cell, ValuePoint vp,
+ const Point<R>& p) const
{
assert(p.x >= 0 && p.y >= 0);
@@ -53,15 +45,15 @@ namespace md {
std::vector<long long int> debug_ids = {3, 13, 54, 218, 350, 382, 484, 795, 2040, 8415, 44076};
bool debug = false; // std::find(debug_ids.begin(), debug_ids.end(), dual_cell.id) != debug_ids.end();
#endif
- DualPoint line = dual_cell.value_point(vp);
- const Real base_value = line.weighted_push(p);
+ DualPoint<R> line = dual_cell.value_point(vp);
+ const R base_value = line.weighted_push(p);
spd::debug("Enter get_max_displacement_single_point, p = {},\ndual_cell = {},\nline = {}, base_value = {}\n", p,
dual_cell, line, base_value);
- Real result = 0.0;
- for(DualPoint dp : dual_cell.dual_box().critical_points(p)) {
- Real dp_value = dp.weighted_push(p);
+ R result = 0.0;
+ for(DualPoint<R> dp : dual_cell.dual_box().critical_points(p)) {
+ R dp_value = dp.weighted_push(p);
spd::debug(
"In get_max_displacement_single_point, p = {}, critical dp = {},\ndp_value = {}, diff = {},\ndual_cell = {}\n",
p, dp, dp_value, fabs(base_value - dp_value), dual_cell);
@@ -69,15 +61,15 @@ namespace md {
}
#ifdef MD_DO_FULL_CHECK
- DualBox db = dual_cell.dual_box();
- std::uniform_real_distribution<Real> dlambda(db.lambda_min(), db.lambda_max());
- std::uniform_real_distribution<Real> dmu(db.mu_min(), db.mu_max());
+ auto db = dual_cell.dual_box();
+ std::uniform_real_distribution<R> dlambda(db.lambda_min(), db.lambda_max());
+ std::uniform_real_distribution<R> dmu(db.mu_min(), db.mu_max());
std::mt19937 gen(1);
for(int i = 0; i < 1000; ++i) {
- Real lambda = dlambda(gen);
- Real mu = dmu(gen);
- DualPoint dp_random { db.axis_type(), db.angle_type(), lambda, mu };
- Real dp_value = dp_random.weighted_push(p);
+ R lambda = dlambda(gen);
+ R mu = dmu(gen);
+ DualPoint<R> dp_random { db.axis_type(), db.angle_type(), lambda, mu };
+ R dp_value = dp_random.weighted_push(p);
if (fabs(base_value - dp_value) > result) {
spd::error("in get_max_displacement_single_point, p = {}, vp = {}\ndb = {}\nresult = {}, base_value = {}, dp_value = {}, dp_random = {}",
p, vp, db, result, base_value, dp_value, dp_random);
@@ -89,12 +81,12 @@ namespace md {
return result;
}
- template<class T>
- typename DistanceCalculator<T>::CellValueVector DistanceCalculator<T>::get_initial_dual_grid(Real& lower_bound)
+ template<class R, class T>
+ typename DistanceCalculator<R, T>::CellValueVector DistanceCalculator<R, T>::get_initial_dual_grid(R& lower_bound)
{
CellValueVector result = get_refined_grid(params_.initialization_depth, false, true);
- lower_bound = -1.0;
+ lower_bound = -1;
for(const auto& dc : result) {
lower_bound = std::max(lower_bound, dc.max_corner_value());
}
@@ -102,8 +94,8 @@ namespace md {
assert(lower_bound >= 0);
for(auto& dual_cell : result) {
- Real good_enough_ub = get_good_enough_upper_bound(lower_bound);
- Real max_value_on_cell = get_upper_bound(dual_cell, good_enough_ub);
+ R good_enough_ub = get_good_enough_upper_bound(lower_bound);
+ R max_value_on_cell = get_upper_bound(dual_cell, good_enough_ub);
dual_cell.set_max_possible_value(max_value_on_cell);
#ifdef MD_DO_FULL_CHECK
@@ -116,39 +108,39 @@ namespace md {
return result;
}
- template<class T>
- typename DistanceCalculator<T>::CellValueVector
- DistanceCalculator<T>::get_refined_grid(int init_depth, bool calculate_on_intermediate, bool calculate_on_last)
+ template<class R, class T>
+ typename DistanceCalculator<R, T>::CellValueVector
+ DistanceCalculator<R, T>::get_refined_grid(int init_depth, bool calculate_on_intermediate, bool calculate_on_last)
{
- const Real y_max = std::max(module_a_.max_y(), module_b_.max_y());
- const Real x_max = std::max(module_a_.max_x(), module_b_.max_x());
+ const R y_max = std::max(module_a_.max_y(), module_b_.max_y());
+ const R x_max = std::max(module_a_.max_x(), module_b_.max_x());
- const Real lambda_min = 0;
- const Real lambda_max = 1;
+ const R lambda_min = 0;
+ const R lambda_max = 1;
- const Real mu_min = 0;
+ const R mu_min = 0;
- DualBox x_flat(DualPoint(AxisType::x_type, AngleType::flat, lambda_min, mu_min),
- DualPoint(AxisType::x_type, AngleType::flat, lambda_max, x_max));
+ DualBox<R> x_flat(DualPoint<R>(AxisType::x_type, AngleType::flat, lambda_min, mu_min),
+ DualPoint<R>(AxisType::x_type, AngleType::flat, lambda_max, x_max));
- DualBox x_steep(DualPoint(AxisType::x_type, AngleType::steep, lambda_min, mu_min),
- DualPoint(AxisType::x_type, AngleType::steep, lambda_max, x_max));
+ DualBox<R> x_steep(DualPoint<R>(AxisType::x_type, AngleType::steep, lambda_min, mu_min),
+ DualPoint<R>(AxisType::x_type, AngleType::steep, lambda_max, x_max));
- DualBox y_flat(DualPoint(AxisType::y_type, AngleType::flat, lambda_min, mu_min),
- DualPoint(AxisType::y_type, AngleType::flat, lambda_max, y_max));
+ DualBox<R> y_flat(DualPoint<R>(AxisType::y_type, AngleType::flat, lambda_min, mu_min),
+ DualPoint<R>(AxisType::y_type, AngleType::flat, lambda_max, y_max));
- DualBox y_steep(DualPoint(AxisType::y_type, AngleType::steep, lambda_min, mu_min),
- DualPoint(AxisType::y_type, AngleType::steep, lambda_max, y_max));
+ DualBox<R> y_steep(DualPoint<R>(AxisType::y_type, AngleType::steep, lambda_min, mu_min),
+ DualPoint<R>(AxisType::y_type, AngleType::steep, lambda_max, y_max));
- CellWithValue x_flat_cell(x_flat, 0);
- CellWithValue x_steep_cell(x_steep, 0);
- CellWithValue y_flat_cell(y_flat, 0);
- CellWithValue y_steep_cell(y_steep, 0);
+ CellWithValue<R> x_flat_cell(x_flat, 0);
+ CellWithValue<R> x_steep_cell(x_steep, 0);
+ CellWithValue<R> y_flat_cell(y_flat, 0);
+ CellWithValue<R> y_steep_cell(y_steep, 0);
if (init_depth == 0) {
- DualPoint diagonal_x_flat(AxisType::x_type, AngleType::flat, 1, 0);
+ DualPoint<R> diagonal_x_flat(AxisType::x_type, AngleType::flat, 1, 0);
- Real diagonal_value = distance_on_line(diagonal_x_flat);
+ R diagonal_value = distance_on_line(diagonal_x_flat);
n_hera_calls_per_level_[0]++;
x_flat_cell.set_value_at(ValuePoint::lower_right, diagonal_value);
@@ -162,7 +154,7 @@ namespace md {
x_steep_cell.id = 2;
y_flat_cell.id = 3;
y_steep_cell.id = 4;
- CellWithValue::max_id = 4;
+ CellWithValue<R>::max_id = 4;
#endif
CellValueVector result {x_flat_cell, x_steep_cell, y_flat_cell, y_steep_cell};
@@ -189,10 +181,10 @@ namespace md {
return result;
}
- template<class T>
- DistanceCalculator<T>::DistanceCalculator(const T& a,
+ template<class R, class T>
+ DistanceCalculator<R, T>::DistanceCalculator(const T& a,
const T& b,
- CalculationParams& params)
+ CalculationParams<R>& params)
:
module_a_(a),
module_b_(b),
@@ -213,33 +205,33 @@ namespace md {
module_a_.max_x(), module_a_.max_y(), module_b_.max_x(), module_b_.max_y());
}
- template<class T>
- Real DistanceCalculator<T>::get_max_x(int module) const
+ template<class R, class T>
+ R DistanceCalculator<R, T>::get_max_x(int module) const
{
return (module == 0) ? module_a_.max_x() : module_b_.max_x();
}
- template<class T>
- Real DistanceCalculator<T>::get_max_y(int module) const
+ template<class R, class T>
+ R DistanceCalculator<R, T>::get_max_y(int module) const
{
return (module == 0) ? module_a_.max_y() : module_b_.max_y();
}
- template<class T>
- Real
- DistanceCalculator<T>::get_local_refined_bound(const md::DualBox& dual_box) const
+ template<class R, class T>
+ R
+ DistanceCalculator<R, T>::get_local_refined_bound(const DualBox<R>& dual_box) const
{
return get_local_refined_bound(0, dual_box) + get_local_refined_bound(1, dual_box);
}
- template<class T>
- Real
- DistanceCalculator<T>::get_local_refined_bound(int module, const md::DualBox& dual_box) const
+ template<class R, class T>
+ R
+ DistanceCalculator<R, T>::get_local_refined_bound(int module, const DualBox<R>& dual_box) const
{
spd::debug("Enter get_local_refined_bound, dual_box = {}", dual_box);
- Real d_lambda = dual_box.lambda_max() - dual_box.lambda_min();
- Real d_mu = dual_box.mu_max() - dual_box.mu_min();
- Real result;
+ R d_lambda = dual_box.lambda_max() - dual_box.lambda_min();
+ R d_mu = dual_box.mu_max() - dual_box.mu_min();
+ R result;
if (dual_box.axis_type() == AxisType::x_type) {
if (dual_box.is_flat()) {
result = dual_box.lambda_max() * d_mu + (get_max_x(module) - dual_box.mu_min()) * d_lambda;
@@ -258,11 +250,11 @@ namespace md {
return result;
}
- template<class T>
- Real DistanceCalculator<T>::get_local_dual_bound(int module, const md::DualBox& dual_box) const
+ template<class R, class T>
+ R DistanceCalculator<R, T>::get_local_dual_bound(int module, const DualBox<R>& dual_box) const
{
- Real dlambda = dual_box.lambda_max() - dual_box.lambda_min();
- Real dmu = dual_box.mu_max() - dual_box.mu_min();
+ R dlambda = dual_box.lambda_max() - dual_box.lambda_min();
+ R dmu = dual_box.mu_max() - dual_box.mu_min();
if (dual_box.is_flat()) {
return get_max_x(module) * dlambda + dmu;
@@ -271,20 +263,20 @@ namespace md {
}
}
- template<class T>
- Real DistanceCalculator<T>::get_local_dual_bound(const md::DualBox& dual_box) const
+ template<class R, class T>
+ R DistanceCalculator<R, T>::get_local_dual_bound(const DualBox<R>& dual_box) const
{
return get_local_dual_bound(0, dual_box) + get_local_dual_bound(1, dual_box);
}
- template<class T>
- Real DistanceCalculator<T>::get_upper_bound(const CellWithValue& dual_cell, Real good_enough_ub) const
+ template<class R, class T>
+ R DistanceCalculator<R, T>::get_upper_bound(const CellWithValue<R>& dual_cell, R good_enough_ub) const
{
assert(good_enough_ub >= 0);
switch(params_.bound_strategy) {
case BoundStrategy::bruteforce:
- return std::numeric_limits<Real>::max();
+ return std::numeric_limits<R>::max();
case BoundStrategy::local_dual_bound:
return dual_cell.min_value() + get_local_dual_bound(dual_cell.dual_box());
@@ -293,7 +285,7 @@ namespace md {
return dual_cell.min_value() + get_local_refined_bound(dual_cell.dual_box());
case BoundStrategy::local_combined: {
- Real cheap_upper_bound = dual_cell.min_value() + get_local_refined_bound(dual_cell.dual_box());
+ R cheap_upper_bound = dual_cell.min_value() + get_local_refined_bound(dual_cell.dual_box());
if (cheap_upper_bound < good_enough_ub) {
return cheap_upper_bound;
} else {
@@ -302,14 +294,14 @@ namespace md {
}
case BoundStrategy::local_dual_bound_for_each_point: {
- Real result = std::numeric_limits<Real>::max();
+ R result = std::numeric_limits<R>::max();
for(ValuePoint vp : k_corner_vps) {
if (not dual_cell.has_value_at(vp)) {
continue;
}
- Real base_value = dual_cell.value_at(vp);
- Real bound_dgm_a = get_single_dgm_bound(dual_cell, vp, 0, good_enough_ub);
+ R base_value = dual_cell.value_at(vp);
+ R bound_dgm_a = get_single_dgm_bound(dual_cell, vp, 0, good_enough_ub);
if (params_.stop_asap and bound_dgm_a + base_value >= good_enough_ub) {
// we want to return a valid upper bound, not just something that will prevent discarding the cell
@@ -318,8 +310,8 @@ namespace md {
return dual_cell.min_value() + get_local_refined_bound(dual_cell.dual_box());
}
- Real bound_dgm_b = get_single_dgm_bound(dual_cell, vp, 1,
- std::max(Real(0), good_enough_ub - bound_dgm_a));
+ R bound_dgm_b = get_single_dgm_bound(dual_cell, vp, 1,
+ std::max(R(0), good_enough_ub - bound_dgm_a));
result = std::min(result, base_value + bound_dgm_a + bound_dgm_b);
@@ -336,19 +328,19 @@ namespace md {
}
}
// to suppress compiler warning
- return std::numeric_limits<Real>::max();
+ return std::numeric_limits<R>::max();
}
// find maximal displacement of weighted points of m for all lines in dual_box
- template<class T>
- Real
- DistanceCalculator<T>::get_single_dgm_bound(const CellWithValue& dual_cell,
+ template<class R, class T>
+ R
+ DistanceCalculator<R, T>::get_single_dgm_bound(const CellWithValue<R>& dual_cell,
ValuePoint vp,
int module,
- [[maybe_unused]] Real good_enough_value) const
+ R good_enough_value) const
{
- Real result = 0;
- Point max_point;
+ R result = 0;
+ Point<R> max_point;
spd::debug(
"Enter get_single_dgm_bound, module = {}, dual_cell = {}, vp = {}, good_enough_value = {}, stop_asap = {}\n",
@@ -358,7 +350,7 @@ namespace md {
for(const auto& position : m.positions()) {
spd::debug("in get_single_dgm_bound, simplex = {}\n", position);
- Real x = get_max_displacement_single_point(dual_cell, vp, position);
+ R x = get_max_displacement_single_point(dual_cell, vp, position);
spd::debug("In get_single_dgm_bound, point = {}, displacement = {}", position, x);
@@ -385,30 +377,30 @@ namespace md {
return result;
}
- template<class T>
- Real DistanceCalculator<T>::distance()
+ template<class R, class T>
+ R DistanceCalculator<R, T>::distance()
{
return get_distance_pq();
}
// calculate weighted bottleneneck distance between slices on line
// increments hera calls counter
- template<class T>
- Real DistanceCalculator<T>::distance_on_line(DualPoint line)
+ template<class R, class T>
+ R DistanceCalculator<R, T>::distance_on_line(DualPoint<R> line)
{
++n_hera_calls_;
- Real result = distance_on_line_const(line);
+ R result = distance_on_line_const(line);
return result;
}
- template<class T>
- Real DistanceCalculator<T>::distance_on_line_const(DualPoint line) const
+ template<class R, class T>
+ R DistanceCalculator<R, T>::distance_on_line_const(DualPoint<R> line) const
{
// TODO: think about this - how to call Hera
auto dgm_a = module_a_.weighted_slice_diagram(line);
auto dgm_b = module_b_.weighted_slice_diagram(line);
- Real result;
- if (params_.hera_epsilon > static_cast<Real>(0)) {
+ R result;
+ if (params_.hera_epsilon > static_cast<R>(0)) {
result = hera::bottleneckDistApprox(dgm_a, dgm_b, params_.hera_epsilon) / ( params_.hera_epsilon + 1);
} else {
result = hera::bottleneckDistExact(dgm_a, dgm_b);
@@ -423,10 +415,10 @@ namespace md {
return result;
}
- template<class T>
- Real DistanceCalculator<T>::get_good_enough_upper_bound(Real lower_bound) const
+ template<class R, class T>
+ R DistanceCalculator<R, T>::get_good_enough_upper_bound(R lower_bound) const
{
- Real result;
+ R result;
// in upper_bound strategy we only prune cells if they cannot improve the lower bound,
// otherwise the experiment is supposed to run indefinitely
if (params_.traverse_strategy == TraverseStrategy::upper_bound) {
@@ -440,14 +432,14 @@ namespace md {
// helper function
// calculate weighted bt distance on cell center,
// assign distance value to cell, keep it in heat_map, and return
- template<class T>
- void DistanceCalculator<T>::set_cell_central_value(CellWithValue& dual_cell)
+ template<class R, class T>
+ void DistanceCalculator<R, T>::set_cell_central_value(CellWithValue<R>& dual_cell)
{
- DualPoint central_line {dual_cell.center()};
+ DualPoint<R> central_line {dual_cell.center()};
spd::debug("In set_cell_central_value, processing dual cell = {}, line = {}", dual_cell.dual_box(),
central_line);
- Real new_value = distance_on_line(central_line);
+ R new_value = distance_on_line(central_line);
n_hera_calls_per_level_[dual_cell.level() + 1]++;
dual_cell.set_value_at(ValuePoint::center, new_value);
params_.actual_max_depth = std::max(params_.actual_max_depth, dual_cell.level() + 1);
@@ -472,10 +464,10 @@ namespace md {
// assumes that the underlying container is vector!
// cell_ptr: pointer to the first element in queue
// n_cells: queue size
- template<class T>
- Real DistanceCalculator<T>::get_max_possible_value(const CellWithValue* cell_ptr, int n_cells)
+ template<class R, class T>
+ R DistanceCalculator<R, T>::get_max_possible_value(const CellWithValue<R>* cell_ptr, int n_cells)
{
- Real result = (n_cells > 0) ? cell_ptr->stored_upper_bound() : 0;
+ R result = (n_cells > 0) ? cell_ptr->stored_upper_bound() : 0;
for(int i = 0; i < n_cells; ++i, ++cell_ptr) {
result = std::max(result, cell_ptr->stored_upper_bound());
}
@@ -485,11 +477,11 @@ namespace md {
// helper function:
// return current error from lower and upper bounds
// and save it in params_ (hence not const)
- template<class T>
- Real DistanceCalculator<T>::current_error(Real lower_bound, Real upper_bound)
+ template<class R, class T>
+ R DistanceCalculator<R, T>::current_error(R lower_bound, R upper_bound)
{
- Real current_error = (lower_bound > 0.0) ? (upper_bound - lower_bound) / lower_bound
- : std::numeric_limits<Real>::max();
+ R current_error = (lower_bound > 0.0) ? (upper_bound - lower_bound) / lower_bound
+ : std::numeric_limits<R>::max();
params_.actual_error = current_error;
@@ -505,8 +497,8 @@ namespace md {
// use priority queue to store dual cells
// comparison function depends on the strategies in params_
// ressets hera calls counter
- template<class T>
- Real DistanceCalculator<T>::get_distance_pq()
+ template<class R, class T>
+ R DistanceCalculator<R, T>::get_distance_pq()
{
std::map<int, long> n_cells_considered;
std::map<int, long> n_cells_pushed_into_queue;
@@ -527,26 +519,26 @@ namespace md {
// if cell is too deep and is not pushed into queue,
// we still need to take its max value into account;
// the max over such cells is stored in max_result_on_too_fine_cells
- Real upper_bound_on_deep_cells = -1;
+ R upper_bound_on_deep_cells = -1;
spd::debug("Started iterations in dual space, delta = {}, bound_strategy = {}", params_.delta,
params_.bound_strategy);
// user-defined less lambda function
// to regulate priority queue depending on strategy
- auto dual_cell_less = [this](const CellWithValue& a, const CellWithValue& b) {
+ auto dual_cell_less = [this](const CellWithValue<R>& a, const CellWithValue<R>& b) {
int a_level = a.level();
int b_level = b.level();
- Real a_value = a.max_corner_value();
- Real b_value = b.max_corner_value();
- Real a_ub = a.stored_upper_bound();
- Real b_ub = b.stored_upper_bound();
+ R a_value = a.max_corner_value();
+ R b_value = b.max_corner_value();
+ R a_ub = a.stored_upper_bound();
+ R b_ub = b.stored_upper_bound();
if (this->params_.traverse_strategy == TraverseStrategy::upper_bound and
(not a.has_max_possible_value() or not b.has_max_possible_value())) {
throw std::runtime_error("no upper bound on cell");
}
- DualPoint a_lower_left = a.dual_box().lower_left();
- DualPoint b_lower_left = b.dual_box().lower_left();
+ DualPoint<R> a_lower_left = a.dual_box().lower_left();
+ DualPoint<R> b_lower_left = b.dual_box().lower_left();
switch(this->params_.traverse_strategy) {
// in both breadth_first searches we want coarser cells
@@ -569,24 +561,24 @@ namespace md {
}
};
- std::priority_queue<CellWithValue, CellValueVector, decltype(dual_cell_less)> dual_cells_queue(
+ std::priority_queue<CellWithValue<R>, CellValueVector, decltype(dual_cell_less)> dual_cells_queue(
dual_cell_less);
// weighted bt distance on the center of current cell
- Real lower_bound = std::numeric_limits<Real>::min();
+ R lower_bound = std::numeric_limits<R>::min();
// init pq and lower bound
for(auto& init_cell : get_initial_dual_grid(lower_bound)) {
dual_cells_queue.push(init_cell);
}
- Real upper_bound = get_max_possible_value(&dual_cells_queue.top(), dual_cells_queue.size());
+ R upper_bound = get_max_possible_value(&dual_cells_queue.top(), dual_cells_queue.size());
std::vector<UbExperimentRecord> ub_experiment_results;
while(not dual_cells_queue.empty()) {
- CellWithValue dual_cell = dual_cells_queue.top();
+ CellWithValue<R> dual_cell = dual_cells_queue.top();
dual_cells_queue.pop();
assert(dual_cell.has_corner_value()
and dual_cell.has_max_possible_value()
@@ -620,7 +612,7 @@ namespace md {
// until now, dual_cell knows its value in one of its corners
// new_value will be the weighted distance at its center
set_cell_central_value(dual_cell);
- Real new_value = dual_cell.value_at(ValuePoint::center);
+ R new_value = dual_cell.value_at(ValuePoint::center);
lower_bound = std::max(new_value, lower_bound);
spd::debug("Processed cell = {}, weighted value = {}, lower_bound = {}", dual_cell, new_value, lower_bound);
@@ -638,11 +630,11 @@ namespace md {
throw std::runtime_error("no value on cell");
// if delta is smaller than good_enough_value, it allows to prune cell
- Real good_enough_ub = get_good_enough_upper_bound(lower_bound);
+ R good_enough_ub = get_good_enough_upper_bound(lower_bound);
// upper bound of the parent holds for refined_cell
// and can sometimes be smaller!
- Real upper_bound_on_refined_cell = std::min(dual_cell.stored_upper_bound(),
+ R upper_bound_on_refined_cell = std::min(dual_cell.stored_upper_bound(),
get_upper_bound(refined_cell, good_enough_ub));
spd::debug("upper_bound_on_refined_cell = {}, dual_cell.stored_upper_bound = {}, get_upper_bound = {}",
@@ -774,10 +766,46 @@ namespace md {
return lower_bound;
}
- template<class T>
- int DistanceCalculator<T>::get_hera_calls_number() const
+ template<class R, class T>
+ int DistanceCalculator<R, T>::get_hera_calls_number() const
{
return n_hera_calls_;
}
-} \ No newline at end of file
+ template<class R>
+ R matching_distance(const Bifiltration<R>& bif_a, const Bifiltration<R>& bif_b,
+ CalculationParams<R>& params)
+ {
+ R result;
+ // compute distance only in one dimension
+ if (params.dim != CalculationParams<R>::ALL_DIMENSIONS) {
+ BifiltrationProxy<R> bifp_a(bif_a, params.dim);
+ BifiltrationProxy<R> bifp_b(bif_b, params.dim);
+ DistanceCalculator<R, BifiltrationProxy<R>> runner(bifp_a, bifp_b, params);
+ result = runner.distance();
+ params.n_hera_calls = runner.get_hera_calls_number();
+ } else {
+ // compute distance in all dimensions, return maximal
+ result = -1;
+ for(int dim = 0; dim < std::max(bif_a.maximal_dim(), bif_b.maximal_dim()); ++dim) {
+ BifiltrationProxy<R> bifp_a(bif_a, params.dim);
+ BifiltrationProxy<R> bifp_b(bif_a, params.dim);
+ DistanceCalculator<R, BifiltrationProxy<R>> runner(bifp_a, bifp_b, params);
+ result = std::max(result, runner.distance());
+ params.n_hera_calls += runner.get_hera_calls_number();
+ }
+ }
+ return result;
+ }
+
+
+ template<class R>
+ R matching_distance(const ModulePresentation<R>& mod_a, const ModulePresentation<R>& mod_b,
+ CalculationParams<R>& params)
+ {
+ DistanceCalculator<R, ModulePresentation<R>> runner(mod_a, mod_b, params);
+ R result = runner.distance();
+ params.n_hera_calls = runner.get_hera_calls_number();
+ return result;
+ }
+} // namespace md