diff options
-rw-r--r-- | geom_matching/wasserstein/include/wasserstein.h | 40 |
1 files changed, 8 insertions, 32 deletions
diff --git a/geom_matching/wasserstein/include/wasserstein.h b/geom_matching/wasserstein/include/wasserstein.h index d8d6b2e..db6ce11 100644 --- a/geom_matching/wasserstein/include/wasserstein.h +++ b/geom_matching/wasserstein/include/wasserstein.h @@ -48,7 +48,6 @@ namespace hera template<class PairContainer_, class PointType_ = typename std::remove_reference< decltype(*std::declval<PairContainer_>().begin())>::type > struct DiagramTraits { - using Container = PairContainer_; using PointType = PointType_; using RealType = typename std::remove_reference< decltype(std::declval<PointType>()[0]) >::type; @@ -56,34 +55,11 @@ struct DiagramTraits static RealType get_y(const PointType& p) { return p[1]; } }; -template<class PairContainer_> -struct DiagramTraits<PairContainer_, std::pair<long double, long double>> +template<class PairContainer_, class RealType_> +struct DiagramTraits<PairContainer_, std::pair<RealType_, RealType_>> { - using PointType = std::pair<long double, long double>; - using RealType = long double; - using Container = std::vector<PointType>; - - static RealType get_x(const PointType& p) { return p.first; } - static RealType get_y(const PointType& p) { return p.second; } -}; - -template<class PairContainer_> -struct DiagramTraits<PairContainer_, std::pair<double, double>> -{ - using PointType = std::pair<double, double>; - using RealType = double; - using Container = std::vector<PointType>; - - static RealType get_x(const PointType& p) { return p.first; } - static RealType get_y(const PointType& p) { return p.second; } -}; - -template<class PairContainer_> -struct DiagramTraits<PairContainer_, std::pair<float, float>> -{ - using PointType = std::pair<float, float>; - using RealType = float; - using Container = std::vector<PointType>; + using RealType = RealType_; + using PointType = std::pair<RealType, RealType>; static RealType get_x(const PointType& p) { return p.first; } static RealType get_y(const PointType& p) { return p.second; } @@ -106,11 +82,11 @@ namespace ws std::map<PointType, int> m1, m2; - for(const auto& pair1 : dgm1) { + for(auto&& pair1 : dgm1) { m1[pair1]++; } - for(const auto& pair2 : dgm2) { + for(auto&& pair2 : dgm2) { m2[pair2]++; } @@ -296,7 +272,7 @@ wasserstein_cost(const PairContainer& A, std::vector<RealType> x_plus_B, x_minus_B, y_plus_B, y_minus_B; // loop over A, add projections of A-points to corresponding positions // in B-vector - for(auto& pair_A : A) { + for(auto&& pair_A : A) { a_empty = false; RealType x = Traits::get_x(pair_A); RealType y = Traits::get_y(pair_A); @@ -315,7 +291,7 @@ wasserstein_cost(const PairContainer& A, } } // the same for B - for(auto& pair_B : B) { + for(auto&& pair_B : B) { b_empty = false; RealType x = Traits::get_x(pair_B); RealType y = Traits::get_y(pair_B); |