From 76a4bbcc733bdd24bb61072a341c43a14b7f83d1 Mon Sep 17 00:00:00 2001 From: Mario Mulansky Date: Fri, 8 May 2015 11:57:00 +0200 Subject: performance improvement for multivar spike sync dont compute the average profile in the function spike_sync_multi, but rather compute the overall average distance directly --- examples/performance.py | 29 +++++++++++++++++++++++------ pyspike/DiscreteFunc.py | 17 +++++++++-------- pyspike/spike_sync.py | 22 ++++++++++++++++++++-- 3 files changed, 52 insertions(+), 16 deletions(-) diff --git a/examples/performance.py b/examples/performance.py index 469b5ab..94dae25 100644 --- a/examples/performance.py +++ b/examples/performance.py @@ -6,6 +6,7 @@ from __future__ import print_function import pyspike as spk from datetime import datetime import cProfile +import pstats M = 100 # number of spike trains r = 1.0 # rate of Poisson spike times @@ -22,21 +23,37 @@ t_end = datetime.now() runtime = (t_end-t_start).total_seconds() print("Spike generation runtime: %.3fs" % runtime) +print() print("================ ISI COMPUTATIONS ================") print(" MULTIVARIATE DISTANCE") -cProfile.run('spk.isi_distance_multi(spike_trains)') +cProfile.run('spk.isi_distance_multi(spike_trains)', 'performance.stat') +p = pstats.Stats('performance.stat') +p.strip_dirs().sort_stats('tottime').print_stats(5) + print(" MULTIVARIATE PROFILE") -cProfile.run('spk.isi_profile_multi(spike_trains)') +cProfile.run('spk.isi_profile_multi(spike_trains)', 'performance.stat') +p = pstats.Stats('performance.stat') +p.strip_dirs().sort_stats('tottime').print_stats(5) print("================ SPIKE COMPUTATIONS ================") print(" MULTIVARIATE DISTANCE") -cProfile.run('spk.spike_distance_multi(spike_trains)') +cProfile.run('spk.spike_distance_multi(spike_trains)', 'performance.stat') +p = pstats.Stats('performance.stat') +p.strip_dirs().sort_stats('tottime').print_stats(5) + print(" MULTIVARIATE PROFILE") -cProfile.run('spk.spike_profile_multi(spike_trains)') +cProfile.run('spk.spike_profile_multi(spike_trains)', 'performance.stat') +p = pstats.Stats('performance.stat') +p.strip_dirs().sort_stats('tottime').print_stats(5) print("================ SPIKE-SYNC COMPUTATIONS ================") print(" MULTIVARIATE DISTANCE") -cProfile.run('spk.spike_sync_multi(spike_trains)') +cProfile.run('spk.spike_sync_multi(spike_trains)', 'performance.stat') +p = pstats.Stats('performance.stat') +p.strip_dirs().sort_stats('tottime').print_stats(5) + print(" MULTIVARIATE PROFILE") -cProfile.run('spk.spike_sync_profile_multi(spike_trains)') +cProfile.run('spk.spike_sync_profile_multi(spike_trains)', 'performance.stat') +p = pstats.Stats('performance.stat') +p.strip_dirs().sort_stats('tottime').print_stats(5) diff --git a/pyspike/DiscreteFunc.py b/pyspike/DiscreteFunc.py index 33b7a81..dfe2cab 100644 --- a/pyspike/DiscreteFunc.py +++ b/pyspike/DiscreteFunc.py @@ -125,15 +125,15 @@ class DiscreteFunc(object): def integral(self, interval=None): """ Returns the integral over the given interval. For the discrete - function, this amounts to the sum over all values divided by the total - multiplicity. + function, this amounts to two values: the sum over all values and the + sum over all multiplicities. :param interval: integration interval given as a pair of floats, or a sequence of pairs in case of multiple intervals, if None the integral over the whole function is computed. :type interval: Pair, sequence of pairs, or None. - :returns: the integral - :rtype: float + :returns: the summed values and the summed multiplicity + :rtype: pair of float """ def get_indices(ival): @@ -147,7 +147,7 @@ class DiscreteFunc(object): if interval is None: # no interval given, integrate over the whole spike train # don't count the first value, which is zero by definition - return 1.0 * np.sum(self.y[1:-1]) / np.sum(self.mp[1:-1]) + return (1.0 * np.sum(self.y[1:-1]), np.sum(self.mp[1:-1])) # check if interval is as sequence assert isinstance(interval, collections.Sequence), \ @@ -156,7 +156,7 @@ class DiscreteFunc(object): if not isinstance(interval[0], collections.Sequence): # find the indices corresponding to the interval start_ind, end_ind = get_indices(interval) - return (np.sum(self.y[start_ind:end_ind]) / + return (np.sum(self.y[start_ind:end_ind]), np.sum(self.mp[start_ind:end_ind])) else: value = 0.0 @@ -166,7 +166,7 @@ class DiscreteFunc(object): start_ind, end_ind = get_indices(ival) value += np.sum(self.y[start_ind:end_ind]) multiplicity += np.sum(self.mp[start_ind:end_ind]) - return value/multiplicity + return (value, multiplicity) def avrg(self, interval=None): """ Computes the average of the interval sequence: @@ -180,7 +180,8 @@ class DiscreteFunc(object): :returns: the average a. :rtype: float """ - return self.integral(interval) + val, mp = self.integral(interval) + return val/mp def add(self, f): """ Adds another `DiscreteFunc` function to this function. diff --git a/pyspike/spike_sync.py b/pyspike/spike_sync.py index 9d2e363..0c78228 100644 --- a/pyspike/spike_sync.py +++ b/pyspike/spike_sync.py @@ -3,6 +3,7 @@ # Copyright 2014-2015, Mario Mulansky # Distributed under the BSD License +import numpy as np from functools import partial from pyspike import DiscreteFunc from pyspike.generic import _generic_profile_multi, _generic_distance_matrix @@ -131,8 +132,25 @@ def spike_sync_multi(spike_trains, indices=None, interval=None, max_tau=None): :rtype: double """ - return spike_sync_profile_multi(spike_trains, indices, - max_tau).avrg(interval) + if indices is None: + indices = np.arange(len(spike_trains)) + indices = np.array(indices) + # check validity of indices + assert (indices < len(spike_trains)).all() and (indices >= 0).all(), \ + "Invalid index list." + # generate a list of possible index pairs + pairs = [(indices[i], j) for i in range(len(indices)) + for j in indices[i+1:]] + + coincidence = 0.0 + mp = 0.0 + for (i, j) in pairs: + profile = spike_sync_profile(spike_trains[i], spike_trains[j]) + summed_vals = profile.integral(interval) + coincidence += summed_vals[0] + mp += summed_vals[1] + + return coincidence/mp ############################################################ -- cgit v1.2.3