summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGard Spreemann <gspr@nonempty.org>2020-08-05 16:12:17 +0200
committerGard Spreemann <gspr@nonempty.org>2020-08-05 16:12:17 +0200
commit355ff89d1860f8d928bbc97a542b4560c257b9e6 (patch)
tree3ec3df3ff2c99b7c46349844bf2f1339406dada5
parent4e4a4bfc4ccc3f39288678a7550ff12601d13223 (diff)
Improve.
-rw-r--r--include/pd.hpp82
-rw-r--r--src/main.cpp30
2 files changed, 94 insertions, 18 deletions
diff --git a/include/pd.hpp b/include/pd.hpp
index 04b8609..2641125 100644
--- a/include/pd.hpp
+++ b/include/pd.hpp
@@ -3,22 +3,32 @@
#include <cstdint>
#include <fstream>
#include <limits>
+#include <map>
#include <string>
#include <vector>
#include "misc.hpp"
-template <typename T> class PD
+template <typename T> class Interval
{
public:
- class Interval
- {
- public:
- T birth;
- T death;
- };
+ inline T length() const { return death - birth; }
+ T birth;
+ T death;
+};
-
+template <typename T> inline bool operator==(Interval<T> x, Interval<T> y) { return x.birth == y.birth && x.death == y.death; }
+template <typename T> inline bool operator!=(Interval<T> x, Interval<T> y) { return !(x == y); }
+template <typename T> inline bool operator<(Interval<T> x, Interval<T> y)
+{
+ return (x.length() < y.length()) ||
+ ((x.length() == y.length()) && (x.birth < y.birth)) ||
+ (((x.length() == y.length()) && (x.birth == y.birth) && (x.death < y.death)));
+}
+
+template <typename T> class PD
+{
+public:
PD() : intervals()
{
};
@@ -27,7 +37,7 @@ public:
{
if (m > 0 && b < d)
{
- Interval interval;
+ Interval<T> interval;
interval.birth = b;
interval.death = d;
intervals.push_back(std::make_pair(interval, m));
@@ -45,18 +55,58 @@ public:
}
}
- typename std::vector<std::pair<PD<T>::Interval, unsigned int> >::size_type size() const { return intervals.size(); }
+ void discretize(T delta)
+ {
+ if (delta <= 0)
+ return;
+ for (auto it = begin(); it != end(); ++it)
+ {
+ it->first.birth = std::round(it->first.birth/delta)*delta;
+ it->first.death = std::round(it->first.death/delta)*delta;
+ }
+ }
+
+ void compress_and_sort()
+ {
+ std::map<Interval<T>, unsigned int> tmp;
+ for (auto it = cbegin(); it != cend(); ++it)
+ {
+ auto existing = tmp.find(it->first);
+ if (existing == tmp.end())
+ {
+ tmp.insert(*it);
+ }
+ else
+ {
+ existing->second += it->second;
+ }
+ }
+ intervals = std::vector<std::pair<Interval<T>, unsigned int> >(tmp.cbegin(), tmp.cend());
+ }
+
+ unsigned int size() const
+ {
+ unsigned int ret = 0;
+ for (auto it = cbegin(); it != cend(); ++it)
+ ret += it->second;
+ return ret;
+ }
+
+ inline typename std::vector<std::pair<Interval<T>, unsigned int> >::size_type size_2() const { return intervals.size(); }
- using Iterator = typename std::vector<std::pair<PD<T>::Interval, unsigned int> >::const_iterator;
+ using Iterator = typename std::vector<std::pair<Interval<T>, unsigned int> >::iterator;
+ using Const_iterator = typename std::vector<std::pair<Interval<T>, unsigned int> >::const_iterator;
- inline Iterator cbegin() const { return intervals.cbegin(); }
- inline Iterator cend() const { return intervals.cend(); }
+ inline Const_iterator cbegin() const { return intervals.cbegin(); }
+ inline Const_iterator cend() const { return intervals.cend(); }
+ inline Iterator begin() { return intervals.begin(); }
+ inline Iterator end() { return intervals.end(); }
private:
- std::vector<std::pair<PD<T>::Interval, unsigned int> > intervals;
+ std::vector<std::pair<Interval<T>, unsigned int> > intervals;
};
-inline double sqdist(PD<double>::Interval x, PD<double>::Interval y)
+template <typename T> inline T sqdist(Interval<T> x, Interval<T> y)
{
return (x.birth - y.birth)*(x.birth - y.birth) + (x.death - y.death)*(x.death - y.death);
}
@@ -68,7 +118,7 @@ template <typename T> T heat_kernel(T sigma, const PD<T> & a, const PD<T> & b)
for (auto it = a.cbegin(); it != a.cend(); ++it)
{
auto x = *it;
- typename PD<T>::Interval xbar;
+ Interval<T> xbar;
xbar.birth = x.first.death;
xbar.death = x.first.birth;
diff --git a/src/main.cpp b/src/main.cpp
index 448ae5a..0626db5 100644
--- a/src/main.cpp
+++ b/src/main.cpp
@@ -17,10 +17,12 @@ void print_help(const std::string & invocation)
std::cout << " If the same file name is given twice, the computation will exploit symmetry." << std::endl;
std::cout << " --dimension, -d dim" << std::endl;
std::cout << " Mandatory integer. Degree to read from the persistence diagrams." << std::endl;
- std::cout << " --sigma, -d s" << std::endl;
+ std::cout << " --sigma, -s σ" << std::endl;
std::cout << " Mandatory decimal. σ parameter in heat kernel." << std::endl;
std::cout << " --finitize, -f x" << std::endl;
std::cout << " Optional decimal. Finitize all infinite intervals to scale x. If not given (default), infinite intervals are ignored." << std::endl;
+ std::cout << " --discretize, -t x" << std::endl;
+ std::cout << " Optional decimal. Discretize to and x-by-x grid." << std::endl;
std::cout << " --out, -o file" << std::endl;
std::cout << " Mandatory output file name. Plain text." << std::endl;
std::cout << " --chunk, -c c" << std::endl;
@@ -77,6 +79,7 @@ int main(int argc, char ** argv)
int dim = -1;
double sigma = std::numeric_limits<double>::quiet_NaN();
double finitization = std::numeric_limits<double>::infinity();
+ double discretization = 0;
std::string out_file_name;
int chunk_size = 10;
@@ -114,6 +117,13 @@ int main(int argc, char ** argv)
else
arg_fail(invocation, world_rank, "Missing argument for --finitize.");
}
+ else if (arg == std::string("--discretize") || arg == std::string("-x"))
+ {
+ if (i < argc - 1)
+ discretization = std::atof(argv[++i]);
+ else
+ arg_fail(invocation, world_rank, "Missing argument for --discretize.");
+ }
else if (arg == std::string("--out") || arg == std::string("-o"))
{
if (i < argc - 1)
@@ -151,6 +161,8 @@ int main(int argc, char ** argv)
arg_fail(invocation, world_rank, "σ must be positive.");
if (finitization <= 0)
arg_fail(invocation, world_rank, "Finitization must be positive.");
+ if (discretization < 0)
+ arg_fail(invocation, world_rank, "Discretization must be non-negative.");
if (out_file_name == std::string(""))
arg_fail(invocation, world_rank, "Need output file.");
if (chunk_size <= 0)
@@ -307,7 +319,7 @@ int main(int argc, char ** argv)
j = idxs[k].second;
int load_status = 1;
-
+
if (i != i_prev)
{
pd_1.clear();
@@ -317,6 +329,13 @@ int main(int argc, char ** argv)
fail(world_rank, std::string("Failed to load file ") + files_1[i] + std::string("."));
pd_1.finitize(finitization);
+ if (discretization > 0)
+ {
+ std::cout << pd_1.size_2() << " --> ";
+ pd_1.discretize(discretization);
+ pd_1.compress_and_sort();
+ std::cout << pd_1.size_2() << std::endl;
+ }
}
if (j != j_prev)
@@ -328,6 +347,13 @@ int main(int argc, char ** argv)
fail(world_rank, std::string("Failed to load file ") + files_2[j] + std::string("."));
pd_2.finitize(finitization);
+ if (discretization > 0)
+ {
+ std::cout << pd_2.size_2() << " --> ";
+ pd_2.discretize(discretization);
+ pd_2.compress_and_sort();
+ std::cout << pd_2.size_2() << std::endl;
+ }
}
result[k - work[0]] = heat_kernel<double>(sigma, pd_1, pd_2);