diff options
Diffstat (limited to 'geom_matching/wasserstein/src/wasserstein.cpp')
-rw-r--r-- | geom_matching/wasserstein/src/wasserstein.cpp | 65 |
1 files changed, 59 insertions, 6 deletions
diff --git a/geom_matching/wasserstein/src/wasserstein.cpp b/geom_matching/wasserstein/src/wasserstein.cpp index cac811b..fc1b662 100644 --- a/geom_matching/wasserstein/src/wasserstein.cpp +++ b/geom_matching/wasserstein/src/wasserstein.cpp @@ -36,7 +36,7 @@ derivative works thereof, in binary and source code form. #ifdef GAUSS_SEIDEL_AUCTION #include "auction_runner_gs.h" #else -#include "auction_runner_jak.h" +#include "auction_runner_jac.h" #endif namespace geom_ws { @@ -50,29 +50,78 @@ double wassersteinDistVec(const std::vector<DiagramPoint>& A, const double _epsFactor) { if (q < 1) { +#ifndef FOR_R_TDA std::cerr << "Wasserstein distance not defined for q = " << q << ", must be >= 1" << std::endl; - throw "Bad q in Wasserstein"; +#endif + throw std::runtime_error("Bad q in Wasserstein " + std::to_string(q)); } if (delta < 0.0) { +#ifndef FOR_R_TDA std::cerr << "Relative error " << delta << ", must be > 0" << std::endl; - throw "Bad delta in Wasserstein"; +#endif + throw std::runtime_error("Bad delta in Wasserstein " + std::to_string(delta)); } if (_initialEpsilon < 0.0) { +#ifndef FOR_R_TDA std::cerr << "Initial epsilon = " << _initialEpsilon << ", must be non-negative" << std::endl; - throw "Bad delta in Wasserstein"; +#endif + throw std::runtime_error("Bad initial epsilon in Wasserstein" + std::to_string(_initialEpsilon)); } if (_epsFactor < 0.0) { +#ifndef FOR_R_TDA std::cerr << "Epsilon factor = " << _epsFactor << ", must be non-negative" << std::endl; - throw "Bad delta in Wasserstein"; +#endif + throw std::runtime_error("Bad epsilon factor in Wasserstein " + std::to_string(_epsFactor)); } + #ifdef GAUSS_SEIDEL_AUCTION AuctionRunnerGS auction(A, B, q, delta, _internal_p, _initialEpsilon, _epsFactor); #else - AuctionRunnerJak auction(A, B, q, delta, _internal_p); + AuctionRunnerJac auction(A, B, q, delta, _internal_p); #endif return auction.getWassersteinDistance(); } +double wassersteinCostVec(const std::vector<DiagramPoint>& A, + const std::vector<DiagramPoint>& B, + const double q, + const double delta, + const double _internal_p, + const double _initialEpsilon, + const double _epsFactor) +{ + if (q < 1) { +#ifndef FOR_R_TDA + std::cerr << "Wasserstein distance not defined for q = " << q << ", must be >= 1" << std::endl; +#endif + throw std::runtime_error("Bad q in Wasserstein " + std::to_string(q)); + } + if (delta < 0.0) { +#ifndef FOR_R_TDA + std::cerr << "Relative error " << delta << ", must be > 0" << std::endl; +#endif + throw std::runtime_error("Bad delta in Wasserstein " + std::to_string(delta)); + } + if (_initialEpsilon < 0.0) { +#ifndef FOR_R_TDA + std::cerr << "Initial epsilon = " << _initialEpsilon << ", must be non-negative" << std::endl; +#endif + throw std::runtime_error("Bad initial epsilon in Wasserstein" + std::to_string(_initialEpsilon)); + } + if (_epsFactor < 0.0) { +#ifndef FOR_R_TDA + std::cerr << "Epsilon factor = " << _epsFactor << ", must be non-negative" << std::endl; +#endif + throw std::runtime_error("Bad epsilon factor in Wasserstein " + std::to_string(_epsFactor)); + } +#ifdef GAUSS_SEIDEL_AUCTION + AuctionRunnerGS auction(A, B, q, delta, _internal_p, _initialEpsilon, _epsFactor); +#else + AuctionRunnerJac auction(A, B, q, delta, _internal_p); +#endif + return auction.getWassersteinCost(); +} + bool readDiagramPointSet(const std::string& fname, std::vector<std::pair<double, double>>& result) { return readDiagramPointSet(fname.c_str(), result); @@ -84,7 +133,9 @@ bool readDiagramPointSet(const char* fname, std::vector<std::pair<double, double result.clear(); std::ifstream f(fname); if (!f.good()) { +#ifndef FOR_R_TDA std::cerr << "Cannot open file " << fname << std::endl; +#endif return false; } std::string line; @@ -109,7 +160,9 @@ bool readDiagramPointSet(const char* fname, std::vector<std::pair<double, double double x, y; std::istringstream iss(line); if (not(iss >> x >> y)) { +#ifndef FOR_R_TDA std::cerr << "Error in file " << fname << ", line number " << lineNumber << ": cannot parse \"" << line << "\"" << std::endl; +#endif return false; } result.push_back(std::make_pair(x,y)); |