summaryrefslogtreecommitdiff
path: root/geom_matching/wasserstein/include/wasserstein.h
diff options
context:
space:
mode:
authorArnur Nigmetov <nigmetov@tugraz.at>2019-02-06 22:45:18 +0100
committerArnur Nigmetov <nigmetov@tugraz.at>2019-02-06 22:45:18 +0100
commitba25f264b5d309efcf77a6b72d1b784ae97f741f (patch)
treefb536fc0d0e9f241e6c7b37b45a94e31d25f1a8f /geom_matching/wasserstein/include/wasserstein.h
parent657f73321f04d5d1c4cec8085ec43a73633b96af (diff)
Switch to opts, tolerate max-iterations-exceeded.
1. Use opts.h for command-line parsing in wasserstein_dist. 2. Real relative error at the end of auction is stored in params, params is passed by reference. 3. If -e command line option is given to wasserstein_dist, relative error will be printed. 4. If -t option is given to wasserstein_dist, no exception will be thrown, if maximum number of iterations is exceeded. 5. Run wasserstein_dist -h to see all options.
Diffstat (limited to 'geom_matching/wasserstein/include/wasserstein.h')
-rw-r--r--geom_matching/wasserstein/include/wasserstein.h7
1 files changed, 4 insertions, 3 deletions
diff --git a/geom_matching/wasserstein/include/wasserstein.h b/geom_matching/wasserstein/include/wasserstein.h
index a24bada..35d0bf6 100644
--- a/geom_matching/wasserstein/include/wasserstein.h
+++ b/geom_matching/wasserstein/include/wasserstein.h
@@ -212,7 +212,7 @@ namespace ws
template<class RealType>
inline RealType wasserstein_cost_vec(const std::vector<DiagramPoint<RealType>>& A,
const std::vector<DiagramPoint<RealType>>& B,
- const AuctionParams<RealType>& params,
+ AuctionParams<RealType>& params,
const std::string& _log_filename_prefix)
{
if (params.wasserstein_power < 1.0) {
@@ -237,6 +237,7 @@ namespace ws
AuctionRunnerGS<RealType> auction(A, B, params, _log_filename_prefix);
auction.run_auction();
result = auction.get_wasserstein_cost();
+ params.final_relative_error = auction.get_relative_error();
return result;
}
@@ -248,7 +249,7 @@ template<class PairContainer>
inline typename DiagramTraits<PairContainer>::RealType
wasserstein_cost(const PairContainer& A,
const PairContainer& B,
- const AuctionParams< typename DiagramTraits<PairContainer>::RealType >& params,
+ AuctionParams< typename DiagramTraits<PairContainer>::RealType >& params,
const std::string& _log_filename_prefix = "")
{
using Traits = DiagramTraits<PairContainer>;
@@ -335,7 +336,7 @@ template<class PairContainer>
inline typename DiagramTraits<PairContainer>::RealType
wasserstein_dist(const PairContainer& A,
const PairContainer& B,
- const AuctionParams<typename DiagramTraits<PairContainer>::RealType> params,
+ AuctionParams<typename DiagramTraits<PairContainer>::RealType>& params,
const std::string& _log_filename_prefix = "")
{
using Real = typename DiagramTraits<PairContainer>::RealType;