From 355ff89d1860f8d928bbc97a542b4560c257b9e6 Mon Sep 17 00:00:00 2001 From: Gard Spreemann Date: Wed, 5 Aug 2020 16:12:17 +0200 Subject: Improve. --- include/pd.hpp | 82 ++++++++++++++++++++++++++++++++++++++++++++++------------ src/main.cpp | 30 +++++++++++++++++++-- 2 files changed, 94 insertions(+), 18 deletions(-) diff --git a/include/pd.hpp b/include/pd.hpp index 04b8609..2641125 100644 --- a/include/pd.hpp +++ b/include/pd.hpp @@ -3,22 +3,32 @@ #include #include #include +#include #include #include #include "misc.hpp" -template class PD +template class Interval { public: - class Interval - { - public: - T birth; - T death; - }; + inline T length() const { return death - birth; } + T birth; + T death; +}; - +template inline bool operator==(Interval x, Interval y) { return x.birth == y.birth && x.death == y.death; } +template inline bool operator!=(Interval x, Interval y) { return !(x == y); } +template inline bool operator<(Interval x, Interval y) +{ + return (x.length() < y.length()) || + ((x.length() == y.length()) && (x.birth < y.birth)) || + (((x.length() == y.length()) && (x.birth == y.birth) && (x.death < y.death))); +} + +template class PD +{ +public: PD() : intervals() { }; @@ -27,7 +37,7 @@ public: { if (m > 0 && b < d) { - Interval interval; + Interval interval; interval.birth = b; interval.death = d; intervals.push_back(std::make_pair(interval, m)); @@ -45,18 +55,58 @@ public: } } - typename std::vector::Interval, unsigned int> >::size_type size() const { return intervals.size(); } + void discretize(T delta) + { + if (delta <= 0) + return; + for (auto it = begin(); it != end(); ++it) + { + it->first.birth = std::round(it->first.birth/delta)*delta; + it->first.death = std::round(it->first.death/delta)*delta; + } + } + + void compress_and_sort() + { + std::map, unsigned int> tmp; + for (auto it = cbegin(); it != cend(); ++it) + { + auto existing = tmp.find(it->first); + if (existing == tmp.end()) + { + tmp.insert(*it); + } + else + { + existing->second += it->second; + } + } + intervals = std::vector, unsigned int> >(tmp.cbegin(), tmp.cend()); + } + + unsigned int size() const + { + unsigned int ret = 0; + for (auto it = cbegin(); it != cend(); ++it) + ret += it->second; + return ret; + } + + inline typename std::vector, unsigned int> >::size_type size_2() const { return intervals.size(); } - using Iterator = typename std::vector::Interval, unsigned int> >::const_iterator; + using Iterator = typename std::vector, unsigned int> >::iterator; + using Const_iterator = typename std::vector, unsigned int> >::const_iterator; - inline Iterator cbegin() const { return intervals.cbegin(); } - inline Iterator cend() const { return intervals.cend(); } + inline Const_iterator cbegin() const { return intervals.cbegin(); } + inline Const_iterator cend() const { return intervals.cend(); } + inline Iterator begin() { return intervals.begin(); } + inline Iterator end() { return intervals.end(); } private: - std::vector::Interval, unsigned int> > intervals; + std::vector, unsigned int> > intervals; }; -inline double sqdist(PD::Interval x, PD::Interval y) +template inline T sqdist(Interval x, Interval y) { return (x.birth - y.birth)*(x.birth - y.birth) + (x.death - y.death)*(x.death - y.death); } @@ -68,7 +118,7 @@ template T heat_kernel(T sigma, const PD & a, const PD & b) for (auto it = a.cbegin(); it != a.cend(); ++it) { auto x = *it; - typename PD::Interval xbar; + Interval xbar; xbar.birth = x.first.death; xbar.death = x.first.birth; diff --git a/src/main.cpp b/src/main.cpp index 448ae5a..0626db5 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -17,10 +17,12 @@ void print_help(const std::string & invocation) std::cout << " If the same file name is given twice, the computation will exploit symmetry." << std::endl; std::cout << " --dimension, -d dim" << std::endl; std::cout << " Mandatory integer. Degree to read from the persistence diagrams." << std::endl; - std::cout << " --sigma, -d s" << std::endl; + std::cout << " --sigma, -s σ" << std::endl; std::cout << " Mandatory decimal. σ parameter in heat kernel." << std::endl; std::cout << " --finitize, -f x" << std::endl; std::cout << " Optional decimal. Finitize all infinite intervals to scale x. If not given (default), infinite intervals are ignored." << std::endl; + std::cout << " --discretize, -t x" << std::endl; + std::cout << " Optional decimal. Discretize to and x-by-x grid." << std::endl; std::cout << " --out, -o file" << std::endl; std::cout << " Mandatory output file name. Plain text." << std::endl; std::cout << " --chunk, -c c" << std::endl; @@ -77,6 +79,7 @@ int main(int argc, char ** argv) int dim = -1; double sigma = std::numeric_limits::quiet_NaN(); double finitization = std::numeric_limits::infinity(); + double discretization = 0; std::string out_file_name; int chunk_size = 10; @@ -114,6 +117,13 @@ int main(int argc, char ** argv) else arg_fail(invocation, world_rank, "Missing argument for --finitize."); } + else if (arg == std::string("--discretize") || arg == std::string("-x")) + { + if (i < argc - 1) + discretization = std::atof(argv[++i]); + else + arg_fail(invocation, world_rank, "Missing argument for --discretize."); + } else if (arg == std::string("--out") || arg == std::string("-o")) { if (i < argc - 1) @@ -151,6 +161,8 @@ int main(int argc, char ** argv) arg_fail(invocation, world_rank, "σ must be positive."); if (finitization <= 0) arg_fail(invocation, world_rank, "Finitization must be positive."); + if (discretization < 0) + arg_fail(invocation, world_rank, "Discretization must be non-negative."); if (out_file_name == std::string("")) arg_fail(invocation, world_rank, "Need output file."); if (chunk_size <= 0) @@ -307,7 +319,7 @@ int main(int argc, char ** argv) j = idxs[k].second; int load_status = 1; - + if (i != i_prev) { pd_1.clear(); @@ -317,6 +329,13 @@ int main(int argc, char ** argv) fail(world_rank, std::string("Failed to load file ") + files_1[i] + std::string(".")); pd_1.finitize(finitization); + if (discretization > 0) + { + std::cout << pd_1.size_2() << " --> "; + pd_1.discretize(discretization); + pd_1.compress_and_sort(); + std::cout << pd_1.size_2() << std::endl; + } } if (j != j_prev) @@ -328,6 +347,13 @@ int main(int argc, char ** argv) fail(world_rank, std::string("Failed to load file ") + files_2[j] + std::string(".")); pd_2.finitize(finitization); + if (discretization > 0) + { + std::cout << pd_2.size_2() << " --> "; + pd_2.discretize(discretization); + pd_2.compress_and_sort(); + std::cout << pd_2.size_2() << std::endl; + } } result[k - work[0]] = heat_kernel(sigma, pd_1, pd_2); -- cgit v1.2.3