summaryrefslogtreecommitdiff
path: root/pyspike/spike_sync.py
diff options
context:
space:
mode:
Diffstat (limited to 'pyspike/spike_sync.py')
-rw-r--r--pyspike/spike_sync.py79
1 files changed, 65 insertions, 14 deletions
diff --git a/pyspike/spike_sync.py b/pyspike/spike_sync.py
index 9d2e363..40d98d2 100644
--- a/pyspike/spike_sync.py
+++ b/pyspike/spike_sync.py
@@ -3,6 +3,7 @@
# Copyright 2014-2015, Mario Mulansky <mario.mulansky@gmx.net>
# Distributed under the BSD License
+import numpy as np
from functools import partial
from pyspike import DiscreteFunc
from pyspike.generic import _generic_profile_multi, _generic_distance_matrix
@@ -29,35 +30,68 @@ def spike_sync_profile(spike_train1, spike_train2, max_tau=None):
"""
# check whether the spike trains are defined for the same interval
assert spike_train1.t_start == spike_train2.t_start, \
- "Given spike trains seems not to have auxiliary spikes!"
+ "Given spike trains are not defined on the same interval!"
assert spike_train1.t_end == spike_train2.t_end, \
- "Given spike trains seems not to have auxiliary spikes!"
+ "Given spike trains are not defined on the same interval!"
# cython implementation
try:
- from cython.cython_distance import coincidence_cython \
- as coincidence_impl
+ from cython.cython_profiles import coincidence_profile_cython \
+ as coincidence_profile_impl
except ImportError:
print("Warning: spike_distance_cython not found. Make sure that \
PySpike is installed by running\n 'python setup.py build_ext --inplace'!\n \
Falling back to slow python backend.")
# use python backend
from cython.python_backend import coincidence_python \
- as coincidence_impl
+ as coincidence_profile_impl
if max_tau is None:
max_tau = 0.0
- times, coincidences, multiplicity = coincidence_impl(spike_train1.spikes,
- spike_train2.spikes,
- spike_train1.t_start,
- spike_train1.t_end,
- max_tau)
+ times, coincidences, multiplicity \
+ = coincidence_profile_impl(spike_train1.spikes, spike_train2.spikes,
+ spike_train1.t_start, spike_train1.t_end,
+ max_tau)
return DiscreteFunc(times, coincidences, multiplicity)
############################################################
+# _spike_sync_values
+############################################################
+def _spike_sync_values(spike_train1, spike_train2, interval, max_tau):
+ """" Internal function. Computes the summed coincidences and multiplicity
+ for spike synchronization of the two given spike trains.
+
+ Do not call this function directly, use `spike_sync` or `spike_sync_multi`
+ instead.
+ """
+ if interval is None:
+ # distance over the whole interval is requested: use specific function
+ # for optimal performance
+ try:
+ from cython.cython_distances import coincidence_value_cython \
+ as coincidence_value_impl
+ if max_tau is None:
+ max_tau = 0.0
+ c, mp = coincidence_value_impl(spike_train1.spikes,
+ spike_train2.spikes,
+ spike_train1.t_start,
+ spike_train1.t_end,
+ max_tau)
+ return c, mp
+ except ImportError:
+ # Cython backend not available: fall back to profile averaging
+ return spike_sync_profile(spike_train1, spike_train2,
+ max_tau).integral(interval)
+ else:
+ # some specific interval is provided: use profile
+ return spike_sync_profile(spike_train1, spike_train2,
+ max_tau).integral(interval)
+
+
+############################################################
# spike_sync
############################################################
def spike_sync(spike_train1, spike_train2, interval=None, max_tau=None):
@@ -80,8 +114,8 @@ def spike_sync(spike_train1, spike_train2, interval=None, max_tau=None):
:rtype: `double`
"""
- return spike_sync_profile(spike_train1, spike_train2,
- max_tau).avrg(interval)
+ c, mp = _spike_sync_values(spike_train1, spike_train2, interval, max_tau)
+ return 1.0*c/mp
############################################################
@@ -131,8 +165,25 @@ def spike_sync_multi(spike_trains, indices=None, interval=None, max_tau=None):
:rtype: double
"""
- return spike_sync_profile_multi(spike_trains, indices,
- max_tau).avrg(interval)
+ if indices is None:
+ indices = np.arange(len(spike_trains))
+ indices = np.array(indices)
+ # check validity of indices
+ assert (indices < len(spike_trains)).all() and (indices >= 0).all(), \
+ "Invalid index list."
+ # generate a list of possible index pairs
+ pairs = [(indices[i], j) for i in range(len(indices))
+ for j in indices[i+1:]]
+
+ coincidence = 0.0
+ mp = 0.0
+ for (i, j) in pairs:
+ c, m = _spike_sync_values(spike_trains[i], spike_trains[j],
+ interval, max_tau)
+ coincidence += c
+ mp += m
+
+ return coincidence/mp
############################################################