summaryrefslogtreecommitdiff
path: root/wasserstein/tests
diff options
context:
space:
mode:
authorArnur Nigmetov <anigmetov@lbl.gov>2021-04-21 13:18:41 -0700
committerArnur Nigmetov <anigmetov@lbl.gov>2021-04-21 13:18:41 -0700
commitb528c4067a8aac346eb307d3c23b82d5953cfe2d (patch)
treec46898aa945cc94fddd1e1c75827878623ac6d28 /wasserstein/tests
parent7af824834e97c703c7724a6649d059639c8a1e36 (diff)
Fix Wasserstein: ignore diagonal points in are_equal.
Diffstat (limited to 'wasserstein/tests')
-rw-r--r--wasserstein/tests/test_hera_wasserstein.cpp12
1 files changed, 12 insertions, 0 deletions
diff --git a/wasserstein/tests/test_hera_wasserstein.cpp b/wasserstein/tests/test_hera_wasserstein.cpp
index 621dd5a..6f5de3b 100644
--- a/wasserstein/tests/test_hera_wasserstein.cpp
+++ b/wasserstein/tests/test_hera_wasserstein.cpp
@@ -111,6 +111,18 @@ TEST_CASE("simple cases", "wasserstein_dist")
}
+ SECTION("trivial: two diagrams differing by diagonal point") {
+
+ diagram_A.emplace_back(0.0, 1.0);
+ diagram_B.emplace_back(0.0, 0.0);
+ diagram_B.emplace_back(0.0, 1.0);
+
+ double d1 = hera::wasserstein_cost<>(diagram_A, diagram_B, params);
+ double d2 = hera::wasserstein_cost<>(diagram_B, diagram_A, params);
+ REQUIRE( fabs(d2) <= 0.00000000001 );
+ REQUIRE( fabs(d1) <= 0.00000000001 );
+ }
+
}