summaryrefslogtreecommitdiff
path: root/geom_matching/wasserstein/src/wasserstein.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'geom_matching/wasserstein/src/wasserstein.cpp')
-rw-r--r--geom_matching/wasserstein/src/wasserstein.cpp65
1 files changed, 59 insertions, 6 deletions
diff --git a/geom_matching/wasserstein/src/wasserstein.cpp b/geom_matching/wasserstein/src/wasserstein.cpp
index cac811b..fc1b662 100644
--- a/geom_matching/wasserstein/src/wasserstein.cpp
+++ b/geom_matching/wasserstein/src/wasserstein.cpp
@@ -36,7 +36,7 @@ derivative works thereof, in binary and source code form.
#ifdef GAUSS_SEIDEL_AUCTION
#include "auction_runner_gs.h"
#else
-#include "auction_runner_jak.h"
+#include "auction_runner_jac.h"
#endif
namespace geom_ws {
@@ -50,29 +50,78 @@ double wassersteinDistVec(const std::vector<DiagramPoint>& A,
const double _epsFactor)
{
if (q < 1) {
+#ifndef FOR_R_TDA
std::cerr << "Wasserstein distance not defined for q = " << q << ", must be >= 1" << std::endl;
- throw "Bad q in Wasserstein";
+#endif
+ throw std::runtime_error("Bad q in Wasserstein " + std::to_string(q));
}
if (delta < 0.0) {
+#ifndef FOR_R_TDA
std::cerr << "Relative error " << delta << ", must be > 0" << std::endl;
- throw "Bad delta in Wasserstein";
+#endif
+ throw std::runtime_error("Bad delta in Wasserstein " + std::to_string(delta));
}
if (_initialEpsilon < 0.0) {
+#ifndef FOR_R_TDA
std::cerr << "Initial epsilon = " << _initialEpsilon << ", must be non-negative" << std::endl;
- throw "Bad delta in Wasserstein";
+#endif
+ throw std::runtime_error("Bad initial epsilon in Wasserstein" + std::to_string(_initialEpsilon));
}
if (_epsFactor < 0.0) {
+#ifndef FOR_R_TDA
std::cerr << "Epsilon factor = " << _epsFactor << ", must be non-negative" << std::endl;
- throw "Bad delta in Wasserstein";
+#endif
+ throw std::runtime_error("Bad epsilon factor in Wasserstein " + std::to_string(_epsFactor));
}
+
#ifdef GAUSS_SEIDEL_AUCTION
AuctionRunnerGS auction(A, B, q, delta, _internal_p, _initialEpsilon, _epsFactor);
#else
- AuctionRunnerJak auction(A, B, q, delta, _internal_p);
+ AuctionRunnerJac auction(A, B, q, delta, _internal_p);
#endif
return auction.getWassersteinDistance();
}
+double wassersteinCostVec(const std::vector<DiagramPoint>& A,
+ const std::vector<DiagramPoint>& B,
+ const double q,
+ const double delta,
+ const double _internal_p,
+ const double _initialEpsilon,
+ const double _epsFactor)
+{
+ if (q < 1) {
+#ifndef FOR_R_TDA
+ std::cerr << "Wasserstein distance not defined for q = " << q << ", must be >= 1" << std::endl;
+#endif
+ throw std::runtime_error("Bad q in Wasserstein " + std::to_string(q));
+ }
+ if (delta < 0.0) {
+#ifndef FOR_R_TDA
+ std::cerr << "Relative error " << delta << ", must be > 0" << std::endl;
+#endif
+ throw std::runtime_error("Bad delta in Wasserstein " + std::to_string(delta));
+ }
+ if (_initialEpsilon < 0.0) {
+#ifndef FOR_R_TDA
+ std::cerr << "Initial epsilon = " << _initialEpsilon << ", must be non-negative" << std::endl;
+#endif
+ throw std::runtime_error("Bad initial epsilon in Wasserstein" + std::to_string(_initialEpsilon));
+ }
+ if (_epsFactor < 0.0) {
+#ifndef FOR_R_TDA
+ std::cerr << "Epsilon factor = " << _epsFactor << ", must be non-negative" << std::endl;
+#endif
+ throw std::runtime_error("Bad epsilon factor in Wasserstein " + std::to_string(_epsFactor));
+ }
+#ifdef GAUSS_SEIDEL_AUCTION
+ AuctionRunnerGS auction(A, B, q, delta, _internal_p, _initialEpsilon, _epsFactor);
+#else
+ AuctionRunnerJac auction(A, B, q, delta, _internal_p);
+#endif
+ return auction.getWassersteinCost();
+}
+
bool readDiagramPointSet(const std::string& fname, std::vector<std::pair<double, double>>& result)
{
return readDiagramPointSet(fname.c_str(), result);
@@ -84,7 +133,9 @@ bool readDiagramPointSet(const char* fname, std::vector<std::pair<double, double
result.clear();
std::ifstream f(fname);
if (!f.good()) {
+#ifndef FOR_R_TDA
std::cerr << "Cannot open file " << fname << std::endl;
+#endif
return false;
}
std::string line;
@@ -109,7 +160,9 @@ bool readDiagramPointSet(const char* fname, std::vector<std::pair<double, double
double x, y;
std::istringstream iss(line);
if (not(iss >> x >> y)) {
+#ifndef FOR_R_TDA
std::cerr << "Error in file " << fname << ", line number " << lineNumber << ": cannot parse \"" << line << "\"" << std::endl;
+#endif
return false;
}
result.push_back(std::make_pair(x,y));