summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMario Mulansky <mario.mulansky@gmx.net>2015-10-10 20:45:09 +0200
committerMario Mulansky <mario.mulansky@gmx.net>2018-06-02 12:59:43 -0700
commit18ea80e2d01e9eb4ceee17219f91098efbcdf67c (patch)
treed7819736b059e9885d53c14e28160d6487d93e6c
parenta5e6a12a619cb9528a4cf7f3ef8f082e5eb877c2 (diff)
spike sync filtering, cython sim ann
Added function for filtering out events based on a threshold for the spike sync values. Usefull for focusing on synchronous events during directionality analysis. Also added cython version of simulated annealing for performance.
-rw-r--r--pyspike/__init__.py7
-rw-r--r--pyspike/cython/cython_distances.pyx200
-rw-r--r--pyspike/cython/cython_profiles.pyx14
-rw-r--r--pyspike/cython/cython_simulated_annealing.pyx82
-rw-r--r--pyspike/spike_directionality.py54
-rw-r--r--pyspike/spike_sync.py19
-rw-r--r--setup.py14
-rw-r--r--test/test_sync_filter.py61
8 files changed, 408 insertions, 43 deletions
diff --git a/pyspike/__init__.py b/pyspike/__init__.py
index 10d2936..61c5c4f 100644
--- a/pyspike/__init__.py
+++ b/pyspike/__init__.py
@@ -19,9 +19,10 @@ from .isi_distance import isi_profile, isi_distance, isi_profile_multi,\
isi_distance_multi, isi_distance_matrix
from .spike_distance import spike_profile, spike_distance, spike_profile_multi,\
spike_distance_multi, spike_distance_matrix
-from .spike_sync import spike_sync_profile, spike_sync,\
- spike_sync_profile_multi, spike_sync_multi, spike_sync_matrix
-from .psth import psth
+from spike_sync import spike_sync_profile, spike_sync,\
+ spike_sync_profile_multi, spike_sync_multi, spike_sync_matrix,\
+ filter_by_spike_sync
+from psth import psth
from .spikes import load_spike_trains_from_txt, save_spike_trains_to_txt, \
spike_train_from_string, import_spike_trains_from_time_series, \
diff --git a/pyspike/cython/cython_distances.pyx b/pyspike/cython/cython_distances.pyx
index ac5f226..d4070ae 100644
--- a/pyspike/cython/cython_distances.pyx
+++ b/pyspike/cython/cython_distances.pyx
@@ -178,6 +178,8 @@ cdef inline double isi_avrg_cython(double isi1, double isi2) nogil:
return 0.5*(isi1+isi2)*(isi1+isi2)
# alternative definition to obtain <S> ~ 0.5 for Poisson spikes
# return 0.5*(isi1*isi1+isi2*isi2)
+ # another alternative definition without second normalization
+ # return 0.5*(isi1+isi2)
############################################################
@@ -248,6 +250,8 @@ def spike_distance_cython(double[:] t1, double[:] t2,
index2 = 0
y_start = (s1*isi2 + s2*isi1) / isi_avrg_cython(isi1, isi2)
+ # alternative definition without second normalization
+ # y_start = (s1 + s2) / isi_avrg_cython(isi1, isi2)
index = 1
while index1+index2 < N1+N2-2:
@@ -267,6 +271,8 @@ def spike_distance_cython(double[:] t1, double[:] t2,
t_curr = t_p1
s2 = (dt_p2*(t_f2-t_p1) + dt_f2*(t_p1-t_p2)) / isi2
y_end = (s1*isi2 + s2*isi1)/isi_avrg_cython(isi1, isi2)
+ # alternative definition without second normalization
+ # y_end = (s1 + s2) / isi_avrg_cython(isi1, isi2)
spike_value += 0.5*(y_start + y_end) * (t_curr - t_last)
@@ -286,6 +292,8 @@ def spike_distance_cython(double[:] t1, double[:] t2,
s1 = dt_p1
# s2 is the same as above, thus we can compute y2 immediately
y_start = (s1*isi2 + s2*isi1)/isi_avrg_cython(isi1, isi2)
+ # alternative definition without second normalization
+ # y_start = (s1 + s2) / isi_avrg_cython(isi1, isi2)
elif (index2 < N2-1) and (t_f1 > t_f2 or index1 == N1-1):
index2 += 1
# first calculate the previous interval end value
@@ -301,6 +309,8 @@ def spike_distance_cython(double[:] t1, double[:] t2,
t_curr = t_p2
s1 = (dt_p1*(t_f1-t_p2) + dt_f1*(t_p2-t_p1)) / isi1
y_end = (s1*isi2 + s2*isi1) / isi_avrg_cython(isi1, isi2)
+ # alternative definition without second normalization
+ # y_end = (s1 + s2) / isi_avrg_cython(isi1, isi2)
spike_value += 0.5*(y_start + y_end) * (t_curr - t_last)
@@ -320,6 +330,9 @@ def spike_distance_cython(double[:] t1, double[:] t2,
s2 = dt_p2
# s1 is the same as above, thus we can compute y2 immediately
y_start = (s1*isi2 + s2*isi1)/isi_avrg_cython(isi1, isi2)
+ # alternative definition without second normalization
+ # y_start = (s1 + s2) / isi_avrg_cython(isi1, isi2)
+
else: # t_f1 == t_f2 - generate only one event
index1 += 1
index2 += 1
@@ -358,6 +371,193 @@ def spike_distance_cython(double[:] t1, double[:] t2,
s1 = dt_f1 # *(t_end-t1[N1-1])/isi1
s2 = dt_f2 # *(t_end-t2[N2-1])/isi2
y_end = (s1*isi2 + s2*isi1) / isi_avrg_cython(isi1, isi2)
+ # alternative definition without second normalization
+ # y_end = (s1 + s2) / isi_avrg_cython(isi1, isi2)
+
+ spike_value += 0.5*(y_start + y_end) * (t_end - t_last)
+ # end nogil
+
+ # use only the data added above
+ # could be less than original length due to equal spike times
+ return spike_value / (t_end-t_start)
+
+
+############################################################
+# isi_avrg_rf_cython
+############################################################
+cdef inline double isi_avrg_rf_cython(double isi1, double isi2) nogil:
+ # rate free version
+ return (isi1+isi2)
+
+
+############################################################
+# spike_distance_rf_cython
+############################################################
+def spike_distance_rf_cython(double[:] t1, double[:] t2,
+ double t_start, double t_end):
+
+ cdef int N1, N2, index1, index2, index
+ cdef double t_p1, t_f1, t_p2, t_f2, dt_p1, dt_p2, dt_f1, dt_f2
+ cdef double isi1, isi2, s1, s2
+ cdef double y_start, y_end, t_last, t_current, spike_value
+
+ spike_value = 0.0
+
+ N1 = len(t1)
+ N2 = len(t2)
+
+ with nogil: # release the interpreter to allow multithreading
+ t_last = t_start
+ t_p1 = t_start
+ t_p2 = t_start
+ if t1[0] > t_start:
+ # dt_p1 = t2[0]-t_start
+ t_f1 = t1[0]
+ dt_f1 = get_min_dist_cython(t_f1, t2, N2, 0, t_start, t_end)
+ isi1 = fmax(t_f1-t_start, t1[1]-t1[0])
+ dt_p1 = dt_f1
+ s1 = dt_p1*(t_f1-t_start)/isi1
+ index1 = -1
+ else:
+ t_f1 = t1[1]
+ dt_f1 = get_min_dist_cython(t_f1, t2, N2, 0, t_start, t_end)
+ dt_p1 = 0.0
+ isi1 = t1[1]-t1[0]
+ s1 = dt_p1
+ index1 = 0
+ if t2[0] > t_start:
+ # dt_p1 = t2[0]-t_start
+ t_f2 = t2[0]
+ dt_f2 = get_min_dist_cython(t_f2, t1, N1, 0, t_start, t_end)
+ dt_p2 = dt_f2
+ isi2 = fmax(t_f2-t_start, t2[1]-t2[0])
+ s2 = dt_p2*(t_f2-t_start)/isi2
+ index2 = -1
+ else:
+ t_f2 = t2[1]
+ dt_f2 = get_min_dist_cython(t_f2, t1, N1, 0, t_start, t_end)
+ dt_p2 = 0.0
+ isi2 = t2[1]-t2[0]
+ s2 = dt_p2
+ index2 = 0
+
+ # y_start = (s1*isi2 + s2*isi1) / isi_avrg_cython(isi1, isi2)
+ # alternative definition without second normalization
+ y_start = (s1 + s2) / isi_avrg_rf_cython(isi1, isi2)
+ index = 1
+
+ while index1+index2 < N1+N2-2:
+ # print(index, index1, index2)
+ if (index1 < N1-1) and (t_f1 < t_f2 or index2 == N2-1):
+ index1 += 1
+ # first calculate the previous interval end value
+ s1 = dt_f1*(t_f1-t_p1) / isi1
+ # the previous time now was the following time before:
+ dt_p1 = dt_f1
+ t_p1 = t_f1 # t_p1 contains the current time point
+ # get the next time
+ if index1 < N1-1:
+ t_f1 = t1[index1+1]
+ else:
+ t_f1 = t_end
+ t_curr = t_p1
+ s2 = (dt_p2*(t_f2-t_p1) + dt_f2*(t_p1-t_p2)) / isi2
+ # y_end = (s1*isi2 + s2*isi1)/isi_avrg_cython(isi1, isi2)
+ # alternative definition without second normalization
+ y_end = (s1 + s2) / isi_avrg_rf_cython(isi1, isi2)
+
+ spike_value += 0.5*(y_start + y_end) * (t_curr - t_last)
+
+ # now the next interval start value
+ if index1 < N1-1:
+ dt_f1 = get_min_dist_cython(t_f1, t2, N2, index2,
+ t_start, t_end)
+ isi1 = t_f1-t_p1
+ s1 = dt_p1
+ else:
+ dt_f1 = dt_p1
+ isi1 = fmax(t_end-t1[N1-1], t1[N1-1]-t1[N1-2])
+ # s1 needs adjustment due to change of isi1
+ s1 = dt_p1*(t_end-t1[N1-1])/isi1
+ # s2 is the same as above, thus we can compute y2 immediately
+ # y_start = (s1*isi2 + s2*isi1)/isi_avrg_cython(isi1, isi2)
+ # alternative definition without second normalization
+ y_start = (s1 + s2) / isi_avrg_rf_cython(isi1, isi2)
+ elif (index2 < N2-1) and (t_f1 > t_f2 or index1 == N1-1):
+ index2 += 1
+ # first calculate the previous interval end value
+ s2 = dt_f2*(t_f2-t_p2) / isi2
+ # the previous time now was the following time before:
+ dt_p2 = dt_f2
+ t_p2 = t_f2 # t_p2 contains the current time point
+ # get the next time
+ if index2 < N2-1:
+ t_f2 = t2[index2+1]
+ else:
+ t_f2 = t_end
+ t_curr = t_p2
+ s1 = (dt_p1*(t_f1-t_p2) + dt_f1*(t_p2-t_p1)) / isi1
+ # y_end = (s1*isi2 + s2*isi1) / isi_avrg_cython(isi1, isi2)
+ # alternative definition without second normalization
+ y_end = (s1 + s2) / isi_avrg_rf_cython(isi1, isi2)
+
+ spike_value += 0.5*(y_start + y_end) * (t_curr - t_last)
+
+ # now the next interval start value
+ if index2 < N2-1:
+ dt_f2 = get_min_dist_cython(t_f2, t1, N1, index1,
+ t_start, t_end)
+ isi2 = t_f2-t_p2
+ s2 = dt_p2
+ else:
+ dt_f2 = dt_p2
+ isi2 = fmax(t_end-t2[N2-1], t2[N2-1]-t2[N2-2])
+ # s2 needs adjustment due to change of isi2
+ s2 = dt_p2*(t_end-t2[N2-1])/isi2
+ # s1 is the same as above, thus we can compute y2 immediately
+ # y_start = (s1*isi2 + s2*isi1)/isi_avrg_cython(isi1, isi2)
+ # alternative definition without second normalization
+ y_start = (s1 + s2) / isi_avrg_rf_cython(isi1, isi2)
+
+ else: # t_f1 == t_f2 - generate only one event
+ index1 += 1
+ index2 += 1
+ t_p1 = t_f1
+ t_p2 = t_f2
+ dt_p1 = 0.0
+ dt_p2 = 0.0
+ t_curr = t_f1
+ y_end = 0.0
+ spike_value += 0.5*(y_start + y_end) * (t_curr - t_last)
+ y_start = 0.0
+ if index1 < N1-1:
+ t_f1 = t1[index1+1]
+ dt_f1 = get_min_dist_cython(t_f1, t2, N2, index2,
+ t_start, t_end)
+ isi1 = t_f1 - t_p1
+ else:
+ t_f1 = t_end
+ dt_f1 = dt_p1
+ isi1 = fmax(t_end-t1[N1-1], t1[N1-1]-t1[N1-2])
+ if index2 < N2-1:
+ t_f2 = t2[index2+1]
+ dt_f2 = get_min_dist_cython(t_f2, t1, N1, index1,
+ t_start, t_end)
+ isi2 = t_f2 - t_p2
+ else:
+ t_f2 = t_end
+ dt_f2 = dt_p2
+ isi2 = fmax(t_end-t2[N2-1], t2[N2-1]-t2[N2-2])
+ index += 1
+ t_last = t_curr
+ # isi1 = max(t_end-t1[N1-1], t1[N1-1]-t1[N1-2])
+ # isi2 = max(t_end-t2[N2-1], t2[N2-1]-t2[N2-2])
+ s1 = dt_f1*(t_end-t1[N1-1])/isi1
+ s2 = dt_f2*(t_end-t2[N2-1])/isi2
+ # y_end = (s1*isi2 + s2*isi1) / isi_avrg_cython(isi1, isi2)
+ # alternative definition without second normalization
+ y_end = (s1 + s2) / isi_avrg_rf_cython(isi1, isi2)
+
spike_value += 0.5*(y_start + y_end) * (t_end - t_last)
# end nogil
diff --git a/pyspike/cython/cython_profiles.pyx b/pyspike/cython/cython_profiles.pyx
index fe08cb7..eb4d157 100644
--- a/pyspike/cython/cython_profiles.pyx
+++ b/pyspike/cython/cython_profiles.pyx
@@ -466,18 +466,20 @@ def coincidence_single_profile_cython(double[:] spikes1, double[:] spikes2,
cdef double tau
for i in xrange(N1):
while j < N2-1 and spikes2[j+1] < spikes1[i]:
+ # move forward until spikes2[j] is the last spike before spikes1[i]
+ # note that if spikes2[j] is after spikes1[i] we dont do anything
j += 1
tau = get_tau(spikes1, spikes2, i, j, interval, max_tau)
- print i, j, spikes1[i], spikes2[j], tau
- if j > -1 and spikes1[i]-spikes2[j] < tau:
+ if j > -1 and fabs(spikes1[i]-spikes2[j]) < tau:
# current spike in st1 is coincident
c[i] = 1
- if j < N2-1:
+ if j < N2-1 and spikes2[j] < spikes1[i]:
+ # in case spikes2[j] is before spikes1[i] it has to be the one
+ # right before (see above), hence we move one forward and also
+ # check the next spike
j += 1
tau = get_tau(spikes1, spikes2, i, j, interval, max_tau)
- print i, j, spikes1[i], spikes2[j], tau
- if spikes2[j]-spikes1[i] < tau:
+ if fabs(spikes2[j]-spikes1[i]) < tau:
# current spike in st1 is coincident
c[i] = 1
-
return c
diff --git a/pyspike/cython/cython_simulated_annealing.pyx b/pyspike/cython/cython_simulated_annealing.pyx
new file mode 100644
index 0000000..be9423c
--- /dev/null
+++ b/pyspike/cython/cython_simulated_annealing.pyx
@@ -0,0 +1,82 @@
+#cython: boundscheck=False
+#cython: wraparound=False
+#cython: cdivision=True
+
+"""
+cython_simulated_annealing.pyx
+
+cython implementation of a simulated annealing algorithm to find the optimal
+spike train order
+
+Note: using cython memoryviews (e.g. double[:]) instead of ndarray objects
+improves the performance of spike_distance by a factor of 10!
+
+Copyright 2015, Mario Mulansky <mario.mulansky@gmx.net>
+
+Distributed under the BSD License
+
+"""
+
+"""
+To test whether things can be optimized: remove all yellow stuff
+in the html output::
+
+ cython -a cython_simulated_annealing.pyx
+
+which gives:
+
+ cython_simulated_annealing.html
+
+"""
+
+import numpy as np
+cimport numpy as np
+
+from libc.math cimport exp
+from libc.math cimport fmod
+from libc.stdlib cimport rand
+from libc.stdlib cimport RAND_MAX
+
+DTYPE = np.float
+ctypedef np.float_t DTYPE_t
+
+
+def sim_ann_cython(double[:, :] D, double T_start, double T_end, double alpha):
+
+ cdef long N = len(D)
+ cdef double A = np.sum(np.triu(D, 0))
+ cdef long[:] p = np.arange(N)
+ cdef double T = T_start
+ cdef long iterations
+ cdef long succ_iter
+ cdef long total_iter = 0
+ cdef double delta_A
+ cdef long ind1
+ cdef long ind2
+
+ while T > T_end:
+ iterations = 0
+ succ_iter = 0
+ # equilibrate for 100*N steps or 10*N successful steps
+ while iterations < 100*N and succ_iter < 10*N:
+ # exchange two rows and cols
+ # ind1 = np.random.randint(N-1)
+ ind1 = rand() % (N-1)
+ if ind1 < N-1:
+ ind2 = ind1+1
+ else: # this can never happen!
+ ind2 = 0
+ delta_A = -2*D[p[ind1], p[ind2]]
+ if delta_A > 0.0 or exp(delta_A/T) > ((1.0*rand()) / RAND_MAX):
+ # swap indices
+ p[ind1], p[ind2] = p[ind2], p[ind1]
+ A += delta_A
+ succ_iter += 1
+ iterations += 1
+ total_iter += iterations
+ T *= alpha # cool down
+ if succ_iter == 0:
+ # no successful step -> we believe we have converged
+ break
+
+ return p, A, total_iter
diff --git a/pyspike/spike_directionality.py b/pyspike/spike_directionality.py
index cda7fe3..e1f5f16 100644
--- a/pyspike/spike_directionality.py
+++ b/pyspike/spike_directionality.py
@@ -242,27 +242,39 @@ def optimal_spike_train_order_from_matrix(D, full_output=False):
p = np.arange(N)
- T = 2*np.max(D) # starting temperature
- T_end = 1E-5 * T # final temperature
- alpha = 0.9 # cooling factor
- total_iter = 0
- while T > T_end:
- iterations = 0
- succ_iter = 0
- while iterations < 100*N and succ_iter < 10*N:
- # exchange two rows and cols
- ind1 = np.random.randint(N-1)
- delta_A = -2*D[p[ind1], p[ind1+1]]
- if delta_A > 0.0 or exp(delta_A/T) > np.random.random():
- # swap indices
- p[ind1], p[ind1+1] = p[ind1+1], p[ind1]
- A += delta_A
- succ_iter += 1
- iterations += 1
- total_iter += iterations
- T *= alpha # cool down
- if succ_iter == 0:
- break
+ T_start = 2*np.max(D) # starting temperature
+ T_end = 1E-5 * T_start # final temperature
+ alpha = 0.9 # cooling factor
+
+ from cython.cython_simulated_annealing import sim_ann_cython as sim_ann
+
+ p, A, total_iter = sim_ann(D, T_start, T_end, alpha)
+
+ # T = T_start
+ # total_iter = 0
+ # while T > T_end:
+ # iterations = 0
+ # succ_iter = 0
+ # # equilibrate for 100*N steps or 10*N successful steps
+ # while iterations < 100*N and succ_iter < 10*N:
+ # # exchange two rows and cols
+ # ind1 = np.random.randint(N-1)
+ # if ind1 < N-1:
+ # ind2 = ind1+1
+ # else: # this can never happend
+ # ind2 = 0
+ # delta_A = -2*D[p[ind1], p[ind2]]
+ # if delta_A > 0.0 or exp(delta_A/T) > np.random.random():
+ # # swap indices
+ # p[ind1], p[ind2] = p[ind2], p[ind1]
+ # A += delta_A
+ # succ_iter += 1
+ # iterations += 1
+ # total_iter += iterations
+ # T *= alpha # cool down
+ # if succ_iter == 0:
+ # break
+
if full_output:
return p, A, total_iter
else:
diff --git a/pyspike/spike_sync.py b/pyspike/spike_sync.py
index d37731f..1d2ecdb 100644
--- a/pyspike/spike_sync.py
+++ b/pyspike/spike_sync.py
@@ -295,12 +295,14 @@ def spike_sync_matrix(spike_trains, indices=None, interval=None, max_tau=None):
############################################################
# filter_by_spike_sync
############################################################
-def filter_by_spike_sync(spike_trains, threshold, indices=None, max_tau=None):
+def filter_by_spike_sync(spike_trains, threshold, indices=None, max_tau=None,
+ return_removed_spikes=False):
""" Removes the spikes with a multi-variate spike_sync value below
threshold.
"""
N = len(spike_trains)
filtered_spike_trains = []
+ removed_spike_trains = []
# cython implementation
try:
@@ -308,7 +310,7 @@ def filter_by_spike_sync(spike_trains, threshold, indices=None, max_tau=None):
as coincidence_impl
except ImportError:
if not(pyspike.disable_backend_warning):
- print("Warning: coincidence_single_profile_cytho not found. Make \
+ print("Warning: coincidence_single_profile_cython not found. Make \
sure that PySpike is installed by running\n \
'python setup.py build_ext --inplace'!\n \
Falling back to slow python backend.")
@@ -321,10 +323,19 @@ Falling back to slow python backend.")
for i, st in enumerate(spike_trains):
coincidences = np.zeros_like(st)
- for j in range(N).remove(i):
+ for j in xrange(N):
+ if i == j:
+ continue
coincidences += coincidence_impl(st.spikes, spike_trains[j].spikes,
st.t_start, st.t_end, max_tau)
filtered_spikes = st[coincidences > threshold*(N-1)]
filtered_spike_trains.append(SpikeTrain(filtered_spikes,
[st.t_start, st.t_end]))
- return filtered_spike_trains
+ if return_removed_spikes:
+ removed_spikes = st[coincidences <= threshold*(N-1)]
+ removed_spike_trains.append(SpikeTrain(removed_spikes,
+ [st.t_start, st.t_end]))
+ if return_removed_spikes:
+ return [filtered_spike_trains, removed_spike_trains]
+ else:
+ return filtered_spike_trains
diff --git a/setup.py b/setup.py
index 9ba1da6..808a122 100644
--- a/setup.py
+++ b/setup.py
@@ -31,7 +31,8 @@ class numpy_include(object):
if os.path.isfile("pyspike/cython/cython_add.c") and \
os.path.isfile("pyspike/cython/cython_profiles.c") and \
os.path.isfile("pyspike/cython/cython_distances.c") and \
- os.path.isfile("pyspike/cython/cython_directionality.c"):
+ os.path.isfile("pyspike/cython/cython_directionality.c") and \
+ os.path.isfile("pyspike/cython/cython_simulated_annealing.c"):
use_c = True
else:
use_c = False
@@ -48,7 +49,9 @@ if use_cython: # Cython is available, compile .pyx -> .c
Extension("pyspike.cython.cython_distances",
["pyspike/cython/cython_distances.pyx"]),
Extension("pyspike.cython.cython_directionality",
- ["pyspike/cython/cython_directionality.pyx"])
+ ["pyspike/cython/cython_directionality.pyx"]),
+ Extension("pyspike.cython.cython_simulated_annealing",
+ ["pyspike/cython/cython_simulated_annealing.pyx"])
]
cmdclass.update({'build_ext': build_ext})
elif use_c: # c files are there, compile to binaries
@@ -60,7 +63,9 @@ elif use_c: # c files are there, compile to binaries
Extension("pyspike.cython.cython_distances",
["pyspike/cython/cython_distances.c"]),
Extension("pyspike.cython.cython_directionality",
- ["pyspike/cython/cython_directionality.c"])
+ ["pyspike/cython/cython_directionality.c"]),
+ Extension("pyspike.cython.cython_simulated_annealing",
+ ["pyspike/cython/cython_simulated_annealing.c"])
]
# neither cython nor c files available -> automatic fall-back to python backend
@@ -105,7 +110,8 @@ train similarity',
package_data={
'pyspike': ['cython/cython_add.c', 'cython/cython_profiles.c',
'cython/cython_distances.c',
- 'cython/cython_directionality.c'],
+ 'cython/cython_directionality.c',
+ 'cython/cython_simulated_annealing.c'],
'test': ['Spike_testdata.txt']
}
)
diff --git a/test/test_sync_filter.py b/test/test_sync_filter.py
index ce03b23..66ffcb6 100644
--- a/test/test_sync_filter.py
+++ b/test/test_sync_filter.py
@@ -17,17 +17,18 @@ import pyspike as spk
from pyspike import SpikeTrain
-def test_cython():
+def test_single_prof():
st1 = np.array([1.0, 2.0, 3.0, 4.0])
st2 = np.array([1.1, 2.1, 3.8])
+ st3 = np.array([0.9, 3.1, 4.1])
# cython implementation
try:
- from pyspike.cython.cython_profiles import coincidence_single_profile_cython \
- as coincidence_impl
+ from pyspike.cython.cython_profiles import \
+ coincidence_single_profile_cython as coincidence_impl
except ImportError:
- from pyspike.cython.python_backend import coincidence_single_profile_python \
- as coincidence_impl
+ from pyspike.cython.python_backend import \
+ coincidence_single_profile_python as coincidence_impl
sync_prof = spk.spike_sync_profile(SpikeTrain(st1, 5.0),
SpikeTrain(st2, 5.0))
@@ -41,3 +42,53 @@ def test_cython():
for i, t in enumerate(st2):
assert_equal(coincidences[i], sync_prof.y[sync_prof.x == t],
"At index %d" % i)
+
+ sync_prof = spk.spike_sync_profile(SpikeTrain(st1, 5.0),
+ SpikeTrain(st3, 5.0))
+
+ coincidences = np.array(coincidence_impl(st1, st3, 0, 5.0, 0.0))
+ for i, t in enumerate(st1):
+ assert_equal(coincidences[i], sync_prof.y[sync_prof.x == t],
+ "At index %d" % i)
+
+ st1 = np.array([1.0, 2.0, 3.0, 4.0])
+ st2 = np.array([1.0, 2.0, 4.0])
+
+ sync_prof = spk.spike_sync_profile(SpikeTrain(st1, 5.0),
+ SpikeTrain(st2, 5.0))
+
+ coincidences = np.array(coincidence_impl(st1, st2, 0, 5.0, 0.0))
+ for i, t in enumerate(st1):
+ expected = sync_prof.y[sync_prof.x == t]/sync_prof.mp[sync_prof.x == t]
+ assert_equal(coincidences[i], expected,
+ "At index %d" % i)
+
+
+def test_filter():
+ st1 = SpikeTrain(np.array([1.0, 2.0, 3.0, 4.0]), 5.0)
+ st2 = SpikeTrain(np.array([1.1, 2.1, 3.8]), 5.0)
+ st3 = SpikeTrain(np.array([0.9, 3.1, 4.1]), 5.0)
+
+ # filtered_spike_trains = spk.filter_by_spike_sync([st1, st2], 0.5)
+
+ # assert_equal(filtered_spike_trains[0].spikes, [1.0, 2.0, 4.0])
+ # assert_equal(filtered_spike_trains[1].spikes, [1.1, 2.1, 3.8])
+
+ # filtered_spike_trains = spk.filter_by_spike_sync([st2, st1], 0.5)
+
+ # assert_equal(filtered_spike_trains[0].spikes, [1.1, 2.1, 3.8])
+ # assert_equal(filtered_spike_trains[1].spikes, [1.0, 2.0, 4.0])
+
+ filtered_spike_trains = spk.filter_by_spike_sync([st1, st2, st3], 0.75)
+
+ for st in filtered_spike_trains:
+ print(st.spikes)
+
+ assert_equal(filtered_spike_trains[0].spikes, [1.0, 4.0])
+ assert_equal(filtered_spike_trains[1].spikes, [1.1, 3.8])
+ assert_equal(filtered_spike_trains[2].spikes, [0.9, 4.1])
+
+
+if __name__ == "main":
+ test_single_prof()
+ test_filter()