From 0cc35ad04f9c2997014d7cf62a12f697e79fb534 Mon Sep 17 00:00:00 2001 From: Arnur Nigmetov Date: Sat, 20 Jan 2018 19:11:29 +0100 Subject: Major rewrite, templatized version --- geom_bottleneck/include/dnn/local/kd-tree.h | 106 ++++++++ geom_bottleneck/include/dnn/local/kd-tree.hpp | 296 +++++++++++++++++++++ .../include/dnn/local/search-functors.h | 119 +++++++++ 3 files changed, 521 insertions(+) create mode 100644 geom_bottleneck/include/dnn/local/kd-tree.h create mode 100644 geom_bottleneck/include/dnn/local/kd-tree.hpp create mode 100644 geom_bottleneck/include/dnn/local/search-functors.h (limited to 'geom_bottleneck/include/dnn/local') diff --git a/geom_bottleneck/include/dnn/local/kd-tree.h b/geom_bottleneck/include/dnn/local/kd-tree.h new file mode 100644 index 0000000..c1aed2b --- /dev/null +++ b/geom_bottleneck/include/dnn/local/kd-tree.h @@ -0,0 +1,106 @@ +#ifndef HERA_BT_DNN_LOCAL_KD_TREE_H +#define HERA_BT_DNN_LOCAL_KD_TREE_H + +#include "../utils.h" +#include "search-functors.h" + +#include +#include + +#include +#include +#include + +#include +#include + +namespace hera { +namespace bt { +namespace dnn +{ + // Weighted KDTree + // Traits_ provides Coordinate, DistanceType, PointType, dimension(), distance(p1,p2), coordinate(p,i) + template< class Traits_ > + class KDTree + { + public: + typedef Traits_ Traits; + typedef hera::bt::dnn::HandleDistance HandleDistance; + + typedef typename Traits::PointType Point; + typedef typename Traits::PointHandle PointHandle; + typedef typename Traits::Coordinate Coordinate; + typedef typename Traits::DistanceType DistanceType; + typedef std::vector HandleContainer; + typedef std::vector HDContainer; // TODO: use tbb::scalable_allocator + typedef HDContainer Result; + typedef std::vector DistanceContainer; + typedef std::unordered_map HandleMap; + //private: + typedef typename HandleContainer::iterator HCIterator; + typedef std::tuple KDTreeNode; + typedef std::tuple KDTreeNodeNoCut; + + //BOOST_STATIC_ASSERT_MSG(has_coordinates::value, "KDTree requires coordinates"); + + public: + KDTree(const Traits& traits): + traits_(traits) {} + + KDTree(const Traits& traits, HandleContainer&& handles); + + template + KDTree(const Traits& traits, const Range& range); + + template + void init(const Range& range); + + HandleDistance find(PointHandle q) const; + Result findR(PointHandle q, DistanceType r) const; // all neighbors within r + Result findFirstR(PointHandle q, DistanceType r) const; // first neighbor within r + Result findK(PointHandle q, size_t k) const; // k nearest neighbors + + HandleDistance find(const Point& q) const { return find(traits().handle(q)); } + Result findR(const Point& q, DistanceType r) const { return findR(traits().handle(q), r); } + Result findFirstR(const Point& q, DistanceType r) const { return findFirstR(traits().handle(q), r); } + Result findK(const Point& q, size_t k) const { return findK(traits().handle(q), k); } + + + + template + void search(PointHandle q, ResultsFunctor& rf) const; + + const Traits& traits() const { return traits_; } + + void get_path_to_root(const size_t idx, std::stack& s); + // to support deletion + void init_n_elems(); + void delete_point(const size_t idx); + void delete_point(PointHandle p); + void update_n_elems(const ssize_t idx, const int delta); + void increase_n_elems(const ssize_t idx); + void decrease_n_elems(const ssize_t idx); + size_t get_num_points() const { return num_points_; } + //private: + void init(); + + + struct CoordinateComparison; + struct OrderTree; + + //private: + Traits traits_; + HandleContainer tree_; + std::vector delete_flags_; + std::vector subtree_n_elems; + HandleMap indices_; + std::vector parents_; + + size_t num_points_; + }; +} // dnn +} // bt +} // hera +#include "kd-tree.hpp" + +#endif diff --git a/geom_bottleneck/include/dnn/local/kd-tree.hpp b/geom_bottleneck/include/dnn/local/kd-tree.hpp new file mode 100644 index 0000000..249fa55 --- /dev/null +++ b/geom_bottleneck/include/dnn/local/kd-tree.hpp @@ -0,0 +1,296 @@ +#include +#include +#include + +#include + +#include "../parallel/tbb.h" + +template +hera::bt::dnn::KDTree::KDTree(const Traits& traits, HandleContainer&& handles): + traits_(traits), + tree_(std::move(handles)), + delete_flags_(handles.size(), static_cast(0) ), + subtree_n_elems(handles.size(), static_cast(0)), + num_points_(handles.size()) +{ + init(); +} + +template +template +hera::bt::dnn::KDTree::KDTree(const Traits& traits, const Range& range): + traits_(traits) +{ + init(range); +} + +template +template +void hera::bt::dnn::KDTree::init(const Range& range) +{ + size_t sz = std::distance(std::begin(range), std::end(range)); + subtree_n_elems = std::vector(sz, 0); + delete_flags_ = std::vector(sz, 0); + num_points_ = sz; + tree_.reserve(sz); + for (PointHandle h : range) + tree_.push_back(h); + parents_.resize(sz, -1); + init(); +} + +template +void hera::bt::dnn::KDTree::init() +{ + if (tree_.empty()) + return; + +#if defined(TBB) + task_group g; + g.run(OrderTree(this, tree_.begin(), tree_.end(), -1, 0, traits())); + g.wait(); +#else + OrderTree(this, tree_.begin(), tree_.end(), -1, 0, traits()).serial(); +#endif + + for (size_t i = 0; i < tree_.size(); ++i) + indices_[tree_[i]] = i; + init_n_elems(); +} + +template +struct +hera::bt::dnn::KDTree::OrderTree +{ + OrderTree(KDTree* tree_, HCIterator b_, HCIterator e_, ssize_t p_, size_t i_, const Traits& traits_): + tree(tree_), b(b_), e(e_), p(p_), i(i_), traits(traits_) {} + + void operator()() const + { + if (e - b < 1000) + { + serial(); + return; + } + + HCIterator m = b + (e - b)/2; + ssize_t im = m - tree->tree_.begin(); + tree->parents_[im] = p; + + CoordinateComparison cmp(i, traits); + std::nth_element(b,m,e, cmp); + size_t next_i = (i + 1) % traits.dimension(); + + task_group g; + if (b < m - 1) g.run(OrderTree(tree, b, m, im, next_i, traits)); + if (e > m + 2) g.run(OrderTree(tree, m+1, e, im, next_i, traits)); + g.wait(); + } + + void serial() const + { + std::queue q; + q.push(KDTreeNode(b,e,p,i)); + while (!q.empty()) + { + HCIterator b, e; ssize_t p; size_t i; + std::tie(b,e,p,i) = q.front(); + q.pop(); + HCIterator m = b + (e - b)/2; + ssize_t im = m - tree->tree_.begin(); + tree->parents_[im] = p; + + CoordinateComparison cmp(i, traits); + std::nth_element(b,m,e, cmp); + size_t next_i = (i + 1) % traits.dimension(); + + // Replace with a size condition instead? + if (b < m - 1) + q.push(KDTreeNode(b, m, im, next_i)); + else if (b < m) + tree->parents_[im - 1] = im; + if (e > m + 2) + q.push(KDTreeNode(m+1, e, im, next_i)); + else if (e > m + 1) + tree->parents_[im + 1] = im; + } + } + + KDTree* tree; + HCIterator b, e; + ssize_t p; + size_t i; + const Traits& traits; +}; + +template +void hera::bt::dnn::KDTree::update_n_elems(ssize_t idx, const int delta) +// add delta to the number of points in node idx and update subtree_n_elems +// for all parents of the node idx +{ + //std::cout << "subtree_n_elems.size = " << subtree_n_elems.size() << std::endl; + // update the node itself + while (idx != -1) + { + //std::cout << idx << std::endl; + subtree_n_elems[idx] += delta; + idx = parents_[idx]; + } +} + +template +void hera::bt::dnn::KDTree::increase_n_elems(const ssize_t idx) +{ + update_n_elems(idx, static_cast(1)); +} + +template +void hera::bt::dnn::KDTree::decrease_n_elems(const ssize_t idx) +{ + update_n_elems(idx, static_cast(-1)); +} + +template +void hera::bt::dnn::KDTree::init_n_elems() +{ + for(size_t idx = 0; idx < tree_.size(); ++idx) { + increase_n_elems(idx); + } +} + + +template +template +void hera::bt::dnn::KDTree::search(PointHandle q, ResultsFunctor& rf) const +{ + typedef typename HandleContainer::const_iterator HCIterator; + typedef std::tuple KDTreeNode; + + if (tree_.empty()) + return; + + DistanceType D = std::numeric_limits::infinity(); + + // TODO: use tbb::scalable_allocator for the queue + std::queue nodes; + + nodes.push(KDTreeNode(tree_.begin(), tree_.end(), 0)); + + //std::cout << "started kdtree::search" << std::endl; + + while (!nodes.empty()) + { + HCIterator b, e; size_t i; + std::tie(b,e,i) = nodes.front(); + nodes.pop(); + + CoordinateComparison cmp(i, traits()); + i = (i + 1) % traits().dimension(); + + HCIterator m = b + (e - b)/2; + size_t m_idx = m - tree_.begin(); + // ignore deleted points + if ( delete_flags_[m_idx] == 0 ) { + DistanceType dist = traits().distance(q, *m); + // + weights_[m - tree_.begin()]; + //std::cout << "Supplied to functor: m : "; + //std::cout << "(" << (*(*m))[0] << ", " << (*(*m))[1] << ")"; + //std::cout << " and q : "; + //std::cout << "(" << (*q)[0] << ", " << (*q)[1] << ")" << std::endl; + //std::cout << "dist^q + weight = " << dist << std::endl; + //std::cout << "weight = " << weights_[m - tree_.begin()] << std::endl; + //std::cout << "dist = " << traits().distance(q, *m) << std::endl; + //std::cout << "dist^q = " << pow(traits().distance(q, *m), wassersteinPower) << std::endl; + + D = rf(*m, dist); + } + // we are really searching w.r.t L_\infty ball; could prune better with an L_2 ball + Coordinate diff = cmp.diff(q, *m); // diff returns signed distance + DistanceType diffToWasserPower = (diff > 0 ? 1.0 : -1.0) * fabs(diff); + + size_t lm = m + 1 + (e - (m+1))/2 - tree_.begin(); + if ( subtree_n_elems[lm] > 0 ) { + if (e > m + 1 && diffToWasserPower >= -D) { + nodes.push(KDTreeNode(m+1, e, i)); + } + } + + size_t rm = b + (m - b) / 2 - tree_.begin(); + if ( subtree_n_elems[rm] > 0 ) { + if (b < m && diffToWasserPower <= D) { + nodes.push(KDTreeNode(b, m, i)); + } + } + } + //std::cout << "exited kdtree::search" << std::endl; +} + +template +typename hera::bt::dnn::KDTree::HandleDistance hera::bt::dnn::KDTree::find(PointHandle q) const +{ + hera::bt::dnn::NNRecord nn; + search(q, nn); + return nn.result; +} + +template +typename hera::bt::dnn::KDTree::Result hera::bt::dnn::KDTree::findR(PointHandle q, DistanceType r) const +{ + hera::bt::dnn::rNNRecord rnn(r); + search(q, rnn); + //std::sort(rnn.result.begin(), rnn.result.end()); + return rnn.result; +} + +template +typename hera::bt::dnn::KDTree::Result hera::bt::dnn::KDTree::findFirstR(PointHandle q, DistanceType r) const +{ + hera::bt::dnn::firstrNNRecord rnn(r); + search(q, rnn); + return rnn.result; +} + +template +typename hera::bt::dnn::KDTree::Result hera::bt::dnn::KDTree::findK(PointHandle q, size_t k) const +{ + hera::bt::dnn::kNNRecord knn(k); + search(q, knn); + // do we need this??? + std::sort(knn.result.begin(), knn.result.end()); + return knn.result; +} + +template +struct hera::bt::dnn::KDTree::CoordinateComparison +{ + CoordinateComparison(size_t i, const Traits& traits): + i_(i), traits_(traits) {} + + bool operator()(PointHandle p1, PointHandle p2) const { return coordinate(p1) < coordinate(p2); } + Coordinate diff(PointHandle p1, PointHandle p2) const { return coordinate(p1) - coordinate(p2); } + + Coordinate coordinate(PointHandle p) const { return traits_.coordinate(p, i_); } + size_t axis() const { return i_; } + + private: + size_t i_; + const Traits& traits_; +}; + +template +void hera::bt::dnn::KDTree::delete_point(const size_t idx) +{ + // prevent double deletion + assert(delete_flags_[idx] == 0); + delete_flags_[idx] = 1; + decrease_n_elems(idx); + --num_points_; +} + +template +void hera::bt::dnn::KDTree::delete_point(PointHandle p) +{ + delete_point(indices_[p]); +} + diff --git a/geom_bottleneck/include/dnn/local/search-functors.h b/geom_bottleneck/include/dnn/local/search-functors.h new file mode 100644 index 0000000..63ad11d --- /dev/null +++ b/geom_bottleneck/include/dnn/local/search-functors.h @@ -0,0 +1,119 @@ +#ifndef HERA_BT_DNN_LOCAL_SEARCH_FUNCTORS_H +#define HERA_BT_DNN_LOCAL_SEARCH_FUNCTORS_H + +#include + +namespace hera +{ +namespace bt +{ +namespace dnn +{ + +template +struct HandleDistance +{ + typedef typename NN::PointHandle PointHandle; + typedef typename NN::DistanceType DistanceType; + typedef typename NN::HDContainer HDContainer; + + HandleDistance() {} + HandleDistance(PointHandle pp, DistanceType dd): + p(pp), d(dd) {} + bool operator<(const HandleDistance& other) const { return d < other.d; } + + PointHandle p; + DistanceType d; +}; + +template +struct NNRecord +{ + typedef typename HandleDistance::PointHandle PointHandle; + typedef typename HandleDistance::DistanceType DistanceType; + + NNRecord() { result.d = std::numeric_limits::infinity(); } + DistanceType operator()(PointHandle p, DistanceType d) { if (d < result.d) { result.p = p; result.d = d; } return result.d; } + HandleDistance result; +}; + +template +struct rNNRecord +{ + typedef typename HandleDistance::PointHandle PointHandle; + typedef typename HandleDistance::DistanceType DistanceType; + typedef typename HandleDistance::HDContainer HDContainer; + + rNNRecord(DistanceType r_): r(r_) {} + DistanceType operator()(PointHandle p, DistanceType d) + { + if (d <= r) + result.push_back(HandleDistance(p,d)); + return r; + } + + DistanceType r; + HDContainer result; +}; + +template +struct firstrNNRecord +{ + typedef typename HandleDistance::PointHandle PointHandle; + typedef typename HandleDistance::DistanceType DistanceType; + typedef typename HandleDistance::HDContainer HDContainer; + + firstrNNRecord(DistanceType r_): r(r_) {} + + DistanceType operator()(PointHandle p, DistanceType d) + { + if (d <= r) { + result.push_back(HandleDistance(p,d)); + return -100000000.0; + } else { + return r; + } + } + + DistanceType r; + HDContainer result; +}; + + +template +struct kNNRecord +{ + typedef typename HandleDistance::PointHandle PointHandle; + typedef typename HandleDistance::DistanceType DistanceType; + typedef typename HandleDistance::HDContainer HDContainer; + + kNNRecord(unsigned k_): k(k_) {} + DistanceType operator()(PointHandle p, DistanceType d) + { + if (result.size() < k) + { + result.push_back(HandleDistance(p,d)); + boost::push_heap(result); + if (result.size() < k) + return std::numeric_limits::infinity(); + } else if (d < result[0].d) + { + boost::pop_heap(result); + result.back() = HandleDistance(p,d); + boost::push_heap(result); + } + if ( result.size() > 1 ) { + assert( result[0].d >= result[1].d ); + } + return result[0].d; + } + + unsigned k; + HDContainer result; +}; + +} // dnn +} // bt +} // hera + +#endif // HERA_BT_DNN_LOCAL_SEARCH_FUNCTORS_H -- cgit v1.2.3