summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMario Mulansky <mario.mulansky@gmx.net>2016-01-31 15:05:21 +0100
committerMario Mulansky <mario.mulansky@gmx.net>2016-01-31 15:05:21 +0100
commit5a556a11fbf8434bd38fa73e05054d581018a4da (patch)
tree6114d1c64d29484a4f4fe3a7309f16f8d437f372
parent2f48f27b55f63726216b6e674fb88b3790b59147 (diff)
generalized interface to isi profile and distance
isi profile and distance functionc an now compute bi-variate and multi-variate results. Therefore, it can be called with different "overloads".
-rw-r--r--pyspike/isi_distance.py116
-rw-r--r--test/test_generic_interfaces.py64
2 files changed, 148 insertions, 32 deletions
diff --git a/pyspike/isi_distance.py b/pyspike/isi_distance.py
index 0ae7393..a85028f 100644
--- a/pyspike/isi_distance.py
+++ b/pyspike/isi_distance.py
@@ -13,11 +13,40 @@ from pyspike.generic import _generic_profile_multi, _generic_distance_multi, \
############################################################
# isi_profile
############################################################
-def isi_profile(spike_train1, spike_train2):
+def isi_profile(*args, **kwargs):
""" Computes the isi-distance profile :math:`I(t)` of the two given
- spike trains. Retruns the profile as a PieceWiseConstFunc object. The
+ spike trains. Returns the profile as a PieceWiseConstFunc object. The
ISI-values are defined positive :math:`I(t)>=0`.
+ Valid call structures::
+
+ isi_profile(st1, st2) # returns the bi-variate profile
+ isi_profile(st1, st2, st3) # multi-variate profile of 3 spike trains
+ spike_trains = [st1, st2, st3, st4] # list of spike trains
+ isi_profile(spike_trains) # return the profile the list of spike trains
+ isi_profile(spike_trains, indices=[0, 1]) # use only the spike trains
+ # given by the indices
+
+ :returns: The isi-distance profile :math:`I(t)`
+ :rtype: :class:`.PieceWiseConstFunc`
+ """
+ if len(args) == 1:
+ return isi_profile_multi(args[0], **kwargs)
+ elif len(args) == 2:
+ return isi_profile_bi(args[0], args[1])
+ else:
+ return isi_profile_multi(args)
+
+
+############################################################
+# isi_profile_bi
+############################################################
+def isi_profile_bi(spike_train1, spike_train2):
+ """ Bi-variate ISI-profile.
+ Computes the isi-distance profile :math:`I(t)` of the two given
+ spike trains. Returns the profile as a PieceWiseConstFunc object.
+ See :func:`.isi_profile`.
+
:param spike_train1: First spike train.
:type spike_train1: :class:`.SpikeTrain`
:param spike_train2: Second spike train.
@@ -52,9 +81,56 @@ Falling back to slow python backend.")
############################################################
+# isi_profile_multi
+############################################################
+def isi_profile_multi(spike_trains, indices=None):
+ """ computes the multi-variate isi distance profile for a set of spike
+ trains. That is the average isi-distance of all pairs of spike-trains:
+
+ .. math:: <I(t)> = \\frac{2}{N(N-1)} \\sum_{<i,j>} I^{i,j},
+
+ where the sum goes over all pairs <i,j>
+
+ :param spike_trains: list of :class:`.SpikeTrain`
+ :param indices: list of indices defining which spike trains to use,
+ if None all given spike trains are used (default=None)
+ :type state: list or None
+ :returns: The averaged isi profile :math:`<I(t)>`
+ :rtype: :class:`.PieceWiseConstFunc`
+ """
+ average_dist, M = _generic_profile_multi(spike_trains, isi_profile_bi,
+ indices)
+ average_dist.mul_scalar(1.0/M) # normalize
+ return average_dist
+
+
+############################################################
# isi_distance
############################################################
-def isi_distance(spike_train1, spike_train2, interval=None):
+def isi_distance(*args, **kwargs):
+ # spike_trains, spike_train2, interval=None):
+ """ Computes the ISI-distance :math:`D_I` of the given spike trains. The
+ isi-distance is the integral over the isi distance profile
+ :math:`I(t)`:
+
+ .. math:: D_I = \\int_{T_0}^{T_1} I(t) dt.
+
+ :returns: The isi-distance :math:`D_I`.
+ :rtype: double
+ """
+
+ if len(args) == 1:
+ return isi_distance_multi(args[0], **kwargs)
+ elif len(args) == 2:
+ return isi_distance_bi(args[0], args[1], **kwargs)
+ else:
+ return isi_distance_multi(args, **kwargs)
+
+
+############################################################
+# isi_distance_bi
+############################################################
+def isi_distance_bi(spike_train1, spike_train2, interval=None):
""" Computes the ISI-distance :math:`D_I` of the given spike trains. The
isi-distance is the integral over the isi distance profile
:math:`I(t)`:
@@ -84,34 +160,10 @@ def isi_distance(spike_train1, spike_train2, interval=None):
spike_train1.t_start, spike_train1.t_end)
except ImportError:
# Cython backend not available: fall back to profile averaging
- return isi_profile(spike_train1, spike_train2).avrg(interval)
+ return isi_profile_bi(spike_train1, spike_train2).avrg(interval)
else:
# some specific interval is provided: use profile
- return isi_profile(spike_train1, spike_train2).avrg(interval)
-
-
-############################################################
-# isi_profile_multi
-############################################################
-def isi_profile_multi(spike_trains, indices=None):
- """ computes the multi-variate isi distance profile for a set of spike
- trains. That is the average isi-distance of all pairs of spike-trains:
-
- .. math:: <I(t)> = \\frac{2}{N(N-1)} \\sum_{<i,j>} I^{i,j},
-
- where the sum goes over all pairs <i,j>
-
- :param spike_trains: list of :class:`.SpikeTrain`
- :param indices: list of indices defining which spike trains to use,
- if None all given spike trains are used (default=None)
- :type state: list or None
- :returns: The averaged isi profile :math:`<I(t)>`
- :rtype: :class:`.PieceWiseConstFunc`
- """
- average_dist, M = _generic_profile_multi(spike_trains, isi_profile,
- indices)
- average_dist.mul_scalar(1.0/M) # normalize
- return average_dist
+ return isi_profile_bi(spike_train1, spike_train2).avrg(interval)
############################################################
@@ -134,7 +186,7 @@ def isi_distance_multi(spike_trains, indices=None, interval=None):
:returns: The time-averaged multivariate ISI distance :math:`D_I`
:rtype: double
"""
- return _generic_distance_multi(spike_trains, isi_distance, indices,
+ return _generic_distance_multi(spike_trains, isi_distance_bi, indices,
interval)
@@ -155,5 +207,5 @@ def isi_distance_matrix(spike_trains, indices=None, interval=None):
:math:`D_{I}^{ij}`
:rtype: np.array
"""
- return _generic_distance_matrix(spike_trains, isi_distance,
- indices, interval)
+ return _generic_distance_matrix(spike_trains, isi_distance_bi,
+ indices=indices, interval=interval)
diff --git a/test/test_generic_interfaces.py b/test/test_generic_interfaces.py
new file mode 100644
index 0000000..caa9ee4
--- /dev/null
+++ b/test/test_generic_interfaces.py
@@ -0,0 +1,64 @@
+""" test_isi_interface.py
+
+Tests the generic interfaces of the profile and distance functions
+
+Copyright 2016, Mario Mulansky <mario.mulansky@gmx.net>
+
+Distributed under the BSD License
+
+"""
+
+from __future__ import print_function
+from numpy.testing import assert_equal
+
+import pyspike as spk
+from pyspike import SpikeTrain
+
+
+class dist_from_prof:
+ """ Simple functor that turns profile function into distance function by
+ calling profile.avrg().
+ """
+ def __init__(self, prof_func):
+ self.prof_func = prof_func
+
+ def __call__(self, *args, **kwargs):
+ return self.prof_func(*args, **kwargs).avrg()
+
+
+def check_func(dist_func):
+ """ generic checker that tests the given distance function.
+ """
+ # generate spike trains:
+ t1 = SpikeTrain([0.2, 0.4, 0.6, 0.7], 1.0)
+ t2 = SpikeTrain([0.3, 0.45, 0.8, 0.9, 0.95], 1.0)
+ t3 = SpikeTrain([0.2, 0.4, 0.6], 1.0)
+ t4 = SpikeTrain([0.1, 0.4, 0.5, 0.6], 1.0)
+ spike_trains = [t1, t2, t3, t4]
+
+ isi12 = dist_func(t1, t2)
+ isi12_ = dist_func([t1, t2])
+ assert_equal(isi12, isi12_)
+
+ isi12_ = dist_func(spike_trains, indices=[0, 1])
+ assert_equal(isi12, isi12_)
+
+ isi123 = dist_func(t1, t2, t3)
+ isi123_ = dist_func([t1, t2, t3])
+ assert_equal(isi123, isi123_)
+
+ isi123_ = dist_func(spike_trains, indices=[0, 1, 2])
+ assert_equal(isi123, isi123_)
+
+
+def test_isi_profile():
+ check_func(dist_from_prof(spk.isi_profile))
+
+
+def test_isi_distance():
+ check_func(spk.isi_distance)
+
+
+if __name__ == "__main__":
+ test_isi_profile()
+ test_isi_distance()