diff options
author | Arnur Nigmetov <a.nigmetov@gmail.com> | 2018-01-20 19:11:29 +0100 |
---|---|---|
committer | Arnur Nigmetov <a.nigmetov@gmail.com> | 2018-01-20 19:11:29 +0100 |
commit | 0cc35ad04f9c2997014d7cf62a12f697e79fb534 (patch) | |
tree | 744c07bc2c12fba193934ac98417c5063bead189 /geom_matching/wasserstein/include/dnn/local | |
parent | 3552ce68bc7654df35da471bd937b09a9fde101f (diff) |
Major rewrite, templatized version
Diffstat (limited to 'geom_matching/wasserstein/include/dnn/local')
3 files changed, 58 insertions, 30 deletions
diff --git a/geom_matching/wasserstein/include/dnn/local/kd-tree.h b/geom_matching/wasserstein/include/dnn/local/kd-tree.h index 13eaf27..8e52a5c 100644 --- a/geom_matching/wasserstein/include/dnn/local/kd-tree.h +++ b/geom_matching/wasserstein/include/dnn/local/kd-tree.h @@ -1,5 +1,5 @@ -#ifndef DNN_LOCAL_KD_TREE_H -#define DNN_LOCAL_KD_TREE_H +#ifndef HERA_WS_DNN_LOCAL_KD_TREE_H +#define HERA_WS_DNN_LOCAL_KD_TREE_H #include "../utils.h" #include "search-functors.h" @@ -13,6 +13,10 @@ #include <boost/static_assert.hpp> #include <boost/type_traits.hpp> +namespace hera +{ +namespace ws +{ namespace dnn { // Weighted KDTree @@ -48,8 +52,9 @@ namespace dnn template<class Range> void init(const Range& range); - DistanceType weight(PointHandle p) { return weights_[indices_[p]]; } + DistanceType weight(PointHandle p) { return weights_[indices_[p]]; } void change_weight(PointHandle p, DistanceType w); + void adjust_weights(DistanceType delta); // subtract delta from all weights HandleDistance find(PointHandle q) const; Result findR(PointHandle q, DistanceType r) const; // all neighbors within r @@ -83,7 +88,9 @@ namespace dnn HandleMap indices_; double wassersteinPower; }; -} +} // dnn +} // ws +} // hera #include "kd-tree.hpp" diff --git a/geom_matching/wasserstein/include/dnn/local/kd-tree.hpp b/geom_matching/wasserstein/include/dnn/local/kd-tree.hpp index 22108aa..3a4f0eb 100644 --- a/geom_matching/wasserstein/include/dnn/local/kd-tree.hpp +++ b/geom_matching/wasserstein/include/dnn/local/kd-tree.hpp @@ -9,14 +9,14 @@ #include "def_debug_ws.h" template<class T> -dnn::KDTree<T>:: +hera::ws::dnn::KDTree<T>:: KDTree(const Traits& traits, HandleContainer&& handles, double _wassersteinPower): traits_(traits), tree_(std::move(handles)), wassersteinPower(_wassersteinPower) { assert(wassersteinPower >= 1.0); init(); } template<class T> template<class Range> -dnn::KDTree<T>:: +hera::ws::dnn::KDTree<T>:: KDTree(const Traits& traits, const Range& range, double _wassersteinPower): traits_(traits), wassersteinPower(_wassersteinPower) { @@ -27,7 +27,7 @@ KDTree(const Traits& traits, const Range& range, double _wassersteinPower): template<class T> template<class Range> void -dnn::KDTree<T>:: +hera::ws::dnn::KDTree<T>:: init(const Range& range) { size_t sz = std::distance(std::begin(range), std::end(range)); @@ -41,7 +41,7 @@ init(const Range& range) template<class T> void -dnn::KDTree<T>:: +hera::ws::dnn::KDTree<T>:: init() { if (tree_.empty()) @@ -61,7 +61,7 @@ init() template<class T> struct -dnn::KDTree<T>::OrderTree +hera::ws::dnn::KDTree<T>::OrderTree { OrderTree(HCIterator b_, HCIterator e_, size_t i_, const Traits& traits_): b(b_), e(e_), i(i_), traits(traits_) {} @@ -114,7 +114,7 @@ dnn::KDTree<T>::OrderTree template<class T> template<class ResultsFunctor> void -dnn::KDTree<T>:: +hera::ws::dnn::KDTree<T>:: search(PointHandle q, ResultsFunctor& rf) const { typedef typename HandleContainer::const_iterator HCIterator; @@ -123,7 +123,7 @@ search(PointHandle q, ResultsFunctor& rf) const if (tree_.empty()) return; - DistanceType D = std::numeric_limits<DistanceType>::infinity(); + DistanceType D = std::numeric_limits<DistanceType>::max(); // TODO: use tbb::scalable_allocator for the queue std::queue<KDTreeNode> nodes; @@ -140,14 +140,16 @@ search(PointHandle q, ResultsFunctor& rf) const i = (i + 1) % traits().dimension(); HCIterator m = b + (e - b)/2; - DistanceType dist = pow(traits().distance(q, *m), wassersteinPower) + weights_[m - tree_.begin()]; + + DistanceType dist = (wassersteinPower == 1.0) ? traits().distance(q, *m) + weights_[m - tree_.begin()] : std::pow(traits().distance(q, *m), wassersteinPower) + weights_[m - tree_.begin()]; D = rf(*m, dist); // we are really searching w.r.t L_\infty ball; could prune better with an L_2 ball Coordinate diff = cmp.diff(q, *m); // diff returns signed distance - DistanceType diffToWasserPower = (diff > 0 ? 1.0 : -1.0) * pow(fabs(diff), wassersteinPower); + + DistanceType diffToWasserPower = (wassersteinPower == 1.0) ? diff : ((diff > 0 ? 1.0 : -1.0) * std::pow(fabs(diff), wassersteinPower)); size_t lm = m + 1 + (e - (m+1))/2 - tree_.begin(); if (e > m + 1 && diffToWasserPower - subtree_weights_[lm] >= -D) { @@ -163,7 +165,20 @@ search(PointHandle q, ResultsFunctor& rf) const template<class T> void -dnn::KDTree<T>:: +hera::ws::dnn::KDTree<T>:: +adjust_weights(DistanceType delta) +{ + for(auto& w : weights_) + w -= delta; + + for(auto& sw : subtree_weights_) + sw -= delta; +} + + +template<class T> +void +hera::ws::dnn::KDTree<T>:: change_weight(PointHandle p, DistanceType w) { size_t idx = indices_[p]; @@ -246,32 +261,32 @@ change_weight(PointHandle p, DistanceType w) } template<class T> -typename dnn::KDTree<T>::HandleDistance -dnn::KDTree<T>:: +typename hera::ws::dnn::KDTree<T>::HandleDistance +hera::ws::dnn::KDTree<T>:: find(PointHandle q) const { - dnn::NNRecord<HandleDistance> nn; + hera::ws::dnn::NNRecord<HandleDistance> nn; search(q, nn); return nn.result; } template<class T> -typename dnn::KDTree<T>::Result -dnn::KDTree<T>:: +typename hera::ws::dnn::KDTree<T>::Result +hera::ws::dnn::KDTree<T>:: findR(PointHandle q, DistanceType r) const { - dnn::rNNRecord<HandleDistance> rnn(r); + hera::ws::dnn::rNNRecord<HandleDistance> rnn(r); search(q, rnn); std::sort(rnn.result.begin(), rnn.result.end()); return rnn.result; } template<class T> -typename dnn::KDTree<T>::Result -dnn::KDTree<T>:: +typename hera::ws::dnn::KDTree<T>::Result +hera::ws::dnn::KDTree<T>:: findK(PointHandle q, size_t k) const { - dnn::kNNRecord<HandleDistance> knn(k); + hera::ws::dnn::kNNRecord<HandleDistance> knn(k); search(q, knn); std::sort(knn.result.begin(), knn.result.end()); return knn.result; @@ -279,7 +294,7 @@ findK(PointHandle q, size_t k) const template<class T> -struct dnn::KDTree<T>::CoordinateComparison +struct hera::ws::dnn::KDTree<T>::CoordinateComparison { CoordinateComparison(size_t i, const Traits& traits): i_(i), traits_(traits) {} @@ -297,7 +312,7 @@ struct dnn::KDTree<T>::CoordinateComparison template<class T> void -dnn::KDTree<T>:: +hera::ws::dnn::KDTree<T>:: printWeights(void) { #ifndef FOR_R_TDA diff --git a/geom_matching/wasserstein/include/dnn/local/search-functors.h b/geom_matching/wasserstein/include/dnn/local/search-functors.h index f257d0c..1419f22 100644 --- a/geom_matching/wasserstein/include/dnn/local/search-functors.h +++ b/geom_matching/wasserstein/include/dnn/local/search-functors.h @@ -1,8 +1,12 @@ -#ifndef DNN_LOCAL_SEARCH_FUNCTORS_H -#define DNN_LOCAL_SEARCH_FUNCTORS_H +#ifndef HERA_WS_DNN_LOCAL_SEARCH_FUNCTORS_H +#define HERA_WS_DNN_LOCAL_SEARCH_FUNCTORS_H #include <boost/range/algorithm/heap_algorithm.hpp> +namespace hera +{ +namespace ws +{ namespace dnn { @@ -28,7 +32,7 @@ struct NNRecord typedef typename HandleDistance::PointHandle PointHandle; typedef typename HandleDistance::DistanceType DistanceType; - NNRecord() { result.d = std::numeric_limits<DistanceType>::infinity(); } + NNRecord() { result.d = std::numeric_limits<DistanceType>::max(); } DistanceType operator()(PointHandle p, DistanceType d) { if (d < result.d) { result.p = p; result.d = d; } return result.d; } HandleDistance result; }; @@ -67,7 +71,7 @@ struct kNNRecord result.push_back(HandleDistance(p,d)); boost::push_heap(result); if (result.size() < k) - return std::numeric_limits<DistanceType>::infinity(); + return std::numeric_limits<DistanceType>::max(); } else if (d < result[0].d) { boost::pop_heap(result); @@ -84,6 +88,8 @@ struct kNNRecord HDContainer result; }; -} +} // dnn +} // ws +} // hera #endif // DNN_LOCAL_SEARCH_FUNCTORS_H |