summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMario Mulansky <mario.mulansky@gmx.net>2015-04-02 20:11:37 +0200
committerMario Mulansky <mario.mulansky@gmx.net>2015-04-02 20:11:37 +0200
commit4366f30d7a27a9aafdf0efc2192f4780706d439b (patch)
tree2e9839f82adc2d1346b6fa56f3123aa32414e295
parent06a72795731c69340685e4bc2a8379626343b56e (diff)
added max_tau to spike_sync functions
-rw-r--r--pyspike/cython/cython_distance.pyx14
-rw-r--r--pyspike/spike_sync.py36
-rw-r--r--test/test_distance.py4
3 files changed, 39 insertions, 15 deletions
diff --git a/pyspike/cython/cython_distance.pyx b/pyspike/cython/cython_distance.pyx
index 489aab9..2834ca5 100644
--- a/pyspike/cython/cython_distance.pyx
+++ b/pyspike/cython/cython_distance.pyx
@@ -236,7 +236,8 @@ def spike_distance_cython(double[:] t1,
############################################################
# coincidence_python
############################################################
-cdef inline double get_tau(double[:] spikes1, double[:] spikes2, int i, int j):
+cdef inline double get_tau(double[:] spikes1, double[:] spikes2,
+ int i, int j, max_tau):
cdef double m = 1E100 # some huge number
cdef int N1 = len(spikes1)-2
cdef int N2 = len(spikes2)-2
@@ -248,13 +249,16 @@ cdef inline double get_tau(double[:] spikes1, double[:] spikes2, int i, int j):
m = fmin(m, spikes1[i]-spikes1[i-1])
if j > 1:
m = fmin(m, spikes2[j]-spikes2[j-1])
- return 0.5*m
+ m *= 0.5
+ if max_tau > 0.0:
+ m = fmin(m, max_tau)
+ return m
############################################################
# coincidence_cython
############################################################
-def coincidence_cython(double[:] spikes1, double[:] spikes2):
+def coincidence_cython(double[:] spikes1, double[:] spikes2, double max_tau):
cdef int N1 = len(spikes1)
cdef int N2 = len(spikes2)
@@ -269,7 +273,7 @@ def coincidence_cython(double[:] spikes1, double[:] spikes2):
if spikes1[i+1] < spikes2[j+1]:
i += 1
n += 1
- tau = get_tau(spikes1, spikes2, i, j)
+ tau = get_tau(spikes1, spikes2, i, j, max_tau)
st[n] = spikes1[i]
if j > 0 and spikes1[i]-spikes2[j] < tau:
# coincidence between the current spike and the previous spike
@@ -279,7 +283,7 @@ def coincidence_cython(double[:] spikes1, double[:] spikes2):
elif spikes1[i+1] > spikes2[j+1]:
j += 1
n += 1
- tau = get_tau(spikes1, spikes2, i, j)
+ tau = get_tau(spikes1, spikes2, i, j, max_tau)
st[n] = spikes2[j]
if i > 0 and spikes2[j]-spikes1[i] < tau:
# coincidence between the current spike and the previous spike
diff --git a/pyspike/spike_sync.py b/pyspike/spike_sync.py
index 342bf69..e12ebb8 100644
--- a/pyspike/spike_sync.py
+++ b/pyspike/spike_sync.py
@@ -16,7 +16,7 @@ from pyspike.generic import _generic_profile_multi, _generic_distance_matrix
############################################################
# spike_sync_profile
############################################################
-def spike_sync_profile(spikes1, spikes2):
+def spike_sync_profile(spikes1, spikes2, max_tau=None):
""" Computes the spike-synchronization profile S_sync(t) of the two given
spike trains. Returns the profile as a DiscreteFunction object. The S_sync
values are either 1 or 0, indicating the presence or absence of a
@@ -26,6 +26,8 @@ def spike_sync_profile(spikes1, spikes2):
:param spikes1: ordered array of spike times with auxiliary spikes.
:param spikes2: ordered array of spike times with auxiliary spikes.
+ :param max_tau: Maximum coincidence window size. If 0 or `None`, the
+ coincidence window has no upper bound.
:returns: The spike-distance profile :math:`S_{sync}(t)`.
:rtype: :class:`pyspike.function.DiscreteFunction`
@@ -43,7 +45,11 @@ Falling back to slow python backend.")
from cython.python_backend import coincidence_python \
as coincidence_impl
- times, coincidences, multiplicity = coincidence_impl(spikes1, spikes2)
+ if max_tau is None:
+ max_tau = 0.0
+
+ times, coincidences, multiplicity = coincidence_impl(spikes1, spikes2,
+ max_tau)
return DiscreteFunc(times, coincidences, multiplicity)
@@ -51,7 +57,7 @@ Falling back to slow python backend.")
############################################################
# spike_sync
############################################################
-def spike_sync(spikes1, spikes2, interval=None):
+def spike_sync(spikes1, spikes2, interval=None, max_tau=None):
""" Computes the spike synchronization value SYNC of the given spike
trains. The spike synchronization value is the computed as the total number
of coincidences divided by the total number of spikes:
@@ -63,16 +69,18 @@ def spike_sync(spikes1, spikes2, interval=None):
:param interval: averaging interval given as a pair of floats (T0, T1),
if None the average over the whole function is computed.
:type interval: Pair of floats or None.
+ :param max_tau: Maximum coincidence window size. If 0 or `None`, the
+ coincidence window has no upper bound.
:returns: The spike synchronization value.
:rtype: double
"""
- return spike_sync_profile(spikes1, spikes2).avrg(interval)
+ return spike_sync_profile(spikes1, spikes2, max_tau).avrg(interval)
############################################################
# spike_sync_profile_multi
############################################################
-def spike_sync_profile_multi(spike_trains, indices=None):
+def spike_sync_profile_multi(spike_trains, indices=None, max_tau=None):
""" Computes the multi-variate spike synchronization profile for a set of
spike trains. For each spike in the set of spike trains, the multi-variate
profile is defined as the number of coincidences divided by the number of
@@ -83,11 +91,13 @@ def spike_sync_profile_multi(spike_trains, indices=None):
:param indices: list of indices defining which spike trains to use,
if None all given spike trains are used (default=None)
:type indices: list or None
+ :param max_tau: Maximum coincidence window size. If 0 or `None`, the
+ coincidence window has no upper bound.
:returns: The multi-variate spike sync profile :math:`<S_{sync}>(t)`
:rtype: :class:`pyspike.function.DiscreteFunction`
"""
- prof_func = partial(spike_sync_profile)
+ prof_func = partial(spike_sync_profile, max_tau=max_tau)
average_dist, M = _generic_profile_multi(spike_trains, prof_func,
indices)
# average_dist.mul_scalar(1.0/M) # no normalization here!
@@ -97,7 +107,7 @@ def spike_sync_profile_multi(spike_trains, indices=None):
############################################################
# spike_sync_multi
############################################################
-def spike_sync_multi(spike_trains, indices=None, interval=None):
+def spike_sync_multi(spike_trains, indices=None, interval=None, max_tau=None):
""" Computes the multi-variate spike synchronization value for a set of
spike trains.
@@ -108,16 +118,19 @@ def spike_sync_multi(spike_trains, indices=None, interval=None):
:param interval: averaging interval given as a pair of floats, if None
the average over the whole function is computed.
:type interval: Pair of floats or None.
+ :param max_tau: Maximum coincidence window size. If 0 or `None`, the
+ coincidence window has no upper bound.
:returns: The multi-variate spike synchronization value SYNC.
:rtype: double
"""
- return spike_sync_profile_multi(spike_trains, indices).avrg(interval)
+ return spike_sync_profile_multi(spike_trains, indices,
+ max_tau).avrg(interval)
############################################################
# spike_sync_matrix
############################################################
-def spike_sync_matrix(spike_trains, indices=None, interval=None):
+def spike_sync_matrix(spike_trains, indices=None, interval=None, max_tau=None):
""" Computes the overall spike-synchronization value of all pairs of
spike-trains.
@@ -128,9 +141,12 @@ def spike_sync_matrix(spike_trains, indices=None, interval=None):
:param interval: averaging interval given as a pair of floats, if None
the average over the whole function is computed.
:type interval: Pair of floats or None.
+ :param max_tau: Maximum coincidence window size. If 0 or `None`, the
+ coincidence window has no upper bound.
:returns: 2D array with the pair wise time spike synchronization values
:math:`SYNC_{ij}`
:rtype: np.array
"""
- return _generic_distance_matrix(spike_trains, spike_sync,
+ dist_func = partial(spike_sync, max_tau=max_tau)
+ return _generic_distance_matrix(spike_trains, dist_func,
indices, interval)
diff --git a/test/test_distance.py b/test/test_distance.py
index 41f625e..ba19f5e 100644
--- a/test/test_distance.py
+++ b/test/test_distance.py
@@ -138,6 +138,10 @@ def test_spike_sync():
assert_almost_equal(spk.spike_sync(spikes1, spikes2),
0.5, decimal=16)
+ # test with some small max_tau, spike_sync should be 0
+ assert_almost_equal(spk.spike_sync(spikes1, spikes2, max_tau=0.05),
+ 0.0, decimal=16)
+
spikes2 = np.array([3.1])
spikes2 = spk.add_auxiliary_spikes(spikes2, 4.0)
assert_almost_equal(spk.spike_sync(spikes1, spikes2),