diff options
author | Arnur Nigmetov <nigmetov@tugraz.at> | 2019-12-03 20:34:28 +0100 |
---|---|---|
committer | Arnur Nigmetov <nigmetov@tugraz.at> | 2019-12-03 20:34:28 +0100 |
commit | 29ed61e4dd575edc9960cc6414ab2e20c02a62e1 (patch) | |
tree | 4134a502f4a91c6c8a8fcc37cd24897445dbf1e3 /wasserstein/include/dnn/local/search-functors.h | |
parent | 5a59cfad45c155f8af89c2c6d82db2848d52a953 (diff) |
Rename directories for bottleneck and Wasserstein
Diffstat (limited to 'wasserstein/include/dnn/local/search-functors.h')
-rw-r--r-- | wasserstein/include/dnn/local/search-functors.h | 95 |
1 files changed, 95 insertions, 0 deletions
diff --git a/wasserstein/include/dnn/local/search-functors.h b/wasserstein/include/dnn/local/search-functors.h new file mode 100644 index 0000000..1419f22 --- /dev/null +++ b/wasserstein/include/dnn/local/search-functors.h @@ -0,0 +1,95 @@ +#ifndef HERA_WS_DNN_LOCAL_SEARCH_FUNCTORS_H +#define HERA_WS_DNN_LOCAL_SEARCH_FUNCTORS_H + +#include <boost/range/algorithm/heap_algorithm.hpp> + +namespace hera +{ +namespace ws +{ +namespace dnn +{ + +template<class NN> +struct HandleDistance +{ + typedef typename NN::PointHandle PointHandle; + typedef typename NN::DistanceType DistanceType; + typedef typename NN::HDContainer HDContainer; + + HandleDistance() {} + HandleDistance(PointHandle pp, DistanceType dd): + p(pp), d(dd) {} + bool operator<(const HandleDistance& other) const { return d < other.d; } + + PointHandle p; + DistanceType d; +}; + +template<class HandleDistance> +struct NNRecord +{ + typedef typename HandleDistance::PointHandle PointHandle; + typedef typename HandleDistance::DistanceType DistanceType; + + NNRecord() { result.d = std::numeric_limits<DistanceType>::max(); } + DistanceType operator()(PointHandle p, DistanceType d) { if (d < result.d) { result.p = p; result.d = d; } return result.d; } + HandleDistance result; +}; + +template<class HandleDistance> +struct rNNRecord +{ + typedef typename HandleDistance::PointHandle PointHandle; + typedef typename HandleDistance::DistanceType DistanceType; + typedef typename HandleDistance::HDContainer HDContainer; + + rNNRecord(DistanceType r_): r(r_) {} + DistanceType operator()(PointHandle p, DistanceType d) + { + if (d <= r) + result.push_back(HandleDistance(p,d)); + return r; + } + + DistanceType r; + HDContainer result; +}; + +template<class HandleDistance> +struct kNNRecord +{ + typedef typename HandleDistance::PointHandle PointHandle; + typedef typename HandleDistance::DistanceType DistanceType; + typedef typename HandleDistance::HDContainer HDContainer; + + kNNRecord(unsigned k_): k(k_) {} + DistanceType operator()(PointHandle p, DistanceType d) + { + if (result.size() < k) + { + result.push_back(HandleDistance(p,d)); + boost::push_heap(result); + if (result.size() < k) + return std::numeric_limits<DistanceType>::max(); + } else if (d < result[0].d) + { + boost::pop_heap(result); + result.back() = HandleDistance(p,d); + boost::push_heap(result); + } + if ( result.size() > 1 ) { + assert( result[0].d >= result[1].d ); + } + return result[0].d; + } + + unsigned k; + HDContainer result; +}; + +} // dnn +} // ws +} // hera + +#endif // DNN_LOCAL_SEARCH_FUNCTORS_H |