diff options
Diffstat (limited to 'src/Persistence_representations/include')
-rw-r--r-- | src/Persistence_representations/include/gudhi/Sliced_Wasserstein.h | 63 |
1 files changed, 44 insertions, 19 deletions
diff --git a/src/Persistence_representations/include/gudhi/Sliced_Wasserstein.h b/src/Persistence_representations/include/gudhi/Sliced_Wasserstein.h index 6196e207..f2ec56b7 100644 --- a/src/Persistence_representations/include/gudhi/Sliced_Wasserstein.h +++ b/src/Persistence_representations/include/gudhi/Sliced_Wasserstein.h @@ -51,15 +51,47 @@ class Sliced_Wasserstein { PD diagram; int approx; double sigma; + std::vector<std::vector<double> > projections, projections_diagonal; + public: - Sliced_Wasserstein(PD _diagram){diagram = _diagram; approx = 100; sigma = 0.001;} - Sliced_Wasserstein(PD _diagram, double _sigma, int _approx){diagram = _diagram; approx = _approx; sigma = _sigma;} + void build_rep(){ + + if(approx > 0){ + + double step = pi/this->approx; + int n = diagram.size(); + + for (int i = 0; i < this->approx; i++){ + std::vector<double> l,l_diag; + for (int j = 0; j < n; j++){ + + double px = diagram[j].first; double py = diagram[j].second; + double proj_diag = (px+py)/2; + + l.push_back ( px * cos(-pi/2+i*step) + py * sin(-pi/2+i*step) ); + l_diag.push_back ( proj_diag * cos(-pi/2+i*step) + proj_diag * sin(-pi/2+i*step) ); + } + + std::sort(l.begin(), l.end()); std::sort(l_diag.begin(), l_diag.end()); + projections.push_back(l); projections_diagonal.push_back(l_diag); + + } + + } + + } + + Sliced_Wasserstein(PD _diagram){diagram = _diagram; approx = 100; sigma = 0.001; build_rep();} + Sliced_Wasserstein(PD _diagram, double _sigma, int _approx){diagram = _diagram; approx = _approx; sigma = _sigma; build_rep();} + PD get_diagram(){return this->diagram;} int get_approx(){return this->approx;} double get_sigma(){return this->sigma;} + + // ********************************** // Utils. @@ -227,28 +259,19 @@ class Sliced_Wasserstein { else{ + double step = pi/this->approx; - // Add projections onto diagonal. - int n1, n2; n1 = diagram1.size(); n2 = diagram2.size(); - for (int i = 0; i < n2; i++) - diagram1.emplace_back( (diagram2[i].first + diagram2[i].second)/2, (diagram2[i].first + diagram2[i].second)/2 ); - for (int i = 0; i < n1; i++) - diagram2.emplace_back( (diagram1[i].first + diagram1[i].second)/2, (diagram1[i].first + diagram1[i].second)/2 ); - int n = diagram1.size(); - - // Sort and compare all projections. for (int i = 0; i < this->approx; i++){ - std::vector<std::pair<int,double> > l1, l2; - for (int j = 0; j < n; j++){ - l1.emplace_back( j, diagram1[j].first*cos(-pi/2+i*step) + diagram1[j].second*sin(-pi/2+i*step) ); - l2.emplace_back( j, diagram2[j].first*cos(-pi/2+i*step) + diagram2[j].second*sin(-pi/2+i*step) ); - } - std::sort(l1.begin(),l1.end(), [=](const std::pair<int,double> & p1, const std::pair<int,double> & p2){return p1.second < p2.second;}); - std::sort(l2.begin(),l2.end(), [=](const std::pair<int,double> & p1, const std::pair<int,double> & p2){return p1.second < p2.second;}); - double f = 0; for (int j = 0; j < n; j++) f += std::abs(l1[j].second - l2[j].second); + + std::vector<double> v1; std::vector<double> l1 = this->projections[i]; std::vector<double> l1bis = second.projections_diagonal[i]; std::merge(l1.begin(), l1.end(), l1bis.begin(), l1bis.end(), std::back_inserter(v1)); + std::vector<double> v2; std::vector<double> l2 = second.projections[i]; std::vector<double> l2bis = this->projections_diagonal[i]; std::merge(l2.begin(), l2.end(), l2bis.begin(), l2bis.end(), std::back_inserter(v2)); + int n = v1.size(); double f = 0; + for (int j = 0; j < n; j++) f += std::abs(v1[j] - v2[j]); sw += f*step; + } + } return sw/pi; @@ -265,6 +288,8 @@ class Sliced_Wasserstein { } + + }; } // namespace Sliced_Wasserstein |