summaryrefslogtreecommitdiff
path: root/wasserstein/tests/tests_reader.h
blob: f2d5735e2ec503d27170612b8de902f4e2c1be43 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
#ifndef WASSERSTEIN_TESTS_READER_H
#define WASSERSTEIN_TESTS_READER_H

#include <vector>
#include <string>
#include <ostream>
#include <iostream>
#include <sstream>
#include <cassert>
#include <cmath>

#include "hera_infinity.h"

namespace  hera_test {
    inline std::vector<std::string> split_on_delim(const std::string& s, char delim)
    {
        std::stringstream ss(s);
        std::string token;
        std::vector<std::string> tokens;
        while (std::getline(ss, token, delim)) {
            tokens.push_back(token);
        }
        return tokens;
    }


    // single row in a file with test cases
    struct TestFromFileCase
    {

        std::string file_1;
        std::string file_2;
        double q;
        double internal_p;
        double answer;

        TestFromFileCase(std::string s)
        {
            auto tokens = split_on_delim(s, ' ');
            assert(tokens.size() == 5);

            file_1 = tokens.at(0);
            file_2 = tokens.at(1);
            q = std::stod(tokens.at(2));
            internal_p = std::stod(tokens.at(3));
            answer = std::stod(tokens.at(4));

            if (q < 1.0 or std::isinf(q) or
                (internal_p != hera::get_infinity<double>() and internal_p < 1.0)) {
                throw std::runtime_error("Bad line in test_list.txt");
            }
        }
    };

    inline std::ostream& operator<<(std::ostream& out, const TestFromFileCase& s)
    {
        out << "[" << s.file_1 << ", " << s.file_2 << ", q = " << s.q << ", norm = ";
        if (s.internal_p != hera::get_infinity()) {
            out << s.internal_p;
        } else {
            out << "infinity";
        }
        out << ", answer = " << s.answer << "]";
        return out;
    }
} // namespace hera_test
#endif //WASSERSTEIN_TESTS_READER_H