summaryrefslogtreecommitdiff
path: root/geom_matching/wasserstein/include/basic_defs_ws.h
blob: 58d6fd2106074a453040cad1e90399b6b3ec36a9 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
/*

Copyright (c) 2015, M. Kerber, D. Morozov, A. Nigmetov
All rights reserved.

Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:

1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.

2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.

3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

You are under no obligation whatsoever to provide any bug fixes, patches, or
upgrades to the features, functionality or performance of the source code
(Enhancements) to anyone; however, if you choose to make your Enhancements
available either publicly, or directly to copyright holder,
without imposing a separate written license agreement for such Enhancements,
then you hereby grant the following license: a  non-exclusive, royalty-free
perpetual license to install, use, modify, prepare derivative works, incorporate
into other computer software, distribute, and sublicense such enhancements or
derivative works thereof, in binary and source code form.

 */

#ifndef BASIC_DEFS_WS_H
#define BASIC_DEFS_WS_H

#include <vector>
#include <math.h>
#include <cstddef>
#include <unordered_map>
#include <unordered_set>
#include <string>
#include <iomanip>
#include <locale>
#include <cassert>
#include <limits>
#include <ostream>
#include <typeinfo>

#ifdef _WIN32
#include <ciso646>
#endif

#ifndef FOR_R_TDA
#include "spdlog/spdlog.h"
#include "spdlog/fmt/fmt.h"
#include "spdlog/fmt/ostr.h"
#endif

#include "dnn/geometry/euclidean-dynamic.h"
#include "def_debug_ws.h"

#define MIN_VALID_ID 10

namespace hera
{

template<class Real = double>
bool is_infinity(const Real& x)
{
    return x == Real(-1);
};

template<class Real = double>
Real get_infinity()
{
    return Real( -1 );
}

template<class Real = double>
bool is_p_valid_norm(const Real& p)
{
    return is_infinity<Real>(p) or p >= Real(1);
}

template<class Real = double>
struct AuctionParams
{
    Real wasserstein_power { 1.0 };
    Real delta { 0.01 }; // relative error
    Real internal_p { get_infinity<Real>() };
    Real initial_epsilon { 0.0 }; // 0.0 means maxVal / 4.0
    Real epsilon_common_ratio { 5.0 };
    Real gamma_threshold { 0.0 };  // for experiments, not in use now
    int max_num_phases { std::numeric_limits<decltype(max_num_phases)>::max() };
    size_t max_bids_per_round { 1 };  // imitate Gauss-Seidel is default behaviour
    unsigned int dim { 2 }; // for pure geometric version only; ignored in persistence diagrams
};

namespace ws
{

    using IdxType = int;

    constexpr size_t k_invalid_index = std::numeric_limits<IdxType>::max();

    template<class Real = double>
    using IdxValPair = std::pair<IdxType, Real>;



    template<class R>
    std::ostream& operator<<(std::ostream& output, const IdxValPair<R> p)
    {
        output << fmt::format("({0}, {1})", p.first, p.second);
        return output;
    }

    enum class OwnerType { k_none, k_normal, k_diagonal };

    std::ostream& operator<<(std::ostream& s, const OwnerType t)
    {
        switch(t)
        {
            case OwnerType::k_none : s << "NONE"; break;
            case OwnerType::k_normal: s << "NORMAL"; break;
            case OwnerType::k_diagonal: s << "DIAGONAL"; break;
        }
        return s;
    }

    template<class Real = double>
    struct Point {
        Real x, y;
        bool operator==(const Point& other) const;
        bool operator!=(const Point& other) const;
        Point(Real _x, Real _y) : x(_x), y(_y) {}
        Point() : x(0.0), y(0.0) {}
    };

#ifndef FOR_R_TDA
    template<class Real = double>
    std::ostream& operator<<(std::ostream& output, const Point<Real> p);
#endif

    template <class T>
    inline void hash_combine(std::size_t & seed, const T & v)
    {
        std::hash<T> hasher;
        seed ^= hasher(v) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
    }

    template<class Real_ = double>
    struct DiagramPoint
    {
        using Real = Real_;
        // data members
        // Points above the diagonal have type NORMAL
        // Projections onto the diagonal have type DIAG
        // for DIAG points only x-coordinate is relevant
        enum Type { NORMAL, DIAG};
        Real x, y;
        Type type;
        // methods
        DiagramPoint(Real xx, Real yy, Type ttype);
        bool is_diagonal() const { return type == DIAG; }
        bool is_normal() const { return type == NORMAL; }
        Real getRealX() const; // return the x-coord
        Real getRealY() const; // return the y-coord
        Real persistence_lp(const Real p) const;
        struct LexicographicCmp
        {
            bool    operator()(const DiagramPoint& p1, const DiagramPoint& p2) const
            { return p1.type < p2.type || (p1.type == p2.type && (p1.x < p2.x || (p1.x == p2.x && p1.y < p2.y))); }
        };

