diff options
author | Marc Glisse <marc.glisse@inria.fr> | 2020-04-13 15:01:51 +0200 |
---|---|---|
committer | Marc Glisse <marc.glisse@inria.fr> | 2020-04-13 15:01:51 +0200 |
commit | 2f1576a23cf4ac055565875d384ca604c0ff6844 (patch) | |
tree | d36bd276d9ffcede0ea0b81c6e5642e5657e3ddb /src/python/gudhi/point_cloud/knn.py | |
parent | 280eb9d2323837619db1ae013b929adb9b45013b (diff) |
Small autodiff tweaks
Diffstat (limited to 'src/python/gudhi/point_cloud/knn.py')
-rw-r--r-- | src/python/gudhi/point_cloud/knn.py | 7 |
1 files changed, 4 insertions, 3 deletions
diff --git a/src/python/gudhi/point_cloud/knn.py b/src/python/gudhi/point_cloud/knn.py index 79362c09..ab3447d4 100644 --- a/src/python/gudhi/point_cloud/knn.py +++ b/src/python/gudhi/point_cloud/knn.py @@ -233,16 +233,17 @@ class KNearestNeighbors: # Another strategy would be to compute the whole distance matrix with torch.cdist # and use neighbors as indices into it. neighbors = ep.astensor(mat.argKmin(k, dim=1)).numpy() - neighbor_pts = Y[neighbors] + # Work around https://github.com/pytorch/pytorch/issues/34452 + neighbor_pts = Y[neighbors,] diff = neighbor_pts - X[:, None, :] if p == numpy.inf: distances = diff.abs().max(-1) elif p == 2: - distances = (diff ** 2).sum(-1) ** 0.5 + distances = (diff ** 2).sum(-1).sqrt() else: distances = (diff.abs() ** p).sum(-1) ** (1.0 / p) if self.return_index: - return neighbors.raw, distances.raw + return neighbors, distances.raw else: return distances.raw if self.return_index: |