summaryrefslogtreecommitdiff
path: root/geom_matching
diff options
context:
space:
mode:
authorGard Spreemann <gard.spreemann@epfl.ch>2018-08-20 21:03:51 +0200
committerGard Spreemann <gard.spreemann@epfl.ch>2018-08-20 21:03:51 +0200
commit4c1b1727ab44a773d27b090463ac6db957267136 (patch)
tree53520f77b9e88998181b2bb8c368302aaa5fea3c /geom_matching
parentf8b9be92615662a1f9559608930559138f9c1fd8 (diff)
parent657f73321f04d5d1c4cec8085ec43a73633b96af (diff)
Merge branch 'master' into gspr/mpi
Diffstat (limited to 'geom_matching')
-rw-r--r--geom_matching/wasserstein/CMakeLists.txt8
-rw-r--r--geom_matching/wasserstein/example/wasserstein_dist_point_cloud.cpp2
-rw-r--r--geom_matching/wasserstein/include/auction_oracle_kdtree_pure_geom.hpp17
-rw-r--r--geom_matching/wasserstein/include/auction_oracle_kdtree_restricted.hpp8
-rw-r--r--geom_matching/wasserstein/include/auction_runner_gs.hpp17
-rw-r--r--geom_matching/wasserstein/include/auction_runner_jac.hpp10
-rw-r--r--geom_matching/wasserstein/include/basic_defs_ws.h47
-rw-r--r--geom_matching/wasserstein/include/basic_defs_ws.hpp18
-rw-r--r--geom_matching/wasserstein/include/diagonal_heap.h2
-rw-r--r--geom_matching/wasserstein/include/diagram_reader.h32
-rw-r--r--geom_matching/wasserstein/include/dnn/geometry/euclidean-dynamic.h24
-rw-r--r--geom_matching/wasserstein/include/dnn/local/kd-tree.hpp2
-rw-r--r--geom_matching/wasserstein/include/hera_infinity.h22
-rw-r--r--geom_matching/wasserstein/include/wasserstein.h14
-rw-r--r--geom_matching/wasserstein/include/wasserstein_pure_geom.hpp5
-rw-r--r--geom_matching/wasserstein/tests/test_hera_wasserstein.cpp53
-rw-r--r--geom_matching/wasserstein/tests/test_hera_wasserstein_pure_geom.cpp111
-rw-r--r--geom_matching/wasserstein/tests/tests_reader.h67
18 files changed, 323 insertions, 136 deletions
diff --git a/geom_matching/wasserstein/CMakeLists.txt b/geom_matching/wasserstein/CMakeLists.txt
index fe5fd73..53da46b 100644
--- a/geom_matching/wasserstein/CMakeLists.txt
+++ b/geom_matching/wasserstein/CMakeLists.txt
@@ -46,20 +46,20 @@ file(GLOB WS_HEADERS ${CMAKE_CURRENT_SOURCE_DIR}/include/*.h ${CMAKE_CURRENT_SOU
find_package (Threads)
set (libraries ${libraries} ${CMAKE_THREAD_LIBS_INIT})
-add_executable(wasserstein_dist ${CMAKE_CURRENT_SOURCE_DIR}/example/wasserstein_dist.cpp ${WS_HEADERS})
+add_executable(wasserstein_dist ${CMAKE_CURRENT_SOURCE_DIR}/example/wasserstein_dist.cpp ${WS_HEADERS} include/hera_infinity.h)
target_link_libraries(wasserstein_dist PUBLIC ${libraries})
-add_executable(wasserstein_dist_dipha ${CMAKE_CURRENT_SOURCE_DIR}/example/wasserstein_dist_dipha.cpp ${WS_HEADERS})
+add_executable(wasserstein_dist_dipha ${CMAKE_CURRENT_SOURCE_DIR}/example/wasserstein_dist_dipha.cpp ${WS_HEADERS} include/hera_infinity.h)
target_link_libraries(wasserstein_dist_dipha PUBLIC ${libraries})
add_executable(wasserstein_dist_finitize_dipha ${CMAKE_CURRENT_SOURCE_DIR}/example/wasserstein_dist_finitize_dipha.cpp ${WS_HEADERS})
target_link_libraries(wasserstein_dist_finitize_dipha PUBLIC ${libraries})
# pure geometric version, arbitrary dimension
-add_executable(wasserstein_dist_point_cloud ${CMAKE_CURRENT_SOURCE_DIR}/example/wasserstein_dist_point_cloud.cpp ${WS_HEADERS})
+add_executable(wasserstein_dist_point_cloud ${CMAKE_CURRENT_SOURCE_DIR}/example/wasserstein_dist_point_cloud.cpp ${WS_HEADERS} include/hera_infinity.h)
target_link_libraries(wasserstein_dist_point_cloud PUBLIC ${libraries})
# Tests
-add_executable(wasserstein_test ${CMAKE_CURRENT_SOURCE_DIR}/tests/tests_main.cpp ${CMAKE_CURRENT_SOURCE_DIR}/tests/test_hera_wasserstein.cpp)
+add_executable(wasserstein_test ${CMAKE_CURRENT_SOURCE_DIR}/tests/tests_main.cpp ${CMAKE_CURRENT_SOURCE_DIR}/tests/test_hera_wasserstein.cpp ${CMAKE_CURRENT_SOURCE_DIR}/tests/test_hera_wasserstein_pure_geom.cpp include/hera_infinity.h tests/tests_reader.h)
#add_executable(wasserstein_test EXCLUDE_FROM_ALL ${CMAKE_CURRENT_SOURCE_DIR}/tests/test_hera_wasserstein.cpp)
target_link_libraries(wasserstein_test PUBLIC ${libraries})
diff --git a/geom_matching/wasserstein/example/wasserstein_dist_point_cloud.cpp b/geom_matching/wasserstein/example/wasserstein_dist_point_cloud.cpp
index 6f699a4..2f9718e 100644
--- a/geom_matching/wasserstein/example/wasserstein_dist_point_cloud.cpp
+++ b/geom_matching/wasserstein/example/wasserstein_dist_point_cloud.cpp
@@ -35,6 +35,7 @@ derivative works thereof, in binary and source code form.
int main(int argc, char* argv[])
{
+
//{
//int n_points = 3;
//int dim = 3;
@@ -106,7 +107,6 @@ int main(int argc, char* argv[])
params.dim = dimension_A;
-
params.wasserstein_power = (4 <= argc) ? atof(argv[3]) : 1.0;
if (params.wasserstein_power < 1.0) {
std::cerr << "The third argument (wasserstein_degree) was \"" << argv[3] << "\", must be a number >= 1.0. Cannot proceed. " << std::endl;
diff --git a/geom_matching/wasserstein/include/auction_oracle_kdtree_pure_geom.hpp b/geom_matching/wasserstein/include/auction_oracle_kdtree_pure_geom.hpp
index a6bdf10..eaf54cf 100644
--- a/geom_matching/wasserstein/include/auction_oracle_kdtree_pure_geom.hpp
+++ b/geom_matching/wasserstein/include/auction_oracle_kdtree_pure_geom.hpp
@@ -111,11 +111,9 @@ AuctionOracleKDTreePureGeom<Real_, PointContainer_>::get_optimal_bid_debug(IdxTy
Real best_item_value = std::numeric_limits<Real>::max();
Real second_best_item_value = std::numeric_limits<Real>::max();
- for(IdxType item_idx = 0; item_idx < this->items.size(); ++item_idx) {
+ for(size_t item_idx = 0; item_idx < this->items.size(); ++item_idx) {
auto item = this->items[item_idx];
- if (item.type != bidder.type and item_idx != bidder_idx)
- continue;
- auto item_value = std::pow(dist_lp(bidder, item, this->internal_p), this->wasserstein_power, this->dim) + this->prices[item_idx];
+ auto item_value = std::pow(traits.distance(bidder, item), this->wasserstein_power) + this->prices[item_idx];
if (item_value < best_item_value) {
best_item_value = item_value;
best_item_idx = item_idx;
@@ -126,11 +124,10 @@ AuctionOracleKDTreePureGeom<Real_, PointContainer_>::get_optimal_bid_debug(IdxTy
for(size_t item_idx = 0; item_idx < this->items.size(); ++item_idx) {
auto item = this->items[item_idx];
- if (item.type != bidder.type and item_idx != bidder_idx)
- continue;
if (item_idx == best_item_idx)
continue;
- auto item_value = std::pow(dist_lp(bidder, item, this->internal_p), this->wasserstein_power, this->dim) + this->prices[item_idx];
+
+ auto item_value = std::pow(traits.distance(bidder, item), this->wasserstein_power) + this->prices[item_idx];
if (item_value < second_best_item_value) {
second_best_item_value = item_value;
second_best_item_idx = item_idx;
@@ -166,6 +163,12 @@ IdxValPair<Real_> AuctionOracleKDTreePureGeom<Real_, PointContainer_>::get_optim
result.first = best_item_idx;
result.second = ( second_best_item_value - best_item_value ) + this->prices[best_item_idx] + this->epsilon;
+#ifdef DEBUG_KDTREE_RESTR_ORACLE
+ auto bid_debug = get_optimal_bid_debug(bidder_idx);
+ assert(fabs(bid_debug.best_item_value - best_item_value) < 0.000000001);
+ assert(fabs(bid_debug.second_best_item_value - second_best_item_value) < 0.000000001);
+#endif
+
return result;
}
diff --git a/geom_matching/wasserstein/include/auction_oracle_kdtree_restricted.hpp b/geom_matching/wasserstein/include/auction_oracle_kdtree_restricted.hpp
index 0e6f780..3c3cba3 100644
--- a/geom_matching/wasserstein/include/auction_oracle_kdtree_restricted.hpp
+++ b/geom_matching/wasserstein/include/auction_oracle_kdtree_restricted.hpp
@@ -243,11 +243,11 @@ AuctionOracleKDTreeRestricted<Real_, PointContainer_>::get_optimal_bid_debug(Idx
Real best_item_value = std::numeric_limits<Real>::max();
Real second_best_item_value = std::numeric_limits<Real>::max();
- for(IdxType item_idx = 0; item_idx < this->items.size(); ++item_idx) {
+ for(IdxType item_idx = 0; item_idx < static_cast<IdxType>(this->items.size()); ++item_idx) {
auto item = this->items[item_idx];
if (item.type != bidder.type and item_idx != bidder_idx)
continue;
- auto item_value = std::pow(dist_lp(bidder, item, this->internal_p), this->wasserstein_power) + this->prices[item_idx];
+ auto item_value = std::pow(dist_lp(bidder, item, this->internal_p, 2), this->wasserstein_power) + this->prices[item_idx];
if (item_value < best_item_value) {
best_item_value = item_value;
best_item_idx = item_idx;
@@ -258,11 +258,11 @@ AuctionOracleKDTreeRestricted<Real_, PointContainer_>::get_optimal_bid_debug(Idx
for(size_t item_idx = 0; item_idx < this->items.size(); ++item_idx) {
auto item = this->items[item_idx];
- if (item.type != bidder.type and item_idx != bidder_idx)
+ if (item.type != bidder.type and static_cast<IdxType>(item_idx) != bidder_idx)
continue;
if (item_idx == best_item_idx)
continue;
- auto item_value = std::pow(dist_lp(bidder, item, this->internal_p), this->wasserstein_power) + this->prices[item_idx];
+ auto item_value = std::pow(dist_lp(bidder, item, this->internal_p, 2), this->wasserstein_power) + this->prices[item_idx];
if (item_value < second_best_item_value) {
second_best_item_value = item_value;
second_best_item_idx = item_idx;
diff --git a/geom_matching/wasserstein/include/auction_runner_gs.hpp b/geom_matching/wasserstein/include/auction_runner_gs.hpp
index d9f419d..141cb2c 100644
--- a/geom_matching/wasserstein/include/auction_runner_gs.hpp
+++ b/geom_matching/wasserstein/include/auction_runner_gs.hpp
@@ -244,6 +244,12 @@ void AuctionRunnerGS<R, AO, PC>::run_auction_phases(const int max_num_phases, co
flush_assignment();
run_auction_phase();
Real current_result = getDistanceToQthPowerInternal();
+// Real current_result_1 = 0.0;
+// for(size_t i = 0; i < num_bidders; ++i) {
+// current_result_1 += oracle.traits.distance(bidders[i], items[bidders_to_items[i]]);
+// }
+// current_result = current_result_1;
+// assert(fabs(current_result - current_result_1) < 0.001);
Real denominator = current_result - num_bidders * oracle.get_epsilon();
current_result = pow(current_result, 1.0 / wasserstein_power);
#ifdef LOG_AUCTION
@@ -259,6 +265,7 @@ void AuctionRunnerGS<R, AO, PC>::run_auction_phases(const int max_num_phases, co
denominator = pow(denominator, 1.0 / wasserstein_power);
Real numerator = current_result - denominator;
relative_error = numerator / denominator;
+ // spdlog::get("console")->info("relative error = {} / {} = {}, result = {}", numerator, denominator, relative_error, current_result);
#ifdef LOG_AUCTION
console_logger->info("error = {0} / {1} = {2}",
numerator, denominator, relative_error);
@@ -280,6 +287,7 @@ void AuctionRunnerGS<R, AO, PC>::run_auction()
if (num_bidders == 1) {
assign_item_to_bidder(0, 0);
wasserstein_cost = get_item_bidder_cost(0,0);
+ is_distance_computed = true;
return;
}
@@ -319,7 +327,7 @@ void AuctionRunnerGS<R, AO, PC>::run_auction_phase()
#ifdef DEBUG_AUCTION
for(size_t bidder_idx = 0; bidder_idx < num_bidders; ++bidder_idx) {
- if ( bidders_to_items[bidder_idx] < 0 or bidders_to_items[bidder_idx] >= num_bidders) {
+ if ( bidders_to_items[bidder_idx] < 0 or bidders_to_items[bidder_idx] >= (IdxType)num_bidders) {
std::cerr << "After auction terminated bidder " << bidder_idx;
std::cerr << " has no items assigned" << std::endl;
throw std::runtime_error("Auction did not give a perfect matching");
@@ -333,8 +341,7 @@ template<class R, class AO, class PC>
R AuctionRunnerGS<R, AO, PC>::get_item_bidder_cost(const size_t item_idx, const size_t bidder_idx, const bool tolerate_invalid_idx) const
{
if (item_idx != k_invalid_index and bidder_idx != k_invalid_index) {
- return std::pow(dist_lp(bidders[bidder_idx], items[item_idx], internal_p, dimension),
- wasserstein_power);
+ return std::pow(dist_lp(bidders[bidder_idx], items[item_idx], internal_p, dimension), wasserstein_power);
} else {
if (tolerate_invalid_idx)
return R(0.0);
@@ -416,7 +423,7 @@ void AuctionRunnerGS<R, AO, PC>::sanity_check()
}
for(size_t bidder_idx = 0; bidder_idx < num_bidders; ++bidder_idx) {
- assert( bidders_to_items[bidder_idx] == k_invalid_index or ( bidders_to_items[bidder_idx] < num_items and bidders_to_items[bidder_idx] >= 0));
+ assert( bidders_to_items[bidder_idx] == k_invalid_index or ( bidders_to_items[bidder_idx] < (IdxType)num_items and bidders_to_items[bidder_idx] >= 0));
if ( bidders_to_items[bidder_idx] != k_invalid_index) {
@@ -440,7 +447,7 @@ void AuctionRunnerGS<R, AO, PC>::sanity_check()
}
for(IdxType item_idx = 0; item_idx < static_cast<IdxType>(num_bidders); ++item_idx) {
- assert( items_to_bidders[item_idx] == k_invalid_index or ( items_to_bidders[item_idx] < num_items and items_to_bidders[item_idx] >= 0));
+ assert( items_to_bidders[item_idx] == k_invalid_index or ( items_to_bidders[item_idx] < static_cast<IdxType>(num_items) and items_to_bidders[item_idx] >= 0));
if ( items_to_bidders.at(item_idx) != k_invalid_index) {
// check for uniqueness
diff --git a/geom_matching/wasserstein/include/auction_runner_jac.hpp b/geom_matching/wasserstein/include/auction_runner_jac.hpp
index 8663bae..e623f4a 100644
--- a/geom_matching/wasserstein/include/auction_runner_jac.hpp
+++ b/geom_matching/wasserstein/include/auction_runner_jac.hpp
@@ -42,7 +42,6 @@ derivative works thereof, in binary and source code form.
#undef DEBUG_AUCTION
#endif
-
namespace hera {
namespace ws {
@@ -350,6 +349,8 @@ namespace ws {
typename AuctionRunnerJac<R, AO, PC>::Real
AuctionRunnerJac<R, AO, PC>::get_relative_error(const bool debug_output) const
{
+ if (partial_cost == 0.0 and unassigned_bidders.empty())
+ return 0.0;
Real result;
#ifndef WASSERSTEIN_PURE_GEOM
Real gamma = get_gamma();
@@ -558,9 +559,10 @@ namespace ws {
if (num_bidders == 1) {
assign_item_to_bidder(0, 0);
wasserstein_cost = get_item_bidder_cost(0,0);
+ is_distance_computed = true;
return;
}
- double init_eps = (initial_epsilon > 0.0) ? initial_epsilon : oracle.max_val_ / 4.0;
+ R init_eps = (initial_epsilon > 0.0) ? initial_epsilon : oracle.max_val_ / 4.0;
run_auction_phases(max_num_phases, init_eps);
is_distance_computed = true;
wasserstein_cost = partial_cost;
@@ -703,7 +705,11 @@ namespace ws {
template<class R, class AO, class PC>
bool AuctionRunnerJac<R, AO, PC>::continue_auction_phase() const
{
+#ifdef WASSERSTEIN_PURE_GEOM
+ return not unassigned_bidders.empty();
+#else
return not unassigned_bidders.empty() and not is_done();
+#endif
}
template<class R, class AO, class PC>
diff --git a/geom_matching/wasserstein/include/basic_defs_ws.h b/geom_matching/wasserstein/include/basic_defs_ws.h
index 58d6fd2..1c5928f 100644
--- a/geom_matching/wasserstein/include/basic_defs_ws.h
+++ b/geom_matching/wasserstein/include/basic_defs_ws.h
@@ -51,6 +51,7 @@ derivative works thereof, in binary and source code form.
#include "spdlog/fmt/ostr.h"
#endif
+#include "hera_infinity.h"
#include "dnn/geometry/euclidean-dynamic.h"
#include "def_debug_ws.h"
@@ -59,20 +60,20 @@ derivative works thereof, in binary and source code form.
namespace hera
{
-template<class Real = double>
-bool is_infinity(const Real& x)
-{
- return x == Real(-1);
-};
+//template<class Real = double>
+//inline bool is_infinity(const Real& x)
+//{
+// return x == Real(-1);
+//};
+//
+//template<class Real = double>
+//inline Real get_infinity()
+//{
+// return Real( -1 );
+//}
template<class Real = double>
-Real get_infinity()
-{
- return Real( -1 );
-}
-
-template<class Real = double>
-bool is_p_valid_norm(const Real& p)
+inline bool is_p_valid_norm(const Real& p)
{
return is_infinity<Real>(p) or p >= Real(1);
}
@@ -101,10 +102,8 @@ namespace ws
template<class Real = double>
using IdxValPair = std::pair<IdxType, Real>;
-
-
template<class R>
- std::ostream& operator<<(std::ostream& output, const IdxValPair<R> p)
+ inline std::ostream& operator<<(std::ostream& output, const IdxValPair<R> p)
{
output << fmt::format("({0}, {1})", p.first, p.second);
return output;
@@ -112,7 +111,7 @@ namespace ws
enum class OwnerType { k_none, k_normal, k_diagonal };
- std::ostream& operator<<(std::ostream& s, const OwnerType t)
+ inline std::ostream& operator<<(std::ostream& s, const OwnerType t)
{
switch(t)
{
@@ -210,11 +209,11 @@ namespace ws
#ifndef FOR_R_TDA
template <class Real = double>
- std::ostream& operator<<(std::ostream& output, const DiagramPoint<Real> p);
+ inline std::ostream& operator<<(std::ostream& output, const DiagramPoint<Real> p);
#endif
template<class Real>
- void format_arg(fmt::BasicFormatter<char> &f, const char *&format_str, const DiagramPoint<Real>&p) {
+ inline void format_arg(fmt::BasicFormatter<char> &f, const char *&format_str, const DiagramPoint<Real>&p) {
if (p.is_diagonal()) {
f.writer().write("({0},{1}, DIAG)", p.x, p.y);
} else {
@@ -269,14 +268,14 @@ namespace ws
};
template<class R, class Pt>
- R dist_lp(const Pt& a, const Pt& b, const R p, const int dim)
+ inline R dist_lp(const Pt& a, const Pt& b, const R p, const int dim)
{
return DistImpl<R, Pt>()(a, b, p, dim);
}
// TODO
template<class Real, typename DiagPointContainer>
- double getFurthestDistance3Approx(DiagPointContainer& A, DiagPointContainer& B, const Real p)
+ inline double getFurthestDistance3Approx(DiagPointContainer& A, DiagPointContainer& B, const Real p)
{
int dim = 2;
Real result { 0.0 };
@@ -297,7 +296,7 @@ namespace ws
}
template<class Real>
- Real getFurthestDistance3Approx_pg(const hera::ws::dnn::DynamicPointVector<Real>& A, const hera::ws::dnn::DynamicPointVector<Real>& B, const Real p, const int dim)
+ inline Real getFurthestDistance3Approx_pg(const hera::ws::dnn::DynamicPointVector<Real>& A, const hera::ws::dnn::DynamicPointVector<Real>& B, const Real p, const int dim)
{
Real result { 0.0 };
int opt_b_idx = 0;
@@ -317,13 +316,13 @@ namespace ws
template<class Container>
- std::string format_container_to_log(const Container& cont);
+ inline std::string format_container_to_log(const Container& cont);
template<class Real, class IndexContainer>
- std::string format_point_set_to_log(const IndexContainer& indices, const std::vector<DiagramPoint<Real>>& points);
+ inline std::string format_point_set_to_log(const IndexContainer& indices, const std::vector<DiagramPoint<Real>>& points);
template<class T>
- std::string format_int(T i);
+ inline std::string format_int(T i);
} // ws
} // hera
diff --git a/geom_matching/wasserstein/include/basic_defs_ws.hpp b/geom_matching/wasserstein/include/basic_defs_ws.hpp
index 1750b4e..a1153af 100644
--- a/geom_matching/wasserstein/include/basic_defs_ws.hpp
+++ b/geom_matching/wasserstein/include/basic_defs_ws.hpp
@@ -64,7 +64,7 @@ bool Point<Real>::operator!=(const Point<Real>& other) const
#ifndef FOR_R_TDA
template <class Real>
-std::ostream& operator<<(std::ostream& output, const Point<Real> p)
+inline std::ostream& operator<<(std::ostream& output, const Point<Real> p)
{
output << "(" << p.x << ", " << p.y << ")";
return output;
@@ -72,20 +72,20 @@ std::ostream& operator<<(std::ostream& output, const Point<Real> p)
#endif
template <class Real>
-Real sqr_dist(const Point<Real>& a, const Point<Real>& b)
+inline Real sqr_dist(const Point<Real>& a, const Point<Real>& b)
{
return (a.x - b.x) * (a.x - b.x) + (a.y - b.y) * (a.y - b.y);
}
template <class Real>
-Real dist(const Point<Real>& a, const Point<Real>& b)
+inline Real dist(const Point<Real>& a, const Point<Real>& b)
{
return sqrt(sqr_dist(a, b));
}
template <class Real>
-Real DiagramPoint<Real>::persistence_lp(const Real p) const
+inline Real DiagramPoint<Real>::persistence_lp(const Real p) const
{
if (is_diagonal())
return 0.0;
@@ -100,7 +100,7 @@ Real DiagramPoint<Real>::persistence_lp(const Real p) const
#ifndef FOR_R_TDA
template <class Real>
-std::ostream& operator<<(std::ostream& output, const DiagramPoint<Real> p)
+inline std::ostream& operator<<(std::ostream& output, const DiagramPoint<Real> p)
{
if ( p.type == DiagramPoint<Real>::DIAG ) {
output << "(" << p.x << ", " << p.y << ", " << 0.5 * (p.x + p.y) << " DIAG )";
@@ -142,7 +142,7 @@ Real DiagramPoint<Real>::getRealY() const
}
template<class Container>
-std::string format_container_to_log(const Container& cont)
+inline std::string format_container_to_log(const Container& cont)
{
std::stringstream result;
result << "[";
@@ -157,7 +157,7 @@ std::string format_container_to_log(const Container& cont)
}
template<class Container>
-std::string format_pair_container_to_log(const Container& cont)
+inline std::string format_pair_container_to_log(const Container& cont)
{
std::stringstream result;
result << "[";
@@ -173,7 +173,7 @@ std::string format_pair_container_to_log(const Container& cont)
template<class Real, class IndexContainer>
-std::string format_point_set_to_log(const IndexContainer& indices,
+inline std::string format_point_set_to_log(const IndexContainer& indices,
const std::vector<DiagramPoint<Real>>& points)
{
std::stringstream result;
@@ -189,7 +189,7 @@ std::string format_point_set_to_log(const IndexContainer& indices,
}
template<class T>
-std::string format_int(T i)
+inline std::string format_int(T i)
{
std::stringstream ss;
ss.imbue(std::locale(""));
diff --git a/geom_matching/wasserstein/include/diagonal_heap.h b/geom_matching/wasserstein/include/diagonal_heap.h
index 9ffee70..3b3c8bc 100644
--- a/geom_matching/wasserstein/include/diagonal_heap.h
+++ b/geom_matching/wasserstein/include/diagonal_heap.h
@@ -129,7 +129,7 @@ using LossesHeapOld = IdxValHeap<Real, CompPairsBySecondLexStruct<Real>>;
#endif
template <class Real>
-std::string losses_heap_to_string(const LossesHeapOld<Real>& h)
+inline std::string losses_heap_to_string(const LossesHeapOld<Real>& h)
{
std::stringstream result;
result << "[";
diff --git a/geom_matching/wasserstein/include/diagram_reader.h b/geom_matching/wasserstein/include/diagram_reader.h
index 55228b4..4b24f78 100644
--- a/geom_matching/wasserstein/include/diagram_reader.h
+++ b/geom_matching/wasserstein/include/diagram_reader.h
@@ -55,31 +55,31 @@ namespace hera {
// cannot choose stod, stof or stold based on RealType,
// lazy solution: partial specialization
template<class RealType = double>
-RealType parse_real_from_str(const std::string& s);
+inline RealType parse_real_from_str(const std::string& s);
template <>
-double parse_real_from_str<double>(const std::string& s)
+inline double parse_real_from_str<double>(const std::string& s)
{
return std::stod(s);
}
template <>
-long double parse_real_from_str<long double>(const std::string& s)
+inline long double parse_real_from_str<long double>(const std::string& s)
{
return std::stold(s);
}
template <>
-float parse_real_from_str<float>(const std::string& s)
+inline float parse_real_from_str<float>(const std::string& s)
{
return std::stof(s);
}
template<class RealType>
-RealType parse_real_from_str(const std::string& s)
+inline RealType parse_real_from_str(const std::string& s)
{
static_assert(sizeof(RealType) != sizeof(RealType), "Must be specialized for each type you want to use, see above");
}
@@ -90,7 +90,7 @@ RealType parse_real_from_str(const std::string& s)
// decPrecision is the maximal decimal precision in the input,
// it is zero if all coordinates in the input are integers
template<class RealType = double, class ContType_ = std::vector<std::pair<RealType, RealType>>>
-bool read_diagram_point_set(const char* fname, ContType_& result, int& decPrecision)
+inline bool read_diagram_point_set(const char* fname, ContType_& result, int& decPrecision)
{
size_t lineNumber { 0 };
result.clear();
@@ -182,7 +182,7 @@ bool read_diagram_point_set(const char* fname, ContType_& result, int& decPrecis
// wrappers
template<class RealType = double, class ContType_ = std::vector<std::pair<RealType, RealType>>>
-bool read_diagram_point_set(const std::string& fname, ContType_& result, int& decPrecision)
+inline bool read_diagram_point_set(const std::string& fname, ContType_& result, int& decPrecision)
{
return read_diagram_point_set<RealType, ContType_>(fname.c_str(), result, decPrecision);
}
@@ -190,21 +190,21 @@ bool read_diagram_point_set(const std::string& fname, ContType_& result, int& de
// these two functions are now just wrappers for the previous ones,
// in case someone needs them; decPrecision is ignored
template<class RealType = double, class ContType_ = std::vector<std::pair<RealType, RealType>>>
-bool read_diagram_point_set(const char* fname, ContType_& result)
+inline bool read_diagram_point_set(const char* fname, ContType_& result)
{
int decPrecision;
return read_diagram_point_set<RealType, ContType_>(fname, result, decPrecision);
}
template<class RealType = double, class ContType_ = std::vector<std::pair<RealType, RealType>>>
-bool read_diagram_point_set(const std::string& fname, ContType_& result)
+inline bool read_diagram_point_set(const std::string& fname, ContType_& result)
{
int decPrecision;
return read_diagram_point_set<RealType, ContType_>(fname.c_str(), result, decPrecision);
}
template<class RealType = double, class ContType_ = std::vector<std::pair<RealType, RealType> > >
-bool read_diagram_dipha(const std::string& fname, unsigned int dim, ContType_& result)
+inline bool read_diagram_dipha(const std::string& fname, unsigned int dim, ContType_& result)
{
std::ifstream file;
file.open(fname, std::ios::in | std::ios::binary);
@@ -274,7 +274,7 @@ bool read_diagram_dipha(const std::string& fname, unsigned int dim, ContType_& r
template<class RealType, class ContType>
-void remove_duplicates(ContType& dgm_A, ContType& dgm_B)
+inline void remove_duplicates(ContType& dgm_A, ContType& dgm_B)
{
std::map<std::pair<RealType, RealType>, int> map_A, map_B;
// copy points to maps
@@ -348,7 +348,7 @@ int finitize(RealType finitization, std::vector<std::pair<RealType, RealType> >&
#ifdef WASSERSTEIN_PURE_GEOM
template<class Real>
-int get_point_dimension(const std::string& line)
+inline int get_point_dimension(const std::string& line)
{
Real x;
int dim = 0;
@@ -361,7 +361,7 @@ int get_point_dimension(const std::string& line)
template<class RealType = double >
-bool read_point_cloud(const char* fname, hera::ws::dnn::DynamicPointVector<RealType>& result, int& dimension, int& decPrecision)
+inline bool read_point_cloud(const char* fname, hera::ws::dnn::DynamicPointVector<RealType>& result, int& dimension, int& decPrecision)
{
using DynamicPointTraitsR = typename hera::ws::dnn::DynamicPointTraits<RealType>;
@@ -443,20 +443,20 @@ bool read_point_cloud(const char* fname, hera::ws::dnn::DynamicPointVector<RealT
// wrappers
template<class RealType = double >
-bool read_point_cloud(const char* fname, hera::ws::dnn::DynamicPointVector<RealType>& result, int& dimension)
+inline bool read_point_cloud(const char* fname, hera::ws::dnn::DynamicPointVector<RealType>& result, int& dimension)
{
int dec_precision;
return read_point_cloud<RealType>(fname, result, dimension, dec_precision);
}
template<class RealType = double >
-bool read_point_cloud(std::string fname, hera::ws::dnn::DynamicPointVector<RealType>& result, int& dimension, int& dec_precision)
+inline bool read_point_cloud(std::string fname, hera::ws::dnn::DynamicPointVector<RealType>& result, int& dimension, int& dec_precision)
{
return read_point_cloud<RealType>(fname.c_str(), result, dimension, dec_precision);
}
template<class RealType = double >
-bool read_point_cloud(std::string fname, hera::ws::dnn::DynamicPointVector<RealType>& result, int& dimension)
+inline bool read_point_cloud(std::string fname, hera::ws::dnn::DynamicPointVector<RealType>& result, int& dimension)
{
return read_point_cloud<RealType>(fname.c_str(), result, dimension);
}
diff --git a/geom_matching/wasserstein/include/dnn/geometry/euclidean-dynamic.h b/geom_matching/wasserstein/include/dnn/geometry/euclidean-dynamic.h
index 4b98309..b003906 100644
--- a/geom_matching/wasserstein/include/dnn/geometry/euclidean-dynamic.h
+++ b/geom_matching/wasserstein/include/dnn/geometry/euclidean-dynamic.h
@@ -8,6 +8,8 @@
#include <boost/serialization/vector.hpp>
#include <cmath>
+#include "hera_infinity.h"
+
namespace hera
{
namespace ws
@@ -89,7 +91,27 @@ struct DynamicPointTraits
DynamicPointTraits(unsigned dim = 0):
dim_(dim) {}
- DistanceType distance(PointType p1, PointType p2) const { return sqrt(sq_distance(p1,p2)); }
+ DistanceType distance(PointType p1, PointType p2) const
+ {
+ Real result = 0.0;
+ if (hera::is_infinity(internal_p)) {
+ // max norm
+ for (unsigned i = 0; i < dimension(); ++i)
+ result = std::max(result, fabs(coordinate(p1,i) - coordinate(p2,i)));
+ } else if (internal_p == Real(1.0)) {
+ // l1-norm
+ for (unsigned i = 0; i < dimension(); ++i)
+ result += fabs(coordinate(p1,i) - coordinate(p2,i));
+ } else if (internal_p == Real(2.0)) {
+ result = sqrt(sq_distance(p1,p2));
+ } else {
+ assert(internal_p > 1.0);
+ for (unsigned i = 0; i < dimension(); ++i)
+ result += std::pow(fabs(coordinate(p1,i) - coordinate(p2,i)), internal_p);
+ result = std::pow(result, Real(1.0) / internal_p);
+ }
+ return result;
+ }
DistanceType distance(PointHandle p1, PointHandle p2) const { return distance(PointType({p1.p}), PointType({p2.p})); }
DistanceType sq_distance(PointType p1, PointType p2) const { Real res = 0; for (unsigned i = 0; i < dimension(); ++i) { Real c1 = coordinate(p1,i), c2 = coordinate(p2,i); res += (c1 - c2)*(c1 - c2); } return res; }
DistanceType sq_distance(PointHandle p1, PointHandle p2) const { return sq_distance(PointType({p1.p}), PointType({p2.p})); }
diff --git a/geom_matching/wasserstein/include/dnn/local/kd-tree.hpp b/geom_matching/wasserstein/include/dnn/local/kd-tree.hpp
index 3a4f0eb..bdeef45 100644
--- a/geom_matching/wasserstein/include/dnn/local/kd-tree.hpp
+++ b/geom_matching/wasserstein/include/dnn/local/kd-tree.hpp
@@ -101,7 +101,7 @@ hera::ws::dnn::KDTree<T>::OrderTree
size_t next_i = (i + 1) % traits.dimension();
// Replace with a size condition instead?
- if (b < m - 1) q.push(KDTreeNode(b, m, next_i));
+ if (m - b > 1) q.push(KDTreeNode(b, m, next_i));
if (e - m > 2) q.push(KDTreeNode(m+1, e, next_i));
}
}
diff --git a/geom_matching/wasserstein/include/hera_infinity.h b/geom_matching/wasserstein/include/hera_infinity.h
new file mode 100644
index 0000000..8d86dbb
--- /dev/null
+++ b/geom_matching/wasserstein/include/hera_infinity.h
@@ -0,0 +1,22 @@
+#ifndef WASSERSTEIN_HERA_INFINITY_H
+#define WASSERSTEIN_HERA_INFINITY_H
+
+// we cannot assume that template parameter Real will always provide infinity() value,
+// so value -1.0 is used to encode infinity (l_inf norm is used by default)
+
+namespace hera {
+
+ template<class Real = double>
+ inline bool is_infinity(const Real& x)
+ {
+ return x == Real(-1);
+ };
+
+ template<class Real = double>
+ inline constexpr Real get_infinity()
+ {
+ return Real(-1);
+ }
+}
+
+#endif //WASSERSTEIN_HERA_INFINITY_H
diff --git a/geom_matching/wasserstein/include/wasserstein.h b/geom_matching/wasserstein/include/wasserstein.h
index b90a545..a24bada 100644
--- a/geom_matching/wasserstein/include/wasserstein.h
+++ b/geom_matching/wasserstein/include/wasserstein.h
@@ -73,7 +73,7 @@ namespace ws
// compare as multisets
template<class PairContainer>
- bool are_equal(const PairContainer& dgm1, const PairContainer& dgm2)
+ inline bool are_equal(const PairContainer& dgm1, const PairContainer& dgm2)
{
if (dgm1.size() != dgm2.size()) {
return false;
@@ -97,7 +97,7 @@ namespace ws
// to handle points with one coordinate = infinity
template<class RealType>
- RealType get_one_dimensional_cost(std::vector<RealType>& set_A,
+ inline RealType get_one_dimensional_cost(std::vector<RealType>& set_A,
std::vector<RealType>& set_B,
const RealType wasserstein_power)
{
@@ -210,7 +210,7 @@ namespace ws
// this function assumes that all coordinates are finite
// points at infinity are processed in wasserstein_cost
template<class RealType>
- RealType wasserstein_cost_vec(const std::vector<DiagramPoint<RealType>>& A,
+ inline RealType wasserstein_cost_vec(const std::vector<DiagramPoint<RealType>>& A,
const std::vector<DiagramPoint<RealType>>& B,
const AuctionParams<RealType>& params,
const std::string& _log_filename_prefix)
@@ -245,7 +245,7 @@ namespace ws
template<class PairContainer>
-typename DiagramTraits<PairContainer>::RealType
+inline typename DiagramTraits<PairContainer>::RealType
wasserstein_cost(const PairContainer& A,
const PairContainer& B,
const AuctionParams< typename DiagramTraits<PairContainer>::RealType >& params,
@@ -332,9 +332,9 @@ wasserstein_cost(const PairContainer& A,
}
template<class PairContainer>
-typename DiagramTraits<PairContainer>::RealType
-wasserstein_dist(PairContainer& A,
- PairContainer& B,
+inline typename DiagramTraits<PairContainer>::RealType
+wasserstein_dist(const PairContainer& A,
+ const PairContainer& B,
const AuctionParams<typename DiagramTraits<PairContainer>::RealType> params,
const std::string& _log_filename_prefix = "")
{
diff --git a/geom_matching/wasserstein/include/wasserstein_pure_geom.hpp b/geom_matching/wasserstein/include/wasserstein_pure_geom.hpp
index 2a57599..096d95d 100644
--- a/geom_matching/wasserstein/include/wasserstein_pure_geom.hpp
+++ b/geom_matching/wasserstein/include/wasserstein_pure_geom.hpp
@@ -30,7 +30,7 @@ namespace ws
using AuctionRunnerJacR = typename hera::ws::AuctionRunnerJac<Real, hera::ws::AuctionOracleKDTreePureGeom<Real>, hera::ws::dnn::DynamicPointVector<Real>>;
-double wasserstein_cost(const DynamicPointVector<double>& set_A, const DynamicPointVector<double>& set_B, const AuctionParams<double>& params)
+inline double wasserstein_cost(const DynamicPointVector<double>& set_A, const DynamicPointVector<double>& set_B, const AuctionParams<double>& params)
{
if (params.wasserstein_power < 1.0) {
throw std::runtime_error("Bad q in Wasserstein " + std::to_string(params.wasserstein_power));
@@ -72,10 +72,9 @@ double wasserstein_cost(const DynamicPointVector<double>& set_A, const DynamicPo
auction.run_auction();
return auction.get_wasserstein_cost();
}
-
}
-double wasserstein_dist(const DynamicPointVector<double>& set_A, const DynamicPointVector<double>& set_B, const AuctionParams<double>& params)
+inline double wasserstein_dist(const DynamicPointVector<double>& set_A, const DynamicPointVector<double>& set_B, const AuctionParams<double>& params)
{
return std::pow(wasserstein_cost(set_A, set_B, params), 1.0 / params.wasserstein_power);
}
diff --git a/geom_matching/wasserstein/tests/test_hera_wasserstein.cpp b/geom_matching/wasserstein/tests/test_hera_wasserstein.cpp
index 3d5db5f..0a80d2f 100644
--- a/geom_matching/wasserstein/tests/test_hera_wasserstein.cpp
+++ b/geom_matching/wasserstein/tests/test_hera_wasserstein.cpp
@@ -8,61 +8,12 @@
#undef LOG_AUCTION
#include "wasserstein.h"
+#include "tests_reader.h"
+using namespace hera_test;
using PairVector = std::vector<std::pair<double, double>>;
-std::vector<std::string> split_on_delim(const std::string& s, char delim)
-{
- std::stringstream ss(s);
- std::string token;
- std::vector<std::string> tokens;
- while(std::getline(ss, token, delim)) {
- tokens.push_back(token);
- }
- return tokens;
-}
-
-
-// single row in a file with test cases
-struct TestFromFileCase {
-
- std::string file_1;
- std::string file_2;
- double q;
- double internal_p;
- double answer;
-
- TestFromFileCase(std::string s)
- {
- auto tokens = split_on_delim(s, ' ');
- assert(tokens.size() == 5);
-
- file_1 = tokens.at(0);
- file_2 = tokens.at(1);
- q = std::stod(tokens.at(2));
- internal_p = std::stod(tokens.at(3));
- answer = std::stod(tokens.at(4));
-
- if ( q < 1.0 or std::isinf(q) or
- (internal_p != hera::get_infinity<double>() and internal_p < 1.0)) {
- throw std::runtime_error("Bad line in test_list.txt");
- }
- }
-};
-
-std::ostream& operator<<(std::ostream& out, const TestFromFileCase& s)
-{
- out << "[" << s.file_1 << ", " << s.file_2 << ", q = " << s.q << ", norm = ";
- if (s.internal_p != hera::get_infinity()) {
- out << s.internal_p;
- } else {
- out << "infinity";
- }
- out << ", answer = " << s.answer << "]";
- return out;
-}
-
TEST_CASE("simple cases", "wasserstein_dist")
{
diff --git a/geom_matching/wasserstein/tests/test_hera_wasserstein_pure_geom.cpp b/geom_matching/wasserstein/tests/test_hera_wasserstein_pure_geom.cpp
new file mode 100644
index 0000000..9603ceb
--- /dev/null
+++ b/geom_matching/wasserstein/tests/test_hera_wasserstein_pure_geom.cpp
@@ -0,0 +1,111 @@
+#include "catch/catch.hpp"
+
+#include <sstream>
+#include <iostream>
+
+
+#undef LOG_AUCTION
+
+#include "wasserstein_pure_geom.hpp"
+#include "tests_reader.h"
+
+using namespace hera_test;
+
+TEST_CASE("simple point clouds", "wasserstein_dist_pure_geom")
+{
+// int n_points = 3;
+// int dim = 3;
+// using Traits = hera::ws::dnn::DynamicPointTraits<double>;
+// hera::ws::dnn::DynamicPointTraits<double> traits(dim);
+// hera::ws::dnn::DynamicPointVector<double> dgm_a = traits.container(n_points);
+// hera::ws::dnn::DynamicPointVector<double> dgm_b = traits.container(n_points);
+//
+// dgm_a[0][0] = 0.0;
+// dgm_a[0][1] = 0.0;
+// dgm_a[0][2] = 0.0;
+//
+// dgm_a[1][0] = 1.0;
+// dgm_a[1][1] = 0.0;
+// dgm_a[1][2] = 0.0;
+//
+// dgm_a[2][0] = 0.0;
+// dgm_a[2][1] = 1.0;
+// dgm_a[2][2] = 1.0;
+//
+// dgm_b[0][0] = 0.0;
+// dgm_b[0][1] = 0.1;
+// dgm_b[0][2] = 0.1;
+//
+// dgm_b[1][0] = 1.1;
+// dgm_b[1][1] = 0.0;
+// dgm_b[1][2] = 0.0;
+//
+// dgm_b[2][0] = 0.0;
+// dgm_b[2][1] = 1.1;
+// dgm_b[2][2] = 0.9;
+
+ const int dim = 3;
+ using Traits = hera::ws::dnn::DynamicPointTraits<double>;
+ hera::ws::dnn::DynamicPointTraits<double> traits(dim);
+ hera::AuctionParams<double> params;
+ params.dim = dim;
+ params.wasserstein_power = 1.0;
+ params.delta = 0.01;
+ params.internal_p = hera::get_infinity<double>();
+ params.initial_epsilon = 0.0;
+ params.epsilon_common_ratio = 0.0;
+ params.max_num_phases = 30;
+ params.gamma_threshold = 0.0;
+ params.max_bids_per_round = 0; // use Jacobi
+
+
+ SECTION("trivial: two single-point diagrams-1") {
+
+ int n_points = 1;
+ hera::ws::dnn::DynamicPointVector<double> dgm_a = traits.container(n_points);
+ hera::ws::dnn::DynamicPointVector<double> dgm_b = traits.container(n_points);
+
+ dgm_a[0][0] = 0.0;
+ dgm_a[0][1] = 0.0;
+ dgm_a[0][2] = 0.0;
+
+ dgm_b[0][0] = 1.0;
+ dgm_b[0][1] = 1.0;
+ dgm_b[0][2] = 1.0;
+
+ std::vector<size_t> max_bids { 1, 10, 0 };
+ std::vector<int> internal_ps{ 1, 2, static_cast<int>(hera::get_infinity()) };
+ std::vector<double> wasserstein_powers { 1, 2, 3 };
+
+ for(auto internal_p : internal_ps) {
+ // there is only one point, so the answer does not depend wasserstein power
+ double correct_answer;
+ switch (internal_p) {
+ case 1 :
+ correct_answer = 3.0;
+ break;
+ case 2 :
+ correct_answer = sqrt(3.0);
+ break;
+ case static_cast<int>(hera::get_infinity()) :
+ correct_answer = 1.0;
+ break;
+ default :
+ throw std::runtime_error("Correct answer not specified in test case");
+ }
+
+ for (auto max_bid : max_bids) {
+ for (auto wasserstein_power : wasserstein_powers) {
+ params.max_bids_per_round = max_bid;
+ params.internal_p = internal_p;
+ params.wasserstein_power = wasserstein_power;
+ double d1 = hera::ws::wasserstein_dist(dgm_a, dgm_b, params);
+ double d2 = hera::ws::wasserstein_dist(dgm_b, dgm_a, params);
+ REQUIRE(fabs(d1 - d2) <= 0.00000000001);
+ REQUIRE(fabs(d1 - correct_answer) <= 0.00000000001);
+ }
+ }
+ }
+ }
+}
+
diff --git a/geom_matching/wasserstein/tests/tests_reader.h b/geom_matching/wasserstein/tests/tests_reader.h
new file mode 100644
index 0000000..f2d5735
--- /dev/null
+++ b/geom_matching/wasserstein/tests/tests_reader.h
@@ -0,0 +1,67 @@
+#ifndef WASSERSTEIN_TESTS_READER_H
+#define WASSERSTEIN_TESTS_READER_H
+
+#include <vector>
+#include <string>
+#include <ostream>
+#include <iostream>
+#include <sstream>
+#include <cassert>
+#include <cmath>
+
+#include "hera_infinity.h"
+
+namespace hera_test {
+ inline std::vector<std::string> split_on_delim(const std::string& s, char delim)
+ {
+ std::stringstream ss(s);
+ std::string token;
+ std::vector<std::string> tokens;
+ while (std::getline(ss, token, delim)) {
+ tokens.push_back(token);
+ }
+ return tokens;
+ }
+
+
+ // single row in a file with test cases
+ struct TestFromFileCase
+ {
+
+ std::string file_1;
+ std::string file_2;
+ double q;
+ double internal_p;
+ double answer;
+
+ TestFromFileCase(std::string s)
+ {
+ auto tokens = split_on_delim(s, ' ');
+ assert(tokens.size() == 5);
+
+ file_1 = tokens.at(0);
+ file_2 = tokens.at(1);
+ q = std::stod(tokens.at(2));
+ internal_p = std::stod(tokens.at(3));
+ answer = std::stod(tokens.at(4));
+
+ if (q < 1.0 or std::isinf(q) or
+ (internal_p != hera::get_infinity<double>() and internal_p < 1.0)) {
+ throw std::runtime_error("Bad line in test_list.txt");
+ }
+ }
+ };
+
+ inline std::ostream& operator<<(std::ostream& out, const TestFromFileCase& s)
+ {
+ out << "[" << s.file_1 << ", " << s.file_2 << ", q = " << s.q << ", norm = ";
+ if (s.internal_p != hera::get_infinity()) {
+ out << s.internal_p;
+ } else {
+ out << "infinity";
+ }
+ out << ", answer = " << s.answer << "]";
+ return out;
+ }
+} // namespace hera_test
+#endif //WASSERSTEIN_TESTS_READER_H