summaryrefslogtreecommitdiff
path: root/matching/src
diff options
context:
space:
mode:
Diffstat (limited to 'matching/src')
-rw-r--r--matching/src/bifiltration.cpp407
-rw-r--r--matching/src/box.cpp61
-rw-r--r--matching/src/cell_with_value.cpp247
-rw-r--r--matching/src/common_util.cpp243
-rw-r--r--matching/src/dual_box.cpp194
-rw-r--r--matching/src/dual_point.cpp282
-rw-r--r--matching/src/main.cpp29
-rw-r--r--matching/src/matching_distance.cpp150
-rw-r--r--matching/src/persistence_module.cpp177
-rw-r--r--matching/src/simplex.cpp121
-rw-r--r--matching/src/test_generator.cpp19
-rw-r--r--matching/src/tests/test_common.cpp66
-rw-r--r--matching/src/tests/test_matching_distance.cpp22
13 files changed, 67 insertions, 1951 deletions
diff --git a/matching/src/bifiltration.cpp b/matching/src/bifiltration.cpp
deleted file mode 100644
index 44b12cf..0000000
--- a/matching/src/bifiltration.cpp
+++ /dev/null
@@ -1,407 +0,0 @@
-#include <iostream>
-#include <fstream>
-#include <sstream>
-#include <cassert>
-
-#include<phat/boundary_matrix.h>
-#include<phat/compute_persistence_pairs.h>
-
-#include "spdlog/spdlog.h"
-#include "spdlog/fmt/fmt.h"
-#include "spdlog/fmt/ostr.h"
-
-#include "common_util.h"
-#include "bifiltration.h"
-
-namespace spd = spdlog;
-
-namespace md {
-
- void Bifiltration::init()
- {
- Point lower_left = max_point();
- Point upper_right = min_point();
- for(const auto& simplex : simplices_) {
- lower_left = greatest_lower_bound(lower_left, simplex.position());
- upper_right = least_upper_bound(upper_right, simplex.position());
- maximal_dim_ = std::max(maximal_dim_, simplex.dim());
- }
- bounding_box_ = Box(lower_left, upper_right);
- }
-
- Bifiltration::Bifiltration(const std::string& fname )
- {
- std::ifstream ifstr {fname.c_str()};
- if (!ifstr.good()) {
- std::string error_message = fmt::format("Cannot open file {0}", fname);
- std::cerr << error_message << std::endl;
- throw std::runtime_error(error_message);
- }
-
- BifiltrationFormat input_format;
-
- std::string s;
-
- while(ignore_line(s)) {
- std::getline(ifstr, s);
- }
-
- if (s == "bifiltration") {
- input_format = BifiltrationFormat::rivet;
- } else if (s == "bifiltration_phat_like") {
- input_format = BifiltrationFormat::phat_like;
- } else {
- std::cerr << "Unknown format: '" << s << "' in file " << fname << std::endl;
- throw std::runtime_error("unknown bifiltration format");
- }
-
- switch(input_format) {
- case BifiltrationFormat::rivet :
- rivet_format_reader(ifstr);
- break;
- case BifiltrationFormat::phat_like :
- phat_like_format_reader(ifstr);
- break;
- }
-
- ifstr.close();
-
- init();
- }
-
- void Bifiltration::rivet_format_reader(std::ifstream& ifstr)
- {
- std::string s;
- // read axes names
- std::getline(ifstr, parameter_1_name_);
- std::getline(ifstr, parameter_2_name_);
-
- Index index = 0;
- while(std::getline(ifstr, s)) {
- if (!ignore_line(s)) {
- simplices_.emplace_back(index++, s, BifiltrationFormat::rivet);
- }
- }
- }
-
- void Bifiltration::phat_like_format_reader(std::ifstream& ifstr)
- {
- spd::debug("Enter phat_like_format_reader");
- // read stream line by line; do not use >> operator
- std::string s;
- std::getline(ifstr, s);
-
- // first line contains number of simplices
- long n_simplices = std::stol(s);
-
- // all other lines represent a simplex
- Index index = 0;
- while(index < n_simplices) {
- std::getline(ifstr, s);
- if (!ignore_line(s)) {
- simplices_.emplace_back(index++, s, BifiltrationFormat::phat_like);
- }
- }
- spd::debug("Read {} simplices from file", n_simplices);
- }
-
- void Bifiltration::scale(Real lambda)
- {
- for(auto& s : simplices_) {
- s.scale(lambda);
- }
- init();
- }
-
- void Bifiltration::sanity_check() const
- {
-#ifdef DEBUG
- spd::debug("Enter Bifiltration::sanity_check");
- // check that boundary has correct number of simplices,
- // each bounding simplex has correct dim
- // and appears in the filtration before the simplex it bounds
- for(const auto& s : simplices_) {
- assert(s.dim() >= 0);
- assert(s.dim() == 0 or s.dim() + 1 == (int) s.boundary().size());
- for(auto bdry_idx : s.boundary()) {
- Simplex bdry_simplex = simplices()[bdry_idx];
- assert(bdry_simplex.dim() == s.dim() - 1);
- assert(bdry_simplex.position().is_less(s.position(), false));
- }
- }
- spd::debug("Exit Bifiltration::sanity_check");
-#endif
- }
-
- Diagram Bifiltration::weighted_slice_diagram(const DualPoint& line, int dim) const
- {
- DiagramKeeper dgm;
-
- // make a copy for now; I want slice_diagram to be const
- std::vector<Simplex> simplices(simplices_);
-
-// std::vector<Simplex> simplices;
-// simplices.reserve(simplices_.size() / 2);
-// for(const auto& s : simplices_) {
-// if (s.dim() <= dim + 1 and s.dim() >= dim)
-// simplices.emplace_back(s);
-// }
-
- spd::debug("Enter slice diagram, line = {}, simplices.size = {}", line, simplices.size());
-
- for(auto& simplex : simplices) {
- Real value = line.weighted_push(simplex.position());
-// spd::debug("in slice_diagram, simplex = {}, value = {}\n", simplex, value);
- simplex.set_value(value);
- }
-
- std::sort(simplices.begin(), simplices.end(),
- [](const Simplex& a, const Simplex& b) { return a.value() < b.value(); });
- std::map<Index, Index> index_map;
- for(Index i = 0; i < (int) simplices.size(); i++) {
- index_map[simplices[i].id()] = i;
- }
-
- phat::boundary_matrix<> phat_matrix;
- phat_matrix.set_num_cols(simplices.size());
- std::vector<Index> bd_in_slice_filtration;
- for(Index i = 0; i < (int) simplices.size(); i++) {
- phat_matrix.set_dim(i, simplices[i].dim());
- bd_in_slice_filtration.clear();
- //std::cout << "new col" << i << std::endl;
- for(int j = 0; j < (int) simplices[i].boundary().size(); j++) {
- // F[i] contains the indices of its facet wrt to the
- // original filtration. We have to express it, however,
- // wrt to the filtration along the slice. That is why
- // we need the index_map
- //std::cout << "Found " << F[i].bd[j] << ", returning " << index_map[F[i].bd[j]] << std::endl;
- bd_in_slice_filtration.push_back(index_map[simplices[i].boundary()[j]]);
- }
- std::sort(bd_in_slice_filtration.begin(), bd_in_slice_filtration.end());
- phat_matrix.set_col(i, bd_in_slice_filtration);
- }
- phat::persistence_pairs phat_persistence_pairs;
- phat::compute_persistence_pairs<phat::twist_reduction>(phat_persistence_pairs, phat_matrix);
-
- dgm.clear();
- constexpr Real real_inf = std::numeric_limits<Real>::infinity();
- for(long i = 0; i < (long) phat_persistence_pairs.get_num_pairs(); i++) {
- std::pair<phat::index, phat::index> new_pair = phat_persistence_pairs.get_pair(i);
- bool is_finite_pair = new_pair.second != phat::k_infinity_index;
- Real birth = simplices.at(new_pair.first).value();
- Real death = is_finite_pair ? simplices.at(new_pair.second).value() : real_inf;
- int dim = simplices[new_pair.first].dim();
- assert(dim + 1 == simplices[new_pair.second].dim());
- if (birth != death) {
- dgm.add_point(dim, birth, death);
- }
- }
-
- spdlog::debug("Exiting slice_diagram, #dgm[0] = {}", dgm.get_diagram(0).size());
-
- return dgm.get_diagram(dim);
- }
-
- Box Bifiltration::bounding_box() const
- {
- return bounding_box_;
- }
-
- Real Bifiltration::minimal_coordinate() const
- {
- return std::min(bounding_box_.lower_left().x, bounding_box_.lower_left().y);
- }
-
- void Bifiltration::translate(Real a)
- {
- bounding_box_.translate(a);
- for(auto& simplex : simplices_) {
- simplex.translate(a);
- }
- }
-
- Real Bifiltration::max_x() const
- {
- if (simplices_.empty())
- return 1;
- auto me = std::max_element(simplices_.cbegin(), simplices_.cend(),
- [](const auto& s_a, const auto& s_b) { return s_a.position().x < s_b.position().x; });
- assert(me != simplices_.cend());
- return me->position().x;
- }
-
- Real Bifiltration::max_y() const
- {
- if (simplices_.empty())
- return 1;
- auto me = std::max_element(simplices_.cbegin(), simplices_.cend(),
- [](const auto& s_a, const auto& s_b) { return s_a.position().y < s_b.position().y; });
- assert(me != simplices_.cend());
- return me->position().y;
- }
-
- Real Bifiltration::min_x() const
- {
- if (simplices_.empty())
- return 0;
- auto me = std::min_element(simplices_.cbegin(), simplices_.cend(),
- [](const auto& s_a, const auto& s_b) { return s_a.position().x < s_b.position().x; });
- assert(me != simplices_.cend());
- return me->position().x;
- }
-
- Real Bifiltration::min_y() const
- {
- if (simplices_.empty())
- return 0;
- auto me = std::min_element(simplices_.cbegin(), simplices_.cend(),
- [](const auto& s_a, const auto& s_b) { return s_a.position().y < s_b.position().y; });
- assert(me != simplices_.cend());
- return me->position().y;
- }
-
- void Bifiltration::add_simplex(md::Index _id, md::Point birth, int _dim, const md::Column& _bdry)
- {
- simplices_.emplace_back(_id, birth, _dim, _bdry);
- }
-
- void Bifiltration::save(const std::string& filename, md::BifiltrationFormat format)
- {
- switch(format) {
- case BifiltrationFormat::rivet:
- throw std::runtime_error("Not implemented");
- break;
- case BifiltrationFormat::phat_like: {
- std::ofstream f(filename);
- if (not f.good()) {
- std::cerr << "Bifiltration::save: cannot open file " << filename << std::endl;
- throw std::runtime_error("Cannot open file for writing ");
- }
- f << simplices_.size() << "\n";
-
- for(const auto& s : simplices_) {
- f << s.dim() << " " << s.position().x << " " << s.position().y << " ";
- for(int b : s.boundary()) {
- f << b << " ";
- }
- f << std::endl;
- }
-
- }
- break;
- }
- }
-
- void Bifiltration::postprocess_rivet_format()
- {
- std::map<Column, Index> facets_to_ids;
-
- // fill the map
- for(Index i = 0; i < (Index) simplices_.size(); ++i) {
- assert(simplices_[i].id() == i);
- facets_to_ids[simplices_[i].vertices_] = i;
- }
-
-// for(const auto& s : simplices_) {
-// facets_to_ids[s] = s.id();
-// }
-
- // main loop
- for(auto& s : simplices_) {
- assert(not s.vertices_.empty());
- assert(s.facet_indices_.empty());
- Column facet_indices;
- for(Index i = 0; i <= s.dim(); ++i) {
- Column facet;
- for(Index j : s.vertices_) {
- if (j != i)
- facet.push_back(j);
- }
- auto facet_index = facets_to_ids.at(facet);
- facet_indices.push_back(facet_index);
- }
- s.facet_indices_ = facet_indices;
- } // loop over simplices
- }
-
- std::ostream& operator<<(std::ostream& os, const Bifiltration& bif)
- {
- os << "Bifiltration, axes = " << bif.parameter_1_name_ << ", " << bif.parameter_2_name_ << std::endl;
- for(const auto& s : bif.simplices()) {
- os << s << std::endl;
- }
- return os;
- }
-
- BifiltrationProxy::BifiltrationProxy(const md::Bifiltration& bif, int dim)
- :
- dim_(dim),
- bif_(bif)
- {
- cache_positions();
- }
-
- void BifiltrationProxy::cache_positions() const
- {
- cached_positions_.clear();
- for(const auto& simplex : bif_.simplices()) {
- if (simplex.dim() == dim_ or simplex.dim() == dim_ + 1)
- cached_positions_.push_back(simplex.position());
- }
- }
-
- PointVec BifiltrationProxy::positions() const
- {
- if (cached_positions_.empty()) {
- cache_positions();
- }
- return cached_positions_;
- }
-
- // translate all points by vector (a,a)
- void BifiltrationProxy::translate(Real a)
- {
- bif_.translate(a);
- }
-
- // return minimal value of x- and y-coordinates
- // among all simplices
- Real BifiltrationProxy::minimal_coordinate() const
- {
- return bif_.minimal_coordinate();
- }
-
- // return box that contains positions of all simplices
- Box BifiltrationProxy::bounding_box() const
- {
- return bif_.bounding_box();
- }
-
- Real BifiltrationProxy::max_x() const
- {
- return bif_.max_x();
- }
-
- Real BifiltrationProxy::max_y() const
- {
- return bif_.max_y();
- }
-
- Real BifiltrationProxy::min_x() const
- {
- return bif_.min_x();
- }
-
- Real BifiltrationProxy::min_y() const
- {
- return bif_.min_y();
- }
-
-
- Diagram BifiltrationProxy::weighted_slice_diagram(const DualPoint& slice) const
- {
- return bif_.weighted_slice_diagram(slice, dim_);
- }
-
-}
-
diff --git a/matching/src/box.cpp b/matching/src/box.cpp
deleted file mode 100644
index c128698..0000000
--- a/matching/src/box.cpp
+++ /dev/null
@@ -1,61 +0,0 @@
-
-#include "box.h"
-
-namespace md {
-
- std::ostream& operator<<(std::ostream& os, const Box& box)
- {
- os << "Box(lower_left = " << box.lower_left() << ", upper_right = " << box.upper_right() << ")";
- return os;
- }
-
- Box get_enclosing_box(const Box& box_a, const Box& box_b)
- {
- Point lower_left(std::min(box_a.lower_left().x, box_b.lower_left().x),
- std::min(box_a.lower_left().y, box_b.lower_left().y));
- Point upper_right(std::max(box_a.upper_right().x, box_b.upper_right().x),
- std::max(box_a.upper_right().y, box_b.upper_right().y));
- return Box(lower_left, upper_right);
- }
-
- void Box::translate(md::Real a)
- {
- ll.x += a;
- ll.y += a;
- ur.x += a;
- ur.y += a;
- }
-
- std::vector<Box> Box::refine() const
- {
- std::vector<Box> result;
-
-// 1 | 2
-// 0 | 3
-
- Point new_ll = lower_left();
- Point new_ur = center();
- result.emplace_back(new_ll, new_ur);
-
- new_ll.y = center().y;
- new_ur.y = ur.y;
- result.emplace_back(new_ll, new_ur);
-
- new_ll = center();
- new_ur = upper_right();
- result.emplace_back(new_ll, new_ur);
-
- new_ll.y = ll.y;
- new_ur.y = center().y;
- result.emplace_back(new_ll, new_ur);
-
- return result;
- }
-
- std::vector<Point> Box::corners() const
- {
- return {ll, Point(ll.x, ur.y), ur, Point(ur.x, ll.y)};
- };
-
-
-}
diff --git a/matching/src/cell_with_value.cpp b/matching/src/cell_with_value.cpp
deleted file mode 100644
index d8fd7d4..0000000
--- a/matching/src/cell_with_value.cpp
+++ /dev/null
@@ -1,247 +0,0 @@
-#include <spdlog/spdlog.h>
-#include <spdlog/fmt/ostr.h>
-
-namespace spd = spdlog;
-
-#include "cell_with_value.h"
-
-namespace md {
-
-#ifdef MD_DEBUG
- long long int CellWithValue::max_id = 0;
-#endif
-
- Real CellWithValue::value_at(ValuePoint vp) const
- {
- switch(vp) {
- case ValuePoint::upper_left :
- return upper_left_value_;
- case ValuePoint::upper_right :
- return upper_right_value_;
- case ValuePoint::lower_left :
- return lower_left_value_;
- case ValuePoint::lower_right :
- return lower_right_value_;
- case ValuePoint::center:
- return central_value_;
- }
- // to shut up compiler warning
- return 1.0 / 0.0;
- }
-
- bool CellWithValue::has_value_at(ValuePoint vp) const
- {
- switch(vp) {
- case ValuePoint::upper_left :
- return upper_left_value_ >= 0;
- case ValuePoint::upper_right :
- return upper_right_value_ >= 0;
- case ValuePoint::lower_left :
- return lower_left_value_ >= 0;
- case ValuePoint::lower_right :
- return lower_right_value_ >= 0;
- case ValuePoint::center:
- return central_value_ >= 0;
- }
- // to shut up compiler warning
- return 1.0 / 0.0;
- }
-
- DualPoint CellWithValue::value_point(md::ValuePoint vp) const
- {
- switch(vp) {
- case ValuePoint::upper_left :
- return dual_box().upper_left();
- case ValuePoint::upper_right :
- return dual_box().upper_right();
- case ValuePoint::lower_left :
- return dual_box().lower_left();
- case ValuePoint::lower_right :
- return dual_box().lower_right();
- case ValuePoint::center:
- return dual_box().center();
- }
- // to shut up compiler warning
- return DualPoint();
- }
-
- bool CellWithValue::has_corner_value() const
- {
- return has_lower_left_value() or has_lower_right_value() or has_upper_left_value()
- or has_upper_right_value();
- }
-
- Real CellWithValue::stored_upper_bound() const
- {
- assert(has_max_possible_value_);
- return max_possible_value_;
- }
-
- Real CellWithValue::max_corner_value() const
- {
- return std::max({lower_left_value_, lower_right_value_, upper_left_value_, upper_right_value_});
- }
-
- Real CellWithValue::min_value() const
- {
- Real result = std::numeric_limits<Real>::max();
- for(auto vp : k_all_vps) {
- if (not has_value_at(vp)) {
- continue;
- }
- result = std::min(result, value_at(vp));
- }
- return result;
- }
-
- std::vector<CellWithValue> CellWithValue::get_refined_cells() const
- {
- std::vector<CellWithValue> result;
- result.reserve(4);
- for(const auto& refined_box : dual_box_.refine()) {
-
- CellWithValue refined_cell(refined_box, level() + 1);
-
-#ifdef MD_DEBUG
- refined_cell.parent_ids = parent_ids;
- refined_cell.parent_ids.push_back(id);
- refined_cell.id = ++max_id;
-#endif
-
- if (refined_box.lower_left() == dual_box_.lower_left()) {
- // _|_
- // H|_
-
- refined_cell.set_value_at(ValuePoint::lower_left, lower_left_value_);
- refined_cell.set_value_at(ValuePoint::upper_right, central_value_);
-
- } else if (refined_box.upper_right() == dual_box_.upper_right()) {
- // _|H
- // _|_
-
- refined_cell.set_value_at(ValuePoint::lower_left, central_value_);
- refined_cell.set_value_at(ValuePoint::upper_right, upper_right_value_);
-
- } else if (refined_box.lower_right() == dual_box_.lower_right()) {
- // _|_
- // _|H
-
- refined_cell.set_value_at(ValuePoint::lower_right, lower_right_value_);
- refined_cell.set_value_at(ValuePoint::upper_left, central_value_);
-
- } else if (refined_box.upper_left() == dual_box_.upper_left()) {
-
- // H|_
- // _|_
-
- refined_cell.set_value_at(ValuePoint::lower_right, central_value_);
- refined_cell.set_value_at(ValuePoint::upper_left, upper_left_value_);
- }
- result.emplace_back(refined_cell);
- }
- return result;
- }
-
- void CellWithValue::set_value_at(md::ValuePoint vp, md::Real new_value)
- {
- if (has_value_at(vp))
- spd::error("CellWithValue: trying to re-assign value!, this = {}, vp = {}", *this, vp);
-
- switch(vp) {
- case ValuePoint::upper_left :
- upper_left_value_ = new_value;
- break;
- case ValuePoint::upper_right :
- upper_right_value_ = new_value;
- break;
- case ValuePoint::lower_left :
- lower_left_value_ = new_value;
- break;
- case ValuePoint::lower_right :
- lower_right_value_ = new_value;
- break;
- case ValuePoint::center:
- central_value_ = new_value;
- break;
- }
-
-
- }
-
- int CellWithValue::num_values() const
- {
- int result = 0;
- for(ValuePoint vp : k_all_vps) {
- result += has_value_at(vp);
- }
- return result;
- }
-
-
- void CellWithValue::set_max_possible_value(Real new_upper_bound)
- {
- assert(new_upper_bound >= central_value_);
- assert(new_upper_bound >= lower_left_value_);
- assert(new_upper_bound >= lower_right_value_);
- assert(new_upper_bound >= upper_left_value_);
- assert(new_upper_bound >= upper_right_value_);
- has_max_possible_value_ = true;
- max_possible_value_ = new_upper_bound;
- }
-
- std::ostream& operator<<(std::ostream& os, const ValuePoint& vp)
- {
- switch(vp) {
- case ValuePoint::upper_left :
- os << "upper_left";
- break;
- case ValuePoint::upper_right :
- os << "upper_right";
- break;
- case ValuePoint::lower_left :
- os << "lower_left";
- break;
- case ValuePoint::lower_right :
- os << "lower_right";
- break;
- case ValuePoint::center:
- os << "center";
- break;
- default:
- os << "FORGOTTEN ValuePoint";
- }
- return os;
- }
-
-
-
- std::ostream& operator<<(std::ostream& os, const CellWithValue& cell)
- {
- os << "CellWithValue(box = " << cell.dual_box() << ", ";
-
-#ifdef MD_DEBUG
- os << "id = " << cell.id;
- if (not cell.parent_ids.empty())
- os << ", parent_ids = " << container_to_string(cell.parent_ids) << ", ";
-#endif
-
- for(ValuePoint vp : k_all_vps) {
- if (cell.has_value_at(vp)) {
- os << "value = " << cell.value_at(vp);
- os << ", at " << vp << " " << cell.value_point(vp);
- }
- }
-
- os << ", max_corner_value = ";
- if (cell.has_max_possible_value()) {
- os << cell.stored_upper_bound();
- } else {
- os << "-";
- }
-
- os << ", level = " << cell.level() << ")";
- return os;
- }
-
-} // namespace md
-
diff --git a/matching/src/common_util.cpp b/matching/src/common_util.cpp
deleted file mode 100644
index 96c3388..0000000
--- a/matching/src/common_util.cpp
+++ /dev/null
@@ -1,243 +0,0 @@
-#include <vector>
-#include <utility>
-#include <cmath>
-#include <ostream>
-#include <limits>
-#include <algorithm>
-
-#include <common_util.h>
-
-#include "spdlog/spdlog.h"
-#include "spdlog/fmt/ostr.h"
-
-namespace md {
-
-
- int gcd(int a, int b)
- {
- assert(a != 0 or b != 0);
- // make b <= a
- std::tie(b, a) = std::minmax({ abs(a), abs(b) });
- if (b == 0)
- return a;
- while((a = a % b)) {
- std::swap(a, b);
- }
- return b;
- }
-
- int signum(int a)
- {
- if (a < 0)
- return -1;
- else if (a > 0)
- return 1;
- else
- return 0;
- }
-
- Rational reduce(Rational frac)
- {
- int d = gcd(frac.numerator, frac.denominator);
- frac.numerator /= d;
- frac.denominator /= d;
- return frac;
- }
-
- void Rational::reduce() { *this = md::reduce(*this); }
-
-
- Rational& Rational::operator*=(const md::Rational& rhs)
- {
- numerator *= rhs.numerator;
- denominator *= rhs.denominator;
- reduce();
- return *this;
- }
-
- Rational& Rational::operator/=(const md::Rational& rhs)
- {
- numerator *= rhs.denominator;
- denominator *= rhs.numerator;
- reduce();
- return *this;
- }
-
- Rational& Rational::operator+=(const md::Rational& rhs)
- {
- numerator = numerator * rhs.denominator + denominator * rhs.numerator;
- denominator *= rhs.denominator;
- reduce();
- return *this;
- }
-
- Rational& Rational::operator-=(const md::Rational& rhs)
- {
- numerator = numerator * rhs.denominator - denominator * rhs.numerator;
- denominator *= rhs.denominator;
- reduce();
- return *this;
- }
-
-
- Rational midpoint(Rational a, Rational b)
- {
- return reduce({a.numerator * b.denominator + a.denominator * b.numerator, 2 * a.denominator * b.denominator });
- }
-
- Rational operator+(Rational a, const Rational& b)
- {
- a += b;
- return a;
- }
-
- Rational operator-(Rational a, const Rational& b)
- {
- a -= b;
- return a;
- }
-
- Rational operator*(Rational a, const Rational& b)
- {
- a *= b;
- return a;
- }
-
- Rational operator/(Rational a, const Rational& b)
- {
- a /= b;
- return a;
- }
-
- bool is_less(Rational a, Rational b)
- {
- // compute a - b = a_1 / a_2 - b_1 / b_2
- long numer = a.numerator * b.denominator - a.denominator * b.numerator;
- long denom = a.denominator * b.denominator;
- assert(denom != 0);
- return signum(numer) * signum(denom) < 0;
- }
-
- bool operator==(const Rational& a, const Rational& b)
- {
- return std::tie(a.numerator, a.denominator) == std::tie(b.numerator, b.denominator);
- }
-
- bool operator<(const Rational& a, const Rational& b)
- {
- // do not remove signum - overflow
- long numer = a.numerator * b.denominator - a.denominator * b.numerator;
- long denom = a.denominator * b.denominator;
- assert(denom != 0);
-// spdlog::debug("a = {}, b = {}, numer = {}, denom = {}, result = {}", a, b, numer, denom, signum(numer) * signum(denom) <= 0);
- return signum(numer) * signum(denom) < 0;
- }
-
- bool is_leq(Rational a, Rational b)
- {
- // compute a - b = a_1 / a_2 - b_1 / b_2
- long numer = a.numerator * b.denominator - a.denominator * b.numerator;
- long denom = a.denominator * b.denominator;
- assert(denom != 0);
- return signum(numer) * signum(denom) <= 0;
- }
-
- bool is_greater(Rational a, Rational b)
- {
- return not is_leq(a, b);
- }
-
- bool is_geq(Rational a, Rational b)
- {
- return not is_less(a, b);
- }
-
- Point operator+(const Point& u, const Point& v)
- {
- return Point(u.x + v.x, u.y + v.y);
- }
-
- Point operator-(const Point& u, const Point& v)
- {
- return Point(u.x - v.x, u.y - v.y);
- }
-
- Point least_upper_bound(const Point& u, const Point& v)
- {
- return Point(std::max(u.x, v.x), std::max(u.y, v.y));
- }
-
- Point greatest_lower_bound(const Point& u, const Point& v)
- {
- return Point(std::min(u.x, v.x), std::min(u.y, v.y));
- }
-
- Point max_point()
- {
- return Point(std::numeric_limits<Real>::max(), std::numeric_limits<Real>::min());
- }
-
- Point min_point()
- {
- return Point(-std::numeric_limits<Real>::max(), -std::numeric_limits<Real>::min());
- }
-
- std::ostream& operator<<(std::ostream& ostr, const Point& vec)
- {
- ostr << "(" << vec.x << ", " << vec.y << ")";
- return ostr;
- }
-
- Real l_infty_norm(const Point& v)
- {
- return std::max(std::abs(v.x), std::abs(v.y));
- }
-
- Real l_2_norm(const Point& v)
- {
- return v.norm();
- }
-
- Real l_2_dist(const Point& x, const Point& y)
- {
- return l_2_norm(x - y);
- }
-
- Real l_infty_dist(const Point& x, const Point& y)
- {
- return l_infty_norm(x - y);
- }
-
- void DiagramKeeper::add_point(int dim, md::Real birth, md::Real death)
- {
- data_[dim].emplace_back(birth, death);
- }
-
- DiagramKeeper::Diagram DiagramKeeper::get_diagram(int dim) const
- {
- if (data_.count(dim) == 1)
- return data_.at(dim);
- else
- return DiagramKeeper::Diagram();
- }
-
- // return true, if line starts with #
- // or contains only spaces
- bool ignore_line(const std::string& s)
- {
- for(auto c : s) {
- if (isspace(c))
- continue;
- return (c == '#');
- }
- return true;
- }
-
-
-
- std::ostream& operator<<(std::ostream& os, const Rational& a)
- {
- os << a.numerator << " / " << a.denominator;
- return os;
- }
-}
diff --git a/matching/src/dual_box.cpp b/matching/src/dual_box.cpp
deleted file mode 100644
index ff4d30c..0000000
--- a/matching/src/dual_box.cpp
+++ /dev/null
@@ -1,194 +0,0 @@
-#include <random>
-
-#include "spdlog/spdlog.h"
-#include "spdlog/fmt/ostr.h"
-
-namespace spd = spdlog;
-
-#include "dual_box.h"
-
-namespace md {
-
- std::ostream& operator<<(std::ostream& os, const DualBox& db)
- {
- os << "DualBox(" << db.lower_left_ << ", " << db.upper_right_ << ")";
- return os;
- }
-
- DualBox::DualBox(DualPoint ll, DualPoint ur)
- :lower_left_(ll), upper_right_(ur)
- {
- }
-
- std::vector<DualPoint> DualBox::corners() const
- {
- return {lower_left_,
- DualPoint(axis_type(), angle_type(), lower_left_.lambda(), upper_right_.mu()),
- upper_right_,
- DualPoint(axis_type(), angle_type(), upper_right_.lambda(), lower_left_.mu())};
- }
-
- std::vector<DualPoint> DualBox::push_change_points(const Point& p) const
- {
- std::vector<DualPoint> result;
- result.reserve(2);
-
- bool is_y_type = lower_left_.is_y_type();
- bool is_flat = lower_left_.is_flat();
-
- auto mu_from_lambda = [p, is_y_type, is_flat](Real lambda) {
- bool is_x_type = not is_y_type, is_steep = not is_flat;
- if (is_y_type and is_flat) {
- return p.y - lambda * p.x;
- } else if (is_y_type and is_steep) {
- return p.y - p.x / lambda;
- } else if (is_x_type and is_flat) {
- return p.x - p.y / lambda;
- } else if (is_x_type and is_steep) {
- return p.x - lambda * p.y;
- }
- // to shut up compiler warning
- return static_cast<Real>(1.0 / 0.0);
- };
-
- auto lambda_from_mu = [p, is_y_type, is_flat](Real mu) {
- bool is_x_type = not is_y_type, is_steep = not is_flat;
- if (is_y_type and is_flat) {
- return (p.y - mu) / p.x;
- } else if (is_y_type and is_steep) {
- return p.x / (p.y - mu);
- } else if (is_x_type and is_flat) {
- return p.y / (p.x - mu);
- } else if (is_x_type and is_steep) {
- return (p.x - mu) / p.y;
- }
- // to shut up compiler warning
- return static_cast<Real>(1.0 / 0.0);
- };
-
- // all inequalities below are strict: equality means it is a corner
- // and critical_points() returns corners anyway
-
- Real mu_intersect_min = mu_from_lambda(lambda_min());
-
- if (mu_min() < mu_intersect_min && mu_intersect_min < mu_max())
- result.emplace_back(axis_type(), angle_type(), lambda_min(), mu_intersect_min);
-
- Real mu_intersect_max = mu_from_lambda(lambda_max());
-
- if (mu_max() < mu_intersect_max && mu_intersect_max < mu_max())
- result.emplace_back(axis_type(), angle_type(), lambda_max(), mu_intersect_max);
-
- Real lambda_intersect_min = lambda_from_mu(mu_min());
-
- if (lambda_min() < lambda_intersect_min && lambda_intersect_min < lambda_max())
- result.emplace_back(axis_type(), angle_type(), lambda_intersect_min, mu_min());
-
- Real lambda_intersect_max = lambda_from_mu(mu_max());
- if (lambda_min() < lambda_intersect_max && lambda_intersect_max < lambda_max())
- result.emplace_back(axis_type(), angle_type(), lambda_intersect_max, mu_max());
-
- assert(result.size() <= 2);
-
- if (result.size() > 2) {
- fmt::print("Error in push_change_points, p = {}, dual_box = {}, result = {}\n", p, *this,
- container_to_string(result));
- throw std::runtime_error("push_change_points returned more than 2 points");
- }
-
- return result;
- }
-
- std::vector<DualPoint> DualBox::critical_points(const Point& /*p*/) const
- {
- // maximal difference is attained at corners
- return corners();
-// std::vector<DualPoint> result;
-// result.reserve(6);
-// for(auto dp : corners()) result.push_back(dp);
-// for(auto dp : push_change_points(p)) result.push_back(dp);
-// return result;
- }
-
- std::vector<DualPoint> DualBox::random_points(int n) const
- {
- assert(n >= 0);
- std::mt19937_64 gen(1);
- std::vector<DualPoint> result;
- result.reserve(n);
- std::uniform_real_distribution<Real> mu_distr(mu_min(), mu_max());
- std::uniform_real_distribution<Real> lambda_distr(lambda_min(), lambda_max());
- for(int i = 0; i < n; ++i) {
- result.emplace_back(axis_type(), angle_type(), lambda_distr(gen), mu_distr(gen));
- }
- return result;
- }
-
- bool DualBox::sanity_check() const
- {
- lower_left_.sanity_check();
- upper_right_.sanity_check();
-
- if (lower_left_.angle_type() != upper_right_.angle_type())
- throw std::runtime_error("angle types differ");
-
- if (lower_left_.axis_type() != upper_right_.axis_type())
- throw std::runtime_error("axis types differ");
-
- if (lower_left_.lambda() >= upper_right_.lambda())
- throw std::runtime_error("lambda of lower_left_ greater than lambda of upper_right ");
-
- if (lower_left_.mu() >= upper_right_.mu())
- throw std::runtime_error("mu of lower_left_ greater than mu of upper_right ");
-
- return true;
- }
-
- std::vector<DualBox> DualBox::refine() const
- {
- std::vector<DualBox> result;
-
- result.reserve(4);
-
- Real lambda_middle = (lower_left().lambda() + upper_right().lambda()) / 2.0;
- Real mu_middle = (lower_left().mu() + upper_right().mu()) / 2.0;
-
- DualPoint refinement_center(axis_type(), angle_type(), lambda_middle, mu_middle);
-
- result.emplace_back(lower_left_, refinement_center);
-
- result.emplace_back(DualPoint(axis_type(), angle_type(), lambda_middle, mu_min()),
- DualPoint(axis_type(), angle_type(), lambda_max(), mu_middle));
-
- result.emplace_back(refinement_center, upper_right_);
-
- result.emplace_back(DualPoint(axis_type(), angle_type(), lambda_min(), mu_middle),
- DualPoint(axis_type(), angle_type(), lambda_middle, mu_max()));
- return result;
- }
-
- bool DualBox::operator==(const DualBox& other) const
- {
- return lower_left() == other.lower_left() and
- upper_right() == other.upper_right();
- }
-
- bool DualBox::contains(const DualPoint& dp) const
- {
- return dp.angle_type() == angle_type() and dp.axis_type() == axis_type() and
- mu_max() >= dp.mu() and
- mu_min() <= dp.mu() and
- lambda_min() <= dp.lambda() and
- lambda_max() >= dp.lambda();
- }
-
- DualPoint DualBox::lower_right() const
- {
- return DualPoint(lower_left_.axis_type(), lower_left_.angle_type(), lambda_max(), mu_min());
- }
-
- DualPoint DualBox::upper_left() const
- {
- return DualPoint(lower_left_.axis_type(), lower_left_.angle_type(), lambda_min(), mu_max());
- }
-}
diff --git a/matching/src/dual_point.cpp b/matching/src/dual_point.cpp
deleted file mode 100644
index 1c00b58..0000000
--- a/matching/src/dual_point.cpp
+++ /dev/null
@@ -1,282 +0,0 @@
-#include <tuple>
-
-#include "dual_point.h"
-
-namespace md {
-
- std::ostream& operator<<(std::ostream& os, const AxisType& at)
- {
- if (at == AxisType::x_type)
- os << "x-type";
- else
- os << "y-type";
- return os;
- }
-
- std::ostream& operator<<(std::ostream& os, const AngleType& at)
- {
- if (at == AngleType::flat)
- os << "flat";
- else
- os << "steep";
- return os;
- }
-
- std::ostream& operator<<(std::ostream& os, const DualPoint& dp)
- {
- os << "Line(" << dp.axis_type() << ", ";
- os << dp.angle_type() << ", ";
- os << dp.lambda() << ", ";
- os << dp.mu() << ", equation: ";
- if (not dp.is_vertical()) {
- os << "y = " << dp.y_slope() << " x + " << dp.y_intercept();
- } else {
- os << "x = " << dp.x_intercept();
- }
- os << " )";
- return os;
- }
-
- bool DualPoint::operator<(const DualPoint& rhs) const
- {
- return std::tie(axis_type_, angle_type_, lambda_, mu_)
- < std::tie(rhs.axis_type_, rhs.angle_type_, rhs.lambda_, rhs.mu_);
- }
-
- DualPoint::DualPoint(AxisType axis_type, AngleType angle_type, Real lambda, Real mu)
- :
- axis_type_(axis_type),
- angle_type_(angle_type),
- lambda_(lambda),
- mu_(mu)
- {
- assert(sanity_check());
- }
-
- bool DualPoint::sanity_check() const
- {
- if (lambda_ < 0.0)
- throw std::runtime_error("Invalid line, negative lambda");
- if (lambda_ > 1.0)
- throw std::runtime_error("Invalid line, lambda > 1");
- if (mu_ < 0.0)
- throw std::runtime_error("Invalid line, negative mu");
- return true;
- }
-
- Real DualPoint::gamma() const
- {
- if (is_steep())
- return atan(Real(1.0) / lambda_);
- else
- return atan(lambda_);
- }
-
- DualPoint midpoint(DualPoint x, DualPoint y)
- {
- assert(x.angle_type() == y.angle_type() and x.axis_type() == y.axis_type());
- Real lambda_mid = (x.lambda() + y.lambda()) / 2;
- Real mu_mid = (x.mu() + y.mu()) / 2;
- return DualPoint(x.axis_type(), x.angle_type(), lambda_mid, mu_mid);
-
- }
-
- // return k in the line equation y = kx + b
- Real DualPoint::y_slope() const
- {
- if (is_flat())
- return lambda();
- else
- return Real(1.0) / lambda();
- }
-
- // return k in the line equation x = ky + b
- Real DualPoint::x_slope() const
- {
- if (is_flat())
- return Real(1.0) / lambda();
- else
- return lambda();
- }
-
- // return b in the line equation y = kx + b
- Real DualPoint::y_intercept() const
- {
- if (is_y_type()) {
- return mu();
- } else {
- // x = x_slope * y + mu = x_slope * (y + mu / x_slope)
- // x-intercept is -mu/x_slope = -mu * y_slope
- return -mu() * y_slope();
- }
- }
-
- // return k in the line equation x = ky + b
- Real DualPoint::x_intercept() const
- {
- if (is_x_type()) {
- return mu();
- } else {
- // y = y_slope * x + mu = y_slope (x + mu / y_slope)
- // x_intercept is -mu/y_slope = -mu * x_slope
- return -mu() * x_slope();
- }
- }
-
- Real DualPoint::x_from_y(Real y) const
- {
- if (is_horizontal())
- throw std::runtime_error("x_from_y called on horizontal line");
- else
- return x_slope() * y + x_intercept();
- }
-
- Real DualPoint::y_from_x(Real x) const
- {
- if (is_vertical())
- throw std::runtime_error("x_from_y called on horizontal line");
- else
- return y_slope() * x + y_intercept();
- }
-
- bool DualPoint::is_horizontal() const
- {
- return is_flat() and lambda() == 0;
- }
-
- bool DualPoint::is_vertical() const
- {
- return is_steep() and lambda() == 0;
- }
-
- bool DualPoint::contains(Point p) const
- {
- if (is_vertical())
- return p.x == x_from_y(p.y);
- else
- return p.y == y_from_x(p.x);
- }
-
- bool DualPoint::goes_below(Point p) const
- {
- if (is_vertical())
- return p.x <= x_from_y(p.y);
- else
- return p.y >= y_from_x(p.x);
- }
-
- bool DualPoint::goes_above(Point p) const
- {
- if (is_vertical())
- return p.x >= x_from_y(p.y);
- else
- return p.y <= y_from_x(p.x);
- }
-
- Point DualPoint::push(Point p) const
- {
- Point result;
- // if line is below p, we push horizontally
- bool horizontal_push = goes_below(p);
- if (is_x_type()) {
- if (is_flat()) {
- if (horizontal_push) {
- result.x = p.y / lambda() + mu();
- result.y = p.y;
- } else {
- // vertical push
- result.x = p.x;
- result.y = lambda() * (p.x - mu());
- }
- } else {
- // steep
- if (horizontal_push) {
- result.x = lambda() * p.y + mu();
- result.y = p.y;
- } else {
- // vertical push
- result.x = p.x;
- result.y = (p.x - mu()) / lambda();
- }
- }
- } else {
- // y-type
- if (is_flat()) {
- if (horizontal_push) {
- result.x = (p.y - mu()) / lambda();
- result.y = p.y;
- } else {
- // vertical push
- result.x = p.x;
- result.y = lambda() * p.x + mu();
- }
- } else {
- // steep
- if (horizontal_push) {
- result.x = (p.y - mu()) * lambda();
- result.y = p.y;
- } else {
- // vertical push
- result.x = p.x;
- result.y = p.x / lambda() + mu();
- }
- }
- }
- return result;
- }
-
- Real DualPoint::weighted_push(Point p) const
- {
- // if line is below p, we push horizontally
- bool horizontal_push = goes_below(p);
- if (is_x_type()) {
- if (is_flat()) {
- if (horizontal_push) {
- return p.y;
- } else {
- // vertical push
- return lambda() * (p.x - mu());
- }
- } else {
- // steep
- if (horizontal_push) {
- return lambda() * p.y;
- } else {
- // vertical push
- return (p.x - mu());
- }
- }
- } else {
- // y-type
- if (is_flat()) {
- if (horizontal_push) {
- return p.y - mu();
- } else {
- // vertical push
- return lambda() * p.x;
- }
- } else {
- // steep
- if (horizontal_push) {
- return lambda() * (p.y - mu());
- } else {
- // vertical push
- return p.x;
- }
- }
- }
- }
-
- bool DualPoint::operator==(const DualPoint& other) const
- {
- return axis_type() == other.axis_type() and
- angle_type() == other.angle_type() and
- mu() == other.mu() and
- lambda() == other.lambda();
- }
-
- Real DualPoint::weight() const
- {
- return lambda_ / sqrt(1 + lambda_ * lambda_);
- }
-} // namespace md
diff --git a/matching/src/main.cpp b/matching/src/main.cpp
index f1472be..2093457 100644
--- a/matching/src/main.cpp
+++ b/matching/src/main.cpp
@@ -18,12 +18,20 @@
#include "box.h"
#include "matching_distance.h"
+using Real = double;
+
using namespace md;
namespace fs = std::experimental::filesystem;
+void force_instantiation()
+{
+ DualBox<Real> db;
+ std::cout << db;
+}
+
#ifdef PRINT_HEAT_MAP
-void print_heat_map(const md::HeatMaps& hms, std::string fname, const CalculationParams& params)
+void print_heat_map(const md::HeatMaps<Real>& hms, std::string fname, const CalculationParams<Real>& params)
{
spd::debug("Entered print_heat_map");
std::set<Real> mu_vals, lambda_vals;
@@ -143,7 +151,7 @@ int main(int argc, char** argv)
bool help = false;
bool no_stop_asap = false;
- CalculationParams params;
+ CalculationParams<Real> params;
#ifdef PRINT_HEAT_MAP
bool heatmap_only = false;
@@ -178,8 +186,8 @@ int main(int argc, char** argv)
auto bounds_list = split_by_delim(bounds_list_str, ',');
auto traverse_list = split_by_delim(traverse_list_str, ',');
- Bifiltration bif_a(fname_a);
- Bifiltration bif_b(fname_b);
+ Bifiltration<Real> bif_a(fname_a);
+ Bifiltration<Real> bif_b(fname_b);
bif_a.sanity_check();
bif_b.sanity_check();
@@ -207,11 +215,11 @@ int main(int argc, char** argv)
}
struct ExperimentResult {
- CalculationParams params {CalculationParams()};
+ CalculationParams<Real> params {CalculationParams()};
int n_hera_calls {0};
double total_milliseconds_elapsed {0};
- double distance {0};
- double actual_error {std::numeric_limits<double>::max()};
+ Real distance {0};
+ Real actual_error {std::numeric_limits<double>::max()};
int actual_max_depth {0};
int x_wins {0};
@@ -250,7 +258,7 @@ int main(int argc, char** argv)
ExperimentResult() { }
- ExperimentResult(CalculationParams p, int nhc, double tme, double d)
+ ExperimentResult(CalculationParams<Real> p, int nhc, double tme, double d)
:
params(p), n_hera_calls(nhc), total_milliseconds_elapsed(tme), distance(d) { }
};
@@ -267,7 +275,7 @@ int main(int argc, char** argv)
std::map<std::tuple<BoundStrategy, TraverseStrategy>, ExperimentResult> results;
for(BoundStrategy bound_strategy : bound_strategies) {
for(TraverseStrategy traverse_strategy : traverse_strategies) {
- CalculationParams params_experiment;
+ CalculationParams<Real> params_experiment;
params_experiment.bound_strategy = bound_strategy;
params_experiment.traverse_strategy = traverse_strategy;
params_experiment.max_depth = params.max_depth;
@@ -366,8 +374,9 @@ int main(int argc, char** argv)
spd::debug("Will use {} bound, {} traverse strategy", params.bound_strategy, params.traverse_strategy);
- Real dist = matching_distance(bif_a, bif_b, params);
+ Real dist = matching_distance<Real>(bif_a, bif_b, params);
std::cout << dist << std::endl;
#endif
+ force_instantiation();
return 0;
}
diff --git a/matching/src/matching_distance.cpp b/matching/src/matching_distance.cpp
deleted file mode 100644
index e53233f..0000000
--- a/matching/src/matching_distance.cpp
+++ /dev/null
@@ -1,150 +0,0 @@
-#include <chrono>
-#include <tuple>
-#include <algorithm>
-
-#include "common_defs.h"
-
-#include "matching_distance.h"
-
-namespace md {
-
- Real matching_distance(const Bifiltration& bif_a, const Bifiltration& bif_b,
- CalculationParams& params)
- {
- Real result;
- // compute distance only in one dimension
- if (params.dim != CalculationParams::ALL_DIMENSIONS) {
- BifiltrationProxy bifp_a(bif_a, params.dim);
- BifiltrationProxy bifp_b(bif_b, params.dim);
- DistanceCalculator<BifiltrationProxy> runner(bifp_a, bifp_b, params);
- result = runner.distance();
- params.n_hera_calls = runner.get_hera_calls_number();
- } else {
- // compute distance in all dimensions, return maximal
- result = -1;
- for(int dim = 0; dim < std::max(bif_a.maximal_dim(), bif_b.maximal_dim()); ++dim) {
- BifiltrationProxy bifp_a(bif_a, params.dim);
- BifiltrationProxy bifp_b(bif_a, params.dim);
- DistanceCalculator<BifiltrationProxy> runner(bifp_a, bifp_b, params);
- result = std::max(result, runner.distance());
- params.n_hera_calls += runner.get_hera_calls_number();
- }
- }
- return result;
- }
-
-
- Real matching_distance(const ModulePresentation& mod_a, const ModulePresentation& mod_b,
- CalculationParams& params)
- {
- DistanceCalculator<ModulePresentation> runner(mod_a, mod_b, params);
- Real result = runner.distance();
- params.n_hera_calls = runner.get_hera_calls_number();
- return result;
- }
-
- std::istream& operator>>(std::istream& is, BoundStrategy& s)
- {
- std::string ss;
- is >> ss;
- if (ss == "bruteforce") {
- s = BoundStrategy::bruteforce;
- } else if (ss == "local_grob") {
- s = BoundStrategy::local_dual_bound;
- } else if (ss == "local_combined") {
- s = BoundStrategy::local_combined;
- } else if (ss == "local_refined") {
- s = BoundStrategy::local_dual_bound_refined;
- } else if (ss == "local_for_each_point") {
- s = BoundStrategy::local_dual_bound_for_each_point;
- } else {
- throw std::runtime_error("UNKNOWN BOUND STRATEGY");
- }
- return is;
- }
-
- BoundStrategy bs_from_string(std::string s)
- {
- std::stringstream ss(s);
- BoundStrategy result;
- ss >> result;
- return result;
- }
-
- TraverseStrategy ts_from_string(std::string s)
- {
- std::stringstream ss(s);
- TraverseStrategy result;
- ss >> result;
- return result;
- }
-
- std::istream& operator>>(std::istream& is, TraverseStrategy& s)
- {
- std::string ss;
- is >> ss;
- if (ss == "DFS") {
- s = TraverseStrategy::depth_first;
- } else if (ss == "BFS") {
- s = TraverseStrategy::breadth_first;
- } else if (ss == "BFS-VAL") {
- s = TraverseStrategy::breadth_first_value;
- } else if (ss == "UB") {
- s = TraverseStrategy::upper_bound;
- } else {
- throw std::runtime_error("UNKNOWN TRAVERSE STRATEGY");
- }
- return is;
- }
-
- std::ostream& operator<<(std::ostream& os, const UbExperimentRecord& r)
- {
- os << r.time << "\t" << r.n_hera_calls << "\t" << r.error << "\t" << r.lower_bound << "\t" << r.upper_bound;
- return os;
- }
-
- std::ostream& operator<<(std::ostream& os, const BoundStrategy& s)
- {
- switch(s) {
- case BoundStrategy::bruteforce :
- os << "bruteforce";
- break;
- case BoundStrategy::local_dual_bound :
- os << "local_grob";
- break;
- case BoundStrategy::local_combined :
- os << "local_combined";
- break;
- case BoundStrategy::local_dual_bound_refined :
- os << "local_refined";
- break;
- case BoundStrategy::local_dual_bound_for_each_point :
- os << "local_for_each_point";
- break;
- default:
- os << "FORGOTTEN BOUND STRATEGY";
- }
- return os;
- }
-
- std::ostream& operator<<(std::ostream& os, const TraverseStrategy& s)
- {
- switch(s) {
- case TraverseStrategy::depth_first :
- os << "DFS";
- break;
- case TraverseStrategy::breadth_first :
- os << "BFS";
- break;
- case TraverseStrategy::breadth_first_value :
- os << "BFS-VAL";
- break;
- case TraverseStrategy::upper_bound :
- os << "UB";
- break;
- default:
- os << "FORGOTTEN TRAVERSE STRATEGY";
- }
- return os;
- }
-}
diff --git a/matching/src/persistence_module.cpp b/matching/src/persistence_module.cpp
deleted file mode 100644
index efb20ef..0000000
--- a/matching/src/persistence_module.cpp
+++ /dev/null
@@ -1,177 +0,0 @@
-#include <numeric>
-#include <algorithm>
-#include <unordered_set>
-
-#include <phat/boundary_matrix.h>
-#include <phat/compute_persistence_pairs.h>
-
-#include "persistence_module.h"
-
-namespace md {
-
- /**
- *
- * @param values vector of length n
- * @return [a_1,...,a_n] such that
- * 1) values[a_1] <= values[a_2] <= ... <= values[a_n]
- * 2) a_1,...,a_n is a permutation of 1,..,n
- */
-
- template<typename T>
- IndexVec get_sorted_indices(const std::vector<T>& values)
- {
- IndexVec result(values.size());
- std::iota(result.begin(), result.end(), 0);
- std::sort(result.begin(), result.end(),
- [&values](size_t a, size_t b) { return values[a] < values[b]; });
- return result;
- }
-
- // helper function to initialize const member positions_ in ModulePresentation
- PointVec
- concat_gen_and_rel_positions(const PointVec& generators, const ModulePresentation::RelVec& relations)
- {
- std::unordered_set<Point> ps(generators.begin(), generators.end());
- for(const auto& rel : relations) {
- ps.insert(rel.position_);
- }
- return PointVec(ps.begin(), ps.end());
- }
-
-
- void ModulePresentation::init_boundaries()
- {
- max_x_ = std::numeric_limits<Real>::max();
- max_y_ = std::numeric_limits<Real>::max();
- min_x_ = -std::numeric_limits<Real>::max();
- min_y_ = -std::numeric_limits<Real>::max();
-
- for(const auto& gen : positions_) {
- min_x_ = std::min(gen.x, min_x_);
- min_y_ = std::min(gen.y, min_y_);
- max_x_ = std::max(gen.x, max_x_);
- max_y_ = std::max(gen.y, max_y_);
- }
-
- bounding_box_ = Box(Point(min_x_, min_y_), Point(max_x_, max_y_));
- }
-
-
- ModulePresentation::ModulePresentation(const PointVec& _generators, const RelVec& _relations) :
- generators_(_generators),
- relations_(_relations)
- {
- init_boundaries();
- }
-
- void ModulePresentation::translate(md::Real a)
- {
- for(auto& g : generators_) {
- g.translate(a);
- }
-
- for(auto& r : relations_) {
- r.position_.translate(a);
- }
-
- positions_ = concat_gen_and_rel_positions(generators_, relations_);
- init_boundaries();
- }
-
-
- /**
- *
- * @param slice line on which generators are projected
- * @param sorted_indices [a_1,...,a_n] s.t. wpush(generator[a_1]) <= wpush(generator[a_2]) <= ..
- * @param projections sorted weighted pushes of generators
- */
-
- void
- ModulePresentation::project_generators(const DualPoint& slice, IndexVec& sorted_indices, RealVec& projections) const
- {
- size_t num_gens = generators_.size();
-
- RealVec gen_values;
- gen_values.reserve(num_gens);
- for(const auto& pos : generators_) {
- gen_values.push_back(slice.weighted_push(pos));
- }
- sorted_indices = get_sorted_indices(gen_values);
- projections.clear();
- projections.reserve(num_gens);
- for(auto i : sorted_indices) {
- projections.push_back(gen_values[i]);
- }
- }
-
- void ModulePresentation::project_relations(const DualPoint& slice, IndexVec& sorted_rel_indices,
- RealVec& projections) const
- {
- size_t num_rels = relations_.size();
-
- RealVec rel_values;
- rel_values.reserve(num_rels);
- for(const auto& rel : relations_) {
- rel_values.push_back(slice.weighted_push(rel.position_));
- }
- sorted_rel_indices = get_sorted_indices(rel_values);
- projections.clear();
- projections.reserve(num_rels);
- for(auto i : sorted_rel_indices) {
- projections.push_back(rel_values[i]);
- }
- }
-
- Diagram ModulePresentation::weighted_slice_diagram(const DualPoint& slice) const
- {
- IndexVec sorted_gen_indices, sorted_rel_indices;
- RealVec gen_projections, rel_projections;
-
- project_generators(slice, sorted_gen_indices, gen_projections);
- project_relations(slice, sorted_rel_indices, rel_projections);
-
- phat::boundary_matrix<> phat_matrix;
-
- phat_matrix.set_num_cols(relations_.size());
-
- for(Index i = 0; i < (Index) relations_.size(); i++) {
- IndexVec current_relation = relations_[sorted_rel_indices[i]].components_;
- for(auto& j : current_relation) {
- j = sorted_gen_indices[j];
- }
- std::sort(current_relation.begin(), current_relation.end());
- phat_matrix.set_dim(i, current_relation.size());
- phat_matrix.set_col(i, current_relation);
- }
-
- phat::persistence_pairs phat_persistence_pairs;
- phat::compute_persistence_pairs<phat::twist_reduction>(phat_persistence_pairs, phat_matrix);
-
- Diagram dgm;
-
- constexpr Real real_inf = std::numeric_limits<Real>::infinity();
-
- for(Index i = 0; i < (Index) phat_persistence_pairs.get_num_pairs(); i++) {
- std::pair<phat::index, phat::index> new_pair = phat_persistence_pairs.get_pair(i);
- bool is_finite_pair = new_pair.second != phat::k_infinity_index;
- Real birth = gen_projections.at(new_pair.first);
- Real death = is_finite_pair ? rel_projections.at(new_pair.second) : real_inf;
- if (birth != death) {
- dgm.emplace_back(birth, death);
- }
- }
-
- return dgm;
- }
-
- PointVec ModulePresentation::positions() const
- {
- return positions_;
- }
-
- Box ModulePresentation::bounding_box() const
- {
- return bounding_box_;
- }
-
-}
diff --git a/matching/src/simplex.cpp b/matching/src/simplex.cpp
deleted file mode 100644
index 6b53680..0000000
--- a/matching/src/simplex.cpp
+++ /dev/null
@@ -1,121 +0,0 @@
-#include "simplex.h"
-
-namespace md {
-
- std::ostream& operator<<(std::ostream& os, const AbstractSimplex& s)
- {
- os << "AbstractSimplex(id = " << s.id << ", vertices_ = " << container_to_string(s.vertices_) << ")";
- return os;
- }
-
- bool operator<(const AbstractSimplex& a, const AbstractSimplex& b)
- {
- return a.vertices_ < b.vertices_;
- }
-
- bool operator==(const AbstractSimplex& s1, const AbstractSimplex& s2)
- {
- return s1.vertices_ == s2.vertices_;
- }
-
- void AbstractSimplex::push_back(int v)
- {
- vertices_.push_back(v);
- std::sort(vertices_.begin(), vertices_.end());
- }
-
- AbstractSimplex::AbstractSimplex(std::vector<int> vertices, bool sort)
- :vertices_(vertices)
- {
- if (sort)
- std::sort(vertices_.begin(), vertices_.end());
- }
-
- std::vector<AbstractSimplex> AbstractSimplex::facets() const
- {
- std::vector<AbstractSimplex> result;
- for (int i = 0; i < static_cast<int>(vertices_.size()); ++i) {
- std::vector<int> facet_vertices;
- facet_vertices.reserve(dim());
- for (int j = 0; j < static_cast<int>(vertices_.size()); ++j) {
- if (j != i)
- facet_vertices.push_back(vertices_[j]);
- }
- if (!facet_vertices.empty()) {
- result.emplace_back(facet_vertices, false);
- }
- }
- return result;
- }
-
- Simplex::Simplex(md::Index id, md::Point birth, int dim, const md::Column& bdry)
- :
- id_(id),
- pos_(birth),
- dim_(dim),
- facet_indices_(bdry) { }
-
- void Simplex::translate(Real a)
- {
- pos_.translate(a);
- }
-
- void Simplex::init_rivet(std::string s)
- {
- auto delim_pos = s.find_first_of(";");
- assert(delim_pos > 0);
- std::string vertices_str = s.substr(0, delim_pos);
- std::string pos_str = s.substr(delim_pos + 1);
- assert(not vertices_str.empty() and not pos_str.empty());
- // get vertices
- std::stringstream vertices_ss(vertices_str);
- int dim = 0;
- int vertex;
- while (vertices_ss >> vertex) {
- dim++;
- vertices_.push_back(vertex);
- }
- //
- std::sort(vertices_.begin(), vertices_.end());
- assert(dim > 0);
-
- std::stringstream pos_ss(pos_str);
- // TODO: get rid of 1-criticaltiy assumption
- pos_ss >> pos_.x >> pos_.y;
- }
-
- void Simplex::init_phat_like(std::string s)
- {
- facet_indices_.clear();
- std::stringstream ss(s);
- ss >> dim_ >> pos_.x >> pos_.y;
- if (dim_ > 0) {
- facet_indices_.reserve(dim_ + 1);
- for (int j = 0; j <= dim_; j++) {
- Index k;
- ss >> k;
- facet_indices_.push_back(k);
- }
- }
- }
-
- Simplex::Simplex(Index _id, std::string s, BifiltrationFormat input_format)
- :id_(_id)
- {
- switch (input_format) {
- case BifiltrationFormat::phat_like :
- init_phat_like(s);
- break;
- case BifiltrationFormat::rivet :
- init_rivet(s);
- break;
- }
- }
-
- std::ostream& operator<<(std::ostream& os, const Simplex& x)
- {
- os << "Simplex(id = " << x.id() << ", dim = " << x.dim();
- os << ", boundary = " << container_to_string(x.boundary()) << ", pos = " << x.position() << ")";
- return os;
- }
-}
diff --git a/matching/src/test_generator.cpp b/matching/src/test_generator.cpp
index e8f128f..a2f0625 100644
--- a/matching/src/test_generator.cpp
+++ b/matching/src/test_generator.cpp
@@ -11,9 +11,12 @@
#include "common_util.h"
#include "bifiltration.h"
+using Real = double;
using Index = md::Index;
-using Point = md::Point;
+using Point = md::Point<Real>;
+using Bifiltration = md::Bifiltration<Real>;
using Column = md::Column;
+using Simplex = md::Simplex<Real>;
int g_max_coord = 100;
@@ -100,7 +103,7 @@ void generate_positions(const ASimplex& s, ASimplexToBirthMap& simplex_to_birth,
}
}
-md::Bifiltration get_random_bifiltration(int n_vertices, int max_dim, int n_top_simplices)
+Bifiltration get_random_bifiltration(int n_vertices, int max_dim, int n_top_simplices)
{
ASimplexToBirthMap simplex_to_birth;
@@ -122,13 +125,13 @@ md::Bifiltration get_random_bifiltration(int n_vertices, int max_dim, int n_top_
add_if_top(candidate_simplex, top_simplices);
}
- Point upper_bound{static_cast<md::Real>(g_max_coord), static_cast<md::Real>(g_max_coord)};
+ Point upper_bound{static_cast<Real>(g_max_coord), static_cast<Real>(g_max_coord)};
for(const auto& top_simplex : top_simplices) {
generate_positions(top_simplex, simplex_to_birth, upper_bound);
}
std::vector<std::pair<ASimplex, Point>> simplex_birth_pairs{simplex_to_birth.begin(), simplex_to_birth.end()};
- std::vector<md::Column> boundaries{simplex_to_birth.size(), md::Column()};
+ std::vector<Column> boundaries{simplex_to_birth.size(), Column()};
// assign ids and save boundaries
int id = 0;
@@ -138,7 +141,7 @@ md::Bifiltration get_random_bifiltration(int n_vertices, int max_dim, int n_top_
ASimplex& simplex = simplex_birth_pairs[i].first;
if (simplex.dim() == dim) {
simplex.id = id++;
- md::Column bdry;
+ Column bdry;
for(auto& facet : simplex.facets()) {
auto facet_iter = std::find_if(simplex_birth_pairs.begin(), simplex_birth_pairs.end(),
[facet](const std::pair<ASimplex, Point>& sbp) { return facet == sbp.first; });
@@ -153,7 +156,7 @@ md::Bifiltration get_random_bifiltration(int n_vertices, int max_dim, int n_top_
}
// create vector of Simplex-es
- std::vector<md::Simplex> simplices;
+ std::vector<Simplex> simplices;
for(int i = 0; i < (int) simplex_birth_pairs.size(); ++i) {
int id = simplex_birth_pairs[i].first.id;
int dim = simplex_birth_pairs[i].first.dim();
@@ -164,13 +167,13 @@ md::Bifiltration get_random_bifiltration(int n_vertices, int max_dim, int n_top_
// sort by id
std::sort(simplices.begin(), simplices.end(),
- [](const md::Simplex& s1, const md::Simplex& s2) { return s1.id() < s2.id(); });
+ [](const Simplex& s1, const Simplex& s2) { return s1.id() < s2.id(); });
for(int i = 0; i < (int)simplices.size(); ++i) {
assert(simplices[i].id() == i);
assert(i == 0 || simplices[i].dim() >= simplices[i - 1].dim());
}
- return md::Bifiltration(simplices.begin(), simplices.end());
+ return Bifiltration(simplices.begin(), simplices.end());
}
int main(int argc, char** argv)
diff --git a/matching/src/tests/test_common.cpp b/matching/src/tests/test_common.cpp
index c55577e..9079a56 100644
--- a/matching/src/tests/test_common.cpp
+++ b/matching/src/tests/test_common.cpp
@@ -8,56 +8,24 @@
#include "simplex.h"
#include "matching_distance.h"
-using namespace md;
+//using namespace md;
+using Real = double;
+using Point = md::Point<Real>;
+using Bifiltration = md::Bifiltration<Real>;
+using BifiltrationProxy = md::BifiltrationProxy<Real>;
+using CalculationParams = md::CalculationParams<Real>;
+using CellWithValue = md::CellWithValue<Real>;
+using DualPoint = md::DualPoint<Real>;
+using DualBox = md::DualBox<Real>;
+using Simplex = md::Simplex<Real>;
+using AbstractSimplex = md::AbstractSimplex;
+using BoundStrategy = md::BoundStrategy;
+using TraverseStrategy = md::TraverseStrategy;
+using AxisType = md::AxisType;
+using AngleType = md::AngleType;
+using ValuePoint = md::ValuePoint;
+using Column = md::Column;
-TEST_CASE("Rational", "[common_utils][rational]")
-{
- // gcd
- REQUIRE(gcd(10, 5) == 5);
- REQUIRE(gcd(5, 10) == 5);
- REQUIRE(gcd(5, 7) == 1);
- REQUIRE(gcd(7, 5) == 1);
- REQUIRE(gcd(13, 0) == 13);
- REQUIRE(gcd(0, 13) == 13);
- REQUIRE(gcd(16, 24) == 8);
- REQUIRE(gcd(24, 16) == 8);
- REQUIRE(gcd(16, 32) == 16);
- REQUIRE(gcd(32, 16) == 16);
-
-
- // reduce
- REQUIRE(reduce({2, 1}) == std::make_pair(2, 1));
- REQUIRE(reduce({1, 2}) == std::make_pair(1, 2));
- REQUIRE(reduce({2, 2}) == std::make_pair(1, 1));
- REQUIRE(reduce({0, 2}) == std::make_pair(0, 1));
- REQUIRE(reduce({0, 20}) == std::make_pair(0, 1));
- REQUIRE(reduce({35, 49}) == std::make_pair(5, 7));
- REQUIRE(reduce({35, 25}) == std::make_pair(7, 5));
-
- // midpoint
- REQUIRE(midpoint(Rational {0, 1}, Rational {1, 2}) == std::make_pair(1, 4));
- REQUIRE(midpoint(Rational {1, 4}, Rational {1, 2}) == std::make_pair(3, 8));
- REQUIRE(midpoint(Rational {1, 2}, Rational {1, 2}) == std::make_pair(1, 2));
- REQUIRE(midpoint(Rational {1, 2}, Rational {1, 1}) == std::make_pair(3, 4));
- REQUIRE(midpoint(Rational {3, 7}, Rational {5, 14}) == std::make_pair(11, 28));
-
-
- // arithmetic
-
- REQUIRE(Rational(1, 2) + Rational(3, 5) == Rational(11, 10));
- REQUIRE(Rational(2, 5) - Rational(3, 10) == Rational(1, 10));
- REQUIRE(Rational(2, 3) * Rational(4, 7) == Rational(8, 21));
- REQUIRE(Rational(2, 3) * Rational(3, 2) == Rational(1));
- REQUIRE(Rational(2, 3) / Rational(3, 2) == Rational(4, 9));
- REQUIRE(Rational(1, 2) * Rational(3, 5) == Rational(3, 10));
-
- // comparison
- REQUIRE(Rational(100000, 2000000) < Rational(100001, 2000000));
- REQUIRE(!(Rational(100001, 2000000) < Rational(100000, 2000000)));
- REQUIRE(!(Rational(100000, 2000000) < Rational(100000, 2000000)));
- REQUIRE(Rational(-100000, 2000000) < Rational(100001, 2000000));
- REQUIRE(Rational(-100001, 2000000) < Rational(100000, 2000000));
-};
TEST_CASE("AbstractSimplex", "[abstract_simplex]")
{
diff --git a/matching/src/tests/test_matching_distance.cpp b/matching/src/tests/test_matching_distance.cpp
index df9345e..82da530 100644
--- a/matching/src/tests/test_matching_distance.cpp
+++ b/matching/src/tests/test_matching_distance.cpp
@@ -11,7 +11,25 @@
#include "simplex.h"
#include "matching_distance.h"
-using namespace md;
+using Real = double;
+using Point = md::Point<Real>;
+using Bifiltration = md::Bifiltration<Real>;
+using BifiltrationProxy = md::BifiltrationProxy<Real>;
+using CalculationParams = md::CalculationParams<Real>;
+using CellWithValue = md::CellWithValue<Real>;
+using DualPoint = md::DualPoint<Real>;
+using DualBox = md::DualBox<Real>;
+using Simplex = md::Simplex<Real>;
+using AbstractSimplex = md::AbstractSimplex;
+using BoundStrategy = md::BoundStrategy;
+using TraverseStrategy = md::TraverseStrategy;
+using AxisType = md::AxisType;
+using AngleType = md::AngleType;
+using ValuePoint = md::ValuePoint;
+using Column = md::Column;
+
+using md::k_corner_vps;
+
namespace spd = spdlog;
TEST_CASE("Different bounds", "[bounds]")
@@ -40,7 +58,7 @@ TEST_CASE("Different bounds", "[bounds]")
BifiltrationProxy bifp_a(bif_a, params.dim);
BifiltrationProxy bifp_b(bif_b, params.dim);
- DistanceCalculator<BifiltrationProxy> calc(bifp_a, bifp_b, params);
+ md::DistanceCalculator<Real, BifiltrationProxy> calc(bifp_a, bifp_b, params);
// REQUIRE(calc.max_x_ == Approx(max_x));
// REQUIRE(calc.max_y_ == Approx(max_y));