summaryrefslogtreecommitdiff
path: root/geom_matching/wasserstein/include/wasserstein_pure_geom.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'geom_matching/wasserstein/include/wasserstein_pure_geom.hpp')
-rw-r--r--geom_matching/wasserstein/include/wasserstein_pure_geom.hpp5
1 files changed, 2 insertions, 3 deletions
diff --git a/geom_matching/wasserstein/include/wasserstein_pure_geom.hpp b/geom_matching/wasserstein/include/wasserstein_pure_geom.hpp
index 2a57599..096d95d 100644
--- a/geom_matching/wasserstein/include/wasserstein_pure_geom.hpp
+++ b/geom_matching/wasserstein/include/wasserstein_pure_geom.hpp
@@ -30,7 +30,7 @@ namespace ws
using AuctionRunnerJacR = typename hera::ws::AuctionRunnerJac<Real, hera::ws::AuctionOracleKDTreePureGeom<Real>, hera::ws::dnn::DynamicPointVector<Real>>;
-double wasserstein_cost(const DynamicPointVector<double>& set_A, const DynamicPointVector<double>& set_B, const AuctionParams<double>& params)
+inline double wasserstein_cost(const DynamicPointVector<double>& set_A, const DynamicPointVector<double>& set_B, const AuctionParams<double>& params)
{
if (params.wasserstein_power < 1.0) {
throw std::runtime_error("Bad q in Wasserstein " + std::to_string(params.wasserstein_power));
@@ -72,10 +72,9 @@ double wasserstein_cost(const DynamicPointVector<double>& set_A, const DynamicPo
auction.run_auction();
return auction.get_wasserstein_cost();
}
-
}
-double wasserstein_dist(const DynamicPointVector<double>& set_A, const DynamicPointVector<double>& set_B, const AuctionParams<double>& params)
+inline double wasserstein_dist(const DynamicPointVector<double>& set_A, const DynamicPointVector<double>& set_B, const AuctionParams<double>& params)
{
return std::pow(wasserstein_cost(set_A, set_B, params), 1.0 / params.wasserstein_power);
}