diff options
author | mcarrier <mcarrier@636b058d-ea47-450e-bf9e-a15bfbe3eedb> | 2018-03-06 17:50:39 +0000 |
---|---|---|
committer | mcarrier <mcarrier@636b058d-ea47-450e-bf9e-a15bfbe3eedb> | 2018-03-06 17:50:39 +0000 |
commit | 784697ab263e30c062e92aacfce36d1ed4070c6f (patch) | |
tree | d1a744bac07b68b449d086591c17e917da034697 | |
parent | d574f7f65acdd6dde92150879c06db5e6e0b75a9 (diff) |
git-svn-id: svn+ssh://scm.gforge.inria.fr/svnroot/gudhi/branches/kernels@3269 636b058d-ea47-450e-bf9e-a15bfbe3eedb
Former-commit-id: 17860628d3250f689152cdf65432c5a61d76f4d2
-rw-r--r-- | src/Persistence_representations/example/sliced_wasserstein.cpp | 2 | ||||
-rw-r--r-- | src/Persistence_representations/include/gudhi/Sliced_Wasserstein.h | 63 | ||||
-rw-r--r-- | src/cython/cython/kernels.pyx | 17 | ||||
-rw-r--r-- | src/cython/include/Kernels_interface.h | 15 |
4 files changed, 77 insertions, 20 deletions
diff --git a/src/Persistence_representations/example/sliced_wasserstein.cpp b/src/Persistence_representations/example/sliced_wasserstein.cpp index f153fbe8..2470029b 100644 --- a/src/Persistence_representations/example/sliced_wasserstein.cpp +++ b/src/Persistence_representations/example/sliced_wasserstein.cpp @@ -32,6 +32,8 @@ int main(int argc, char** argv) { std::vector<std::pair<double, double> > persistence1; std::vector<std::pair<double, double> > persistence2; + std::vector<std::vector<std::pair<double, double> > > set1; + std::vector<std::vector<std::pair<double, double> > > set2; persistence1.push_back(std::make_pair(1, 2)); persistence1.push_back(std::make_pair(6, 8)); diff --git a/src/Persistence_representations/include/gudhi/Sliced_Wasserstein.h b/src/Persistence_representations/include/gudhi/Sliced_Wasserstein.h index 6196e207..f2ec56b7 100644 --- a/src/Persistence_representations/include/gudhi/Sliced_Wasserstein.h +++ b/src/Persistence_representations/include/gudhi/Sliced_Wasserstein.h @@ -51,15 +51,47 @@ class Sliced_Wasserstein { PD diagram; int approx; double sigma; + std::vector<std::vector<double> > projections, projections_diagonal; + public: - Sliced_Wasserstein(PD _diagram){diagram = _diagram; approx = 100; sigma = 0.001;} - Sliced_Wasserstein(PD _diagram, double _sigma, int _approx){diagram = _diagram; approx = _approx; sigma = _sigma;} + void build_rep(){ + + if(approx > 0){ + + double step = pi/this->approx; + int n = diagram.size(); + + for (int i = 0; i < this->approx; i++){ + std::vector<double> l,l_diag; + for (int j = 0; j < n; j++){ + + double px = diagram[j].first; double py = diagram[j].second; + double proj_diag = (px+py)/2; + + l.push_back ( px * cos(-pi/2+i*step) + py * sin(-pi/2+i*step) ); + l_diag.push_back ( proj_diag * cos(-pi/2+i*step) + proj_diag * sin(-pi/2+i*step) ); + } + + std::sort(l.begin(), l.end()); std::sort(l_diag.begin(), l_diag.end()); + projections.push_back(l); projections_diagonal.push_back(l_diag); + + } + + } + + } + + Sliced_Wasserstein(PD _diagram){diagram = _diagram; approx = 100; sigma = 0.001; build_rep();} + Sliced_Wasserstein(PD _diagram, double _sigma, int _approx){diagram = _diagram; approx = _approx; sigma = _sigma; build_rep();} + PD get_diagram(){return this->diagram;} int get_approx(){return this->approx;} double get_sigma(){return this->sigma;} + + // ********************************** // Utils. @@ -227,28 +259,19 @@ class Sliced_Wasserstein { else{ + double step = pi/this->approx; - // Add projections onto diagonal. - int n1, n2; n1 = diagram1.size(); n2 = diagram2.size(); - for (int i = 0; i < n2; i++) - diagram1.emplace_back( (diagram2[i].first + diagram2[i].second)/2, (diagram2[i].first + diagram2[i].second)/2 ); - for (int i = 0; i < n1; i++) - diagram2.emplace_back( (diagram1[i].first + diagram1[i].second)/2, (diagram1[i].first + diagram1[i].second)/2 ); - int n = diagram1.size(); - - // Sort and compare all projections. for (int i = 0; i < this->approx; i++){ - std::vector<std::pair<int,double> > l1, l2; - for (int j = 0; j < n; j++){ - l1.emplace_back( j, diagram1[j].first*cos(-pi/2+i*step) + diagram1[j].second*sin(-pi/2+i*step) ); - l2.emplace_back( j, diagram2[j].first*cos(-pi/2+i*step) + diagram2[j].second*sin(-pi/2+i*step) ); - } - std::sort(l1.begin(),l1.end(), [=](const std::pair<int,double> & p1, const std::pair<int,double> & p2){return p1.second < p2.second;}); - std::sort(l2.begin(),l2.end(), [=](const std::pair<int,double> & p1, const std::pair<int,double> & p2){return p1.second < p2.second;}); - double f = 0; for (int j = 0; j < n; j++) f += std::abs(l1[j].second - l2[j].second); + + std::vector<double> v1; std::vector<double> l1 = this->projections[i]; std::vector<double> l1bis = second.projections_diagonal[i]; std::merge(l1.begin(), l1.end(), l1bis.begin(), l1bis.end(), std::back_inserter(v1)); + std::vector<double> v2; std::vector<double> l2 = second.projections[i]; std::vector<double> l2bis = this->projections_diagonal[i]; std::merge(l2.begin(), l2.end(), l2bis.begin(), l2bis.end(), std::back_inserter(v2)); + int n = v1.size(); double f = 0; + for (int j = 0; j < n; j++) f += std::abs(v1[j] - v2[j]); sw += f*step; + } + } return sw/pi; @@ -265,6 +288,8 @@ class Sliced_Wasserstein { } + + }; } // namespace Sliced_Wasserstein diff --git a/src/cython/cython/kernels.pyx b/src/cython/cython/kernels.pyx index 220fc6ce..f8798aab 100644 --- a/src/cython/cython/kernels.pyx +++ b/src/cython/cython/kernels.pyx @@ -30,7 +30,8 @@ __copyright__ = "Copyright (C) 2018 INRIA" __license__ = "GPL v3" cdef extern from "Kernels_interface.h" namespace "Gudhi::persistence_diagram": - double sw(vector[pair[double, double]], vector[pair[double, double]], double, int) + double sw (vector[pair[double, double]], vector[pair[double, double]], double, int) + vector[vector[double]] sw_matrix (vector[vector[pair[double, double]]], vector[vector[pair[double, double]]], double, int) def sliced_wasserstein(diagram_1, diagram_2, sigma = 1, N = 100): """ @@ -45,3 +46,17 @@ def sliced_wasserstein(diagram_1, diagram_2, sigma = 1, N = 100): :returns: the sliced wasserstein kernel. """ return sw(diagram_1, diagram_2, sigma, N) + +def sliced_wasserstein_matrix(diagrams_1, diagrams_2, sigma = 1, N = 100): + """ + + :param diagram_1: The first set of diagrams. + :type diagram_1: vector[vector[pair[double, double]]] + :param diagram_2: The second set of diagrams. + :type diagram_2: vector[vector[pair[double, double]]] + :param sigma: bandwidth of Gaussian + :param N: number of directions + + :returns: the sliced wasserstein kernel matrix. + """ + return sw_matrix(diagrams_1, diagrams_2, sigma, N) diff --git a/src/cython/include/Kernels_interface.h b/src/cython/include/Kernels_interface.h index 9eb610b0..ef136731 100644 --- a/src/cython/include/Kernels_interface.h +++ b/src/cython/include/Kernels_interface.h @@ -41,6 +41,21 @@ namespace persistence_diagram { return sw1.compute_scalar_product(sw2); } + std::vector<std::vector<double> > sw_matrix(const std::vector<std::vector<std::pair<double, double> > >& s1, + const std::vector<std::vector<std::pair<double, double> > >& s2, + double sigma, int N){ + std::vector<std::vector<double> > matrix; + std::vector<Gudhi::Persistence_representations::Sliced_Wasserstein> ss1; + int num_diag_1 = s1.size(); for(int i = 0; i < num_diag_1; i++){Gudhi::Persistence_representations::Sliced_Wasserstein sw1(s1[i], sigma, N); ss1.push_back(sw1);} + std::vector<Gudhi::Persistence_representations::Sliced_Wasserstein> ss2; + int num_diag_2 = s2.size(); for(int i = 0; i < num_diag_2; i++){Gudhi::Persistence_representations::Sliced_Wasserstein sw2(s2[i], sigma, N); ss2.push_back(sw2);} + for(int i = 0; i < num_diag_1; i++){ + std::cout << 100.0*i/num_diag_1 << " %" << std::endl; + std::vector<double> ps; for(int j = 0; j < num_diag_2; j++) ps.push_back(ss1[i].compute_scalar_product(ss2[j])); matrix.push_back(ps); + } + return matrix; + } + } // namespace persistence_diagram } // namespace Gudhi |