summaryrefslogtreecommitdiff
path: root/src/Persistence_representations/include/gudhi/Sliced_Wasserstein.h
diff options
context:
space:
mode:
Diffstat (limited to 'src/Persistence_representations/include/gudhi/Sliced_Wasserstein.h')
-rw-r--r--src/Persistence_representations/include/gudhi/Sliced_Wasserstein.h87
1 files changed, 32 insertions, 55 deletions
diff --git a/src/Persistence_representations/include/gudhi/Sliced_Wasserstein.h b/src/Persistence_representations/include/gudhi/Sliced_Wasserstein.h
index 8c92ab54..a3c0dc2f 100644
--- a/src/Persistence_representations/include/gudhi/Sliced_Wasserstein.h
+++ b/src/Persistence_representations/include/gudhi/Sliced_Wasserstein.h
@@ -28,18 +28,6 @@
#include <gudhi/common_persistence_representations.h>
#include <gudhi/Debug_utils.h>
-// standard include
-#include <cmath>
-#include <iostream>
-#include <vector>
-#include <limits>
-#include <fstream>
-#include <sstream>
-#include <algorithm>
-#include <string>
-#include <utility>
-#include <functional>
-
namespace Gudhi {
namespace Persistence_representations {
@@ -70,12 +58,15 @@ namespace Persistence_representations {
class Sliced_Wasserstein {
protected:
- Persistence_diagram diagram;
- int approx;
- double sigma;
- std::vector<std::vector<double> > projections, projections_diagonal;
- public:
+ Persistence_diagram diagram;
+ int approx;
+ double sigma;
+ std::vector<std::vector<double> > projections, projections_diagonal;
+
+ // **********************************
+ // Utils.
+ // **********************************
void build_rep(){
@@ -96,28 +87,14 @@ class Sliced_Wasserstein {
}
std::sort(l.begin(), l.end()); std::sort(l_diag.begin(), l_diag.end());
- projections.push_back(l); projections_diagonal.push_back(l_diag);
+ projections.push_back(std::move(l)); projections_diagonal.push_back(std::move(l_diag));
}
+ diagram.clear();
}
-
}
- /** \brief Sliced Wasserstein kernel constructor.
- * \ingroup Sliced_Wasserstein
- *
- * @param[in] _diagram persistence diagram.
- * @param[in] _sigma bandwidth parameter.
- * @param[in] _approx number of directions used to approximate the integral in the Sliced Wasserstein distance, set to -1 for exact computation.
- *
- */
- Sliced_Wasserstein(const Persistence_diagram & _diagram, double _sigma = 1.0, int _approx = 100){diagram = _diagram; approx = _approx; sigma = _sigma; build_rep();}
-
- // **********************************
- // Utils.
- // **********************************
-
// Compute the angle formed by two points of a PD
double compute_angle(const Persistence_diagram & diag, int i, int j) const {
std::pair<double,double> vect; double x1,y1, x2,y2;
@@ -177,21 +154,7 @@ class Sliced_Wasserstein {
return norm*integral;
}
-
-
-
- // **********************************
- // Scalar product + distance.
- // **********************************
-
- /** \brief Evaluation of the Sliced Wasserstein Distance between a pair of diagrams.
- * \ingroup Sliced_Wasserstein
- *
- * @pre approx attribute needs to be the same for both instances.
- * @param[in] second other instance of class Sliced_Wasserstein.
- *
- *
- */
+ // Evaluation of the Sliced Wasserstein Distance between a pair of diagrams.
double compute_sliced_wasserstein_distance(const Sliced_Wasserstein & second) const {
GUDHI_CHECK(this->approx != second.approx, std::invalid_argument("Error: different approx values for representations"));
@@ -232,8 +195,8 @@ class Sliced_Wasserstein {
}
// Sort angles.
- std::sort(angles1.begin(), angles1.end(), [=](const std::pair<double, std::pair<int,int> >& p1, const std::pair<double, std::pair<int,int> >& p2){return (p1.first < p2.first);});
- std::sort(angles2.begin(), angles2.end(), [=](const std::pair<double, std::pair<int,int> >& p1, const std::pair<double, std::pair<int,int> >& p2){return (p1.first < p2.first);});
+ std::sort(angles1.begin(), angles1.end(), [](const std::pair<double, std::pair<int,int> >& p1, const std::pair<double, std::pair<int,int> >& p2){return (p1.first < p2.first);});
+ std::sort(angles2.begin(), angles2.end(), [](const std::pair<double, std::pair<int,int> >& p1, const std::pair<double, std::pair<int,int> >& p2){return (p1.first < p2.first);});
// Initialize orders of the points of both PDs (given by ordinates when theta = -pi/2).
std::vector<int> orderp1, orderp2;
@@ -291,11 +254,13 @@ class Sliced_Wasserstein {
else{
- double step = pi/this->approx;
+ double step = pi/this->approx; std::vector<double> v1, v2;
for (int i = 0; i < this->approx; i++){
- std::vector<double> v1; std::vector<double> l1 = this->projections[i]; std::vector<double> l1bis = second.projections_diagonal[i]; std::merge(l1.begin(), l1.end(), l1bis.begin(), l1bis.end(), std::back_inserter(v1));
- std::vector<double> v2; std::vector<double> l2 = second.projections[i]; std::vector<double> l2bis = this->projections_diagonal[i]; std::merge(l2.begin(), l2.end(), l2bis.begin(), l2bis.end(), std::back_inserter(v2));
+ v1.clear(); v2.clear();
+ std::merge(this->projections[i].begin(), this->projections[i].end(), second.projections_diagonal[i].begin(), second.projections_diagonal[i].end(), std::back_inserter(v1));
+ std::merge(second.projections[i].begin(), second.projections[i].end(), this->projections_diagonal[i].begin(), this->projections_diagonal[i].end(), std::back_inserter(v2));
+
int n = v1.size(); double f = 0;
for (int j = 0; j < n; j++) f += std::abs(v1[j] - v2[j]);
sw += f*step;
@@ -306,6 +271,20 @@ class Sliced_Wasserstein {
return sw/pi;
}
+ public:
+
+ /** \brief Sliced Wasserstein kernel constructor.
+ * \ingroup Sliced_Wasserstein
+ *
+ * @param[in] _diagram persistence diagram.
+ * @param[in] _sigma bandwidth parameter.
+ * @param[in] _approx number of directions used to approximate the integral in the Sliced Wasserstein distance, set to -1 for exact computation. If positive, then projections of the diagram
+ * points on all directions are stored in memory to reduce computation time.
+ *
+ */
+ // This class implements the following concepts: Topological_data_with_distances, Real_valued_topological_data, Topological_data_with_scalar_product
+ Sliced_Wasserstein(const Persistence_diagram & _diagram, double _sigma = 1.0, int _approx = 10):diagram(_diagram), approx(_approx), sigma(_sigma) {build_rep();}
+
/** \brief Evaluation of the kernel on a pair of diagrams.
* \ingroup Sliced_Wasserstein
*
@@ -331,8 +310,6 @@ class Sliced_Wasserstein {
}
-
-
}; // class Sliced_Wasserstein
} // namespace Persistence_representations
} // namespace Gudhi