summaryrefslogtreecommitdiff
path: root/geom_matching/wasserstein/include/dnn/local/kd-tree.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'geom_matching/wasserstein/include/dnn/local/kd-tree.hpp')
-rw-r--r--geom_matching/wasserstein/include/dnn/local/kd-tree.hpp29
1 files changed, 22 insertions, 7 deletions
diff --git a/geom_matching/wasserstein/include/dnn/local/kd-tree.hpp b/geom_matching/wasserstein/include/dnn/local/kd-tree.hpp
index 6b0852c..22108aa 100644
--- a/geom_matching/wasserstein/include/dnn/local/kd-tree.hpp
+++ b/geom_matching/wasserstein/include/dnn/local/kd-tree.hpp
@@ -164,11 +164,15 @@ search(PointHandle q, ResultsFunctor& rf) const
template<class T>
void
dnn::KDTree<T>::
-increase_weight(PointHandle p, DistanceType w)
+change_weight(PointHandle p, DistanceType w)
{
size_t idx = indices_[p];
- // weight should only increase
- assert( weights_[idx] <= w );
+
+ if ( weights_[idx] == w ) {
+ return;
+ }
+
+ bool weight_increases = ( weights_[idx] < w );
weights_[idx] = w;
typedef std::tuple<HCIterator, HCIterator> KDTreeNode;
@@ -223,10 +227,21 @@ increase_weight(PointHandle p, DistanceType w)
min_w = weights_[im];
}
- if (subtree_weights_[im] < min_w ) // increase weight
- subtree_weights_[im] = min_w;
- else
- break;
+ if (weight_increases) {
+
+ if (subtree_weights_[im] < min_w ) // increase weight
+ subtree_weights_[im] = min_w;
+ else
+ break;
+
+ } else {
+
+ if (subtree_weights_[im] > min_w ) // decrease weight
+ subtree_weights_[im] = min_w;
+ else
+ break;
+
+ }
}
}