summaryrefslogtreecommitdiff
path: root/geom_matching/wasserstein/include/basic_defs_ws.h
diff options
context:
space:
mode:
Diffstat (limited to 'geom_matching/wasserstein/include/basic_defs_ws.h')
-rw-r--r--geom_matching/wasserstein/include/basic_defs_ws.h325
1 files changed, 271 insertions, 54 deletions
diff --git a/geom_matching/wasserstein/include/basic_defs_ws.h b/geom_matching/wasserstein/include/basic_defs_ws.h
index db305c0..58d6fd2 100644
--- a/geom_matching/wasserstein/include/basic_defs_ws.h
+++ b/geom_matching/wasserstein/include/basic_defs_ws.h
@@ -29,91 +29,308 @@ derivative works thereof, in binary and source code form.
#define BASIC_DEFS_WS_H
#include <vector>
-#include <cmath>
+#include <math.h>
#include <cstddef>
#include <unordered_map>
#include <unordered_set>
#include <string>
+#include <iomanip>
+#include <locale>
#include <cassert>
+#include <limits>
+#include <ostream>
+#include <typeinfo>
#ifdef _WIN32
#include <ciso646>
#endif
+#ifndef FOR_R_TDA
+#include "spdlog/spdlog.h"
+#include "spdlog/fmt/fmt.h"
+#include "spdlog/fmt/ostr.h"
+#endif
+#include "dnn/geometry/euclidean-dynamic.h"
#include "def_debug_ws.h"
#define MIN_VALID_ID 10
-namespace geom_ws {
+namespace hera
+{
-using IdxType = int;
-using IdxValPair = std::pair<IdxType, double>;
+template<class Real = double>
+bool is_infinity(const Real& x)
+{
+ return x == Real(-1);
+};
+template<class Real = double>
+Real get_infinity()
+{
+ return Real( -1 );
+}
-struct Point {
- double x, y;
- bool operator==(const Point& other) const;
- bool operator!=(const Point& other) const;
- Point(double ax, double ay) : x(ax), y(ay) {}
- Point() : x(0.0), y(0.0) {}
-#ifndef FOR_R_TDA
- friend std::ostream& operator<<(std::ostream& output, const Point p);
-#endif
+template<class Real = double>
+bool is_p_valid_norm(const Real& p)
+{
+ return is_infinity<Real>(p) or p >= Real(1);
+}
+
+template<class Real = double>
+struct AuctionParams
+{
+ Real wasserstein_power { 1.0 };
+ Real delta { 0.01 }; // relative error
+ Real internal_p { get_infinity<Real>() };
+ Real initial_epsilon { 0.0 }; // 0.0 means maxVal / 4.0
+ Real epsilon_common_ratio { 5.0 };
+ Real gamma_threshold { 0.0 }; // for experiments, not in use now
+ int max_num_phases { std::numeric_limits<decltype(max_num_phases)>::max() };
+ size_t max_bids_per_round { 1 }; // imitate Gauss-Seidel is default behaviour
+ unsigned int dim { 2 }; // for pure geometric version only; ignored in persistence diagrams
};
-struct DiagramPoint
+namespace ws
{
- // data members
- // Points above the diagonal have type NORMAL
- // Projections onto the diagonal have type DIAG
- // for DIAG points only x-coordinate is relevant
- enum Type { NORMAL, DIAG};
- double x, y;
- Type type;
- // methods
- DiagramPoint(double xx, double yy, Type ttype);
- bool isDiagonal(void) const { return type == DIAG; }
- bool isNormal(void) const { return type == NORMAL; }
- double getRealX() const; // return the x-coord
- double getRealY() const; // return the y-coord
- double persistenceLp(const double p) const;
+
+ using IdxType = int;
+
+ constexpr size_t k_invalid_index = std::numeric_limits<IdxType>::max();
+
+ template<class Real = double>
+ using IdxValPair = std::pair<IdxType, Real>;
+
+
+
+ template<class R>
+ std::ostream& operator<<(std::ostream& output, const IdxValPair<R> p)
+ {
+ output << fmt::format("({0}, {1})", p.first, p.second);
+ return output;
+ }
+
+ enum class OwnerType { k_none, k_normal, k_diagonal };
+
+ std::ostream& operator<<(std::ostream& s, const OwnerType t)
+ {
+ switch(t)
+ {
+ case OwnerType::k_none : s << "NONE"; break;
+ case OwnerType::k_normal: s << "NORMAL"; break;
+ case OwnerType::k_diagonal: s << "DIAGONAL"; break;
+ }
+ return s;
+ }
+
+ template<class Real = double>
+ struct Point {
+ Real x, y;
+ bool operator==(const Point& other) const;
+ bool operator!=(const Point& other) const;
+ Point(Real _x, Real _y) : x(_x), y(_y) {}
+ Point() : x(0.0), y(0.0) {}
+ };
+
#ifndef FOR_R_TDA
- friend std::ostream& operator<<(std::ostream& output, const DiagramPoint p);
+ template<class Real = double>
+ std::ostream& operator<<(std::ostream& output, const Point<Real> p);
#endif
- struct LexicographicCmp
+ template <class T>
+ inline void hash_combine(std::size_t & seed, const T & v)
{
- bool operator()(const DiagramPoint& p1, const DiagramPoint& p2) const
- { return p1.type < p2.type || (p1.type == p2.type && (p1.x < p2.x || (p1.x == p2.x && p1.y < p2.y))); }
+ std::hash<T> hasher;
+ seed ^= hasher(v) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
+ }
+
+ template<class Real_ = double>
+ struct DiagramPoint
+ {
+ using Real = Real_;
+ // data members
+ // Points above the diagonal have type NORMAL
+ // Projections onto the diagonal have type DIAG
+ // for DIAG points only x-coordinate is relevant
+ enum Type { NORMAL, DIAG};
+ Real x, y;
+ Type type;
+ // methods
+ DiagramPoint(Real xx, Real yy, Type ttype);
+ bool is_diagonal() const { return type == DIAG; }
+ bool is_normal() const { return type == NORMAL; }
+ Real getRealX() const; // return the x-coord
+ Real getRealY() const; // return the y-coord
+ Real persistence_lp(const Real p) const;
+ struct LexicographicCmp
+ {
+ bool operator()(const DiagramPoint& p1, const DiagramPoint& p2) const
+ { return p1.type < p2.type || (p1.type == p2.type && (p1.x < p2.x || (p1.x == p2.x && p1.y < p2.y))); }
+ };
+
+ const Real& operator[](const int idx) const
+ {
+ switch(idx)
+ {
+ case 0 : return x;
+ break;
+ case 1 : return y;
+ break;
+ default: throw std::out_of_range("DiagramPoint has dimension 2");
+ }
+ }
+
+ Real& operator[](const int idx)
+ {
+ switch(idx)
+ {
+ case 0 : return x;
+ break;
+ case 1 : return y;
+ break;
+ default: throw std::out_of_range("DiagramPoint has dimension 2");
+ }
+ }
+
};
-};
-double sqrDist(const Point& a, const Point& b);
-double dist(const Point& a, const Point& b);
-double distLInf(const DiagramPoint& a, const DiagramPoint& b);
-double distLp(const DiagramPoint& a, const DiagramPoint& b, const double p);
-double persistenceLp(const DiagramPoint& a, const double p);
-template<typename DiagPointContainer>
-double getFurthestDistance3Approx(DiagPointContainer& A, DiagPointContainer& B)
-{
- double result { 0.0 };
- DiagramPoint begA = *(A.begin());
- DiagramPoint optB = *(B.begin());
- for(const auto& pointB : B) {
- if (distLInf(begA, pointB) > result) {
- result = distLInf(begA, pointB);
- optB = pointB;
+ template<class Real>
+ struct DiagramPointHash {
+ size_t operator()(const DiagramPoint<Real> &p) const
+ {
+ std::size_t seed = 0;
+ hash_combine(seed, std::hash<Real>(p.x));
+ hash_combine(seed, std::hash<Real>(p.y));
+ hash_combine(seed, std::hash<bool>(p.is_diagonal()));
+ return seed;
+ }
+ };
+
+
+#ifndef FOR_R_TDA
+ template <class Real = double>
+ std::ostream& operator<<(std::ostream& output, const DiagramPoint<Real> p);
+#endif
+
+ template<class Real>
+ void format_arg(fmt::BasicFormatter<char> &f, const char *&format_str, const DiagramPoint<Real>&p) {
+ if (p.is_diagonal()) {
+ f.writer().write("({0},{1}, DIAG)", p.x, p.y);
+ } else {
+ f.writer().write("({0},{1}, NORM)", p.x, p.y);
}
}
- for(const auto& pointA : A) {
- if (distLInf(pointA, optB) > result) {
- result = distLInf(pointA, optB);
+
+
+ template<class Real, class Pt>
+ struct DistImpl
+ {
+ Real operator()(const Pt& a, const Pt& b, const Real p, const int dim)
+ {
+ Real result = 0.0;
+ if (hera::is_infinity(p)) {
+ for(int d = 0; d < dim; ++d) {
+ result = std::max(result, std::fabs(a[d] - b[d]));
+ }
+ } else if (p == 1.0) {
+ for(int d = 0; d < dim; ++d) {
+ result += std::fabs(a[d] - b[d]);
+ }
+ } else {
+ assert(p > 1.0);
+ for(int d = 0; d < dim; ++d) {
+ result += std::pow(std::fabs(a[d] - b[d]), p);
+ }
+ result = std::pow(result, 1.0 / p);
+ }
+ return result;
}
+ };
+
+ template<class Real>
+ struct DistImpl<Real, DiagramPoint<Real>>
+ {
+ Real operator()(const DiagramPoint<Real>& a, const DiagramPoint<Real>& b, const Real p, const int dim)
+ {
+ Real result = 0.0;
+ if ( a.is_diagonal() and b.is_diagonal()) {
+ return result;
+ } else if (hera::is_infinity(p)) {
+ result = std::max(std::fabs(a.getRealX() - b.getRealX()), std::fabs(a.getRealY() - b.getRealY()));
+ } else if (p == 1.0) {
+ result = std::fabs(a.getRealX() - b.getRealX()) + std::fabs(a.getRealY() - b.getRealY());
+ } else {
+ assert(p > 1.0);
+ result = std::pow(std::pow(std::fabs(a.getRealX() - b.getRealX()), p) + std::pow(std::fabs(a.getRealY() - b.getRealY()), p), 1.0 / p);
+ }
+ return result;
+ }
+ };
+
+ template<class R, class Pt>
+ R dist_lp(const Pt& a, const Pt& b, const R p, const int dim)
+ {
+ return DistImpl<R, Pt>()(a, b, p, dim);
}
- return result;
-}
-} // end of namespace geom_ws
+ // TODO
+ template<class Real, typename DiagPointContainer>
+ double getFurthestDistance3Approx(DiagPointContainer& A, DiagPointContainer& B, const Real p)
+ {
+ int dim = 2;
+ Real result { 0.0 };
+ DiagramPoint<Real> begA = *(A.begin());
+ DiagramPoint<Real> optB = *(B.begin());
+ for(const auto& pointB : B) {
+ if (dist_lp(begA, pointB, p, dim) > result) {
+ result = dist_lp(begA, pointB, p, dim);
+ optB = pointB;
+ }
+ }
+ for(const auto& pointA : A) {
+ if (dist_lp(pointA, optB, p, dim) > result) {
+ result = dist_lp(pointA, optB, p, dim);
+ }
+ }
+ return result;
+ }
+
+ template<class Real>
+ Real getFurthestDistance3Approx_pg(const hera::ws::dnn::DynamicPointVector<Real>& A, const hera::ws::dnn::DynamicPointVector<Real>& B, const Real p, const int dim)
+ {
+ Real result { 0.0 };
+ int opt_b_idx = 0;
+ for(size_t b_idx = 0; b_idx < B.size(); ++b_idx) {
+ if (dist_lp(A[0], B[b_idx], p, dim) > result) {
+ result = dist_lp(A[0], B[b_idx], p, dim);
+ opt_b_idx = b_idx;
+ }
+ }
+
+ for(size_t a_idx = 0; a_idx < A.size(); ++a_idx) {
+ result = std::max(result, dist_lp(A[a_idx], B[opt_b_idx], p, dim));
+ }
+
+ return result;
+ }
+
+
+ template<class Container>
+ std::string format_container_to_log(const Container& cont);
+
+ template<class Real, class IndexContainer>
+ std::string format_point_set_to_log(const IndexContainer& indices, const std::vector<DiagramPoint<Real>>& points);
+
+ template<class T>
+ std::string format_int(T i);
+
+} // ws
+} // hera
+
+
+
+#include "basic_defs_ws.hpp"
+
+
#endif