From 0cc35ad04f9c2997014d7cf62a12f697e79fb534 Mon Sep 17 00:00:00 2001 From: Arnur Nigmetov Date: Sat, 20 Jan 2018 19:11:29 +0100 Subject: Major rewrite, templatized version --- geom_matching/wasserstein/include/wasserstein.h | 355 ++++++++++++++++++------ 1 file changed, 268 insertions(+), 87 deletions(-) (limited to 'geom_matching/wasserstein/include/wasserstein.h') diff --git a/geom_matching/wasserstein/include/wasserstein.h b/geom_matching/wasserstein/include/wasserstein.h index c3e9280..d843a04 100644 --- a/geom_matching/wasserstein/include/wasserstein.h +++ b/geom_matching/wasserstein/include/wasserstein.h @@ -26,138 +26,319 @@ derivative works thereof, in binary and source code form. */ -#ifndef WASSERSTEIN_H -#define WASSERSTEIN_H +#ifndef HERA_WASSERSTEIN_H +#define HERA_WASSERSTEIN_H #include #include #include +#include "def_debug_ws.h" #include "basic_defs_ws.h" +#include "diagram_reader.h" +#include "auction_runner_gs.h" +#include "auction_runner_gs_single_diag.h" +#include "auction_runner_jac.h" +#include "auction_runner_fr.h" -// use Gauss-Seidel version; comment out to switch to Jacobi (not recommended) -#define GAUSS_SEIDEL_AUCTION -namespace geom_ws { +namespace hera +{ -using PairVector = std::vector>; +template().begin())>::type > +struct DiagramTraits +{ + using Container = PairContainer_; + using PointType = PointType_; + using RealType = typename std::remove_reference< decltype(std::declval()[0]) >::type; -// get Wasserstein distance between two persistence diagrams -double wassersteinDistVec(const std::vector& A, - const std::vector& B, - const double q, - const double delta, - const double _internal_p = std::numeric_limits::infinity(), - const double _initialEpsilon = 0.0, - const double _epsFactor = 0.0); + static RealType get_x(const PointType& p) { return p[0]; } + static RealType get_y(const PointType& p) { return p[1]; } +}; -// get Wasserstein cost (distance^q) between two persistence diagrams -double wassersteinCostVec(const std::vector& A, - const std::vector& B, - const double q, - const double delta, - const double _internal_p = std::numeric_limits::infinity(), - const double _initialEpsilon = 0.0, - const double _epsFactor = 0.0); +template +struct DiagramTraits> +{ + using PointType = std::pair; + using RealType = double; + using Container = std::vector; + static RealType get_x(const PointType& p) { return p.first; } + static RealType get_y(const PointType& p) { return p.second; } +}; -// compare as multisets -template -bool areEqual(PairContainer& dgm1, PairContainer& dgm2) + +namespace ws { - if (dgm1.size() != dgm2.size()) { - return false; - } - std::map, int> m1, m2; + // compare as multisets + template + bool are_equal(const PairContainer& dgm1, const PairContainer& dgm2) + { + if (dgm1.size() != dgm2.size()) { + return false; + } + + using Traits = typename hera::DiagramTraits; + using PointType = typename Traits::PointType; + + std::map m1, m2; + + for(const auto& pair1 : dgm1) { + m1[pair1]++; + } - for(const auto& pair1 : dgm1) { - m1[pair1]++; + for(const auto& pair2 : dgm2) { + m2[pair2]++; + } + + return m1 == m2; } - for(const auto& pair2 : dgm2) { - m2[pair2]++; + // to handle points with one coordinate = infinity + template + RealType get_one_dimensional_cost(std::vector& set_A, + std::vector& set_B, + const RealType wasserstein_power) + { + if (set_A.size() != set_B.size()) { + return std::numeric_limits::infinity(); + } + std::sort(set_A.begin(), set_A.end()); + std::sort(set_B.begin(), set_B.end()); + RealType result = 0.0; + for(size_t i = 0; i < set_A.size(); ++i) { + result += std::pow(std::fabs(set_A[i] - set_B[i]), wasserstein_power); + } + return result; } - return m1 == m2; -} + + template + struct SplitProblemInput + { + std::vector> A_1; + std::vector> B_1; + std::vector> A_2; + std::vector> B_2; + + std::unordered_map A_1_indices; + std::unordered_map A_2_indices; + std::unordered_map B_1_indices; + std::unordered_map B_2_indices; + + RealType mid_coord { 0.0 }; + RealType strip_width { 0.0 }; + + void init_vectors(size_t n) + { + + A_1_indices.clear(); + A_2_indices.clear(); + B_1_indices.clear(); + B_2_indices.clear(); + + A_1.clear(); + A_2.clear(); + B_1.clear(); + B_2.clear(); + + A_1.reserve(n / 2); + B_1.reserve(n / 2); + A_2.reserve(n / 2); + B_2.reserve(n / 2); + } + + void init(const std::vector>& A, + const std::vector>& B) + { + using DiagramPointR = DiagramPoint; + + init_vectors(A.size()); + + RealType min_sum = std::numeric_limits::max(); + RealType max_sum = -std::numeric_limits::max(); + for(const auto& p_A : A) { + RealType s = p_A[0] + p_A[1]; + if (s > max_sum) + max_sum = s; + if (s < min_sum) + min_sum = s; + mid_coord += s; + } + + mid_coord /= A.size(); + + strip_width = 0.25 * (max_sum - min_sum); + + auto first_diag_iter = std::upper_bound(A.begin(), A.end(), 0, [](const int& a, const DiagramPointR& p) { return a < (int)(p.is_diagonal()); }); + size_t num_normal_A_points = std::distance(A.begin(), first_diag_iter); + + // process all normal points in A, + // projections follow normal points + for(size_t i = 0; i < A.size(); ++i) { + + assert(i < num_normal_A_points and A.is_normal() or i >= num_normal_A_points and A.is_diagonal()); + assert(i < num_normal_A_points and B.is_diagonal() or i >= num_normal_A_points and B.is_normal()); + + RealType s = i < num_normal_A_points ? A[i][0] + A[i][1] : B[i][0] + B[i][1]; + + if (s < mid_coord + strip_width) { + // add normal point and its projection to the + // left half + A_1.push_back(A[i]); + B_1.push_back(B[i]); + A_1_indices[i] = A_1.size() - 1; + B_1_indices[i] = B_1.size() - 1; + } + + if (s > mid_coord - strip_width) { + // to the right half + A_2.push_back(A[i]); + B_2.push_back(B[i]); + A_2_indices[i] = A_2.size() - 1; + B_2_indices[i] = B_2.size() - 1; + } + + } + } // end init + + }; + + + // CAUTION: + // this function assumes that all coordinates are finite + // points at infinity are processed in wasserstein_cost + template + RealType wasserstein_cost_vec(const std::vector>& A, + const std::vector>& B, + const AuctionParams& params, + const std::string& _log_filename_prefix) + { + if (params.wasserstein_power < 1.0) { + throw std::runtime_error("Bad q in Wasserstein " + std::to_string(params.wasserstein_power)); + } + if (params.delta < 0.0) { + throw std::runtime_error("Bad delta in Wasserstein " + std::to_string(params.delta)); + } + if (params.initial_epsilon < 0.0) { + throw std::runtime_error("Bad initial epsilon in Wasserstein" + std::to_string(params.initial_epsilon)); + } + if (params.epsilon_common_ratio < 0.0) { + throw std::runtime_error("Bad epsilon factor in Wasserstein " + std::to_string(params.epsilon_common_ratio)); + } + + RealType result; + + // just use Gauss-Seidel + AuctionRunnerGS auction(A, B, params, _log_filename_prefix); + auction.run_auction(); + result = auction.get_wasserstein_cost(); + return result; + } + +} // ws + + template -double wassersteinDist(PairContainer& A, PairContainer& B, const double q, const double delta, const double _internal_p = std::numeric_limits::infinity(), const double _initialEpsilon = 0.0, const double _epsFactor = 0.0) +typename DiagramTraits::RealType +wasserstein_cost(const PairContainer& A, + const PairContainer& B, + const AuctionParams< typename DiagramTraits::RealType >& params, + const std::string& _log_filename_prefix = "") { - if (areEqual(A, B)) { + using Traits = DiagramTraits; + + //using PointType = typename Traits::PointType; + using RealType = typename Traits::RealType; + + if (hera::ws::are_equal(A, B)) { return 0.0; } - bool a_empty { true }; - bool b_empty { true }; + bool a_empty = true; + bool b_empty = true; + RealType total_cost_A = 0.0; + RealType total_cost_B = 0.0; + + using DgmPoint = hera::ws::DiagramPoint; - std::vector dgmA, dgmB; + std::vector dgm_A, dgm_B; + // coordinates of points at infinity + std::vector x_plus_A, x_minus_A, y_plus_A, y_minus_A; + std::vector x_plus_B, x_minus_B, y_plus_B, y_minus_B; // loop over A, add projections of A-points to corresponding positions // in B-vector - for(auto& pairA : A) { + for(auto& pair_A : A) { a_empty = false; - double x = pairA.first; - double y = pairA.second; - dgmA.push_back(DiagramPoint(x, y, DiagramPoint::NORMAL)); - dgmB.push_back(DiagramPoint(x, y, DiagramPoint::DIAG)); + RealType x = Traits::get_x(pair_A); + RealType y = Traits::get_y(pair_A); + if ( x == std::numeric_limits::infinity()) { + y_plus_A.push_back(y); + } else if (x == -std::numeric_limits::infinity()) { + y_minus_A.push_back(y); + } else if (y == std::numeric_limits::infinity()) { + x_plus_A.push_back(x); + } else if (y == -std::numeric_limits::infinity()) { + x_minus_A.push_back(x); + } else { + dgm_A.emplace_back(x, y, DgmPoint::NORMAL); + dgm_B.emplace_back(x, y, DgmPoint::DIAG); + total_cost_A += std::pow(dgm_A.back().persistence_lp(params.internal_p), params.wasserstein_power); + } } // the same for B - for(auto& pairB : B) { + for(auto& pair_B : B) { b_empty = false; - double x = pairB.first; - double y = pairB.second; - dgmA.push_back(DiagramPoint(x, y, DiagramPoint::DIAG)); - dgmB.push_back(DiagramPoint(x, y, DiagramPoint::NORMAL)); + RealType x = Traits::get_x(pair_B); + RealType y = Traits::get_y(pair_B); + if (x == std::numeric_limits::infinity()) { + y_plus_B.push_back(y); + } else if (x == -std::numeric_limits::infinity()) { + y_minus_B.push_back(y); + } else if (y == std::numeric_limits::infinity()) { + x_plus_B.push_back(x); + } else if (y == -std::numeric_limits::infinity()) { + x_minus_B.push_back(x); + } else { + dgm_A.emplace_back(x, y, DgmPoint::DIAG); + dgm_B.emplace_back(x, y, DgmPoint::NORMAL); + total_cost_B += std::pow(dgm_B.back().persistence_lp(params.internal_p), params.wasserstein_power); + } } - if (a_empty && b_empty) - return 0.0; + RealType infinity_cost = ws::get_one_dimensional_cost(x_plus_A, x_plus_B, params.wasserstein_power); + infinity_cost += ws::get_one_dimensional_cost(x_minus_A, x_minus_B, params.wasserstein_power); + infinity_cost += ws::get_one_dimensional_cost(y_plus_A, y_plus_B, params.wasserstein_power); + infinity_cost += ws::get_one_dimensional_cost(y_minus_A, y_minus_B, params.wasserstein_power); if (a_empty) - dgmA.clear(); + return total_cost_B + infinity_cost; if (b_empty) - dgmB.clear(); + return total_cost_A + infinity_cost; - return wassersteinDistVec(dgmA, dgmB, q, delta, _internal_p, _initialEpsilon, _epsFactor); -} -template -double wassersteinCost(PairContainer& A, PairContainer& B, const double q, const double delta, const double _internal_p = std::numeric_limits::infinity(), const double _initialEpsilon = 0.0, const double _epsFactor = 0.0) -{ - if (areEqual(A, B)) { - return 0.0; + if (infinity_cost == std::numeric_limits::infinity()) { + return infinity_cost; + } else { + return infinity_cost + wasserstein_cost_vec(dgm_A, dgm_B, params, _log_filename_prefix); } - std::vector dgmA, dgmB; - // loop over A, add projections of A-points to corresponding positions - // in B-vector - for(auto& pairA : A) { - double x = pairA.first; - double y = pairA.second; - dgmA.push_back(DiagramPoint(x, y, DiagramPoint::NORMAL)); - dgmB.push_back(DiagramPoint(x, y, DiagramPoint::DIAG)); - } - // the same for B - for(auto& pairB : B) { - double x = pairB.first; - double y = pairB.second; - dgmA.push_back(DiagramPoint(x, y, DiagramPoint::DIAG)); - dgmB.push_back(DiagramPoint(x, y, DiagramPoint::NORMAL)); - } - - return wassersteinCostVec(dgmA, dgmB, q, delta, _internal_p, _initialEpsilon, _epsFactor); } +template +typename DiagramTraits::RealType +wasserstein_dist(PairContainer& A, + PairContainer& B, + const AuctionParams::RealType> params, + const std::string& _log_filename_prefix = "") +{ + using Real = typename DiagramTraits::RealType; + return std::pow(hera::wasserstein_cost(A, B, params, _log_filename_prefix), Real(1.)/params.wasserstein_power); +} -// fill in result with points from file fname -// return false if file can't be opened -// or error occurred while reading -bool readDiagramPointSet(const char* fname, PairVector& result); -bool readDiagramPointSet(const std::string& fname, PairVector& result); -void removeDuplicates(PairVector& dgmA, PairVector& dgmB); - -} // end of namespace geom_ws +} // end of namespace hera #endif -- cgit v1.2.3