summaryrefslogtreecommitdiff
path: root/geom_matching/wasserstein/include/dnn/local/search-functors.h
blob: f257d0c916a936fd85c04485badda9971a638615 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
#ifndef DNN_LOCAL_SEARCH_FUNCTORS_H
#define DNN_LOCAL_SEARCH_FUNCTORS_H

#include <boost/range/algorithm/heap_algorithm.hpp>

namespace dnn
{

template<class NN>
struct HandleDistance
{
    typedef             typename NN::PointHandle                                    PointHandle;
    typedef             typename NN::DistanceType                                   DistanceType;
    typedef             typename NN::HDContainer                                    HDContainer;

                        HandleDistance()                                            {}
                        HandleDistance(PointHandle pp, DistanceType dd):
                            p(pp), d(dd)                                            {}
    bool                operator<(const HandleDistance& other) const                { return d < other.d; }

    PointHandle         p;
    DistanceType        d;
};

template<class HandleDistance>
struct NNRecord
{
    typedef         typename HandleDistance::PointHandle                            PointHandle;
    typedef         typename HandleDistance::DistanceType                           DistanceType;

                    NNRecord()                                                      { result.d = std::numeric_limits<DistanceType>::infinity(); }
    DistanceType    operator()(PointHandle p, DistanceType d)                       { if (d < result.d) { result.p = p; result.d = d; } return result.d; }
    HandleDistance  result;
};

template<class HandleDistance>
struct rNNRecord
{
    typedef         typename HandleDistance::PointHandle                            PointHandle;
    typedef         typename HandleDistance::DistanceType                           DistanceType;
    typedef         typename HandleDistance::HDContainer                            HDContainer;

                    rNNRecord(DistanceType r_): r(r_)                               {}
    DistanceType    operator()(PointHandle p, DistanceType d)
    {
        if (d <= r)
            result.push_back(HandleDistance(p,d));
        return r;
    }

    DistanceType    r;
    HDContainer     result;
};

template<class HandleDistance>
struct kNNRecord
{
    typedef         typename HandleDistance::PointHandle                            PointHandle;
    typedef         typename HandleDistance::DistanceType                           DistanceType;
    typedef         typename HandleDistance::HDContainer                            HDContainer;

                    kNNRecord(unsigned k_): k(k_)                                   {}
    DistanceType    operator()(PointHandle p, DistanceType d)
    {
        if (result.size() < k)
        {
            result.push_back(HandleDistance(p,d));
            boost::push_heap(result);
            if (result.size() < k)
                return std::numeric_limits<DistanceType>::infinity();
        } else if (d < result[0].d)
        {
            boost::pop_heap(result);
            result.back() = HandleDistance(p,d);
            boost::push_heap(result);
        }
        if ( result.size() > 1 ) {
            assert( result[0].d >= result[1].d );
        }
        return result[0].d;
    }

    unsigned        k;
    HDContainer     result;
};

}

#endif // DNN_LOCAL_SEARCH_FUNCTORS_H