#include #include #include #include #include #include #include #include #include #include "debug.hpp" #include "wasserstein.h" void print_help(const std::string & invocation) { std::cout << "Usage: " << invocation << " arguments" << std::endl; std::cout << "Arguments:" << std::endl; std::cout << " --in-lists, -i file_1 file_2" << std::endl; std::cout << " Mandatory. Files containing lists of persistence diagram files to process, one file per line." << std::endl; std::cout << " If the same file is given twice, and contains m entries:" << std::endl; std::cout << " The program outputs the full m×m distance matrix for all pairs of persistence diagrams given." << std::endl; std::cout << " If different files are given, the first containing m and the second containing n entries:" << std::endl; std::cout << " The outputs the *row-major strict upper triangle* of the m×n distance matrix corresponding " << std::endl; std::cout << " To the diagrams in the two files." << std::endl; std::cout << " --in-type, -t type" << std::endl; std::cout << " Optional (defaults to dipha). Input file format, dipha|txt." << std::endl; std::cout << " --dimension, -d dim" << std::endl; std::cout << " Mandatory if the input is a DIPHA persistence diagram file. Ignored otherwise. Integer." << std::endl; std::cout << " --outer-norm, -p p" << std::endl; std::cout << " Mandatory. Floating point. Outer norm. In the interval [1, ∞)." << std::endl; std::cout << " --inner-norm, -q q" << std::endl; std::cout << " Mandatory. Floating point. Inner norm. In the interval [1, ∞]. Use inf for infinity." << std::endl; std::cout << " --error, -e e" << std::endl; std::cout << " Mandatory. Relative error. Positive floating point." << std::endl; std::cout << " --finitize, -f f" << std::endl; std::cout << " Optional. Make infinite intervals die at f." << std::endl; std::cout << " --out, -o file" << std::endl; std::cout << " Mandatory. Output file name. Plain text." << std::endl; std::cout << " --chunk, -c c" << std::endl; std::cout << " Optional. Size of work chunk to send off to each computational node. Too small a value yields a lot of overhead, too large a value can cause an unbalanced load. Increase if there are many small computations. Default: 100." << std::endl; std::cout << " --help, -h" << std::endl; std::cout << " Print this help text." << std::endl; } void arg_fail(const std::string & invocation, int rank, const std::string & message) { if (rank == 0) { std::cout << message << std::endl; print_help(invocation); } MPI_Finalize(); exit(1); } void fail(int rank, const std::string & message) { if (rank == 0) { std::cout << message << std::endl; } MPI_Finalize(); exit(1); } inline int unroll(int n, int i, int j) { IFDEBUG( assert(i < n); assert(j < n); assert(n > 0); assert(i < j); assert(i >= 0); assert(j >= 0); ); return i*n + j - (i*(i+1))/2 - i - 1; } enum Message_tag { tag_result, tag_work }; enum File_type { file_type_dipha, file_type_txt }; int main(int argc, char ** argv) { MPI_Init(NULL, NULL); std::string invocation(argv[0]); int world_size; MPI_Comm_size(MPI_COMM_WORLD, &world_size); int world_rank; MPI_Comm_rank(MPI_COMM_WORLD, &world_rank); if (world_size < 2) { arg_fail(invocation, world_rank, "Currently there is no support for running with just one process. Please run at least 2 MPI jobs."); } char processor_name_tmp[MPI_MAX_PROCESSOR_NAME]; int name_len; MPI_Get_processor_name(processor_name_tmp, &name_len); std::string processor_name(processor_name_tmp); std::string list_1_file_name; std::string list_2_file_name; std::string out_file_name; int dim = -1; int in_type = file_type_dipha; double finitization = std::numeric_limits::infinity(); int chunk_size = 100; hera::AuctionParams params; params.wasserstein_power = std::numeric_limits::quiet_NaN(); params.delta = std::numeric_limits::quiet_NaN(); params.internal_p = std::numeric_limits::quiet_NaN(); params.initial_epsilon = 0.0; // Default value taken from upstream example code. params.epsilon_common_ratio = 0.0; // Default value taken from upstream example code. params.max_bids_per_round = std::numeric_limits::max(); // Default value taken from upstream example code. params.gamma_threshold = 0.0; // Default value taken from upstream example code. params.max_num_phases = 800; // Default value taken from upstream example code. for (int i = 1; i < argc; ++i) { std::string arg(argv[i]); if (arg == std::string("--in-lists") || arg == std::string("-i")) { if (i < argc - 2) { list_1_file_name = std::string(argv[++i]); list_2_file_name = std::string(argv[++i]); } else arg_fail(invocation, world_rank, "Need two arguments for --in-lists."); } else if (arg == std::string("--in-type") || arg == std::string("-t")) { if (i < argc - 1) { std::string argnext(argv[++i]); if (argnext == std::string("txt")) in_type = file_type_txt; else if (argnext == std::string("dipha")) in_type = file_type_dipha; else arg_fail(invocation, world_rank, "Invalid argument for --in-type."); } else arg_fail(invocation, world_rank, "Missing argument for --in-type."); } else if (arg == std::string("--dimension") || arg == std::string("-d")) { if (i < argc - 1) dim = std::atoi(argv[++i]); else arg_fail(invocation, world_rank, "Missing argument for --dimension."); } else if (arg == std::string("--outer-norm") || arg == std::string("-p")) { if (i < argc - 1) params.wasserstein_power = std::atof(argv[++i]); else arg_fail(invocation, world_rank, "Missing argument for --outer-norm."); } else if (arg == std::string("--inner-norm") || arg == std::string("-q")) { if (i < argc - 1) { std::string argnext(argv[++i]); if (argnext == "inf") params.internal_p = -1; else params.internal_p = std::stod(argnext); } else arg_fail(invocation, world_rank, "Missing argument for --inner-norm."); } else if (arg == std::string("--error") || arg == std::string("-e")) { if (i < argc - 1) params.delta = std::atof(argv[++i]); else arg_fail(invocation, world_rank, "Missing argument for --error."); } else if (arg == std::string("--finitize") || arg == std::string("-f")) { if (i < argc - 1) finitization = std::atof(argv[++i]); else arg_fail(invocation, world_rank, "Missing argument for --finitize."); } else if (arg == std::string("--out") || arg == std::string("-o")) { if (i < argc - 1) out_file_name = std::string(argv[++i]); else arg_fail(invocation, world_rank, "Missing argument for --out."); } else if (arg == std::string("--chunk") || arg == std::string("-c")) { if (i < argc - 1) chunk_size = std::atoi(argv[++i]); else arg_fail(invocation, world_rank, "Missing argument for --chunk."); } else if (arg == std::string("--help") || arg == std::string("-h")) { if (world_rank == 0) print_help(invocation); MPI_Finalize(); exit(0); } else { arg_fail(invocation, world_rank, "Incorrect argument."); } } // Argument validation. if (list_1_file_name == std::string("") || list_2_file_name == std::string("")) arg_fail(invocation, world_rank, "Need input file lists."); if (dim < 0) arg_fail(invocation, world_rank, "Dimension must be non-negative."); if (params.wasserstein_power < 1 || !std::isfinite(params.wasserstein_power)) arg_fail(invocation, world_rank, "Outer norm power must be in the interval [1, ∞)."); if (params.internal_p < 1 || std::isnan(params.internal_p)) arg_fail(invocation, world_rank, "Inner norm power must be in ther interval [1, ∞]."); if (params.delta <= 0 || !std::isfinite(params.delta)) arg_fail(invocation, world_rank, "Error must be positive."); if (out_file_name == std::string("")) arg_fail(invocation, world_rank, "Need output file."); if (chunk_size <= 0) arg_fail(invocation, world_rank, "Chunk size must be positive."); if (world_rank == 0) { if (finitization == std::numeric_limits::infinity()) { std::cout << "Warning: You are running without finitization. Bugs do exist in this branch of Hera if you have diagrams with infinite generators." << std::endl; } } bool square = list_1_file_name == list_2_file_name; std::vector files_1; std::vector files_2; std::ifstream list_file(list_1_file_name, std::ios::in); std::string line; while (std::getline(list_file, line)) { files_1.push_back(line); } list_file.close(); if (square) { files_2 = files_1; } else { list_file.open(list_2_file_name, std::ios::in); while (std::getline(list_file, line)) { files_2.push_back(line); } list_file.close(); } int m = files_1.size(); int n = files_2.size(); std::vector> idxs; if (square) { for (int i = 0; i < n; ++i) { for (int j = i+1; j < n; ++j) { idxs.push_back(std::make_pair(i, j)); } } } else { for (int i = 0; i < m; ++i) { for (int j = 0; j < n; ++j) { idxs.push_back(std::make_pair(i, j)); } } } std::vector results; if (world_rank == 0) { results.resize(idxs.size(), std::numeric_limits::quiet_NaN()); double unused_buf = std::numeric_limits::quiet_NaN(); std::vector result_reqs(world_size); // Element zero not used, index by actual rank. int done = 0; int next_chunk = 0; std::vector> assigned(world_size, std::make_pair(-2, -2)); std::cout << "Total things to compute: " << idxs.size() << std::endl; for (int r = 1; r < world_size; ++r) { MPI_Irecv(&unused_buf, 0, MPI_DOUBLE, r, tag_result, MPI_COMM_WORLD, &(result_reqs[r])); } while ((size_t)done < idxs.size()) { int respondent_index = -1; MPI_Status status; MPI_Waitany(world_size - 1, &(result_reqs[1]), &respondent_index, &status); int r = 1 + respondent_index; std::cout << "Heard back from rank " << r << "." << std::endl; int recv_count = -1; MPI_Get_count(&status, MPI_DOUBLE, &recv_count); std::cout << "Received " << recv_count << " elements from rank " << r << "." << std::endl; assert(recv_count == assigned[r].second - assigned[r].first); done += recv_count; int work[2] = {-1, -1}; if ((size_t)next_chunk*chunk_size < idxs.size()) { work[0] = next_chunk*chunk_size; work[1] = std::min((int)idxs.size(), (next_chunk + 1)*chunk_size); MPI_Irecv(&(results[work[0]]), work[1] - work[0], MPI_DOUBLE, r, tag_result, MPI_COMM_WORLD, &(result_reqs[r])); std::cout << "Rank will get new work." << std::endl; } else std::cout << "Rank will terminate." << std::endl; assigned[r] = std::make_pair(work[0], work[1]); ++next_chunk; MPI_Send(work, 2, MPI_INT, r, tag_work, MPI_COMM_WORLD); std::cout << 100*(double)done/(double)idxs.size() << "% complete." << std::endl; std::cout << "------------------" << std::endl; } } else // Slaves { results.resize(chunk_size, std::numeric_limits::quiet_NaN()); int work[2] = {-2, -2}; while (work[0] != -1) { int i_prev = -1; int j_prev = -1; int i = -1; int j = -1; std::vector > pd_1; std::vector > pd_2; for (int k = work[0]; k < work[1]; ++k) { i = idxs[k].first; j = idxs[k].second; bool load_success = false; if (i != i_prev) { pd_1.clear(); if (in_type == file_type_dipha) { load_success = hera::read_diagram_dipha > >(files_1[i], dim, pd_1); } else if (in_type == file_type_txt) { load_success = hera::read_diagram_point_set > >(files_1[i], pd_1); } else { fail(world_rank, "Boo"); } if (!load_success) fail(world_rank, std::string("Failed to load file ") + files_1[i] + std::string(".")); hera::finitize(finitization, pd_1); } if (j != j_prev) { pd_2.clear(); if (in_type == file_type_dipha) { load_success = hera::read_diagram_dipha > >(files_2[j], dim, pd_2); } else if (in_type == file_type_txt) { load_success = hera::read_diagram_point_set > >(files_2[j], pd_2); } else { fail(world_rank, "Boo"); } if (!load_success) fail(world_rank, std::string("Failed to load file ") + files_2[j] + std::string(".")); hera::finitize(finitization, pd_2); } std::string fixme(""); results[k - work[0]] = hera::wasserstein_dist(pd_1, pd_2, params, fixme); i_prev = i; j_prev = j; } MPI_Send(results.data(), work[1] - work[0], MPI_DOUBLE, 0, tag_result, MPI_COMM_WORLD); MPI_Recv(work, 2, MPI_INT, 0, tag_work, MPI_COMM_WORLD, MPI_STATUS_IGNORE); } } if (world_rank == 0) { std::cout << "Writing out." << std::endl; std::ofstream out_file(out_file_name, std::ios::out); if (square) { for (int i = 0; i < n; ++i) { for (int j = 0; j < i; ++j) { out_file << std::scientific << results[unroll(n, j, i)] << " "; } out_file << "0 "; for (int j = i + 1; j < n; ++j) { out_file << std::scientific << results[unroll(n, i, j)] << " "; } out_file << std::endl; } } else { for (auto it = results.cbegin(); it != results.cend(); ++it) { out_file << std::scientific << (*it) << "\n"; } } out_file.close(); } MPI_Finalize(); }