From e0a3b5468364342d4468e07029e4daf2cacfd6b9 Mon Sep 17 00:00:00 2001 From: Mario Mulansky Date: Mon, 19 Jan 2015 17:25:46 +0100 Subject: cython implementation of spike-sync --- pyspike/cython_distance.pyx | 81 +++++++++++++++++++++++++++++++++++++++++++++ pyspike/distances.py | 6 ++-- test/test_distance.py | 6 ++-- 3 files changed, 87 insertions(+), 6 deletions(-) diff --git a/pyspike/cython_distance.pyx b/pyspike/cython_distance.pyx index 779ff94..489aab9 100644 --- a/pyspike/cython_distance.pyx +++ b/pyspike/cython_distance.pyx @@ -33,6 +33,7 @@ cimport numpy as np from libc.math cimport fabs from libc.math cimport fmax +from libc.math cimport fmin DTYPE = np.float ctypedef np.float_t DTYPE_t @@ -229,3 +230,83 @@ def spike_distance_cython(double[:] t1, # use only the data added above # could be less than original length due to equal spike times return spike_events[:index+1], y_starts[:index], y_ends[:index] + + + +############################################################ +# coincidence_python +############################################################ +cdef inline double get_tau(double[:] spikes1, double[:] spikes2, int i, int j): + cdef double m = 1E100 # some huge number + cdef int N1 = len(spikes1)-2 + cdef int N2 = len(spikes2)-2 + if i < N1: + m = fmin(m, spikes1[i+1]-spikes1[i]) + if j < N2: + m = fmin(m, spikes2[j+1]-spikes2[j]) + if i > 1: + m = fmin(m, spikes1[i]-spikes1[i-1]) + if j > 1: + m = fmin(m, spikes2[j]-spikes2[j-1]) + return 0.5*m + + +############################################################ +# coincidence_cython +############################################################ +def coincidence_cython(double[:] spikes1, double[:] spikes2): + + cdef int N1 = len(spikes1) + cdef int N2 = len(spikes2) + cdef int i = 0 + cdef int j = 0 + cdef int n = 0 + cdef double[:] st = np.zeros(N1 + N2 - 2) # spike times + cdef double[:] c = np.zeros(N1 + N2 - 2) # coincidences + cdef double[:] mp = np.ones(N1 + N2 - 2) # multiplicity + cdef double tau + while n < N1 + N2 - 2: + if spikes1[i+1] < spikes2[j+1]: + i += 1 + n += 1 + tau = get_tau(spikes1, spikes2, i, j) + st[n] = spikes1[i] + if j > 0 and spikes1[i]-spikes2[j] < tau: + # coincidence between the current spike and the previous spike + # both get marked with 1 + c[n] = 1 + c[n-1] = 1 + elif spikes1[i+1] > spikes2[j+1]: + j += 1 + n += 1 + tau = get_tau(spikes1, spikes2, i, j) + st[n] = spikes2[j] + if i > 0 and spikes2[j]-spikes1[i] < tau: + # coincidence between the current spike and the previous spike + # both get marked with 1 + c[n] = 1 + c[n-1] = 1 + else: # spikes1[i+1] = spikes2[j+1] + # advance in both spike trains + j += 1 + i += 1 + if i == N1-1 or j == N2-1: + break + n += 1 + # add only one event, but with coincidence 2 and multiplicity 2 + st[n] = spikes1[i] + c[n] = 2 + mp[n] = 2 + + st = st[:n+2] + c = c[:n+2] + mp = mp[:n+2] + + st[0] = spikes1[0] + st[len(st)-1] = spikes1[len(spikes1)-1] + c[0] = c[1] + c[len(c)-1] = c[len(c)-2] + mp[0] = mp[1] + mp[len(mp)-1] = mp[len(mp)-2] + + return st, c, mp diff --git a/pyspike/distances.py b/pyspike/distances.py index 5ee8261..8bde724 100644 --- a/pyspike/distances.py +++ b/pyspike/distances.py @@ -139,9 +139,9 @@ def spike_sync_profile(spikes1, spikes2): from cython_distance import coincidence_cython \ as coincidence_impl except ImportError: -# print("Warning: spike_distance_cython not found. Make sure that \ -# PySpike is installed by running\n 'python setup.py build_ext --inplace'!\n \ -# Falling back to slow python backend.") + print("Warning: spike_distance_cython not found. Make sure that \ +PySpike is installed by running\n 'python setup.py build_ext --inplace'!\n \ +Falling back to slow python backend.") # use python backend from python_backend import coincidence_python \ as coincidence_impl diff --git a/test/test_distance.py b/test/test_distance.py index 4f8f6e8..6bdb049 100644 --- a/test/test_distance.py +++ b/test/test_distance.py @@ -199,11 +199,11 @@ def test_multi_spike(): def test_multi_spike_sync(): # some basic multivariate check spikes1 = np.array([100, 300, 400, 405, 410, 500, 700, 800, - 805, 810, 815, 900]) + 805, 810, 815, 900], dtype=float) spikes2 = np.array([100, 200, 205, 210, 295, 350, 400, 510, - 600, 605, 700, 910]) + 600, 605, 700, 910], dtype=float) spikes3 = np.array([100, 180, 198, 295, 412, 420, 510, 640, - 695, 795, 820, 920]) + 695, 795, 820, 920], dtype=float) spikes1 = spk.add_auxiliary_spikes(spikes1, 1000) spikes2 = spk.add_auxiliary_spikes(spikes2, 1000) spikes3 = spk.add_auxiliary_spikes(spikes3, 1000) -- cgit v1.2.3