summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMario Mulansky <mario.mulansky@gmx.net>2015-05-11 12:06:19 +0200
committerMario Mulansky <mario.mulansky@gmx.net>2015-05-11 12:06:19 +0200
commitad7961c9f73d77b87ee23349169618d128418a51 (patch)
tree8c1869a96ea51c998df2d308c18223ec9e3cba6c
parentf688dc2e8616f914040746de845646abb158125d (diff)
performance improvement for spike sync
Additional cython implementation for overall spike sync values. It is not necessary to compute the profile anymore if only the spike sync value is required. 3x performance gain.
-rw-r--r--pyspike/cython/cython_distances.pyx64
-rw-r--r--pyspike/spike_sync.py46
2 files changed, 104 insertions, 6 deletions
diff --git a/pyspike/cython/cython_distances.pyx b/pyspike/cython/cython_distances.pyx
index 65c2872..bf90638 100644
--- a/pyspike/cython/cython_distances.pyx
+++ b/pyspike/cython/cython_distances.pyx
@@ -326,3 +326,67 @@ def spike_distance_cython(double[:] t1, double[:] t2,
# use only the data added above
# could be less than original length due to equal spike times
return spike_value / (t_end-t_start)
+
+
+
+############################################################
+# get_tau
+############################################################
+cdef inline double get_tau(double[:] spikes1, double[:] spikes2,
+ int i, int j, double max_tau):
+ cdef double m = 1E100 # some huge number
+ cdef int N1 = spikes1.shape[0]-1 # len(spikes1)-1
+ cdef int N2 = spikes2.shape[0]-1 # len(spikes2)-1
+ if i < N1 and i > -1:
+ m = fmin(m, spikes1[i+1]-spikes1[i])
+ if j < N2 and j > -1:
+ m = fmin(m, spikes2[j+1]-spikes2[j])
+ if i > 0:
+ m = fmin(m, spikes1[i]-spikes1[i-1])
+ if j > 0:
+ m = fmin(m, spikes2[j]-spikes2[j-1])
+ m *= 0.5
+ if max_tau > 0.0:
+ m = fmin(m, max_tau)
+ return m
+
+
+############################################################
+# coincidence_value_cython
+############################################################
+def coincidence_value_cython(double[:] spikes1, double[:] spikes2,
+ double t_start, double t_end, double max_tau):
+
+ cdef int N1 = len(spikes1)
+ cdef int N2 = len(spikes2)
+ cdef int i = -1
+ cdef int j = -1
+ cdef double coinc = 0.0
+ cdef double mp = 0.0
+ cdef double tau
+ while i + j < N1 + N2 - 2:
+ if (i < N1-1) and (j == N2-1 or spikes1[i+1] < spikes2[j+1]):
+ i += 1
+ mp += 1
+ tau = get_tau(spikes1, spikes2, i, j, max_tau)
+ if j > -1 and spikes1[i]-spikes2[j] < tau:
+ # coincidence between the current spike and the previous spike
+ # both get marked with 1
+ coinc += 2
+ elif (j < N2-1) and (i == N1-1 or spikes1[i+1] > spikes2[j+1]):
+ j += 1
+ mp += 1
+ tau = get_tau(spikes1, spikes2, i, j, max_tau)
+ if i > -1 and spikes2[j]-spikes1[i] < tau:
+ # coincidence between the current spike and the previous spike
+ # both get marked with 1
+ coinc += 2
+ else: # spikes1[i+1] = spikes2[j+1]
+ # advance in both spike trains
+ j += 1
+ i += 1
+ # add only one event, but with coincidence 2 and multiplicity 2
+ mp += 2
+ coinc += 2
+
+ return coinc, mp
diff --git a/pyspike/spike_sync.py b/pyspike/spike_sync.py
index 107734d..40d98d2 100644
--- a/pyspike/spike_sync.py
+++ b/pyspike/spike_sync.py
@@ -58,6 +58,40 @@ Falling back to slow python backend.")
############################################################
+# _spike_sync_values
+############################################################
+def _spike_sync_values(spike_train1, spike_train2, interval, max_tau):
+ """" Internal function. Computes the summed coincidences and multiplicity
+ for spike synchronization of the two given spike trains.
+
+ Do not call this function directly, use `spike_sync` or `spike_sync_multi`
+ instead.
+ """
+ if interval is None:
+ # distance over the whole interval is requested: use specific function
+ # for optimal performance
+ try:
+ from cython.cython_distances import coincidence_value_cython \
+ as coincidence_value_impl
+ if max_tau is None:
+ max_tau = 0.0
+ c, mp = coincidence_value_impl(spike_train1.spikes,
+ spike_train2.spikes,
+ spike_train1.t_start,
+ spike_train1.t_end,
+ max_tau)
+ return c, mp
+ except ImportError:
+ # Cython backend not available: fall back to profile averaging
+ return spike_sync_profile(spike_train1, spike_train2,
+ max_tau).integral(interval)
+ else:
+ # some specific interval is provided: use profile
+ return spike_sync_profile(spike_train1, spike_train2,
+ max_tau).integral(interval)
+
+
+############################################################
# spike_sync
############################################################
def spike_sync(spike_train1, spike_train2, interval=None, max_tau=None):
@@ -80,8 +114,8 @@ def spike_sync(spike_train1, spike_train2, interval=None, max_tau=None):
:rtype: `double`
"""
- return spike_sync_profile(spike_train1, spike_train2,
- max_tau).avrg(interval)
+ c, mp = _spike_sync_values(spike_train1, spike_train2, interval, max_tau)
+ return 1.0*c/mp
############################################################
@@ -144,10 +178,10 @@ def spike_sync_multi(spike_trains, indices=None, interval=None, max_tau=None):
coincidence = 0.0
mp = 0.0
for (i, j) in pairs:
- profile = spike_sync_profile(spike_trains[i], spike_trains[j])
- summed_vals = profile.integral(interval)
- coincidence += summed_vals[0]
- mp += summed_vals[1]
+ c, m = _spike_sync_values(spike_trains[i], spike_trains[j],
+ interval, max_tau)
+ coincidence += c
+ mp += m
return coincidence/mp