        const Real& operator[](const int idx) const
        {
            switch(idx)
            {
                case 0 : return x;
                         break;
                case 1 : return y;
                         break;
                default: throw std::out_of_range("DiagramPoint has dimension 2");
            }
        }

        Real& operator[](const int idx)
        {
            switch(idx)
            {
                case 0 : return x;
                         break;
                case 1 : return y;
                         break;
                default: throw std::out_of_range("DiagramPoint has dimension 2");
            }
        }

    };


    template<class Real>
    struct DiagramPointHash {
        size_t operator()(const DiagramPoint<Real> &p) const
        {
            std::size_t seed = 0;
            hash_combine(seed, std::hash<Real>(p.x));
            hash_combine(seed, std::hash<Real>(p.y));
            hash_combine(seed, std::hash<bool>(p.is_diagonal()));
            return seed;
        }
    };


#ifndef FOR_R_TDA
    template <class Real = double>
    std::ostream& operator<<(std::ostream& output, const DiagramPoint<Real> p);
#endif

    template<class Real>
    void format_arg(fmt::BasicFormatter<char> &f, const char *&format_str, const DiagramPoint<Real>&p) {
        if (p.is_diagonal()) {
            f.writer().write("({0},{1}, DIAG)", p.x, p.y);
        } else {
            f.writer().write("({0},{1}, NORM)", p.x, p.y);
        }
    }


    template<class Real, class Pt>
    struct DistImpl
    {
        Real operator()(const Pt& a, const Pt& b, const Real p, const int dim)
        {
            Real result = 0.0;
            if (hera::is_infinity(p)) {
                for(int d = 0; d < dim; ++d) {
                    result = std::max(result, std::fabs(a[d] - b[d]));
                }
            } else if (p == 1.0) {
                for(int d = 0; d < dim; ++d) {
                    result += std::fabs(a[d] - b[d]);
                }
            } else {
                assert(p > 1.0);
                for(int d = 0; d < dim; ++d) {
                    result += std::pow(std::fabs(a[d] - b[d]), p);
                }
                result = std::pow(result, 1.0 / p);
            }
            return result;
        }
    };

    template<class Real>
    struct DistImpl<Real, DiagramPoint<Real>>
    {
        Real operator()(const DiagramPoint<Real>& a, const DiagramPoint<Real>& b, const Real p, const int dim)
        {
            Real result = 0.0;
            if ( a.is_diagonal() and b.is_diagonal()) {
                return result;
            } else if (hera::is_infinity(p)) {
               result = std::max(std::fabs(a.getRealX() - b.getRealX()), std::fabs(a.getRealY() - b.getRealY()));
            } else if (p == 1.0) {
                result = std::fabs(a.getRealX() - b.getRealX()) + std::fabs(a.getRealY() - b.getRealY());
            } else {
                assert(p > 1.0);
                result = std::pow(std::pow(std::fabs(a.getRealX() - b.getRealX()), p) + std::pow(std::fabs(a.getRealY() - b.getRealY()), p), 1.0 / p);
            }
            return result;
        }
    };

    template<class R, class Pt>
    R dist_lp(const Pt& a, const Pt& b, const R p, const int dim)
    {
        return DistImpl<R, Pt>()(a, b, p, dim);
    }

    // TODO
    template<class Real, typename DiagPointContainer>
    double getFurthestDistance3Approx(DiagPointContainer& A, DiagPointContainer& B, const Real p)
    {
        int dim = 2;
        Real result { 0.0 };
        DiagramPoint<Real> begA = *(A.begin());
        DiagramPoint<Real> optB = *(B.begin());
        for(const auto& pointB : B) {
            if (dist_lp(begA, pointB, p, dim) > result) {
                result = dist_lp(begA, pointB, p, dim);
                optB = pointB;
            }
        }
        for(const auto& pointA : A) {
            if (dist_lp(pointA, optB, p, dim) > result) {
                result = dist_lp(pointA, optB, p, dim);
            }
        }
        return result;
    }

    template<class Real>
    Real getFurthestDistance3Approx_pg(const hera::ws::dnn::DynamicPointVector<Real>& A, const hera::ws::dnn::DynamicPointVector<Real>& B, const Real p, const int dim)
    {
        Real result { 0.0 };
        int opt_b_idx = 0;
        for(size_t b_idx = 0; b_idx < B.size(); ++b_idx) {
            if (dist_lp(A[0], B[b_idx], p, dim) > result) {
                result = dist_lp(A[0], B[b_idx], p, dim);
                opt_b_idx = b_idx;
            }
        }

        for(size_t a_idx = 0; a_idx < A.size(); ++a_idx) {
            result = std::max(result,  dist_lp(A[a_idx], B[opt_b_idx], p, dim));
        }

        return result;
    }


    template<class Container>
    std::string format_container_to_log(const Container& cont);

    template<class Real, class IndexContainer>
    std::string format_point_set_to_log(const IndexContainer& indices, const std::vector<DiagramPoint<Real>>& points);

    template<class T>
    std::string format_int(T i);

} // ws
} // hera



#include "basic_defs_ws.hpp"


#endif