summaryrefslogtreecommitdiff
path: root/matching/src/tests/test_matching_distance.cpp
blob: 90baa0fe1ebeba52b7e87bc0b4796f4d070b5204 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
#include "catch/catch.hpp"

#include <sstream>
#include <iostream>
#include <string>

#include "spdlog/spdlog.h"
#include "spdlog/fmt/ostr.h"

#include "common_util.h"
#include "simplex.h"
#include "matching_distance.h"

using namespace md;
namespace spd = spdlog;

TEST_CASE("Different bounds", "[bounds]")
{
    std::vector<Simplex> simplices;
    std::vector<Point> points;

    Real max_x = 10;
    Real max_y = 20;

    int simplex_id = 0;
    for(int i = 0; i <= max_x; ++i) {
        for(int j = 0; j <= max_y; ++j) {
            Point p(i, j);
            simplices.emplace_back(simplex_id++, p, 0, Column());
            points.push_back(p);
        }
    }

    Bifiltration bif_a(simplices.begin(), simplices.end());
    Bifiltration bif_b(simplices.begin(), simplices.end());

    CalculationParams params;
    params.initialization_depth = 2;

    BifiltrationProxy bifp_a(bif_a, params.dim);
    BifiltrationProxy bifp_b(bif_b, params.dim);

    DistanceCalculator<BifiltrationProxy> calc(bifp_a, bifp_b, params);

//    REQUIRE(calc.max_x_ == Approx(max_x));
//    REQUIRE(calc.max_y_ == Approx(max_y));

    std::vector<DualBox> boxes;

    for(CellWithValue c : calc.get_refined_grid(5, false, false)) {
        boxes.push_back(c.dual_box());
    }

    // fill in boxes and points

    for(DualBox db : boxes) {
        Real local_bound = calc.get_local_dual_bound(db);
        Real local_bound_refined = calc.get_local_refined_bound(db);
        REQUIRE(local_bound >= local_bound_refined);
        for(Point p : points) {
            for(ValuePoint vp_a : k_corner_vps) {
                CellWithValue dual_cell(db, 1);
                DualPoint corner_a = dual_cell.value_point(vp_a);
                Real wp_a = corner_a.weighted_push(p);
                dual_cell.set_value_at(vp_a, wp_a);
                Real point_bound = calc.get_max_displacement_single_point(dual_cell, vp_a, p);
                for(ValuePoint vp_b : k_corner_vps) {
                    if (vp_b <= vp_a)
                        continue;
                    DualPoint corner_b = dual_cell.value_point(vp_b);
                    Real wp_b = corner_b.weighted_push(p);
                    Real diff = fabs(wp_a - wp_b);
                    if (not(point_bound <= Approx(local_bound_refined))) {
                        std::cerr << "ERROR point: " << p << ", box = " << db << ", point bound = " << point_bound
                                  << ", refined local = " << local_bound_refined << std::endl;
                        spd::set_level(spd::level::debug);
                        calc.get_max_displacement_single_point(dual_cell, vp_a, p);
                        calc.get_local_refined_bound(db);
                        spd::set_level(spd::level::info);
                    }

                    REQUIRE(point_bound <= Approx(local_bound_refined));
                    REQUIRE(diff <= Approx(point_bound));
                    REQUIRE(diff <= Approx(local_bound_refined));
                }

                for(DualPoint l_random : db.random_points(100)) {
                    Real wp_random = l_random.weighted_push(p);
                    Real diff = fabs(wp_a - wp_random);
                    if (not(diff <= Approx(point_bound))) {
                        if (db.critical_points(p).size() > 4) {
                            std::cerr << "ERROR interesting case" << std::endl;
                        } else {
                            std::cerr << "ERROR boring case" << std::endl;
                        }
                        spd::set_level(spd::level::debug);
                        l_random.weighted_push(p);
                        spd::set_level(spd::level::info);
                        std::cerr << "ERROR point: " << p << ", box = " << db << ", point bound = " << point_bound
                                  << ", refined local = " << local_bound_refined;
                        std::cerr << ", random_dual_point = " << l_random << ", wp_random = " << wp_random
                                  << ", diff = " << diff << std::endl;
                    }
                    REQUIRE(diff <= Approx(point_bound));
                }
            }
        }
    }
}

TEST_CASE("Bifiltrations from file", "[matching_distance][small_example][lesnick]")
{
    std::string fname_a, fname_b;

    fname_a = "/home/narn/code/matching_distance/code/python_scripts/prism_1_lesnick.bif";
    fname_b = "/home/narn/code/matching_distance/code/python_scripts/prism_2_lesnick.bif";

    Bifiltration bif_a(fname_a, BifiltrationFormat::phat_like);
    Bifiltration bif_b(fname_b, BifiltrationFormat::phat_like);

    CalculationParams params;

    std::vector<BoundStrategy> bound_strategies {BoundStrategy::local_combined,
                                                 BoundStrategy::local_dual_bound_refined};

    std::vector<TraverseStrategy> traverse_strategies {TraverseStrategy::breadth_first, TraverseStrategy::depth_first};

    std::vector<double> scaling_factors {10, 1.0};

    for(auto bs : bound_strategies) {
        for(auto ts : traverse_strategies) {
            for(double lambda : scaling_factors) {
                Bifiltration bif_a_copy(bif_a);
                Bifiltration bif_b_copy(bif_b);
                bif_a_copy.scale(lambda);
                bif_b_copy.scale(lambda);
                params.bound_strategy = bs;
                params.traverse_strategy = ts;
                params.max_depth = 7;
                params.delta = 0.01;
                params.dim = 1;
                Real answer = matching_distance(bif_a_copy, bif_b_copy, params);
                Real correct_answer = lambda * 1.0;
                REQUIRE(fabs(answer - correct_answer) < lambda * 0.05);
            }
        }
    }
}