From e77514fbcb7e8aa2d91747e95e1250820d223bae Mon Sep 17 00:00:00 2001 From: Arnur Nigmetov Date: Thu, 25 Jan 2018 10:33:41 +0100 Subject: Fix for inf-only diagrams --- geom_matching/wasserstein/include/wasserstein.h | 3 + .../wasserstein/tests/test_hera_wasserstein.cpp | 355 +++++++++++++++++++++ 2 files changed, 358 insertions(+) diff --git a/geom_matching/wasserstein/include/wasserstein.h b/geom_matching/wasserstein/include/wasserstein.h index d843a04..b90a545 100644 --- a/geom_matching/wasserstein/include/wasserstein.h +++ b/geom_matching/wasserstein/include/wasserstein.h @@ -228,6 +228,9 @@ namespace ws throw std::runtime_error("Bad epsilon factor in Wasserstein " + std::to_string(params.epsilon_common_ratio)); } + if (A.empty() and B.empty()) + return 0.0; + RealType result; // just use Gauss-Seidel diff --git a/geom_matching/wasserstein/tests/test_hera_wasserstein.cpp b/geom_matching/wasserstein/tests/test_hera_wasserstein.cpp index 2ee49ad..a1a257b 100644 --- a/geom_matching/wasserstein/tests/test_hera_wasserstein.cpp +++ b/geom_matching/wasserstein/tests/test_hera_wasserstein.cpp @@ -201,3 +201,358 @@ TEST_CASE("file cases", "wasserstein_dist") } } } + + + +TEST_CASE("infinity points", "wasserstein_dist") +{ + PairVector diagram_A, diagram_B; + hera::AuctionParams params; + params.wasserstein_power = 1.0; + params.delta = 0.01; + params.internal_p = hera::get_infinity(); + params.initial_epsilon = 0.0; + params.epsilon_common_ratio = 0.0; + params.max_num_phases = 30; + params.gamma_threshold = 0.0; + params.max_bids_per_round = 0; // use Jacobi + + // do not use Hera's infinity! it is -1 + double inf = std::numeric_limits::infinity(); + + SECTION("two points at infinity, no finite points") { + + // edge cost 1.0 + diagram_A.emplace_back(1.0, inf); + diagram_B.emplace_back(2.0, inf); + + double d = hera::wasserstein_dist<>(diagram_A, diagram_B, params); + double corr_answer = 1.0; + REQUIRE( fabs(d - corr_answer) <= 0.00000000001 ); + } + + SECTION("two points at infinity") { + + // edge cost 3.0 + diagram_A.emplace_back(10.0, 20.0); // (5, 5) + diagram_B.emplace_back(13.0, 19.0); // (3, 3) + + // edge cost 1.0 + diagram_A.emplace_back(1.0, inf); + diagram_B.emplace_back(2.0, inf); + + double d = hera::wasserstein_dist<>(diagram_A, diagram_B, params); + double corr_answer = 1.0 + 3.0; + REQUIRE( fabs(d - corr_answer) <= 0.00000000001 ); + } + + SECTION("three points at infinity, no finite points") { + + // edge cost 1.0 + diagram_A.emplace_back(1.0, inf); + diagram_B.emplace_back(2.0, inf); + diagram_B.emplace_back(2.0, inf); + + double d = hera::wasserstein_dist<>(diagram_A, diagram_B, params); + double corr_answer = inf; + REQUIRE( d == corr_answer ); + } + + SECTION("three points at infinity") { + + // edge cost 3.0 + diagram_A.emplace_back(10.0, 20.0); // (5, 5) + diagram_B.emplace_back(13.0, 19.0); // (3, 3) + + // edge cost 1.0 + diagram_A.emplace_back(1.0, inf); + diagram_A.emplace_back(1.0, inf); + diagram_B.emplace_back(2.0, inf); + + double d = hera::wasserstein_dist<>(diagram_A, diagram_B, params); + double corr_answer = inf; + REQUIRE( d == corr_answer ); + } + + + SECTION("all four corners at infinity, no finite points, finite answer") { + + // edge cost 1.0 + diagram_A.emplace_back(1.0, inf); + diagram_B.emplace_back(2.0, inf); + + // edge cost 1.0 + diagram_A.emplace_back(1.0, -inf); + diagram_B.emplace_back(2.0, -inf); + + // edge cost 1.0 + diagram_A.emplace_back(inf, 1.0); + diagram_B.emplace_back(inf, 2.0); + + // edge cost 1.0 + diagram_A.emplace_back(-inf, 1.0); + diagram_B.emplace_back(-inf, 2.0); + + double d = hera::wasserstein_dist<>(diagram_A, diagram_B, params); + double corr_answer = 4.0; + + REQUIRE( d == corr_answer ); + } + + SECTION("all four corners at infinity, no finite points, infinite answer-1") { + + // edge cost 1.0 + diagram_A.emplace_back(1.0, inf); + diagram_A.emplace_back(1.0, inf); + diagram_B.emplace_back(2.0, inf); + + // edge cost 1.0 + diagram_A.emplace_back(1.0, -inf); + diagram_B.emplace_back(2.0, -inf); + + // edge cost 1.0 + diagram_A.emplace_back(inf, 1.0); + diagram_B.emplace_back(inf, 2.0); + + // edge cost 1.0 + diagram_A.emplace_back(-inf, 1.0); + diagram_B.emplace_back(-inf, 2.0); + + double d1 = hera::wasserstein_dist<>(diagram_A, diagram_B, params); + double d2 = hera::wasserstein_dist<>(diagram_B, diagram_A, params); + double corr_answer = inf; + + REQUIRE( d1 == corr_answer ); + REQUIRE( d2 == corr_answer ); + } + + SECTION("all four corners at infinity, no finite points, infinite answer-2") { + + // edge cost 1.0 + diagram_A.emplace_back(1.0, inf); + diagram_B.emplace_back(2.0, inf); + + // edge cost 1.0 + diagram_A.emplace_back(1.0, -inf); + diagram_B.emplace_back(2.0, -inf); + diagram_B.emplace_back(2.0, -inf); + + // edge cost 1.0 + diagram_A.emplace_back(inf, 1.0); + diagram_B.emplace_back(inf, 2.0); + + // edge cost 1.0 + diagram_A.emplace_back(-inf, 1.0); + diagram_B.emplace_back(-inf, 2.0); + + double d1 = hera::wasserstein_dist<>(diagram_A, diagram_B, params); + double d2 = hera::wasserstein_dist<>(diagram_B, diagram_A, params); + double corr_answer = inf; + + REQUIRE( d1 == corr_answer ); + REQUIRE( d2 == corr_answer ); + } + + SECTION("all four corners at infinity, no finite points, infinite answer-3") { + + // edge cost 1.0 + diagram_A.emplace_back(1.0, inf); + diagram_B.emplace_back(2.0, inf); + + // edge cost 1.0 + diagram_A.emplace_back(1.0, -inf); + diagram_B.emplace_back(2.0, -inf); + + // edge cost 1.0 + diagram_A.emplace_back(inf, 1.0); + diagram_A.emplace_back(inf, 1.0); + diagram_B.emplace_back(inf, 2.0); + + // edge cost 1.0 + diagram_A.emplace_back(-inf, 1.0); + diagram_B.emplace_back(-inf, 2.0); + + double d1 = hera::wasserstein_dist<>(diagram_A, diagram_B, params); + double d2 = hera::wasserstein_dist<>(diagram_B, diagram_A, params); + double corr_answer = inf; + + REQUIRE( d1 == corr_answer ); + REQUIRE( d2 == corr_answer ); + } + + SECTION("all four corners at infinity, no finite points, infinite answer-4") { + + // edge cost 1.0 + diagram_A.emplace_back(1.0, inf); + diagram_B.emplace_back(2.0, inf); + + // edge cost 1.0 + diagram_A.emplace_back(1.0, -inf); + diagram_B.emplace_back(2.0, -inf); + + // edge cost 1.0 + diagram_A.emplace_back(inf, 1.0); + diagram_B.emplace_back(inf, 2.0); + + // edge cost 1.0 + diagram_A.emplace_back(-inf, 1.0); + diagram_B.emplace_back(-inf, 2.0); + diagram_B.emplace_back(-inf, 2.0); + + double d1 = hera::wasserstein_dist<>(diagram_A, diagram_B, params); + double d2 = hera::wasserstein_dist<>(diagram_B, diagram_A, params); + double corr_answer = inf; + + REQUIRE( d1 == corr_answer ); + REQUIRE( d2 == corr_answer ); + } + + SECTION("all four corners at infinity, with finite points, infinite answer-1") { + + diagram_A.emplace_back(1.0, inf); + diagram_A.emplace_back(1.0, inf); + diagram_B.emplace_back(2.0, inf); + + diagram_A.emplace_back(1.0, -inf); + diagram_B.emplace_back(2.0, -inf); + + diagram_A.emplace_back(inf, 1.0); + diagram_B.emplace_back(inf, 2.0); + + diagram_A.emplace_back(-inf, 1.0); + diagram_B.emplace_back(-inf, 2.0); + + // finite edge + diagram_A.emplace_back(10.0, 20.0); + diagram_B.emplace_back(13.0, 19.0); + + double d1 = hera::wasserstein_dist<>(diagram_A, diagram_B, params); + double d2 = hera::wasserstein_dist<>(diagram_B, diagram_A, params); + double corr_answer = inf; + + REQUIRE( d1 == corr_answer ); + REQUIRE( d2 == corr_answer ); + } + + SECTION("all four corners at infinity, with finite points, infinite answer-2") { + + diagram_A.emplace_back(1.0, inf); + diagram_B.emplace_back(2.0, inf); + + diagram_A.emplace_back(1.0, -inf); + diagram_B.emplace_back(2.0, -inf); + diagram_B.emplace_back(2.0, -inf); + + diagram_A.emplace_back(inf, 1.0); + diagram_B.emplace_back(inf, 2.0); + + diagram_A.emplace_back(-inf, 1.0); + diagram_B.emplace_back(-inf, 2.0); + + // finite edge + diagram_A.emplace_back(10.0, 20.0); + diagram_B.emplace_back(13.0, 19.0); + + double d1 = hera::wasserstein_dist<>(diagram_A, diagram_B, params); + double d2 = hera::wasserstein_dist<>(diagram_B, diagram_A, params); + double corr_answer = inf; + + REQUIRE( d1 == corr_answer ); + REQUIRE( d2 == corr_answer ); + } + + SECTION("all four corners at infinity, with finite points, infinite answer-3") { + + diagram_A.emplace_back(1.0, inf); + diagram_B.emplace_back(2.0, inf); + + diagram_A.emplace_back(1.0, -inf); + diagram_B.emplace_back(2.0, -inf); + + diagram_A.emplace_back(inf, 1.0); + diagram_A.emplace_back(inf, 1.0); + diagram_B.emplace_back(inf, 2.0); + + diagram_A.emplace_back(-inf, 1.0); + diagram_B.emplace_back(-inf, 2.0); + + // finite edge + diagram_A.emplace_back(10.0, 20.0); + diagram_B.emplace_back(13.0, 19.0); + + double d1 = hera::wasserstein_dist<>(diagram_A, diagram_B, params); + double d2 = hera::wasserstein_dist<>(diagram_B, diagram_A, params); + double corr_answer = inf; + + REQUIRE( d1 == corr_answer ); + REQUIRE( d2 == corr_answer ); + } + + SECTION("all four corners at infinity, no finite points, infinite answer-4") { + + diagram_A.emplace_back(1.0, inf); + diagram_B.emplace_back(2.0, inf); + + diagram_A.emplace_back(1.0, -inf); + diagram_B.emplace_back(2.0, -inf); + + diagram_A.emplace_back(inf, 1.0); + diagram_B.emplace_back(inf, 2.0); + + diagram_A.emplace_back(-inf, 1.0); + diagram_B.emplace_back(-inf, 2.0); + diagram_B.emplace_back(-inf, 2.0); + + // finite edge + diagram_A.emplace_back(10.0, 20.0); + diagram_B.emplace_back(13.0, 19.0); + + double d1 = hera::wasserstein_dist<>(diagram_A, diagram_B, params); + double d2 = hera::wasserstein_dist<>(diagram_B, diagram_A, params); + double corr_answer = inf; + + REQUIRE( d1 == corr_answer ); + REQUIRE( d2 == corr_answer ); + } + + + SECTION("simple small example with finite answer") { + diagram_A.emplace_back(1.0, inf); + diagram_B.emplace_back(2.0, inf); + + diagram_A.emplace_back(1.9, inf); + diagram_B.emplace_back(1.1, inf); + + // 1.1 - 1.0 + 2.0 - 1.9 = 0.2 + + diagram_A.emplace_back(inf, 1.0); + diagram_B.emplace_back(inf, 2.0); + + diagram_A.emplace_back(inf, 1.9); + diagram_B.emplace_back(inf, 1.1); + + + // finite edge + diagram_A.emplace_back(10.0, 20.0); + diagram_B.emplace_back(13.0, 19.0); + + double d1 = hera::wasserstein_dist<>(diagram_A, diagram_B, params); + double d2 = hera::wasserstein_dist<>(diagram_B, diagram_A, params); + double corr_answer = 3.0 + 0.2 + 0.2; + + REQUIRE( d1 == corr_answer ); + REQUIRE( d2 == corr_answer ); + + params.wasserstein_power = 2.0; + + d1 = hera::wasserstein_dist<>(diagram_A, diagram_B, params); + d2 = hera::wasserstein_dist<>(diagram_B, diagram_A, params); + corr_answer = std::sqrt(3.0 * 3.0 + 4 * 0.1 * 0.1); + + REQUIRE( fabs(d1 - corr_answer) < 0.000000000001 ); + REQUIRE( fabs(d2 - corr_answer) < 0.000000000001 ); + + } + +} + -- cgit v1.2.3