diff options
Diffstat (limited to 'src/Persistence_representations/include/gudhi/Sliced_Wasserstein.h')
-rw-r--r-- | src/Persistence_representations/include/gudhi/Sliced_Wasserstein.h | 19 |
1 files changed, 12 insertions, 7 deletions
diff --git a/src/Persistence_representations/include/gudhi/Sliced_Wasserstein.h b/src/Persistence_representations/include/gudhi/Sliced_Wasserstein.h index a0191dd7..8fc4bd15 100644 --- a/src/Persistence_representations/include/gudhi/Sliced_Wasserstein.h +++ b/src/Persistence_representations/include/gudhi/Sliced_Wasserstein.h @@ -150,24 +150,29 @@ class Sliced_Wasserstein { if(this->approx == -1){ // Add projections onto diagonal. - int n1, n2; n1 = diagram1.size(); n2 = diagram2.size(); double max_ordinate = 0; double max_abscissa = 0; + int n1, n2; n1 = diagram1.size(); n2 = diagram2.size(); + double min_ordinate = std::numeric_limits<double>::max(); double min_abscissa = std::numeric_limits<double>::max(); + double max_ordinate = std::numeric_limits<double>::lowest(); double max_abscissa = std::numeric_limits<double>::lowest(); for (int i = 0; i < n2; i++){ - max_ordinate = std::max(max_ordinate, std::abs(diagram2[i].second)); max_abscissa = std::max(max_abscissa, std::abs(diagram2[i].first)); + min_ordinate = std::min(min_ordinate, diagram2[i].second); min_abscissa = std::min(min_abscissa, diagram2[i].first); + max_ordinate = std::max(max_ordinate, diagram2[i].second); max_abscissa = std::max(max_abscissa, diagram2[i].first); diagram1.emplace_back( (diagram2[i].first+diagram2[i].second)/2, (diagram2[i].first+diagram2[i].second)/2 ); } for (int i = 0; i < n1; i++){ - max_ordinate = std::max(max_ordinate, std::abs(diagram1[i].second)); max_abscissa = std::max(max_abscissa, std::abs(diagram1[i].first)); + min_ordinate = std::min(min_ordinate, diagram1[i].second); min_abscissa = std::min(min_abscissa, diagram1[i].first); + max_ordinate = std::max(max_ordinate, diagram1[i].second); max_abscissa = std::max(max_abscissa, diagram1[i].first); diagram2.emplace_back( (diagram1[i].first+diagram1[i].second)/2, (diagram1[i].first+diagram1[i].second)/2 ); } int num_pts_dgm = diagram1.size(); // Slightly perturb the points so that the PDs are in generic positions. - double thresh_y = max_ordinate * 0.00001; double thresh_x = max_abscissa * 0.00001; + double epsilon = 0.0001; + double thresh_y = (max_ordinate-min_ordinate) * epsilon; double thresh_x = (max_abscissa-min_abscissa) * epsilon; std::random_device rd; std::default_random_engine re(rd()); std::uniform_real_distribution<double> uni(-1,1); - double epsilon = uni(re); for (int i = 0; i < num_pts_dgm; i++){ - diagram1[i].first += thresh_x*epsilon; diagram1[i].second += thresh_y*epsilon; - diagram2[i].first += thresh_x*epsilon; diagram2[i].second += thresh_y*epsilon; + double u = uni(re); + diagram1[i].first += u*thresh_x; diagram1[i].second += u*thresh_y; + diagram2[i].first += u*thresh_x; diagram2[i].second += u*thresh_y; } // Compute all angles in both PDs. |