summaryrefslogtreecommitdiff
path: root/src/python/gudhi/hera.cc
diff options
context:
space:
mode:
Diffstat (limited to 'src/python/gudhi/hera.cc')
-rw-r--r--src/python/gudhi/hera.cc6
1 files changed, 4 insertions, 2 deletions
diff --git a/src/python/gudhi/hera.cc b/src/python/gudhi/hera.cc
index 50d49c77..63bbb075 100644
--- a/src/python/gudhi/hera.cc
+++ b/src/python/gudhi/hera.cc
@@ -45,10 +45,12 @@ double wasserstein_distance(
throw std::runtime_error("Diagram 1 must be an array of size n x 2");
if((buf2.ndim!=2 || buf2.shape[1]!=2) && (buf2.ndim!=1 || buf2.shape[0]!=0))
throw std::runtime_error("Diagram 2 must be an array of size n x 2");
+ ssize_t stride11 = buf1.ndim == 2 ? buf1.strides[1] : 0;
+ ssize_t stride21 = buf2.ndim == 2 ? buf2.strides[1] : 0;
auto cnt1 = boost::counting_range<ssize_t>(0, buf1.shape[0]);
- auto diag1 = boost::adaptors::transform(cnt1, pairify(buf1.ptr, buf1.strides[0], buf1.strides[1]));
+ auto diag1 = boost::adaptors::transform(cnt1, pairify(buf1.ptr, buf1.strides[0], stride11));
auto cnt2 = boost::counting_range<ssize_t>(0, buf2.shape[0]);
- auto diag2 = boost::adaptors::transform(cnt2, pairify(buf2.ptr, buf2.strides[0], buf2.strides[1]));
+ auto diag2 = boost::adaptors::transform(cnt2, pairify(buf2.ptr, buf2.strides[0], stride21));
hera::AuctionParams<double> params;
params.wasserstein_power = wasserstein_power;