summaryrefslogtreecommitdiff
path: root/src/Persistence_representations
diff options
context:
space:
mode:
authormcarrier <mcarrier@636b058d-ea47-450e-bf9e-a15bfbe3eedb>2018-10-31 22:53:43 +0000
committermcarrier <mcarrier@636b058d-ea47-450e-bf9e-a15bfbe3eedb>2018-10-31 22:53:43 +0000
commit78a118fabcee13aab0ca66cf8738b20c95f5d8dd (patch)
treed8fb372866aa6eca17315e228b999c67d5d8bfad /src/Persistence_representations
parent230eac46395aeb406ca79280a8d62fc35b6f41a3 (diff)
git-svn-id: svn+ssh://scm.gforge.inria.fr/svnroot/gudhi/branches/kernels@3970 636b058d-ea47-450e-bf9e-a15bfbe3eedb
Former-commit-id: 161d2d6200b09ba799dbb3694be6c7afe353abb5
Diffstat (limited to 'src/Persistence_representations')
-rw-r--r--src/Persistence_representations/include/gudhi/Sliced_Wasserstein.h19
1 files changed, 12 insertions, 7 deletions
diff --git a/src/Persistence_representations/include/gudhi/Sliced_Wasserstein.h b/src/Persistence_representations/include/gudhi/Sliced_Wasserstein.h
index a0191dd7..8fc4bd15 100644
--- a/src/Persistence_representations/include/gudhi/Sliced_Wasserstein.h
+++ b/src/Persistence_representations/include/gudhi/Sliced_Wasserstein.h
@@ -150,24 +150,29 @@ class Sliced_Wasserstein {
if(this->approx == -1){
// Add projections onto diagonal.
- int n1, n2; n1 = diagram1.size(); n2 = diagram2.size(); double max_ordinate = 0; double max_abscissa = 0;
+ int n1, n2; n1 = diagram1.size(); n2 = diagram2.size();
+ double min_ordinate = std::numeric_limits<double>::max(); double min_abscissa = std::numeric_limits<double>::max();
+ double max_ordinate = std::numeric_limits<double>::lowest(); double max_abscissa = std::numeric_limits<double>::lowest();
for (int i = 0; i < n2; i++){
- max_ordinate = std::max(max_ordinate, std::abs(diagram2[i].second)); max_abscissa = std::max(max_abscissa, std::abs(diagram2[i].first));
+ min_ordinate = std::min(min_ordinate, diagram2[i].second); min_abscissa = std::min(min_abscissa, diagram2[i].first);
+ max_ordinate = std::max(max_ordinate, diagram2[i].second); max_abscissa = std::max(max_abscissa, diagram2[i].first);
diagram1.emplace_back( (diagram2[i].first+diagram2[i].second)/2, (diagram2[i].first+diagram2[i].second)/2 );
}
for (int i = 0; i < n1; i++){
- max_ordinate = std::max(max_ordinate, std::abs(diagram1[i].second)); max_abscissa = std::max(max_abscissa, std::abs(diagram1[i].first));
+ min_ordinate = std::min(min_ordinate, diagram1[i].second); min_abscissa = std::min(min_abscissa, diagram1[i].first);
+ max_ordinate = std::max(max_ordinate, diagram1[i].second); max_abscissa = std::max(max_abscissa, diagram1[i].first);
diagram2.emplace_back( (diagram1[i].first+diagram1[i].second)/2, (diagram1[i].first+diagram1[i].second)/2 );
}
int num_pts_dgm = diagram1.size();
// Slightly perturb the points so that the PDs are in generic positions.
- double thresh_y = max_ordinate * 0.00001; double thresh_x = max_abscissa * 0.00001;
+ double epsilon = 0.0001;
+ double thresh_y = (max_ordinate-min_ordinate) * epsilon; double thresh_x = (max_abscissa-min_abscissa) * epsilon;
std::random_device rd; std::default_random_engine re(rd()); std::uniform_real_distribution<double> uni(-1,1);
- double epsilon = uni(re);
for (int i = 0; i < num_pts_dgm; i++){
- diagram1[i].first += thresh_x*epsilon; diagram1[i].second += thresh_y*epsilon;
- diagram2[i].first += thresh_x*epsilon; diagram2[i].second += thresh_y*epsilon;
+ double u = uni(re);
+ diagram1[i].first += u*thresh_x; diagram1[i].second += u*thresh_y;
+ diagram2[i].first += u*thresh_x; diagram2[i].second += u*thresh_y;
}
// Compute all angles in both PDs.