diff options
author | mcarrier <mcarrier@636b058d-ea47-450e-bf9e-a15bfbe3eedb> | 2018-08-13 23:17:08 +0000 |
---|---|---|
committer | mcarrier <mcarrier@636b058d-ea47-450e-bf9e-a15bfbe3eedb> | 2018-08-13 23:17:08 +0000 |
commit | 5f5a7a21e9db73eaf9dc2604cb0de3066f7a4fb6 (patch) | |
tree | 0e68f4ae883d8e2e7e57b01bce1413173ba3124e /src/Persistence_representations/include/gudhi/Sliced_Wasserstein.h | |
parent | 4560e97df7abb106c420c7f05747d26f2972b5aa (diff) | |
parent | 0784baddd1392727289a972b8374b3c2dca940a9 (diff) |
git-svn-id: svn+ssh://scm.gforge.inria.fr/svnroot/gudhi/branches/kernels@3778 636b058d-ea47-450e-bf9e-a15bfbe3eedb
Former-commit-id: 189ac5572f69842e1d8d1cec68ca6a4f62e39bd4
Diffstat (limited to 'src/Persistence_representations/include/gudhi/Sliced_Wasserstein.h')
-rw-r--r-- | src/Persistence_representations/include/gudhi/Sliced_Wasserstein.h | 87 |
1 files changed, 32 insertions, 55 deletions
diff --git a/src/Persistence_representations/include/gudhi/Sliced_Wasserstein.h b/src/Persistence_representations/include/gudhi/Sliced_Wasserstein.h index 8c92ab54..a3c0dc2f 100644 --- a/src/Persistence_representations/include/gudhi/Sliced_Wasserstein.h +++ b/src/Persistence_representations/include/gudhi/Sliced_Wasserstein.h @@ -28,18 +28,6 @@ #include <gudhi/common_persistence_representations.h> #include <gudhi/Debug_utils.h> -// standard include -#include <cmath> -#include <iostream> -#include <vector> -#include <limits> -#include <fstream> -#include <sstream> -#include <algorithm> -#include <string> -#include <utility> -#include <functional> - namespace Gudhi { namespace Persistence_representations { @@ -70,12 +58,15 @@ namespace Persistence_representations { class Sliced_Wasserstein { protected: - Persistence_diagram diagram; - int approx; - double sigma; - std::vector<std::vector<double> > projections, projections_diagonal; - public: + Persistence_diagram diagram; + int approx; + double sigma; + std::vector<std::vector<double> > projections, projections_diagonal; + + // ********************************** + // Utils. + // ********************************** void build_rep(){ @@ -96,28 +87,14 @@ class Sliced_Wasserstein { } std::sort(l.begin(), l.end()); std::sort(l_diag.begin(), l_diag.end()); - projections.push_back(l); projections_diagonal.push_back(l_diag); + projections.push_back(std::move(l)); projections_diagonal.push_back(std::move(l_diag)); } + diagram.clear(); } - } - /** \brief Sliced Wasserstein kernel constructor. - * \ingroup Sliced_Wasserstein - * - * @param[in] _diagram persistence diagram. - * @param[in] _sigma bandwidth parameter. - * @param[in] _approx number of directions used to approximate the integral in the Sliced Wasserstein distance, set to -1 for exact computation. - * - */ - Sliced_Wasserstein(const Persistence_diagram & _diagram, double _sigma = 1.0, int _approx = 100){diagram = _diagram; approx = _approx; sigma = _sigma; build_rep();} - - // ********************************** - // Utils. - // ********************************** - // Compute the angle formed by two points of a PD double compute_angle(const Persistence_diagram & diag, int i, int j) const { std::pair<double,double> vect; double x1,y1, x2,y2; @@ -177,21 +154,7 @@ class Sliced_Wasserstein { return norm*integral; } - - - - // ********************************** - // Scalar product + distance. - // ********************************** - - /** \brief Evaluation of the Sliced Wasserstein Distance between a pair of diagrams. - * \ingroup Sliced_Wasserstein - * - * @pre approx attribute needs to be the same for both instances. - * @param[in] second other instance of class Sliced_Wasserstein. - * - * - */ + // Evaluation of the Sliced Wasserstein Distance between a pair of diagrams. double compute_sliced_wasserstein_distance(const Sliced_Wasserstein & second) const { GUDHI_CHECK(this->approx != second.approx, std::invalid_argument("Error: different approx values for representations")); @@ -232,8 +195,8 @@ class Sliced_Wasserstein { } // Sort angles. - std::sort(angles1.begin(), angles1.end(), [=](const std::pair<double, std::pair<int,int> >& p1, const std::pair<double, std::pair<int,int> >& p2){return (p1.first < p2.first);}); - std::sort(angles2.begin(), angles2.end(), [=](const std::pair<double, std::pair<int,int> >& p1, const std::pair<double, std::pair<int,int> >& p2){return (p1.first < p2.first);}); + std::sort(angles1.begin(), angles1.end(), [](const std::pair<double, std::pair<int,int> >& p1, const std::pair<double, std::pair<int,int> >& p2){return (p1.first < p2.first);}); + std::sort(angles2.begin(), angles2.end(), [](const std::pair<double, std::pair<int,int> >& p1, const std::pair<double, std::pair<int,int> >& p2){return (p1.first < p2.first);}); // Initialize orders of the points of both PDs (given by ordinates when theta = -pi/2). std::vector<int> orderp1, orderp2; @@ -291,11 +254,13 @@ class Sliced_Wasserstein { else{ - double step = pi/this->approx; + double step = pi/this->approx; std::vector<double> v1, v2; for (int i = 0; i < this->approx; i++){ - std::vector<double> v1; std::vector<double> l1 = this->projections[i]; std::vector<double> l1bis = second.projections_diagonal[i]; std::merge(l1.begin(), l1.end(), l1bis.begin(), l1bis.end(), std::back_inserter(v1)); - std::vector<double> v2; std::vector<double> l2 = second.projections[i]; std::vector<double> l2bis = this->projections_diagonal[i]; std::merge(l2.begin(), l2.end(), l2bis.begin(), l2bis.end(), std::back_inserter(v2)); + v1.clear(); v2.clear(); + std::merge(this->projections[i].begin(), this->projections[i].end(), second.projections_diagonal[i].begin(), second.projections_diagonal[i].end(), std::back_inserter(v1)); + std::merge(second.projections[i].begin(), second.projections[i].end(), this->projections_diagonal[i].begin(), this->projections_diagonal[i].end(), std::back_inserter(v2)); + int n = v1.size(); double f = 0; for (int j = 0; j < n; j++) f += std::abs(v1[j] - v2[j]); sw += f*step; @@ -306,6 +271,20 @@ class Sliced_Wasserstein { return sw/pi; } + public: + + /** \brief Sliced Wasserstein kernel constructor. + * \ingroup Sliced_Wasserstein + * + * @param[in] _diagram persistence diagram. + * @param[in] _sigma bandwidth parameter. + * @param[in] _approx number of directions used to approximate the integral in the Sliced Wasserstein distance, set to -1 for exact computation. If positive, then projections of the diagram + * points on all directions are stored in memory to reduce computation time. + * + */ + // This class implements the following concepts: Topological_data_with_distances, Real_valued_topological_data, Topological_data_with_scalar_product + Sliced_Wasserstein(const Persistence_diagram & _diagram, double _sigma = 1.0, int _approx = 10):diagram(_diagram), approx(_approx), sigma(_sigma) {build_rep();} + /** \brief Evaluation of the kernel on a pair of diagrams. * \ingroup Sliced_Wasserstein * @@ -331,8 +310,6 @@ class Sliced_Wasserstein { } - - }; // class Sliced_Wasserstein } // namespace Persistence_representations } // namespace Gudhi |