summaryrefslogtreecommitdiff
path: root/src/Persistence_representations/include/gudhi/Sliced_Wasserstein.h
diff options
context:
space:
mode:
authormcarrier <mcarrier@636b058d-ea47-450e-bf9e-a15bfbe3eedb>2018-02-16 15:43:29 +0000
committermcarrier <mcarrier@636b058d-ea47-450e-bf9e-a15bfbe3eedb>2018-02-16 15:43:29 +0000
commitff0dc023588e3b33bc4bc7f26ce1f68c647ae441 (patch)
treea6f839885acbbefe07ffeeca996eea77dc136e96 /src/Persistence_representations/include/gudhi/Sliced_Wasserstein.h
parent69c683e663329d8410ca77c371f877bcc3bef906 (diff)
git-svn-id: svn+ssh://scm.gforge.inria.fr/svnroot/gudhi/branches/kernels@3251 636b058d-ea47-450e-bf9e-a15bfbe3eedb
Former-commit-id: 80f084fc990df6e5c6b60ac83514220aba2ceb5c
Diffstat (limited to 'src/Persistence_representations/include/gudhi/Sliced_Wasserstein.h')
-rw-r--r--src/Persistence_representations/include/gudhi/Sliced_Wasserstein.h26
1 files changed, 16 insertions, 10 deletions
diff --git a/src/Persistence_representations/include/gudhi/Sliced_Wasserstein.h b/src/Persistence_representations/include/gudhi/Sliced_Wasserstein.h
index 4fa6151f..ad1a6c42 100644
--- a/src/Persistence_representations/include/gudhi/Sliced_Wasserstein.h
+++ b/src/Persistence_representations/include/gudhi/Sliced_Wasserstein.h
@@ -53,11 +53,16 @@ class Sliced_Wasserstein {
protected:
PD diagram;
+ int approx;
+ double sigma;
public:
- Sliced_Wasserstein(PD _diagram){diagram = _diagram;}
+ Sliced_Wasserstein(PD _diagram){diagram = _diagram; approx = 100; sigma = 0.001;}
+ Sliced_Wasserstein(PD _diagram, double _sigma, int _approx){diagram = _diagram; approx = _approx; sigma = _sigma;}
PD get_diagram(){return this->diagram;}
+ int get_approx(){return this->approx;}
+ double get_sigma(){return this->sigma;}
// **********************************
@@ -130,11 +135,11 @@ class Sliced_Wasserstein {
// Scalar product + distance.
// **********************************
- double compute_sliced_wasserstein_distance(Sliced_Wasserstein second, int approx) {
+ double compute_sliced_wasserstein_distance(Sliced_Wasserstein second) {
PD diagram1 = this->diagram; PD diagram2 = second.diagram; double sw = 0;
- if(approx == -1){
+ if(this->approx == -1){
// Add projections onto diagonal.
int n1, n2; n1 = diagram1.size(); n2 = diagram2.size(); double max_ordinate = std::numeric_limits<double>::lowest();
@@ -226,7 +231,7 @@ class Sliced_Wasserstein {
else{
- double step = pi/approx;
+ double step = pi/this->approx;
// Add projections onto diagonal.
int n1, n2; n1 = diagram1.size(); n2 = diagram2.size();
@@ -238,7 +243,7 @@ class Sliced_Wasserstein {
// Sort and compare all projections.
#ifdef GUDHI_USE_TBB
- tbb::parallel_for(0, approx, [&](int i){
+ tbb::parallel_for(0, this->approx, [&](int i){
std::vector<std::pair<int,double> > l1, l2;
for (int j = 0; j < n; j++){
l1.emplace_back( j, diagram1[j].first*cos(-pi/2+i*step) + diagram1[j].second*sin(-pi/2+i*step) );
@@ -250,7 +255,7 @@ class Sliced_Wasserstein {
sw += f*step;
});
#else
- for (int i = 0; i < approx; i++){
+ for (int i = 0; i < this->approx; i++){
std::vector<std::pair<int,double> > l1, l2;
for (int j = 0; j < n; j++){
l1.emplace_back( j, diagram1[j].first*cos(-pi/2+i*step) + diagram1[j].second*sin(-pi/2+i*step) );
@@ -268,12 +273,13 @@ class Sliced_Wasserstein {
}
- double compute_scalar_product(Sliced_Wasserstein second, double sigma, int approx = 100) {
- return std::exp(-compute_sliced_wasserstein_distance(second, approx)/(2*sigma*sigma));
+ double compute_scalar_product(Sliced_Wasserstein second){
+ return std::exp(-compute_sliced_wasserstein_distance(second)/(2*this->sigma*this->sigma));
}
- double distance(Sliced_Wasserstein second, double sigma, int approx = 100, double power = 1) {
- return std::pow(this->compute_scalar_product(*this, sigma, approx) + second.compute_scalar_product(second, sigma, approx)-2*this->compute_scalar_product(second, sigma, approx), power/2.0);
+ double distance(Sliced_Wasserstein second, double power = 1) {
+ if(this->sigma != second.sigma || this->approx != second.approx){std::cout << "Error: different representations!" << std::endl; return 0;}
+ else return std::pow(this->compute_scalar_product(*this) + second.compute_scalar_product(second)-2*this->compute_scalar_product(second), power/2.0);
}