From 5a556a11fbf8434bd38fa73e05054d581018a4da Mon Sep 17 00:00:00 2001 From: Mario Mulansky Date: Sun, 31 Jan 2016 15:05:21 +0100 Subject: generalized interface to isi profile and distance isi profile and distance functionc an now compute bi-variate and multi-variate results. Therefore, it can be called with different "overloads". --- pyspike/isi_distance.py | 116 +++++++++++++++++++++++++++++----------- test/test_generic_interfaces.py | 64 ++++++++++++++++++++++ 2 files changed, 148 insertions(+), 32 deletions(-) create mode 100644 test/test_generic_interfaces.py diff --git a/pyspike/isi_distance.py b/pyspike/isi_distance.py index 0ae7393..a85028f 100644 --- a/pyspike/isi_distance.py +++ b/pyspike/isi_distance.py @@ -13,11 +13,40 @@ from pyspike.generic import _generic_profile_multi, _generic_distance_multi, \ ############################################################ # isi_profile ############################################################ -def isi_profile(spike_train1, spike_train2): +def isi_profile(*args, **kwargs): """ Computes the isi-distance profile :math:`I(t)` of the two given - spike trains. Retruns the profile as a PieceWiseConstFunc object. The + spike trains. Returns the profile as a PieceWiseConstFunc object. The ISI-values are defined positive :math:`I(t)>=0`. + Valid call structures:: + + isi_profile(st1, st2) # returns the bi-variate profile + isi_profile(st1, st2, st3) # multi-variate profile of 3 spike trains + spike_trains = [st1, st2, st3, st4] # list of spike trains + isi_profile(spike_trains) # return the profile the list of spike trains + isi_profile(spike_trains, indices=[0, 1]) # use only the spike trains + # given by the indices + + :returns: The isi-distance profile :math:`I(t)` + :rtype: :class:`.PieceWiseConstFunc` + """ + if len(args) == 1: + return isi_profile_multi(args[0], **kwargs) + elif len(args) == 2: + return isi_profile_bi(args[0], args[1]) + else: + return isi_profile_multi(args) + + +############################################################ +# isi_profile_bi +############################################################ +def isi_profile_bi(spike_train1, spike_train2): + """ Bi-variate ISI-profile. + Computes the isi-distance profile :math:`I(t)` of the two given + spike trains. Returns the profile as a PieceWiseConstFunc object. + See :func:`.isi_profile`. + :param spike_train1: First spike train. :type spike_train1: :class:`.SpikeTrain` :param spike_train2: Second spike train. @@ -51,10 +80,57 @@ Falling back to slow python backend.") return PieceWiseConstFunc(times, values) +############################################################ +# isi_profile_multi +############################################################ +def isi_profile_multi(spike_trains, indices=None): + """ computes the multi-variate isi distance profile for a set of spike + trains. That is the average isi-distance of all pairs of spike-trains: + + .. math:: = \\frac{2}{N(N-1)} \\sum_{} I^{i,j}, + + where the sum goes over all pairs + + :param spike_trains: list of :class:`.SpikeTrain` + :param indices: list of indices defining which spike trains to use, + if None all given spike trains are used (default=None) + :type state: list or None + :returns: The averaged isi profile :math:`` + :rtype: :class:`.PieceWiseConstFunc` + """ + average_dist, M = _generic_profile_multi(spike_trains, isi_profile_bi, + indices) + average_dist.mul_scalar(1.0/M) # normalize + return average_dist + + ############################################################ # isi_distance ############################################################ -def isi_distance(spike_train1, spike_train2, interval=None): +def isi_distance(*args, **kwargs): + # spike_trains, spike_train2, interval=None): + """ Computes the ISI-distance :math:`D_I` of the given spike trains. The + isi-distance is the integral over the isi distance profile + :math:`I(t)`: + + .. math:: D_I = \\int_{T_0}^{T_1} I(t) dt. + + :returns: The isi-distance :math:`D_I`. + :rtype: double + """ + + if len(args) == 1: + return isi_distance_multi(args[0], **kwargs) + elif len(args) == 2: + return isi_distance_bi(args[0], args[1], **kwargs) + else: + return isi_distance_multi(args, **kwargs) + + +############################################################ +# isi_distance_bi +############################################################ +def isi_distance_bi(spike_train1, spike_train2, interval=None): """ Computes the ISI-distance :math:`D_I` of the given spike trains. The isi-distance is the integral over the isi distance profile :math:`I(t)`: @@ -84,34 +160,10 @@ def isi_distance(spike_train1, spike_train2, interval=None): spike_train1.t_start, spike_train1.t_end) except ImportError: # Cython backend not available: fall back to profile averaging - return isi_profile(spike_train1, spike_train2).avrg(interval) + return isi_profile_bi(spike_train1, spike_train2).avrg(interval) else: # some specific interval is provided: use profile - return isi_profile(spike_train1, spike_train2).avrg(interval) - - -############################################################ -# isi_profile_multi -############################################################ -def isi_profile_multi(spike_trains, indices=None): - """ computes the multi-variate isi distance profile for a set of spike - trains. That is the average isi-distance of all pairs of spike-trains: - - .. math:: = \\frac{2}{N(N-1)} \\sum_{} I^{i,j}, - - where the sum goes over all pairs - - :param spike_trains: list of :class:`.SpikeTrain` - :param indices: list of indices defining which spike trains to use, - if None all given spike trains are used (default=None) - :type state: list or None - :returns: The averaged isi profile :math:`` - :rtype: :class:`.PieceWiseConstFunc` - """ - average_dist, M = _generic_profile_multi(spike_trains, isi_profile, - indices) - average_dist.mul_scalar(1.0/M) # normalize - return average_dist + return isi_profile_bi(spike_train1, spike_train2).avrg(interval) ############################################################ @@ -134,7 +186,7 @@ def isi_distance_multi(spike_trains, indices=None, interval=None): :returns: The time-averaged multivariate ISI distance :math:`D_I` :rtype: double """ - return _generic_distance_multi(spike_trains, isi_distance, indices, + return _generic_distance_multi(spike_trains, isi_distance_bi, indices, interval) @@ -155,5 +207,5 @@ def isi_distance_matrix(spike_trains, indices=None, interval=None): :math:`D_{I}^{ij}` :rtype: np.array """ - return _generic_distance_matrix(spike_trains, isi_distance, - indices, interval) + return _generic_distance_matrix(spike_trains, isi_distance_bi, + indices=indices, interval=interval) diff --git a/test/test_generic_interfaces.py b/test/test_generic_interfaces.py new file mode 100644 index 0000000..caa9ee4 --- /dev/null +++ b/test/test_generic_interfaces.py @@ -0,0 +1,64 @@ +""" test_isi_interface.py + +Tests the generic interfaces of the profile and distance functions + +Copyright 2016, Mario Mulansky + +Distributed under the BSD License + +""" + +from __future__ import print_function +from numpy.testing import assert_equal + +import pyspike as spk +from pyspike import SpikeTrain + + +class dist_from_prof: + """ Simple functor that turns profile function into distance function by + calling profile.avrg(). + """ + def __init__(self, prof_func): + self.prof_func = prof_func + + def __call__(self, *args, **kwargs): + return self.prof_func(*args, **kwargs).avrg() + + +def check_func(dist_func): + """ generic checker that tests the given distance function. + """ + # generate spike trains: + t1 = SpikeTrain([0.2, 0.4, 0.6, 0.7], 1.0) + t2 = SpikeTrain([0.3, 0.45, 0.8, 0.9, 0.95], 1.0) + t3 = SpikeTrain([0.2, 0.4, 0.6], 1.0) + t4 = SpikeTrain([0.1, 0.4, 0.5, 0.6], 1.0) + spike_trains = [t1, t2, t3, t4] + + isi12 = dist_func(t1, t2) + isi12_ = dist_func([t1, t2]) + assert_equal(isi12, isi12_) + + isi12_ = dist_func(spike_trains, indices=[0, 1]) + assert_equal(isi12, isi12_) + + isi123 = dist_func(t1, t2, t3) + isi123_ = dist_func([t1, t2, t3]) + assert_equal(isi123, isi123_) + + isi123_ = dist_func(spike_trains, indices=[0, 1, 2]) + assert_equal(isi123, isi123_) + + +def test_isi_profile(): + check_func(dist_from_prof(spk.isi_profile)) + + +def test_isi_distance(): + check_func(spk.isi_distance) + + +if __name__ == "__main__": + test_isi_profile() + test_isi_distance() -- cgit v1.2.3