From 7f6b96187423ba374b697f6f411c4a70d24ee297 Mon Sep 17 00:00:00 2001 From: Arnur Nigmetov Date: Mon, 16 Sep 2019 17:26:42 +0200 Subject: Add traits for float in Wasserstein. --- geom_matching/wasserstein/example/wasserstein_dist.cpp | 17 +++++++++-------- geom_matching/wasserstein/include/wasserstein.h | 11 +++++++++++ 2 files changed, 20 insertions(+), 8 deletions(-) diff --git a/geom_matching/wasserstein/example/wasserstein_dist.cpp b/geom_matching/wasserstein/example/wasserstein_dist.cpp index 25e1f68..cbe83e2 100644 --- a/geom_matching/wasserstein/example/wasserstein_dist.cpp +++ b/geom_matching/wasserstein/example/wasserstein_dist.cpp @@ -40,15 +40,16 @@ derivative works thereof, in binary and source code form. #include "wasserstein.h" -// any container of pairs of doubles can be used, +// any container of pairs of Reals can be used, // we use vector in this example. int main(int argc, char* argv[]) { - using PairVector = std::vector>; + using Real = double; + using PairVector = std::vector>; PairVector diagramA, diagramB; - hera::AuctionParams params; + hera::AuctionParams params; params.max_num_phases = 800; opts::Options ops(argc, argv); @@ -87,7 +88,7 @@ int main(int argc, char* argv[]) std::cout << ops << std::endl; } - if (!hera::read_diagram_point_set(dgm_fname_1, diagramA)) { + if (!hera::read_diagram_point_set(dgm_fname_1, diagramA)) { std::exit(1); } @@ -101,7 +102,7 @@ int main(int argc, char* argv[]) } if (params.wasserstein_power == 1.0) { - hera::remove_duplicates(diagramA, diagramB); + hera::remove_duplicates(diagramA, diagramB); } //default relative error: 1% @@ -112,11 +113,11 @@ int main(int argc, char* argv[]) // default for internal metric is l_infinity if (std::isinf(params.internal_p)) { - params.internal_p = hera::get_infinity(); + params.internal_p = hera::get_infinity(); } - if (not hera::is_p_valid_norm(params.internal_p)) { + if (not hera::is_p_valid_norm(params.internal_p)) { std::cerr << "internal-p was \"" << params.internal_p << "\", must be a number >= 1.0 or inf. Cannot proceed. " << std::endl; std::exit(1); } @@ -144,7 +145,7 @@ int main(int argc, char* argv[]) spdlog::set_level(spdlog::level::info); #endif - double res = hera::wasserstein_dist(diagramA, diagramB, params, log_filename_prefix); + Real res = hera::wasserstein_dist(diagramA, diagramB, params, log_filename_prefix); std::cout << std::setprecision(15) << res << std::endl; if (print_relative_error) diff --git a/geom_matching/wasserstein/include/wasserstein.h b/geom_matching/wasserstein/include/wasserstein.h index 35d0bf6..17bb211 100644 --- a/geom_matching/wasserstein/include/wasserstein.h +++ b/geom_matching/wasserstein/include/wasserstein.h @@ -67,6 +67,17 @@ struct DiagramTraits> static RealType get_y(const PointType& p) { return p.second; } }; +template +struct DiagramTraits> +{ + using PointType = std::pair; + using RealType = float; + using Container = std::vector; + + static RealType get_x(const PointType& p) { return p.first; } + static RealType get_y(const PointType& p) { return p.second; } +}; + namespace ws { -- cgit v1.2.3