summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorArnur Nigmetov <anigmetov@lbl.gov>2021-04-21 12:44:41 -0700
committerArnur Nigmetov <anigmetov@lbl.gov>2021-04-21 12:44:41 -0700
commit7af824834e97c703c7724a6649d059639c8a1e36 (patch)
tree3e2cadf375342596ab8b80e886b6e447af9ccba1
parent66789dd1319402fe573395e20a4e534a9e7142c7 (diff)
Ignore diagonal points in wasserstein_cost, support points with both coords infinite
-rw-r--r--wasserstein/include/wasserstein.h60
-rw-r--r--wasserstein/tests/data/test_inf_1_A3
-rw-r--r--wasserstein/tests/data/test_inf_1_B4
-rw-r--r--wasserstein/tests/data/test_inf_2_A3
-rw-r--r--wasserstein/tests/data/test_inf_2_B5
-rw-r--r--wasserstein/tests/data/test_inf_3_A6
-rw-r--r--wasserstein/tests/data/test_inf_3_B5
-rw-r--r--wasserstein/tests/data/test_list.txt3
-rw-r--r--wasserstein/tests/test_hera_wasserstein.cpp17
9 files changed, 86 insertions, 20 deletions
diff --git a/wasserstein/include/wasserstein.h b/wasserstein/include/wasserstein.h
index db6ce11..3d9fd17 100644
--- a/wasserstein/include/wasserstein.h
+++ b/wasserstein/include/wasserstein.h
@@ -255,6 +255,9 @@ wasserstein_cost(const PairContainer& A,
//using PointType = typename Traits::PointType;
using RealType = typename Traits::RealType;
+ constexpr RealType plus_inf = std::numeric_limits<RealType>::infinity();
+ constexpr RealType minus_inf = -std::numeric_limits<RealType>::infinity();
+
if (hera::ws::are_equal(A, B)) {
return 0.0;
}
@@ -270,19 +273,34 @@ wasserstein_cost(const PairContainer& A,
// coordinates of points at infinity
std::vector<RealType> x_plus_A, x_minus_A, y_plus_A, y_minus_A;
std::vector<RealType> x_plus_B, x_minus_B, y_plus_B, y_minus_B;
+ // points with both coordinates infinite are treated as equal
+ int n_minus_inf_plus_inf_A = 0;
+ int n_plus_inf_minus_inf_A = 0;
+ int n_minus_inf_plus_inf_B = 0;
+ int n_plus_inf_minus_inf_B = 0;
// loop over A, add projections of A-points to corresponding positions
// in B-vector
for(auto&& pair_A : A) {
a_empty = false;
RealType x = Traits::get_x(pair_A);
RealType y = Traits::get_y(pair_A);
- if ( x == std::numeric_limits<RealType>::infinity()) {
+
+ // skip diagonal points, including (inf, inf), (-inf, -inf)
+ if (x == y) {
+ continue;
+ }
+
+ if (x == plus_inf && y == minus_inf) {
+ n_plus_inf_minus_inf_A++;
+ } else if (x == minus_inf && y == plus_inf) {
+ n_minus_inf_plus_inf_A++;
+ } else if ( x == plus_inf) {
y_plus_A.push_back(y);
- } else if (x == -std::numeric_limits<RealType>::infinity()) {
+ } else if (x == minus_inf) {
y_minus_A.push_back(y);
- } else if (y == std::numeric_limits<RealType>::infinity()) {
+ } else if (y == plus_inf) {
x_plus_A.push_back(x);
- } else if (y == -std::numeric_limits<RealType>::infinity()) {
+ } else if (y == minus_inf) {
x_minus_A.push_back(x);
} else {
dgm_A.emplace_back(x, y, DgmPoint::NORMAL);
@@ -295,13 +313,22 @@ wasserstein_cost(const PairContainer& A,
b_empty = false;
RealType x = Traits::get_x(pair_B);
RealType y = Traits::get_y(pair_B);
- if (x == std::numeric_limits<RealType>::infinity()) {
+
+ if (x == y) {
+ continue;
+ }
+
+ if (x == plus_inf && y == minus_inf) {
+ n_plus_inf_minus_inf_B++;
+ } else if (x == minus_inf && y == plus_inf) {
+ n_minus_inf_plus_inf_B++;
+ } else if (x == plus_inf) {
y_plus_B.push_back(y);
- } else if (x == -std::numeric_limits<RealType>::infinity()) {
+ } else if (x == minus_inf) {
y_minus_B.push_back(y);
- } else if (y == std::numeric_limits<RealType>::infinity()) {
+ } else if (y == plus_inf) {
x_plus_B.push_back(x);
- } else if (y == -std::numeric_limits<RealType>::infinity()) {
+ } else if (y == minus_inf) {
x_minus_B.push_back(x);
} else {
dgm_A.emplace_back(x, y, DgmPoint::DIAG);
@@ -310,10 +337,16 @@ wasserstein_cost(const PairContainer& A,
}
}
- RealType infinity_cost = ws::get_one_dimensional_cost(x_plus_A, x_plus_B, params.wasserstein_power);
- infinity_cost += ws::get_one_dimensional_cost(x_minus_A, x_minus_B, params.wasserstein_power);
- infinity_cost += ws::get_one_dimensional_cost(y_plus_A, y_plus_B, params.wasserstein_power);
- infinity_cost += ws::get_one_dimensional_cost(y_minus_A, y_minus_B, params.wasserstein_power);
+ RealType infinity_cost = 0;
+
+ if (n_plus_inf_minus_inf_A != n_plus_inf_minus_inf_B || n_minus_inf_plus_inf_A != n_minus_inf_plus_inf_B)
+ infinity_cost = plus_inf;
+ else {
+ infinity_cost += ws::get_one_dimensional_cost(x_plus_A, x_plus_B, params.wasserstein_power);
+ infinity_cost += ws::get_one_dimensional_cost(x_minus_A, x_minus_B, params.wasserstein_power);
+ infinity_cost += ws::get_one_dimensional_cost(y_plus_A, y_plus_B, params.wasserstein_power);
+ infinity_cost += ws::get_one_dimensional_cost(y_minus_A, y_minus_B, params.wasserstein_power);
+ }
if (a_empty)
return total_cost_B + infinity_cost;
@@ -321,8 +354,7 @@ wasserstein_cost(const PairContainer& A,
if (b_empty)
return total_cost_A + infinity_cost;
-
- if (infinity_cost == std::numeric_limits<RealType>::infinity()) {
+ if (infinity_cost == plus_inf) {
return infinity_cost;
} else {
return infinity_cost + wasserstein_cost_vec(dgm_A, dgm_B, params, _log_filename_prefix);
diff --git a/wasserstein/tests/data/test_inf_1_A b/wasserstein/tests/data/test_inf_1_A
new file mode 100644
index 0000000..c773f02
--- /dev/null
+++ b/wasserstein/tests/data/test_inf_1_A
@@ -0,0 +1,3 @@
+-inf inf
+-inf inf
+2 1
diff --git a/wasserstein/tests/data/test_inf_1_B b/wasserstein/tests/data/test_inf_1_B
new file mode 100644
index 0000000..a55e496
--- /dev/null
+++ b/wasserstein/tests/data/test_inf_1_B
@@ -0,0 +1,4 @@
+-inf inf
+-inf inf
+inf inf
+4 9
diff --git a/wasserstein/tests/data/test_inf_2_A b/wasserstein/tests/data/test_inf_2_A
new file mode 100644
index 0000000..c773f02
--- /dev/null
+++ b/wasserstein/tests/data/test_inf_2_A
@@ -0,0 +1,3 @@
+-inf inf
+-inf inf
+2 1
diff --git a/wasserstein/tests/data/test_inf_2_B b/wasserstein/tests/data/test_inf_2_B
new file mode 100644
index 0000000..6d7e751
--- /dev/null
+++ b/wasserstein/tests/data/test_inf_2_B
@@ -0,0 +1,5 @@
+-inf inf
+-inf inf
+inf -inf
+inf inf
+4 9
diff --git a/wasserstein/tests/data/test_inf_3_A b/wasserstein/tests/data/test_inf_3_A
new file mode 100644
index 0000000..4f3fc2f
--- /dev/null
+++ b/wasserstein/tests/data/test_inf_3_A
@@ -0,0 +1,6 @@
+-inf inf
+-inf inf
+inf -inf
+-inf -inf
+-inf -inf
+2 1
diff --git a/wasserstein/tests/data/test_inf_3_B b/wasserstein/tests/data/test_inf_3_B
new file mode 100644
index 0000000..6d7e751
--- /dev/null
+++ b/wasserstein/tests/data/test_inf_3_B
@@ -0,0 +1,5 @@
+-inf inf
+-inf inf
+inf -inf
+inf inf
+4 9
diff --git a/wasserstein/tests/data/test_list.txt b/wasserstein/tests/data/test_list.txt
index 27340d8..b1ba6ed 100644
--- a/wasserstein/tests/data/test_list.txt
+++ b/wasserstein/tests/data/test_list.txt
@@ -19,3 +19,6 @@ test_100_A test_100_B 3.0 2.0 2.09695346034248
test_diag1_A test_diag1_B 1.0 -1.0 0.0
test_diag2_A test_diag2_B 1.0 -1.0 0.0
test_diag3_A test_diag3_B 1.0 -1.0 0.0
+test_inf_1_A test_inf_1_B 1.0 -1.0 3.0
+test_inf_2_A test_inf_2_B 1.0 -1.0 inf
+test_inf_3_A test_inf_3_B 1.0 -1.0 3.0
diff --git a/wasserstein/tests/test_hera_wasserstein.cpp b/wasserstein/tests/test_hera_wasserstein.cpp
index 0a80d2f..621dd5a 100644
--- a/wasserstein/tests/test_hera_wasserstein.cpp
+++ b/wasserstein/tests/test_hera_wasserstein.cpp
@@ -130,9 +130,14 @@ TEST_CASE("file cases", "wasserstein_dist")
SECTION("from file:") {
- const char* file_name = "../tests/data/test_list.txt";
+ const char* file_name = "test_list.txt";
std::ifstream f;
f.open(file_name);
+ if (!f.good()) {
+ std::cerr << "Must run from tests/data" << std::endl;
+ REQUIRE(false);
+ }
+
std::vector<TestFromFileCase> test_params;
std::string s;
while (std::getline(f, s)) {
@@ -147,13 +152,13 @@ TEST_CASE("file cases", "wasserstein_dist")
REQUIRE( read_file_A );
REQUIRE( read_file_B );
double hera_answer = hera::wasserstein_dist(diagram_A, diagram_B, params);
- REQUIRE( fabs(hera_answer - ts.answer) <= 0.01 * hera_answer );
- std::cout << ts << " PASSED " << std::endl;
+ bool is_correct = (hera_answer == ts.answer) || (fabs(hera_answer - ts.answer) <= 0.01 * hera_answer);
+ REQUIRE(is_correct);
}
}
SECTION("from DIPHA file:") {
- const char* file_name = "../tests/data/test_list.txt";
+ const char* file_name = "test_list.txt";
std::ifstream f;
f.open(file_name);
std::vector<TestFromFileCase> test_params;
@@ -167,8 +172,8 @@ TEST_CASE("file cases", "wasserstein_dist")
params.internal_p = ts.internal_p;
bool read_file_A = hera::read_diagram_dipha<double, PairVector>(ts.file_1 + std::string(".pd.dipha"), 1, diagram_A);
bool read_file_B = hera::read_diagram_dipha<double, PairVector>(ts.file_2 + std::string(".pd.dipha"), 1, diagram_B);
- REQUIRE( read_file_A );
- REQUIRE( read_file_B );
+ if (!read_file_A)
+ continue;
double hera_answer = hera::wasserstein_dist(diagram_A, diagram_B, params);
REQUIRE( fabs(hera_answer - ts.answer) <= 0.01 * hera_answer );
std::cout << ts << " PASSED " << std::endl;