From a5e6a12a619cb9528a4cf7f3ef8f082e5eb877c2 Mon Sep 17 00:00:00 2001 From: Mario Mulansky Date: Wed, 9 Sep 2015 17:51:03 +0200 Subject: added SPIKE-Sync based filtering new function filter_by_spike_sync removes spikes that have a multi-variate Spike Sync value below some threshold not yet fully tested, python backend missing. --- pyspike/cython/cython_profiles.pyx | 31 +++++++++++++++++++++++++++ pyspike/spike_sync.py | 40 ++++++++++++++++++++++++++++++++++- test/test_sync_filter.py | 43 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 113 insertions(+), 1 deletion(-) create mode 100644 test/test_sync_filter.py diff --git a/pyspike/cython/cython_profiles.pyx b/pyspike/cython/cython_profiles.pyx index 4a42cdb..fe08cb7 100644 --- a/pyspike/cython/cython_profiles.pyx +++ b/pyspike/cython/cython_profiles.pyx @@ -450,3 +450,34 @@ def coincidence_profile_cython(double[:] spikes1, double[:] spikes2, c[1] = 1 return st, c, mp + + +############################################################ +# coincidence_single_profile_cython +############################################################ +def coincidence_single_profile_cython(double[:] spikes1, double[:] spikes2, + double t_start, double t_end, double max_tau): + + cdef int N1 = len(spikes1) + cdef int N2 = len(spikes2) + cdef int j = -1 + cdef double[:] c = np.zeros(N1) # coincidences + cdef double interval = t_end - t_start + cdef double tau + for i in xrange(N1): + while j < N2-1 and spikes2[j+1] < spikes1[i]: + j += 1 + tau = get_tau(spikes1, spikes2, i, j, interval, max_tau) + print i, j, spikes1[i], spikes2[j], tau + if j > -1 and spikes1[i]-spikes2[j] < tau: + # current spike in st1 is coincident + c[i] = 1 + if j < N2-1: + j += 1 + tau = get_tau(spikes1, spikes2, i, j, interval, max_tau) + print i, j, spikes1[i], spikes2[j], tau + if spikes2[j]-spikes1[i] < tau: + # current spike in st1 is coincident + c[i] = 1 + + return c diff --git a/pyspike/spike_sync.py b/pyspike/spike_sync.py index 80f7805..d37731f 100644 --- a/pyspike/spike_sync.py +++ b/pyspike/spike_sync.py @@ -8,7 +8,7 @@ from __future__ import absolute_import import numpy as np from functools import partial import pyspike -from pyspike import DiscreteFunc +from pyspike import DiscreteFunc, SpikeTrain from pyspike.generic import _generic_profile_multi, _generic_distance_matrix @@ -290,3 +290,41 @@ def spike_sync_matrix(spike_trains, indices=None, interval=None, max_tau=None): dist_func = partial(spike_sync_bi, max_tau=max_tau) return _generic_distance_matrix(spike_trains, dist_func, indices, interval) + + +############################################################ +# filter_by_spike_sync +############################################################ +def filter_by_spike_sync(spike_trains, threshold, indices=None, max_tau=None): + """ Removes the spikes with a multi-variate spike_sync value below + threshold. + """ + N = len(spike_trains) + filtered_spike_trains = [] + + # cython implementation + try: + from cython.cython_profiles import coincidence_single_profile_cython \ + as coincidence_impl + except ImportError: + if not(pyspike.disable_backend_warning): + print("Warning: coincidence_single_profile_cytho not found. Make \ +sure that PySpike is installed by running\n \ +'python setup.py build_ext --inplace'!\n \ +Falling back to slow python backend.") + # use python backend + from cython.python_backend import coincidence_single_profile_python \ + as coincidence_impl + + if max_tau is None: + max_tau = 0.0 + + for i, st in enumerate(spike_trains): + coincidences = np.zeros_like(st) + for j in range(N).remove(i): + coincidences += coincidence_impl(st.spikes, spike_trains[j].spikes, + st.t_start, st.t_end, max_tau) + filtered_spikes = st[coincidences > threshold*(N-1)] + filtered_spike_trains.append(SpikeTrain(filtered_spikes, + [st.t_start, st.t_end])) + return filtered_spike_trains diff --git a/test/test_sync_filter.py b/test/test_sync_filter.py new file mode 100644 index 0000000..ce03b23 --- /dev/null +++ b/test/test_sync_filter.py @@ -0,0 +1,43 @@ +""" test_sync_filter.py + +Tests the spike sync based filtering + +Copyright 2015, Mario Mulansky + +Distributed under the BSD License + +""" + +from __future__ import print_function +import numpy as np +from numpy.testing import assert_equal, assert_almost_equal, \ + assert_array_almost_equal + +import pyspike as spk +from pyspike import SpikeTrain + + +def test_cython(): + st1 = np.array([1.0, 2.0, 3.0, 4.0]) + st2 = np.array([1.1, 2.1, 3.8]) + + # cython implementation + try: + from pyspike.cython.cython_profiles import coincidence_single_profile_cython \ + as coincidence_impl + except ImportError: + from pyspike.cython.python_backend import coincidence_single_profile_python \ + as coincidence_impl + + sync_prof = spk.spike_sync_profile(SpikeTrain(st1, 5.0), + SpikeTrain(st2, 5.0)) + + coincidences = np.array(coincidence_impl(st1, st2, 0, 5.0, 0.0)) + for i, t in enumerate(st1): + assert_equal(coincidences[i], sync_prof.y[sync_prof.x == t], + "At index %d" % i) + + coincidences = np.array(coincidence_impl(st2, st1, 0, 5.0, 0.0)) + for i, t in enumerate(st2): + assert_equal(coincidences[i], sync_prof.y[sync_prof.x == t], + "At index %d" % i) -- cgit v1.2.3