summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorArnur Nigmetov <a.nigmetov@gmail.com>2018-06-01 16:01:00 +0200
committerArnur Nigmetov <a.nigmetov@gmail.com>2018-06-01 16:01:00 +0200
commit58642131623733ed7360fa146d106cff7f3a057c (patch)
tree7b9e4c7fc49849dc9fcf1fd9434bd377cd157800
parentc563d463ffce73b070b35e62baf980d19d0bc4ac (diff)
First trivial test for point cloud version
-rw-r--r--geom_matching/wasserstein/CMakeLists.txt2
-rw-r--r--geom_matching/wasserstein/include/auction_runner_gs.hpp1
-rw-r--r--geom_matching/wasserstein/include/auction_runner_jac.hpp1
-rw-r--r--geom_matching/wasserstein/include/hera_infinity.h2
-rw-r--r--geom_matching/wasserstein/tests/test_hera_wasserstein.cpp53
-rw-r--r--geom_matching/wasserstein/tests/test_hera_wasserstein_pure_geom.cpp111
-rw-r--r--geom_matching/wasserstein/tests/tests_reader.h67
7 files changed, 184 insertions, 53 deletions
diff --git a/geom_matching/wasserstein/CMakeLists.txt b/geom_matching/wasserstein/CMakeLists.txt
index c6fba2c..dea4550 100644
--- a/geom_matching/wasserstein/CMakeLists.txt
+++ b/geom_matching/wasserstein/CMakeLists.txt
@@ -57,6 +57,6 @@ add_executable(wasserstein_dist_point_cloud ${CMAKE_CURRENT_SOURCE_DIR}/example/
target_link_libraries(wasserstein_dist_point_cloud PUBLIC ${libraries})
# Tests
-add_executable(wasserstein_test ${CMAKE_CURRENT_SOURCE_DIR}/tests/tests_main.cpp ${CMAKE_CURRENT_SOURCE_DIR}/tests/test_hera_wasserstein.cpp include/hera_infinity.h)
+add_executable(wasserstein_test ${CMAKE_CURRENT_SOURCE_DIR}/tests/tests_main.cpp ${CMAKE_CURRENT_SOURCE_DIR}/tests/test_hera_wasserstein.cpp ${CMAKE_CURRENT_SOURCE_DIR}/tests/test_hera_wasserstein_pure_geom.cpp include/hera_infinity.h tests/tests_reader.h)
#add_executable(wasserstein_test EXCLUDE_FROM_ALL ${CMAKE_CURRENT_SOURCE_DIR}/tests/test_hera_wasserstein.cpp)
target_link_libraries(wasserstein_test PUBLIC ${libraries})
diff --git a/geom_matching/wasserstein/include/auction_runner_gs.hpp b/geom_matching/wasserstein/include/auction_runner_gs.hpp
index 960c707..141cb2c 100644
--- a/geom_matching/wasserstein/include/auction_runner_gs.hpp
+++ b/geom_matching/wasserstein/include/auction_runner_gs.hpp
@@ -287,6 +287,7 @@ void AuctionRunnerGS<R, AO, PC>::run_auction()
if (num_bidders == 1) {
assign_item_to_bidder(0, 0);
wasserstein_cost = get_item_bidder_cost(0,0);
+ is_distance_computed = true;
return;
}
diff --git a/geom_matching/wasserstein/include/auction_runner_jac.hpp b/geom_matching/wasserstein/include/auction_runner_jac.hpp
index c519de1..e623f4a 100644
--- a/geom_matching/wasserstein/include/auction_runner_jac.hpp
+++ b/geom_matching/wasserstein/include/auction_runner_jac.hpp
@@ -559,6 +559,7 @@ namespace ws {
if (num_bidders == 1) {
assign_item_to_bidder(0, 0);
wasserstein_cost = get_item_bidder_cost(0,0);
+ is_distance_computed = true;
return;
}
R init_eps = (initial_epsilon > 0.0) ? initial_epsilon : oracle.max_val_ / 4.0;
diff --git a/geom_matching/wasserstein/include/hera_infinity.h b/geom_matching/wasserstein/include/hera_infinity.h
index 5a446e7..8d86dbb 100644
--- a/geom_matching/wasserstein/include/hera_infinity.h
+++ b/geom_matching/wasserstein/include/hera_infinity.h
@@ -13,7 +13,7 @@ namespace hera {
};
template<class Real = double>
- inline Real get_infinity()
+ inline constexpr Real get_infinity()
{
return Real(-1);
}
diff --git a/geom_matching/wasserstein/tests/test_hera_wasserstein.cpp b/geom_matching/wasserstein/tests/test_hera_wasserstein.cpp
index 3d5db5f..0a80d2f 100644
--- a/geom_matching/wasserstein/tests/test_hera_wasserstein.cpp
+++ b/geom_matching/wasserstein/tests/test_hera_wasserstein.cpp
@@ -8,61 +8,12 @@
#undef LOG_AUCTION
#include "wasserstein.h"
+#include "tests_reader.h"
+using namespace hera_test;
using PairVector = std::vector<std::pair<double, double>>;
-std::vector<std::string> split_on_delim(const std::string& s, char delim)
-{
- std::stringstream ss(s);
- std::string token;
- std::vector<std::string> tokens;
- while(std::getline(ss, token, delim)) {
- tokens.push_back(token);
- }
- return tokens;
-}
-
-
-// single row in a file with test cases
-struct TestFromFileCase {
-
- std::string file_1;
- std::string file_2;
- double q;
- double internal_p;
- double answer;
-
- TestFromFileCase(std::string s)
- {
- auto tokens = split_on_delim(s, ' ');
- assert(tokens.size() == 5);
-
- file_1 = tokens.at(0);
- file_2 = tokens.at(1);
- q = std::stod(tokens.at(2));
- internal_p = std::stod(tokens.at(3));
- answer = std::stod(tokens.at(4));
-
- if ( q < 1.0 or std::isinf(q) or
- (internal_p != hera::get_infinity<double>() and internal_p < 1.0)) {
- throw std::runtime_error("Bad line in test_list.txt");
- }
- }
-};
-
-std::ostream& operator<<(std::ostream& out, const TestFromFileCase& s)
-{
- out << "[" << s.file_1 << ", " << s.file_2 << ", q = " << s.q << ", norm = ";
- if (s.internal_p != hera::get_infinity()) {
- out << s.internal_p;
- } else {
- out << "infinity";
- }
- out << ", answer = " << s.answer << "]";
- return out;
-}
-
TEST_CASE("simple cases", "wasserstein_dist")
{
diff --git a/geom_matching/wasserstein/tests/test_hera_wasserstein_pure_geom.cpp b/geom_matching/wasserstein/tests/test_hera_wasserstein_pure_geom.cpp
new file mode 100644
index 0000000..9603ceb
--- /dev/null
+++ b/geom_matching/wasserstein/tests/test_hera_wasserstein_pure_geom.cpp
@@ -0,0 +1,111 @@
+#include "catch/catch.hpp"
+
+#include <sstream>
+#include <iostream>
+
+
+#undef LOG_AUCTION
+
+#include "wasserstein_pure_geom.hpp"
+#include "tests_reader.h"
+
+using namespace hera_test;
+
+TEST_CASE("simple point clouds", "wasserstein_dist_pure_geom")
+{
+// int n_points = 3;
+// int dim = 3;
+// using Traits = hera::ws::dnn::DynamicPointTraits<double>;
+// hera::ws::dnn::DynamicPointTraits<double> traits(dim);
+// hera::ws::dnn::DynamicPointVector<double> dgm_a = traits.container(n_points);
+// hera::ws::dnn::DynamicPointVector<double> dgm_b = traits.container(n_points);
+//
+// dgm_a[0][0] = 0.0;
+// dgm_a[0][1] = 0.0;
+// dgm_a[0][2] = 0.0;
+//
+// dgm_a[1][0] = 1.0;
+// dgm_a[1][1] = 0.0;
+// dgm_a[1][2] = 0.0;
+//
+// dgm_a[2][0] = 0.0;
+// dgm_a[2][1] = 1.0;
+// dgm_a[2][2] = 1.0;
+//
+// dgm_b[0][0] = 0.0;
+// dgm_b[0][1] = 0.1;
+// dgm_b[0][2] = 0.1;
+//
+// dgm_b[1][0] = 1.1;
+// dgm_b[1][1] = 0.0;
+// dgm_b[1][2] = 0.0;
+//
+// dgm_b[2][0] = 0.0;
+// dgm_b[2][1] = 1.1;
+// dgm_b[2][2] = 0.9;
+
+ const int dim = 3;
+ using Traits = hera::ws::dnn::DynamicPointTraits<double>;
+ hera::ws::dnn::DynamicPointTraits<double> traits(dim);
+ hera::AuctionParams<double> params;
+ params.dim = dim;
+ params.wasserstein_power = 1.0;
+ params.delta = 0.01;
+ params.internal_p = hera::get_infinity<double>();
+ params.initial_epsilon = 0.0;
+ params.epsilon_common_ratio = 0.0;
+ params.max_num_phases = 30;
+ params.gamma_threshold = 0.0;
+ params.max_bids_per_round = 0; // use Jacobi
+
+
+ SECTION("trivial: two single-point diagrams-1") {
+
+ int n_points = 1;
+ hera::ws::dnn::DynamicPointVector<double> dgm_a = traits.container(n_points);
+ hera::ws::dnn::DynamicPointVector<double> dgm_b = traits.container(n_points);
+
+ dgm_a[0][0] = 0.0;
+ dgm_a[0][1] = 0.0;
+ dgm_a[0][2] = 0.0;
+
+ dgm_b[0][0] = 1.0;
+ dgm_b[0][1] = 1.0;
+ dgm_b[0][2] = 1.0;
+
+ std::vector<size_t> max_bids { 1, 10, 0 };
+ std::vector<int> internal_ps{ 1, 2, static_cast<int>(hera::get_infinity()) };
+ std::vector<double> wasserstein_powers { 1, 2, 3 };
+
+ for(auto internal_p : internal_ps) {
+ // there is only one point, so the answer does not depend wasserstein power
+ double correct_answer;
+ switch (internal_p) {
+ case 1 :
+ correct_answer = 3.0;
+ break;
+ case 2 :
+ correct_answer = sqrt(3.0);
+ break;
+ case static_cast<int>(hera::get_infinity()) :
+ correct_answer = 1.0;
+ break;
+ default :
+ throw std::runtime_error("Correct answer not specified in test case");
+ }
+
+ for (auto max_bid : max_bids) {
+ for (auto wasserstein_power : wasserstein_powers) {
+ params.max_bids_per_round = max_bid;
+ params.internal_p = internal_p;
+ params.wasserstein_power = wasserstein_power;
+ double d1 = hera::ws::wasserstein_dist(dgm_a, dgm_b, params);
+ double d2 = hera::ws::wasserstein_dist(dgm_b, dgm_a, params);
+ REQUIRE(fabs(d1 - d2) <= 0.00000000001);
+ REQUIRE(fabs(d1 - correct_answer) <= 0.00000000001);
+ }
+ }
+ }
+ }
+}
+
diff --git a/geom_matching/wasserstein/tests/tests_reader.h b/geom_matching/wasserstein/tests/tests_reader.h
new file mode 100644
index 0000000..f2d5735
--- /dev/null
+++ b/geom_matching/wasserstein/tests/tests_reader.h
@@ -0,0 +1,67 @@
+#ifndef WASSERSTEIN_TESTS_READER_H
+#define WASSERSTEIN_TESTS_READER_H
+
+#include <vector>
+#include <string>
+#include <ostream>
+#include <iostream>
+#include <sstream>
+#include <cassert>
+#include <cmath>
+
+#include "hera_infinity.h"
+
+namespace hera_test {
+ inline std::vector<std::string> split_on_delim(const std::string& s, char delim)
+ {
+ std::stringstream ss(s);
+ std::string token;
+ std::vector<std::string> tokens;
+ while (std::getline(ss, token, delim)) {
+ tokens.push_back(token);
+ }
+ return tokens;
+ }
+
+
+ // single row in a file with test cases
+ struct TestFromFileCase
+ {
+
+ std::string file_1;
+ std::string file_2;
+ double q;
+ double internal_p;
+ double answer;
+
+ TestFromFileCase(std::string s)
+ {
+ auto tokens = split_on_delim(s, ' ');
+ assert(tokens.size() == 5);
+
+ file_1 = tokens.at(0);
+ file_2 = tokens.at(1);
+ q = std::stod(tokens.at(2));
+ internal_p = std::stod(tokens.at(3));
+ answer = std::stod(tokens.at(4));
+
+ if (q < 1.0 or std::isinf(q) or
+ (internal_p != hera::get_infinity<double>() and internal_p < 1.0)) {
+ throw std::runtime_error("Bad line in test_list.txt");
+ }
+ }
+ };
+
+ inline std::ostream& operator<<(std::ostream& out, const TestFromFileCase& s)
+ {
+ out << "[" << s.file_1 << ", " << s.file_2 << ", q = " << s.q << ", norm = ";
+ if (s.internal_p != hera::get_infinity()) {
+ out << s.internal_p;
+ } else {
+ out << "infinity";
+ }
+ out << ", answer = " << s.answer << "]";
+ return out;
+ }
+} // namespace hera_test
+#endif //WASSERSTEIN_TESTS_READER_H