diff options
author | Gard Spreemann <gard.spreemann@epfl.ch> | 2019-05-10 16:20:51 +0200 |
---|---|---|
committer | Gard Spreemann <gard.spreemann@epfl.ch> | 2019-05-10 16:20:51 +0200 |
commit | a99391b04635836b028b8f875ac8d99226da4683 (patch) | |
tree | f3457a3f8b07f3332d0df67bff6682c2d356af01 | |
parent | 4c1b1727ab44a773d27b090463ac6db957267136 (diff) |
Allow for bipartite diagram sets.
-rw-r--r-- | geom_matching/wasserstein/mpi/main.cpp | 122 |
1 files changed, 89 insertions, 33 deletions
diff --git a/geom_matching/wasserstein/mpi/main.cpp b/geom_matching/wasserstein/mpi/main.cpp index c3b4e09..a1d980a 100644 --- a/geom_matching/wasserstein/mpi/main.cpp +++ b/geom_matching/wasserstein/mpi/main.cpp @@ -14,8 +14,13 @@ void print_help(const std::string & invocation) { std::cout << "Usage: " << invocation << " arguments" << std::endl; std::cout << "Arguments:" << std::endl; - std::cout << " --in-list, -i file" << std::endl; - std::cout << " Mandatory. File containing a list of persistence diagram files to process, one file per line." << std::endl; + std::cout << " --in-lists, -i file_1 file_2" << std::endl; + std::cout << " Mandatory. Files containing lists of persistence diagram files to process, one file per line." << std::endl; + std::cout << " If the same file is given twice, and contains m entries:" << std::endl; + std::cout << " The program outputs the full m×m distance matrix for all pairs of persistence diagrams given." << std::endl; + std::cout << " If different files are given, the first containing m and the second containing n entries:" << std::endl; + std::cout << " The outputs the *row-major strict upper triangle* of the m×n distance matrix corresponding " << std::endl; + std::cout << " To the diagrams in the two files." << std::endl; std::cout << " --in-type, -t type" << std::endl; std::cout << " Optional (defaults to dipha). Input file format, dipha|txt." << std::endl; std::cout << " --dimension, -d dim" << std::endl; @@ -95,7 +100,8 @@ int main(int argc, char ** argv) MPI_Get_processor_name(processor_name_tmp, &name_len); std::string processor_name(processor_name_tmp); - std::string list_file_name; + std::string list_1_file_name; + std::string list_2_file_name; std::string out_file_name; int dim = -1; int in_type = file_type_dipha; @@ -114,12 +120,15 @@ int main(int argc, char ** argv) for (int i = 1; i < argc; ++i) { std::string arg(argv[i]); - if (arg == std::string("--in-list") || arg == std::string("-i")) + if (arg == std::string("--in-lists") || arg == std::string("-i")) { - if (i < argc - 1) - list_file_name = std::string(argv[++i]); + if (i < argc - 2) + { + list_1_file_name = std::string(argv[++i]); + list_2_file_name = std::string(argv[++i]); + } else - arg_fail(invocation, world_rank, "Missing argument for --in-list."); + arg_fail(invocation, world_rank, "Need two arguments for --in-lists."); } else if (arg == std::string("--in-type") || arg == std::string("-t")) { @@ -207,8 +216,8 @@ int main(int argc, char ** argv) // Argument validation. - if (list_file_name == std::string("")) - arg_fail(invocation, world_rank, "Need input file list."); + if (list_1_file_name == std::string("") || list_2_file_name == std::string("")) + arg_fail(invocation, world_rank, "Need input file lists."); if (dim < 0) arg_fail(invocation, world_rank, "Dimension must be non-negative."); if (params.wasserstein_power < 1 || !std::isfinite(params.wasserstein_power)) @@ -222,25 +231,63 @@ int main(int argc, char ** argv) if (chunk_size <= 0) arg_fail(invocation, world_rank, "Chunk size must be positive."); - std::vector<std::string> files; + if (world_rank == 0) + { + if (finitization == std::numeric_limits<double>::infinity()) + { + std::cout << "Warning: You are running without finitization. Bugs do exist in this branch of Hera if you have diagrams with infinite generators." << std::endl; + } + } + + bool square = list_1_file_name == list_2_file_name; + + std::vector<std::string> files_1; + std::vector<std::string> files_2; - std::ifstream list_file(list_file_name, std::ios::in); + std::ifstream list_file(list_1_file_name, std::ios::in); std::string line; while (std::getline(list_file, line)) { - files.push_back(line); + files_1.push_back(line); } list_file.close(); - int n = files.size(); - - std::vector<std::pair<int, int>> idxs((n*(n-1))/2); - int k = 0; - for (int i = 0; i < n; ++i) + if (square) + { + files_2 = files_1; + } + else { - for (int j = i+1; j < n; ++j) + list_file.open(list_2_file_name, std::ios::in); + while (std::getline(list_file, line)) { - idxs[k++] = std::make_pair(i, j); + files_2.push_back(line); + } + list_file.close(); + } + + int m = files_1.size(); + int n = files_2.size(); + + std::vector<std::pair<int, int>> idxs; + if (square) + { + for (int i = 0; i < n; ++i) + { + for (int j = i+1; j < n; ++j) + { + idxs.push_back(std::make_pair(i, j)); + } + } + } + else + { + for (int i = 0; i < m; ++i) + { + for (int j = 0; j < n; ++j) + { + idxs.push_back(std::make_pair(i, j)); + } } } @@ -321,11 +368,11 @@ int main(int argc, char ** argv) pd_1.clear(); if (in_type == file_type_dipha) { - load_success = hera::read_diagram_dipha<double, std::vector<std::pair<double, double> > >(files[i], dim, pd_1); + load_success = hera::read_diagram_dipha<double, std::vector<std::pair<double, double> > >(files_1[i], dim, pd_1); } else if (in_type == file_type_txt) { - load_success = hera::read_diagram_point_set<double, std::vector<std::pair<double, double> > >(files[i], pd_1); + load_success = hera::read_diagram_point_set<double, std::vector<std::pair<double, double> > >(files_1[i], pd_1); } else { @@ -333,7 +380,7 @@ int main(int argc, char ** argv) } if (!load_success) - fail(world_rank, std::string("Failed to load file ") + files[i] + std::string(".")); + fail(world_rank, std::string("Failed to load file ") + files_1[i] + std::string(".")); hera::finitize(finitization, pd_1); } @@ -343,11 +390,11 @@ int main(int argc, char ** argv) pd_2.clear(); if (in_type == file_type_dipha) { - load_success = hera::read_diagram_dipha<double, std::vector<std::pair<double, double> > >(files[j], dim, pd_2); + load_success = hera::read_diagram_dipha<double, std::vector<std::pair<double, double> > >(files_2[j], dim, pd_2); } else if (in_type == file_type_txt) { - load_success = hera::read_diagram_point_set<double, std::vector<std::pair<double, double> > >(files[j], pd_2); + load_success = hera::read_diagram_point_set<double, std::vector<std::pair<double, double> > >(files_2[j], pd_2); } else { @@ -355,14 +402,13 @@ int main(int argc, char ** argv) } if (!load_success) - fail(world_rank, std::string("Failed to load file ") + files[i] + std::string(".")); + fail(world_rank, std::string("Failed to load file ") + files_2[j] + std::string(".")); hera::finitize(finitization, pd_2); } std::string fixme(""); results[k - work[0]] = hera::wasserstein_dist(pd_1, pd_2, params, fixme); - i_prev = i; j_prev = j; } @@ -378,18 +424,28 @@ int main(int argc, char ** argv) { std::cout << "Writing out." << std::endl; std::ofstream out_file(out_file_name, std::ios::out); - for (int i = 0; i < n; ++i) + if (square) { - for (int j = 0; j < i; ++j) + for (int i = 0; i < n; ++i) { - out_file << std::scientific << results[unroll(n, j, i)] << " "; + for (int j = 0; j < i; ++j) + { + out_file << std::scientific << results[unroll(n, j, i)] << " "; + } + out_file << "0 "; + for (int j = i + 1; j < n; ++j) + { + out_file << std::scientific << results[unroll(n, i, j)] << " "; + } + out_file << std::endl; } - out_file << "0 "; - for (int j = i + 1; j < n; ++j) + } + else + { + for (auto it = results.cbegin(); it != results.cend(); ++it) { - out_file << std::scientific << results[unroll(n, i, j)] << " "; + out_file << std::scientific << (*it) << "\n"; } - out_file << std::endl; } out_file.close(); } |