From b2177e897b575e0c8d17b8ae5ed3259541a06bea Mon Sep 17 00:00:00 2001 From: MathieuCarriere Date: Wed, 29 Apr 2020 19:16:50 -0400 Subject: small modifs --- src/python/gudhi/representations/kernel_methods.py | 3 ++- src/python/gudhi/representations/metrics.py | 9 ++++----- 2 files changed, 6 insertions(+), 6 deletions(-) (limited to 'src/python/gudhi/representations') diff --git a/src/python/gudhi/representations/kernel_methods.py b/src/python/gudhi/representations/kernel_methods.py index edd1382a..596f4f07 100644 --- a/src/python/gudhi/representations/kernel_methods.py +++ b/src/python/gudhi/representations/kernel_methods.py @@ -67,7 +67,8 @@ def pairwise_persistence_diagram_kernels(X, Y=None, kernel="sliced_wasserstein", Parameters: X (list of n numpy arrays of shape (numx2)): first list of persistence diagrams. Y (list of m numpy arrays of shape (numx2)): second list of persistence diagrams (optional). If None, pairwise kernel values are computed from the first list only. - kernel: kernel to use. It can be either a string ("sliced_wasserstein", "persistence_scale_space", "persistence_weighted_gaussian", "persistence_fisher") or a function taking two numpy arrays of shape (nx2) and (mx2) as inputs. + kernel: kernel to use. It can be either a string ("sliced_wasserstein", "persistence_scale_space", "persistence_weighted_gaussian", "persistence_fisher") or a function taking two numpy arrays of shape (nx2) and (mx2) as inputs. If it is a function, make sure that it is symmetric. + **kwargs: optional keyword parameters. Any further parameters are passed directly to the kernel function. See the docs of the various kernel classes in this module. Returns: numpy array of shape (nxm): kernel matrix. diff --git a/src/python/gudhi/representations/metrics.py b/src/python/gudhi/representations/metrics.py index a4bf19a6..ce416fb1 100644 --- a/src/python/gudhi/representations/metrics.py +++ b/src/python/gudhi/representations/metrics.py @@ -32,11 +32,9 @@ def _sliced_wasserstein_distance(D1, D2, num_directions): thetas = np.linspace(-np.pi/2, np.pi/2, num=num_directions+1)[np.newaxis,:-1] lines = np.concatenate([np.cos(thetas), np.sin(thetas)], axis=0) approx1 = np.matmul(D1, lines) - diag_proj1 = (1./2) * np.ones((2,2)) - approx_diag1 = np.matmul(np.matmul(D1, diag_proj1), lines) + approx_diag1 = np.matmul(np.broadcast_to(D1.sum(-1,keepdims=True)/2,(len(D1),2)), lines) approx2 = np.matmul(D2, lines) - diag_proj2 = (1./2) * np.ones((2,2)) - approx_diag2 = np.matmul(np.matmul(D2, diag_proj2), lines) + approx_diag2 = np.matmul(np.broadcast_to(D2.sum(-1,keepdims=True)/2,(len(D2),2)), lines) A = np.sort(np.concatenate([approx1, approx_diag2], axis=0), axis=0) B = np.sort(np.concatenate([approx2, approx_diag1], axis=0), axis=0) L1 = np.sum(np.abs(A-B), axis=0) @@ -143,7 +141,8 @@ def pairwise_persistence_diagram_distances(X, Y=None, metric="bottleneck", **kwa Parameters: X (list of n numpy arrays of shape (numx2)): first list of persistence diagrams. Y (list of m numpy arrays of shape (numx2)): second list of persistence diagrams (optional). If None, pairwise distances are computed from the first list only. - metric: distance to use. It can be either a string ("sliced_wasserstein", "wasserstein", "hera_wasserstein" (Wasserstein distance computed with Hera---note that Hera is also used for the default option "wasserstein"), "pot_wasserstein" (Wasserstein distance computed with POT), "bottleneck", "persistence_fisher") or a symmetric function taking two numpy arrays of shape (nx2) and (mx2) as inputs. + metric: distance to use. It can be either a string ("sliced_wasserstein", "wasserstein", "hera_wasserstein" (Wasserstein distance computed with Hera---note that Hera is also used for the default option "wasserstein"), "pot_wasserstein" (Wasserstein distance computed with POT), "bottleneck", "persistence_fisher") or a function taking two numpy arrays of shape (nx2) and (mx2) as inputs. If it is a function, make sure that it is symmetric and that it outputs 0 if called on the same two arrays. + **kwargs: optional keyword parameters. Any further parameters are passed directly to the distance function. See the docs of the various distance classes in this module. Returns: numpy array of shape (nxm): distance matrix -- cgit v1.2.3