summaryrefslogtreecommitdiff
path: root/geom_matching/wasserstein/include
diff options
context:
space:
mode:
Diffstat (limited to 'geom_matching/wasserstein/include')
-rw-r--r--geom_matching/wasserstein/include/auction_runner_gs.h6
-rw-r--r--geom_matching/wasserstein/include/auction_runner_jac.h28
-rw-r--r--geom_matching/wasserstein/include/basic_defs_ws.h5
-rw-r--r--geom_matching/wasserstein/include/def_debug_ws.h8
-rw-r--r--geom_matching/wasserstein/include/dnn/geometry/euclidean-fixed.h2
-rw-r--r--geom_matching/wasserstein/include/dnn/local/kd-tree.hpp9
-rw-r--r--geom_matching/wasserstein/include/dnn/parallel/tbb.h2
-rw-r--r--geom_matching/wasserstein/include/wasserstein.h36
8 files changed, 70 insertions, 26 deletions
diff --git a/geom_matching/wasserstein/include/auction_runner_gs.h b/geom_matching/wasserstein/include/auction_runner_gs.h
index 34a91e8..80aa9f0 100644
--- a/geom_matching/wasserstein/include/auction_runner_gs.h
+++ b/geom_matching/wasserstein/include/auction_runner_gs.h
@@ -76,8 +76,9 @@ public:
const double _initialEpsilon,
const double _epsFactor);
void setEpsilon(double newVal) { assert(epsilon > 0.0); epsilon = newVal; };
- double getEpsilon(void) const { return epsilon; }
- double getWassersteinDistance(void);
+ double getEpsilon() const { return epsilon; }
+ double getWassersteinDistance();
+ double getWassersteinCost();
static constexpr int maxIterNum { 25 }; // maximal number of iterations of epsilon-scaling
private:
// private data
@@ -94,6 +95,7 @@ private:
double epsilonCommonRatio; // next epsilon = current epsilon / epsilonCommonRatio
double weightAdjConst;
double wassersteinDistance;
+ double wassersteinCost;
// to get the 2 best items
std::unique_ptr<AuctionOracle> oracle;
#ifdef KEEP_UNASSIGNED_ORDERED
diff --git a/geom_matching/wasserstein/include/auction_runner_jac.h b/geom_matching/wasserstein/include/auction_runner_jac.h
index ae0cb56..22d42b0 100644
--- a/geom_matching/wasserstein/include/auction_runner_jac.h
+++ b/geom_matching/wasserstein/include/auction_runner_jac.h
@@ -45,12 +45,13 @@ using AuctionOracle = AuctionOracleKDTreeRestricted;
// 1. epsilonCommonRatio
// 2. maxIterNum
-class AuctionRunnerJak {
+class AuctionRunnerJac {
public:
- AuctionRunnerJak(const std::vector<DiagramPoint>& A, const std::vector<DiagramPoint>& B, const double q, const double _delta, const double _internal_p);
+ AuctionRunnerJac(const std::vector<DiagramPoint>& A, const std::vector<DiagramPoint>& B, const double q, const double _delta, const double _internal_p);
void setEpsilon(double newVal) { assert(epsilon > 0.0); epsilon = newVal; };
- double getEpsilon(void) const { return epsilon; }
- double getWassersteinDistance(void);
+ double getEpsilon() const { return epsilon; }
+ double getWassersteinDistance();
+ double getWassersteinCost();
static constexpr double epsilonCommonRatio { 5 }; // next epsilon = current epsilon / epsilonCommonRatio
static constexpr int maxIterNum { 25 }; // maximal number of iterations of epsilon-scaling
private:
@@ -66,6 +67,7 @@ private:
double internal_p;
double weightAdjConst;
double wassersteinDistance;
+ double wassersteinCost;
std::vector<IdxValPair> bidTable;
// to get the 2 best items
std::unique_ptr<AuctionOracle> oracle;
@@ -76,18 +78,18 @@ private:
// private methods
void assignGoodToBidder(const IdxType bidderIdx, const IdxType itemsIdx);
void assignToBestBidder(const IdxType itemsIdx);
- void clearBidTable(void);
- void runAuction(void);
- void runAuctionPhase(void);
+ void clearBidTable();
+ void runAuction();
+ void runAuctionPhase();
void submitBid(IdxType bidderIdx, const IdxValPair& itemsBidValuePair);
- void flushAssignment(void);
+ void flushAssignment();
// for debug only
- void sanityCheck(void);
- void printDebug(void);
- int countUnhappy(void);
- void printMatching(void);
- double getDistanceToQthPowerInternal(void);
+ void sanityCheck();
+ void printDebug();
+ int countUnhappy();
+ void printMatching();
+ double getDistanceToQthPowerInternal();
};
diff --git a/geom_matching/wasserstein/include/basic_defs_ws.h b/geom_matching/wasserstein/include/basic_defs_ws.h
index 9e9c4ec..365c3bd 100644
--- a/geom_matching/wasserstein/include/basic_defs_ws.h
+++ b/geom_matching/wasserstein/include/basic_defs_ws.h
@@ -33,7 +33,6 @@ derivative works thereof, in binary and source code form.
#include <cstddef>
#include <unordered_map>
#include <unordered_set>
-#include <iostream>
#include <string>
#include <assert.h>
@@ -58,7 +57,9 @@ struct Point {
bool operator!=(const Point& other) const;
Point(double ax, double ay) : x(ax), y(ay) {}
Point() : x(0.0), y(0.0) {}
+#ifndef FOR_R_TDA
friend std::ostream& operator<<(std::ostream& output, const Point p);
+#endif
};
struct DiagramPoint
@@ -76,7 +77,9 @@ struct DiagramPoint
bool isNormal(void) const { return type == NORMAL; }
double getRealX() const; // return the x-coord
double getRealY() const; // return the y-coord
+#ifndef FOR_R_TDA
friend std::ostream& operator<<(std::ostream& output, const DiagramPoint p);
+#endif
struct LexicographicCmp
{
diff --git a/geom_matching/wasserstein/include/def_debug_ws.h b/geom_matching/wasserstein/include/def_debug_ws.h
index 7323c18..6751556 100644
--- a/geom_matching/wasserstein/include/def_debug_ws.h
+++ b/geom_matching/wasserstein/include/def_debug_ws.h
@@ -25,12 +25,16 @@ derivative works thereof, in binary and source code form.
*/
-#ifndef DEF_DEBUG_H
-#define DEF_DEBUG_H
+#ifndef DEF_DEBUG_WS_H
+#define DEF_DEBUG_WS_H
//#define DEBUG_BOUND_MATCH
//#define DEBUG_NEIGHBOUR_ORACLE
//#define DEBUG_MATCHING
//#define DEBUG_AUCTION
+// This symbol should be defined only in the version
+// for R package TDA, to comply with some CRAN rules
+// like no usage of cout, cerr, cin, exit, etc.
+//#define FOR_R_TDA
#endif
diff --git a/geom_matching/wasserstein/include/dnn/geometry/euclidean-fixed.h b/geom_matching/wasserstein/include/dnn/geometry/euclidean-fixed.h
index a6ccef7..e2c5b44 100644
--- a/geom_matching/wasserstein/include/dnn/geometry/euclidean-fixed.h
+++ b/geom_matching/wasserstein/include/dnn/geometry/euclidean-fixed.h
@@ -7,7 +7,7 @@
#include <boost/serialization/access.hpp>
#include <boost/serialization/base_object.hpp>
-#include <iostream>
+//#include <iostream>
#include <fstream>
#include <string>
#include <sstream>
diff --git a/geom_matching/wasserstein/include/dnn/local/kd-tree.hpp b/geom_matching/wasserstein/include/dnn/local/kd-tree.hpp
index 151a4ad..6b0852c 100644
--- a/geom_matching/wasserstein/include/dnn/local/kd-tree.hpp
+++ b/geom_matching/wasserstein/include/dnn/local/kd-tree.hpp
@@ -6,6 +6,7 @@
#include <stack>
#include "../parallel/tbb.h"
+#include "def_debug_ws.h"
template<class T>
dnn::KDTree<T>::
@@ -127,13 +128,8 @@ search(PointHandle q, ResultsFunctor& rf) const
// TODO: use tbb::scalable_allocator for the queue
std::queue<KDTreeNode> nodes;
-
-
nodes.push(KDTreeNode(tree_.begin(), tree_.end(), 0));
-
- //std::cout << "started kdtree::search" << std::endl;
-
while (!nodes.empty())
{
HCIterator b, e; size_t i;
@@ -163,7 +159,6 @@ search(PointHandle q, ResultsFunctor& rf) const
nodes.push(KDTreeNode(b, m, i));
}
}
- //std::cout << "exited kdtree::search" << std::endl;
}
template<class T>
@@ -290,6 +285,7 @@ void
dnn::KDTree<T>::
printWeights(void)
{
+#ifndef FOR_R_TDA
std::cout << "weights_:" << std::endl;
for(const auto ph : indices_) {
std::cout << "idx = " << ph.second << ": (" << (ph.first)->at(0) << ", " << (ph.first)->at(1) << ") weight = " << weights_[ph.second] << std::endl;
@@ -298,6 +294,7 @@ printWeights(void)
for(size_t idx = 0; idx < subtree_weights_.size(); ++idx) {
std::cout << idx << " : " << subtree_weights_[idx] << std::endl;
}
+#endif
}
diff --git a/geom_matching/wasserstein/include/dnn/parallel/tbb.h b/geom_matching/wasserstein/include/dnn/parallel/tbb.h
index 4aa6805..64c59e0 100644
--- a/geom_matching/wasserstein/include/dnn/parallel/tbb.h
+++ b/geom_matching/wasserstein/include/dnn/parallel/tbb.h
@@ -1,7 +1,7 @@
#ifndef PARALLEL_H
#define PARALLEL_H
-#include <iostream>
+//#include <iostream>
#include <vector>
#include <boost/range.hpp>
diff --git a/geom_matching/wasserstein/include/wasserstein.h b/geom_matching/wasserstein/include/wasserstein.h
index 38ac6bd..fcc9e74 100644
--- a/geom_matching/wasserstein/include/wasserstein.h
+++ b/geom_matching/wasserstein/include/wasserstein.h
@@ -49,6 +49,15 @@ double wassersteinDistVec(const std::vector<DiagramPoint>& A,
const double _initialEpsilon = 0.0,
const double _epsFactor = 0.0);
+// get Wasserstein cost (distance^q) between two persistence diagrams
+double wassersteinCostVec(const std::vector<DiagramPoint>& A,
+ const std::vector<DiagramPoint>& B,
+ const double q,
+ const double delta,
+ const double _internal_p = std::numeric_limits<double>::infinity(),
+ const double _initialEpsilon = 0.0,
+ const double _epsFactor = 0.0);
+
// compare as multisets
template<class PairContainer>
@@ -98,6 +107,33 @@ double wassersteinDist(PairContainer& A, PairContainer& B, const double q, const
return wassersteinDistVec(dgmA, dgmB, q, delta, _internal_p, _initialEpsilon, _epsFactor);
}
+template<class PairContainer>
+double wassersteinCost(PairContainer& A, PairContainer& B, const double q, const double delta, const double _internal_p = std::numeric_limits<double>::infinity(), const double _initialEpsilon = 0.0, const double _epsFactor = 0.0)
+{
+ if (areEqual(A, B)) {
+ return 0.0;
+ }
+
+ std::vector<DiagramPoint> dgmA, dgmB;
+ // loop over A, add projections of A-points to corresponding positions
+ // in B-vector
+ for(auto& pairA : A) {
+ double x = pairA.first;
+ double y = pairA.second;
+ dgmA.push_back(DiagramPoint(x, y, DiagramPoint::NORMAL));
+ dgmB.push_back(DiagramPoint(x, y, DiagramPoint::DIAG));
+ }
+ // the same for B
+ for(auto& pairB : B) {
+ double x = pairB.first;
+ double y = pairB.second;
+ dgmA.push_back(DiagramPoint(x, y, DiagramPoint::DIAG));
+ dgmB.push_back(DiagramPoint(x, y, DiagramPoint::NORMAL));
+ }
+
+ return wassersteinCostVec(dgmA, dgmB, q, delta, _internal_p, _initialEpsilon, _epsFactor);
+}
+
// fill in result with points from file fname
// return false if file can't be opened