#include #include #include #include #include #include "../parallel/tbb.h" #include "def_debug_ws.h" template dnn::KDTree:: KDTree(const Traits& traits, HandleContainer&& handles, double _wassersteinPower): traits_(traits), tree_(std::move(handles)), wassersteinPower(_wassersteinPower) { assert(wassersteinPower >= 1.0); init(); } template template dnn::KDTree:: KDTree(const Traits& traits, const Range& range, double _wassersteinPower): traits_(traits), wassersteinPower(_wassersteinPower) { assert( wassersteinPower >= 1.0); init(range); } template template void dnn::KDTree:: init(const Range& range) { size_t sz = std::distance(std::begin(range), std::end(range)); tree_.reserve(sz); weights_.resize(sz, 0); subtree_weights_.resize(sz, 0); for (PointHandle h : range) tree_.push_back(h); init(); } template void dnn::KDTree:: init() { if (tree_.empty()) return; #if defined(TBB) task_group g; g.run(OrderTree(tree_.begin(), tree_.end(), 0, traits())); g.wait(); #else OrderTree(tree_.begin(), tree_.end(), 0, traits()).serial(); #endif for (size_t i = 0; i < tree_.size(); ++i) indices_[tree_[i]] = i; } template struct dnn::KDTree::OrderTree { OrderTree(HCIterator b_, HCIterator e_, size_t i_, const Traits& traits_): b(b_), e(e_), i(i_), traits(traits_) {} void operator()() const { if (e - b < 1000) { serial(); return; } HCIterator m = b + (e - b)/2; CoordinateComparison cmp(i, traits); std::nth_element(b,m,e, cmp); size_t next_i = (i + 1) % traits.dimension(); task_group g; if (b < m - 1) g.run(OrderTree(b, m, next_i, traits)); if (e > m + 2) g.run(OrderTree(m+1, e, next_i, traits)); g.wait(); } void serial() const { std::queue q; q.push(KDTreeNode(b,e,i)); while (!q.empty()) { HCIterator b, e; size_t i; std::tie(b,e,i) = q.front(); q.pop(); HCIterator m = b + (e - b)/2; CoordinateComparison cmp(i, traits); std::nth_element(b,m,e, cmp); size_t next_i = (i + 1) % traits.dimension(); // Replace with a size condition instead? if (b < m - 1) q.push(KDTreeNode(b, m, next_i)); if (e - m > 2) q.push(KDTreeNode(m+1, e, next_i)); } } HCIterator b, e; size_t i; const Traits& traits; }; template template void dnn::KDTree:: search(PointHandle q, ResultsFunctor& rf) const { typedef typename HandleContainer::const_iterator HCIterator; typedef std::tuple KDTreeNode; if (tree_.empty()) return; DistanceType D = std::numeric_limits::infinity(); // TODO: use tbb::scalable_allocator for the queue std::queue nodes; nodes.push(KDTreeNode(tree_.begin(), tree_.end(), 0)); while (!nodes.empty()) { HCIterator b, e; size_t i; std::tie(b,e,i) = nodes.front(); nodes.pop(); CoordinateComparison cmp(i, traits()); i = (i + 1) % traits().dimension(); HCIterator m = b + (e - b)/2; DistanceType dist = pow(traits().distance(q, *m), wassersteinPower) + weights_[m - tree_.begin()]; D = rf(*m, dist); // we are really searching w.r.t L_\infty ball; could prune better with an L_2 ball Coordinate diff = cmp.diff(q, *m); // diff returns signed distance DistanceType diffToWasserPower = (diff > 0 ? 1.0 : -1.0) * pow(fabs(diff), wassersteinPower); size_t lm = m + 1 + (e - (m+1))/2 - tree_.begin(); if (e > m + 1 && diffToWasserPower - subtree_weights_[lm] >= -D) { nodes.push(KDTreeNode(m+1, e, i)); } size_t rm = b + (m - b) / 2 - tree_.begin(); if (b < m && diffToWasserPower + subtree_weights_[rm] <= D) { nodes.push(KDTreeNode(b, m, i)); } } } template void dnn::KDTree:: increase_weight(PointHandle p, DistanceType w) { size_t idx = indices_[p]; // weight should only increase assert( weights_[idx] <= w ); weights_[idx] = w; typedef std::tuple KDTreeNode; // find the path down the tree to this node // not an ideal strategy, but // it's not clear how to move up from the node in general std::stack s; s.push(KDTreeNode(tree_.begin(),tree_.end())); do { HCIterator b,e; std::tie(b,e) = s.top(); size_t im = b + (e - b)/2 - tree_.begin(); if (idx == im) break; else if (idx < im) s.push(KDTreeNode(b, tree_.begin() + im)); else // idx > im s.push(KDTreeNode(tree_.begin() + im + 1, e)); } while(1); // update subtree_weights_ on the path to the root DistanceType min_w = w; while (!s.empty()) { HCIterator b,e; std::tie(b,e) = s.top(); HCIterator m = b + (e - b)/2; size_t im = m - tree_.begin(); s.pop(); // left and right children if (b < m) { size_t lm = b + (m - b)/2 - tree_.begin(); if (subtree_weights_[lm] < min_w) min_w = subtree_weights_[lm]; } if (e > m + 1) { size_t rm = m + 1 + (e - (m+1))/2 - tree_.begin(); if (subtree_weights_[rm] < min_w) min_w = subtree_weights_[rm]; } if (weights_[im] < min_w) { min_w = weights_[im]; } if (subtree_weights_[im] < min_w ) // increase weight subtree_weights_[im] = min_w; else break; } } template typename dnn::KDTree::HandleDistance dnn::KDTree:: find(PointHandle q) const { dnn::NNRecord nn; search(q, nn); return nn.result; } template typename dnn::KDTree::Result dnn::KDTree:: findR(PointHandle q, DistanceType r) const { dnn::rNNRecord rnn(r); search(q, rnn); std::sort(rnn.result.begin(), rnn.result.end()); return rnn.result; } template typename dnn::KDTree::Result dnn::KDTree:: findK(PointHandle q, size_t k) const { dnn::kNNRecord knn(k); search(q, knn); std::sort(knn.result.begin(), knn.result.end()); return knn.result; } template struct dnn::KDTree::CoordinateComparison { CoordinateComparison(size_t i, const Traits& traits): i_(i), traits_(traits) {} bool operator()(PointHandle p1, PointHandle p2) const { return coordinate(p1) < coordinate(p2); } Coordinate diff(PointHandle p1, PointHandle p2) const { return coordinate(p1) - coordinate(p2); } Coordinate coordinate(PointHandle p) const { return traits_.coordinate(p, i_); } size_t axis() const { return i_; } private: size_t i_; const Traits& traits_; }; template void dnn::KDTree:: printWeights(void) { #ifndef FOR_R_TDA std::cout << "weights_:" << std::endl; for(const auto ph : indices_) { std::cout << "idx = " << ph.second << ": (" << (ph.first)->at(0) << ", " << (ph.first)->at(1) << ") weight = " << weights_[ph.second] << std::endl; } std::cout << "subtree_weights_:" << std::endl; for(size_t idx = 0; idx < subtree_weights_.size(); ++idx) { std::cout << idx << " : " << subtree_weights_[idx] << std::endl; } #endif }