summaryrefslogtreecommitdiff
path: root/matching/include
diff options
context:
space:
mode:
Diffstat (limited to 'matching/include')
-rw-r--r--matching/include/bifiltration.h72
-rw-r--r--matching/include/bifiltration.hpp421
-rw-r--r--matching/include/box.h77
-rw-r--r--matching/include/box.hpp52
-rw-r--r--matching/include/cell_with_value.h47
-rw-r--r--matching/include/cell_with_value.hpp224
-rw-r--r--matching/include/common_defs.h4
-rw-r--r--matching/include/common_util.h113
-rw-r--r--matching/include/common_util.hpp96
-rw-r--r--matching/include/dual_box.h78
-rw-r--r--matching/include/dual_box.hpp190
-rw-r--r--matching/include/dual_point.h27
-rw-r--r--matching/include/dual_point.hpp299
-rw-r--r--matching/include/matching_distance.h203
-rw-r--r--matching/include/matching_distance.hpp326
-rw-r--r--matching/include/persistence_module.h34
-rw-r--r--matching/include/persistence_module.hpp177
-rw-r--r--matching/include/simplex.h70
-rw-r--r--matching/include/simplex.hpp79
19 files changed, 2117 insertions, 472 deletions
diff --git a/matching/include/bifiltration.h b/matching/include/bifiltration.h
index f505ed9..4dd8662 100644
--- a/matching/include/bifiltration.h
+++ b/matching/include/bifiltration.h
@@ -3,19 +3,30 @@
#include <string>
#include <ostream>
+#include <iostream>
+#include <fstream>
+#include <sstream>
+#include <cassert>
#include "common_util.h"
#include "box.h"
#include "simplex.h"
#include "dual_point.h"
+#include "phat/boundary_matrix.h"
+#include "phat/compute_persistence_pairs.h"
+
+#include "spdlog/spdlog.h"
+#include "spdlog/fmt/fmt.h"
+#include "spdlog/fmt/ostr.h"
+
+#include "common_util.h"
namespace md {
+ template<class Real>
class Bifiltration {
public:
- using Diagram = std::vector<std::pair<Real, Real>>;
- using Box = md::Box;
- using SimplexVector = std::vector<Simplex>;
+ using SimplexVector = std::vector<Simplex<Real>>;
Bifiltration() = default;
@@ -36,7 +47,7 @@ namespace md {
init();
}
- Diagram weighted_slice_diagram(const DualPoint& line, int dim) const;
+ Diagram<Real> weighted_slice_diagram(const DualPoint<Real>& line, int dim) const;
SimplexVector simplices() const { return simplices_; }
@@ -48,14 +59,12 @@ namespace md {
Real minimal_coordinate() const;
// return box that contains positions of all simplices
- Box bounding_box() const;
+ Box<Real> bounding_box() const;
void sanity_check() const;
int maximal_dim() const { return maximal_dim_; }
- friend std::ostream& operator<<(std::ostream& os, const Bifiltration& bif);
-
Real max_x() const;
Real max_y() const;
@@ -64,7 +73,7 @@ namespace md {
Real min_y() const;
- void add_simplex(Index _id, Point birth, int _dim, const Column& _bdry);
+ void add_simplex(Index _id, Point<Real> birth, int _dim, const Column& _bdry);
void save(const std::string& filename, BifiltrationFormat format = BifiltrationFormat::rivet); // save to file
@@ -72,11 +81,8 @@ namespace md {
private:
SimplexVector simplices_;
- // axes names, for rivet bifiltration format only
- std::string parameter_1_name_ {"axis_1"};
- std::string parameter_2_name_ {"axis_2"};
- Box bounding_box_;
+ Box<Real> bounding_box_;
int maximal_dim_ {-1};
void init();
@@ -97,13 +103,15 @@ namespace md {
};
- std::ostream& operator<<(std::ostream& os, const Bifiltration& bif);
+ template<class Real>
+ std::ostream& operator<<(std::ostream& os, const Bifiltration<Real>& bif);
+ template<class Real>
class BifiltrationProxy {
public:
- BifiltrationProxy(const Bifiltration& bif, int dim = 0);
+ BifiltrationProxy(const Bifiltration<Real>& bif, int dim = 0);
// return critical values of simplices that are important for current dimension (dim and dim+1)
- PointVec positions() const;
+ PointVec<Real> positions() const;
// set current dimension
int set_dim(int new_dim);
@@ -111,46 +119,22 @@ namespace md {
int maximal_dim() const;
void translate(Real a);
Real minimal_coordinate() const;
- Box bounding_box() const;
+ Box<Real> bounding_box() const;
Real max_x() const;
Real max_y() const;
Real min_x() const;
Real min_y() const;
- Diagram weighted_slice_diagram(const DualPoint& slice) const;
+ Diagram<Real> weighted_slice_diagram(const DualPoint<Real>& slice) const;
private:
int dim_ { 0 };
- mutable PointVec cached_positions_;
- Bifiltration bif_;
+ mutable PointVec<Real> cached_positions_;
+ Bifiltration<Real> bif_;
void cache_positions() const;
};
}
-
+#include "bifiltration.hpp"
#endif //MATCHING_DISTANCE_BIFILTRATION_H
-
-//// The value type of OutputIterator is Simplex_in_2D_filtration
-//template<typename OutputIterator>
-//void read_input(std::string filename, OutputIterator out)
-//{
-// std::ifstream ifstr;
-// ifstr.open(filename.c_str());
-// long n;
-// ifstr >> n; // number of simplices is the first number in file
-//
-// Index k; // used in loop
-// for (int i = 0; i < n; i++) {
-// Simplex_in_2D_filtration next;
-// next.index = i;
-// ifstr >> next.dim >> next.pos.x >> next.pos.y;
-// if (next.dim > 0) {
-// for (int j = 0; j <= next.dim; j++) {
-// ifstr >> k;
-// next.bd.push_back(k);
-// }
-// }
-// *out++ = next;
-// }
-//}
diff --git a/matching/include/bifiltration.hpp b/matching/include/bifiltration.hpp
new file mode 100644
index 0000000..9e2a82e
--- /dev/null
+++ b/matching/include/bifiltration.hpp
@@ -0,0 +1,421 @@
+namespace md {
+
+ template<class Real>
+ void Bifiltration<Real>::init()
+ {
+ auto lower_left = max_point<Real>();
+ auto upper_right = min_point<Real>();
+ for(const auto& simplex : simplices_) {
+ lower_left = greatest_lower_bound<>(lower_left, simplex.position());
+ upper_right = least_upper_bound<>(upper_right, simplex.position());
+ maximal_dim_ = std::max(maximal_dim_, simplex.dim());
+ }
+ bounding_box_ = Box<Real>(lower_left, upper_right);
+ }
+
+ template<class Real>
+ Bifiltration<Real>::Bifiltration(const std::string& fname)
+ {
+ std::ifstream ifstr {fname.c_str()};
+ if (!ifstr.good()) {
+ std::string error_message = fmt::format("Cannot open file {0}", fname);
+ std::cerr << error_message << std::endl;
+ throw std::runtime_error(error_message);
+ }
+
+ BifiltrationFormat input_format;
+
+ std::string s;
+
+ while(ignore_line(s)) {
+ std::getline(ifstr, s);
+ }
+
+ if (s == "bifiltration") {
+ input_format = BifiltrationFormat::rivet;
+ } else if (s == "bifiltration_phat_like") {
+ input_format = BifiltrationFormat::phat_like;
+ } else {
+ std::cerr << "Unknown format: '" << s << "' in file " << fname << std::endl;
+ throw std::runtime_error("unknown bifiltration format");
+ }
+
+ switch(input_format) {
+ case BifiltrationFormat::rivet :
+ rivet_format_reader(ifstr);
+ break;
+ case BifiltrationFormat::phat_like :
+ phat_like_format_reader(ifstr);
+ break;
+ }
+
+ ifstr.close();
+
+ init();
+ }
+
+ template<class Real>
+ void Bifiltration<Real>::rivet_format_reader(std::ifstream& ifstr)
+ {
+ std::string s;
+ // read axes names, ignore them
+ std::getline(ifstr, s);
+ std::getline(ifstr, s);
+
+ Index index = 0;
+ while(std::getline(ifstr, s)) {
+ if (!ignore_line(s)) {
+ simplices_.emplace_back(index++, s, BifiltrationFormat::rivet);
+ }
+ }
+ }
+
+ template<class Real>
+ void Bifiltration<Real>::phat_like_format_reader(std::ifstream& ifstr)
+ {
+ spd::debug("Enter phat_like_format_reader");
+ // read stream line by line; do not use >> operator
+ std::string s;
+ std::getline(ifstr, s);
+
+ // first line contains number of simplices
+ long n_simplices = std::stol(s);
+
+ // all other lines represent a simplex
+ Index index = 0;
+ while(index < n_simplices) {
+ std::getline(ifstr, s);
+ if (!ignore_line(s)) {
+ simplices_.emplace_back(index++, s, BifiltrationFormat::phat_like);
+ }
+ }
+ spd::debug("Read {} simplices from file", n_simplices);
+ }
+
+ template<class Real>
+ void Bifiltration<Real>::scale(Real lambda)
+ {
+ for(auto& s : simplices_) {
+ s.scale(lambda);
+ }
+ init();
+ }
+
+ template<class Real>
+ void Bifiltration<Real>::sanity_check() const
+ {
+#ifdef DEBUG
+ spd::debug("Enter Bifiltration<Real>::sanity_check");
+ // check that boundary has correct number of simplices,
+ // each bounding simplex has correct dim
+ // and appears in the filtration before the simplex it bounds
+ for(const auto& s : simplices_) {
+ assert(s.dim() >= 0);
+ assert(s.dim() == 0 or s.dim() + 1 == (int) s.boundary().size());
+ for(auto bdry_idx : s.boundary()) {
+ Simplex bdry_simplex = simplices()[bdry_idx];
+ assert(bdry_simplex.dim() == s.dim() - 1);
+ assert(bdry_simplex.position().is_less(s.position(), false));
+ }
+ }
+ spd::debug("Exit Bifiltration<Real>::sanity_check");
+#endif
+ }
+
+ template<class Real>
+ Diagram<Real> Bifiltration<Real>::weighted_slice_diagram(const DualPoint<Real>& line, int dim) const
+ {
+ DiagramKeeper<Real> dgm;
+
+ // make a copy for now; I want slice_diagram to be const
+ std::vector<Simplex<Real>> simplices(simplices_);
+
+// std::vector<Simplex> simplices;
+// simplices.reserve(simplices_.size() / 2);
+// for(const auto& s : simplices_) {
+// if (s.dim() <= dim + 1 and s.dim() >= dim)
+// simplices.emplace_back(s);
+// }
+
+ spd::debug("Enter slice diagram, line = {}, simplices.size = {}", line, simplices.size());
+
+ for(auto& simplex : simplices) {
+ Real value = line.weighted_push(simplex.position());
+// spd::debug("in slice_diagram, simplex = {}, value = {}\n", simplex, value);
+ simplex.set_value(value);
+ }
+
+ std::sort(simplices.begin(), simplices.end(),
+ [](const Simplex<Real>& a, const Simplex<Real>& b) { return a.value() < b.value(); });
+ std::map<Index, Index> index_map;
+ for(Index i = 0; i < (int) simplices.size(); i++) {
+ index_map[simplices[i].id()] = i;
+ }
+
+ phat::boundary_matrix<> phat_matrix;
+ phat_matrix.set_num_cols(simplices.size());
+ std::vector<Index> bd_in_slice_filtration;
+ for(Index i = 0; i < (int) simplices.size(); i++) {
+ phat_matrix.set_dim(i, simplices[i].dim());
+ bd_in_slice_filtration.clear();
+ //std::cout << "new col" << i << std::endl;
+ for(int j = 0; j < (int) simplices[i].boundary().size(); j++) {
+ // F[i] contains the indices of its facet wrt to the
+ // original filtration. We have to express it, however,
+ // wrt to the filtration along the slice. That is why
+ // we need the index_map
+ //std::cout << "Found " << F[i].bd[j] << ", returning " << index_map[F[i].bd[j]] << std::endl;
+ bd_in_slice_filtration.push_back(index_map[simplices[i].boundary()[j]]);
+ }
+ std::sort(bd_in_slice_filtration.begin(), bd_in_slice_filtration.end());
+ phat_matrix.set_col(i, bd_in_slice_filtration);
+ }
+ phat::persistence_pairs phat_persistence_pairs;
+ phat::compute_persistence_pairs<phat::twist_reduction>(phat_persistence_pairs, phat_matrix);
+
+ dgm.clear();
+ constexpr Real real_inf = std::numeric_limits<Real>::infinity();
+ for(long i = 0; i < (long) phat_persistence_pairs.get_num_pairs(); i++) {
+ std::pair<phat::index, phat::index> new_pair = phat_persistence_pairs.get_pair(i);
+ bool is_finite_pair = new_pair.second != phat::k_infinity_index;
+ Real birth = simplices.at(new_pair.first).value();
+ Real death = is_finite_pair ? simplices.at(new_pair.second).value() : real_inf;
+ int dim = simplices[new_pair.first].dim();
+ assert(dim + 1 == simplices[new_pair.second].dim());
+ if (birth != death) {
+ dgm.add_point(dim, birth, death);
+ }
+ }
+
+ spdlog::debug("Exiting slice_diagram, #dgm[0] = {}", dgm.get_diagram(0).size());
+
+ return dgm.get_diagram(dim);
+ }
+
+ template<class Real>
+ Box<Real> Bifiltration<Real>::bounding_box() const
+ {
+ return bounding_box_;
+ }
+
+ template<class Real>
+ Real Bifiltration<Real>::minimal_coordinate() const
+ {
+ return std::min(bounding_box_.lower_left().x, bounding_box_.lower_left().y);
+ }
+
+ template<class Real>
+ void Bifiltration<Real>::translate(Real a)
+ {
+ bounding_box_.translate(a);
+ for(auto& simplex : simplices_) {
+ simplex.translate(a);
+ }
+ }
+
+ template<class Real>
+ Real Bifiltration<Real>::max_x() const
+ {
+ if (simplices_.empty())
+ return 1;
+ auto me = std::max_element(simplices_.cbegin(), simplices_.cend(),
+ [](const auto& s_a, const auto& s_b) { return s_a.position().x < s_b.position().x; });
+ assert(me != simplices_.cend());
+ return me->position().x;
+ }
+
+ template<class Real>
+ Real Bifiltration<Real>::max_y() const
+ {
+ if (simplices_.empty())
+ return 1;
+ auto me = std::max_element(simplices_.cbegin(), simplices_.cend(),
+ [](const auto& s_a, const auto& s_b) { return s_a.position().y < s_b.position().y; });
+ assert(me != simplices_.cend());
+ return me->position().y;
+ }
+
+ template<class Real>
+ Real Bifiltration<Real>::min_x() const
+ {
+ if (simplices_.empty())
+ return 0;
+ auto me = std::min_element(simplices_.cbegin(), simplices_.cend(),
+ [](const auto& s_a, const auto& s_b) { return s_a.position().x < s_b.position().x; });
+ assert(me != simplices_.cend());
+ return me->position().x;
+ }
+
+ template<class Real>
+ Real Bifiltration<Real>::min_y() const
+ {
+ if (simplices_.empty())
+ return 0;
+ auto me = std::min_element(simplices_.cbegin(), simplices_.cend(),
+ [](const auto& s_a, const auto& s_b) { return s_a.position().y < s_b.position().y; });
+ assert(me != simplices_.cend());
+ return me->position().y;
+ }
+
+ template<class Real>
+ void Bifiltration<Real>::add_simplex(Index _id, Point<Real> birth, int _dim, const Column& _bdry)
+ {
+ simplices_.emplace_back(_id, birth, _dim, _bdry);
+ }
+
+ template<class Real>
+ void Bifiltration<Real>::save(const std::string& filename, md::BifiltrationFormat format)
+ {
+ switch(format) {
+ case BifiltrationFormat::rivet:
+ throw std::runtime_error("Not implemented");
+ break;
+ case BifiltrationFormat::phat_like: {
+ std::ofstream f(filename);
+ if (not f.good()) {
+ std::cerr << "Bifiltration::save: cannot open file " << filename << std::endl;
+ throw std::runtime_error("Cannot open file for writing ");
+ }
+ f << simplices_.size() << "\n";
+
+ for(const auto& s : simplices_) {
+ f << s.dim() << " " << s.position().x << " " << s.position().y << " ";
+ for(int b : s.boundary()) {
+ f << b << " ";
+ }
+ f << std::endl;
+ }
+
+ }
+ break;
+ }
+ }
+
+ template<class Real>
+ void Bifiltration<Real>::postprocess_rivet_format()
+ {
+ std::map<Column, Index> facets_to_ids;
+
+ // fill the map
+ for(Index i = 0; i < (Index) simplices_.size(); ++i) {
+ assert(simplices_[i].id() == i);
+ facets_to_ids[simplices_[i].vertices_] = i;
+ }
+
+// for(const auto& s : simplices_) {
+// facets_to_ids[s] = s.id();
+// }
+
+ // main loop
+ for(auto& s : simplices_) {
+ assert(not s.vertices_.empty());
+ assert(s.facet_indices_.empty());
+ Column facet_indices;
+ for(Index i = 0; i <= s.dim(); ++i) {
+ Column facet;
+ for(Index j : s.vertices_) {
+ if (j != i)
+ facet.push_back(j);
+ }
+ auto facet_index = facets_to_ids.at(facet);
+ facet_indices.push_back(facet_index);
+ }
+ s.facet_indices_ = facet_indices;
+ } // loop over simplices
+ }
+
+ template<class Real>
+ std::ostream& operator<<(std::ostream& os, const Bifiltration<Real>& bif)
+ {
+ os << "Bifiltration [" << std::endl;
+ for(const auto& s : bif.simplices()) {
+ os << s << std::endl;
+ }
+ os << "]" << std::endl;
+ return os;
+ }
+
+ template<class Real>
+ BifiltrationProxy<Real>::BifiltrationProxy(const Bifiltration<Real>& bif, int dim)
+ :
+ dim_(dim),
+ bif_(bif)
+ {
+ cache_positions();
+ }
+
+ template<class Real>
+ void BifiltrationProxy<Real>::cache_positions() const
+ {
+ cached_positions_.clear();
+ for(const auto& simplex : bif_.simplices()) {
+ if (simplex.dim() == dim_ or simplex.dim() == dim_ + 1)
+ cached_positions_.push_back(simplex.position());
+ }
+ }
+
+ template<class Real>
+ PointVec<Real>
+ BifiltrationProxy<Real>::positions() const
+ {
+ if (cached_positions_.empty()) {
+ cache_positions();
+ }
+ return cached_positions_;
+ }
+
+ // translate all points by vector (a,a)
+ template<class Real>
+ void BifiltrationProxy<Real>::translate(Real a)
+ {
+ bif_.translate(a);
+ }
+
+ // return minimal value of x- and y-coordinates
+ // among all simplices
+ template<class Real>
+ Real BifiltrationProxy<Real>::minimal_coordinate() const
+ {
+ return bif_.minimal_coordinate();
+ }
+
+ // return box that contains positions of all simplices
+ template<class Real>
+ Box<Real> BifiltrationProxy<Real>::bounding_box() const
+ {
+ return bif_.bounding_box();
+ }
+
+ template<class Real>
+ Real BifiltrationProxy<Real>::max_x() const
+ {
+ return bif_.max_x();
+ }
+
+ template<class Real>
+ Real BifiltrationProxy<Real>::max_y() const
+ {
+ return bif_.max_y();
+ }
+
+ template<class Real>
+ Real BifiltrationProxy<Real>::min_x() const
+ {
+ return bif_.min_x();
+ }
+
+ template<class Real>
+ Real BifiltrationProxy<Real>::min_y() const
+ {
+ return bif_.min_y();
+ }
+
+
+ template<class Real>
+ Diagram<Real> BifiltrationProxy<Real>::weighted_slice_diagram(const DualPoint<Real>& slice) const
+ {
+ return bif_.weighted_slice_diagram(slice, dim_);
+ }
+
+}
+
diff --git a/matching/include/box.h b/matching/include/box.h
index 2990fba..4243667 100644
--- a/matching/include/box.h
+++ b/matching/include/box.h
@@ -8,20 +8,23 @@
namespace md {
+ template<class Real_>
struct Box {
+ public:
+ using Real = Real_;
private:
- Point ll;
- Point ur;
+ Point<Real> ll;
+ Point<Real> ur;
public:
- Box(Point ll = Point(), Point ur = Point())
+ Box(Point<Real> ll = Point<Real>(), Point<Real> ur = Point<Real>())
:ll(ll), ur(ur)
{
}
- Box(Point center, Real width, Real height) :
- ll(Point(center.x - 0.5 * width, center.y - 0.5 * height)),
- ur(Point(center.x + 0.5 * width, center.y + 0.5 * height))
+ Box(Point<Real> center, Real width, Real height) :
+ ll(Point<Real>(center.x - 0.5 * width, center.y - 0.5 * height)),
+ ur(Point<Real>(center.x + 0.5 * width, center.y + 0.5 * height))
{
}
@@ -30,11 +33,9 @@ namespace md {
inline double height() const { return ur.y - ll.y; }
- inline Point lower_left() const { return ll; }
- inline Point upper_right() const { return ur; }
- inline Point center() const { return Point((ll.x + ur.x) / 2, (ll.y + ur.y) / 2); }
-
-// bool inside(Point& p) const { return ll.x <= p.x && ll.y <= p.y && ur.x >= p.x && ur.y >= p.y; }
+ inline Point<Real> lower_left() const { return ll; }
+ inline Point<Real> upper_right() const { return ur; }
+ inline Point<Real> center() const { return Point<Real>((ll.x + ur.x) / 2, (ll.y + ur.y) / 2); }
inline bool operator==(const Box& p)
{
@@ -43,58 +44,16 @@ namespace md {
std::vector<Box> refine() const;
- std::vector<Point> corners() const;
+ std::vector<Point<Real>> corners() const;
void translate(Real a);
-
- // return minimal and maximal value of func
- // on the corners of the box
- template<typename F>
- std::pair<Real, Real> min_max_on_corners(const F& func) const;
-
- friend std::ostream& operator<<(std::ostream& os, const Box& box);
};
- std::ostream& operator<<(std::ostream& os, const Box& box);
-// template<typename InputIterator>
-// Box compute_bounding_box(InputIterator simplices_begin, InputIterator simplices_end)
-// {
-// if (simplices_begin == simplices_end) {
-// return Box();
-// }
-// Box bb;
-// bb.ll = bb.ur = simplices_begin->pos;
-// for (InputIterator it = simplices_begin; it != simplices_end; it++) {
-// Point& pos = it->pos;
-// if (pos.x < bb.ll.x) {
-// bb.ll.x = pos.x;
-// }
-// if (pos.y < bb.ll.y) {
-// bb.ll.y = pos.y;
-// }
-// if (pos.x > bb.ur.x) {
-// bb.ur.x = pos.x;
-// }
-// if (pos.y > bb.ur.y) {
-// bb.ur.y = pos.y;
-// }
-// }
-// return bb;
-// }
-
- Box get_enclosing_box(const Box& box_a, const Box& box_b);
-
- template<typename F>
- std::pair<Real, Real> Box::min_max_on_corners(const F& func) const
- {
- std::pair<Real, Real> min_max { std::numeric_limits<Real>::max(), -std::numeric_limits<Real>::max() };
- for(Point p : corners()) {
- Real value = func(p);
- min_max.first = std::min(min_max.first, value);
- min_max.second = std::max(min_max.second, value);
- }
- return min_max;
- };
+ template<class Real>
+ std::ostream& operator<<(std::ostream& os, const Box<Real>& box);
+
} // namespace md
+#include "box.hpp"
+
#endif //MATCHING_DISTANCE_BOX_H
diff --git a/matching/include/box.hpp b/matching/include/box.hpp
new file mode 100644
index 0000000..f551d84
--- /dev/null
+++ b/matching/include/box.hpp
@@ -0,0 +1,52 @@
+namespace md {
+
+ template<class Real>
+ std::ostream& operator<<(std::ostream& os, const Box<Real>& box)
+ {
+ os << "Box(lower_left = " << box.lower_left() << ", upper_right = " << box.upper_right() << ")";
+ return os;
+ }
+
+ template<class Real>
+ void Box<Real>::translate(Real a)
+ {
+ ll.x += a;
+ ll.y += a;
+ ur.x += a;
+ ur.y += a;
+ }
+
+ template<class Real>
+ std::vector<Box<Real>> Box<Real>::refine() const
+ {
+ std::vector<Box<Real>> result;
+
+// 1 | 2
+// 0 | 3
+
+ Point<Real> new_ll = lower_left();
+ Point<Real> new_ur = center();
+ result.emplace_back(new_ll, new_ur);
+
+ new_ll.y = center().y;
+ new_ur.y = ur.y;
+ result.emplace_back(new_ll, new_ur);
+
+ new_ll = center();
+ new_ur = upper_right();
+ result.emplace_back(new_ll, new_ur);
+
+ new_ll.y = ll.y;
+ new_ur.y = center().y;
+ result.emplace_back(new_ll, new_ur);
+
+ return result;
+ }
+
+ template<class Real>
+ std::vector<Point<Real>> Box<Real>::corners() const
+ {
+ return {ll, Point<Real>(ll.x, ur.y), ur, Point<Real>(ur.x, ll.y)};
+ };
+
+}
diff --git a/matching/include/cell_with_value.h b/matching/include/cell_with_value.h
index 25644d1..3548a11 100644
--- a/matching/include/cell_with_value.h
+++ b/matching/include/cell_with_value.h
@@ -1,7 +1,3 @@
-//
-// Created by narn on 16.07.19.
-//
-
#ifndef MATCHING_DISTANCE_CELL_WITH_VALUE_H
#define MATCHING_DISTANCE_CELL_WITH_VALUE_H
@@ -21,7 +17,29 @@ namespace md {
upper_right
};
- std::ostream& operator<<(std::ostream& os, const ValuePoint& vp);
+ inline std::ostream& operator<<(std::ostream& os, const ValuePoint& vp)
+ {
+ switch(vp) {
+ case ValuePoint::upper_left :
+ os << "upper_left";
+ break;
+ case ValuePoint::upper_right :
+ os << "upper_right";
+ break;
+ case ValuePoint::lower_left :
+ os << "lower_left";
+ break;
+ case ValuePoint::lower_right :
+ os << "lower_right";
+ break;
+ case ValuePoint::center:
+ os << "center";
+ break;
+ default:
+ os << "FORGOTTEN ValuePoint";
+ }
+ return os;
+ }
const std::vector<ValuePoint> k_all_vps = {ValuePoint::center, ValuePoint::lower_left, ValuePoint::upper_left,
ValuePoint::upper_right, ValuePoint::lower_right};
@@ -31,8 +49,10 @@ namespace md {
// represents a cell in the dual space with the value
// of the weighted bottleneck distance
+ template<class Real_>
class CellWithValue {
public:
+ using Real = Real_;
CellWithValue() = default;
@@ -44,18 +64,18 @@ namespace md {
CellWithValue& operator=(CellWithValue&& other) = default;
- CellWithValue(const DualBox& b, int level)
+ CellWithValue(const DualBox<Real>& b, int level)
:dual_box_(b), level_(level) { }
- DualBox dual_box() const { return dual_box_; }
+ DualBox<Real> dual_box() const { return dual_box_; }
- DualPoint center() const { return dual_box_.center(); }
+ DualPoint<Real> center() const { return dual_box_.center(); }
Real value_at(ValuePoint vp) const;
bool has_value_at(ValuePoint vp) const;
- DualPoint value_point(ValuePoint vp) const;
+ DualPoint<Real> value_point(ValuePoint vp) const;
int level() const { return level_; }
@@ -73,8 +93,6 @@ namespace md {
std::vector<CellWithValue> get_refined_cells() const;
- friend std::ostream& operator<<(std::ostream&, const CellWithValue&);
-
void set_max_possible_value(Real new_upper_bound);
int num_values() const;
@@ -100,7 +118,7 @@ namespace md {
bool has_upper_right_value() const { return upper_right_value_ >= 0; }
- DualBox dual_box_;
+ DualBox<Real> dual_box_;
Real central_value_ {-1.0};
Real lower_left_value_ {-1.0};
Real lower_right_value_ {-1.0};
@@ -114,7 +132,10 @@ namespace md {
bool has_max_possible_value_ {false};
};
- std::ostream& operator<<(std::ostream& os, const CellWithValue& cell);
+ template<class Real>
+ std::ostream& operator<<(std::ostream& os, const CellWithValue<Real>& cell);
} // namespace md
+#include "cell_with_value.hpp"
+
#endif //MATCHING_DISTANCE_CELL_WITH_VALUE_H
diff --git a/matching/include/cell_with_value.hpp b/matching/include/cell_with_value.hpp
new file mode 100644
index 0000000..88b2569
--- /dev/null
+++ b/matching/include/cell_with_value.hpp
@@ -0,0 +1,224 @@
+namespace md {
+
+#ifdef MD_DEBUG
+ long long int CellWithValue<Real>::max_id = 0;
+#endif
+
+ template<class Real>
+ Real CellWithValue<Real>::value_at(ValuePoint vp) const
+ {
+ switch(vp) {
+ case ValuePoint::upper_left :
+ return upper_left_value_;
+ case ValuePoint::upper_right :
+ return upper_right_value_;
+ case ValuePoint::lower_left :
+ return lower_left_value_;
+ case ValuePoint::lower_right :
+ return lower_right_value_;
+ case ValuePoint::center:
+ return central_value_;
+ }
+ // to shut up compiler warning
+ return 1.0 / 0.0;
+ }
+
+ template<class Real>
+ bool CellWithValue<Real>::has_value_at(ValuePoint vp) const
+ {
+ switch(vp) {
+ case ValuePoint::upper_left :
+ return upper_left_value_ >= 0;
+ case ValuePoint::upper_right :
+ return upper_right_value_ >= 0;
+ case ValuePoint::lower_left :
+ return lower_left_value_ >= 0;
+ case ValuePoint::lower_right :
+ return lower_right_value_ >= 0;
+ case ValuePoint::center:
+ return central_value_ >= 0;
+ }
+ // to shut up compiler warning
+ return 1.0 / 0.0;
+ }
+
+ template<class Real>
+ DualPoint<Real> CellWithValue<Real>::value_point(md::ValuePoint vp) const
+ {
+ switch(vp) {
+ case ValuePoint::upper_left :
+ return dual_box().upper_left();
+ case ValuePoint::upper_right :
+ return dual_box().upper_right();
+ case ValuePoint::lower_left :
+ return dual_box().lower_left();
+ case ValuePoint::lower_right :
+ return dual_box().lower_right();
+ case ValuePoint::center:
+ return dual_box().center();
+ }
+ // to shut up compiler warning
+ return DualPoint<Real>();
+ }
+
+ template<class Real>
+ bool CellWithValue<Real>::has_corner_value() const
+ {
+ return has_lower_left_value() or has_lower_right_value() or has_upper_left_value()
+ or has_upper_right_value();
+ }
+
+ template<class Real>
+ Real CellWithValue<Real>::stored_upper_bound() const
+ {
+ assert(has_max_possible_value_);
+ return max_possible_value_;
+ }
+
+ template<class Real>
+ Real CellWithValue<Real>::max_corner_value() const
+ {
+ return std::max({lower_left_value_, lower_right_value_, upper_left_value_, upper_right_value_});
+ }
+
+ template<class Real>
+ Real CellWithValue<Real>::min_value() const
+ {
+ Real result = std::numeric_limits<Real>::max();
+ for(auto vp : k_all_vps) {
+ if (not has_value_at(vp)) {
+ continue;
+ }
+ result = std::min(result, value_at(vp));
+ }
+ return result;
+ }
+
+ template<class Real>
+ std::vector<CellWithValue<Real>> CellWithValue<Real>::get_refined_cells() const
+ {
+ std::vector<CellWithValue<Real>> result;
+ result.reserve(4);
+ for(const auto& refined_box : dual_box_.refine()) {
+
+ CellWithValue<Real> refined_cell(refined_box, level() + 1);
+
+#ifdef MD_DEBUG
+ refined_cell.parent_ids = parent_ids;
+ refined_cell.parent_ids.push_back(id);
+ refined_cell.id = ++max_id;
+#endif
+
+ if (refined_box.lower_left() == dual_box_.lower_left()) {
+ // _|_
+ // H|_
+
+ refined_cell.set_value_at(ValuePoint::lower_left, lower_left_value_);
+ refined_cell.set_value_at(ValuePoint::upper_right, central_value_);
+
+ } else if (refined_box.upper_right() == dual_box_.upper_right()) {
+ // _|H
+ // _|_
+
+ refined_cell.set_value_at(ValuePoint::lower_left, central_value_);
+ refined_cell.set_value_at(ValuePoint::upper_right, upper_right_value_);
+
+ } else if (refined_box.lower_right() == dual_box_.lower_right()) {
+ // _|_
+ // _|H
+
+ refined_cell.set_value_at(ValuePoint::lower_right, lower_right_value_);
+ refined_cell.set_value_at(ValuePoint::upper_left, central_value_);
+
+ } else if (refined_box.upper_left() == dual_box_.upper_left()) {
+
+ // H|_
+ // _|_
+
+ refined_cell.set_value_at(ValuePoint::lower_right, central_value_);
+ refined_cell.set_value_at(ValuePoint::upper_left, upper_left_value_);
+ }
+ result.emplace_back(refined_cell);
+ }
+ return result;
+ }
+
+ template<class Real>
+ void CellWithValue<Real>::set_value_at(ValuePoint vp, Real new_value)
+ {
+ if (has_value_at(vp))
+ spd::error("CellWithValue<Real>: trying to re-assign value!, this = {}, vp = {}", *this, vp);
+
+ switch(vp) {
+ case ValuePoint::upper_left :
+ upper_left_value_ = new_value;
+ break;
+ case ValuePoint::upper_right :
+ upper_right_value_ = new_value;
+ break;
+ case ValuePoint::lower_left :
+ lower_left_value_ = new_value;
+ break;
+ case ValuePoint::lower_right :
+ lower_right_value_ = new_value;
+ break;
+ case ValuePoint::center:
+ central_value_ = new_value;
+ break;
+ }
+ }
+
+ template<class Real>
+ int CellWithValue<Real>::num_values() const
+ {
+ int result = 0;
+ for(ValuePoint vp : k_all_vps) {
+ result += has_value_at(vp);
+ }
+ return result;
+ }
+
+
+ template<class Real>
+ void CellWithValue<Real>::set_max_possible_value(Real new_upper_bound)
+ {
+ assert(new_upper_bound >= central_value_);
+ assert(new_upper_bound >= lower_left_value_);
+ assert(new_upper_bound >= lower_right_value_);
+ assert(new_upper_bound >= upper_left_value_);
+ assert(new_upper_bound >= upper_right_value_);
+ has_max_possible_value_ = true;
+ max_possible_value_ = new_upper_bound;
+ }
+
+
+ template<class Real>
+ std::ostream& operator<<(std::ostream& os, const CellWithValue<Real>& cell)
+ {
+ os << "CellWithValue(box = " << cell.dual_box() << ", ";
+
+#ifdef MD_DEBUG
+ os << "id = " << cell.id;
+ if (not cell.parent_ids.empty())
+ os << ", parent_ids = " << container_to_string(cell.parent_ids) << ", ";
+#endif
+
+ for(ValuePoint vp : k_all_vps) {
+ if (cell.has_value_at(vp)) {
+ os << "value = " << cell.value_at(vp);
+ os << ", at " << vp << " " << cell.value_point(vp);
+ }
+ }
+
+ os << ", max_corner_value = ";
+ if (cell.has_max_possible_value()) {
+ os << cell.stored_upper_bound();
+ } else {
+ os << "-";
+ }
+
+ os << ", level = " << cell.level() << ")";
+ return os;
+ }
+
+} // namespace md
diff --git a/matching/include/common_defs.h b/matching/include/common_defs.h
index 3f3d937..8d01325 100644
--- a/matching/include/common_defs.h
+++ b/matching/include/common_defs.h
@@ -1,8 +1,8 @@
#ifndef MATCHING_DISTANCE_DEF_DEBUG_H
#define MATCHING_DISTANCE_DEF_DEBUG_H
-//#define EXPERIMENTAL_TIMING
-//#define PRINT_HEAT_MAP
+//#define MD_EXPERIMENTAL_TIMING
+//#define MD_PRINT_HEAT_MAP
//#define MD_DEBUG
//#define MD_DO_CHECKS
//#define MD_DO_FULL_CHECK
diff --git a/matching/include/common_util.h b/matching/include/common_util.h
index 2d8dcb0..778536f 100644
--- a/matching/include/common_util.h
+++ b/matching/include/common_util.h
@@ -11,22 +11,24 @@
#include <map>
#include <functional>
+#include <spdlog/spdlog.h>
+#include <spdlog/fmt/ostr.h>
+
+namespace spd = spdlog;
+
#include "common_defs.h"
#include "phat/helpers/misc.h"
-
namespace md {
-
- using Real = double;
- using RealVec = std::vector<Real>;
using Index = phat::index;
using IndexVec = std::vector<Index>;
- static constexpr Real pi = M_PI;
+ //static constexpr Real pi = M_PI;
using Column = std::vector<Index>;
+ template<class Real>
struct Point {
Real x;
Real y;
@@ -71,59 +73,56 @@ namespace md {
};
- using PointVec = std::vector<Point>;
-
- Point operator+(const Point& u, const Point& v);
+ template<class Real>
+ using PointVec = std::vector<Point<Real>>;
- Point operator-(const Point& u, const Point& v);
+ template<class Real>
+ Point<Real> operator+(const Point<Real>& u, const Point<Real>& v);
- Point least_upper_bound(const Point& u, const Point& v);
+ template<class Real>
+ Point<Real> operator-(const Point<Real>& u, const Point<Real>& v);
- Point greatest_lower_bound(const Point& u, const Point& v);
- Point max_point();
+ template<class Real>
+ Point<Real> least_upper_bound(const Point<Real>& u, const Point<Real>& v);
- Point min_point();
+ template<class Real>
+ Point<Real> greatest_lower_bound(const Point<Real>& u, const Point<Real>& v);
- std::ostream& operator<<(std::ostream& ostr, const Point& vec);
+ template<class Real>
+ Point<Real> max_point();
- Real L_infty(const Point& v);
+ template<class Real>
+ Point<Real> min_point();
- Real l_2_norm(const Point& v);
+ template<class Real>
+ std::ostream& operator<<(std::ostream& ostr, const Point<Real>& vec);
- Real l_2_dist(const Point& x, const Point& y);
+ template<class Real>
+ using DiagramPoint = std::pair<Real, Real>;
- Real l_infty_dist(const Point& x, const Point& y);
+ template<class Real>
+ using Diagram = std::vector<DiagramPoint<Real>>;
- using Interval = std::pair<Real, Real>;
-
- // return minimal interval that contains both a and b
- inline Interval minimal_covering_interval(Interval a, Interval b)
- {
- return {std::min(a.first, b.first), std::max(a.second, b.second)};
- }
// to keep diagrams in all dimensions
// TODO: store in Hera format?
+ template<class Real>
class DiagramKeeper {
public:
- using DiagramPoint = std::pair<Real, Real>;
- using Diagram = std::vector<DiagramPoint>;
DiagramKeeper() { };
void add_point(int dim, Real birth, Real death);
- Diagram get_diagram(int dim) const;
+ Diagram<Real> get_diagram(int dim) const;
void clear() { data_.clear(); }
private:
- std::map<int, Diagram> data_;
+ std::map<int, Diagram<Real>> data_;
};
- using Diagram = std::vector<std::pair<Real, Real>>;
-
template<typename C>
std::string container_to_string(const C& cont)
{
@@ -140,42 +139,18 @@ namespace md {
return ss.str();
}
- int gcd(int a, int b);
-
- struct Rational {
- int numerator {0};
- int denominator {1};
- Rational() = default;
- Rational(int n, int d) : numerator(n / gcd(n, d)), denominator(d / gcd(n, d)) {}
- Rational(std::pair<int, int> p) : Rational(p.first, p.second) {}
- Rational(int n) : numerator(n), denominator(1) {}
- Real to_real() const { return (Real)numerator / (Real)denominator; }
- void reduce();
- Rational& operator+=(const Rational& rhs);
- Rational& operator-=(const Rational& rhs);
- Rational& operator*=(const Rational& rhs);
- Rational& operator/=(const Rational& rhs);
- };
-
- using namespace std::rel_ops;
-
- bool operator==(const Rational& a, const Rational& b);
- bool operator<(const Rational& a, const Rational& b);
- std::ostream& operator<<(std::ostream& os, const Rational& a);
-
- // arithmetic
- Rational operator+(Rational a, const Rational& b);
- Rational operator-(Rational a, const Rational& b);
- Rational operator*(Rational a, const Rational& b);
- Rational operator/(Rational a, const Rational& b);
-
- Rational reduce(Rational frac);
-
- Rational midpoint(Rational a, Rational b);
-
// return true, if s is empty or starts with # (commented out line)
// whitespaces in the beginning of s are ignored
- bool ignore_line(const std::string& s);
+ inline bool ignore_line(const std::string& s)
+ {
+ for(auto c : s) {
+ if (isspace(c))
+ continue;
+ return (c == '#');
+ }
+ return true;
+ }
+
// split string by delimeter
template<typename Out>
@@ -195,10 +170,10 @@ namespace md {
}
namespace std {
- template<>
- struct hash<md::Point>
+ template<class Real>
+ struct hash<md::Point<Real>>
{
- std::size_t operator()(const md::Point& p) const
+ std::size_t operator()(const md::Point<Real>& p) const
{
auto hx = std::hash<decltype(p.x)>()(p.x);
auto hy = std::hash<decltype(p.y)>()(p.y);
@@ -207,5 +182,7 @@ namespace std {
};
};
+#include "common_util.hpp"
+
#endif //MATCHING_DISTANCE_COMMON_UTIL_H
diff --git a/matching/include/common_util.hpp b/matching/include/common_util.hpp
new file mode 100644
index 0000000..76d97af
--- /dev/null
+++ b/matching/include/common_util.hpp
@@ -0,0 +1,96 @@
+#include <vector>
+#include <utility>
+#include <cmath>
+#include <ostream>
+#include <limits>
+#include <algorithm>
+
+#include <common_util.h>
+
+#include "spdlog/spdlog.h"
+#include "spdlog/fmt/ostr.h"
+
+namespace md {
+
+ template<class Real>
+ Point<Real> operator+(const Point<Real>& u, const Point<Real>& v)
+ {
+ return Point<Real>(u.x + v.x, u.y + v.y);
+ }
+
+ template<class Real>
+ Point<Real> operator-(const Point<Real>& u, const Point<Real>& v)
+ {
+ return Point<Real>(u.x - v.x, u.y - v.y);
+ }
+
+ template<class Real>
+ Point<Real> least_upper_bound(const Point<Real>& u, const Point<Real>& v)
+ {
+ return Point<Real>(std::max(u.x, v.x), std::max(u.y, v.y));
+ }
+
+ template<class Real>
+ Point<Real> greatest_lower_bound(const Point<Real>& u, const Point<Real>& v)
+ {
+ return Point<Real>(std::min(u.x, v.x), std::min(u.y, v.y));
+ }
+
+ template<class Real>
+ Point<Real> max_point()
+ {
+ return Point<Real>(std::numeric_limits<Real>::max(), std::numeric_limits<Real>::min());
+ }
+
+ template<class Real>
+ Point<Real> min_point()
+ {
+ return Point<Real>(-std::numeric_limits<Real>::max(), -std::numeric_limits<Real>::min());
+ }
+
+ template<class Real>
+ std::ostream& operator<<(std::ostream& ostr, const Point<Real>& vec)
+ {
+ ostr << "(" << vec.x << ", " << vec.y << ")";
+ return ostr;
+ }
+
+ template<class Real>
+ Real l_infty_norm(const Point<Real>& v)
+ {
+ return std::max(std::abs(v.x), std::abs(v.y));
+ }
+
+ template<class Real>
+ Real l_2_norm(const Point<Real>& v)
+ {
+ return v.norm();
+ }
+
+ template<class Real>
+ Real l_2_dist(const Point<Real>& x, const Point<Real>& y)
+ {
+ return l_2_norm(x - y);
+ }
+
+ template<class Real>
+ Real l_infty_dist(const Point<Real>& x, const Point<Real>& y)
+ {
+ return l_infty_norm(x - y);
+ }
+
+ template<class Real>
+ void DiagramKeeper<Real>::add_point(int dim, Real birth, Real death)
+ {
+ data_[dim].emplace_back(birth, death);
+ }
+
+ template<class Real>
+ Diagram<Real> DiagramKeeper<Real>::get_diagram(int dim) const
+ {
+ if (data_.count(dim) == 1)
+ return data_.at(dim);
+ else
+ return Diagram<Real>();
+ }
+}
diff --git a/matching/include/dual_box.h b/matching/include/dual_box.h
index ce0384d..0e4f4d5 100644
--- a/matching/include/dual_box.h
+++ b/matching/include/dual_box.h
@@ -4,16 +4,23 @@
#include <ostream>
#include <limits>
#include <vector>
+#include <random>
+
+#include "spdlog/spdlog.h"
+#include "spdlog/fmt/ostr.h"
+
#include "common_util.h"
#include "dual_point.h"
namespace md {
+
+ template<class Real>
class DualBox {
public:
- DualBox(DualPoint ll, DualPoint ur);
+ DualBox(DualPoint<Real> ll, DualPoint<Real> ur);
DualBox() = default;
DualBox(const DualBox&) = default;
@@ -23,12 +30,12 @@ namespace md {
DualBox& operator=(DualBox&& other) = default;
- DualPoint center() const { return midpoint(lower_left_, upper_right_); }
- DualPoint lower_left() const { return lower_left_; }
- DualPoint upper_right() const { return upper_right_; }
+ DualPoint<Real> center() const { return midpoint(lower_left_, upper_right_); }
+ DualPoint<Real> lower_left() const { return lower_left_; }
+ DualPoint<Real> upper_right() const { return upper_right_; }
- DualPoint lower_right() const;
- DualPoint upper_left() const;
+ DualPoint<Real> lower_right() const;
+ DualPoint<Real> upper_left() const;
AxisType axis_type() const { return lower_left_.axis_type(); }
AngleType angle_type() const { return lower_left_.angle_type(); }
@@ -42,66 +49,35 @@ namespace md {
bool is_flat() const { return upper_right_.is_flat(); }
bool is_steep() const { return lower_left_.is_steep(); }
- // return minimal and maximal value of func
- // on the corners of the box
- template<typename F>
- std::pair<Real, Real> min_max_on_corners(const F& func) const;
-
- template<typename F>
- Real max_abs_value(const F& func) const;
-
-
std::vector<DualBox> refine() const;
- std::vector<DualPoint> corners() const;
- std::vector<DualPoint> critical_points(const Point& p) const;
+ std::vector<DualPoint<Real>> corners() const;
+ std::vector<DualPoint<Real>> critical_points(const Point<Real>& p) const;
// sample n points from the box uniformly; for tests
- std::vector<DualPoint> random_points(int n) const;
+ std::vector<DualPoint<Real>> random_points(int n) const;
// return 2 dual points at the boundary
// where push changes from horizontal to vertical
- std::vector<DualPoint> push_change_points(const Point& p) const;
-
- friend std::ostream& operator<<(std::ostream& os, const DualBox& db);
+ std::vector<DualPoint<Real>> push_change_points(const Point<Real>& p) const;
// check that a has same sign, angles are all flat or all steep
bool sanity_check() const;
- bool contains(const DualPoint& dp) const;
+ bool contains(const DualPoint<Real>& dp) const;
bool operator==(const DualBox& other) const;
private:
- DualPoint lower_left_;
- DualPoint upper_right_;
+ DualPoint<Real> lower_left_;
+ DualPoint<Real> upper_right_;
};
- std::ostream& operator<<(std::ostream& os, const DualBox& db);
-
- template<typename F>
- std::pair<Real, Real> DualBox::min_max_on_corners(const F& func) const
+ template<class Real>
+ std::ostream& operator<<(std::ostream& os, const DualBox<Real>& db)
{
- std::pair<Real, Real> min_max { std::numeric_limits<Real>::max(), -std::numeric_limits<Real>::max() };
- for(auto p : corners()) {
- Real value = func(p);
- min_max.first = std::min(min_max.first, value);
- min_max.second = std::max(min_max.second, value);
- }
- return min_max;
- };
-
-
- template<typename F>
- Real DualBox::max_abs_value(const F& func) const
- {
- Real result = 0;
- for(auto p_1 : corners()) {
- for(auto p_2 : corners()) {
- Real value = fabs(func(p_1, p_2));
- result = std::max(value, result);
- }
- }
- return result;
- };
-
+ os << "DualBox(" << db.lower_left() << ", " << db.upper_right() << ")";
+ return os;
+ }
}
+#include "dual_box.hpp"
+
#endif //MATCHING_DISTANCE_DUAL_BOX_H
diff --git a/matching/include/dual_box.hpp b/matching/include/dual_box.hpp
new file mode 100644
index 0000000..85f7f27
--- /dev/null
+++ b/matching/include/dual_box.hpp
@@ -0,0 +1,190 @@
+namespace md {
+
+ template<class Real>
+ DualBox<Real>::DualBox(DualPoint<Real> ll, DualPoint<Real> ur)
+ :lower_left_(ll), upper_right_(ur)
+ {
+ }
+
+ template<class Real>
+ std::vector<DualPoint<Real>> DualBox<Real>::corners() const
+ {
+ return {lower_left_,
+ DualPoint<Real>(axis_type(), angle_type(), lower_left_.lambda(), upper_right_.mu()),
+ upper_right_,
+ DualPoint<Real>(axis_type(), angle_type(), upper_right_.lambda(), lower_left_.mu())};
+ }
+
+ template<class Real>
+ std::vector<DualPoint<Real>> DualBox<Real>::push_change_points(const Point<Real>& p) const
+ {
+ std::vector<DualPoint<Real>> result;
+ result.reserve(2);
+
+ bool is_y_type = lower_left_.is_y_type();
+ bool is_flat = lower_left_.is_flat();
+
+ auto mu_from_lambda = [p, is_y_type, is_flat](Real lambda) {
+ bool is_x_type = not is_y_type, is_steep = not is_flat;
+ if (is_y_type && is_flat) {
+ return p.y - lambda * p.x;
+ } else if (is_y_type && is_steep) {
+ return p.y - p.x / lambda;
+ } else if (is_x_type && is_flat) {
+ return p.x - p.y / lambda;
+ } else if (is_x_type && is_steep) {
+ return p.x - lambda * p.y;
+ }
+ // to shut up compiler warning
+ return static_cast<Real>(1.0 / 0.0);
+ };
+
+ auto lambda_from_mu = [p, is_y_type, is_flat](Real mu) {
+ bool is_x_type = not is_y_type, is_steep = not is_flat;
+ if (is_y_type && is_flat) {
+ return (p.y - mu) / p.x;
+ } else if (is_y_type && is_steep) {
+ return p.x / (p.y - mu);
+ } else if (is_x_type && is_flat) {
+ return p.y / (p.x - mu);
+ } else if (is_x_type && is_steep) {
+ return (p.x - mu) / p.y;
+ }
+ // to shut up compiler warning
+ return static_cast<Real>(1.0 / 0.0);
+ };
+
+ // all inequalities below are strict: equality means it is a corner
+ // && critical_points() returns corners anyway
+
+ Real mu_intersect_min = mu_from_lambda(lambda_min());
+
+ if (mu_min() < mu_intersect_min && mu_intersect_min < mu_max())
+ result.emplace_back(axis_type(), angle_type(), lambda_min(), mu_intersect_min);
+
+ Real mu_intersect_max = mu_from_lambda(lambda_max());
+
+ if (mu_max() < mu_intersect_max && mu_intersect_max < mu_max())
+ result.emplace_back(axis_type(), angle_type(), lambda_max(), mu_intersect_max);
+
+ Real lambda_intersect_min = lambda_from_mu(mu_min());
+
+ if (lambda_min() < lambda_intersect_min && lambda_intersect_min < lambda_max())
+ result.emplace_back(axis_type(), angle_type(), lambda_intersect_min, mu_min());
+
+ Real lambda_intersect_max = lambda_from_mu(mu_max());
+ if (lambda_min() < lambda_intersect_max && lambda_intersect_max < lambda_max())
+ result.emplace_back(axis_type(), angle_type(), lambda_intersect_max, mu_max());
+
+ assert(result.size() <= 2);
+
+ if (result.size() > 2) {
+ fmt::print("Error in push_change_points, p = {}, dual_box = {}, result = {}\n", p, *this,
+ container_to_string(result));
+ throw std::runtime_error("push_change_points returned more than 2 points");
+ }
+
+ return result;
+ }
+
+ template<class Real>
+ std::vector<DualPoint<Real>> DualBox<Real>::critical_points(const Point<Real>& /*p*/) const
+ {
+ // maximal difference is attained at corners
+ return corners();
+// std::vector<DualPoint<Real>> result;
+// result.reserve(6);
+// for(auto dp : corners()) result.push_back(dp);
+// for(auto dp : push_change_points(p)) result.push_back(dp);
+// return result;
+ }
+
+ template<class Real>
+ std::vector<DualPoint<Real>> DualBox<Real>::random_points(int n) const
+ {
+ assert(n >= 0);
+ std::mt19937_64 gen(1);
+ std::vector<DualPoint<Real>> result;
+ result.reserve(n);
+ std::uniform_real_distribution<Real> mu_distr(mu_min(), mu_max());
+ std::uniform_real_distribution<Real> lambda_distr(lambda_min(), lambda_max());
+ for(int i = 0; i < n; ++i) {
+ result.emplace_back(axis_type(), angle_type(), lambda_distr(gen), mu_distr(gen));
+ }
+ return result;
+ }
+
+ template<class Real>
+ bool DualBox<Real>::sanity_check() const
+ {
+ lower_left_.sanity_check();
+ upper_right_.sanity_check();
+
+ if (lower_left_.angle_type() != upper_right_.angle_type())
+ throw std::runtime_error("angle types differ");
+
+ if (lower_left_.axis_type() != upper_right_.axis_type())
+ throw std::runtime_error("axis types differ");
+
+ if (lower_left_.lambda() >= upper_right_.lambda())
+ throw std::runtime_error("lambda of lower_left_ greater than lambda of upper_right ");
+
+ if (lower_left_.mu() >= upper_right_.mu())
+ throw std::runtime_error("mu of lower_left_ greater than mu of upper_right ");
+
+ return true;
+ }
+
+ template<class Real>
+ std::vector<DualBox<Real>> DualBox<Real>::refine() const
+ {
+ std::vector<DualBox<Real>> result;
+
+ result.reserve(4);
+
+ Real lambda_middle = (lower_left().lambda() + upper_right().lambda()) / 2.0;
+ Real mu_middle = (lower_left().mu() + upper_right().mu()) / 2.0;
+
+ DualPoint<Real> refinement_center(axis_type(), angle_type(), lambda_middle, mu_middle);
+
+ result.emplace_back(lower_left_, refinement_center);
+
+ result.emplace_back(DualPoint<Real>(axis_type(), angle_type(), lambda_middle, mu_min()),
+ DualPoint<Real>(axis_type(), angle_type(), lambda_max(), mu_middle));
+
+ result.emplace_back(refinement_center, upper_right_);
+
+ result.emplace_back(DualPoint<Real>(axis_type(), angle_type(), lambda_min(), mu_middle),
+ DualPoint<Real>(axis_type(), angle_type(), lambda_middle, mu_max()));
+ return result;
+ }
+
+ template<class Real>
+ bool DualBox<Real>::operator==(const DualBox& other) const
+ {
+ return lower_left() == other.lower_left() &&
+ upper_right() == other.upper_right();
+ }
+
+ template<class Real>
+ bool DualBox<Real>::contains(const DualPoint<Real>& dp) const
+ {
+ return dp.angle_type() == angle_type() && dp.axis_type() == axis_type() &&
+ mu_max() >= dp.mu() &&
+ mu_min() <= dp.mu() &&
+ lambda_min() <= dp.lambda() &&
+ lambda_max() >= dp.lambda();
+ }
+
+ template<class Real>
+ DualPoint<Real> DualBox<Real>::lower_right() const
+ {
+ return DualPoint<Real>(lower_left_.axis_type(), lower_left_.angle_type(), lambda_max(), mu_min());
+ }
+
+ template<class Real>
+ DualPoint<Real> DualBox<Real>::upper_left() const
+ {
+ return DualPoint<Real>(lower_left_.axis_type(), lower_left_.angle_type(), lambda_min(), mu_max());
+ }
+}
diff --git a/matching/include/dual_point.h b/matching/include/dual_point.h
index db32f1a..8438860 100644
--- a/matching/include/dual_point.h
+++ b/matching/include/dual_point.h
@@ -1,12 +1,9 @@
-//
-// Created by narn on 12.02.19.
-//
-
#ifndef MATCHING_DISTANCE_DUAL_POINT_H
#define MATCHING_DISTANCE_DUAL_POINT_H
#include <vector>
#include <ostream>
+#include <tuple>
#include "common_util.h"
#include "box.h"
@@ -25,9 +22,10 @@ namespace md {
// so, e.g., line y = x has 4 different non-equal representation.
// we are unlikely to ever need this, because 4 cases are
// always treated separately.
+ template<class Real_>
class DualPoint {
public:
- using Real = md::Real;
+ using Real = Real_;
DualPoint() = default;
@@ -56,7 +54,6 @@ namespace md {
bool is_y_type() const { return axis_type_ == AxisType::y_type; }
- friend std::ostream& operator<<(std::ostream& os, const DualPoint& dp);
bool operator<(const DualPoint& rhs) const;
AxisType axis_type() const { return axis_type_; }
@@ -66,16 +63,16 @@ namespace md {
// return true otherwise
bool sanity_check() const;
- Real weighted_push(Point p) const;
- Point push(Point p) const;
+ Real weighted_push(Point<Real> p) const;
+ Point<Real> push(Point<Real> p) const;
bool is_horizontal() const;
bool is_vertical() const;
- bool goes_below(Point p) const;
- bool goes_above(Point p) const;
+ bool goes_below(Point<Real> p) const;
+ bool goes_above(Point<Real> p) const;
- bool contains(Point p) const;
+ bool contains(Point<Real> p) const;
Real x_slope() const;
Real y_slope() const;
@@ -98,9 +95,13 @@ namespace md {
Real mu_ {-1.0};
};
- std::ostream& operator<<(std::ostream& os, const DualPoint& dp);
+ template<class Real>
+ std::ostream& operator<<(std::ostream& os, const DualPoint<Real>& dp);
- DualPoint midpoint(DualPoint x, DualPoint y);
+ template<class Real>
+ DualPoint<Real> midpoint(DualPoint<Real> x, DualPoint<Real> y);
};
+#include "dual_point.hpp"
+
#endif //MATCHING_DISTANCE_DUAL_POINT_H
diff --git a/matching/include/dual_point.hpp b/matching/include/dual_point.hpp
new file mode 100644
index 0000000..04e25f2
--- /dev/null
+++ b/matching/include/dual_point.hpp
@@ -0,0 +1,299 @@
+namespace md {
+
+ inline std::ostream& operator<<(std::ostream& os, const AxisType& at)
+ {
+ if (at == AxisType::x_type)
+ os << "x-type";
+ else
+ os << "y-type";
+ return os;
+ }
+
+ inline std::ostream& operator<<(std::ostream& os, const AngleType& at)
+ {
+ if (at == AngleType::flat)
+ os << "flat";
+ else
+ os << "steep";
+ return os;
+ }
+
+ template<class Real>
+ std::ostream& operator<<(std::ostream& os, const DualPoint<Real>& dp)
+ {
+ os << "Line(" << dp.axis_type() << ", ";
+ os << dp.angle_type() << ", ";
+ os << dp.lambda() << ", ";
+ os << dp.mu() << ", equation: ";
+ if (not dp.is_vertical()) {
+ os << "y = " << dp.y_slope() << " x + " << dp.y_intercept();
+ } else {
+ os << "x = " << dp.x_intercept();
+ }
+ os << " )";
+ return os;
+ }
+
+ template<class Real>
+ bool DualPoint<Real>::operator<(const DualPoint<Real>& rhs) const
+ {
+ return std::tie(axis_type_, angle_type_, lambda_, mu_)
+ < std::tie(rhs.axis_type_, rhs.angle_type_, rhs.lambda_, rhs.mu_);
+ }
+
+ template<class Real>
+ DualPoint<Real>::DualPoint(AxisType axis_type, AngleType angle_type, Real lambda, Real mu)
+ :
+ axis_type_(axis_type),
+ angle_type_(angle_type),
+ lambda_(lambda),
+ mu_(mu)
+ {
+ assert(sanity_check());
+ }
+
+ template<class Real>
+ bool DualPoint<Real>::sanity_check() const
+ {
+ if (lambda_ < 0.0)
+ throw std::runtime_error("Invalid line, negative lambda");
+ if (lambda_ > 1.0)
+ throw std::runtime_error("Invalid line, lambda > 1");
+ if (mu_ < 0.0)
+ throw std::runtime_error("Invalid line, negative mu");
+ return true;
+ }
+
+ template<class Real>
+ Real DualPoint<Real>::gamma() const
+ {
+ if (is_steep())
+ return atan(Real(1.0) / lambda_);
+ else
+ return atan(lambda_);
+ }
+
+ template<class Real>
+ DualPoint<Real> midpoint(DualPoint<Real> x, DualPoint<Real> y)
+ {
+ assert(x.angle_type() == y.angle_type() and x.axis_type() == y.axis_type());
+ Real lambda_mid = (x.lambda() + y.lambda()) / 2;
+ Real mu_mid = (x.mu() + y.mu()) / 2;
+ return DualPoint<Real>(x.axis_type(), x.angle_type(), lambda_mid, mu_mid);
+
+ }
+
+ // return k in the line equation y = kx + b
+ template<class Real>
+ Real DualPoint<Real>::y_slope() const
+ {
+ if (is_flat())
+ return lambda();
+ else
+ return Real(1.0) / lambda();
+ }
+
+ // return k in the line equation x = ky + b
+ template<class Real>
+ Real DualPoint<Real>::x_slope() const
+ {
+ if (is_flat())
+ return Real(1.0) / lambda();
+ else
+ return lambda();
+ }
+
+ // return b in the line equation y = kx + b
+ template<class Real>
+ Real DualPoint<Real>::y_intercept() const
+ {
+ if (is_y_type()) {
+ return mu();
+ } else {
+ // x = x_slope * y + mu = x_slope * (y + mu / x_slope)
+ // x-intercept is -mu/x_slope = -mu * y_slope
+ return -mu() * y_slope();
+ }
+ }
+
+ // return k in the line equation x = ky + b
+ template<class Real>
+ Real DualPoint<Real>::x_intercept() const
+ {
+ if (is_x_type()) {
+ return mu();
+ } else {
+ // y = y_slope * x + mu = y_slope (x + mu / y_slope)
+ // x_intercept is -mu/y_slope = -mu * x_slope
+ return -mu() * x_slope();
+ }
+ }
+
+ template<class Real>
+ Real DualPoint<Real>::x_from_y(Real y) const
+ {
+ if (is_horizontal())
+ throw std::runtime_error("x_from_y called on horizontal line");
+ else
+ return x_slope() * y + x_intercept();
+ }
+
+ template<class Real>
+ Real DualPoint<Real>::y_from_x(Real x) const
+ {
+ if (is_vertical())
+ throw std::runtime_error("x_from_y called on horizontal line");
+ else
+ return y_slope() * x + y_intercept();
+ }
+
+ template<class Real>
+ bool DualPoint<Real>::is_horizontal() const
+ {
+ return is_flat() and lambda() == 0;
+ }
+
+ template<class Real>
+ bool DualPoint<Real>::is_vertical() const
+ {
+ return is_steep() and lambda() == 0;
+ }
+
+ template<class Real>
+ bool DualPoint<Real>::contains(Point<Real> p) const
+ {
+ if (is_vertical())
+ return p.x == x_from_y(p.y);
+ else
+ return p.y == y_from_x(p.x);
+ }
+
+ template<class Real>
+ bool DualPoint<Real>::goes_below(Point<Real> p) const
+ {
+ if (is_vertical())
+ return p.x <= x_from_y(p.y);
+ else
+ return p.y >= y_from_x(p.x);
+ }
+
+ template<class Real>
+ bool DualPoint<Real>::goes_above(Point<Real> p) const
+ {
+ if (is_vertical())
+ return p.x >= x_from_y(p.y);
+ else
+ return p.y <= y_from_x(p.x);
+ }
+
+ template<class Real>
+ Point<Real> DualPoint<Real>::push(Point<Real> p) const
+ {
+ Point<Real> result;
+ // if line is below p, we push horizontally
+ bool horizontal_push = goes_below(p);
+ if (is_x_type()) {
+ if (is_flat()) {
+ if (horizontal_push) {
+ result.x = p.y / lambda() + mu();
+ result.y = p.y;
+ } else {
+ // vertical push
+ result.x = p.x;
+ result.y = lambda() * (p.x - mu());
+ }
+ } else {
+ // steep
+ if (horizontal_push) {
+ result.x = lambda() * p.y + mu();
+ result.y = p.y;
+ } else {
+ // vertical push
+ result.x = p.x;
+ result.y = (p.x - mu()) / lambda();
+ }
+ }
+ } else {
+ // y-type
+ if (is_flat()) {
+ if (horizontal_push) {
+ result.x = (p.y - mu()) / lambda();
+ result.y = p.y;
+ } else {
+ // vertical push
+ result.x = p.x;
+ result.y = lambda() * p.x + mu();
+ }
+ } else {
+ // steep
+ if (horizontal_push) {
+ result.x = (p.y - mu()) * lambda();
+ result.y = p.y;
+ } else {
+ // vertical push
+ result.x = p.x;
+ result.y = p.x / lambda() + mu();
+ }
+ }
+ }
+ return result;
+ }
+
+ template<class Real>
+ Real DualPoint<Real>::weighted_push(Point<Real> p) const
+ {
+ // if line is below p, we push horizontally
+ bool horizontal_push = goes_below(p);
+ if (is_x_type()) {
+ if (is_flat()) {
+ if (horizontal_push) {
+ return p.y;
+ } else {
+ // vertical push
+ return lambda() * (p.x - mu());
+ }
+ } else {
+ // steep
+ if (horizontal_push) {
+ return lambda() * p.y;
+ } else {
+ // vertical push
+ return (p.x - mu());
+ }
+ }
+ } else {
+ // y-type
+ if (is_flat()) {
+ if (horizontal_push) {
+ return p.y - mu();
+ } else {
+ // vertical push
+ return lambda() * p.x;
+ }
+ } else {
+ // steep
+ if (horizontal_push) {
+ return lambda() * (p.y - mu());
+ } else {
+ // vertical push
+ return p.x;
+ }
+ }
+ }
+ }
+
+ template<class Real>
+ bool DualPoint<Real>::operator==(const DualPoint<Real>& other) const
+ {
+ return axis_type() == other.axis_type() and
+ angle_type() == other.angle_type() and
+ mu() == other.mu() and
+ lambda() == other.lambda();
+ }
+
+ template<class Real>
+ Real DualPoint<Real>::weight() const
+ {
+ return lambda_ / sqrt(1 + lambda_ * lambda_);
+ }
+} // namespace md
diff --git a/matching/include/matching_distance.h b/matching/include/matching_distance.h
index bb10203..5be34c7 100644
--- a/matching/include/matching_distance.h
+++ b/matching/include/matching_distance.h
@@ -4,9 +4,10 @@
#include <limits>
#include <utility>
#include <ostream>
+#include <chrono>
+#include <tuple>
+#include <algorithm>
-#include "spdlog/spdlog.h"
-#include "spdlog/fmt/ostr.h"
#include "common_defs.h"
#include "cell_with_value.h"
@@ -17,12 +18,15 @@
#include "bifiltration.h"
#include "bottleneck.h"
-namespace spd = spdlog;
-
namespace md {
- using HeatMap = std::map<DualPoint, Real>;
- using HeatMaps = std::map<int, HeatMap>;
+#ifdef MD_PRINT_HEAT_MAP
+ template<class Real>
+ using HeatMap = std::map<DualPoint<Real>, Real>;
+
+ template<class Real>
+ using HeatMaps = std::map<int, HeatMap<Real>>;
+#endif
enum class BoundStrategy {
bruteforce,
@@ -39,18 +43,107 @@ namespace md {
upper_bound
};
- std::ostream& operator<<(std::ostream& os, const BoundStrategy& s);
-
- std::ostream& operator<<(std::ostream& os, const TraverseStrategy& s);
-
- std::istream& operator>>(std::istream& is, BoundStrategy& s);
-
- std::istream& operator>>(std::istream& is, TraverseStrategy& s);
-
- BoundStrategy bs_from_string(std::string s);
-
- TraverseStrategy ts_from_string(std::string s);
-
+ inline std::ostream& operator<<(std::ostream& os, const BoundStrategy& s)
+ {
+ switch(s) {
+ case BoundStrategy::bruteforce :
+ os << "bruteforce";
+ break;
+ case BoundStrategy::local_dual_bound :
+ os << "local_grob";
+ break;
+ case BoundStrategy::local_combined :
+ os << "local_combined";
+ break;
+ case BoundStrategy::local_dual_bound_refined :
+ os << "local_refined";
+ break;
+ case BoundStrategy::local_dual_bound_for_each_point :
+ os << "local_for_each_point";
+ break;
+ default:
+ os << "FORGOTTEN BOUND STRATEGY";
+ }
+ return os;
+ }
+
+ inline std::ostream& operator<<(std::ostream& os, const TraverseStrategy& s)
+ {
+ switch(s) {
+ case TraverseStrategy::depth_first :
+ os << "DFS";
+ break;
+ case TraverseStrategy::breadth_first :
+ os << "BFS";
+ break;
+ case TraverseStrategy::breadth_first_value :
+ os << "BFS-VAL";
+ break;
+ case TraverseStrategy::upper_bound :
+ os << "UB";
+ break;
+ default:
+ os << "FORGOTTEN TRAVERSE STRATEGY";
+ }
+ return os;
+ }
+
+ inline std::istream& operator>>(std::istream& is, TraverseStrategy& s)
+ {
+ std::string ss;
+ is >> ss;
+ if (ss == "DFS") {
+ s = TraverseStrategy::depth_first;
+ } else if (ss == "BFS") {
+ s = TraverseStrategy::breadth_first;
+ } else if (ss == "BFS-VAL") {
+ s = TraverseStrategy::breadth_first_value;
+ } else if (ss == "UB") {
+ s = TraverseStrategy::upper_bound;
+ } else {
+ throw std::runtime_error("UNKNOWN TRAVERSE STRATEGY");
+ }
+ return is;
+ }
+
+
+ inline std::istream& operator>>(std::istream& is, BoundStrategy& s)
+ {
+ std::string ss;
+ is >> ss;
+ if (ss == "bruteforce") {
+ s = BoundStrategy::bruteforce;
+ } else if (ss == "local_grob") {
+ s = BoundStrategy::local_dual_bound;
+ } else if (ss == "local_combined") {
+ s = BoundStrategy::local_combined;
+ } else if (ss == "local_refined") {
+ s = BoundStrategy::local_dual_bound_refined;
+ } else if (ss == "local_for_each_point") {
+ s = BoundStrategy::local_dual_bound_for_each_point;
+ } else {
+ throw std::runtime_error("UNKNOWN BOUND STRATEGY");
+ }
+ return is;
+ }
+
+ inline BoundStrategy bs_from_string(std::string s)
+ {
+ std::stringstream ss(s);
+ BoundStrategy result;
+ ss >> result;
+ return result;
+ }
+
+ inline TraverseStrategy ts_from_string(std::string s)
+ {
+ std::stringstream ss(s);
+ TraverseStrategy result;
+ ss >> result;
+ return result;
+ }
+
+ template<class Real>
struct CalculationParams {
static constexpr int ALL_DIMENSIONS = -1;
@@ -75,22 +168,22 @@ namespace md {
// print statistics on each quad-tree level
bool print_stats { false };
-#ifdef PRINT_HEAT_MAP
+#ifdef MD_PRINT_HEAT_MAP
HeatMaps heat_maps;
#endif
};
- template<class DiagramProvider>
+ template<class Real_, class DiagramProvider>
class DistanceCalculator {
- using DualBox = md::DualBox;
- using CellValueVector = std::vector<CellWithValue>;
+ using Real = Real_;
+ using CellValueVector = std::vector<CellWithValue<Real>>;
public:
DistanceCalculator(const DiagramProvider& a,
const DiagramProvider& b,
- CalculationParams& params);
+ CalculationParams<Real>& params);
Real distance();
@@ -100,7 +193,7 @@ namespace md {
DiagramProvider module_a_;
DiagramProvider module_b_;
- CalculationParams& params_;
+ CalculationParams<Real>& params_;
int n_hera_calls_;
std::map<int, int> n_hera_calls_per_level_;
@@ -112,65 +205,83 @@ namespace md {
CellValueVector get_initial_dual_grid(Real& lower_bound);
+#ifdef MD_PRINT_HEAT_MAP
void heatmap_in_dimension(int dim, int depth);
+#endif
Real get_max_x(int module) const;
Real get_max_y(int module) const;
- void set_cell_central_value(CellWithValue& dual_cell);
+ void set_cell_central_value(CellWithValue<Real>& dual_cell);
Real get_distance();
Real get_distance_pq();
- // temporary, to try priority queue
- Real get_max_possible_value(const CellWithValue* first_cell_ptr, int n_cells);
+ Real get_max_possible_value(const CellWithValue<Real>* first_cell_ptr, int n_cells);
- Real get_upper_bound(const CellWithValue& dual_cell, Real good_enough_upper_bound) const;
+ Real get_upper_bound(const CellWithValue<Real>& dual_cell, Real good_enough_upper_bound) const;
- Real get_single_dgm_bound(const CellWithValue& dual_cell, ValuePoint vp, int module,
+ Real get_single_dgm_bound(const CellWithValue<Real>& dual_cell, ValuePoint vp, int module,
Real good_enough_value) const;
// this bound depends only on dual box
- Real get_local_dual_bound(int module, const DualBox& dual_box) const;
+ Real get_local_dual_bound(int module, const DualBox<Real>& dual_box) const;
- Real get_local_dual_bound(const DualBox& dual_box) const;
+ Real get_local_dual_bound(const DualBox<Real>& dual_box) const;
// this bound depends only on dual box, is more accurate
- Real get_local_refined_bound(int module, const md::DualBox& dual_box) const;
+ Real get_local_refined_bound(int module, const DualBox<Real>& dual_box) const;
- Real get_local_refined_bound(const md::DualBox& dual_box) const;
+ Real get_local_refined_bound(const DualBox<Real>& dual_box) const;
Real get_good_enough_upper_bound(Real lower_bound) const;
- Real
- get_max_displacement_single_point(const CellWithValue& dual_cell, ValuePoint value_point, const Point& p) const;
+ Real get_max_displacement_single_point(const CellWithValue<Real>& dual_cell, ValuePoint value_point,
+ const Point<Real>& p) const;
- void check_upper_bound(const CellWithValue& dual_cell) const;
+ void check_upper_bound(const CellWithValue<Real>& dual_cell) const;
- Real distance_on_line(DualPoint line);
- Real distance_on_line_const(DualPoint line) const;
+ Real distance_on_line(DualPoint<Real> line);
+ Real distance_on_line_const(DualPoint<Real> line) const;
Real current_error(Real lower_bound, Real upper_bound);
};
- Real matching_distance(const Bifiltration& bif_a, const Bifiltration& bif_b, CalculationParams& params);
+ template<class Real>
+ Real matching_distance(const Bifiltration<Real>& bif_a, const Bifiltration<Real>& bif_b,
+ CalculationParams<Real>& params);
- Real matching_distance(const ModulePresentation& mod_a, const ModulePresentation& mod_b, CalculationParams& params);
+ template<class Real>
+ Real matching_distance(const ModulePresentation<Real>& mod_a, const ModulePresentation<Real>& mod_b,
+ CalculationParams<Real>& params);
// for upper bound experiment
struct UbExperimentRecord {
- Real error;
- Real lower_bound;
- Real upper_bound;
- CellWithValue cell;
+ double error;
+ double lower_bound;
+ double upper_bound;
+ CellWithValue<double> cell;
long long int time;
long long int n_hera_calls;
};
- std::ostream& operator<<(std::ostream& os, const UbExperimentRecord& r);
+ inline std::ostream& operator<<(std::ostream& os, const UbExperimentRecord& r)
+ {
+ os << r.time << "\t" << r.n_hera_calls << "\t" << r.error << "\t" << r.lower_bound << "\t" << r.upper_bound;
+ return os;
+ }
+
+
+ template<class K, class V>
+ void print_map(const std::map<K, V>& dic)
+ {
+ for(const auto kv : dic) {
+ fmt::print("{} -> {}\n", kv.first, kv.second);
+ }
+ }
-}
+} // namespace md
#include "matching_distance.hpp"
diff --git a/matching/include/matching_distance.hpp b/matching/include/matching_distance.hpp
index d2d2fbc..48c8464 100644
--- a/matching/include/matching_distance.hpp
+++ b/matching/include/matching_distance.hpp
@@ -1,34 +1,26 @@
namespace md {
- template<class K, class V>
- void print_map(const std::map<K, V>& dic)
- {
- for(const auto kv : dic) {
- fmt::print("{} -> {}\n", kv.first, kv.second);
- }
- }
-
- template<class T>
- void DistanceCalculator<T>::check_upper_bound(const CellWithValue& dual_cell) const
+ template<class R, class T>
+ void DistanceCalculator<R, T>::check_upper_bound(const CellWithValue<R>& dual_cell) const
{
spd::debug("Enter check_get_max_delta_on_cell");
const int n_samples_lambda = 100;
const int n_samples_mu = 100;
- DualBox db = dual_cell.dual_box();
- Real min_lambda = db.lambda_min();
- Real max_lambda = db.lambda_max();
- Real min_mu = db.mu_min();
- Real max_mu = db.mu_max();
-
- Real h_lambda = (max_lambda - min_lambda) / n_samples_lambda;
- Real h_mu = (max_mu - min_mu) / n_samples_mu;
+ DualBox<R> db = dual_cell.dual_box();
+ R min_lambda = db.lambda_min();
+ R max_lambda = db.lambda_max();
+ R min_mu = db.mu_min();
+ R max_mu = db.mu_max();
+
+ R h_lambda = (max_lambda - min_lambda) / n_samples_lambda;
+ R h_mu = (max_mu - min_mu) / n_samples_mu;
for(int i = 1; i < n_samples_lambda; ++i) {
for(int j = 1; j < n_samples_mu; ++j) {
- Real lambda = min_lambda + i * h_lambda;
- Real mu = min_mu + j * h_mu;
- DualPoint l(db.axis_type(), db.angle_type(), lambda, mu);
- Real other_result = distance_on_line_const(l);
- Real diff = fabs(dual_cell.stored_upper_bound() - other_result);
+ R lambda = min_lambda + i * h_lambda;
+ R mu = min_mu + j * h_mu;
+ DualPoint<R> l(db.axis_type(), db.angle_type(), lambda, mu);
+ R other_result = distance_on_line_const(l);
+ R diff = fabs(dual_cell.stored_upper_bound() - other_result);
if (other_result > dual_cell.stored_upper_bound()) {
spd::error(
"in check_upper_bound, upper_bound = {}, other_result = {}, diff = {}\ndual_cell = {}",
@@ -42,10 +34,10 @@ namespace md {
// for all lines l, l' inside dual box,
// find the upper bound on the difference of weighted pushes of p
- template<class T>
- Real
- DistanceCalculator<T>::get_max_displacement_single_point(const CellWithValue& dual_cell, ValuePoint vp,
- const Point& p) const
+ template<class R, class T>
+ R
+ DistanceCalculator<R, T>::get_max_displacement_single_point(const CellWithValue<R>& dual_cell, ValuePoint vp,
+ const Point<R>& p) const
{
assert(p.x >= 0 && p.y >= 0);
@@ -53,15 +45,15 @@ namespace md {
std::vector<long long int> debug_ids = {3, 13, 54, 218, 350, 382, 484, 795, 2040, 8415, 44076};
bool debug = false; // std::find(debug_ids.begin(), debug_ids.end(), dual_cell.id) != debug_ids.end();
#endif
- DualPoint line = dual_cell.value_point(vp);
- const Real base_value = line.weighted_push(p);
+ DualPoint<R> line = dual_cell.value_point(vp);
+ const R base_value = line.weighted_push(p);
spd::debug("Enter get_max_displacement_single_point, p = {},\ndual_cell = {},\nline = {}, base_value = {}\n", p,
dual_cell, line, base_value);
- Real result = 0.0;
- for(DualPoint dp : dual_cell.dual_box().critical_points(p)) {
- Real dp_value = dp.weighted_push(p);
+ R result = 0.0;
+ for(DualPoint<R> dp : dual_cell.dual_box().critical_points(p)) {
+ R dp_value = dp.weighted_push(p);
spd::debug(
"In get_max_displacement_single_point, p = {}, critical dp = {},\ndp_value = {}, diff = {},\ndual_cell = {}\n",
p, dp, dp_value, fabs(base_value - dp_value), dual_cell);
@@ -69,15 +61,15 @@ namespace md {
}
#ifdef MD_DO_FULL_CHECK
- DualBox db = dual_cell.dual_box();
- std::uniform_real_distribution<Real> dlambda(db.lambda_min(), db.lambda_max());
- std::uniform_real_distribution<Real> dmu(db.mu_min(), db.mu_max());
+ auto db = dual_cell.dual_box();
+ std::uniform_real_distribution<R> dlambda(db.lambda_min(), db.lambda_max());
+ std::uniform_real_distribution<R> dmu(db.mu_min(), db.mu_max());
std::mt19937 gen(1);
for(int i = 0; i < 1000; ++i) {
- Real lambda = dlambda(gen);
- Real mu = dmu(gen);
- DualPoint dp_random { db.axis_type(), db.angle_type(), lambda, mu };
- Real dp_value = dp_random.weighted_push(p);
+ R lambda = dlambda(gen);
+ R mu = dmu(gen);
+ DualPoint<R> dp_random { db.axis_type(), db.angle_type(), lambda, mu };
+ R dp_value = dp_random.weighted_push(p);
if (fabs(base_value - dp_value) > result) {
spd::error("in get_max_displacement_single_point, p = {}, vp = {}\ndb = {}\nresult = {}, base_value = {}, dp_value = {}, dp_random = {}",
p, vp, db, result, base_value, dp_value, dp_random);
@@ -89,12 +81,12 @@ namespace md {
return result;
}
- template<class T>
- typename DistanceCalculator<T>::CellValueVector DistanceCalculator<T>::get_initial_dual_grid(Real& lower_bound)
+ template<class R, class T>
+ typename DistanceCalculator<R, T>::CellValueVector DistanceCalculator<R, T>::get_initial_dual_grid(R& lower_bound)
{
CellValueVector result = get_refined_grid(params_.initialization_depth, false, true);
- lower_bound = -1.0;
+ lower_bound = -1;
for(const auto& dc : result) {
lower_bound = std::max(lower_bound, dc.max_corner_value());
}
@@ -102,8 +94,8 @@ namespace md {
assert(lower_bound >= 0);
for(auto& dual_cell : result) {
- Real good_enough_ub = get_good_enough_upper_bound(lower_bound);
- Real max_value_on_cell = get_upper_bound(dual_cell, good_enough_ub);
+ R good_enough_ub = get_good_enough_upper_bound(lower_bound);
+ R max_value_on_cell = get_upper_bound(dual_cell, good_enough_ub);
dual_cell.set_max_possible_value(max_value_on_cell);
#ifdef MD_DO_FULL_CHECK
@@ -116,39 +108,39 @@ namespace md {
return result;
}
- template<class T>
- typename DistanceCalculator<T>::CellValueVector
- DistanceCalculator<T>::get_refined_grid(int init_depth, bool calculate_on_intermediate, bool calculate_on_last)
+ template<class R, class T>
+ typename DistanceCalculator<R, T>::CellValueVector
+ DistanceCalculator<R, T>::get_refined_grid(int init_depth, bool calculate_on_intermediate, bool calculate_on_last)
{
- const Real y_max = std::max(module_a_.max_y(), module_b_.max_y());
- const Real x_max = std::max(module_a_.max_x(), module_b_.max_x());
+ const R y_max = std::max(module_a_.max_y(), module_b_.max_y());
+ const R x_max = std::max(module_a_.max_x(), module_b_.max_x());
- const Real lambda_min = 0;
- const Real lambda_max = 1;
+ const R lambda_min = 0;
+ const R lambda_max = 1;
- const Real mu_min = 0;
+ const R mu_min = 0;
- DualBox x_flat(DualPoint(AxisType::x_type, AngleType::flat, lambda_min, mu_min),
- DualPoint(AxisType::x_type, AngleType::flat, lambda_max, x_max));
+ DualBox<R> x_flat(DualPoint<R>(AxisType::x_type, AngleType::flat, lambda_min, mu_min),
+ DualPoint<R>(AxisType::x_type, AngleType::flat, lambda_max, x_max));
- DualBox x_steep(DualPoint(AxisType::x_type, AngleType::steep, lambda_min, mu_min),
- DualPoint(AxisType::x_type, AngleType::steep, lambda_max, x_max));
+ DualBox<R> x_steep(DualPoint<R>(AxisType::x_type, AngleType::steep, lambda_min, mu_min),
+ DualPoint<R>(AxisType::x_type, AngleType::steep, lambda_max, x_max));
- DualBox y_flat(DualPoint(AxisType::y_type, AngleType::flat, lambda_min, mu_min),
- DualPoint(AxisType::y_type, AngleType::flat, lambda_max, y_max));
+ DualBox<R> y_flat(DualPoint<R>(AxisType::y_type, AngleType::flat, lambda_min, mu_min),
+ DualPoint<R>(AxisType::y_type, AngleType::flat, lambda_max, y_max));
- DualBox y_steep(DualPoint(AxisType::y_type, AngleType::steep, lambda_min, mu_min),
- DualPoint(AxisType::y_type, AngleType::steep, lambda_max, y_max));
+ DualBox<R> y_steep(DualPoint<R>(AxisType::y_type, AngleType::steep, lambda_min, mu_min),
+ DualPoint<R>(AxisType::y_type, AngleType::steep, lambda_max, y_max));
- CellWithValue x_flat_cell(x_flat, 0);
- CellWithValue x_steep_cell(x_steep, 0);
- CellWithValue y_flat_cell(y_flat, 0);
- CellWithValue y_steep_cell(y_steep, 0);
+ CellWithValue<R> x_flat_cell(x_flat, 0);
+ CellWithValue<R> x_steep_cell(x_steep, 0);
+ CellWithValue<R> y_flat_cell(y_flat, 0);
+ CellWithValue<R> y_steep_cell(y_steep, 0);
if (init_depth == 0) {
- DualPoint diagonal_x_flat(AxisType::x_type, AngleType::flat, 1, 0);
+ DualPoint<R> diagonal_x_flat(AxisType::x_type, AngleType::flat, 1, 0);
- Real diagonal_value = distance_on_line(diagonal_x_flat);
+ R diagonal_value = distance_on_line(diagonal_x_flat);
n_hera_calls_per_level_[0]++;
x_flat_cell.set_value_at(ValuePoint::lower_right, diagonal_value);
@@ -162,7 +154,7 @@ namespace md {
x_steep_cell.id = 2;
y_flat_cell.id = 3;
y_steep_cell.id = 4;
- CellWithValue::max_id = 4;
+ CellWithValue<R>::max_id = 4;
#endif
CellValueVector result {x_flat_cell, x_steep_cell, y_flat_cell, y_steep_cell};
@@ -189,10 +181,10 @@ namespace md {
return result;
}
- template<class T>
- DistanceCalculator<T>::DistanceCalculator(const T& a,
+ template<class R, class T>
+ DistanceCalculator<R, T>::DistanceCalculator(const T& a,
const T& b,
- CalculationParams& params)
+ CalculationParams<R>& params)
:
module_a_(a),
module_b_(b),
@@ -213,33 +205,33 @@ namespace md {
module_a_.max_x(), module_a_.max_y(), module_b_.max_x(), module_b_.max_y());
}
- template<class T>
- Real DistanceCalculator<T>::get_max_x(int module) const
+ template<class R, class T>
+ R DistanceCalculator<R, T>::get_max_x(int module) const
{
return (module == 0) ? module_a_.max_x() : module_b_.max_x();
}
- template<class T>
- Real DistanceCalculator<T>::get_max_y(int module) const
+ template<class R, class T>
+ R DistanceCalculator<R, T>::get_max_y(int module) const
{
return (module == 0) ? module_a_.max_y() : module_b_.max_y();
}
- template<class T>
- Real
- DistanceCalculator<T>::get_local_refined_bound(const md::DualBox& dual_box) const
+ template<class R, class T>
+ R
+ DistanceCalculator<R, T>::get_local_refined_bound(const DualBox<R>& dual_box) const
{
return get_local_refined_bound(0, dual_box) + get_local_refined_bound(1, dual_box);
}
- template<class T>
- Real
- DistanceCalculator<T>::get_local_refined_bound(int module, const md::DualBox& dual_box) const
+ template<class R, class T>
+ R
+ DistanceCalculator<R, T>::get_local_refined_bound(int module, const DualBox<R>& dual_box) const
{
spd::debug("Enter get_local_refined_bound, dual_box = {}", dual_box);
- Real d_lambda = dual_box.lambda_max() - dual_box.lambda_min();
- Real d_mu = dual_box.mu_max() - dual_box.mu_min();
- Real result;
+ R d_lambda = dual_box.lambda_max() - dual_box.lambda_min();
+ R d_mu = dual_box.mu_max() - dual_box.mu_min();
+ R result;
if (dual_box.axis_type() == AxisType::x_type) {
if (dual_box.is_flat()) {
result = dual_box.lambda_max() * d_mu + (get_max_x(module) - dual_box.mu_min()) * d_lambda;
@@ -258,11 +250,11 @@ namespace md {
return result;
}
- template<class T>
- Real DistanceCalculator<T>::get_local_dual_bound(int module, const md::DualBox& dual_box) const
+ template<class R, class T>
+ R DistanceCalculator<R, T>::get_local_dual_bound(int module, const DualBox<R>& dual_box) const
{
- Real dlambda = dual_box.lambda_max() - dual_box.lambda_min();
- Real dmu = dual_box.mu_max() - dual_box.mu_min();
+ R dlambda = dual_box.lambda_max() - dual_box.lambda_min();
+ R dmu = dual_box.mu_max() - dual_box.mu_min();
if (dual_box.is_flat()) {
return get_max_x(module) * dlambda + dmu;
@@ -271,20 +263,20 @@ namespace md {
}
}
- template<class T>
- Real DistanceCalculator<T>::get_local_dual_bound(const md::DualBox& dual_box) const
+ template<class R, class T>
+ R DistanceCalculator<R, T>::get_local_dual_bound(const DualBox<R>& dual_box) const
{
return get_local_dual_bound(0, dual_box) + get_local_dual_bound(1, dual_box);
}
- template<class T>
- Real DistanceCalculator<T>::get_upper_bound(const CellWithValue& dual_cell, Real good_enough_ub) const
+ template<class R, class T>
+ R DistanceCalculator<R, T>::get_upper_bound(const CellWithValue<R>& dual_cell, R good_enough_ub) const
{
assert(good_enough_ub >= 0);
switch(params_.bound_strategy) {
case BoundStrategy::bruteforce:
- return std::numeric_limits<Real>::max();
+ return std::numeric_limits<R>::max();
case BoundStrategy::local_dual_bound:
return dual_cell.min_value() + get_local_dual_bound(dual_cell.dual_box());
@@ -293,7 +285,7 @@ namespace md {
return dual_cell.min_value() + get_local_refined_bound(dual_cell.dual_box());
case BoundStrategy::local_combined: {
- Real cheap_upper_bound = dual_cell.min_value() + get_local_refined_bound(dual_cell.dual_box());
+ R cheap_upper_bound = dual_cell.min_value() + get_local_refined_bound(dual_cell.dual_box());
if (cheap_upper_bound < good_enough_ub) {
return cheap_upper_bound;
} else {
@@ -302,14 +294,14 @@ namespace md {
}
case BoundStrategy::local_dual_bound_for_each_point: {
- Real result = std::numeric_limits<Real>::max();
+ R result = std::numeric_limits<R>::max();
for(ValuePoint vp : k_corner_vps) {
if (not dual_cell.has_value_at(vp)) {
continue;
}
- Real base_value = dual_cell.value_at(vp);
- Real bound_dgm_a = get_single_dgm_bound(dual_cell, vp, 0, good_enough_ub);
+ R base_value = dual_cell.value_at(vp);
+ R bound_dgm_a = get_single_dgm_bound(dual_cell, vp, 0, good_enough_ub);
if (params_.stop_asap and bound_dgm_a + base_value >= good_enough_ub) {
// we want to return a valid upper bound, not just something that will prevent discarding the cell
@@ -318,8 +310,8 @@ namespace md {
return dual_cell.min_value() + get_local_refined_bound(dual_cell.dual_box());
}
- Real bound_dgm_b = get_single_dgm_bound(dual_cell, vp, 1,
- std::max(Real(0), good_enough_ub - bound_dgm_a));
+ R bound_dgm_b = get_single_dgm_bound(dual_cell, vp, 1,
+ std::max(R(0), good_enough_ub - bound_dgm_a));
result = std::min(result, base_value + bound_dgm_a + bound_dgm_b);
@@ -336,19 +328,19 @@ namespace md {
}
}
// to suppress compiler warning
- return std::numeric_limits<Real>::max();
+ return std::numeric_limits<R>::max();
}
// find maximal displacement of weighted points of m for all lines in dual_box
- template<class T>
- Real
- DistanceCalculator<T>::get_single_dgm_bound(const CellWithValue& dual_cell,
+ template<class R, class T>
+ R
+ DistanceCalculator<R, T>::get_single_dgm_bound(const CellWithValue<R>& dual_cell,
ValuePoint vp,
int module,
- [[maybe_unused]] Real good_enough_value) const
+ R good_enough_value) const
{
- Real result = 0;
- Point max_point;
+ R result = 0;
+ Point<R> max_point;
spd::debug(
"Enter get_single_dgm_bound, module = {}, dual_cell = {}, vp = {}, good_enough_value = {}, stop_asap = {}\n",
@@ -358,7 +350,7 @@ namespace md {
for(const auto& position : m.positions()) {
spd::debug("in get_single_dgm_bound, simplex = {}\n", position);
- Real x = get_max_displacement_single_point(dual_cell, vp, position);
+ R x = get_max_displacement_single_point(dual_cell, vp, position);
spd::debug("In get_single_dgm_bound, point = {}, displacement = {}", position, x);
@@ -385,30 +377,30 @@ namespace md {
return result;
}
- template<class T>
- Real DistanceCalculator<T>::distance()
+ template<class R, class T>
+ R DistanceCalculator<R, T>::distance()
{
return get_distance_pq();
}
// calculate weighted bottleneneck distance between slices on line
// increments hera calls counter
- template<class T>
- Real DistanceCalculator<T>::distance_on_line(DualPoint line)
+ template<class R, class T>
+ R DistanceCalculator<R, T>::distance_on_line(DualPoint<R> line)
{
++n_hera_calls_;
- Real result = distance_on_line_const(line);
+ R result = distance_on_line_const(line);
return result;
}
- template<class T>
- Real DistanceCalculator<T>::distance_on_line_const(DualPoint line) const
+ template<class R, class T>
+ R DistanceCalculator<R, T>::distance_on_line_const(DualPoint<R> line) const
{
// TODO: think about this - how to call Hera
auto dgm_a = module_a_.weighted_slice_diagram(line);
auto dgm_b = module_b_.weighted_slice_diagram(line);
- Real result;
- if (params_.hera_epsilon > static_cast<Real>(0)) {
+ R result;
+ if (params_.hera_epsilon > static_cast<R>(0)) {
result = hera::bottleneckDistApprox(dgm_a, dgm_b, params_.hera_epsilon) / ( params_.hera_epsilon + 1);
} else {
result = hera::bottleneckDistExact(dgm_a, dgm_b);
@@ -423,10 +415,10 @@ namespace md {
return result;
}
- template<class T>
- Real DistanceCalculator<T>::get_good_enough_upper_bound(Real lower_bound) const
+ template<class R, class T>
+ R DistanceCalculator<R, T>::get_good_enough_upper_bound(R lower_bound) const
{
- Real result;
+ R result;
// in upper_bound strategy we only prune cells if they cannot improve the lower bound,
// otherwise the experiment is supposed to run indefinitely
if (params_.traverse_strategy == TraverseStrategy::upper_bound) {
@@ -440,14 +432,14 @@ namespace md {
// helper function
// calculate weighted bt distance on cell center,
// assign distance value to cell, keep it in heat_map, and return
- template<class T>
- void DistanceCalculator<T>::set_cell_central_value(CellWithValue& dual_cell)
+ template<class R, class T>
+ void DistanceCalculator<R, T>::set_cell_central_value(CellWithValue<R>& dual_cell)
{
- DualPoint central_line {dual_cell.center()};
+ DualPoint<R> central_line {dual_cell.center()};
spd::debug("In set_cell_central_value, processing dual cell = {}, line = {}", dual_cell.dual_box(),
central_line);
- Real new_value = distance_on_line(central_line);
+ R new_value = distance_on_line(central_line);
n_hera_calls_per_level_[dual_cell.level() + 1]++;
dual_cell.set_value_at(ValuePoint::center, new_value);
params_.actual_max_depth = std::max(params_.actual_max_depth, dual_cell.level() + 1);
@@ -472,10 +464,10 @@ namespace md {
// assumes that the underlying container is vector!
// cell_ptr: pointer to the first element in queue
// n_cells: queue size
- template<class T>
- Real DistanceCalculator<T>::get_max_possible_value(const CellWithValue* cell_ptr, int n_cells)
+ template<class R, class T>
+ R DistanceCalculator<R, T>::get_max_possible_value(const CellWithValue<R>* cell_ptr, int n_cells)
{
- Real result = (n_cells > 0) ? cell_ptr->stored_upper_bound() : 0;
+ R result = (n_cells > 0) ? cell_ptr->stored_upper_bound() : 0;
for(int i = 0; i < n_cells; ++i, ++cell_ptr) {
result = std::max(result, cell_ptr->stored_upper_bound());
}
@@ -485,11 +477,11 @@ namespace md {
// helper function:
// return current error from lower and upper bounds
// and save it in params_ (hence not const)
- template<class T>
- Real DistanceCalculator<T>::current_error(Real lower_bound, Real upper_bound)
+ template<class R, class T>
+ R DistanceCalculator<R, T>::current_error(R lower_bound, R upper_bound)
{
- Real current_error = (lower_bound > 0.0) ? (upper_bound - lower_bound) / lower_bound
- : std::numeric_limits<Real>::max();
+ R current_error = (lower_bound > 0.0) ? (upper_bound - lower_bound) / lower_bound
+ : std::numeric_limits<R>::max();
params_.actual_error = current_error;
@@ -505,8 +497,8 @@ namespace md {
// use priority queue to store dual cells
// comparison function depends on the strategies in params_
// ressets hera calls counter
- template<class T>
- Real DistanceCalculator<T>::get_distance_pq()
+ template<class R, class T>
+ R DistanceCalculator<R, T>::get_distance_pq()
{
std::map<int, long> n_cells_considered;
std::map<int, long> n_cells_pushed_into_queue;
@@ -527,26 +519,26 @@ namespace md {
// if cell is too deep and is not pushed into queue,
// we still need to take its max value into account;
// the max over such cells is stored in max_result_on_too_fine_cells
- Real upper_bound_on_deep_cells = -1;
+ R upper_bound_on_deep_cells = -1;
spd::debug("Started iterations in dual space, delta = {}, bound_strategy = {}", params_.delta,
params_.bound_strategy);
// user-defined less lambda function
// to regulate priority queue depending on strategy
- auto dual_cell_less = [this](const CellWithValue& a, const CellWithValue& b) {
+ auto dual_cell_less = [this](const CellWithValue<R>& a, const CellWithValue<R>& b) {
int a_level = a.level();
int b_level = b.level();
- Real a_value = a.max_corner_value();
- Real b_value = b.max_corner_value();
- Real a_ub = a.stored_upper_bound();
- Real b_ub = b.stored_upper_bound();
+ R a_value = a.max_corner_value();
+ R b_value = b.max_corner_value();
+ R a_ub = a.stored_upper_bound();
+ R b_ub = b.stored_upper_bound();
if (this->params_.traverse_strategy == TraverseStrategy::upper_bound and
(not a.has_max_possible_value() or not b.has_max_possible_value())) {
throw std::runtime_error("no upper bound on cell");
}
- DualPoint a_lower_left = a.dual_box().lower_left();
- DualPoint b_lower_left = b.dual_box().lower_left();
+ DualPoint<R> a_lower_left = a.dual_box().lower_left();
+ DualPoint<R> b_lower_left = b.dual_box().lower_left();
switch(this->params_.traverse_strategy) {
// in both breadth_first searches we want coarser cells
@@ -569,24 +561,24 @@ namespace md {
}
};
- std::priority_queue<CellWithValue, CellValueVector, decltype(dual_cell_less)> dual_cells_queue(
+ std::priority_queue<CellWithValue<R>, CellValueVector, decltype(dual_cell_less)> dual_cells_queue(
dual_cell_less);
// weighted bt distance on the center of current cell
- Real lower_bound = std::numeric_limits<Real>::min();
+ R lower_bound = std::numeric_limits<R>::min();
// init pq and lower bound
for(auto& init_cell : get_initial_dual_grid(lower_bound)) {
dual_cells_queue.push(init_cell);
}
- Real upper_bound = get_max_possible_value(&dual_cells_queue.top(), dual_cells_queue.size());
+ R upper_bound = get_max_possible_value(&dual_cells_queue.top(), dual_cells_queue.size());
std::vector<UbExperimentRecord> ub_experiment_results;
while(not dual_cells_queue.empty()) {
- CellWithValue dual_cell = dual_cells_queue.top();
+ CellWithValue<R> dual_cell = dual_cells_queue.top();
dual_cells_queue.pop();
assert(dual_cell.has_corner_value()
and dual_cell.has_max_possible_value()
@@ -620,7 +612,7 @@ namespace md {
// until now, dual_cell knows its value in one of its corners
// new_value will be the weighted distance at its center
set_cell_central_value(dual_cell);
- Real new_value = dual_cell.value_at(ValuePoint::center);
+ R new_value = dual_cell.value_at(ValuePoint::center);
lower_bound = std::max(new_value, lower_bound);
spd::debug("Processed cell = {}, weighted value = {}, lower_bound = {}", dual_cell, new_value, lower_bound);
@@ -638,11 +630,11 @@ namespace md {
throw std::runtime_error("no value on cell");
// if delta is smaller than good_enough_value, it allows to prune cell
- Real good_enough_ub = get_good_enough_upper_bound(lower_bound);
+ R good_enough_ub = get_good_enough_upper_bound(lower_bound);
// upper bound of the parent holds for refined_cell
// and can sometimes be smaller!
- Real upper_bound_on_refined_cell = std::min(dual_cell.stored_upper_bound(),
+ R upper_bound_on_refined_cell = std::min(dual_cell.stored_upper_bound(),
get_upper_bound(refined_cell, good_enough_ub));
spd::debug("upper_bound_on_refined_cell = {}, dual_cell.stored_upper_bound = {}, get_upper_bound = {}",
@@ -774,10 +766,46 @@ namespace md {
return lower_bound;
}
- template<class T>
- int DistanceCalculator<T>::get_hera_calls_number() const
+ template<class R, class T>
+ int DistanceCalculator<R, T>::get_hera_calls_number() const
{
return n_hera_calls_;
}
-} \ No newline at end of file
+ template<class R>
+ R matching_distance(const Bifiltration<R>& bif_a, const Bifiltration<R>& bif_b,
+ CalculationParams<R>& params)
+ {
+ R result;
+ // compute distance only in one dimension
+ if (params.dim != CalculationParams<R>::ALL_DIMENSIONS) {
+ BifiltrationProxy<R> bifp_a(bif_a, params.dim);
+ BifiltrationProxy<R> bifp_b(bif_b, params.dim);
+ DistanceCalculator<R, BifiltrationProxy<R>> runner(bifp_a, bifp_b, params);
+ result = runner.distance();
+ params.n_hera_calls = runner.get_hera_calls_number();
+ } else {
+ // compute distance in all dimensions, return maximal
+ result = -1;
+ for(int dim = 0; dim < std::max(bif_a.maximal_dim(), bif_b.maximal_dim()); ++dim) {
+ BifiltrationProxy<R> bifp_a(bif_a, params.dim);
+ BifiltrationProxy<R> bifp_b(bif_a, params.dim);
+ DistanceCalculator<R, BifiltrationProxy<R>> runner(bifp_a, bifp_b, params);
+ result = std::max(result, runner.distance());
+ params.n_hera_calls += runner.get_hera_calls_number();
+ }
+ }
+ return result;
+ }
+
+
+ template<class R>
+ R matching_distance(const ModulePresentation<R>& mod_a, const ModulePresentation<R>& mod_b,
+ CalculationParams<R>& params)
+ {
+ DistanceCalculator<R, ModulePresentation<R>> runner(mod_a, mod_b, params);
+ R result = runner.distance();
+ params.n_hera_calls = runner.get_hera_calls_number();
+ return result;
+ }
+} // namespace md
diff --git a/matching/include/persistence_module.h b/matching/include/persistence_module.h
index a1fc67e..e99771f 100644
--- a/matching/include/persistence_module.h
+++ b/matching/include/persistence_module.h
@@ -5,6 +5,12 @@
#include <vector>
#include <utility>
#include <string>
+#include <numeric>
+#include <algorithm>
+#include <unordered_set>
+
+#include "phat/boundary_matrix.h"
+#include "phat/compute_persistence_pairs.h"
#include "common_util.h"
#include "dual_point.h"
@@ -28,17 +34,20 @@ namespace md {
*/
+ template<class Real>
class ModulePresentation {
public:
+ using RealVec = std::vector<Real>;
+
enum Format { rivet_firep };
struct Relation {
- Point position_;
+ Point<Real> position_;
IndexVec components_;
Relation() {}
- Relation(const Point& _pos, const IndexVec& _components);
+ Relation(const Point<Real>& _pos, const IndexVec& _components);
Real get_x() const { return position_.x; }
Real get_y() const { return position_.y; }
@@ -48,9 +57,9 @@ namespace md {
ModulePresentation() {}
- ModulePresentation(const PointVec& _generators, const RelVec& _relations);
+ ModulePresentation(const PointVec<Real>& _generators, const RelVec& _relations);
- Diagram weighted_slice_diagram(const DualPoint& line) const;
+ Diagram<Real> weighted_slice_diagram(const DualPoint<Real>& line) const;
// translate all points by vector (a,a)
void translate(Real a);
@@ -59,9 +68,7 @@ namespace md {
Real minimal_coordinate() const { return std::min(min_x(), min_y()); }
// return box that contains all positions of all simplices
- Box bounding_box() const;
-
- friend std::ostream& operator<<(std::ostream& os, const ModulePresentation& mp);
+ Box<Real> bounding_box() const;
Real max_x() const { return max_x_; }
@@ -71,26 +78,27 @@ namespace md {
Real min_y() const { return min_y_; }
- PointVec positions() const;
+ PointVec<Real> positions() const;
private:
- PointVec generators_;
+ PointVec<Real> generators_;
std::vector<Relation> relations_;
- PointVec positions_;
+ PointVec<Real> positions_;
Real max_x_ { std::numeric_limits<Real>::max() };
Real max_y_ { std::numeric_limits<Real>::max() };
Real min_x_ { -std::numeric_limits<Real>::max() };
Real min_y_ { -std::numeric_limits<Real>::max() };
- Box bounding_box_;
+ Box<Real> bounding_box_;
void init_boundaries();
- void project_generators(const DualPoint& slice, IndexVec& sorted_indices, RealVec& projections) const;
- void project_relations(const DualPoint& slice, IndexVec& sorted_indices, RealVec& projections) const;
+ void project_generators(const DualPoint<Real>& slice, IndexVec& sorted_indices, RealVec& projections) const;
+ void project_relations(const DualPoint<Real>& slice, IndexVec& sorted_indices, RealVec& projections) const;
};
} // namespace md
+#include "persistence_module.hpp"
#endif //MATCHING_DISTANCE_PERSISTENCE_MODULE_H
diff --git a/matching/include/persistence_module.hpp b/matching/include/persistence_module.hpp
new file mode 100644
index 0000000..6e49b2e
--- /dev/null
+++ b/matching/include/persistence_module.hpp
@@ -0,0 +1,177 @@
+namespace md {
+
+ /**
+ *
+ * @param values vector of length n
+ * @return [a_1,...,a_n] such that
+ * 1) values[a_1] <= values[a_2] <= ... <= values[a_n]
+ * 2) a_1,...,a_n is a permutation of 1,..,n
+ */
+
+ template<class T>
+ IndexVec get_sorted_indices(const std::vector<T>& values)
+ {
+ IndexVec result(values.size());
+ std::iota(result.begin(), result.end(), 0);
+ std::sort(result.begin(), result.end(),
+ [&values](size_t a, size_t b) { return values[a] < values[b]; });
+ return result;
+ }
+
+ // helper function to initialize const member positions_ in ModulePresentation
+ template<class Real>
+ PointVec<Real> concat_gen_and_rel_positions(const PointVec<Real>& generators,
+ const typename ModulePresentation<Real>::RelVec& relations)
+ {
+ std::unordered_set<Point<Real>> ps(generators.begin(), generators.end());
+ for(const auto& rel : relations) {
+ ps.insert(rel.position_);
+ }
+ return PointVec<Real>(ps.begin(), ps.end());
+ }
+
+
+ template<class Real>
+ void ModulePresentation<Real>::init_boundaries()
+ {
+ max_x_ = std::numeric_limits<Real>::max();
+ max_y_ = std::numeric_limits<Real>::max();
+ min_x_ = -std::numeric_limits<Real>::max();
+ min_y_ = -std::numeric_limits<Real>::max();
+
+ for(const auto& gen : positions_) {
+ min_x_ = std::min(gen.x, min_x_);
+ min_y_ = std::min(gen.y, min_y_);
+ max_x_ = std::max(gen.x, max_x_);
+ max_y_ = std::max(gen.y, max_y_);
+ }
+
+ bounding_box_ = Box<Real>(Point<Real>(min_x_, min_y_), Point<Real>(max_x_, max_y_));
+ }
+
+
+ template<class Real>
+ ModulePresentation<Real>::ModulePresentation(const PointVec<Real>& _generators, const RelVec& _relations) :
+ generators_(_generators),
+ relations_(_relations)
+ {
+ init_boundaries();
+ }
+
+ template<class Real>
+ void ModulePresentation<Real>::translate(Real a)
+ {
+ for(auto& g : generators_) {
+ g.translate(a);
+ }
+
+ for(auto& r : relations_) {
+ r.position_.translate(a);
+ }
+
+ positions_ = concat_gen_and_rel_positions(generators_, relations_);
+ init_boundaries();
+ }
+
+
+ /**
+ *
+ * @param slice line on which generators are projected
+ * @param sorted_indices [a_1,...,a_n] s.t. wpush(generator[a_1]) <= wpush(generator[a_2]) <= ..
+ * @param projections sorted weighted pushes of generators
+ */
+
+ template<class Real>
+ void ModulePresentation<Real>::project_generators(const DualPoint<Real>& slice,
+ IndexVec& sorted_indices, RealVec& projections) const
+ {
+ size_t num_gens = generators_.size();
+
+ RealVec gen_values;
+ gen_values.reserve(num_gens);
+ for(const auto& pos : generators_) {
+ gen_values.push_back(slice.weighted_push(pos));
+ }
+ sorted_indices = get_sorted_indices(gen_values);
+ projections.clear();
+ projections.reserve(num_gens);
+ for(auto i : sorted_indices) {
+ projections.push_back(gen_values[i]);
+ }
+ }
+
+ template<class Real>
+ void ModulePresentation<Real>::project_relations(const DualPoint<Real>& slice, IndexVec& sorted_rel_indices,
+ RealVec& projections) const
+ {
+ size_t num_rels = relations_.size();
+
+ RealVec rel_values;
+ rel_values.reserve(num_rels);
+ for(const auto& rel : relations_) {
+ rel_values.push_back(slice.weighted_push(rel.position_));
+ }
+ sorted_rel_indices = get_sorted_indices(rel_values);
+ projections.clear();
+ projections.reserve(num_rels);
+ for(auto i : sorted_rel_indices) {
+ projections.push_back(rel_values[i]);
+ }
+ }
+
+ template<class Real>
+ Diagram<Real> ModulePresentation<Real>::weighted_slice_diagram(const DualPoint<Real>& slice) const
+ {
+ IndexVec sorted_gen_indices, sorted_rel_indices;
+ RealVec gen_projections, rel_projections;
+
+ project_generators(slice, sorted_gen_indices, gen_projections);
+ project_relations(slice, sorted_rel_indices, rel_projections);
+
+ phat::boundary_matrix<> phat_matrix;
+
+ phat_matrix.set_num_cols(relations_.size());
+
+ for(Index i = 0; i < (Index) relations_.size(); i++) {
+ IndexVec current_relation = relations_[sorted_rel_indices[i]].components_;
+ for(auto& j : current_relation) {
+ j = sorted_gen_indices[j];
+ }
+ std::sort(current_relation.begin(), current_relation.end());
+ phat_matrix.set_dim(i, current_relation.size());
+ phat_matrix.set_col(i, current_relation);
+ }
+
+ phat::persistence_pairs phat_persistence_pairs;
+ phat::compute_persistence_pairs<phat::twist_reduction>(phat_persistence_pairs, phat_matrix);
+
+ Diagram<Real> dgm;
+
+ constexpr Real real_inf = std::numeric_limits<Real>::infinity();
+
+ for(Index i = 0; i < (Index) phat_persistence_pairs.get_num_pairs(); i++) {
+ std::pair<phat::index, phat::index> new_pair = phat_persistence_pairs.get_pair(i);
+ bool is_finite_pair = new_pair.second != phat::k_infinity_index;
+ Real birth = gen_projections.at(new_pair.first);
+ Real death = is_finite_pair ? rel_projections.at(new_pair.second) : real_inf;
+ if (birth != death) {
+ dgm.emplace_back(birth, death);
+ }
+ }
+
+ return dgm;
+ }
+
+ template<class Real>
+ PointVec<Real> ModulePresentation<Real>::positions() const
+ {
+ return positions_;
+ }
+
+ template<class Real>
+ Box<Real> ModulePresentation<Real>::bounding_box() const
+ {
+ return bounding_box_;
+ }
+
+} // namespace md
diff --git a/matching/include/simplex.h b/matching/include/simplex.h
index e9d0e30..75bbcae 100644
--- a/matching/include/simplex.h
+++ b/matching/include/simplex.h
@@ -9,6 +9,7 @@
namespace md {
+ template<class Real>
class Bifiltration;
enum class BifiltrationFormat {
@@ -38,11 +39,21 @@ namespace md {
int dim() const { return vertices_.size() - 1; }
- void push_back(int v);
+ void push_back(int v)
+ {
+ vertices_.push_back(v);
+ std::sort(vertices_.begin(), vertices_.end());
+ }
AbstractSimplex() { }
- AbstractSimplex(std::vector<int> vertices, bool sort = true);
+ AbstractSimplex(std::vector<int> vertices, bool sort = true)
+ :vertices_(vertices)
+ {
+ if (sort)
+ std::sort(vertices_.begin(), vertices_.end());
+ }
+
template<class Iter>
AbstractSimplex(Iter beg_iter, Iter end_iter, bool sort = true)
@@ -53,22 +64,51 @@ namespace md {
std::sort(vertices_.begin(), end());
}
- std::vector<AbstractSimplex> facets() const;
+ std::vector<AbstractSimplex> facets() const
+ {
+ std::vector<AbstractSimplex> result;
+ for (int i = 0; i < static_cast<int>(vertices_.size()); ++i) {
+ std::vector<int> facet_vertices;
+ facet_vertices.reserve(dim());
+ for (int j = 0; j < static_cast<int>(vertices_.size()); ++j) {
+ if (j != i)
+ facet_vertices.push_back(vertices_[j]);
+ }
+ if (!facet_vertices.empty()) {
+ result.emplace_back(facet_vertices, false);
+ }
+ }
+ return result;
+ }
friend std::ostream& operator<<(std::ostream& os, const AbstractSimplex& s);
-
// compare by vertices_ only
friend bool operator==(const AbstractSimplex& s1, const AbstractSimplex& s2);
friend bool operator<(const AbstractSimplex&, const AbstractSimplex&);
};
- std::ostream& operator<<(std::ostream& os, const AbstractSimplex& s);
+ inline std::ostream& operator<<(std::ostream& os, const AbstractSimplex& s)
+ {
+ os << "AbstractSimplex(id = " << s.id << ", vertices_ = " << container_to_string(s.vertices_) << ")";
+ return os;
+ }
+
+ inline bool operator<(const AbstractSimplex& a, const AbstractSimplex& b)
+ {
+ return a.vertices_ < b.vertices_;
+ }
+
+ inline bool operator==(const AbstractSimplex& s1, const AbstractSimplex& s2)
+ {
+ return s1.vertices_ == s2.vertices_;
+ }
+ template<class Real>
class Simplex {
private:
Index id_;
- Point pos_;
+ Point<Real> pos_;
int dim_;
// in our format we use facet indices,
// this is the fastest representation for homology
@@ -77,11 +117,11 @@ namespace md {
// conversion routines are in Bifiltration
Column facet_indices_;
Column vertices_;
- Real v {0.0}; // used when constructed a filtration for a slice
+ Real v {0}; // used when constructed a filtration for a slice
public:
Simplex(Index _id, std::string s, BifiltrationFormat input_format);
- Simplex(Index _id, Point birth, int _dim, const Column& _bdry);
+ Simplex(Index _id, Point<Real> birth, int _dim, const Column& _bdry);
void init_rivet(std::string s);
@@ -96,9 +136,9 @@ namespace md {
Real value() const { return v; }
// assumes 1-criticality
- Point position() const { return pos_; }
+ Point<Real> position() const { return pos_; }
- void set_position(const Point& new_pos) { pos_ = new_pos; }
+ void set_position(const Point<Real>& new_pos) { pos_ = new_pos; }
void scale(Real lambda)
{
@@ -110,12 +150,14 @@ namespace md {
void set_value(Real new_val) { v = new_val; }
- friend std::ostream& operator<<(std::ostream& os, const Simplex& s);
-
- friend Bifiltration;
+ friend Bifiltration<Real>;
};
- std::ostream& operator<<(std::ostream& os, const Simplex& s);
+ template<class Real>
+ std::ostream& operator<<(std::ostream& os, const Simplex<Real>& s);
}
+
+#include "simplex.hpp"
+
#endif //MATCHING_DISTANCE_SIMPLEX_H
diff --git a/matching/include/simplex.hpp b/matching/include/simplex.hpp
new file mode 100644
index 0000000..ce0e30f
--- /dev/null
+++ b/matching/include/simplex.hpp
@@ -0,0 +1,79 @@
+namespace md {
+
+ template<class Real>
+ Simplex<Real>::Simplex(Index id, Point<Real> birth, int dim, const Column& bdry)
+ :
+ id_(id),
+ pos_(birth),
+ dim_(dim),
+ facet_indices_(bdry) { }
+
+ template<class Real>
+ void Simplex<Real>::translate(Real a)
+ {
+ pos_.translate(a);
+ }
+
+ template<class Real>
+ void Simplex<Real>::init_rivet(std::string s)
+ {
+ auto delim_pos = s.find_first_of(";");
+ assert(delim_pos > 0);
+ std::string vertices_str = s.substr(0, delim_pos);
+ std::string pos_str = s.substr(delim_pos + 1);
+ assert(not vertices_str.empty() and not pos_str.empty());
+ // get vertices
+ std::stringstream vertices_ss(vertices_str);
+ int dim = 0;
+ int vertex;
+ while (vertices_ss >> vertex) {
+ dim++;
+ vertices_.push_back(vertex);
+ }
+ //
+ std::sort(vertices_.begin(), vertices_.end());
+ assert(dim > 0);
+
+ std::stringstream pos_ss(pos_str);
+ // TODO: get rid of 1-criticaltiy assumption
+ pos_ss >> pos_.x >> pos_.y;
+ }
+
+ template<class Real>
+ void Simplex<Real>::init_phat_like(std::string s)
+ {
+ facet_indices_.clear();
+ std::stringstream ss(s);
+ ss >> dim_ >> pos_.x >> pos_.y;
+ if (dim_ > 0) {
+ facet_indices_.reserve(dim_ + 1);
+ for (int j = 0; j <= dim_; j++) {
+ Index k;
+ ss >> k;
+ facet_indices_.push_back(k);
+ }
+ }
+ }
+
+ template<class Real>
+ Simplex<Real>::Simplex(Index _id, std::string s, BifiltrationFormat input_format)
+ :id_(_id)
+ {
+ switch (input_format) {
+ case BifiltrationFormat::phat_like :
+ init_phat_like(s);
+ break;
+ case BifiltrationFormat::rivet :
+ init_rivet(s);
+ break;
+ }
+ }
+
+ template<class Real>
+ std::ostream& operator<<(std::ostream& os, const Simplex<Real>& x)
+ {
+ os << "Simplex<Real>(id = " << x.id() << ", dim = " << x.dim();
+ os << ", boundary = " << container_to_string(x.boundary()) << ", pos = " << x.position() << ")";
+ return os;
+ }
+}