summaryrefslogtreecommitdiff
path: root/src/cython/include/Kernels_interface.h
diff options
context:
space:
mode:
authormcarrier <mcarrier@636b058d-ea47-450e-bf9e-a15bfbe3eedb>2018-06-14 09:25:16 +0000
committermcarrier <mcarrier@636b058d-ea47-450e-bf9e-a15bfbe3eedb>2018-06-14 09:25:16 +0000
commit10c6f6be72a2631cd1a1d28ed61343d55bd2b759 (patch)
tree1b3d2fd6332e1f083ed8c8f65e696fc0a38c4052 /src/cython/include/Kernels_interface.h
parent9b3f3e610646b9a2d35369bdb7a6f272e816eb34 (diff)
small modif on PWGK
git-svn-id: svn+ssh://scm.gforge.inria.fr/svnroot/gudhi/branches/kernels@3612 636b058d-ea47-450e-bf9e-a15bfbe3eedb Former-commit-id: c0b8c70acfbf1a7f4b7bddb69a161086fb249c76
Diffstat (limited to 'src/cython/include/Kernels_interface.h')
-rw-r--r--src/cython/include/Kernels_interface.h21
1 files changed, 13 insertions, 8 deletions
diff --git a/src/cython/include/Kernels_interface.h b/src/cython/include/Kernels_interface.h
index dd46656f..a07d7820 100644
--- a/src/cython/include/Kernels_interface.h
+++ b/src/cython/include/Kernels_interface.h
@@ -23,6 +23,7 @@
#ifndef INCLUDE_KERNELS_INTERFACE_H_
#define INCLUDE_KERNELS_INTERFACE_H_
+#include <gudhi/common_persistence_representations.h>
#include <gudhi/Sliced_Wasserstein.h>
#include <gudhi/Persistence_weighted_gaussian.h>
#include <gudhi/Weight_functions.h>
@@ -46,9 +47,13 @@ namespace persistence_diagram {
return sw1.compute_scalar_product(sw2);
}
- double pwg(const std::vector<std::pair<double, double>>& diag1, const std::vector<std::pair<double, double>>& diag2, double sigma, int N, double C, double p) {
- Gudhi::Persistence_representations::Persistence_weighted_gaussian pwg1(diag1, sigma, N, Gudhi::Persistence_representations::arctan_weight(C,p));
- Gudhi::Persistence_representations::Persistence_weighted_gaussian pwg2(diag2, sigma, N, Gudhi::Persistence_representations::arctan_weight(C,p));
+ double pwg(const std::vector<std::pair<double, double>>& diag1, const std::vector<std::pair<double, double>>& diag2, int N, std::string weight, double sigma, double C, double p) {
+ Gudhi::Persistence_representations::Weight weight_fn;
+ if(weight.compare("linear") == 0) weight_fn = Gudhi::Persistence_representations::linear_weight;
+ if(weight.compare("arctan") == 0) weight_fn = Gudhi::Persistence_representations::arctan_weight(C,p);
+ if(weight.compare("const") == 0) weight_fn = Gudhi::Persistence_representations::const_weight;
+ Gudhi::Persistence_representations::Persistence_weighted_gaussian pwg1(diag1, sigma, N, weight_fn);
+ Gudhi::Persistence_representations::Persistence_weighted_gaussian pwg2(diag2, sigma, N, weight_fn);
return pwg1.compute_scalar_product(pwg2);
}
@@ -87,11 +92,11 @@ namespace persistence_diagram {
return matrix;
}
- std::vector<std::vector<double> > pwg_matrix(const std::vector<std::vector<std::pair<double, double> > >& s1, const std::vector<std::vector<std::pair<double, double> > >& s2, double sigma, int N, double C, double p){
+ std::vector<std::vector<double> > pwg_matrix(const std::vector<std::vector<std::pair<double, double> > >& s1, const std::vector<std::vector<std::pair<double, double> > >& s2, int N, std::string weight, double sigma, double C, double p){
std::vector<std::vector<double> > matrix; int num_diag_1 = s1.size(); int num_diag_2 = s2.size();
for(int i = 0; i < num_diag_1; i++){
std::cout << 100.0*i/num_diag_1 << " %" << std::endl;
- std::vector<double> ps; for(int j = 0; j < num_diag_2; j++) ps.push_back(pwg(s1[i], s2[j], sigma, N, C, p)); matrix.push_back(ps);
+ std::vector<double> ps; for(int j = 0; j < num_diag_2; j++) ps.push_back(pwg(s1[i], s2[j], N, weight, sigma, C, p)); matrix.push_back(ps);
}
return matrix;
}
@@ -99,13 +104,13 @@ namespace persistence_diagram {
std::vector<std::vector<double> > pss_matrix(const std::vector<std::vector<std::pair<double, double> > >& s1, const std::vector<std::vector<std::pair<double, double> > >& s2, double sigma, int N){
std::vector<std::vector<std::pair<double, double> > > ss1, ss2; std::vector<std::vector<double> > matrix; int num_diag_1 = s1.size(); int num_diag_2 = s2.size();
for(int i = 0; i < num_diag_1; i++){
- std::vector<std::pair<double, double>> pd1 = s1[i]; int numpts = s1[i].size();
+ std::vector<std::pair<double, double>> pd1 = s1[i]; int numpts = s1[i].size();
for(int j = 0; j < numpts; j++) pd1.emplace_back(s1[i][j].second,s1[i][j].first);
ss1.push_back(pd1);
}
-
+
for(int i = 0; i < num_diag_2; i++){
- std::vector<std::pair<double, double>> pd2 = s2[i]; int numpts = s2[i].size();
+ std::vector<std::pair<double, double>> pd2 = s2[i]; int numpts = s2[i].size();
for(int j = 0; j < numpts; j++) pd2.emplace_back(s2[i][j].second,s2[i][j].first);
ss2.push_back(pd2);
}