diff options
Diffstat (limited to 'geom_matching')
-rw-r--r-- | geom_matching/wasserstein/example/wasserstein_dist.cpp | 17 | ||||
-rw-r--r-- | geom_matching/wasserstein/include/wasserstein.h | 11 |
2 files changed, 20 insertions, 8 deletions
diff --git a/geom_matching/wasserstein/example/wasserstein_dist.cpp b/geom_matching/wasserstein/example/wasserstein_dist.cpp index 25e1f68..cbe83e2 100644 --- a/geom_matching/wasserstein/example/wasserstein_dist.cpp +++ b/geom_matching/wasserstein/example/wasserstein_dist.cpp @@ -40,15 +40,16 @@ derivative works thereof, in binary and source code form. #include "wasserstein.h" -// any container of pairs of doubles can be used, +// any container of pairs of Reals can be used, // we use vector in this example. int main(int argc, char* argv[]) { - using PairVector = std::vector<std::pair<double, double>>; + using Real = double; + using PairVector = std::vector<std::pair<Real, Real>>; PairVector diagramA, diagramB; - hera::AuctionParams<double> params; + hera::AuctionParams<Real> params; params.max_num_phases = 800; opts::Options ops(argc, argv); @@ -87,7 +88,7 @@ int main(int argc, char* argv[]) std::cout << ops << std::endl; } - if (!hera::read_diagram_point_set<double, PairVector>(dgm_fname_1, diagramA)) { + if (!hera::read_diagram_point_set<Real, PairVector>(dgm_fname_1, diagramA)) { std::exit(1); } @@ -101,7 +102,7 @@ int main(int argc, char* argv[]) } if (params.wasserstein_power == 1.0) { - hera::remove_duplicates<double>(diagramA, diagramB); + hera::remove_duplicates<Real>(diagramA, diagramB); } //default relative error: 1% @@ -112,11 +113,11 @@ int main(int argc, char* argv[]) // default for internal metric is l_infinity if (std::isinf(params.internal_p)) { - params.internal_p = hera::get_infinity<double>(); + params.internal_p = hera::get_infinity<Real>(); } - if (not hera::is_p_valid_norm<double>(params.internal_p)) { + if (not hera::is_p_valid_norm<Real>(params.internal_p)) { std::cerr << "internal-p was \"" << params.internal_p << "\", must be a number >= 1.0 or inf. Cannot proceed. " << std::endl; std::exit(1); } @@ -144,7 +145,7 @@ int main(int argc, char* argv[]) spdlog::set_level(spdlog::level::info); #endif - double res = hera::wasserstein_dist(diagramA, diagramB, params, log_filename_prefix); + Real res = hera::wasserstein_dist(diagramA, diagramB, params, log_filename_prefix); std::cout << std::setprecision(15) << res << std::endl; if (print_relative_error) diff --git a/geom_matching/wasserstein/include/wasserstein.h b/geom_matching/wasserstein/include/wasserstein.h index 35d0bf6..17bb211 100644 --- a/geom_matching/wasserstein/include/wasserstein.h +++ b/geom_matching/wasserstein/include/wasserstein.h @@ -67,6 +67,17 @@ struct DiagramTraits<PairContainer_, std::pair<double, double>> static RealType get_y(const PointType& p) { return p.second; } }; +template<class PairContainer_> +struct DiagramTraits<PairContainer_, std::pair<float, float>> +{ + using PointType = std::pair<float, float>; + using RealType = float; + using Container = std::vector<PointType>; + + static RealType get_x(const PointType& p) { return p.first; } + static RealType get_y(const PointType& p) { return p.second; } +}; + namespace ws { |