diff options
author | Marc Glisse <marc.glisse@inria.fr> | 2020-04-13 15:21:06 +0200 |
---|---|---|
committer | Marc Glisse <marc.glisse@inria.fr> | 2020-04-13 15:21:06 +0200 |
commit | 3a86402b733a48d9c25a4995325e72c7438c06c0 (patch) | |
tree | 0a339795c6d3db5f8aca79a69aae2f19631e978b /src/python/gudhi/point_cloud | |
parent | 2f1576a23cf4ac055565875d384ca604c0ff6844 (diff) |
Fix NaN gradient with pytorch
Diffstat (limited to 'src/python/gudhi/point_cloud')
-rw-r--r-- | src/python/gudhi/point_cloud/knn.py | 9 |
1 files changed, 4 insertions, 5 deletions
diff --git a/src/python/gudhi/point_cloud/knn.py b/src/python/gudhi/point_cloud/knn.py index ab3447d4..185a7764 100644 --- a/src/python/gudhi/point_cloud/knn.py +++ b/src/python/gudhi/point_cloud/knn.py @@ -236,12 +236,11 @@ class KNearestNeighbors: # Work around https://github.com/pytorch/pytorch/issues/34452 neighbor_pts = Y[neighbors,] diff = neighbor_pts - X[:, None, :] - if p == numpy.inf: - distances = diff.abs().max(-1) - elif p == 2: - distances = (diff ** 2).sum(-1).sqrt() + if isinstance(diff, ep.PyTorchTensor): + # https://github.com/jonasrauber/eagerpy/issues/6 + distances = ep.astensor(diff.raw.norm(p, -1)) else: - distances = (diff.abs() ** p).sum(-1) ** (1.0 / p) + distances = diff.norms.lp(p, -1) if self.return_index: return neighbors, distances.raw else: |