diff options
Diffstat (limited to 'src/python/gudhi')
-rw-r--r-- | src/python/gudhi/persistence_graphical_tools.py | 12 | ||||
-rw-r--r-- | src/python/gudhi/wasserstein/barycenter.py | 49 |
2 files changed, 23 insertions, 38 deletions
diff --git a/src/python/gudhi/persistence_graphical_tools.py b/src/python/gudhi/persistence_graphical_tools.py index d59e51a0..6a74a6ca 100644 --- a/src/python/gudhi/persistence_graphical_tools.py +++ b/src/python/gudhi/persistence_graphical_tools.py @@ -109,9 +109,6 @@ def plot_persistence_barcode( plt.rc('text', usetex=True) plt.rc('font', family='serif') - - persistence = _array_handler(persistence) - if persistence_file != "": if path.isfile(persistence_file): # Reset persistence @@ -126,6 +123,8 @@ def plot_persistence_barcode( print("file " + persistence_file + " not found.") return None + persistence = _array_handler(persistence) + if max_barcodes != 1000: print("Deprecated parameter. It has been replaced by max_intervals") max_intervals = max_barcodes @@ -255,8 +254,6 @@ def plot_persistence_diagram( plt.rc('text', usetex=True) plt.rc('font', family='serif') - persistence = _array_handler(persistence) - if persistence_file != "": if path.isfile(persistence_file): # Reset persistence @@ -271,6 +268,8 @@ def plot_persistence_diagram( print("file " + persistence_file + " not found.") return None + persistence = _array_handler(persistence) + if max_plots != 1000: print("Deprecated parameter. It has been replaced by max_intervals") max_intervals = max_plots @@ -427,8 +426,6 @@ def plot_persistence_density( plt.rc('text', usetex=True) plt.rc('font', family='serif') - persistence = _array_handler(persistence) - if persistence_file != "": if dimension is None: # All dimension case @@ -442,6 +439,7 @@ def plot_persistence_density( return None if len(persistence) > 0: + persistence = _array_handler(persistence) persistence_dim = np.array( [ (dim_interval[1][0], dim_interval[1][1]) diff --git a/src/python/gudhi/wasserstein/barycenter.py b/src/python/gudhi/wasserstein/barycenter.py index de7aea81..d67bcde7 100644 --- a/src/python/gudhi/wasserstein/barycenter.py +++ b/src/python/gudhi/wasserstein/barycenter.py @@ -18,8 +18,7 @@ from gudhi.wasserstein import wasserstein_distance def _mean(x, m): ''' :param x: a list of 2D-points, off diagonal, x_0... x_{k-1} - :param m: total amount of points taken into account, - that is we have (m-k) copies of diagonal + :param m: total amount of points taken into account, that is we have (m-k) copies of diagonal :returns: the weighted mean of x with (m-k) copies of the diagonal ''' k = len(x) @@ -33,37 +32,26 @@ def _mean(x, m): def lagrangian_barycenter(pdiagset, init=None, verbose=False): ''' - :param pdiagset: a list of ``numpy.array`` of shape `(n x 2)` - (`n` can variate), encoding a set of - persistence diagrams with only finite coordinates. + :param pdiagset: a list of ``numpy.array`` of shape `(n x 2)` (`n` can variate), encoding a set of persistence + diagrams with only finite coordinates. :param init: The initial value for barycenter estimate. - If ``None``, init is made on a random diagram from the dataset. - Otherwise, it can be an ``int`` - (then initialization is made on ``pdiagset[init]``) - or a `(n x 2)` ``numpy.array`` enconding - a persistence diagram with `n` points. + If ``None``, init is made on a random diagram from the dataset. + Otherwise, it can be an ``int`` (then initialization is made on ``pdiagset[init]``) + or a `(n x 2)` ``numpy.array`` enconding a persistence diagram with `n` points. :type init: ``int``, or (n x 2) ``np.array`` - :param verbose: if ``True``, returns additional information about the - barycenter. + :param verbose: if ``True``, returns additional information about the barycenter. :type verbose: boolean - :returns: If not verbose (default), a ``numpy.array`` encoding - the barycenter estimate of pdiagset - (local minimum of the energy function). - If ``pdiagset`` is empty, returns ``None``. - If verbose, returns a couple ``(Y, log)`` - where ``Y`` is the barycenter estimate, - and ``log`` is a ``dict`` that contains additional informations: - - - `"groupings"`, a list of list of pairs ``(i,j)``. - Namely, ``G[k] = [...(i, j)...]``, where ``(i,j)`` indicates - that ``pdiagset[k][i]`` is matched to ``Y[j]`` - if ``i = -1`` or ``j = -1``, it means they - represent the diagonal. - - - `"energy"`, ``float`` representing the Frechet energy value obtained. - It is the mean of squared distances of observations to the output. - - - `"nb_iter"`, ``int`` number of iterations performed before convergence of the algorithm. + :returns: If not verbose (default), a ``numpy.array`` encoding the barycenter estimate of pdiagset + (local minimum of the energy function). + If ``pdiagset`` is empty, returns ``None``. + If verbose, returns a couple ``(Y, log)`` where ``Y`` is the barycenter estimate, + and ``log`` is a ``dict`` that contains additional informations: + + - `"groupings"`, a list of list of pairs ``(i,j)``. Namely, ``G[k] = [...(i, j)...]``, where ``(i,j)`` indicates that `pdiagset[k][i]`` is matched to ``Y[j]`` if ``i = -1`` or ``j = -1``, it means they represent the diagonal. + + - `"energy"`, ``float`` representing the Frechet energy value obtained. It is the mean of squared distances of observations to the output. + + - `"nb_iter"`, ``int`` number of iterations performed before convergence of the algorithm. ''' X = pdiagset # to shorten notations, not a copy m = len(X) # number of diagrams we are averaging @@ -156,4 +144,3 @@ def lagrangian_barycenter(pdiagset, init=None, verbose=False): return Y, log else: return Y - |