diff options
author | Arnur Nigmetov <anigmetov@lbl.gov> | 2021-04-21 12:44:41 -0700 |
---|---|---|
committer | Arnur Nigmetov <anigmetov@lbl.gov> | 2021-04-21 12:44:41 -0700 |
commit | 7af824834e97c703c7724a6649d059639c8a1e36 (patch) | |
tree | 3e2cadf375342596ab8b80e886b6e447af9ccba1 | |
parent | 66789dd1319402fe573395e20a4e534a9e7142c7 (diff) |
Ignore diagonal points in wasserstein_cost, support points with both coords infinite
-rw-r--r-- | wasserstein/include/wasserstein.h | 60 | ||||
-rw-r--r-- | wasserstein/tests/data/test_inf_1_A | 3 | ||||
-rw-r--r-- | wasserstein/tests/data/test_inf_1_B | 4 | ||||
-rw-r--r-- | wasserstein/tests/data/test_inf_2_A | 3 | ||||
-rw-r--r-- | wasserstein/tests/data/test_inf_2_B | 5 | ||||
-rw-r--r-- | wasserstein/tests/data/test_inf_3_A | 6 | ||||
-rw-r--r-- | wasserstein/tests/data/test_inf_3_B | 5 | ||||
-rw-r--r-- | wasserstein/tests/data/test_list.txt | 3 | ||||
-rw-r--r-- | wasserstein/tests/test_hera_wasserstein.cpp | 17 |
9 files changed, 86 insertions, 20 deletions
diff --git a/wasserstein/include/wasserstein.h b/wasserstein/include/wasserstein.h index db6ce11..3d9fd17 100644 --- a/wasserstein/include/wasserstein.h +++ b/wasserstein/include/wasserstein.h @@ -255,6 +255,9 @@ wasserstein_cost(const PairContainer& A, //using PointType = typename Traits::PointType; using RealType = typename Traits::RealType; + constexpr RealType plus_inf = std::numeric_limits<RealType>::infinity(); + constexpr RealType minus_inf = -std::numeric_limits<RealType>::infinity(); + if (hera::ws::are_equal(A, B)) { return 0.0; } @@ -270,19 +273,34 @@ wasserstein_cost(const PairContainer& A, // coordinates of points at infinity std::vector<RealType> x_plus_A, x_minus_A, y_plus_A, y_minus_A; std::vector<RealType> x_plus_B, x_minus_B, y_plus_B, y_minus_B; + // points with both coordinates infinite are treated as equal + int n_minus_inf_plus_inf_A = 0; + int n_plus_inf_minus_inf_A = 0; + int n_minus_inf_plus_inf_B = 0; + int n_plus_inf_minus_inf_B = 0; // loop over A, add projections of A-points to corresponding positions // in B-vector for(auto&& pair_A : A) { a_empty = false; RealType x = Traits::get_x(pair_A); RealType y = Traits::get_y(pair_A); - if ( x == std::numeric_limits<RealType>::infinity()) { + + // skip diagonal points, including (inf, inf), (-inf, -inf) + if (x == y) { + continue; + } + + if (x == plus_inf && y == minus_inf) { + n_plus_inf_minus_inf_A++; + } else if (x == minus_inf && y == plus_inf) { + n_minus_inf_plus_inf_A++; + } else if ( x == plus_inf) { y_plus_A.push_back(y); - } else if (x == -std::numeric_limits<RealType>::infinity()) { + } else if (x == minus_inf) { y_minus_A.push_back(y); - } else if (y == std::numeric_limits<RealType>::infinity()) { + } else if (y == plus_inf) { x_plus_A.push_back(x); - } else if (y == -std::numeric_limits<RealType>::infinity()) { + } else if (y == minus_inf) { x_minus_A.push_back(x); } else { dgm_A.emplace_back(x, y, DgmPoint::NORMAL); @@ -295,13 +313,22 @@ wasserstein_cost(const PairContainer& A, b_empty = false; RealType x = Traits::get_x(pair_B); RealType y = Traits::get_y(pair_B); - if (x == std::numeric_limits<RealType>::infinity()) { + + if (x == y) { + continue; + } + + if (x == plus_inf && y == minus_inf) { + n_plus_inf_minus_inf_B++; + } else if (x == minus_inf && y == plus_inf) { + n_minus_inf_plus_inf_B++; + } else if (x == plus_inf) { y_plus_B.push_back(y); - } else if (x == -std::numeric_limits<RealType>::infinity()) { + } else if (x == minus_inf) { y_minus_B.push_back(y); - } else if (y == std::numeric_limits<RealType>::infinity()) { + } else if (y == plus_inf) { x_plus_B.push_back(x); - } else if (y == -std::numeric_limits<RealType>::infinity()) { + } else if (y == minus_inf) { x_minus_B.push_back(x); } else { dgm_A.emplace_back(x, y, DgmPoint::DIAG); @@ -310,10 +337,16 @@ wasserstein_cost(const PairContainer& A, } } - RealType infinity_cost = ws::get_one_dimensional_cost(x_plus_A, x_plus_B, params.wasserstein_power); - infinity_cost += ws::get_one_dimensional_cost(x_minus_A, x_minus_B, params.wasserstein_power); - infinity_cost += ws::get_one_dimensional_cost(y_plus_A, y_plus_B, params.wasserstein_power); - infinity_cost += ws::get_one_dimensional_cost(y_minus_A, y_minus_B, params.wasserstein_power); + RealType infinity_cost = 0; + + if (n_plus_inf_minus_inf_A != n_plus_inf_minus_inf_B || n_minus_inf_plus_inf_A != n_minus_inf_plus_inf_B) + infinity_cost = plus_inf; + else { + infinity_cost += ws::get_one_dimensional_cost(x_plus_A, x_plus_B, params.wasserstein_power); + infinity_cost += ws::get_one_dimensional_cost(x_minus_A, x_minus_B, params.wasserstein_power); + infinity_cost += ws::get_one_dimensional_cost(y_plus_A, y_plus_B, params.wasserstein_power); + infinity_cost += ws::get_one_dimensional_cost(y_minus_A, y_minus_B, params.wasserstein_power); + } if (a_empty) return total_cost_B + infinity_cost; @@ -321,8 +354,7 @@ wasserstein_cost(const PairContainer& A, if (b_empty) return total_cost_A + infinity_cost; - - if (infinity_cost == std::numeric_limits<RealType>::infinity()) { + if (infinity_cost == plus_inf) { return infinity_cost; } else { return infinity_cost + wasserstein_cost_vec(dgm_A, dgm_B, params, _log_filename_prefix); diff --git a/wasserstein/tests/data/test_inf_1_A b/wasserstein/tests/data/test_inf_1_A new file mode 100644 index 0000000..c773f02 --- /dev/null +++ b/wasserstein/tests/data/test_inf_1_A @@ -0,0 +1,3 @@ +-inf inf +-inf inf +2 1 diff --git a/wasserstein/tests/data/test_inf_1_B b/wasserstein/tests/data/test_inf_1_B new file mode 100644 index 0000000..a55e496 --- /dev/null +++ b/wasserstein/tests/data/test_inf_1_B @@ -0,0 +1,4 @@ +-inf inf +-inf inf +inf inf +4 9 diff --git a/wasserstein/tests/data/test_inf_2_A b/wasserstein/tests/data/test_inf_2_A new file mode 100644 index 0000000..c773f02 --- /dev/null +++ b/wasserstein/tests/data/test_inf_2_A @@ -0,0 +1,3 @@ +-inf inf +-inf inf +2 1 diff --git a/wasserstein/tests/data/test_inf_2_B b/wasserstein/tests/data/test_inf_2_B new file mode 100644 index 0000000..6d7e751 --- /dev/null +++ b/wasserstein/tests/data/test_inf_2_B @@ -0,0 +1,5 @@ +-inf inf +-inf inf +inf -inf +inf inf +4 9 diff --git a/wasserstein/tests/data/test_inf_3_A b/wasserstein/tests/data/test_inf_3_A new file mode 100644 index 0000000..4f3fc2f --- /dev/null +++ b/wasserstein/tests/data/test_inf_3_A @@ -0,0 +1,6 @@ +-inf inf +-inf inf +inf -inf +-inf -inf +-inf -inf +2 1 diff --git a/wasserstein/tests/data/test_inf_3_B b/wasserstein/tests/data/test_inf_3_B new file mode 100644 index 0000000..6d7e751 --- /dev/null +++ b/wasserstein/tests/data/test_inf_3_B @@ -0,0 +1,5 @@ +-inf inf +-inf inf +inf -inf +inf inf +4 9 diff --git a/wasserstein/tests/data/test_list.txt b/wasserstein/tests/data/test_list.txt index 27340d8..b1ba6ed 100644 --- a/wasserstein/tests/data/test_list.txt +++ b/wasserstein/tests/data/test_list.txt @@ -19,3 +19,6 @@ test_100_A test_100_B 3.0 2.0 2.09695346034248 test_diag1_A test_diag1_B 1.0 -1.0 0.0 test_diag2_A test_diag2_B 1.0 -1.0 0.0 test_diag3_A test_diag3_B 1.0 -1.0 0.0 +test_inf_1_A test_inf_1_B 1.0 -1.0 3.0 +test_inf_2_A test_inf_2_B 1.0 -1.0 inf +test_inf_3_A test_inf_3_B 1.0 -1.0 3.0 diff --git a/wasserstein/tests/test_hera_wasserstein.cpp b/wasserstein/tests/test_hera_wasserstein.cpp index 0a80d2f..621dd5a 100644 --- a/wasserstein/tests/test_hera_wasserstein.cpp +++ b/wasserstein/tests/test_hera_wasserstein.cpp @@ -130,9 +130,14 @@ TEST_CASE("file cases", "wasserstein_dist") SECTION("from file:") { - const char* file_name = "../tests/data/test_list.txt"; + const char* file_name = "test_list.txt"; std::ifstream f; f.open(file_name); + if (!f.good()) { + std::cerr << "Must run from tests/data" << std::endl; + REQUIRE(false); + } + std::vector<TestFromFileCase> test_params; std::string s; while (std::getline(f, s)) { @@ -147,13 +152,13 @@ TEST_CASE("file cases", "wasserstein_dist") REQUIRE( read_file_A ); REQUIRE( read_file_B ); double hera_answer = hera::wasserstein_dist(diagram_A, diagram_B, params); - REQUIRE( fabs(hera_answer - ts.answer) <= 0.01 * hera_answer ); - std::cout << ts << " PASSED " << std::endl; + bool is_correct = (hera_answer == ts.answer) || (fabs(hera_answer - ts.answer) <= 0.01 * hera_answer); + REQUIRE(is_correct); } } SECTION("from DIPHA file:") { - const char* file_name = "../tests/data/test_list.txt"; + const char* file_name = "test_list.txt"; std::ifstream f; f.open(file_name); std::vector<TestFromFileCase> test_params; @@ -167,8 +172,8 @@ TEST_CASE("file cases", "wasserstein_dist") params.internal_p = ts.internal_p; bool read_file_A = hera::read_diagram_dipha<double, PairVector>(ts.file_1 + std::string(".pd.dipha"), 1, diagram_A); bool read_file_B = hera::read_diagram_dipha<double, PairVector>(ts.file_2 + std::string(".pd.dipha"), 1, diagram_B); - REQUIRE( read_file_A ); - REQUIRE( read_file_B ); + if (!read_file_A) + continue; double hera_answer = hera::wasserstein_dist(diagram_A, diagram_B, params); REQUIRE( fabs(hera_answer - ts.answer) <= 0.01 * hera_answer ); std::cout << ts << " PASSED " << std::endl; |