summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorArnur Nigmetov <a.nigmetov@gmail.com>2018-01-25 10:33:41 +0100
committerArnur Nigmetov <a.nigmetov@gmail.com>2018-01-25 10:33:41 +0100
commite77514fbcb7e8aa2d91747e95e1250820d223bae (patch)
treebcb7f751d7c48a751756487d0d8f1088f04ff754
parent0cc35ad04f9c2997014d7cf62a12f697e79fb534 (diff)
Fix for inf-only diagrams
-rw-r--r--geom_matching/wasserstein/include/wasserstein.h3
-rw-r--r--geom_matching/wasserstein/tests/test_hera_wasserstein.cpp355
2 files changed, 358 insertions, 0 deletions
diff --git a/geom_matching/wasserstein/include/wasserstein.h b/geom_matching/wasserstein/include/wasserstein.h
index d843a04..b90a545 100644
--- a/geom_matching/wasserstein/include/wasserstein.h
+++ b/geom_matching/wasserstein/include/wasserstein.h
@@ -228,6 +228,9 @@ namespace ws
throw std::runtime_error("Bad epsilon factor in Wasserstein " + std::to_string(params.epsilon_common_ratio));
}
+ if (A.empty() and B.empty())
+ return 0.0;
+
RealType result;
// just use Gauss-Seidel
diff --git a/geom_matching/wasserstein/tests/test_hera_wasserstein.cpp b/geom_matching/wasserstein/tests/test_hera_wasserstein.cpp
index 2ee49ad..a1a257b 100644
--- a/geom_matching/wasserstein/tests/test_hera_wasserstein.cpp
+++ b/geom_matching/wasserstein/tests/test_hera_wasserstein.cpp
@@ -201,3 +201,358 @@ TEST_CASE("file cases", "wasserstein_dist")
}
}
}
+
+
+
+TEST_CASE("infinity points", "wasserstein_dist")
+{
+ PairVector diagram_A, diagram_B;
+ hera::AuctionParams<double> params;
+ params.wasserstein_power = 1.0;
+ params.delta = 0.01;
+ params.internal_p = hera::get_infinity<double>();
+ params.initial_epsilon = 0.0;
+ params.epsilon_common_ratio = 0.0;
+ params.max_num_phases = 30;
+ params.gamma_threshold = 0.0;
+ params.max_bids_per_round = 0; // use Jacobi
+
+ // do not use Hera's infinity! it is -1
+ double inf = std::numeric_limits<double>::infinity();
+
+ SECTION("two points at infinity, no finite points") {
+
+ // edge cost 1.0
+ diagram_A.emplace_back(1.0, inf);
+ diagram_B.emplace_back(2.0, inf);
+
+ double d = hera::wasserstein_dist<>(diagram_A, diagram_B, params);
+ double corr_answer = 1.0;
+ REQUIRE( fabs(d - corr_answer) <= 0.00000000001 );
+ }
+
+ SECTION("two points at infinity") {
+
+ // edge cost 3.0
+ diagram_A.emplace_back(10.0, 20.0); // (5, 5)
+ diagram_B.emplace_back(13.0, 19.0); // (3, 3)
+
+ // edge cost 1.0
+ diagram_A.emplace_back(1.0, inf);
+ diagram_B.emplace_back(2.0, inf);
+
+ double d = hera::wasserstein_dist<>(diagram_A, diagram_B, params);
+ double corr_answer = 1.0 + 3.0;
+ REQUIRE( fabs(d - corr_answer) <= 0.00000000001 );
+ }
+
+ SECTION("three points at infinity, no finite points") {
+
+ // edge cost 1.0
+ diagram_A.emplace_back(1.0, inf);
+ diagram_B.emplace_back(2.0, inf);
+ diagram_B.emplace_back(2.0, inf);
+
+ double d = hera::wasserstein_dist<>(diagram_A, diagram_B, params);
+ double corr_answer = inf;
+ REQUIRE( d == corr_answer );
+ }
+
+ SECTION("three points at infinity") {
+
+ // edge cost 3.0
+ diagram_A.emplace_back(10.0, 20.0); // (5, 5)
+ diagram_B.emplace_back(13.0, 19.0); // (3, 3)
+
+ // edge cost 1.0
+ diagram_A.emplace_back(1.0, inf);
+ diagram_A.emplace_back(1.0, inf);
+ diagram_B.emplace_back(2.0, inf);
+
+ double d = hera::wasserstein_dist<>(diagram_A, diagram_B, params);
+ double corr_answer = inf;
+ REQUIRE( d == corr_answer );
+ }
+
+
+ SECTION("all four corners at infinity, no finite points, finite answer") {
+
+ // edge cost 1.0
+ diagram_A.emplace_back(1.0, inf);
+ diagram_B.emplace_back(2.0, inf);
+
+ // edge cost 1.0
+ diagram_A.emplace_back(1.0, -inf);
+ diagram_B.emplace_back(2.0, -inf);
+
+ // edge cost 1.0
+ diagram_A.emplace_back(inf, 1.0);
+ diagram_B.emplace_back(inf, 2.0);
+
+ // edge cost 1.0
+ diagram_A.emplace_back(-inf, 1.0);
+ diagram_B.emplace_back(-inf, 2.0);
+
+ double d = hera::wasserstein_dist<>(diagram_A, diagram_B, params);
+ double corr_answer = 4.0;
+
+ REQUIRE( d == corr_answer );
+ }
+
+ SECTION("all four corners at infinity, no finite points, infinite answer-1") {
+
+ // edge cost 1.0
+ diagram_A.emplace_back(1.0, inf);
+ diagram_A.emplace_back(1.0, inf);
+ diagram_B.emplace_back(2.0, inf);
+
+ // edge cost 1.0
+ diagram_A.emplace_back(1.0, -inf);
+ diagram_B.emplace_back(2.0, -inf);
+
+ // edge cost 1.0
+ diagram_A.emplace_back(inf, 1.0);
+ diagram_B.emplace_back(inf, 2.0);
+
+ // edge cost 1.0
+ diagram_A.emplace_back(-inf, 1.0);
+ diagram_B.emplace_back(-inf, 2.0);
+
+ double d1 = hera::wasserstein_dist<>(diagram_A, diagram_B, params);
+ double d2 = hera::wasserstein_dist<>(diagram_B, diagram_A, params);
+ double corr_answer = inf;
+
+ REQUIRE( d1 == corr_answer );
+ REQUIRE( d2 == corr_answer );
+ }
+
+ SECTION("all four corners at infinity, no finite points, infinite answer-2") {
+
+ // edge cost 1.0
+ diagram_A.emplace_back(1.0, inf);
+ diagram_B.emplace_back(2.0, inf);
+
+ // edge cost 1.0
+ diagram_A.emplace_back(1.0, -inf);
+ diagram_B.emplace_back(2.0, -inf);
+ diagram_B.emplace_back(2.0, -inf);
+
+ // edge cost 1.0
+ diagram_A.emplace_back(inf, 1.0);
+ diagram_B.emplace_back(inf, 2.0);
+
+ // edge cost 1.0
+ diagram_A.emplace_back(-inf, 1.0);
+ diagram_B.emplace_back(-inf, 2.0);
+
+ double d1 = hera::wasserstein_dist<>(diagram_A, diagram_B, params);
+ double d2 = hera::wasserstein_dist<>(diagram_B, diagram_A, params);
+ double corr_answer = inf;
+
+ REQUIRE( d1 == corr_answer );
+ REQUIRE( d2 == corr_answer );
+ }
+
+ SECTION("all four corners at infinity, no finite points, infinite answer-3") {
+
+ // edge cost 1.0
+ diagram_A.emplace_back(1.0, inf);
+ diagram_B.emplace_back(2.0, inf);
+
+ // edge cost 1.0
+ diagram_A.emplace_back(1.0, -inf);
+ diagram_B.emplace_back(2.0, -inf);
+
+ // edge cost 1.0
+ diagram_A.emplace_back(inf, 1.0);
+ diagram_A.emplace_back(inf, 1.0);
+ diagram_B.emplace_back(inf, 2.0);
+
+ // edge cost 1.0
+ diagram_A.emplace_back(-inf, 1.0);
+ diagram_B.emplace_back(-inf, 2.0);
+
+ double d1 = hera::wasserstein_dist<>(diagram_A, diagram_B, params);
+ double d2 = hera::wasserstein_dist<>(diagram_B, diagram_A, params);
+ double corr_answer = inf;
+
+ REQUIRE( d1 == corr_answer );
+ REQUIRE( d2 == corr_answer );
+ }
+
+ SECTION("all four corners at infinity, no finite points, infinite answer-4") {
+
+ // edge cost 1.0
+ diagram_A.emplace_back(1.0, inf);
+ diagram_B.emplace_back(2.0, inf);
+
+ // edge cost 1.0
+ diagram_A.emplace_back(1.0, -inf);
+ diagram_B.emplace_back(2.0, -inf);
+
+ // edge cost 1.0
+ diagram_A.emplace_back(inf, 1.0);
+ diagram_B.emplace_back(inf, 2.0);
+
+ // edge cost 1.0
+ diagram_A.emplace_back(-inf, 1.0);
+ diagram_B.emplace_back(-inf, 2.0);
+ diagram_B.emplace_back(-inf, 2.0);
+
+ double d1 = hera::wasserstein_dist<>(diagram_A, diagram_B, params);
+ double d2 = hera::wasserstein_dist<>(diagram_B, diagram_A, params);
+ double corr_answer = inf;
+
+ REQUIRE( d1 == corr_answer );
+ REQUIRE( d2 == corr_answer );
+ }
+
+ SECTION("all four corners at infinity, with finite points, infinite answer-1") {
+
+ diagram_A.emplace_back(1.0, inf);
+ diagram_A.emplace_back(1.0, inf);
+ diagram_B.emplace_back(2.0, inf);
+
+ diagram_A.emplace_back(1.0, -inf);
+ diagram_B.emplace_back(2.0, -inf);
+
+ diagram_A.emplace_back(inf, 1.0);
+ diagram_B.emplace_back(inf, 2.0);
+
+ diagram_A.emplace_back(-inf, 1.0);
+ diagram_B.emplace_back(-inf, 2.0);
+
+ // finite edge
+ diagram_A.emplace_back(10.0, 20.0);
+ diagram_B.emplace_back(13.0, 19.0);
+
+ double d1 = hera::wasserstein_dist<>(diagram_A, diagram_B, params);
+ double d2 = hera::wasserstein_dist<>(diagram_B, diagram_A, params);
+ double corr_answer = inf;
+
+ REQUIRE( d1 == corr_answer );
+ REQUIRE( d2 == corr_answer );
+ }
+
+ SECTION("all four corners at infinity, with finite points, infinite answer-2") {
+
+ diagram_A.emplace_back(1.0, inf);
+ diagram_B.emplace_back(2.0, inf);
+
+ diagram_A.emplace_back(1.0, -inf);
+ diagram_B.emplace_back(2.0, -inf);
+ diagram_B.emplace_back(2.0, -inf);
+
+ diagram_A.emplace_back(inf, 1.0);
+ diagram_B.emplace_back(inf, 2.0);
+
+ diagram_A.emplace_back(-inf, 1.0);
+ diagram_B.emplace_back(-inf, 2.0);
+
+ // finite edge
+ diagram_A.emplace_back(10.0, 20.0);
+ diagram_B.emplace_back(13.0, 19.0);
+
+ double d1 = hera::wasserstein_dist<>(diagram_A, diagram_B, params);
+ double d2 = hera::wasserstein_dist<>(diagram_B, diagram_A, params);
+ double corr_answer = inf;
+
+ REQUIRE( d1 == corr_answer );
+ REQUIRE( d2 == corr_answer );
+ }
+
+ SECTION("all four corners at infinity, with finite points, infinite answer-3") {
+
+ diagram_A.emplace_back(1.0, inf);
+ diagram_B.emplace_back(2.0, inf);
+
+ diagram_A.emplace_back(1.0, -inf);
+ diagram_B.emplace_back(2.0, -inf);
+
+ diagram_A.emplace_back(inf, 1.0);
+ diagram_A.emplace_back(inf, 1.0);
+ diagram_B.emplace_back(inf, 2.0);
+
+ diagram_A.emplace_back(-inf, 1.0);
+ diagram_B.emplace_back(-inf, 2.0);
+
+ // finite edge
+ diagram_A.emplace_back(10.0, 20.0);
+ diagram_B.emplace_back(13.0, 19.0);
+
+ double d1 = hera::wasserstein_dist<>(diagram_A, diagram_B, params);
+ double d2 = hera::wasserstein_dist<>(diagram_B, diagram_A, params);
+ double corr_answer = inf;
+
+ REQUIRE( d1 == corr_answer );
+ REQUIRE( d2 == corr_answer );
+ }
+
+ SECTION("all four corners at infinity, no finite points, infinite answer-4") {
+
+ diagram_A.emplace_back(1.0, inf);
+ diagram_B.emplace_back(2.0, inf);
+
+ diagram_A.emplace_back(1.0, -inf);
+ diagram_B.emplace_back(2.0, -inf);
+
+ diagram_A.emplace_back(inf, 1.0);
+ diagram_B.emplace_back(inf, 2.0);
+
+ diagram_A.emplace_back(-inf, 1.0);
+ diagram_B.emplace_back(-inf, 2.0);
+ diagram_B.emplace_back(-inf, 2.0);
+
+ // finite edge
+ diagram_A.emplace_back(10.0, 20.0);
+ diagram_B.emplace_back(13.0, 19.0);
+
+ double d1 = hera::wasserstein_dist<>(diagram_A, diagram_B, params);
+ double d2 = hera::wasserstein_dist<>(diagram_B, diagram_A, params);
+ double corr_answer = inf;
+
+ REQUIRE( d1 == corr_answer );
+ REQUIRE( d2 == corr_answer );
+ }
+
+
+ SECTION("simple small example with finite answer") {
+ diagram_A.emplace_back(1.0, inf);
+ diagram_B.emplace_back(2.0, inf);
+
+ diagram_A.emplace_back(1.9, inf);
+ diagram_B.emplace_back(1.1, inf);
+
+ // 1.1 - 1.0 + 2.0 - 1.9 = 0.2
+
+ diagram_A.emplace_back(inf, 1.0);
+ diagram_B.emplace_back(inf, 2.0);
+
+ diagram_A.emplace_back(inf, 1.9);
+ diagram_B.emplace_back(inf, 1.1);
+
+
+ // finite edge
+ diagram_A.emplace_back(10.0, 20.0);
+ diagram_B.emplace_back(13.0, 19.0);
+
+ double d1 = hera::wasserstein_dist<>(diagram_A, diagram_B, params);
+ double d2 = hera::wasserstein_dist<>(diagram_B, diagram_A, params);
+ double corr_answer = 3.0 + 0.2 + 0.2;
+
+ REQUIRE( d1 == corr_answer );
+ REQUIRE( d2 == corr_answer );
+
+ params.wasserstein_power = 2.0;
+
+ d1 = hera::wasserstein_dist<>(diagram_A, diagram_B, params);
+ d2 = hera::wasserstein_dist<>(diagram_B, diagram_A, params);
+ corr_answer = std::sqrt(3.0 * 3.0 + 4 * 0.1 * 0.1);
+
+ REQUIRE( fabs(d1 - corr_answer) < 0.000000000001 );
+ REQUIRE( fabs(d2 - corr_answer) < 0.000000000001 );
+
+ }
+
+}
+