summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGard Spreemann <gspr@nonempty.org>2021-08-14 18:32:59 +0200
committerGard Spreemann <gspr@nonempty.org>2021-08-14 18:32:59 +0200
commit66702d9cf122703964dbe22319ae8d97424d496f (patch)
tree08681d5c5b5878ed4283d5fba2cbb8f4612dbf7c
parent069338dfb03b4d04c1410b3e24b762b18db5c233 (diff)
parent2ed9afc052bee7956f6abb195947de1f80cb9d91 (diff)
Merge branch 'upstream/latest' into dfsg/latest
-rw-r--r--.gitignore4
-rw-r--r--bottleneck/include/bottleneck_detail.hpp28
-rw-r--r--bottleneck/tests/data/test_880_A4
-rw-r--r--bottleneck/tests/data/test_880_B4
-rw-r--r--bottleneck/tests/data/test_list.txt1
-rw-r--r--bottleneck/tests/test_hera_bottleneck.cpp16
-rw-r--r--matching/include/matching_distance.hpp1
-rw-r--r--matching/include/persistence_module.h1
-rw-r--r--matching/include/persistence_module.hpp2
-rw-r--r--wasserstein/include/wasserstein.h70
-rw-r--r--wasserstein/tests/data/test_inf_1_A3
-rw-r--r--wasserstein/tests/data/test_inf_1_B4
-rw-r--r--wasserstein/tests/data/test_inf_2_A3
-rw-r--r--wasserstein/tests/data/test_inf_2_B5
-rw-r--r--wasserstein/tests/data/test_inf_3_A6
-rw-r--r--wasserstein/tests/data/test_inf_3_B5
-rw-r--r--wasserstein/tests/data/test_list.txt3
-rw-r--r--wasserstein/tests/test_hera_wasserstein.cpp29
18 files changed, 139 insertions, 50 deletions
diff --git a/.gitignore b/.gitignore
index 8d53708..8404245 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,6 +1,6 @@
/*.cfg
-/build
-build/
+*build*/
+.idea/
*.gitattributes
*.opensdf
*.sdf
diff --git a/bottleneck/include/bottleneck_detail.hpp b/bottleneck/include/bottleneck_detail.hpp
index 8f51d07..f976909 100644
--- a/bottleneck/include/bottleneck_detail.hpp
+++ b/bottleneck/include/bottleneck_detail.hpp
@@ -51,7 +51,6 @@ namespace hera {
void binarySearch(const Real epsilon,
std::pair<Real, Real>& result,
BoundMatchOracle <Real>& oracle,
- const Real infinityCost,
bool isResultInitializedCorrectly,
const Real distProbeInit)
{
@@ -59,9 +58,6 @@ namespace hera {
Real& distMin = result.first;
Real& distMax = result.second;
- distMin = std::max(distMin, infinityCost);
- distMax = std::max(distMax, infinityCost);
-
Real distProbe;
if (not isResultInitializedCorrectly) {
@@ -87,13 +83,6 @@ namespace hera {
// bounds are correct , perform binary search
distProbe = (distMin + distMax) / 2.0;
while ((distMax - distMin) / distMin >= epsilon) {
-
- if (distMax < infinityCost) {
- distMin = infinityCost;
- distMax = infinityCost;
- break;
- }
-
if (oracle.isMatchLess(distProbe)) {
distMax = distProbe;
} else {
@@ -102,9 +91,6 @@ namespace hera {
distProbe = (distMin + distMax) / 2.0;
}
-
- distMin = std::max(distMin, infinityCost);
- distMax = std::max(distMax, infinityCost);
}
// template<class Real>
@@ -291,7 +277,7 @@ namespace hera {
// get a 3-approximation of maximal distance between A and B
// as a starting value for probe distance
Real distProbe { getFurthestDistance3Approx<Real, DiagramPointSet<Real>>(A, B) };
- binarySearch(epsilon, result, oracle, infinity_cost, false, distProbe);
+ binarySearch(epsilon, result, oracle, false, distProbe);
// to compute longest edge a perfect matching is needed
if (compute_longest_edge and result.first > infinity_cost) {
oracle.isMatchLess(result.second);
@@ -448,8 +434,10 @@ namespace hera {
Real bottleneckDistApprox(DiagramPointSet <Real>& A, DiagramPointSet <Real>& B, const Real epsilon,
MatchingEdge <Real>& longest_edge, bool compute_longest_edge)
{
+ // must compute here: infinity points will be erased in bottleneckDistApproxInterval
+ Real infCost = getInfinityCost(A, B).cost;
auto interval = bottleneckDistApproxInterval<Real>(A, B, epsilon, longest_edge, compute_longest_edge);
- return interval.second;
+ return std::max(infCost, interval.second);
}
@@ -518,6 +506,9 @@ namespace hera {
using DgmPoint = DiagramPoint<Real>;
constexpr Real epsilon = 0.001;
+
+ Real infCost = getInfinityCost(A, B, true).cost;
+
auto interval = bottleneckDistApproxInterval(A, B, epsilon, longest_edge, true);
// if the longest edge is on infinity, the answer is already exact
// this will be detected here and all the code after if
@@ -627,8 +618,9 @@ namespace hera {
pw_dists.push_back(d);
}
- return bottleneckDistExactFromSortedPwDist(A, B, pw_dists, decPrecision, longest_edge,
- compute_longest_edge);
+ Real exactFinite = bottleneckDistExactFromSortedPwDist(A, B, pw_dists, decPrecision, longest_edge, compute_longest_edge);
+
+ return std::max(infCost, exactFinite);
}
} // end namespace bt
diff --git a/bottleneck/tests/data/test_880_A b/bottleneck/tests/data/test_880_A
new file mode 100644
index 0000000..5ad7454
--- /dev/null
+++ b/bottleneck/tests/data/test_880_A
@@ -0,0 +1,4 @@
+0.4217142857142855 inf
+0.8868393268494345 1.8
+0.45290401176484574 1.8
+0.9154285714285715 1.8
diff --git a/bottleneck/tests/data/test_880_B b/bottleneck/tests/data/test_880_B
new file mode 100644
index 0000000..22129a2
--- /dev/null
+++ b/bottleneck/tests/data/test_880_B
@@ -0,0 +1,4 @@
+0.8742857142857143 1.8
+0.46285714285714286 inf
+0.4722695736819835 1.8
+0.8474648251910301 1.8
diff --git a/bottleneck/tests/data/test_list.txt b/bottleneck/tests/data/test_list.txt
index 24b462a..91c03a4 100644
--- a/bottleneck/tests/data/test_list.txt
+++ b/bottleneck/tests/data/test_list.txt
@@ -790,3 +790,4 @@ test_876_A test_876_B 0.01 165
test_877_A test_877_B 0.01 190
test_878_A test_878_B 0.01 242
test_879_A test_879_B 0.01 187
+test_880_A test_880_B 0.0 0.0411428571428574
diff --git a/bottleneck/tests/test_hera_bottleneck.cpp b/bottleneck/tests/test_hera_bottleneck.cpp
index f22e415..c74e800 100644
--- a/bottleneck/tests/test_hera_bottleneck.cpp
+++ b/bottleneck/tests/test_hera_bottleneck.cpp
@@ -159,6 +159,9 @@ TEST_CASE("infinity points", "bottleneckDistApprox")
double d = hera::bottleneckDistApprox<>(diagram_A, diagram_B, delta);
double corr_answer = 1.0;
REQUIRE( fabs(d - corr_answer) <= delta * corr_answer);
+
+ double exact_d = hera::bottleneckDistExact<>(diagram_A, diagram_B);
+ REQUIRE(exact_d == corr_answer);
}
SECTION("two points at infinity") {
@@ -559,15 +562,23 @@ TEST_CASE("file cases", "bottleneck_dist")
SECTION("from file:") {
- for (const auto& ts : test_params) {
+ for (auto& ts : test_params) {
bool read_file_A = hera::readDiagramPointSet(dir_prefix + ts.file_1, diagram_A);
bool read_file_B = hera::readDiagramPointSet(dir_prefix + ts.file_2, diagram_B);
REQUIRE(read_file_A);
REQUIRE(read_file_B);
- double hera_answer = hera::bottleneckDistApprox(diagram_A, diagram_B, ts.delta, longest_edge, true);
+ double hera_answer;
+ if (ts.delta > 0)
+ hera_answer = hera::bottleneckDistApprox(diagram_A, diagram_B, ts.delta, longest_edge, true);
+ else
+ hera_answer = hera::bottleneckDistExact(diagram_A, diagram_B);
std::pair<int, int> hera_le { longest_edge.first.get_user_id(), longest_edge.second.get_user_id() };
+ // cannot store exact answer in test_list.txt, but need to make sure that Exact is called
+ if (ts.delta == 0.0)
+ ts.delta = 0.00000001;
+
REQUIRE((hera_answer == ts.answer or fabs(hera_answer - ts.answer) <= ts.delta * hera_answer));
REQUIRE((ts.longest_edges.empty() or
std::find(ts.longest_edges.begin(), ts.longest_edges.end(), hera_le) != ts.longest_edges.end()));
@@ -604,7 +615,6 @@ TEST_CASE("file cases", "bottleneck_dist")
// << hera_answer_exact << std::endl;
// }
REQUIRE( (not check_longest_edge_cost or fabs(hera_le_cost - hera_answer_exact) < 0.0001 * hera_answer_exact) );
- std::cout << ts << " PASSED " << std::endl;
}
}
diff --git a/matching/include/matching_distance.hpp b/matching/include/matching_distance.hpp
index f7f44a5..9beab1f 100644
--- a/matching/include/matching_distance.hpp
+++ b/matching/include/matching_distance.hpp
@@ -362,6 +362,7 @@ namespace md {
// TODO: think about this - how to call Hera
auto dgm_a = module_a_.weighted_slice_diagram(line);
auto dgm_b = module_b_.weighted_slice_diagram(line);
+
R result;
if (params_.hera_epsilon > static_cast<R>(0)) {
result = hera::bottleneckDistApprox(dgm_a, dgm_b, params_.hera_epsilon) / ( params_.hera_epsilon + 1);
diff --git a/matching/include/persistence_module.h b/matching/include/persistence_module.h
index b68c21e..4df2148 100644
--- a/matching/include/persistence_module.h
+++ b/matching/include/persistence_module.h
@@ -11,6 +11,7 @@
#include "phat/boundary_matrix.h"
#include "phat/compute_persistence_pairs.h"
+#include "phat/algorithms/standard_reduction.h"
#include "common_util.h"
#include "dual_point.h"
diff --git a/matching/include/persistence_module.hpp b/matching/include/persistence_module.hpp
index 7479e02..022b43d 100644
--- a/matching/include/persistence_module.hpp
+++ b/matching/include/persistence_module.hpp
@@ -158,7 +158,7 @@ namespace md {
get_slice_projection_matrix(slice, phat_matrix, gen_projections, rel_projections);
phat::persistence_pairs phat_persistence_pairs;
- phat::compute_persistence_pairs<phat::twist_reduction>(phat_persistence_pairs, phat_matrix);
+ phat::compute_persistence_pairs<phat::standard_reduction>(phat_persistence_pairs, phat_matrix, true);
Diagram<Real> dgm;
diff --git a/wasserstein/include/wasserstein.h b/wasserstein/include/wasserstein.h
index db6ce11..142fcbb 100644
--- a/wasserstein/include/wasserstein.h
+++ b/wasserstein/include/wasserstein.h
@@ -73,21 +73,19 @@ namespace ws
template<class PairContainer>
inline bool are_equal(const PairContainer& dgm1, const PairContainer& dgm2)
{
- if (dgm1.size() != dgm2.size()) {
- return false;
- }
-
using Traits = typename hera::DiagramTraits<PairContainer>;
using PointType = typename Traits::PointType;
std::map<PointType, int> m1, m2;
for(auto&& pair1 : dgm1) {
- m1[pair1]++;
+ if (Traits::get_x(pair1) != Traits::get_y(pair1))
+ m1[pair1]++;
}
for(auto&& pair2 : dgm2) {
- m2[pair2]++;
+ if (Traits::get_x(pair2) != Traits::get_y(pair2))
+ m2[pair2]++;
}
return m1 == m2;
@@ -255,6 +253,9 @@ wasserstein_cost(const PairContainer& A,
//using PointType = typename Traits::PointType;
using RealType = typename Traits::RealType;
+ constexpr RealType plus_inf = std::numeric_limits<RealType>::infinity();
+ constexpr RealType minus_inf = -std::numeric_limits<RealType>::infinity();
+
if (hera::ws::are_equal(A, B)) {
return 0.0;
}
@@ -270,19 +271,34 @@ wasserstein_cost(const PairContainer& A,
// coordinates of points at infinity
std::vector<RealType> x_plus_A, x_minus_A, y_plus_A, y_minus_A;
std::vector<RealType> x_plus_B, x_minus_B, y_plus_B, y_minus_B;
+ // points with both coordinates infinite are treated as equal
+ int n_minus_inf_plus_inf_A = 0;
+ int n_plus_inf_minus_inf_A = 0;
+ int n_minus_inf_plus_inf_B = 0;
+ int n_plus_inf_minus_inf_B = 0;
// loop over A, add projections of A-points to corresponding positions
// in B-vector
for(auto&& pair_A : A) {
a_empty = false;
RealType x = Traits::get_x(pair_A);
RealType y = Traits::get_y(pair_A);
- if ( x == std::numeric_limits<RealType>::infinity()) {
+
+ // skip diagonal points, including (inf, inf), (-inf, -inf)
+ if (x == y) {
+ continue;
+ }
+
+ if (x == plus_inf && y == minus_inf) {
+ n_plus_inf_minus_inf_A++;
+ } else if (x == minus_inf && y == plus_inf) {
+ n_minus_inf_plus_inf_A++;
+ } else if ( x == plus_inf) {
y_plus_A.push_back(y);
- } else if (x == -std::numeric_limits<RealType>::infinity()) {
+ } else if (x == minus_inf) {
y_minus_A.push_back(y);
- } else if (y == std::numeric_limits<RealType>::infinity()) {
+ } else if (y == plus_inf) {
x_plus_A.push_back(x);
- } else if (y == -std::numeric_limits<RealType>::infinity()) {
+ } else if (y == minus_inf) {
x_minus_A.push_back(x);
} else {
dgm_A.emplace_back(x, y, DgmPoint::NORMAL);
@@ -295,13 +311,22 @@ wasserstein_cost(const PairContainer& A,
b_empty = false;
RealType x = Traits::get_x(pair_B);
RealType y = Traits::get_y(pair_B);
- if (x == std::numeric_limits<RealType>::infinity()) {
+
+ if (x == y) {
+ continue;
+ }
+
+ if (x == plus_inf && y == minus_inf) {
+ n_plus_inf_minus_inf_B++;
+ } else if (x == minus_inf && y == plus_inf) {
+ n_minus_inf_plus_inf_B++;
+ } else if (x == plus_inf) {
y_plus_B.push_back(y);
- } else if (x == -std::numeric_limits<RealType>::infinity()) {
+ } else if (x == minus_inf) {
y_minus_B.push_back(y);
- } else if (y == std::numeric_limits<RealType>::infinity()) {
+ } else if (y == plus_inf) {
x_plus_B.push_back(x);
- } else if (y == -std::numeric_limits<RealType>::infinity()) {
+ } else if (y == minus_inf) {
x_minus_B.push_back(x);
} else {
dgm_A.emplace_back(x, y, DgmPoint::DIAG);
@@ -310,10 +335,16 @@ wasserstein_cost(const PairContainer& A,
}
}
- RealType infinity_cost = ws::get_one_dimensional_cost(x_plus_A, x_plus_B, params.wasserstein_power);
- infinity_cost += ws::get_one_dimensional_cost(x_minus_A, x_minus_B, params.wasserstein_power);
- infinity_cost += ws::get_one_dimensional_cost(y_plus_A, y_plus_B, params.wasserstein_power);
- infinity_cost += ws::get_one_dimensional_cost(y_minus_A, y_minus_B, params.wasserstein_power);
+ RealType infinity_cost = 0;
+
+ if (n_plus_inf_minus_inf_A != n_plus_inf_minus_inf_B || n_minus_inf_plus_inf_A != n_minus_inf_plus_inf_B)
+ infinity_cost = plus_inf;
+ else {
+ infinity_cost += ws::get_one_dimensional_cost(x_plus_A, x_plus_B, params.wasserstein_power);
+ infinity_cost += ws::get_one_dimensional_cost(x_minus_A, x_minus_B, params.wasserstein_power);
+ infinity_cost += ws::get_one_dimensional_cost(y_plus_A, y_plus_B, params.wasserstein_power);
+ infinity_cost += ws::get_one_dimensional_cost(y_minus_A, y_minus_B, params.wasserstein_power);
+ }
if (a_empty)
return total_cost_B + infinity_cost;
@@ -321,8 +352,7 @@ wasserstein_cost(const PairContainer& A,
if (b_empty)
return total_cost_A + infinity_cost;
-
- if (infinity_cost == std::numeric_limits<RealType>::infinity()) {
+ if (infinity_cost == plus_inf) {
return infinity_cost;
} else {
return infinity_cost + wasserstein_cost_vec(dgm_A, dgm_B, params, _log_filename_prefix);
diff --git a/wasserstein/tests/data/test_inf_1_A b/wasserstein/tests/data/test_inf_1_A
new file mode 100644
index 0000000..c773f02
--- /dev/null
+++ b/wasserstein/tests/data/test_inf_1_A
@@ -0,0 +1,3 @@
+-inf inf
+-inf inf
+2 1
diff --git a/wasserstein/tests/data/test_inf_1_B b/wasserstein/tests/data/test_inf_1_B
new file mode 100644
index 0000000..a55e496
--- /dev/null
+++ b/wasserstein/tests/data/test_inf_1_B
@@ -0,0 +1,4 @@
+-inf inf
+-inf inf
+inf inf
+4 9
diff --git a/wasserstein/tests/data/test_inf_2_A b/wasserstein/tests/data/test_inf_2_A
new file mode 100644
index 0000000..c773f02
--- /dev/null
+++ b/wasserstein/tests/data/test_inf_2_A
@@ -0,0 +1,3 @@
+-inf inf
+-inf inf
+2 1
diff --git a/wasserstein/tests/data/test_inf_2_B b/wasserstein/tests/data/test_inf_2_B
new file mode 100644
index 0000000..6d7e751
--- /dev/null
+++ b/wasserstein/tests/data/test_inf_2_B
@@ -0,0 +1,5 @@
+-inf inf
+-inf inf
+inf -inf
+inf inf
+4 9
diff --git a/wasserstein/tests/data/test_inf_3_A b/wasserstein/tests/data/test_inf_3_A
new file mode 100644
index 0000000..4f3fc2f
--- /dev/null
+++ b/wasserstein/tests/data/test_inf_3_A
@@ -0,0 +1,6 @@
+-inf inf
+-inf inf
+inf -inf
+-inf -inf
+-inf -inf
+2 1
diff --git a/wasserstein/tests/data/test_inf_3_B b/wasserstein/tests/data/test_inf_3_B
new file mode 100644
index 0000000..6d7e751
--- /dev/null
+++ b/wasserstein/tests/data/test_inf_3_B
@@ -0,0 +1,5 @@
+-inf inf
+-inf inf
+inf -inf
+inf inf
+4 9
diff --git a/wasserstein/tests/data/test_list.txt b/wasserstein/tests/data/test_list.txt
index 27340d8..b1ba6ed 100644
--- a/wasserstein/tests/data/test_list.txt
+++ b/wasserstein/tests/data/test_list.txt
@@ -19,3 +19,6 @@ test_100_A test_100_B 3.0 2.0 2.09695346034248
test_diag1_A test_diag1_B 1.0 -1.0 0.0
test_diag2_A test_diag2_B 1.0 -1.0 0.0
test_diag3_A test_diag3_B 1.0 -1.0 0.0
+test_inf_1_A test_inf_1_B 1.0 -1.0 3.0
+test_inf_2_A test_inf_2_B 1.0 -1.0 inf
+test_inf_3_A test_inf_3_B 1.0 -1.0 3.0
diff --git a/wasserstein/tests/test_hera_wasserstein.cpp b/wasserstein/tests/test_hera_wasserstein.cpp
index 0a80d2f..6f5de3b 100644
--- a/wasserstein/tests/test_hera_wasserstein.cpp
+++ b/wasserstein/tests/test_hera_wasserstein.cpp
@@ -111,6 +111,18 @@ TEST_CASE("simple cases", "wasserstein_dist")
}
+ SECTION("trivial: two diagrams differing by diagonal point") {
+
+ diagram_A.emplace_back(0.0, 1.0);
+ diagram_B.emplace_back(0.0, 0.0);
+ diagram_B.emplace_back(0.0, 1.0);
+
+ double d1 = hera::wasserstein_cost<>(diagram_A, diagram_B, params);
+ double d2 = hera::wasserstein_cost<>(diagram_B, diagram_A, params);
+ REQUIRE( fabs(d2) <= 0.00000000001 );
+ REQUIRE( fabs(d1) <= 0.00000000001 );
+ }
+
}
@@ -130,9 +142,14 @@ TEST_CASE("file cases", "wasserstein_dist")
SECTION("from file:") {
- const char* file_name = "../tests/data/test_list.txt";
+ const char* file_name = "test_list.txt";
std::ifstream f;
f.open(file_name);
+ if (!f.good()) {
+ std::cerr << "Must run from tests/data" << std::endl;
+ REQUIRE(false);
+ }
+
std::vector<TestFromFileCase> test_params;
std::string s;
while (std::getline(f, s)) {
@@ -147,13 +164,13 @@ TEST_CASE("file cases", "wasserstein_dist")
REQUIRE( read_file_A );
REQUIRE( read_file_B );
double hera_answer = hera::wasserstein_dist(diagram_A, diagram_B, params);
- REQUIRE( fabs(hera_answer - ts.answer) <= 0.01 * hera_answer );
- std::cout << ts << " PASSED " << std::endl;
+ bool is_correct = (hera_answer == ts.answer) || (fabs(hera_answer - ts.answer) <= 0.01 * hera_answer);
+ REQUIRE(is_correct);
}
}
SECTION("from DIPHA file:") {
- const char* file_name = "../tests/data/test_list.txt";
+ const char* file_name = "test_list.txt";
std::ifstream f;
f.open(file_name);
std::vector<TestFromFileCase> test_params;
@@ -167,8 +184,8 @@ TEST_CASE("file cases", "wasserstein_dist")
params.internal_p = ts.internal_p;
bool read_file_A = hera::read_diagram_dipha<double, PairVector>(ts.file_1 + std::string(".pd.dipha"), 1, diagram_A);
bool read_file_B = hera::read_diagram_dipha<double, PairVector>(ts.file_2 + std::string(".pd.dipha"), 1, diagram_B);
- REQUIRE( read_file_A );
- REQUIRE( read_file_B );
+ if (!read_file_A)
+ continue;
double hera_answer = hera::wasserstein_dist(diagram_A, diagram_B, params);
REQUIRE( fabs(hera_answer - ts.answer) <= 0.01 * hera_answer );
std::cout << ts << " PASSED " << std::endl;