diff options
Diffstat (limited to 'wasserstein/tests/test_hera_wasserstein.cpp')
-rw-r--r-- | wasserstein/tests/test_hera_wasserstein.cpp | 29 |
1 files changed, 23 insertions, 6 deletions
diff --git a/wasserstein/tests/test_hera_wasserstein.cpp b/wasserstein/tests/test_hera_wasserstein.cpp index 0a80d2f..6f5de3b 100644 --- a/wasserstein/tests/test_hera_wasserstein.cpp +++ b/wasserstein/tests/test_hera_wasserstein.cpp @@ -111,6 +111,18 @@ TEST_CASE("simple cases", "wasserstein_dist") } + SECTION("trivial: two diagrams differing by diagonal point") { + + diagram_A.emplace_back(0.0, 1.0); + diagram_B.emplace_back(0.0, 0.0); + diagram_B.emplace_back(0.0, 1.0); + + double d1 = hera::wasserstein_cost<>(diagram_A, diagram_B, params); + double d2 = hera::wasserstein_cost<>(diagram_B, diagram_A, params); + REQUIRE( fabs(d2) <= 0.00000000001 ); + REQUIRE( fabs(d1) <= 0.00000000001 ); + } + } @@ -130,9 +142,14 @@ TEST_CASE("file cases", "wasserstein_dist") SECTION("from file:") { - const char* file_name = "../tests/data/test_list.txt"; + const char* file_name = "test_list.txt"; std::ifstream f; f.open(file_name); + if (!f.good()) { + std::cerr << "Must run from tests/data" << std::endl; + REQUIRE(false); + } + std::vector<TestFromFileCase> test_params; std::string s; while (std::getline(f, s)) { @@ -147,13 +164,13 @@ TEST_CASE("file cases", "wasserstein_dist") REQUIRE( read_file_A ); REQUIRE( read_file_B ); double hera_answer = hera::wasserstein_dist(diagram_A, diagram_B, params); - REQUIRE( fabs(hera_answer - ts.answer) <= 0.01 * hera_answer ); - std::cout << ts << " PASSED " << std::endl; + bool is_correct = (hera_answer == ts.answer) || (fabs(hera_answer - ts.answer) <= 0.01 * hera_answer); + REQUIRE(is_correct); } } SECTION("from DIPHA file:") { - const char* file_name = "../tests/data/test_list.txt"; + const char* file_name = "test_list.txt"; std::ifstream f; f.open(file_name); std::vector<TestFromFileCase> test_params; @@ -167,8 +184,8 @@ TEST_CASE("file cases", "wasserstein_dist") params.internal_p = ts.internal_p; bool read_file_A = hera::read_diagram_dipha<double, PairVector>(ts.file_1 + std::string(".pd.dipha"), 1, diagram_A); bool read_file_B = hera::read_diagram_dipha<double, PairVector>(ts.file_2 + std::string(".pd.dipha"), 1, diagram_B); - REQUIRE( read_file_A ); - REQUIRE( read_file_B ); + if (!read_file_A) + continue; double hera_answer = hera::wasserstein_dist(diagram_A, diagram_B, params); REQUIRE( fabs(hera_answer - ts.answer) <= 0.01 * hera_answer ); std::cout << ts << " PASSED " << std::endl; |