From 784697ab263e30c062e92aacfce36d1ed4070c6f Mon Sep 17 00:00:00 2001 From: mcarrier Date: Tue, 6 Mar 2018 17:50:39 +0000 Subject: git-svn-id: svn+ssh://scm.gforge.inria.fr/svnroot/gudhi/branches/kernels@3269 636b058d-ea47-450e-bf9e-a15bfbe3eedb Former-commit-id: 17860628d3250f689152cdf65432c5a61d76f4d2 --- .../example/sliced_wasserstein.cpp | 2 + .../include/gudhi/Sliced_Wasserstein.h | 63 +++++++++++++++------- 2 files changed, 46 insertions(+), 19 deletions(-) (limited to 'src/Persistence_representations') diff --git a/src/Persistence_representations/example/sliced_wasserstein.cpp b/src/Persistence_representations/example/sliced_wasserstein.cpp index f153fbe8..2470029b 100644 --- a/src/Persistence_representations/example/sliced_wasserstein.cpp +++ b/src/Persistence_representations/example/sliced_wasserstein.cpp @@ -32,6 +32,8 @@ int main(int argc, char** argv) { std::vector > persistence1; std::vector > persistence2; + std::vector > > set1; + std::vector > > set2; persistence1.push_back(std::make_pair(1, 2)); persistence1.push_back(std::make_pair(6, 8)); diff --git a/src/Persistence_representations/include/gudhi/Sliced_Wasserstein.h b/src/Persistence_representations/include/gudhi/Sliced_Wasserstein.h index 6196e207..f2ec56b7 100644 --- a/src/Persistence_representations/include/gudhi/Sliced_Wasserstein.h +++ b/src/Persistence_representations/include/gudhi/Sliced_Wasserstein.h @@ -51,15 +51,47 @@ class Sliced_Wasserstein { PD diagram; int approx; double sigma; + std::vector > projections, projections_diagonal; + public: - Sliced_Wasserstein(PD _diagram){diagram = _diagram; approx = 100; sigma = 0.001;} - Sliced_Wasserstein(PD _diagram, double _sigma, int _approx){diagram = _diagram; approx = _approx; sigma = _sigma;} + void build_rep(){ + + if(approx > 0){ + + double step = pi/this->approx; + int n = diagram.size(); + + for (int i = 0; i < this->approx; i++){ + std::vector l,l_diag; + for (int j = 0; j < n; j++){ + + double px = diagram[j].first; double py = diagram[j].second; + double proj_diag = (px+py)/2; + + l.push_back ( px * cos(-pi/2+i*step) + py * sin(-pi/2+i*step) ); + l_diag.push_back ( proj_diag * cos(-pi/2+i*step) + proj_diag * sin(-pi/2+i*step) ); + } + + std::sort(l.begin(), l.end()); std::sort(l_diag.begin(), l_diag.end()); + projections.push_back(l); projections_diagonal.push_back(l_diag); + + } + + } + + } + + Sliced_Wasserstein(PD _diagram){diagram = _diagram; approx = 100; sigma = 0.001; build_rep();} + Sliced_Wasserstein(PD _diagram, double _sigma, int _approx){diagram = _diagram; approx = _approx; sigma = _sigma; build_rep();} + PD get_diagram(){return this->diagram;} int get_approx(){return this->approx;} double get_sigma(){return this->sigma;} + + // ********************************** // Utils. @@ -227,28 +259,19 @@ class Sliced_Wasserstein { else{ + double step = pi/this->approx; - // Add projections onto diagonal. - int n1, n2; n1 = diagram1.size(); n2 = diagram2.size(); - for (int i = 0; i < n2; i++) - diagram1.emplace_back( (diagram2[i].first + diagram2[i].second)/2, (diagram2[i].first + diagram2[i].second)/2 ); - for (int i = 0; i < n1; i++) - diagram2.emplace_back( (diagram1[i].first + diagram1[i].second)/2, (diagram1[i].first + diagram1[i].second)/2 ); - int n = diagram1.size(); - - // Sort and compare all projections. for (int i = 0; i < this->approx; i++){ - std::vector > l1, l2; - for (int j = 0; j < n; j++){ - l1.emplace_back( j, diagram1[j].first*cos(-pi/2+i*step) + diagram1[j].second*sin(-pi/2+i*step) ); - l2.emplace_back( j, diagram2[j].first*cos(-pi/2+i*step) + diagram2[j].second*sin(-pi/2+i*step) ); - } - std::sort(l1.begin(),l1.end(), [=](const std::pair & p1, const std::pair & p2){return p1.second < p2.second;}); - std::sort(l2.begin(),l2.end(), [=](const std::pair & p1, const std::pair & p2){return p1.second < p2.second;}); - double f = 0; for (int j = 0; j < n; j++) f += std::abs(l1[j].second - l2[j].second); + + std::vector v1; std::vector l1 = this->projections[i]; std::vector l1bis = second.projections_diagonal[i]; std::merge(l1.begin(), l1.end(), l1bis.begin(), l1bis.end(), std::back_inserter(v1)); + std::vector v2; std::vector l2 = second.projections[i]; std::vector l2bis = this->projections_diagonal[i]; std::merge(l2.begin(), l2.end(), l2bis.begin(), l2bis.end(), std::back_inserter(v2)); + int n = v1.size(); double f = 0; + for (int j = 0; j < n; j++) f += std::abs(v1[j] - v2[j]); sw += f*step; + } + } return sw/pi; @@ -265,6 +288,8 @@ class Sliced_Wasserstein { } + + }; } // namespace Sliced_Wasserstein -- cgit v1.2.3