summaryrefslogtreecommitdiff
path: root/include/pd.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'include/pd.hpp')
-rw-r--r--include/pd.hpp82
1 files changed, 66 insertions, 16 deletions
diff --git a/include/pd.hpp b/include/pd.hpp
index 04b8609..2641125 100644
--- a/include/pd.hpp
+++ b/include/pd.hpp
@@ -3,22 +3,32 @@
#include <cstdint>
#include <fstream>
#include <limits>
+#include <map>
#include <string>
#include <vector>
#include "misc.hpp"
-template <typename T> class PD
+template <typename T> class Interval
{
public:
- class Interval
- {
- public:
- T birth;
- T death;
- };
+ inline T length() const { return death - birth; }
+ T birth;
+ T death;
+};
-
+template <typename T> inline bool operator==(Interval<T> x, Interval<T> y) { return x.birth == y.birth && x.death == y.death; }
+template <typename T> inline bool operator!=(Interval<T> x, Interval<T> y) { return !(x == y); }
+template <typename T> inline bool operator<(Interval<T> x, Interval<T> y)
+{
+ return (x.length() < y.length()) ||
+ ((x.length() == y.length()) && (x.birth < y.birth)) ||
+ (((x.length() == y.length()) && (x.birth == y.birth) && (x.death < y.death)));
+}
+
+template <typename T> class PD
+{
+public:
PD() : intervals()
{
};
@@ -27,7 +37,7 @@ public:
{
if (m > 0 && b < d)
{
- Interval interval;
+ Interval<T> interval;
interval.birth = b;
interval.death = d;
intervals.push_back(std::make_pair(interval, m));
@@ -45,18 +55,58 @@ public:
}
}
- typename std::vector<std::pair<PD<T>::Interval, unsigned int> >::size_type size() const { return intervals.size(); }
+ void discretize(T delta)
+ {
+ if (delta <= 0)
+ return;
+ for (auto it = begin(); it != end(); ++it)
+ {
+ it->first.birth = std::round(it->first.birth/delta)*delta;
+ it->first.death = std::round(it->first.death/delta)*delta;
+ }
+ }
+
+ void compress_and_sort()
+ {
+ std::map<Interval<T>, unsigned int> tmp;
+ for (auto it = cbegin(); it != cend(); ++it)
+ {
+ auto existing = tmp.find(it->first);
+ if (existing == tmp.end())
+ {
+ tmp.insert(*it);
+ }
+ else
+ {
+ existing->second += it->second;
+ }
+ }
+ intervals = std::vector<std::pair<Interval<T>, unsigned int> >(tmp.cbegin(), tmp.cend());
+ }
+
+ unsigned int size() const
+ {
+ unsigned int ret = 0;
+ for (auto it = cbegin(); it != cend(); ++it)
+ ret += it->second;
+ return ret;
+ }
+
+ inline typename std::vector<std::pair<Interval<T>, unsigned int> >::size_type size_2() const { return intervals.size(); }
- using Iterator = typename std::vector<std::pair<PD<T>::Interval, unsigned int> >::const_iterator;
+ using Iterator = typename std::vector<std::pair<Interval<T>, unsigned int> >::iterator;
+ using Const_iterator = typename std::vector<std::pair<Interval<T>, unsigned int> >::const_iterator;
- inline Iterator cbegin() const { return intervals.cbegin(); }
- inline Iterator cend() const { return intervals.cend(); }
+ inline Const_iterator cbegin() const { return intervals.cbegin(); }
+ inline Const_iterator cend() const { return intervals.cend(); }
+ inline Iterator begin() { return intervals.begin(); }
+ inline Iterator end() { return intervals.end(); }
private:
- std::vector<std::pair<PD<T>::Interval, unsigned int> > intervals;
+ std::vector<std::pair<Interval<T>, unsigned int> > intervals;
};
-inline double sqdist(PD<double>::Interval x, PD<double>::Interval y)
+template <typename T> inline T sqdist(Interval<T> x, Interval<T> y)
{
return (x.birth - y.birth)*(x.birth - y.birth) + (x.death - y.death)*(x.death - y.death);
}
@@ -68,7 +118,7 @@ template <typename T> T heat_kernel(T sigma, const PD<T> & a, const PD<T> & b)
for (auto it = a.cbegin(); it != a.cend(); ++it)
{
auto x = *it;
- typename PD<T>::Interval xbar;
+ Interval<T> xbar;
xbar.birth = x.first.death;
xbar.death = x.first.birth;