summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authormcarrier <mcarrier@636b058d-ea47-450e-bf9e-a15bfbe3eedb>2018-06-14 09:25:16 +0000
committermcarrier <mcarrier@636b058d-ea47-450e-bf9e-a15bfbe3eedb>2018-06-14 09:25:16 +0000
commit10c6f6be72a2631cd1a1d28ed61343d55bd2b759 (patch)
tree1b3d2fd6332e1f083ed8c8f65e696fc0a38c4052
parent9b3f3e610646b9a2d35369bdb7a6f272e816eb34 (diff)
small modif on PWGK
git-svn-id: svn+ssh://scm.gforge.inria.fr/svnroot/gudhi/branches/kernels@3612 636b058d-ea47-450e-bf9e-a15bfbe3eedb Former-commit-id: c0b8c70acfbf1a7f4b7bddb69a161086fb249c76
-rw-r--r--src/cython/cython/kernels.pyx26
-rw-r--r--src/cython/cython/vectors.pyx3
-rw-r--r--src/cython/include/Kernels_interface.h21
3 files changed, 30 insertions, 20 deletions
diff --git a/src/cython/cython/kernels.pyx b/src/cython/cython/kernels.pyx
index 0cb296ec..cb8fc0fd 100644
--- a/src/cython/cython/kernels.pyx
+++ b/src/cython/cython/kernels.pyx
@@ -34,8 +34,8 @@ cdef extern from "Kernels_interface.h" namespace "Gudhi::persistence_diagram":
vector[vector[double]] sw_matrix (vector[vector[pair[double, double]]], vector[vector[pair[double, double]]], double, int)
double pss (vector[pair[double, double]], vector[pair[double, double]], double, int)
vector[vector[double]] pss_matrix (vector[vector[pair[double, double]]], vector[vector[pair[double, double]]], double, int)
- double pwg (vector[pair[double, double]], vector[pair[double, double]], double, int, double, double)
- vector[vector[double]] pwg_matrix (vector[vector[pair[double, double]]], vector[vector[pair[double, double]]], double, int, double, double)
+ double pwg (vector[pair[double, double]], vector[pair[double, double]], int, string, double, double, double)
+ vector[vector[double]] pwg_matrix (vector[vector[pair[double, double]]], vector[vector[pair[double, double]]], int, string, double, double, double)
def sliced_wasserstein(diagram_1, diagram_2, sigma = 1, N = 100):
"""
@@ -65,37 +65,39 @@ def sliced_wasserstein_matrix(diagrams_1, diagrams_2, sigma = 1, N = 100):
"""
return sw_matrix(diagrams_1, diagrams_2, sigma, N)
-def persistence_weighted_gaussian(diagram_1, diagram_2, sigma = 1, N = 100, C = 1, p = 1):
+def persistence_weighted_gaussian(diagram_1, diagram_2, N = 100, weight = "arctan", sigma = 1.0, C = 1.0, p = 1.0):
"""
:param diagram_1: The first diagram.
:type diagram_1: vector[pair[double, double]]
:param diagram_2: The second diagram.
:type diagram_2: vector[pair[double, double]]
- :param sigma: bandwidth of Gaussian
:param N: number of Fourier features
- :param C: cost of persistence weight
- :param p: power of persistence weight
+ :param weight: weight to use for the diagram points
+ :param sigma: bandwidth of Gaussian
+ :param C: cost of arctan persistence weight
+ :param p: power of arctan persistence weight
:returns: the persistence weighted gaussian kernel.
"""
- return pwg(diagram_1, diagram_2, sigma, N, C, p)
+ return pwg(diagram_1, diagram_2, N, weight, sigma, C, p)
-def persistence_weighted_gaussian_matrix(diagrams_1, diagrams_2, sigma = 1, N = 100, C = 1, p = 1):
+def persistence_weighted_gaussian_matrix(diagrams_1, diagrams_2, N = 100, weight = "arctan", sigma = 1.0, C = 1.0, p = 1.0):
"""
:param diagram_1: The first set of diagrams.
:type diagram_1: vector[vector[pair[double, double]]]
:param diagram_2: The second set of diagrams.
:type diagram_2: vector[vector[pair[double, double]]]
- :param sigma: bandwidth of Gaussian
:param N: number of Fourier features
- :param C: cost of persistence weight
- :param p: power of persistence weight
+ :param weight: weight to use for the diagram points
+ :param sigma: bandwidth of Gaussian
+ :param C: cost of arctan persistence weight
+ :param p: power of arctan persistence weight
:returns: the persistence weighted gaussian kernel matrix.
"""
- return pwg_matrix(diagrams_1, diagrams_2, sigma, N, C, p)
+ return pwg_matrix(diagrams_1, diagrams_2, N, weight, sigma, C, p)
def persistence_scale_space(diagram_1, diagram_2, sigma = 1, N = 100):
"""
diff --git a/src/cython/cython/vectors.pyx b/src/cython/cython/vectors.pyx
index 42390ae6..af53f739 100644
--- a/src/cython/cython/vectors.pyx
+++ b/src/cython/cython/vectors.pyx
@@ -58,7 +58,10 @@ def persistence_image(diagram, min_x = 0.0, max_x = 1.0, res_x = 10, min_y = 0.0
:param min_x: Minimum ordinate
:param max_x: Maximum ordinate
:param res_x: Number of ordinate pixels
+ :param weight: Weight to use for the diagram points
:param sigma: bandwidth of Gaussian
+ :param C: cost of arctan persistence weight
+ :param p: power of arctan persistence weight
:returns: the persistence image
"""
diff --git a/src/cython/include/Kernels_interface.h b/src/cython/include/Kernels_interface.h
index dd46656f..a07d7820 100644
--- a/src/cython/include/Kernels_interface.h
+++ b/src/cython/include/Kernels_interface.h
@@ -23,6 +23,7 @@
#ifndef INCLUDE_KERNELS_INTERFACE_H_
#define INCLUDE_KERNELS_INTERFACE_H_
+#include <gudhi/common_persistence_representations.h>
#include <gudhi/Sliced_Wasserstein.h>
#include <gudhi/Persistence_weighted_gaussian.h>
#include <gudhi/Weight_functions.h>
@@ -46,9 +47,13 @@ namespace persistence_diagram {
return sw1.compute_scalar_product(sw2);
}
- double pwg(const std::vector<std::pair<double, double>>& diag1, const std::vector<std::pair<double, double>>& diag2, double sigma, int N, double C, double p) {
- Gudhi::Persistence_representations::Persistence_weighted_gaussian pwg1(diag1, sigma, N, Gudhi::Persistence_representations::arctan_weight(C,p));
- Gudhi::Persistence_representations::Persistence_weighted_gaussian pwg2(diag2, sigma, N, Gudhi::Persistence_representations::arctan_weight(C,p));
+ double pwg(const std::vector<std::pair<double, double>>& diag1, const std::vector<std::pair<double, double>>& diag2, int N, std::string weight, double sigma, double C, double p) {
+ Gudhi::Persistence_representations::Weight weight_fn;
+ if(weight.compare("linear") == 0) weight_fn = Gudhi::Persistence_representations::linear_weight;
+ if(weight.compare("arctan") == 0) weight_fn = Gudhi::Persistence_representations::arctan_weight(C,p);
+ if(weight.compare("const") == 0) weight_fn = Gudhi::Persistence_representations::const_weight;
+ Gudhi::Persistence_representations::Persistence_weighted_gaussian pwg1(diag1, sigma, N, weight_fn);
+ Gudhi::Persistence_representations::Persistence_weighted_gaussian pwg2(diag2, sigma, N, weight_fn);
return pwg1.compute_scalar_product(pwg2);
}
@@ -87,11 +92,11 @@ namespace persistence_diagram {
return matrix;
}
- std::vector<std::vector<double> > pwg_matrix(const std::vector<std::vector<std::pair<double, double> > >& s1, const std::vector<std::vector<std::pair<double, double> > >& s2, double sigma, int N, double C, double p){
+ std::vector<std::vector<double> > pwg_matrix(const std::vector<std::vector<std::pair<double, double> > >& s1, const std::vector<std::vector<std::pair<double, double> > >& s2, int N, std::string weight, double sigma, double C, double p){
std::vector<std::vector<double> > matrix; int num_diag_1 = s1.size(); int num_diag_2 = s2.size();
for(int i = 0; i < num_diag_1; i++){
std::cout << 100.0*i/num_diag_1 << " %" << std::endl;
- std::vector<double> ps; for(int j = 0; j < num_diag_2; j++) ps.push_back(pwg(s1[i], s2[j], sigma, N, C, p)); matrix.push_back(ps);
+ std::vector<double> ps; for(int j = 0; j < num_diag_2; j++) ps.push_back(pwg(s1[i], s2[j], N, weight, sigma, C, p)); matrix.push_back(ps);
}
return matrix;
}
@@ -99,13 +104,13 @@ namespace persistence_diagram {
std::vector<std::vector<double> > pss_matrix(const std::vector<std::vector<std::pair<double, double> > >& s1, const std::vector<std::vector<std::pair<double, double> > >& s2, double sigma, int N){
std::vector<std::vector<std::pair<double, double> > > ss1, ss2; std::vector<std::vector<double> > matrix; int num_diag_1 = s1.size(); int num_diag_2 = s2.size();
for(int i = 0; i < num_diag_1; i++){
- std::vector<std::pair<double, double>> pd1 = s1[i]; int numpts = s1[i].size();
+ std::vector<std::pair<double, double>> pd1 = s1[i]; int numpts = s1[i].size();
for(int j = 0; j < numpts; j++) pd1.emplace_back(s1[i][j].second,s1[i][j].first);
ss1.push_back(pd1);
}
-
+
for(int i = 0; i < num_diag_2; i++){
- std::vector<std::pair<double, double>> pd2 = s2[i]; int numpts = s2[i].size();
+ std::vector<std::pair<double, double>> pd2 = s2[i]; int numpts = s2[i].size();
for(int j = 0; j < numpts; j++) pd2.emplace_back(s2[i][j].second,s2[i][j].first);
ss2.push_back(pd2);
}