summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMario Mulansky <mario.mulansky@gmx.net>2015-01-19 17:25:46 +0100
committerMario Mulansky <mario.mulansky@gmx.net>2015-01-19 17:25:46 +0100
commite0a3b5468364342d4468e07029e4daf2cacfd6b9 (patch)
tree13403fdaa3f4a39aebf4d469bd3071b5d47f0e09
parentfed0ceec753fc1a7e5a1e20632de5a9800fe4fb1 (diff)
cython implementation of spike-sync
-rw-r--r--pyspike/cython_distance.pyx81
-rw-r--r--pyspike/distances.py6
-rw-r--r--test/test_distance.py6
3 files changed, 87 insertions, 6 deletions
diff --git a/pyspike/cython_distance.pyx b/pyspike/cython_distance.pyx
index 779ff94..489aab9 100644
--- a/pyspike/cython_distance.pyx
+++ b/pyspike/cython_distance.pyx
@@ -33,6 +33,7 @@ cimport numpy as np
from libc.math cimport fabs
from libc.math cimport fmax
+from libc.math cimport fmin
DTYPE = np.float
ctypedef np.float_t DTYPE_t
@@ -229,3 +230,83 @@ def spike_distance_cython(double[:] t1,
# use only the data added above
# could be less than original length due to equal spike times
return spike_events[:index+1], y_starts[:index], y_ends[:index]
+
+
+
+############################################################
+# coincidence_python
+############################################################
+cdef inline double get_tau(double[:] spikes1, double[:] spikes2, int i, int j):
+ cdef double m = 1E100 # some huge number
+ cdef int N1 = len(spikes1)-2
+ cdef int N2 = len(spikes2)-2
+ if i < N1:
+ m = fmin(m, spikes1[i+1]-spikes1[i])
+ if j < N2:
+ m = fmin(m, spikes2[j+1]-spikes2[j])
+ if i > 1:
+ m = fmin(m, spikes1[i]-spikes1[i-1])
+ if j > 1:
+ m = fmin(m, spikes2[j]-spikes2[j-1])
+ return 0.5*m
+
+
+############################################################
+# coincidence_cython
+############################################################
+def coincidence_cython(double[:] spikes1, double[:] spikes2):
+
+ cdef int N1 = len(spikes1)
+ cdef int N2 = len(spikes2)
+ cdef int i = 0
+ cdef int j = 0
+ cdef int n = 0
+ cdef double[:] st = np.zeros(N1 + N2 - 2) # spike times
+ cdef double[:] c = np.zeros(N1 + N2 - 2) # coincidences
+ cdef double[:] mp = np.ones(N1 + N2 - 2) # multiplicity
+ cdef double tau
+ while n < N1 + N2 - 2:
+ if spikes1[i+1] < spikes2[j+1]:
+ i += 1
+ n += 1
+ tau = get_tau(spikes1, spikes2, i, j)
+ st[n] = spikes1[i]
+ if j > 0 and spikes1[i]-spikes2[j] < tau:
+ # coincidence between the current spike and the previous spike
+ # both get marked with 1
+ c[n] = 1
+ c[n-1] = 1
+ elif spikes1[i+1] > spikes2[j+1]:
+ j += 1
+ n += 1
+ tau = get_tau(spikes1, spikes2, i, j)
+ st[n] = spikes2[j]
+ if i > 0 and spikes2[j]-spikes1[i] < tau:
+ # coincidence between the current spike and the previous spike
+ # both get marked with 1
+ c[n] = 1
+ c[n-1] = 1
+ else: # spikes1[i+1] = spikes2[j+1]
+ # advance in both spike trains
+ j += 1
+ i += 1
+ if i == N1-1 or j == N2-1:
+ break
+ n += 1
+ # add only one event, but with coincidence 2 and multiplicity 2
+ st[n] = spikes1[i]
+ c[n] = 2
+ mp[n] = 2
+
+ st = st[:n+2]
+ c = c[:n+2]
+ mp = mp[:n+2]
+
+ st[0] = spikes1[0]
+ st[len(st)-1] = spikes1[len(spikes1)-1]
+ c[0] = c[1]
+ c[len(c)-1] = c[len(c)-2]
+ mp[0] = mp[1]
+ mp[len(mp)-1] = mp[len(mp)-2]
+
+ return st, c, mp
diff --git a/pyspike/distances.py b/pyspike/distances.py
index 5ee8261..8bde724 100644
--- a/pyspike/distances.py
+++ b/pyspike/distances.py
@@ -139,9 +139,9 @@ def spike_sync_profile(spikes1, spikes2):
from cython_distance import coincidence_cython \
as coincidence_impl
except ImportError:
-# print("Warning: spike_distance_cython not found. Make sure that \
-# PySpike is installed by running\n 'python setup.py build_ext --inplace'!\n \
-# Falling back to slow python backend.")
+ print("Warning: spike_distance_cython not found. Make sure that \
+PySpike is installed by running\n 'python setup.py build_ext --inplace'!\n \
+Falling back to slow python backend.")
# use python backend
from python_backend import coincidence_python \
as coincidence_impl
diff --git a/test/test_distance.py b/test/test_distance.py
index 4f8f6e8..6bdb049 100644
--- a/test/test_distance.py
+++ b/test/test_distance.py
@@ -199,11 +199,11 @@ def test_multi_spike():
def test_multi_spike_sync():
# some basic multivariate check
spikes1 = np.array([100, 300, 400, 405, 410, 500, 700, 800,
- 805, 810, 815, 900])
+ 805, 810, 815, 900], dtype=float)
spikes2 = np.array([100, 200, 205, 210, 295, 350, 400, 510,
- 600, 605, 700, 910])
+ 600, 605, 700, 910], dtype=float)
spikes3 = np.array([100, 180, 198, 295, 412, 420, 510, 640,
- 695, 795, 820, 920])
+ 695, 795, 820, 920], dtype=float)
spikes1 = spk.add_auxiliary_spikes(spikes1, 1000)
spikes2 = spk.add_auxiliary_spikes(spikes2, 1000)
spikes3 = spk.add_auxiliary_spikes(spikes3, 1000)