diff options
Diffstat (limited to 'src/Persistence_representations/include/gudhi/Sliced_Wasserstein.h')
-rw-r--r-- | src/Persistence_representations/include/gudhi/Sliced_Wasserstein.h | 26 |
1 files changed, 16 insertions, 10 deletions
diff --git a/src/Persistence_representations/include/gudhi/Sliced_Wasserstein.h b/src/Persistence_representations/include/gudhi/Sliced_Wasserstein.h index 4fa6151f..ad1a6c42 100644 --- a/src/Persistence_representations/include/gudhi/Sliced_Wasserstein.h +++ b/src/Persistence_representations/include/gudhi/Sliced_Wasserstein.h @@ -53,11 +53,16 @@ class Sliced_Wasserstein { protected: PD diagram; + int approx; + double sigma; public: - Sliced_Wasserstein(PD _diagram){diagram = _diagram;} + Sliced_Wasserstein(PD _diagram){diagram = _diagram; approx = 100; sigma = 0.001;} + Sliced_Wasserstein(PD _diagram, double _sigma, int _approx){diagram = _diagram; approx = _approx; sigma = _sigma;} PD get_diagram(){return this->diagram;} + int get_approx(){return this->approx;} + double get_sigma(){return this->sigma;} // ********************************** @@ -130,11 +135,11 @@ class Sliced_Wasserstein { // Scalar product + distance. // ********************************** - double compute_sliced_wasserstein_distance(Sliced_Wasserstein second, int approx) { + double compute_sliced_wasserstein_distance(Sliced_Wasserstein second) { PD diagram1 = this->diagram; PD diagram2 = second.diagram; double sw = 0; - if(approx == -1){ + if(this->approx == -1){ // Add projections onto diagonal. int n1, n2; n1 = diagram1.size(); n2 = diagram2.size(); double max_ordinate = std::numeric_limits<double>::lowest(); @@ -226,7 +231,7 @@ class Sliced_Wasserstein { else{ - double step = pi/approx; + double step = pi/this->approx; // Add projections onto diagonal. int n1, n2; n1 = diagram1.size(); n2 = diagram2.size(); @@ -238,7 +243,7 @@ class Sliced_Wasserstein { // Sort and compare all projections. #ifdef GUDHI_USE_TBB - tbb::parallel_for(0, approx, [&](int i){ + tbb::parallel_for(0, this->approx, [&](int i){ std::vector<std::pair<int,double> > l1, l2; for (int j = 0; j < n; j++){ l1.emplace_back( j, diagram1[j].first*cos(-pi/2+i*step) + diagram1[j].second*sin(-pi/2+i*step) ); @@ -250,7 +255,7 @@ class Sliced_Wasserstein { sw += f*step; }); #else - for (int i = 0; i < approx; i++){ + for (int i = 0; i < this->approx; i++){ std::vector<std::pair<int,double> > l1, l2; for (int j = 0; j < n; j++){ l1.emplace_back( j, diagram1[j].first*cos(-pi/2+i*step) + diagram1[j].second*sin(-pi/2+i*step) ); @@ -268,12 +273,13 @@ class Sliced_Wasserstein { } - double compute_scalar_product(Sliced_Wasserstein second, double sigma, int approx = 100) { - return std::exp(-compute_sliced_wasserstein_distance(second, approx)/(2*sigma*sigma)); + double compute_scalar_product(Sliced_Wasserstein second){ + return std::exp(-compute_sliced_wasserstein_distance(second)/(2*this->sigma*this->sigma)); } - double distance(Sliced_Wasserstein second, double sigma, int approx = 100, double power = 1) { - return std::pow(this->compute_scalar_product(*this, sigma, approx) + second.compute_scalar_product(second, sigma, approx)-2*this->compute_scalar_product(second, sigma, approx), power/2.0); + double distance(Sliced_Wasserstein second, double power = 1) { + if(this->sigma != second.sigma || this->approx != second.approx){std::cout << "Error: different representations!" << std::endl; return 0;} + else return std::pow(this->compute_scalar_product(*this) + second.compute_scalar_product(second)-2*this->compute_scalar_product(second), power/2.0); } |