From 6ff8072e8c5a6dc1301e884f5a648a0b63bdd48a Mon Sep 17 00:00:00 2001 From: Arnur Nigmetov Date: Tue, 3 Dec 2019 20:34:28 +0100 Subject: Rename directories for bottleneck and Wasserstein --- bottleneck/include/dnn/local/search-functors.h | 119 +++++++++++++++++++++++++ 1 file changed, 119 insertions(+) create mode 100644 bottleneck/include/dnn/local/search-functors.h (limited to 'bottleneck/include/dnn/local/search-functors.h') diff --git a/bottleneck/include/dnn/local/search-functors.h b/bottleneck/include/dnn/local/search-functors.h new file mode 100644 index 0000000..63ad11d --- /dev/null +++ b/bottleneck/include/dnn/local/search-functors.h @@ -0,0 +1,119 @@ +#ifndef HERA_BT_DNN_LOCAL_SEARCH_FUNCTORS_H +#define HERA_BT_DNN_LOCAL_SEARCH_FUNCTORS_H + +#include + +namespace hera +{ +namespace bt +{ +namespace dnn +{ + +template +struct HandleDistance +{ + typedef typename NN::PointHandle PointHandle; + typedef typename NN::DistanceType DistanceType; + typedef typename NN::HDContainer HDContainer; + + HandleDistance() {} + HandleDistance(PointHandle pp, DistanceType dd): + p(pp), d(dd) {} + bool operator<(const HandleDistance& other) const { return d < other.d; } + + PointHandle p; + DistanceType d; +}; + +template +struct NNRecord +{ + typedef typename HandleDistance::PointHandle PointHandle; + typedef typename HandleDistance::DistanceType DistanceType; + + NNRecord() { result.d = std::numeric_limits::infinity(); } + DistanceType operator()(PointHandle p, DistanceType d) { if (d < result.d) { result.p = p; result.d = d; } return result.d; } + HandleDistance result; +}; + +template +struct rNNRecord +{ + typedef typename HandleDistance::PointHandle PointHandle; + typedef typename HandleDistance::DistanceType DistanceType; + typedef typename HandleDistance::HDContainer HDContainer; + + rNNRecord(DistanceType r_): r(r_) {} + DistanceType operator()(PointHandle p, DistanceType d) + { + if (d <= r) + result.push_back(HandleDistance(p,d)); + return r; + } + + DistanceType r; + HDContainer result; +}; + +template +struct firstrNNRecord +{ + typedef typename HandleDistance::PointHandle PointHandle; + typedef typename HandleDistance::DistanceType DistanceType; + typedef typename HandleDistance::HDContainer HDContainer; + + firstrNNRecord(DistanceType r_): r(r_) {} + + DistanceType operator()(PointHandle p, DistanceType d) + { + if (d <= r) { + result.push_back(HandleDistance(p,d)); + return -100000000.0; + } else { + return r; + } + } + + DistanceType r; + HDContainer result; +}; + + +template +struct kNNRecord +{ + typedef typename HandleDistance::PointHandle PointHandle; + typedef typename HandleDistance::DistanceType DistanceType; + typedef typename HandleDistance::HDContainer HDContainer; + + kNNRecord(unsigned k_): k(k_) {} + DistanceType operator()(PointHandle p, DistanceType d) + { + if (result.size() < k) + { + result.push_back(HandleDistance(p,d)); + boost::push_heap(result); + if (result.size() < k) + return std::numeric_limits::infinity(); + } else if (d < result[0].d) + { + boost::pop_heap(result); + result.back() = HandleDistance(p,d); + boost::push_heap(result); + } + if ( result.size() > 1 ) { + assert( result[0].d >= result[1].d ); + } + return result[0].d; + } + + unsigned k; + HDContainer result; +}; + +} // dnn +} // bt +} // hera + +#endif // HERA_BT_DNN_LOCAL_SEARCH_FUNCTORS_H -- cgit v1.2.3