diff options
author | Arnur Nigmetov <nigmetov@tugraz.at> | 2019-09-16 17:26:42 +0200 |
---|---|---|
committer | Arnur Nigmetov <nigmetov@tugraz.at> | 2019-09-16 17:26:42 +0200 |
commit | 7f6b96187423ba374b697f6f411c4a70d24ee297 (patch) | |
tree | 215ae5e4e422a3ce8b0b2055b794581886681b0b | |
parent | 498e457b221e9e70131ad73d9b881285d8d9572e (diff) |
Add traits for float in Wasserstein.
-rw-r--r-- | geom_matching/wasserstein/example/wasserstein_dist.cpp | 17 | ||||
-rw-r--r-- | geom_matching/wasserstein/include/wasserstein.h | 11 |
2 files changed, 20 insertions, 8 deletions
diff --git a/geom_matching/wasserstein/example/wasserstein_dist.cpp b/geom_matching/wasserstein/example/wasserstein_dist.cpp index 25e1f68..cbe83e2 100644 --- a/geom_matching/wasserstein/example/wasserstein_dist.cpp +++ b/geom_matching/wasserstein/example/wasserstein_dist.cpp @@ -40,15 +40,16 @@ derivative works thereof, in binary and source code form. #include "wasserstein.h" -// any container of pairs of doubles can be used, +// any container of pairs of Reals can be used, // we use vector in this example. int main(int argc, char* argv[]) { - using PairVector = std::vector<std::pair<double, double>>; + using Real = double; + using PairVector = std::vector<std::pair<Real, Real>>; PairVector diagramA, diagramB; - hera::AuctionParams<double> params; + hera::AuctionParams<Real> params; params.max_num_phases = 800; opts::Options ops(argc, argv); @@ -87,7 +88,7 @@ int main(int argc, char* argv[]) std::cout << ops << std::endl; } - if (!hera::read_diagram_point_set<double, PairVector>(dgm_fname_1, diagramA)) { + if (!hera::read_diagram_point_set<Real, PairVector>(dgm_fname_1, diagramA)) { std::exit(1); } @@ -101,7 +102,7 @@ int main(int argc, char* argv[]) } if (params.wasserstein_power == 1.0) { - hera::remove_duplicates<double>(diagramA, diagramB); + hera::remove_duplicates<Real>(diagramA, diagramB); } //default relative error: 1% @@ -112,11 +113,11 @@ int main(int argc, char* argv[]) // default for internal metric is l_infinity if (std::isinf(params.internal_p)) { - params.internal_p = hera::get_infinity<double>(); + params.internal_p = hera::get_infinity<Real>(); } - if (not hera::is_p_valid_norm<double>(params.internal_p)) { + if (not hera::is_p_valid_norm<Real>(params.internal_p)) { std::cerr << "internal-p was \"" << params.internal_p << "\", must be a number >= 1.0 or inf. Cannot proceed. " << std::endl; std::exit(1); } @@ -144,7 +145,7 @@ int main(int argc, char* argv[]) spdlog::set_level(spdlog::level::info); #endif - double res = hera::wasserstein_dist(diagramA, diagramB, params, log_filename_prefix); + Real res = hera::wasserstein_dist(diagramA, diagramB, params, log_filename_prefix); std::cout << std::setprecision(15) << res << std::endl; if (print_relative_error) diff --git a/geom_matching/wasserstein/include/wasserstein.h b/geom_matching/wasserstein/include/wasserstein.h index 35d0bf6..17bb211 100644 --- a/geom_matching/wasserstein/include/wasserstein.h +++ b/geom_matching/wasserstein/include/wasserstein.h @@ -67,6 +67,17 @@ struct DiagramTraits<PairContainer_, std::pair<double, double>> static RealType get_y(const PointType& p) { return p.second; } }; +template<class PairContainer_> +struct DiagramTraits<PairContainer_, std::pair<float, float>> +{ + using PointType = std::pair<float, float>; + using RealType = float; + using Container = std::vector<PointType>; + + static RealType get_x(const PointType& p) { return p.first; } + static RealType get_y(const PointType& p) { return p.second; } +}; + namespace ws { |