summaryrefslogtreecommitdiff
path: root/geom_matching/wasserstein/example/wasserstein_dist.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'geom_matching/wasserstein/example/wasserstein_dist.cpp')
-rw-r--r--geom_matching/wasserstein/example/wasserstein_dist.cpp83
1 files changed, 61 insertions, 22 deletions
diff --git a/geom_matching/wasserstein/example/wasserstein_dist.cpp b/geom_matching/wasserstein/example/wasserstein_dist.cpp
index a2ed234..fcbc641 100644
--- a/geom_matching/wasserstein/example/wasserstein_dist.cpp
+++ b/geom_matching/wasserstein/example/wasserstein_dist.cpp
@@ -1,5 +1,5 @@
/*
-
+
Copyright (c) 2015, M. Kerber, D. Morozov, A. Nigmetov
All rights reserved.
@@ -12,7 +12,7 @@ Redistribution and use in source and binary forms, with or without modification,
3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
+
You are under no obligation whatsoever to provide any bug fixes, patches, or
upgrades to the features, functionality or performance of the source code
@@ -27,12 +27,14 @@ derivative works thereof, in binary and source code form.
*/
#include <iostream>
+#include <locale>
#include <iomanip>
-#include <fstream>
#include <vector>
-#include <algorithm>
-#include <limits>
-#include <random>
+
+//#define LOG_AUCTION
+
+//#include "auction_runner_fr.h"
+//#include "auction_runner_fr.hpp"
#include "wasserstein.h"
@@ -41,51 +43,88 @@ derivative works thereof, in binary and source code form.
int main(int argc, char* argv[])
{
- geom_ws::PairVector diagramA, diagramB;
+ using PairVector = std::vector<std::pair<double, double>>;
+ PairVector diagramA, diagramB;
+
+ hera::AuctionParams<double> params;
if (argc < 3 ) {
- std::cerr << "Usage: " << argv[0] << " file1 file2 [wasserstein_degree] [relative_error] [internal norm] [output_actual_error]. By default power is 1.0, relative error is 0.01, internal norm is l_infinity, actual relative error is not printed." << std::endl;
+ std::cerr << "Usage: " << argv[0] << " file1 file2 [wasserstein_degree] [relative_error] [internal norm] [initial epsilon] [epsilon_factor] [max_bids_per_round] [gamma_threshold][log_filename_prefix]. By default power is 1.0, relative error is 0.01, internal norm is l_infinity, initall epsilon is chosen automatically, epsilon factor is 5.0, Jacobi variant is used (max bids per round is maximal), gamma_threshold = 0.0." << std::endl;
return 1;
}
- if (!geom_ws::readDiagramPointSet(argv[1], diagramA)) {
+ if (!hera::read_diagram_point_set<double, PairVector>(argv[1], diagramA)) {
std::exit(1);
}
- if (!geom_ws::readDiagramPointSet(argv[2], diagramB)) {
+ if (!hera::read_diagram_point_set(argv[2], diagramB)) {
std::exit(1);
}
- double wasserPower = (4 <= argc) ? atof(argv[3]) : 1.0;
- if (wasserPower < 1.0) {
+ params.wasserstein_power = (4 <= argc) ? atof(argv[3]) : 1.0;
+ if (params.wasserstein_power < 1.0) {
std::cerr << "The third argument (wasserstein_degree) was \"" << argv[3] << "\", must be a number >= 1.0. Cannot proceed. " << std::endl;
std::exit(1);
}
- if (wasserPower == 1.0) {
- geom_ws::removeDuplicates(diagramA, diagramB);
+ if (params.wasserstein_power == 1.0) {
+ hera::remove_duplicates<double>(diagramA, diagramB);
}
//default relative error: 1%
- double delta = (5 <= argc) ? atof(argv[4]) : 0.01;
- if ( delta <= 0.0) {
+ params.delta = (5 <= argc) ? atof(argv[4]) : 0.01;
+ if ( params.delta <= 0.0) {
std::cerr << "The 4th argument (relative error) was \"" << argv[4] << "\", must be a number > 0.0. Cannot proceed. " << std::endl;
std::exit(1);
}
// default for internal metric is l_infinity
- double internal_p = ( 6 <= argc ) ? atof(argv[5]) : std::numeric_limits<double>::infinity();
- if (internal_p < 1.0) {
- std::cerr << "The 5th argument (internal norm) was \"" << argv[5] << "\", must be a number >= 1.0. Cannot proceed. " << std::endl;
+ params.internal_p = ( 6 <= argc ) ? atof(argv[5]) : hera::get_infinity<double>();
+ if (std::isinf(params.internal_p)) {
+ params.internal_p = hera::get_infinity<double>();
+ }
+
+
+ if (not hera::is_p_valid_norm<double>(params.internal_p)) {
+ std::cerr << "The 5th argument (internal norm) was \"" << argv[5] << "\", must be a number >= 1.0 or inf. Cannot proceed. " << std::endl;
std::exit(1);
}
// if you want to specify initial value for epsilon and the factor
// for epsilon-scaling
- double initialEpsilon= ( 7 <= argc ) ? atof(argv[6]) : 0.0 ;
- double epsFactor = ( 8 <= argc ) ? atof(argv[7]) : 0.0 ;
+ params.initial_epsilon= ( 7 <= argc ) ? atof(argv[6]) : 0.0 ;
+
+ if (params.initial_epsilon < 0.0) {
+ std::cerr << "The 6th argument (initial epsilon) was \"" << argv[6] << "\", must be a non-negative number. Cannot proceed." << std::endl;
+ std::exit(1);
+ }
+
+ params.epsilon_common_ratio = ( 8 <= argc ) ? atof(argv[7]) : 0.0 ;
+ if (params.epsilon_common_ratio <= 1.0 and params.epsilon_common_ratio != 0.0) {
+ std::cerr << "The 7th argument (epsilon factor) was \"" << argv[7] << "\", must be a number greater than 1. Cannot proceed." << std::endl;
+ std::exit(1);
+ }
+
+
+ params.max_bids_per_round = ( 9 <= argc ) ? atoi(argv[8]) : 0;
+ if (params.max_bids_per_round == 0)
+ params.max_bids_per_round = std::numeric_limits<size_t>::max();
+
+
+ params.gamma_threshold = (10 <= argc) ? atof(argv[9]) : 0.0;
+
+ std::string log_filename_prefix = ( 11 <= argc ) ? argv[10] : "";
+
+ params.max_num_phases = 800;
+
+#ifdef LOG_AUCTION
+ spdlog::set_level(spdlog::level::info);
+#endif
+
+ double res = hera::wasserstein_dist(diagramA, diagramB, params, log_filename_prefix);
- double res = geom_ws::wassersteinDist(diagramA, diagramB, wasserPower, delta, internal_p, initialEpsilon, epsFactor);
std::cout << std::setprecision(15) << res << std::endl;
+
return 0;
+
}