summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGard Spreemann <gard.spreemann@epfl.ch>2019-05-10 16:20:51 +0200
committerGard Spreemann <gard.spreemann@epfl.ch>2019-05-10 16:20:51 +0200
commita99391b04635836b028b8f875ac8d99226da4683 (patch)
treef3457a3f8b07f3332d0df67bff6682c2d356af01
parent4c1b1727ab44a773d27b090463ac6db957267136 (diff)
Allow for bipartite diagram sets.
-rw-r--r--geom_matching/wasserstein/mpi/main.cpp122
1 files changed, 89 insertions, 33 deletions
diff --git a/geom_matching/wasserstein/mpi/main.cpp b/geom_matching/wasserstein/mpi/main.cpp
index c3b4e09..a1d980a 100644
--- a/geom_matching/wasserstein/mpi/main.cpp
+++ b/geom_matching/wasserstein/mpi/main.cpp
@@ -14,8 +14,13 @@ void print_help(const std::string & invocation)
{
std::cout << "Usage: " << invocation << " arguments" << std::endl;
std::cout << "Arguments:" << std::endl;
- std::cout << " --in-list, -i file" << std::endl;
- std::cout << " Mandatory. File containing a list of persistence diagram files to process, one file per line." << std::endl;
+ std::cout << " --in-lists, -i file_1 file_2" << std::endl;
+ std::cout << " Mandatory. Files containing lists of persistence diagram files to process, one file per line." << std::endl;
+ std::cout << " If the same file is given twice, and contains m entries:" << std::endl;
+ std::cout << " The program outputs the full m×m distance matrix for all pairs of persistence diagrams given." << std::endl;
+ std::cout << " If different files are given, the first containing m and the second containing n entries:" << std::endl;
+ std::cout << " The outputs the *row-major strict upper triangle* of the m×n distance matrix corresponding " << std::endl;
+ std::cout << " To the diagrams in the two files." << std::endl;
std::cout << " --in-type, -t type" << std::endl;
std::cout << " Optional (defaults to dipha). Input file format, dipha|txt." << std::endl;
std::cout << " --dimension, -d dim" << std::endl;
@@ -95,7 +100,8 @@ int main(int argc, char ** argv)
MPI_Get_processor_name(processor_name_tmp, &name_len);
std::string processor_name(processor_name_tmp);
- std::string list_file_name;
+ std::string list_1_file_name;
+ std::string list_2_file_name;
std::string out_file_name;
int dim = -1;
int in_type = file_type_dipha;
@@ -114,12 +120,15 @@ int main(int argc, char ** argv)
for (int i = 1; i < argc; ++i)
{
std::string arg(argv[i]);
- if (arg == std::string("--in-list") || arg == std::string("-i"))
+ if (arg == std::string("--in-lists") || arg == std::string("-i"))
{
- if (i < argc - 1)
- list_file_name = std::string(argv[++i]);
+ if (i < argc - 2)
+ {
+ list_1_file_name = std::string(argv[++i]);
+ list_2_file_name = std::string(argv[++i]);
+ }
else
- arg_fail(invocation, world_rank, "Missing argument for --in-list.");
+ arg_fail(invocation, world_rank, "Need two arguments for --in-lists.");
}
else if (arg == std::string("--in-type") || arg == std::string("-t"))
{
@@ -207,8 +216,8 @@ int main(int argc, char ** argv)
// Argument validation.
- if (list_file_name == std::string(""))
- arg_fail(invocation, world_rank, "Need input file list.");
+ if (list_1_file_name == std::string("") || list_2_file_name == std::string(""))
+ arg_fail(invocation, world_rank, "Need input file lists.");
if (dim < 0)
arg_fail(invocation, world_rank, "Dimension must be non-negative.");
if (params.wasserstein_power < 1 || !std::isfinite(params.wasserstein_power))
@@ -222,25 +231,63 @@ int main(int argc, char ** argv)
if (chunk_size <= 0)
arg_fail(invocation, world_rank, "Chunk size must be positive.");
- std::vector<std::string> files;
+ if (world_rank == 0)
+ {
+ if (finitization == std::numeric_limits<double>::infinity())
+ {
+ std::cout << "Warning: You are running without finitization. Bugs do exist in this branch of Hera if you have diagrams with infinite generators." << std::endl;
+ }
+ }
+
+ bool square = list_1_file_name == list_2_file_name;
+
+ std::vector<std::string> files_1;
+ std::vector<std::string> files_2;
- std::ifstream list_file(list_file_name, std::ios::in);
+ std::ifstream list_file(list_1_file_name, std::ios::in);
std::string line;
while (std::getline(list_file, line))
{
- files.push_back(line);
+ files_1.push_back(line);
}
list_file.close();
- int n = files.size();
-
- std::vector<std::pair<int, int>> idxs((n*(n-1))/2);
- int k = 0;
- for (int i = 0; i < n; ++i)
+ if (square)
+ {
+ files_2 = files_1;
+ }
+ else
{
- for (int j = i+1; j < n; ++j)
+ list_file.open(list_2_file_name, std::ios::in);
+ while (std::getline(list_file, line))
{
- idxs[k++] = std::make_pair(i, j);
+ files_2.push_back(line);
+ }
+ list_file.close();
+ }
+
+ int m = files_1.size();
+ int n = files_2.size();
+
+ std::vector<std::pair<int, int>> idxs;
+ if (square)
+ {
+ for (int i = 0; i < n; ++i)
+ {
+ for (int j = i+1; j < n; ++j)
+ {
+ idxs.push_back(std::make_pair(i, j));
+ }
+ }
+ }
+ else
+ {
+ for (int i = 0; i < m; ++i)
+ {
+ for (int j = 0; j < n; ++j)
+ {
+ idxs.push_back(std::make_pair(i, j));
+ }
}
}
@@ -321,11 +368,11 @@ int main(int argc, char ** argv)
pd_1.clear();
if (in_type == file_type_dipha)
{
- load_success = hera::read_diagram_dipha<double, std::vector<std::pair<double, double> > >(files[i], dim, pd_1);
+ load_success = hera::read_diagram_dipha<double, std::vector<std::pair<double, double> > >(files_1[i], dim, pd_1);
}
else if (in_type == file_type_txt)
{
- load_success = hera::read_diagram_point_set<double, std::vector<std::pair<double, double> > >(files[i], pd_1);
+ load_success = hera::read_diagram_point_set<double, std::vector<std::pair<double, double> > >(files_1[i], pd_1);
}
else
{
@@ -333,7 +380,7 @@ int main(int argc, char ** argv)
}
if (!load_success)
- fail(world_rank, std::string("Failed to load file ") + files[i] + std::string("."));
+ fail(world_rank, std::string("Failed to load file ") + files_1[i] + std::string("."));
hera::finitize(finitization, pd_1);
}
@@ -343,11 +390,11 @@ int main(int argc, char ** argv)
pd_2.clear();
if (in_type == file_type_dipha)
{
- load_success = hera::read_diagram_dipha<double, std::vector<std::pair<double, double> > >(files[j], dim, pd_2);
+ load_success = hera::read_diagram_dipha<double, std::vector<std::pair<double, double> > >(files_2[j], dim, pd_2);
}
else if (in_type == file_type_txt)
{
- load_success = hera::read_diagram_point_set<double, std::vector<std::pair<double, double> > >(files[j], pd_2);
+ load_success = hera::read_diagram_point_set<double, std::vector<std::pair<double, double> > >(files_2[j], pd_2);
}
else
{
@@ -355,14 +402,13 @@ int main(int argc, char ** argv)
}
if (!load_success)
- fail(world_rank, std::string("Failed to load file ") + files[i] + std::string("."));
+ fail(world_rank, std::string("Failed to load file ") + files_2[j] + std::string("."));
hera::finitize(finitization, pd_2);
}
std::string fixme("");
results[k - work[0]] = hera::wasserstein_dist(pd_1, pd_2, params, fixme);
-
i_prev = i;
j_prev = j;
}
@@ -378,18 +424,28 @@ int main(int argc, char ** argv)
{
std::cout << "Writing out." << std::endl;
std::ofstream out_file(out_file_name, std::ios::out);
- for (int i = 0; i < n; ++i)
+ if (square)
{
- for (int j = 0; j < i; ++j)
+ for (int i = 0; i < n; ++i)
{
- out_file << std::scientific << results[unroll(n, j, i)] << " ";
+ for (int j = 0; j < i; ++j)
+ {
+ out_file << std::scientific << results[unroll(n, j, i)] << " ";
+ }
+ out_file << "0 ";
+ for (int j = i + 1; j < n; ++j)
+ {
+ out_file << std::scientific << results[unroll(n, i, j)] << " ";
+ }
+ out_file << std::endl;
}
- out_file << "0 ";
- for (int j = i + 1; j < n; ++j)
+ }
+ else
+ {
+ for (auto it = results.cbegin(); it != results.cend(); ++it)
{
- out_file << std::scientific << results[unroll(n, i, j)] << " ";
+ out_file << std::scientific << (*it) << "\n";
}
- out_file << std::endl;
}
out_file.close();
}