summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorArnur Nigmetov <nigmetov@tugraz.at>2019-09-16 17:26:42 +0200
committerArnur Nigmetov <nigmetov@tugraz.at>2019-09-16 17:26:42 +0200
commit7f6b96187423ba374b697f6f411c4a70d24ee297 (patch)
tree215ae5e4e422a3ce8b0b2055b794581886681b0b
parent498e457b221e9e70131ad73d9b881285d8d9572e (diff)
Add traits for float in Wasserstein.
-rw-r--r--geom_matching/wasserstein/example/wasserstein_dist.cpp17
-rw-r--r--geom_matching/wasserstein/include/wasserstein.h11
2 files changed, 20 insertions, 8 deletions
diff --git a/geom_matching/wasserstein/example/wasserstein_dist.cpp b/geom_matching/wasserstein/example/wasserstein_dist.cpp
index 25e1f68..cbe83e2 100644
--- a/geom_matching/wasserstein/example/wasserstein_dist.cpp
+++ b/geom_matching/wasserstein/example/wasserstein_dist.cpp
@@ -40,15 +40,16 @@ derivative works thereof, in binary and source code form.
#include "wasserstein.h"
-// any container of pairs of doubles can be used,
+// any container of pairs of Reals can be used,
// we use vector in this example.
int main(int argc, char* argv[])
{
- using PairVector = std::vector<std::pair<double, double>>;
+ using Real = double;
+ using PairVector = std::vector<std::pair<Real, Real>>;
PairVector diagramA, diagramB;
- hera::AuctionParams<double> params;
+ hera::AuctionParams<Real> params;
params.max_num_phases = 800;
opts::Options ops(argc, argv);
@@ -87,7 +88,7 @@ int main(int argc, char* argv[])
std::cout << ops << std::endl;
}
- if (!hera::read_diagram_point_set<double, PairVector>(dgm_fname_1, diagramA)) {
+ if (!hera::read_diagram_point_set<Real, PairVector>(dgm_fname_1, diagramA)) {
std::exit(1);
}
@@ -101,7 +102,7 @@ int main(int argc, char* argv[])
}
if (params.wasserstein_power == 1.0) {
- hera::remove_duplicates<double>(diagramA, diagramB);
+ hera::remove_duplicates<Real>(diagramA, diagramB);
}
//default relative error: 1%
@@ -112,11 +113,11 @@ int main(int argc, char* argv[])
// default for internal metric is l_infinity
if (std::isinf(params.internal_p)) {
- params.internal_p = hera::get_infinity<double>();
+ params.internal_p = hera::get_infinity<Real>();
}
- if (not hera::is_p_valid_norm<double>(params.internal_p)) {
+ if (not hera::is_p_valid_norm<Real>(params.internal_p)) {
std::cerr << "internal-p was \"" << params.internal_p << "\", must be a number >= 1.0 or inf. Cannot proceed. " << std::endl;
std::exit(1);
}
@@ -144,7 +145,7 @@ int main(int argc, char* argv[])
spdlog::set_level(spdlog::level::info);
#endif
- double res = hera::wasserstein_dist(diagramA, diagramB, params, log_filename_prefix);
+ Real res = hera::wasserstein_dist(diagramA, diagramB, params, log_filename_prefix);
std::cout << std::setprecision(15) << res << std::endl;
if (print_relative_error)
diff --git a/geom_matching/wasserstein/include/wasserstein.h b/geom_matching/wasserstein/include/wasserstein.h
index 35d0bf6..17bb211 100644
--- a/geom_matching/wasserstein/include/wasserstein.h
+++ b/geom_matching/wasserstein/include/wasserstein.h
@@ -67,6 +67,17 @@ struct DiagramTraits<PairContainer_, std::pair<double, double>>
static RealType get_y(const PointType& p) { return p.second; }
};
+template<class PairContainer_>
+struct DiagramTraits<PairContainer_, std::pair<float, float>>
+{
+ using PointType = std::pair<float, float>;
+ using RealType = float;
+ using Container = std::vector<PointType>;
+
+ static RealType get_x(const PointType& p) { return p.first; }
+ static RealType get_y(const PointType& p) { return p.second; }
+};
+
namespace ws
{