diff options
author | Arnur Nigmetov <a.nigmetov@gmail.com> | 2018-01-20 19:11:29 +0100 |
---|---|---|
committer | Arnur Nigmetov <a.nigmetov@gmail.com> | 2018-01-20 19:11:29 +0100 |
commit | 0cc35ad04f9c2997014d7cf62a12f697e79fb534 (patch) | |
tree | 744c07bc2c12fba193934ac98417c5063bead189 /geom_matching/wasserstein/include/basic_defs_ws.h | |
parent | 3552ce68bc7654df35da471bd937b09a9fde101f (diff) |
Major rewrite, templatized version
Diffstat (limited to 'geom_matching/wasserstein/include/basic_defs_ws.h')
-rw-r--r-- | geom_matching/wasserstein/include/basic_defs_ws.h | 325 |
1 files changed, 271 insertions, 54 deletions
diff --git a/geom_matching/wasserstein/include/basic_defs_ws.h b/geom_matching/wasserstein/include/basic_defs_ws.h index db305c0..58d6fd2 100644 --- a/geom_matching/wasserstein/include/basic_defs_ws.h +++ b/geom_matching/wasserstein/include/basic_defs_ws.h @@ -29,91 +29,308 @@ derivative works thereof, in binary and source code form. #define BASIC_DEFS_WS_H #include <vector> -#include <cmath> +#include <math.h> #include <cstddef> #include <unordered_map> #include <unordered_set> #include <string> +#include <iomanip> +#include <locale> #include <cassert> +#include <limits> +#include <ostream> +#include <typeinfo> #ifdef _WIN32 #include <ciso646> #endif +#ifndef FOR_R_TDA +#include "spdlog/spdlog.h" +#include "spdlog/fmt/fmt.h" +#include "spdlog/fmt/ostr.h" +#endif +#include "dnn/geometry/euclidean-dynamic.h" #include "def_debug_ws.h" #define MIN_VALID_ID 10 -namespace geom_ws { +namespace hera +{ -using IdxType = int; -using IdxValPair = std::pair<IdxType, double>; +template<class Real = double> +bool is_infinity(const Real& x) +{ + return x == Real(-1); +}; +template<class Real = double> +Real get_infinity() +{ + return Real( -1 ); +} -struct Point { - double x, y; - bool operator==(const Point& other) const; - bool operator!=(const Point& other) const; - Point(double ax, double ay) : x(ax), y(ay) {} - Point() : x(0.0), y(0.0) {} -#ifndef FOR_R_TDA - friend std::ostream& operator<<(std::ostream& output, const Point p); -#endif +template<class Real = double> +bool is_p_valid_norm(const Real& p) +{ + return is_infinity<Real>(p) or p >= Real(1); +} + +template<class Real = double> +struct AuctionParams +{ + Real wasserstein_power { 1.0 }; + Real delta { 0.01 }; // relative error + Real internal_p { get_infinity<Real>() }; + Real initial_epsilon { 0.0 }; // 0.0 means maxVal / 4.0 + Real epsilon_common_ratio { 5.0 }; + Real gamma_threshold { 0.0 }; // for experiments, not in use now + int max_num_phases { std::numeric_limits<decltype(max_num_phases)>::max() }; + size_t max_bids_per_round { 1 }; // imitate Gauss-Seidel is default behaviour + unsigned int dim { 2 }; // for pure geometric version only; ignored in persistence diagrams }; -struct DiagramPoint +namespace ws { - // data members - // Points above the diagonal have type NORMAL - // Projections onto the diagonal have type DIAG - // for DIAG points only x-coordinate is relevant - enum Type { NORMAL, DIAG}; - double x, y; - Type type; - // methods - DiagramPoint(double xx, double yy, Type ttype); - bool isDiagonal(void) const { return type == DIAG; } - bool isNormal(void) const { return type == NORMAL; } - double getRealX() const; // return the x-coord - double getRealY() const; // return the y-coord - double persistenceLp(const double p) const; + + using IdxType = int; + + constexpr size_t k_invalid_index = std::numeric_limits<IdxType>::max(); + + template<class Real = double> + using IdxValPair = std::pair<IdxType, Real>; + + + + template<class R> + std::ostream& operator<<(std::ostream& output, const IdxValPair<R> p) + { + output << fmt::format("({0}, {1})", p.first, p.second); + return output; + } + + enum class OwnerType { k_none, k_normal, k_diagonal }; + + std::ostream& operator<<(std::ostream& s, const OwnerType t) + { + switch(t) + { + case OwnerType::k_none : s << "NONE"; break; + case OwnerType::k_normal: s << "NORMAL"; break; + case OwnerType::k_diagonal: s << "DIAGONAL"; break; + } + return s; + } + + template<class Real = double> + struct Point { + Real x, y; + bool operator==(const Point& other) const; + bool operator!=(const Point& other) const; + Point(Real _x, Real _y) : x(_x), y(_y) {} + Point() : x(0.0), y(0.0) {} + }; + #ifndef FOR_R_TDA - friend std::ostream& operator<<(std::ostream& output, const DiagramPoint p); + template<class Real = double> + std::ostream& operator<<(std::ostream& output, const Point<Real> p); #endif - struct LexicographicCmp + template <class T> + inline void hash_combine(std::size_t & seed, const T & v) { - bool operator()(const DiagramPoint& p1, const DiagramPoint& p2) const - { return p1.type < p2.type || (p1.type == p2.type && (p1.x < p2.x || (p1.x == p2.x && p1.y < p2.y))); } + std::hash<T> hasher; + seed ^= hasher(v) + 0x9e3779b9 + (seed << 6) + (seed >> 2); + } + + template<class Real_ = double> + struct DiagramPoint + { + using Real = Real_; + // data members + // Points above the diagonal have type NORMAL + // Projections onto the diagonal have type DIAG + // for DIAG points only x-coordinate is relevant + enum Type { NORMAL, DIAG}; + Real x, y; + Type type; + // methods + DiagramPoint(Real xx, Real yy, Type ttype); + bool is_diagonal() const { return type == DIAG; } + bool is_normal() const { return type == NORMAL; } + Real getRealX() const; // return the x-coord + Real getRealY() const; // return the y-coord + Real persistence_lp(const Real p) const; + struct LexicographicCmp + { + bool operator()(const DiagramPoint& p1, const DiagramPoint& p2) const + { return p1.type < p2.type || (p1.type == p2.type && (p1.x < p2.x || (p1.x == p2.x && p1.y < p2.y))); } + }; + + const Real& operator[](const int idx) const + { + switch(idx) + { + case 0 : return x; + break; + case 1 : return y; + break; + default: throw std::out_of_range("DiagramPoint has dimension 2"); + } + } + + Real& operator[](const int idx) + { + switch(idx) + { + case 0 : return x; + break; + case 1 : return y; + break; + default: throw std::out_of_range("DiagramPoint has dimension 2"); + } + } + }; -}; -double sqrDist(const Point& a, const Point& b); -double dist(const Point& a, const Point& b); -double distLInf(const DiagramPoint& a, const DiagramPoint& b); -double distLp(const DiagramPoint& a, const DiagramPoint& b, const double p); -double persistenceLp(const DiagramPoint& a, const double p); -template<typename DiagPointContainer> -double getFurthestDistance3Approx(DiagPointContainer& A, DiagPointContainer& B) -{ - double result { 0.0 }; - DiagramPoint begA = *(A.begin()); - DiagramPoint optB = *(B.begin()); - for(const auto& pointB : B) { - if (distLInf(begA, pointB) > result) { - result = distLInf(begA, pointB); - optB = pointB; + template<class Real> + struct DiagramPointHash { + size_t operator()(const DiagramPoint<Real> &p) const + { + std::size_t seed = 0; + hash_combine(seed, std::hash<Real>(p.x)); + hash_combine(seed, std::hash<Real>(p.y)); + hash_combine(seed, std::hash<bool>(p.is_diagonal())); + return seed; + } + }; + + +#ifndef FOR_R_TDA + template <class Real = double> + std::ostream& operator<<(std::ostream& output, const DiagramPoint<Real> p); +#endif + + template<class Real> + void format_arg(fmt::BasicFormatter<char> &f, const char *&format_str, const DiagramPoint<Real>&p) { + if (p.is_diagonal()) { + f.writer().write("({0},{1}, DIAG)", p.x, p.y); + } else { + f.writer().write("({0},{1}, NORM)", p.x, p.y); } } - for(const auto& pointA : A) { - if (distLInf(pointA, optB) > result) { - result = distLInf(pointA, optB); + + + template<class Real, class Pt> + struct DistImpl + { + Real operator()(const Pt& a, const Pt& b, const Real p, const int dim) + { + Real result = 0.0; + if (hera::is_infinity(p)) { + for(int d = 0; d < dim; ++d) { + result = std::max(result, std::fabs(a[d] - b[d])); + } + } else if (p == 1.0) { + for(int d = 0; d < dim; ++d) { + result += std::fabs(a[d] - b[d]); + } + } else { + assert(p > 1.0); + for(int d = 0; d < dim; ++d) { + result += std::pow(std::fabs(a[d] - b[d]), p); + } + result = std::pow(result, 1.0 / p); + } + return result; } + }; + + template<class Real> + struct DistImpl<Real, DiagramPoint<Real>> + { + Real operator()(const DiagramPoint<Real>& a, const DiagramPoint<Real>& b, const Real p, const int dim) + { + Real result = 0.0; + if ( a.is_diagonal() and b.is_diagonal()) { + return result; + } else if (hera::is_infinity(p)) { + result = std::max(std::fabs(a.getRealX() - b.getRealX()), std::fabs(a.getRealY() - b.getRealY())); + } else if (p == 1.0) { + result = std::fabs(a.getRealX() - b.getRealX()) + std::fabs(a.getRealY() - b.getRealY()); + } else { + assert(p > 1.0); + result = std::pow(std::pow(std::fabs(a.getRealX() - b.getRealX()), p) + std::pow(std::fabs(a.getRealY() - b.getRealY()), p), 1.0 / p); + } + return result; + } + }; + + template<class R, class Pt> + R dist_lp(const Pt& a, const Pt& b, const R p, const int dim) + { + return DistImpl<R, Pt>()(a, b, p, dim); } - return result; -} -} // end of namespace geom_ws + // TODO + template<class Real, typename DiagPointContainer> + double getFurthestDistance3Approx(DiagPointContainer& A, DiagPointContainer& B, const Real p) + { + int dim = 2; + Real result { 0.0 }; + DiagramPoint<Real> begA = *(A.begin()); + DiagramPoint<Real> optB = *(B.begin()); + for(const auto& pointB : B) { + if (dist_lp(begA, pointB, p, dim) > result) { + result = dist_lp(begA, pointB, p, dim); + optB = pointB; + } + } + for(const auto& pointA : A) { + if (dist_lp(pointA, optB, p, dim) > result) { + result = dist_lp(pointA, optB, p, dim); + } + } + return result; + } + + template<class Real> + Real getFurthestDistance3Approx_pg(const hera::ws::dnn::DynamicPointVector<Real>& A, const hera::ws::dnn::DynamicPointVector<Real>& B, const Real p, const int dim) + { + Real result { 0.0 }; + int opt_b_idx = 0; + for(size_t b_idx = 0; b_idx < B.size(); ++b_idx) { + if (dist_lp(A[0], B[b_idx], p, dim) > result) { + result = dist_lp(A[0], B[b_idx], p, dim); + opt_b_idx = b_idx; + } + } + + for(size_t a_idx = 0; a_idx < A.size(); ++a_idx) { + result = std::max(result, dist_lp(A[a_idx], B[opt_b_idx], p, dim)); + } + + return result; + } + + + template<class Container> + std::string format_container_to_log(const Container& cont); + + template<class Real, class IndexContainer> + std::string format_point_set_to_log(const IndexContainer& indices, const std::vector<DiagramPoint<Real>>& points); + + template<class T> + std::string format_int(T i); + +} // ws +} // hera + + + +#include "basic_defs_ws.hpp" + + #endif |