summaryrefslogtreecommitdiff
path: root/src/Persistence_representations
diff options
context:
space:
mode:
authormcarrier <mcarrier@636b058d-ea47-450e-bf9e-a15bfbe3eedb>2018-03-06 17:50:39 +0000
committermcarrier <mcarrier@636b058d-ea47-450e-bf9e-a15bfbe3eedb>2018-03-06 17:50:39 +0000
commit784697ab263e30c062e92aacfce36d1ed4070c6f (patch)
treed1a744bac07b68b449d086591c17e917da034697 /src/Persistence_representations
parentd574f7f65acdd6dde92150879c06db5e6e0b75a9 (diff)
git-svn-id: svn+ssh://scm.gforge.inria.fr/svnroot/gudhi/branches/kernels@3269 636b058d-ea47-450e-bf9e-a15bfbe3eedb
Former-commit-id: 17860628d3250f689152cdf65432c5a61d76f4d2
Diffstat (limited to 'src/Persistence_representations')
-rw-r--r--src/Persistence_representations/example/sliced_wasserstein.cpp2
-rw-r--r--src/Persistence_representations/include/gudhi/Sliced_Wasserstein.h63
2 files changed, 46 insertions, 19 deletions
diff --git a/src/Persistence_representations/example/sliced_wasserstein.cpp b/src/Persistence_representations/example/sliced_wasserstein.cpp
index f153fbe8..2470029b 100644
--- a/src/Persistence_representations/example/sliced_wasserstein.cpp
+++ b/src/Persistence_representations/example/sliced_wasserstein.cpp
@@ -32,6 +32,8 @@ int main(int argc, char** argv) {
std::vector<std::pair<double, double> > persistence1;
std::vector<std::pair<double, double> > persistence2;
+ std::vector<std::vector<std::pair<double, double> > > set1;
+ std::vector<std::vector<std::pair<double, double> > > set2;
persistence1.push_back(std::make_pair(1, 2));
persistence1.push_back(std::make_pair(6, 8));
diff --git a/src/Persistence_representations/include/gudhi/Sliced_Wasserstein.h b/src/Persistence_representations/include/gudhi/Sliced_Wasserstein.h
index 6196e207..f2ec56b7 100644
--- a/src/Persistence_representations/include/gudhi/Sliced_Wasserstein.h
+++ b/src/Persistence_representations/include/gudhi/Sliced_Wasserstein.h
@@ -51,15 +51,47 @@ class Sliced_Wasserstein {
PD diagram;
int approx;
double sigma;
+ std::vector<std::vector<double> > projections, projections_diagonal;
+
public:
- Sliced_Wasserstein(PD _diagram){diagram = _diagram; approx = 100; sigma = 0.001;}
- Sliced_Wasserstein(PD _diagram, double _sigma, int _approx){diagram = _diagram; approx = _approx; sigma = _sigma;}
+ void build_rep(){
+
+ if(approx > 0){
+
+ double step = pi/this->approx;
+ int n = diagram.size();
+
+ for (int i = 0; i < this->approx; i++){
+ std::vector<double> l,l_diag;
+ for (int j = 0; j < n; j++){
+
+ double px = diagram[j].first; double py = diagram[j].second;
+ double proj_diag = (px+py)/2;
+
+ l.push_back ( px * cos(-pi/2+i*step) + py * sin(-pi/2+i*step) );
+ l_diag.push_back ( proj_diag * cos(-pi/2+i*step) + proj_diag * sin(-pi/2+i*step) );
+ }
+
+ std::sort(l.begin(), l.end()); std::sort(l_diag.begin(), l_diag.end());
+ projections.push_back(l); projections_diagonal.push_back(l_diag);
+
+ }
+
+ }
+
+ }
+
+ Sliced_Wasserstein(PD _diagram){diagram = _diagram; approx = 100; sigma = 0.001; build_rep();}
+ Sliced_Wasserstein(PD _diagram, double _sigma, int _approx){diagram = _diagram; approx = _approx; sigma = _sigma; build_rep();}
+
PD get_diagram(){return this->diagram;}
int get_approx(){return this->approx;}
double get_sigma(){return this->sigma;}
+
+
// **********************************
// Utils.
@@ -227,28 +259,19 @@ class Sliced_Wasserstein {
else{
+
double step = pi/this->approx;
- // Add projections onto diagonal.
- int n1, n2; n1 = diagram1.size(); n2 = diagram2.size();
- for (int i = 0; i < n2; i++)
- diagram1.emplace_back( (diagram2[i].first + diagram2[i].second)/2, (diagram2[i].first + diagram2[i].second)/2 );
- for (int i = 0; i < n1; i++)
- diagram2.emplace_back( (diagram1[i].first + diagram1[i].second)/2, (diagram1[i].first + diagram1[i].second)/2 );
- int n = diagram1.size();
-
- // Sort and compare all projections.
for (int i = 0; i < this->approx; i++){
- std::vector<std::pair<int,double> > l1, l2;
- for (int j = 0; j < n; j++){
- l1.emplace_back( j, diagram1[j].first*cos(-pi/2+i*step) + diagram1[j].second*sin(-pi/2+i*step) );
- l2.emplace_back( j, diagram2[j].first*cos(-pi/2+i*step) + diagram2[j].second*sin(-pi/2+i*step) );
- }
- std::sort(l1.begin(),l1.end(), [=](const std::pair<int,double> & p1, const std::pair<int,double> & p2){return p1.second < p2.second;});
- std::sort(l2.begin(),l2.end(), [=](const std::pair<int,double> & p1, const std::pair<int,double> & p2){return p1.second < p2.second;});
- double f = 0; for (int j = 0; j < n; j++) f += std::abs(l1[j].second - l2[j].second);
+
+ std::vector<double> v1; std::vector<double> l1 = this->projections[i]; std::vector<double> l1bis = second.projections_diagonal[i]; std::merge(l1.begin(), l1.end(), l1bis.begin(), l1bis.end(), std::back_inserter(v1));
+ std::vector<double> v2; std::vector<double> l2 = second.projections[i]; std::vector<double> l2bis = this->projections_diagonal[i]; std::merge(l2.begin(), l2.end(), l2bis.begin(), l2bis.end(), std::back_inserter(v2));
+ int n = v1.size(); double f = 0;
+ for (int j = 0; j < n; j++) f += std::abs(v1[j] - v2[j]);
sw += f*step;
+
}
+
}
return sw/pi;
@@ -265,6 +288,8 @@ class Sliced_Wasserstein {
}
+
+
};
} // namespace Sliced_Wasserstein