summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMario Mulansky <mario.mulansky@gmx.net>2015-05-08 11:57:00 +0200
committerMario Mulansky <mario.mulansky@gmx.net>2015-05-08 11:57:00 +0200
commit76a4bbcc733bdd24bb61072a341c43a14b7f83d1 (patch)
treea7612697e0b299f087ea829b120ddffa44347c35
parenta0262fc04e4b084f4dd270a75938d4ad029783d4 (diff)
performance improvement for multivar spike sync
dont compute the average profile in the function spike_sync_multi, but rather compute the overall average distance directly
-rw-r--r--examples/performance.py29
-rw-r--r--pyspike/DiscreteFunc.py17
-rw-r--r--pyspike/spike_sync.py22
3 files changed, 52 insertions, 16 deletions
diff --git a/examples/performance.py b/examples/performance.py
index 469b5ab..94dae25 100644
--- a/examples/performance.py
+++ b/examples/performance.py
@@ -6,6 +6,7 @@ from __future__ import print_function
import pyspike as spk
from datetime import datetime
import cProfile
+import pstats
M = 100 # number of spike trains
r = 1.0 # rate of Poisson spike times
@@ -22,21 +23,37 @@ t_end = datetime.now()
runtime = (t_end-t_start).total_seconds()
print("Spike generation runtime: %.3fs" % runtime)
+print()
print("================ ISI COMPUTATIONS ================")
print(" MULTIVARIATE DISTANCE")
-cProfile.run('spk.isi_distance_multi(spike_trains)')
+cProfile.run('spk.isi_distance_multi(spike_trains)', 'performance.stat')
+p = pstats.Stats('performance.stat')
+p.strip_dirs().sort_stats('tottime').print_stats(5)
+
print(" MULTIVARIATE PROFILE")
-cProfile.run('spk.isi_profile_multi(spike_trains)')
+cProfile.run('spk.isi_profile_multi(spike_trains)', 'performance.stat')
+p = pstats.Stats('performance.stat')
+p.strip_dirs().sort_stats('tottime').print_stats(5)
print("================ SPIKE COMPUTATIONS ================")
print(" MULTIVARIATE DISTANCE")
-cProfile.run('spk.spike_distance_multi(spike_trains)')
+cProfile.run('spk.spike_distance_multi(spike_trains)', 'performance.stat')
+p = pstats.Stats('performance.stat')
+p.strip_dirs().sort_stats('tottime').print_stats(5)
+
print(" MULTIVARIATE PROFILE")
-cProfile.run('spk.spike_profile_multi(spike_trains)')
+cProfile.run('spk.spike_profile_multi(spike_trains)', 'performance.stat')
+p = pstats.Stats('performance.stat')
+p.strip_dirs().sort_stats('tottime').print_stats(5)
print("================ SPIKE-SYNC COMPUTATIONS ================")
print(" MULTIVARIATE DISTANCE")
-cProfile.run('spk.spike_sync_multi(spike_trains)')
+cProfile.run('spk.spike_sync_multi(spike_trains)', 'performance.stat')
+p = pstats.Stats('performance.stat')
+p.strip_dirs().sort_stats('tottime').print_stats(5)
+
print(" MULTIVARIATE PROFILE")
-cProfile.run('spk.spike_sync_profile_multi(spike_trains)')
+cProfile.run('spk.spike_sync_profile_multi(spike_trains)', 'performance.stat')
+p = pstats.Stats('performance.stat')
+p.strip_dirs().sort_stats('tottime').print_stats(5)
diff --git a/pyspike/DiscreteFunc.py b/pyspike/DiscreteFunc.py
index 33b7a81..dfe2cab 100644
--- a/pyspike/DiscreteFunc.py
+++ b/pyspike/DiscreteFunc.py
@@ -125,15 +125,15 @@ class DiscreteFunc(object):
def integral(self, interval=None):
""" Returns the integral over the given interval. For the discrete
- function, this amounts to the sum over all values divided by the total
- multiplicity.
+ function, this amounts to two values: the sum over all values and the
+ sum over all multiplicities.
:param interval: integration interval given as a pair of floats, or a
sequence of pairs in case of multiple intervals, if
None the integral over the whole function is computed.
:type interval: Pair, sequence of pairs, or None.
- :returns: the integral
- :rtype: float
+ :returns: the summed values and the summed multiplicity
+ :rtype: pair of float
"""
def get_indices(ival):
@@ -147,7 +147,7 @@ class DiscreteFunc(object):
if interval is None:
# no interval given, integrate over the whole spike train
# don't count the first value, which is zero by definition
- return 1.0 * np.sum(self.y[1:-1]) / np.sum(self.mp[1:-1])
+ return (1.0 * np.sum(self.y[1:-1]), np.sum(self.mp[1:-1]))
# check if interval is as sequence
assert isinstance(interval, collections.Sequence), \
@@ -156,7 +156,7 @@ class DiscreteFunc(object):
if not isinstance(interval[0], collections.Sequence):
# find the indices corresponding to the interval
start_ind, end_ind = get_indices(interval)
- return (np.sum(self.y[start_ind:end_ind]) /
+ return (np.sum(self.y[start_ind:end_ind]),
np.sum(self.mp[start_ind:end_ind]))
else:
value = 0.0
@@ -166,7 +166,7 @@ class DiscreteFunc(object):
start_ind, end_ind = get_indices(ival)
value += np.sum(self.y[start_ind:end_ind])
multiplicity += np.sum(self.mp[start_ind:end_ind])
- return value/multiplicity
+ return (value, multiplicity)
def avrg(self, interval=None):
""" Computes the average of the interval sequence:
@@ -180,7 +180,8 @@ class DiscreteFunc(object):
:returns: the average a.
:rtype: float
"""
- return self.integral(interval)
+ val, mp = self.integral(interval)
+ return val/mp
def add(self, f):
""" Adds another `DiscreteFunc` function to this function.
diff --git a/pyspike/spike_sync.py b/pyspike/spike_sync.py
index 9d2e363..0c78228 100644
--- a/pyspike/spike_sync.py
+++ b/pyspike/spike_sync.py
@@ -3,6 +3,7 @@
# Copyright 2014-2015, Mario Mulansky <mario.mulansky@gmx.net>
# Distributed under the BSD License
+import numpy as np
from functools import partial
from pyspike import DiscreteFunc
from pyspike.generic import _generic_profile_multi, _generic_distance_matrix
@@ -131,8 +132,25 @@ def spike_sync_multi(spike_trains, indices=None, interval=None, max_tau=None):
:rtype: double
"""
- return spike_sync_profile_multi(spike_trains, indices,
- max_tau).avrg(interval)
+ if indices is None:
+ indices = np.arange(len(spike_trains))
+ indices = np.array(indices)
+ # check validity of indices
+ assert (indices < len(spike_trains)).all() and (indices >= 0).all(), \
+ "Invalid index list."
+ # generate a list of possible index pairs
+ pairs = [(indices[i], j) for i in range(len(indices))
+ for j in indices[i+1:]]
+
+ coincidence = 0.0
+ mp = 0.0
+ for (i, j) in pairs:
+ profile = spike_sync_profile(spike_trains[i], spike_trains[j])
+ summed_vals = profile.integral(interval)
+ coincidence += summed_vals[0]
+ mp += summed_vals[1]
+
+ return coincidence/mp
############################################################