summaryrefslogtreecommitdiff
path: root/geom_bottleneck/include/dnn/local
diff options
context:
space:
mode:
Diffstat (limited to 'geom_bottleneck/include/dnn/local')
-rw-r--r--geom_bottleneck/include/dnn/local/kd-tree.h106
-rw-r--r--geom_bottleneck/include/dnn/local/kd-tree.hpp296
-rw-r--r--geom_bottleneck/include/dnn/local/search-functors.h119
3 files changed, 521 insertions, 0 deletions
diff --git a/geom_bottleneck/include/dnn/local/kd-tree.h b/geom_bottleneck/include/dnn/local/kd-tree.h
new file mode 100644
index 0000000..c1aed2b
--- /dev/null
+++ b/geom_bottleneck/include/dnn/local/kd-tree.h
@@ -0,0 +1,106 @@
+#ifndef HERA_BT_DNN_LOCAL_KD_TREE_H
+#define HERA_BT_DNN_LOCAL_KD_TREE_H
+
+#include "../utils.h"
+#include "search-functors.h"
+
+#include <unordered_map>
+#include <stack>
+
+#include <boost/tuple/tuple.hpp>
+#include <boost/shared_ptr.hpp>
+#include <boost/range/value_type.hpp>
+
+#include <boost/static_assert.hpp>
+#include <boost/type_traits.hpp>
+
+namespace hera {
+namespace bt {
+namespace dnn
+{
+ // Weighted KDTree
+ // Traits_ provides Coordinate, DistanceType, PointType, dimension(), distance(p1,p2), coordinate(p,i)
+ template< class Traits_ >
+ class KDTree
+ {
+ public:
+ typedef Traits_ Traits;
+ typedef hera::bt::dnn::HandleDistance<KDTree> HandleDistance;
+
+ typedef typename Traits::PointType Point;
+ typedef typename Traits::PointHandle PointHandle;
+ typedef typename Traits::Coordinate Coordinate;
+ typedef typename Traits::DistanceType DistanceType;
+ typedef std::vector<PointHandle> HandleContainer;
+ typedef std::vector<HandleDistance> HDContainer; // TODO: use tbb::scalable_allocator
+ typedef HDContainer Result;
+ typedef std::vector<DistanceType> DistanceContainer;
+ typedef std::unordered_map<PointHandle, size_t> HandleMap;
+ //private:
+ typedef typename HandleContainer::iterator HCIterator;
+ typedef std::tuple<HCIterator, HCIterator, size_t, ssize_t> KDTreeNode;
+ typedef std::tuple<HCIterator, HCIterator> KDTreeNodeNoCut;
+
+ //BOOST_STATIC_ASSERT_MSG(has_coordinates<Traits, PointHandle, int>::value, "KDTree requires coordinates");
+
+ public:
+ KDTree(const Traits& traits):
+ traits_(traits) {}
+
+ KDTree(const Traits& traits, HandleContainer&& handles);
+
+ template<class Range>
+ KDTree(const Traits& traits, const Range& range);
+
+ template<class Range>
+ void init(const Range& range);
+
+ HandleDistance find(PointHandle q) const;
+ Result findR(PointHandle q, DistanceType r) const; // all neighbors within r
+ Result findFirstR(PointHandle q, DistanceType r) const; // first neighbor within r
+ Result findK(PointHandle q, size_t k) const; // k nearest neighbors
+
+ HandleDistance find(const Point& q) const { return find(traits().handle(q)); }
+ Result findR(const Point& q, DistanceType r) const { return findR(traits().handle(q), r); }
+ Result findFirstR(const Point& q, DistanceType r) const { return findFirstR(traits().handle(q), r); }
+ Result findK(const Point& q, size_t k) const { return findK(traits().handle(q), k); }
+
+
+
+ template<class ResultsFunctor>
+ void search(PointHandle q, ResultsFunctor& rf) const;
+
+ const Traits& traits() const { return traits_; }
+
+ void get_path_to_root(const size_t idx, std::stack<KDTreeNodeNoCut>& s);
+ // to support deletion
+ void init_n_elems();
+ void delete_point(const size_t idx);
+ void delete_point(PointHandle p);
+ void update_n_elems(const ssize_t idx, const int delta);
+ void increase_n_elems(const ssize_t idx);
+ void decrease_n_elems(const ssize_t idx);
+ size_t get_num_points() const { return num_points_; }
+ //private:
+ void init();
+
+
+ struct CoordinateComparison;
+ struct OrderTree;
+
+ //private:
+ Traits traits_;
+ HandleContainer tree_;
+ std::vector<char> delete_flags_;
+ std::vector<int> subtree_n_elems;
+ HandleMap indices_;
+ std::vector<ssize_t> parents_;
+
+ size_t num_points_;
+ };
+} // dnn
+} // bt
+} // hera
+#include "kd-tree.hpp"
+
+#endif
diff --git a/geom_bottleneck/include/dnn/local/kd-tree.hpp b/geom_bottleneck/include/dnn/local/kd-tree.hpp
new file mode 100644
index 0000000..249fa55
--- /dev/null
+++ b/geom_bottleneck/include/dnn/local/kd-tree.hpp
@@ -0,0 +1,296 @@
+#include <boost/range/counting_range.hpp>
+#include <boost/range/algorithm_ext/push_back.hpp>
+#include <boost/range.hpp>
+
+#include <queue>
+
+#include "../parallel/tbb.h"
+
+template<class T>
+hera::bt::dnn::KDTree<T>::KDTree(const Traits& traits, HandleContainer&& handles):
+ traits_(traits),
+ tree_(std::move(handles)),
+ delete_flags_(handles.size(), static_cast<char>(0) ),
+ subtree_n_elems(handles.size(), static_cast<size_t>(0)),
+ num_points_(handles.size())
+{
+ init();
+}
+
+template<class T>
+template<class Range>
+hera::bt::dnn::KDTree<T>::KDTree(const Traits& traits, const Range& range):
+ traits_(traits)
+{
+ init(range);
+}
+
+template<class T>
+template<class Range>
+void hera::bt::dnn::KDTree<T>::init(const Range& range)
+{
+ size_t sz = std::distance(std::begin(range), std::end(range));
+ subtree_n_elems = std::vector<int>(sz, 0);
+ delete_flags_ = std::vector<char>(sz, 0);
+ num_points_ = sz;
+ tree_.reserve(sz);
+ for (PointHandle h : range)
+ tree_.push_back(h);
+ parents_.resize(sz, -1);
+ init();
+}
+
+template<class T>
+void hera::bt::dnn::KDTree<T>::init()
+{
+ if (tree_.empty())
+ return;
+
+#if defined(TBB)
+ task_group g;
+ g.run(OrderTree(this, tree_.begin(), tree_.end(), -1, 0, traits()));
+ g.wait();
+#else
+ OrderTree(this, tree_.begin(), tree_.end(), -1, 0, traits()).serial();
+#endif
+
+ for (size_t i = 0; i < tree_.size(); ++i)
+ indices_[tree_[i]] = i;
+ init_n_elems();
+}
+
+template<class T>
+struct
+hera::bt::dnn::KDTree<T>::OrderTree
+{
+ OrderTree(KDTree* tree_, HCIterator b_, HCIterator e_, ssize_t p_, size_t i_, const Traits& traits_):
+ tree(tree_), b(b_), e(e_), p(p_), i(i_), traits(traits_) {}
+
+ void operator()() const
+ {
+ if (e - b < 1000)
+ {
+ serial();
+ return;
+ }
+
+ HCIterator m = b + (e - b)/2;
+ ssize_t im = m - tree->tree_.begin();
+ tree->parents_[im] = p;
+
+ CoordinateComparison cmp(i, traits);
+ std::nth_element(b,m,e, cmp);
+ size_t next_i = (i + 1) % traits.dimension();
+
+ task_group g;
+ if (b < m - 1) g.run(OrderTree(tree, b, m, im, next_i, traits));
+ if (e > m + 2) g.run(OrderTree(tree, m+1, e, im, next_i, traits));
+ g.wait();
+ }
+
+ void serial() const
+ {
+ std::queue<KDTreeNode> q;
+ q.push(KDTreeNode(b,e,p,i));
+ while (!q.empty())
+ {
+ HCIterator b, e; ssize_t p; size_t i;
+ std::tie(b,e,p,i) = q.front();
+ q.pop();
+ HCIterator m = b + (e - b)/2;
+ ssize_t im = m - tree->tree_.begin();
+ tree->parents_[im] = p;
+
+ CoordinateComparison cmp(i, traits);
+ std::nth_element(b,m,e, cmp);
+ size_t next_i = (i + 1) % traits.dimension();
+
+ // Replace with a size condition instead?
+ if (b < m - 1)
+ q.push(KDTreeNode(b, m, im, next_i));
+ else if (b < m)
+ tree->parents_[im - 1] = im;
+ if (e > m + 2)
+ q.push(KDTreeNode(m+1, e, im, next_i));
+ else if (e > m + 1)
+ tree->parents_[im + 1] = im;
+ }
+ }
+
+ KDTree* tree;
+ HCIterator b, e;
+ ssize_t p;
+ size_t i;
+ const Traits& traits;
+};
+
+template<class T>
+void hera::bt::dnn::KDTree<T>::update_n_elems(ssize_t idx, const int delta)
+// add delta to the number of points in node idx and update subtree_n_elems
+// for all parents of the node idx
+{
+ //std::cout << "subtree_n_elems.size = " << subtree_n_elems.size() << std::endl;
+ // update the node itself
+ while (idx != -1)
+ {
+ //std::cout << idx << std::endl;
+ subtree_n_elems[idx] += delta;
+ idx = parents_[idx];
+ }
+}
+
+template<class T>
+void hera::bt::dnn::KDTree<T>::increase_n_elems(const ssize_t idx)
+{
+ update_n_elems(idx, static_cast<ssize_t>(1));
+}
+
+template<class T>
+void hera::bt::dnn::KDTree<T>::decrease_n_elems(const ssize_t idx)
+{
+ update_n_elems(idx, static_cast<ssize_t>(-1));
+}
+
+template<class T>
+void hera::bt::dnn::KDTree<T>::init_n_elems()
+{
+ for(size_t idx = 0; idx < tree_.size(); ++idx) {
+ increase_n_elems(idx);
+ }
+}
+
+
+template<class T>
+template<class ResultsFunctor>
+void hera::bt::dnn::KDTree<T>::search(PointHandle q, ResultsFunctor& rf) const
+{
+ typedef typename HandleContainer::const_iterator HCIterator;
+ typedef std::tuple<HCIterator, HCIterator, size_t> KDTreeNode;
+
+ if (tree_.empty())
+ return;
+
+ DistanceType D = std::numeric_limits<DistanceType>::infinity();
+
+ // TODO: use tbb::scalable_allocator for the queue
+ std::queue<KDTreeNode> nodes;
+
+ nodes.push(KDTreeNode(tree_.begin(), tree_.end(), 0));
+
+ //std::cout << "started kdtree::search" << std::endl;
+
+ while (!nodes.empty())
+ {
+ HCIterator b, e; size_t i;
+ std::tie(b,e,i) = nodes.front();
+ nodes.pop();
+
+ CoordinateComparison cmp(i, traits());
+ i = (i + 1) % traits().dimension();
+
+ HCIterator m = b + (e - b)/2;
+ size_t m_idx = m - tree_.begin();
+ // ignore deleted points
+ if ( delete_flags_[m_idx] == 0 ) {
+ DistanceType dist = traits().distance(q, *m);
+ // + weights_[m - tree_.begin()];
+ //std::cout << "Supplied to functor: m : ";
+ //std::cout << "(" << (*(*m))[0] << ", " << (*(*m))[1] << ")";
+ //std::cout << " and q : ";
+ //std::cout << "(" << (*q)[0] << ", " << (*q)[1] << ")" << std::endl;
+ //std::cout << "dist^q + weight = " << dist << std::endl;
+ //std::cout << "weight = " << weights_[m - tree_.begin()] << std::endl;
+ //std::cout << "dist = " << traits().distance(q, *m) << std::endl;
+ //std::cout << "dist^q = " << pow(traits().distance(q, *m), wassersteinPower) << std::endl;
+
+ D = rf(*m, dist);
+ }
+ // we are really searching w.r.t L_\infty ball; could prune better with an L_2 ball
+ Coordinate diff = cmp.diff(q, *m); // diff returns signed distance
+ DistanceType diffToWasserPower = (diff > 0 ? 1.0 : -1.0) * fabs(diff);
+
+ size_t lm = m + 1 + (e - (m+1))/2 - tree_.begin();
+ if ( subtree_n_elems[lm] > 0 ) {
+ if (e > m + 1 && diffToWasserPower >= -D) {
+ nodes.push(KDTreeNode(m+1, e, i));
+ }
+ }
+
+ size_t rm = b + (m - b) / 2 - tree_.begin();
+ if ( subtree_n_elems[rm] > 0 ) {
+ if (b < m && diffToWasserPower <= D) {
+ nodes.push(KDTreeNode(b, m, i));
+ }
+ }
+ }
+ //std::cout << "exited kdtree::search" << std::endl;
+}
+
+template<class T>
+typename hera::bt::dnn::KDTree<T>::HandleDistance hera::bt::dnn::KDTree<T>::find(PointHandle q) const
+{
+ hera::bt::dnn::NNRecord<HandleDistance> nn;
+ search(q, nn);
+ return nn.result;
+}
+
+template<class T>
+typename hera::bt::dnn::KDTree<T>::Result hera::bt::dnn::KDTree<T>::findR(PointHandle q, DistanceType r) const
+{
+ hera::bt::dnn::rNNRecord<HandleDistance> rnn(r);
+ search(q, rnn);
+ //std::sort(rnn.result.begin(), rnn.result.end());
+ return rnn.result;
+}
+
+template<class T>
+typename hera::bt::dnn::KDTree<T>::Result hera::bt::dnn::KDTree<T>::findFirstR(PointHandle q, DistanceType r) const
+{
+ hera::bt::dnn::firstrNNRecord<HandleDistance> rnn(r);
+ search(q, rnn);
+ return rnn.result;
+}
+
+template<class T>
+typename hera::bt::dnn::KDTree<T>::Result hera::bt::dnn::KDTree<T>::findK(PointHandle q, size_t k) const
+{
+ hera::bt::dnn::kNNRecord<HandleDistance> knn(k);
+ search(q, knn);
+ // do we need this???
+ std::sort(knn.result.begin(), knn.result.end());
+ return knn.result;
+}
+
+template<class T>
+struct hera::bt::dnn::KDTree<T>::CoordinateComparison
+{
+ CoordinateComparison(size_t i, const Traits& traits):
+ i_(i), traits_(traits) {}
+
+ bool operator()(PointHandle p1, PointHandle p2) const { return coordinate(p1) < coordinate(p2); }
+ Coordinate diff(PointHandle p1, PointHandle p2) const { return coordinate(p1) - coordinate(p2); }
+
+ Coordinate coordinate(PointHandle p) const { return traits_.coordinate(p, i_); }
+ size_t axis() const { return i_; }
+
+ private:
+ size_t i_;
+ const Traits& traits_;
+};
+
+template<class T>
+void hera::bt::dnn::KDTree<T>::delete_point(const size_t idx)
+{
+ // prevent double deletion
+ assert(delete_flags_[idx] == 0);
+ delete_flags_[idx] = 1;
+ decrease_n_elems(idx);
+ --num_points_;
+}
+
+template<class T>
+void hera::bt::dnn::KDTree<T>::delete_point(PointHandle p)
+{
+ delete_point(indices_[p]);
+}
+
diff --git a/geom_bottleneck/include/dnn/local/search-functors.h b/geom_bottleneck/include/dnn/local/search-functors.h
new file mode 100644
index 0000000..63ad11d
--- /dev/null
+++ b/geom_bottleneck/include/dnn/local/search-functors.h
@@ -0,0 +1,119 @@
+#ifndef HERA_BT_DNN_LOCAL_SEARCH_FUNCTORS_H
+#define HERA_BT_DNN_LOCAL_SEARCH_FUNCTORS_H
+
+#include <boost/range/algorithm/heap_algorithm.hpp>
+
+namespace hera
+{
+namespace bt
+{
+namespace dnn
+{
+
+template<class NN>
+struct HandleDistance
+{
+ typedef typename NN::PointHandle PointHandle;
+ typedef typename NN::DistanceType DistanceType;
+ typedef typename NN::HDContainer HDContainer;
+
+ HandleDistance() {}
+ HandleDistance(PointHandle pp, DistanceType dd):
+ p(pp), d(dd) {}
+ bool operator<(const HandleDistance& other) const { return d < other.d; }
+
+ PointHandle p;
+ DistanceType d;
+};
+
+template<class HandleDistance>
+struct NNRecord
+{
+ typedef typename HandleDistance::PointHandle PointHandle;
+ typedef typename HandleDistance::DistanceType DistanceType;
+
+ NNRecord() { result.d = std::numeric_limits<DistanceType>::infinity(); }
+ DistanceType operator()(PointHandle p, DistanceType d) { if (d < result.d) { result.p = p; result.d = d; } return result.d; }
+ HandleDistance result;
+};
+
+template<class HandleDistance>
+struct rNNRecord
+{
+ typedef typename HandleDistance::PointHandle PointHandle;
+ typedef typename HandleDistance::DistanceType DistanceType;
+ typedef typename HandleDistance::HDContainer HDContainer;
+
+ rNNRecord(DistanceType r_): r(r_) {}
+ DistanceType operator()(PointHandle p, DistanceType d)
+ {
+ if (d <= r)
+ result.push_back(HandleDistance(p,d));
+ return r;
+ }
+
+ DistanceType r;
+ HDContainer result;
+};
+
+template<class HandleDistance>
+struct firstrNNRecord
+{
+ typedef typename HandleDistance::PointHandle PointHandle;
+ typedef typename HandleDistance::DistanceType DistanceType;
+ typedef typename HandleDistance::HDContainer HDContainer;
+
+ firstrNNRecord(DistanceType r_): r(r_) {}
+
+ DistanceType operator()(PointHandle p, DistanceType d)
+ {
+ if (d <= r) {
+ result.push_back(HandleDistance(p,d));
+ return -100000000.0;
+ } else {
+ return r;
+ }
+ }
+
+ DistanceType r;
+ HDContainer result;
+};
+
+
+template<class HandleDistance>
+struct kNNRecord
+{
+ typedef typename HandleDistance::PointHandle PointHandle;
+ typedef typename HandleDistance::DistanceType DistanceType;
+ typedef typename HandleDistance::HDContainer HDContainer;
+
+ kNNRecord(unsigned k_): k(k_) {}
+ DistanceType operator()(PointHandle p, DistanceType d)
+ {
+ if (result.size() < k)
+ {
+ result.push_back(HandleDistance(p,d));
+ boost::push_heap(result);
+ if (result.size() < k)
+ return std::numeric_limits<DistanceType>::infinity();
+ } else if (d < result[0].d)
+ {
+ boost::pop_heap(result);
+ result.back() = HandleDistance(p,d);
+ boost::push_heap(result);
+ }
+ if ( result.size() > 1 ) {
+ assert( result[0].d >= result[1].d );
+ }
+ return result[0].d;
+ }
+
+ unsigned k;
+ HDContainer result;
+};
+
+} // dnn
+} // bt
+} // hera
+
+#endif // HERA_BT_DNN_LOCAL_SEARCH_FUNCTORS_H