summaryrefslogtreecommitdiff
path: root/geom_matching/wasserstein/include/dnn/geometry/euclidean-dynamic.h
diff options
context:
space:
mode:
Diffstat (limited to 'geom_matching/wasserstein/include/dnn/geometry/euclidean-dynamic.h')
-rw-r--r--geom_matching/wasserstein/include/dnn/geometry/euclidean-dynamic.h24
1 files changed, 23 insertions, 1 deletions
diff --git a/geom_matching/wasserstein/include/dnn/geometry/euclidean-dynamic.h b/geom_matching/wasserstein/include/dnn/geometry/euclidean-dynamic.h
index 4b98309..b003906 100644
--- a/geom_matching/wasserstein/include/dnn/geometry/euclidean-dynamic.h
+++ b/geom_matching/wasserstein/include/dnn/geometry/euclidean-dynamic.h
@@ -8,6 +8,8 @@
#include <boost/serialization/vector.hpp>
#include <cmath>
+#include "hera_infinity.h"
+
namespace hera
{
namespace ws
@@ -89,7 +91,27 @@ struct DynamicPointTraits
DynamicPointTraits(unsigned dim = 0):
dim_(dim) {}
- DistanceType distance(PointType p1, PointType p2) const { return sqrt(sq_distance(p1,p2)); }
+ DistanceType distance(PointType p1, PointType p2) const
+ {
+ Real result = 0.0;
+ if (hera::is_infinity(internal_p)) {
+ // max norm
+ for (unsigned i = 0; i < dimension(); ++i)
+ result = std::max(result, fabs(coordinate(p1,i) - coordinate(p2,i)));
+ } else if (internal_p == Real(1.0)) {
+ // l1-norm
+ for (unsigned i = 0; i < dimension(); ++i)
+ result += fabs(coordinate(p1,i) - coordinate(p2,i));
+ } else if (internal_p == Real(2.0)) {
+ result = sqrt(sq_distance(p1,p2));
+ } else {
+ assert(internal_p > 1.0);
+ for (unsigned i = 0; i < dimension(); ++i)
+ result += std::pow(fabs(coordinate(p1,i) - coordinate(p2,i)), internal_p);
+ result = std::pow(result, Real(1.0) / internal_p);
+ }
+ return result;
+ }
DistanceType distance(PointHandle p1, PointHandle p2) const { return distance(PointType({p1.p}), PointType({p2.p})); }
DistanceType sq_distance(PointType p1, PointType p2) const { Real res = 0; for (unsigned i = 0; i < dimension(); ++i) { Real c1 = coordinate(p1,i), c2 = coordinate(p2,i); res += (c1 - c2)*(c1 - c2); } return res; }
DistanceType sq_distance(PointHandle p1, PointHandle p2) const { return sq_distance(PointType({p1.p}), PointType({p2.p})); }