summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMario Mulansky <mario.mulansky@gmx.net>2016-01-31 16:40:47 +0100
committerMario Mulansky <mario.mulansky@gmx.net>2016-01-31 16:40:47 +0100
commitea3709e2f4367cb539acc26ec8e05b686d6bf836 (patch)
tree421a1901ca6c598f5e762fcc78126183da833946
parent5a556a11fbf8434bd38fa73e05054d581018a4da (diff)
generic interface for spike distance/profile
spike_profile and spike_distance now have a generic interface that allows to compute bi-variate and multi-variate results with the same function.
-rw-r--r--pyspike/isi_distance.py17
-rw-r--r--pyspike/spike_distance.py124
-rw-r--r--test/test_generic_interfaces.py31
-rw-r--r--test/test_regression/test_regression_15.py1
4 files changed, 138 insertions, 35 deletions
diff --git a/pyspike/isi_distance.py b/pyspike/isi_distance.py
index a85028f..122e11d 100644
--- a/pyspike/isi_distance.py
+++ b/pyspike/isi_distance.py
@@ -14,7 +14,7 @@ from pyspike.generic import _generic_profile_multi, _generic_distance_multi, \
# isi_profile
############################################################
def isi_profile(*args, **kwargs):
- """ Computes the isi-distance profile :math:`I(t)` of the two given
+ """ Computes the isi-distance profile :math:`I(t)` of the given
spike trains. Returns the profile as a PieceWiseConstFunc object. The
ISI-values are defined positive :math:`I(t)>=0`.
@@ -22,8 +22,9 @@ def isi_profile(*args, **kwargs):
isi_profile(st1, st2) # returns the bi-variate profile
isi_profile(st1, st2, st3) # multi-variate profile of 3 spike trains
+
spike_trains = [st1, st2, st3, st4] # list of spike trains
- isi_profile(spike_trains) # return the profile the list of spike trains
+ isi_profile(spike_trains) # profile of the list of spike trains
isi_profile(spike_trains, indices=[0, 1]) # use only the spike trains
# given by the indices
@@ -108,13 +109,23 @@ def isi_profile_multi(spike_trains, indices=None):
# isi_distance
############################################################
def isi_distance(*args, **kwargs):
- # spike_trains, spike_train2, interval=None):
""" Computes the ISI-distance :math:`D_I` of the given spike trains. The
isi-distance is the integral over the isi distance profile
:math:`I(t)`:
.. math:: D_I = \\int_{T_0}^{T_1} I(t) dt.
+
+ Valid call structures::
+
+ isi_distance(st1, st2) # returns the bi-variate distance
+ isi_distance(st1, st2, st3) # multi-variate distance of 3 spike trains
+
+ spike_trains = [st1, st2, st3, st4] # list of spike trains
+ isi_distance(spike_trains) # distance of the list of spike trains
+ isi_distance(spike_trains, indices=[0, 1]) # use only the spike trains
+ # given by the indices
+
:returns: The isi-distance :math:`D_I`.
:rtype: double
"""
diff --git a/pyspike/spike_distance.py b/pyspike/spike_distance.py
index e418283..7acb959 100644
--- a/pyspike/spike_distance.py
+++ b/pyspike/spike_distance.py
@@ -13,7 +13,36 @@ from pyspike.generic import _generic_profile_multi, _generic_distance_multi, \
############################################################
# spike_profile
############################################################
-def spike_profile(spike_train1, spike_train2):
+def spike_profile(*args, **kwargs):
+ """ Computes the spike-distance profile :math:`S(t)` of the given
+ spike trains. Returns the profile as a PieceWiseConstLin object. The
+ SPIKE-values are defined positive :math:`S(t)>=0`.
+
+ Valid call structures::
+
+ spike_profile(st1, st2) # returns the bi-variate profile
+ spike_profile(st1, st2, st3) # multi-variate profile of 3 spike trains
+
+ spike_trains = [st1, st2, st3, st4] # list of spike trains
+ spike_profile(spike_trains) # profile of the list of spike trains
+ spike_profile(spike_trains, indices=[0, 1]) # use only the spike trains
+ # given by the indices
+
+ :returns: The spike-distance profile :math:`S(t)`
+ :rtype: :class:`.PieceWiseConstLin`
+ """
+ if len(args) == 1:
+ return spike_profile_multi(args[0], **kwargs)
+ elif len(args) == 2:
+ return spike_profile_bi(args[0], args[1])
+ else:
+ return spike_profile_multi(args)
+
+
+############################################################
+# spike_profile_bi
+############################################################
+def spike_profile_bi(spike_train1, spike_train2):
""" Computes the spike-distance profile :math:`S(t)` of the two given spike
trains. Returns the profile as a PieceWiseLinFunc object. The SPIKE-values
are defined positive :math:`S(t)>=0`.
@@ -54,9 +83,67 @@ Falling back to slow python backend.")
############################################################
+# spike_profile_multi
+############################################################
+def spike_profile_multi(spike_trains, indices=None):
+ """ Computes the multi-variate spike distance profile for a set of spike
+ trains. That is the average spike-distance of all pairs of spike-trains:
+
+ .. math:: <S(t)> = \\frac{2}{N(N-1)} \\sum_{<i,j>} S^{i, j}`,
+
+ where the sum goes over all pairs <i,j>
+
+ :param spike_trains: list of :class:`.SpikeTrain`
+ :param indices: list of indices defining which spike trains to use,
+ if None all given spike trains are used (default=None)
+ :type indices: list or None
+ :returns: The averaged spike profile :math:`<S>(t)`
+ :rtype: :class:`.PieceWiseLinFunc`
+
+ """
+ average_dist, M = _generic_profile_multi(spike_trains, spike_profile_bi,
+ indices)
+ average_dist.mul_scalar(1.0/M) # normalize
+ return average_dist
+
+
+############################################################
# spike_distance
############################################################
-def spike_distance(spike_train1, spike_train2, interval=None):
+def spike_distance(*args, **kwargs):
+ """ Computes the SPIKE-distance :math:`D_S` of the given spike trains. The
+ spike-distance is the integral over the spike distance profile
+ :math:`D(t)`:
+
+ .. math:: D_S = \\int_{T_0}^{T_1} S(t) dt.
+
+
+ Valid call structures::
+
+ spike_distance(st1, st2) # returns the bi-variate distance
+ spike_distance(st1, st2, st3) # multi-variate distance of 3 spike trains
+
+ spike_trains = [st1, st2, st3, st4] # list of spike trains
+ spike_distance(spike_trains) # distance of the list of spike trains
+ spike_distance(spike_trains, indices=[0, 1]) # use only the spike trains
+ # given by the indices
+
+ :returns: The spike-distance :math:`D_S`.
+ :rtype: double
+ """
+
+ if len(args) == 1:
+ return spike_distance_multi(args[0], **kwargs)
+ elif len(args) == 2:
+ return spike_distance_bi(args[0], args[1], **kwargs)
+ else:
+ return spike_distance_multi(args, **kwargs)
+
+
+############################################################
+# spike_distance_bi
+############################################################
+def spike_distance_bi(spike_train1, spike_train2, interval=None):
""" Computes the spike-distance :math:`D_S` of the given spike trains. The
spike-distance is the integral over the spike distance profile
:math:`S(t)`:
@@ -86,35 +173,10 @@ def spike_distance(spike_train1, spike_train2, interval=None):
spike_train1.t_end)
except ImportError:
# Cython backend not available: fall back to average profile
- return spike_profile(spike_train1, spike_train2).avrg(interval)
+ return spike_profile_bi(spike_train1, spike_train2).avrg(interval)
else:
# some specific interval is provided: compute the whole profile
- return spike_profile(spike_train1, spike_train2).avrg(interval)
-
-
-############################################################
-# spike_profile_multi
-############################################################
-def spike_profile_multi(spike_trains, indices=None):
- """ Computes the multi-variate spike distance profile for a set of spike
- trains. That is the average spike-distance of all pairs of spike-trains:
-
- .. math:: <S(t)> = \\frac{2}{N(N-1)} \\sum_{<i,j>} S^{i, j}`,
-
- where the sum goes over all pairs <i,j>
-
- :param spike_trains: list of :class:`.SpikeTrain`
- :param indices: list of indices defining which spike trains to use,
- if None all given spike trains are used (default=None)
- :type indices: list or None
- :returns: The averaged spike profile :math:`<S>(t)`
- :rtype: :class:`.PieceWiseLinFunc`
-
- """
- average_dist, M = _generic_profile_multi(spike_trains, spike_profile,
- indices)
- average_dist.mul_scalar(1.0/M) # normalize
- return average_dist
+ return spike_profile_bi(spike_train1, spike_train2).avrg(interval)
############################################################
@@ -139,7 +201,7 @@ def spike_distance_multi(spike_trains, indices=None, interval=None):
:returns: The averaged multi-variate spike distance :math:`D_S`.
:rtype: double
"""
- return _generic_distance_multi(spike_trains, spike_distance, indices,
+ return _generic_distance_multi(spike_trains, spike_distance_bi, indices,
interval)
@@ -160,5 +222,5 @@ def spike_distance_matrix(spike_trains, indices=None, interval=None):
:math:`D_S^{ij}`
:rtype: np.array
"""
- return _generic_distance_matrix(spike_trains, spike_distance,
+ return _generic_distance_matrix(spike_trains, spike_distance_bi,
indices, interval)
diff --git a/test/test_generic_interfaces.py b/test/test_generic_interfaces.py
index caa9ee4..ee87be4 100644
--- a/test/test_generic_interfaces.py
+++ b/test/test_generic_interfaces.py
@@ -23,7 +23,12 @@ class dist_from_prof:
self.prof_func = prof_func
def __call__(self, *args, **kwargs):
- return self.prof_func(*args, **kwargs).avrg()
+ if "interval" in kwargs:
+ # forward interval arg into avrg function
+ interval = kwargs.pop("interval")
+ return self.prof_func(*args, **kwargs).avrg(interval=interval)
+ else:
+ return self.prof_func(*args, **kwargs).avrg()
def check_func(dist_func):
@@ -50,6 +55,22 @@ def check_func(dist_func):
isi123_ = dist_func(spike_trains, indices=[0, 1, 2])
assert_equal(isi123, isi123_)
+ # run the same test with an additional interval parameter
+
+ isi12 = dist_func(t1, t2, interval=[0.0, 0.5])
+ isi12_ = dist_func([t1, t2], interval=[0.0, 0.5])
+ assert_equal(isi12, isi12_)
+
+ isi12_ = dist_func(spike_trains, indices=[0, 1], interval=[0.0, 0.5])
+ assert_equal(isi12, isi12_)
+
+ isi123 = dist_func(t1, t2, t3, interval=[0.0, 0.5])
+ isi123_ = dist_func([t1, t2, t3], interval=[0.0, 0.5])
+ assert_equal(isi123, isi123_)
+
+ isi123_ = dist_func(spike_trains, indices=[0, 1, 2], interval=[0.0, 0.5])
+ assert_equal(isi123, isi123_)
+
def test_isi_profile():
check_func(dist_from_prof(spk.isi_profile))
@@ -59,6 +80,14 @@ def test_isi_distance():
check_func(spk.isi_distance)
+def test_spike_profile():
+ check_func(dist_from_prof(spk.spike_profile))
+
+
+def test_spike_distance():
+ check_func(spk.spike_distance)
+
+
if __name__ == "__main__":
test_isi_profile()
test_isi_distance()
diff --git a/test/test_regression/test_regression_15.py b/test/test_regression/test_regression_15.py
index dcacae2..54adf23 100644
--- a/test/test_regression/test_regression_15.py
+++ b/test/test_regression/test_regression_15.py
@@ -20,6 +20,7 @@ import os
TEST_PATH = os.path.dirname(os.path.realpath(__file__))
TEST_DATA = os.path.join(TEST_PATH, "..", "SPIKE_Sync_Test.txt")
+
def test_regression_15_isi():
# load spike trains
spike_trains = spk.load_spike_trains_from_txt(TEST_DATA, edges=[0, 4000])