summaryrefslogtreecommitdiff
path: root/geom_matching/wasserstein/include/dnn/local/kd-tree.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'geom_matching/wasserstein/include/dnn/local/kd-tree.hpp')
-rw-r--r--geom_matching/wasserstein/include/dnn/local/kd-tree.hpp57
1 files changed, 36 insertions, 21 deletions
diff --git a/geom_matching/wasserstein/include/dnn/local/kd-tree.hpp b/geom_matching/wasserstein/include/dnn/local/kd-tree.hpp
index 22108aa..3a4f0eb 100644
--- a/geom_matching/wasserstein/include/dnn/local/kd-tree.hpp
+++ b/geom_matching/wasserstein/include/dnn/local/kd-tree.hpp
@@ -9,14 +9,14 @@
#include "def_debug_ws.h"
template<class T>
-dnn::KDTree<T>::
+hera::ws::dnn::KDTree<T>::
KDTree(const Traits& traits, HandleContainer&& handles, double _wassersteinPower):
traits_(traits), tree_(std::move(handles)), wassersteinPower(_wassersteinPower)
{ assert(wassersteinPower >= 1.0); init(); }
template<class T>
template<class Range>
-dnn::KDTree<T>::
+hera::ws::dnn::KDTree<T>::
KDTree(const Traits& traits, const Range& range, double _wassersteinPower):
traits_(traits), wassersteinPower(_wassersteinPower)
{
@@ -27,7 +27,7 @@ KDTree(const Traits& traits, const Range& range, double _wassersteinPower):
template<class T>
template<class Range>
void
-dnn::KDTree<T>::
+hera::ws::dnn::KDTree<T>::
init(const Range& range)
{
size_t sz = std::distance(std::begin(range), std::end(range));
@@ -41,7 +41,7 @@ init(const Range& range)
template<class T>
void
-dnn::KDTree<T>::
+hera::ws::dnn::KDTree<T>::
init()
{
if (tree_.empty())
@@ -61,7 +61,7 @@ init()
template<class T>
struct
-dnn::KDTree<T>::OrderTree
+hera::ws::dnn::KDTree<T>::OrderTree
{
OrderTree(HCIterator b_, HCIterator e_, size_t i_, const Traits& traits_):
b(b_), e(e_), i(i_), traits(traits_) {}
@@ -114,7 +114,7 @@ dnn::KDTree<T>::OrderTree
template<class T>
template<class ResultsFunctor>
void
-dnn::KDTree<T>::
+hera::ws::dnn::KDTree<T>::
search(PointHandle q, ResultsFunctor& rf) const
{
typedef typename HandleContainer::const_iterator HCIterator;
@@ -123,7 +123,7 @@ search(PointHandle q, ResultsFunctor& rf) const
if (tree_.empty())
return;
- DistanceType D = std::numeric_limits<DistanceType>::infinity();
+ DistanceType D = std::numeric_limits<DistanceType>::max();
// TODO: use tbb::scalable_allocator for the queue
std::queue<KDTreeNode> nodes;
@@ -140,14 +140,16 @@ search(PointHandle q, ResultsFunctor& rf) const
i = (i + 1) % traits().dimension();
HCIterator m = b + (e - b)/2;
- DistanceType dist = pow(traits().distance(q, *m), wassersteinPower) + weights_[m - tree_.begin()];
+
+ DistanceType dist = (wassersteinPower == 1.0) ? traits().distance(q, *m) + weights_[m - tree_.begin()] : std::pow(traits().distance(q, *m), wassersteinPower) + weights_[m - tree_.begin()];
D = rf(*m, dist);
// we are really searching w.r.t L_\infty ball; could prune better with an L_2 ball
Coordinate diff = cmp.diff(q, *m); // diff returns signed distance
- DistanceType diffToWasserPower = (diff > 0 ? 1.0 : -1.0) * pow(fabs(diff), wassersteinPower);
+
+ DistanceType diffToWasserPower = (wassersteinPower == 1.0) ? diff : ((diff > 0 ? 1.0 : -1.0) * std::pow(fabs(diff), wassersteinPower));
size_t lm = m + 1 + (e - (m+1))/2 - tree_.begin();
if (e > m + 1 && diffToWasserPower - subtree_weights_[lm] >= -D) {
@@ -163,7 +165,20 @@ search(PointHandle q, ResultsFunctor& rf) const
template<class T>
void
-dnn::KDTree<T>::
+hera::ws::dnn::KDTree<T>::
+adjust_weights(DistanceType delta)
+{
+ for(auto& w : weights_)
+ w -= delta;
+
+ for(auto& sw : subtree_weights_)
+ sw -= delta;
+}
+
+
+template<class T>
+void
+hera::ws::dnn::KDTree<T>::
change_weight(PointHandle p, DistanceType w)
{
size_t idx = indices_[p];
@@ -246,32 +261,32 @@ change_weight(PointHandle p, DistanceType w)
}
template<class T>
-typename dnn::KDTree<T>::HandleDistance
-dnn::KDTree<T>::
+typename hera::ws::dnn::KDTree<T>::HandleDistance
+hera::ws::dnn::KDTree<T>::
find(PointHandle q) const
{
- dnn::NNRecord<HandleDistance> nn;
+ hera::ws::dnn::NNRecord<HandleDistance> nn;
search(q, nn);
return nn.result;
}
template<class T>
-typename dnn::KDTree<T>::Result
-dnn::KDTree<T>::
+typename hera::ws::dnn::KDTree<T>::Result
+hera::ws::dnn::KDTree<T>::
findR(PointHandle q, DistanceType r) const
{
- dnn::rNNRecord<HandleDistance> rnn(r);
+ hera::ws::dnn::rNNRecord<HandleDistance> rnn(r);
search(q, rnn);
std::sort(rnn.result.begin(), rnn.result.end());
return rnn.result;
}
template<class T>
-typename dnn::KDTree<T>::Result
-dnn::KDTree<T>::
+typename hera::ws::dnn::KDTree<T>::Result
+hera::ws::dnn::KDTree<T>::
findK(PointHandle q, size_t k) const
{
- dnn::kNNRecord<HandleDistance> knn(k);
+ hera::ws::dnn::kNNRecord<HandleDistance> knn(k);
search(q, knn);
std::sort(knn.result.begin(), knn.result.end());
return knn.result;
@@ -279,7 +294,7 @@ findK(PointHandle q, size_t k) const
template<class T>
-struct dnn::KDTree<T>::CoordinateComparison
+struct hera::ws::dnn::KDTree<T>::CoordinateComparison
{
CoordinateComparison(size_t i, const Traits& traits):
i_(i), traits_(traits) {}
@@ -297,7 +312,7 @@ struct dnn::KDTree<T>::CoordinateComparison
template<class T>
void
-dnn::KDTree<T>::
+hera::ws::dnn::KDTree<T>::
printWeights(void)
{
#ifndef FOR_R_TDA