summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authormglisse <marc.glisse@inria.fr>2020-01-22 11:03:30 +0000
committerArnur Nigmetov <a.nigmetov@gmail.com>2020-01-22 11:03:30 +0000
commit2f822cc372f940e5eb7ebf1a6fe510bc46a2d9a6 (patch)
treedba7e1cf8b6b39631af2f42702f7b191fbe4740d
parent9a89971855acefe39dce0e2adadf53b88ca8f683 (diff)
parent65758d72a31a01bd39c55d94ad1ce38f1969083f (diff)
Merged in mglisse/hera/container (pull request #1)
Tweaks for containers accepted by wasserstein_distance Approved-by: Arnur Nigmetov
-rw-r--r--geom_matching/wasserstein/include/wasserstein.h40
1 files changed, 8 insertions, 32 deletions
diff --git a/geom_matching/wasserstein/include/wasserstein.h b/geom_matching/wasserstein/include/wasserstein.h
index d8d6b2e..db6ce11 100644
--- a/geom_matching/wasserstein/include/wasserstein.h
+++ b/geom_matching/wasserstein/include/wasserstein.h
@@ -48,7 +48,6 @@ namespace hera
template<class PairContainer_, class PointType_ = typename std::remove_reference< decltype(*std::declval<PairContainer_>().begin())>::type >
struct DiagramTraits
{
- using Container = PairContainer_;
using PointType = PointType_;
using RealType = typename std::remove_reference< decltype(std::declval<PointType>()[0]) >::type;
@@ -56,34 +55,11 @@ struct DiagramTraits
static RealType get_y(const PointType& p) { return p[1]; }
};
-template<class PairContainer_>
-struct DiagramTraits<PairContainer_, std::pair<long double, long double>>
+template<class PairContainer_, class RealType_>
+struct DiagramTraits<PairContainer_, std::pair<RealType_, RealType_>>
{
- using PointType = std::pair<long double, long double>;
- using RealType = long double;
- using Container = std::vector<PointType>;
-
- static RealType get_x(const PointType& p) { return p.first; }
- static RealType get_y(const PointType& p) { return p.second; }
-};
-
-template<class PairContainer_>
-struct DiagramTraits<PairContainer_, std::pair<double, double>>
-{
- using PointType = std::pair<double, double>;
- using RealType = double;
- using Container = std::vector<PointType>;
-
- static RealType get_x(const PointType& p) { return p.first; }
- static RealType get_y(const PointType& p) { return p.second; }
-};
-
-template<class PairContainer_>
-struct DiagramTraits<PairContainer_, std::pair<float, float>>
-{
- using PointType = std::pair<float, float>;
- using RealType = float;
- using Container = std::vector<PointType>;
+ using RealType = RealType_;
+ using PointType = std::pair<RealType, RealType>;
static RealType get_x(const PointType& p) { return p.first; }
static RealType get_y(const PointType& p) { return p.second; }
@@ -106,11 +82,11 @@ namespace ws
std::map<PointType, int> m1, m2;
- for(const auto& pair1 : dgm1) {
+ for(auto&& pair1 : dgm1) {
m1[pair1]++;
}
- for(const auto& pair2 : dgm2) {
+ for(auto&& pair2 : dgm2) {
m2[pair2]++;
}
@@ -296,7 +272,7 @@ wasserstein_cost(const PairContainer& A,
std::vector<RealType> x_plus_B, x_minus_B, y_plus_B, y_minus_B;
// loop over A, add projections of A-points to corresponding positions
// in B-vector
- for(auto& pair_A : A) {
+ for(auto&& pair_A : A) {
a_empty = false;
RealType x = Traits::get_x(pair_A);
RealType y = Traits::get_y(pair_A);
@@ -315,7 +291,7 @@ wasserstein_cost(const PairContainer& A,
}
}
// the same for B
- for(auto& pair_B : B) {
+ for(auto&& pair_B : B) {
b_empty = false;
RealType x = Traits::get_x(pair_B);
RealType y = Traits::get_y(pair_B);