diff options
author | mcarrier <mcarrier@636b058d-ea47-450e-bf9e-a15bfbe3eedb> | 2018-06-14 09:25:16 +0000 |
---|---|---|
committer | mcarrier <mcarrier@636b058d-ea47-450e-bf9e-a15bfbe3eedb> | 2018-06-14 09:25:16 +0000 |
commit | 10c6f6be72a2631cd1a1d28ed61343d55bd2b759 (patch) | |
tree | 1b3d2fd6332e1f083ed8c8f65e696fc0a38c4052 /src/cython/include/Kernels_interface.h | |
parent | 9b3f3e610646b9a2d35369bdb7a6f272e816eb34 (diff) |
small modif on PWGK
git-svn-id: svn+ssh://scm.gforge.inria.fr/svnroot/gudhi/branches/kernels@3612 636b058d-ea47-450e-bf9e-a15bfbe3eedb
Former-commit-id: c0b8c70acfbf1a7f4b7bddb69a161086fb249c76
Diffstat (limited to 'src/cython/include/Kernels_interface.h')
-rw-r--r-- | src/cython/include/Kernels_interface.h | 21 |
1 files changed, 13 insertions, 8 deletions
diff --git a/src/cython/include/Kernels_interface.h b/src/cython/include/Kernels_interface.h index dd46656f..a07d7820 100644 --- a/src/cython/include/Kernels_interface.h +++ b/src/cython/include/Kernels_interface.h @@ -23,6 +23,7 @@ #ifndef INCLUDE_KERNELS_INTERFACE_H_ #define INCLUDE_KERNELS_INTERFACE_H_ +#include <gudhi/common_persistence_representations.h> #include <gudhi/Sliced_Wasserstein.h> #include <gudhi/Persistence_weighted_gaussian.h> #include <gudhi/Weight_functions.h> @@ -46,9 +47,13 @@ namespace persistence_diagram { return sw1.compute_scalar_product(sw2); } - double pwg(const std::vector<std::pair<double, double>>& diag1, const std::vector<std::pair<double, double>>& diag2, double sigma, int N, double C, double p) { - Gudhi::Persistence_representations::Persistence_weighted_gaussian pwg1(diag1, sigma, N, Gudhi::Persistence_representations::arctan_weight(C,p)); - Gudhi::Persistence_representations::Persistence_weighted_gaussian pwg2(diag2, sigma, N, Gudhi::Persistence_representations::arctan_weight(C,p)); + double pwg(const std::vector<std::pair<double, double>>& diag1, const std::vector<std::pair<double, double>>& diag2, int N, std::string weight, double sigma, double C, double p) { + Gudhi::Persistence_representations::Weight weight_fn; + if(weight.compare("linear") == 0) weight_fn = Gudhi::Persistence_representations::linear_weight; + if(weight.compare("arctan") == 0) weight_fn = Gudhi::Persistence_representations::arctan_weight(C,p); + if(weight.compare("const") == 0) weight_fn = Gudhi::Persistence_representations::const_weight; + Gudhi::Persistence_representations::Persistence_weighted_gaussian pwg1(diag1, sigma, N, weight_fn); + Gudhi::Persistence_representations::Persistence_weighted_gaussian pwg2(diag2, sigma, N, weight_fn); return pwg1.compute_scalar_product(pwg2); } @@ -87,11 +92,11 @@ namespace persistence_diagram { return matrix; } - std::vector<std::vector<double> > pwg_matrix(const std::vector<std::vector<std::pair<double, double> > >& s1, const std::vector<std::vector<std::pair<double, double> > >& s2, double sigma, int N, double C, double p){ + std::vector<std::vector<double> > pwg_matrix(const std::vector<std::vector<std::pair<double, double> > >& s1, const std::vector<std::vector<std::pair<double, double> > >& s2, int N, std::string weight, double sigma, double C, double p){ std::vector<std::vector<double> > matrix; int num_diag_1 = s1.size(); int num_diag_2 = s2.size(); for(int i = 0; i < num_diag_1; i++){ std::cout << 100.0*i/num_diag_1 << " %" << std::endl; - std::vector<double> ps; for(int j = 0; j < num_diag_2; j++) ps.push_back(pwg(s1[i], s2[j], sigma, N, C, p)); matrix.push_back(ps); + std::vector<double> ps; for(int j = 0; j < num_diag_2; j++) ps.push_back(pwg(s1[i], s2[j], N, weight, sigma, C, p)); matrix.push_back(ps); } return matrix; } @@ -99,13 +104,13 @@ namespace persistence_diagram { std::vector<std::vector<double> > pss_matrix(const std::vector<std::vector<std::pair<double, double> > >& s1, const std::vector<std::vector<std::pair<double, double> > >& s2, double sigma, int N){ std::vector<std::vector<std::pair<double, double> > > ss1, ss2; std::vector<std::vector<double> > matrix; int num_diag_1 = s1.size(); int num_diag_2 = s2.size(); for(int i = 0; i < num_diag_1; i++){ - std::vector<std::pair<double, double>> pd1 = s1[i]; int numpts = s1[i].size(); + std::vector<std::pair<double, double>> pd1 = s1[i]; int numpts = s1[i].size(); for(int j = 0; j < numpts; j++) pd1.emplace_back(s1[i][j].second,s1[i][j].first); ss1.push_back(pd1); } - + for(int i = 0; i < num_diag_2; i++){ - std::vector<std::pair<double, double>> pd2 = s2[i]; int numpts = s2[i].size(); + std::vector<std::pair<double, double>> pd2 = s2[i]; int numpts = s2[i].size(); for(int j = 0; j < numpts; j++) pd2.emplace_back(s2[i][j].second,s2[i][j].first); ss2.push_back(pd2); } |