From 6ff8072e8c5a6dc1301e884f5a648a0b63bdd48a Mon Sep 17 00:00:00 2001 From: Arnur Nigmetov Date: Tue, 3 Dec 2019 20:34:28 +0100 Subject: Rename directories for bottleneck and Wasserstein --- wasserstein/include/dnn/local/kd-tree.hpp | 330 ++++++++++++++++++++++++++++++ 1 file changed, 330 insertions(+) create mode 100644 wasserstein/include/dnn/local/kd-tree.hpp (limited to 'wasserstein/include/dnn/local/kd-tree.hpp') diff --git a/wasserstein/include/dnn/local/kd-tree.hpp b/wasserstein/include/dnn/local/kd-tree.hpp new file mode 100644 index 0000000..bdeef45 --- /dev/null +++ b/wasserstein/include/dnn/local/kd-tree.hpp @@ -0,0 +1,330 @@ +#include +#include +#include + +#include +#include + +#include "../parallel/tbb.h" +#include "def_debug_ws.h" + +template +hera::ws::dnn::KDTree:: +KDTree(const Traits& traits, HandleContainer&& handles, double _wassersteinPower): + traits_(traits), tree_(std::move(handles)), wassersteinPower(_wassersteinPower) +{ assert(wassersteinPower >= 1.0); init(); } + +template +template +hera::ws::dnn::KDTree:: +KDTree(const Traits& traits, const Range& range, double _wassersteinPower): + traits_(traits), wassersteinPower(_wassersteinPower) +{ + assert( wassersteinPower >= 1.0); + init(range); +} + +template +template +void +hera::ws::dnn::KDTree:: +init(const Range& range) +{ + size_t sz = std::distance(std::begin(range), std::end(range)); + tree_.reserve(sz); + weights_.resize(sz, 0); + subtree_weights_.resize(sz, 0); + for (PointHandle h : range) + tree_.push_back(h); + init(); +} + +template +void +hera::ws::dnn::KDTree:: +init() +{ + if (tree_.empty()) + return; + +#if defined(TBB) + task_group g; + g.run(OrderTree(tree_.begin(), tree_.end(), 0, traits())); + g.wait(); +#else + OrderTree(tree_.begin(), tree_.end(), 0, traits()).serial(); +#endif + + for (size_t i = 0; i < tree_.size(); ++i) + indices_[tree_[i]] = i; +} + +template +struct +hera::ws::dnn::KDTree::OrderTree +{ + OrderTree(HCIterator b_, HCIterator e_, size_t i_, const Traits& traits_): + b(b_), e(e_), i(i_), traits(traits_) {} + + void operator()() const + { + if (e - b < 1000) + { + serial(); + return; + } + + HCIterator m = b + (e - b)/2; + CoordinateComparison cmp(i, traits); + std::nth_element(b,m,e, cmp); + size_t next_i = (i + 1) % traits.dimension(); + + task_group g; + if (b < m - 1) g.run(OrderTree(b, m, next_i, traits)); + if (e > m + 2) g.run(OrderTree(m+1, e, next_i, traits)); + g.wait(); + } + + void serial() const + { + std::queue q; + q.push(KDTreeNode(b,e,i)); + while (!q.empty()) + { + HCIterator b, e; size_t i; + std::tie(b,e,i) = q.front(); + q.pop(); + HCIterator m = b + (e - b)/2; + + CoordinateComparison cmp(i, traits); + std::nth_element(b,m,e, cmp); + size_t next_i = (i + 1) % traits.dimension(); + + // Replace with a size condition instead? + if (m - b > 1) q.push(KDTreeNode(b, m, next_i)); + if (e - m > 2) q.push(KDTreeNode(m+1, e, next_i)); + } + } + + HCIterator b, e; + size_t i; + const Traits& traits; +}; + +template +template +void +hera::ws::dnn::KDTree:: +search(PointHandle q, ResultsFunctor& rf) const +{ + typedef typename HandleContainer::const_iterator HCIterator; + typedef std::tuple KDTreeNode; + + if (tree_.empty()) + return; + + DistanceType D = std::numeric_limits::max(); + + // TODO: use tbb::scalable_allocator for the queue + std::queue nodes; + + nodes.push(KDTreeNode(tree_.begin(), tree_.end(), 0)); + + while (!nodes.empty()) + { + HCIterator b, e; size_t i; + std::tie(b,e,i) = nodes.front(); + nodes.pop(); + + CoordinateComparison cmp(i, traits()); + i = (i + 1) % traits().dimension(); + + HCIterator m = b + (e - b)/2; + + DistanceType dist = (wassersteinPower == 1.0) ? traits().distance(q, *m) + weights_[m - tree_.begin()] : std::pow(traits().distance(q, *m), wassersteinPower) + weights_[m - tree_.begin()]; + + + D = rf(*m, dist); + + // we are really searching w.r.t L_\infty ball; could prune better with an L_2 ball + Coordinate diff = cmp.diff(q, *m); // diff returns signed distance + + DistanceType diffToWasserPower = (wassersteinPower == 1.0) ? diff : ((diff > 0 ? 1.0 : -1.0) * std::pow(fabs(diff), wassersteinPower)); + + size_t lm = m + 1 + (e - (m+1))/2 - tree_.begin(); + if (e > m + 1 && diffToWasserPower - subtree_weights_[lm] >= -D) { + nodes.push(KDTreeNode(m+1, e, i)); + } + + size_t rm = b + (m - b) / 2 - tree_.begin(); + if (b < m && diffToWasserPower + subtree_weights_[rm] <= D) { + nodes.push(KDTreeNode(b, m, i)); + } + } +} + +template +void +hera::ws::dnn::KDTree:: +adjust_weights(DistanceType delta) +{ + for(auto& w : weights_) + w -= delta; + + for(auto& sw : subtree_weights_) + sw -= delta; +} + + +template +void +hera::ws::dnn::KDTree:: +change_weight(PointHandle p, DistanceType w) +{ + size_t idx = indices_[p]; + + if ( weights_[idx] == w ) { + return; + } + + bool weight_increases = ( weights_[idx] < w ); + weights_[idx] = w; + + typedef std::tuple KDTreeNode; + + // find the path down the tree to this node + // not an ideal strategy, but // it's not clear how to move up from the node in general + std::stack s; + s.push(KDTreeNode(tree_.begin(),tree_.end())); + + do + { + HCIterator b,e; + std::tie(b,e) = s.top(); + + size_t im = b + (e - b)/2 - tree_.begin(); + + if (idx == im) + break; + else if (idx < im) + s.push(KDTreeNode(b, tree_.begin() + im)); + else // idx > im + s.push(KDTreeNode(tree_.begin() + im + 1, e)); + } while(1); + + // update subtree_weights_ on the path to the root + DistanceType min_w = w; + while (!s.empty()) + { + HCIterator b,e; + std::tie(b,e) = s.top(); + HCIterator m = b + (e - b)/2; + size_t im = m - tree_.begin(); + s.pop(); + + + // left and right children + if (b < m) + { + size_t lm = b + (m - b)/2 - tree_.begin(); + if (subtree_weights_[lm] < min_w) + min_w = subtree_weights_[lm]; + } + + if (e > m + 1) + { + size_t rm = m + 1 + (e - (m+1))/2 - tree_.begin(); + if (subtree_weights_[rm] < min_w) + min_w = subtree_weights_[rm]; + } + + if (weights_[im] < min_w) { + min_w = weights_[im]; + } + + if (weight_increases) { + + if (subtree_weights_[im] < min_w ) // increase weight + subtree_weights_[im] = min_w; + else + break; + + } else { + + if (subtree_weights_[im] > min_w ) // decrease weight + subtree_weights_[im] = min_w; + else + break; + + } + } +} + +template +typename hera::ws::dnn::KDTree::HandleDistance +hera::ws::dnn::KDTree:: +find(PointHandle q) const +{ + hera::ws::dnn::NNRecord nn; + search(q, nn); + return nn.result; +} + +template +typename hera::ws::dnn::KDTree::Result +hera::ws::dnn::KDTree:: +findR(PointHandle q, DistanceType r) const +{ + hera::ws::dnn::rNNRecord rnn(r); + search(q, rnn); + std::sort(rnn.result.begin(), rnn.result.end()); + return rnn.result; +} + +template +typename hera::ws::dnn::KDTree::Result +hera::ws::dnn::KDTree:: +findK(PointHandle q, size_t k) const +{ + hera::ws::dnn::kNNRecord knn(k); + search(q, knn); + std::sort(knn.result.begin(), knn.result.end()); + return knn.result; +} + + +template +struct hera::ws::dnn::KDTree::CoordinateComparison +{ + CoordinateComparison(size_t i, const Traits& traits): + i_(i), traits_(traits) {} + + bool operator()(PointHandle p1, PointHandle p2) const { return coordinate(p1) < coordinate(p2); } + Coordinate diff(PointHandle p1, PointHandle p2) const { return coordinate(p1) - coordinate(p2); } + + Coordinate coordinate(PointHandle p) const { return traits_.coordinate(p, i_); } + size_t axis() const { return i_; } + + private: + size_t i_; + const Traits& traits_; +}; + +template +void +hera::ws::dnn::KDTree:: +printWeights(void) +{ +#ifndef FOR_R_TDA + std::cout << "weights_:" << std::endl; + for(const auto ph : indices_) { + std::cout << "idx = " << ph.second << ": (" << (ph.first)->at(0) << ", " << (ph.first)->at(1) << ") weight = " << weights_[ph.second] << std::endl; + } + std::cout << "subtree_weights_:" << std::endl; + for(size_t idx = 0; idx < subtree_weights_.size(); ++idx) { + std::cout << idx << " : " << subtree_weights_[idx] << std::endl; + } +#endif +} + + -- cgit v1.2.3