From 0cc35ad04f9c2997014d7cf62a12f697e79fb534 Mon Sep 17 00:00:00 2001 From: Arnur Nigmetov Date: Sat, 20 Jan 2018 19:11:29 +0100 Subject: Major rewrite, templatized version --- geom_matching/wasserstein/include/basic_defs_ws.h | 325 ++++++++++++++++++---- 1 file changed, 271 insertions(+), 54 deletions(-) (limited to 'geom_matching/wasserstein/include/basic_defs_ws.h') diff --git a/geom_matching/wasserstein/include/basic_defs_ws.h b/geom_matching/wasserstein/include/basic_defs_ws.h index db305c0..58d6fd2 100644 --- a/geom_matching/wasserstein/include/basic_defs_ws.h +++ b/geom_matching/wasserstein/include/basic_defs_ws.h @@ -29,91 +29,308 @@ derivative works thereof, in binary and source code form. #define BASIC_DEFS_WS_H #include -#include +#include #include #include #include #include +#include +#include #include +#include +#include +#include #ifdef _WIN32 #include #endif +#ifndef FOR_R_TDA +#include "spdlog/spdlog.h" +#include "spdlog/fmt/fmt.h" +#include "spdlog/fmt/ostr.h" +#endif +#include "dnn/geometry/euclidean-dynamic.h" #include "def_debug_ws.h" #define MIN_VALID_ID 10 -namespace geom_ws { +namespace hera +{ -using IdxType = int; -using IdxValPair = std::pair; +template +bool is_infinity(const Real& x) +{ + return x == Real(-1); +}; +template +Real get_infinity() +{ + return Real( -1 ); +} -struct Point { - double x, y; - bool operator==(const Point& other) const; - bool operator!=(const Point& other) const; - Point(double ax, double ay) : x(ax), y(ay) {} - Point() : x(0.0), y(0.0) {} -#ifndef FOR_R_TDA - friend std::ostream& operator<<(std::ostream& output, const Point p); -#endif +template +bool is_p_valid_norm(const Real& p) +{ + return is_infinity(p) or p >= Real(1); +} + +template +struct AuctionParams +{ + Real wasserstein_power { 1.0 }; + Real delta { 0.01 }; // relative error + Real internal_p { get_infinity() }; + Real initial_epsilon { 0.0 }; // 0.0 means maxVal / 4.0 + Real epsilon_common_ratio { 5.0 }; + Real gamma_threshold { 0.0 }; // for experiments, not in use now + int max_num_phases { std::numeric_limits::max() }; + size_t max_bids_per_round { 1 }; // imitate Gauss-Seidel is default behaviour + unsigned int dim { 2 }; // for pure geometric version only; ignored in persistence diagrams }; -struct DiagramPoint +namespace ws { - // data members - // Points above the diagonal have type NORMAL - // Projections onto the diagonal have type DIAG - // for DIAG points only x-coordinate is relevant - enum Type { NORMAL, DIAG}; - double x, y; - Type type; - // methods - DiagramPoint(double xx, double yy, Type ttype); - bool isDiagonal(void) const { return type == DIAG; } - bool isNormal(void) const { return type == NORMAL; } - double getRealX() const; // return the x-coord - double getRealY() const; // return the y-coord - double persistenceLp(const double p) const; + + using IdxType = int; + + constexpr size_t k_invalid_index = std::numeric_limits::max(); + + template + using IdxValPair = std::pair; + + + + template + std::ostream& operator<<(std::ostream& output, const IdxValPair p) + { + output << fmt::format("({0}, {1})", p.first, p.second); + return output; + } + + enum class OwnerType { k_none, k_normal, k_diagonal }; + + std::ostream& operator<<(std::ostream& s, const OwnerType t) + { + switch(t) + { + case OwnerType::k_none : s << "NONE"; break; + case OwnerType::k_normal: s << "NORMAL"; break; + case OwnerType::k_diagonal: s << "DIAGONAL"; break; + } + return s; + } + + template + struct Point { + Real x, y; + bool operator==(const Point& other) const; + bool operator!=(const Point& other) const; + Point(Real _x, Real _y) : x(_x), y(_y) {} + Point() : x(0.0), y(0.0) {} + }; + #ifndef FOR_R_TDA - friend std::ostream& operator<<(std::ostream& output, const DiagramPoint p); + template + std::ostream& operator<<(std::ostream& output, const Point p); #endif - struct LexicographicCmp + template + inline void hash_combine(std::size_t & seed, const T & v) { - bool operator()(const DiagramPoint& p1, const DiagramPoint& p2) const - { return p1.type < p2.type || (p1.type == p2.type && (p1.x < p2.x || (p1.x == p2.x && p1.y < p2.y))); } + std::hash hasher; + seed ^= hasher(v) + 0x9e3779b9 + (seed << 6) + (seed >> 2); + } + + template + struct DiagramPoint + { + using Real = Real_; + // data members + // Points above the diagonal have type NORMAL + // Projections onto the diagonal have type DIAG + // for DIAG points only x-coordinate is relevant + enum Type { NORMAL, DIAG}; + Real x, y; + Type type; + // methods + DiagramPoint(Real xx, Real yy, Type ttype); + bool is_diagonal() const { return type == DIAG; } + bool is_normal() const { return type == NORMAL; } + Real getRealX() const; // return the x-coord + Real getRealY() const; // return the y-coord + Real persistence_lp(const Real p) const; + struct LexicographicCmp + { + bool operator()(const DiagramPoint& p1, const DiagramPoint& p2) const + { return p1.type < p2.type || (p1.type == p2.type && (p1.x < p2.x || (p1.x == p2.x && p1.y < p2.y))); } + }; + + const Real& operator[](const int idx) const + { + switch(idx) + { + case 0 : return x; + break; + case 1 : return y; + break; + default: throw std::out_of_range("DiagramPoint has dimension 2"); + } + } + + Real& operator[](const int idx) + { + switch(idx) + { + case 0 : return x; + break; + case 1 : return y; + break; + default: throw std::out_of_range("DiagramPoint has dimension 2"); + } + } + }; -}; -double sqrDist(const Point& a, const Point& b); -double dist(const Point& a, const Point& b); -double distLInf(const DiagramPoint& a, const DiagramPoint& b); -double distLp(const DiagramPoint& a, const DiagramPoint& b, const double p); -double persistenceLp(const DiagramPoint& a, const double p); -template -double getFurthestDistance3Approx(DiagPointContainer& A, DiagPointContainer& B) -{ - double result { 0.0 }; - DiagramPoint begA = *(A.begin()); - DiagramPoint optB = *(B.begin()); - for(const auto& pointB : B) { - if (distLInf(begA, pointB) > result) { - result = distLInf(begA, pointB); - optB = pointB; + template + struct DiagramPointHash { + size_t operator()(const DiagramPoint &p) const + { + std::size_t seed = 0; + hash_combine(seed, std::hash(p.x)); + hash_combine(seed, std::hash(p.y)); + hash_combine(seed, std::hash(p.is_diagonal())); + return seed; + } + }; + + +#ifndef FOR_R_TDA + template + std::ostream& operator<<(std::ostream& output, const DiagramPoint p); +#endif + + template + void format_arg(fmt::BasicFormatter &f, const char *&format_str, const DiagramPoint&p) { + if (p.is_diagonal()) { + f.writer().write("({0},{1}, DIAG)", p.x, p.y); + } else { + f.writer().write("({0},{1}, NORM)", p.x, p.y); } } - for(const auto& pointA : A) { - if (distLInf(pointA, optB) > result) { - result = distLInf(pointA, optB); + + + template + struct DistImpl + { + Real operator()(const Pt& a, const Pt& b, const Real p, const int dim) + { + Real result = 0.0; + if (hera::is_infinity(p)) { + for(int d = 0; d < dim; ++d) { + result = std::max(result, std::fabs(a[d] - b[d])); + } + } else if (p == 1.0) { + for(int d = 0; d < dim; ++d) { + result += std::fabs(a[d] - b[d]); + } + } else { + assert(p > 1.0); + for(int d = 0; d < dim; ++d) { + result += std::pow(std::fabs(a[d] - b[d]), p); + } + result = std::pow(result, 1.0 / p); + } + return result; } + }; + + template + struct DistImpl> + { + Real operator()(const DiagramPoint& a, const DiagramPoint& b, const Real p, const int dim) + { + Real result = 0.0; + if ( a.is_diagonal() and b.is_diagonal()) { + return result; + } else if (hera::is_infinity(p)) { + result = std::max(std::fabs(a.getRealX() - b.getRealX()), std::fabs(a.getRealY() - b.getRealY())); + } else if (p == 1.0) { + result = std::fabs(a.getRealX() - b.getRealX()) + std::fabs(a.getRealY() - b.getRealY()); + } else { + assert(p > 1.0); + result = std::pow(std::pow(std::fabs(a.getRealX() - b.getRealX()), p) + std::pow(std::fabs(a.getRealY() - b.getRealY()), p), 1.0 / p); + } + return result; + } + }; + + template + R dist_lp(const Pt& a, const Pt& b, const R p, const int dim) + { + return DistImpl()(a, b, p, dim); } - return result; -} -} // end of namespace geom_ws + // TODO + template + double getFurthestDistance3Approx(DiagPointContainer& A, DiagPointContainer& B, const Real p) + { + int dim = 2; + Real result { 0.0 }; + DiagramPoint begA = *(A.begin()); + DiagramPoint optB = *(B.begin()); + for(const auto& pointB : B) { + if (dist_lp(begA, pointB, p, dim) > result) { + result = dist_lp(begA, pointB, p, dim); + optB = pointB; + } + } + for(const auto& pointA : A) { + if (dist_lp(pointA, optB, p, dim) > result) { + result = dist_lp(pointA, optB, p, dim); + } + } + return result; + } + + template + Real getFurthestDistance3Approx_pg(const hera::ws::dnn::DynamicPointVector& A, const hera::ws::dnn::DynamicPointVector& B, const Real p, const int dim) + { + Real result { 0.0 }; + int opt_b_idx = 0; + for(size_t b_idx = 0; b_idx < B.size(); ++b_idx) { + if (dist_lp(A[0], B[b_idx], p, dim) > result) { + result = dist_lp(A[0], B[b_idx], p, dim); + opt_b_idx = b_idx; + } + } + + for(size_t a_idx = 0; a_idx < A.size(); ++a_idx) { + result = std::max(result, dist_lp(A[a_idx], B[opt_b_idx], p, dim)); + } + + return result; + } + + + template + std::string format_container_to_log(const Container& cont); + + template + std::string format_point_set_to_log(const IndexContainer& indices, const std::vector>& points); + + template + std::string format_int(T i); + +} // ws +} // hera + + + +#include "basic_defs_ws.hpp" + + #endif -- cgit v1.2.3