From ea3709e2f4367cb539acc26ec8e05b686d6bf836 Mon Sep 17 00:00:00 2001 From: Mario Mulansky Date: Sun, 31 Jan 2016 16:40:47 +0100 Subject: generic interface for spike distance/profile spike_profile and spike_distance now have a generic interface that allows to compute bi-variate and multi-variate results with the same function. --- pyspike/isi_distance.py | 17 +++- pyspike/spike_distance.py | 124 +++++++++++++++++++++-------- test/test_generic_interfaces.py | 31 +++++++- test/test_regression/test_regression_15.py | 1 + 4 files changed, 138 insertions(+), 35 deletions(-) diff --git a/pyspike/isi_distance.py b/pyspike/isi_distance.py index a85028f..122e11d 100644 --- a/pyspike/isi_distance.py +++ b/pyspike/isi_distance.py @@ -14,7 +14,7 @@ from pyspike.generic import _generic_profile_multi, _generic_distance_multi, \ # isi_profile ############################################################ def isi_profile(*args, **kwargs): - """ Computes the isi-distance profile :math:`I(t)` of the two given + """ Computes the isi-distance profile :math:`I(t)` of the given spike trains. Returns the profile as a PieceWiseConstFunc object. The ISI-values are defined positive :math:`I(t)>=0`. @@ -22,8 +22,9 @@ def isi_profile(*args, **kwargs): isi_profile(st1, st2) # returns the bi-variate profile isi_profile(st1, st2, st3) # multi-variate profile of 3 spike trains + spike_trains = [st1, st2, st3, st4] # list of spike trains - isi_profile(spike_trains) # return the profile the list of spike trains + isi_profile(spike_trains) # profile of the list of spike trains isi_profile(spike_trains, indices=[0, 1]) # use only the spike trains # given by the indices @@ -108,13 +109,23 @@ def isi_profile_multi(spike_trains, indices=None): # isi_distance ############################################################ def isi_distance(*args, **kwargs): - # spike_trains, spike_train2, interval=None): """ Computes the ISI-distance :math:`D_I` of the given spike trains. The isi-distance is the integral over the isi distance profile :math:`I(t)`: .. math:: D_I = \\int_{T_0}^{T_1} I(t) dt. + + Valid call structures:: + + isi_distance(st1, st2) # returns the bi-variate distance + isi_distance(st1, st2, st3) # multi-variate distance of 3 spike trains + + spike_trains = [st1, st2, st3, st4] # list of spike trains + isi_distance(spike_trains) # distance of the list of spike trains + isi_distance(spike_trains, indices=[0, 1]) # use only the spike trains + # given by the indices + :returns: The isi-distance :math:`D_I`. :rtype: double """ diff --git a/pyspike/spike_distance.py b/pyspike/spike_distance.py index e418283..7acb959 100644 --- a/pyspike/spike_distance.py +++ b/pyspike/spike_distance.py @@ -13,7 +13,36 @@ from pyspike.generic import _generic_profile_multi, _generic_distance_multi, \ ############################################################ # spike_profile ############################################################ -def spike_profile(spike_train1, spike_train2): +def spike_profile(*args, **kwargs): + """ Computes the spike-distance profile :math:`S(t)` of the given + spike trains. Returns the profile as a PieceWiseConstLin object. The + SPIKE-values are defined positive :math:`S(t)>=0`. + + Valid call structures:: + + spike_profile(st1, st2) # returns the bi-variate profile + spike_profile(st1, st2, st3) # multi-variate profile of 3 spike trains + + spike_trains = [st1, st2, st3, st4] # list of spike trains + spike_profile(spike_trains) # profile of the list of spike trains + spike_profile(spike_trains, indices=[0, 1]) # use only the spike trains + # given by the indices + + :returns: The spike-distance profile :math:`S(t)` + :rtype: :class:`.PieceWiseConstLin` + """ + if len(args) == 1: + return spike_profile_multi(args[0], **kwargs) + elif len(args) == 2: + return spike_profile_bi(args[0], args[1]) + else: + return spike_profile_multi(args) + + +############################################################ +# spike_profile_bi +############################################################ +def spike_profile_bi(spike_train1, spike_train2): """ Computes the spike-distance profile :math:`S(t)` of the two given spike trains. Returns the profile as a PieceWiseLinFunc object. The SPIKE-values are defined positive :math:`S(t)>=0`. @@ -53,10 +82,68 @@ Falling back to slow python backend.") return PieceWiseLinFunc(times, y_starts, y_ends) +############################################################ +# spike_profile_multi +############################################################ +def spike_profile_multi(spike_trains, indices=None): + """ Computes the multi-variate spike distance profile for a set of spike + trains. That is the average spike-distance of all pairs of spike-trains: + + .. math:: = \\frac{2}{N(N-1)} \\sum_{} S^{i, j}`, + + where the sum goes over all pairs + + :param spike_trains: list of :class:`.SpikeTrain` + :param indices: list of indices defining which spike trains to use, + if None all given spike trains are used (default=None) + :type indices: list or None + :returns: The averaged spike profile :math:`(t)` + :rtype: :class:`.PieceWiseLinFunc` + + """ + average_dist, M = _generic_profile_multi(spike_trains, spike_profile_bi, + indices) + average_dist.mul_scalar(1.0/M) # normalize + return average_dist + + ############################################################ # spike_distance ############################################################ -def spike_distance(spike_train1, spike_train2, interval=None): +def spike_distance(*args, **kwargs): + """ Computes the SPIKE-distance :math:`D_S` of the given spike trains. The + spike-distance is the integral over the spike distance profile + :math:`D(t)`: + + .. math:: D_S = \\int_{T_0}^{T_1} S(t) dt. + + + Valid call structures:: + + spike_distance(st1, st2) # returns the bi-variate distance + spike_distance(st1, st2, st3) # multi-variate distance of 3 spike trains + + spike_trains = [st1, st2, st3, st4] # list of spike trains + spike_distance(spike_trains) # distance of the list of spike trains + spike_distance(spike_trains, indices=[0, 1]) # use only the spike trains + # given by the indices + + :returns: The spike-distance :math:`D_S`. + :rtype: double + """ + + if len(args) == 1: + return spike_distance_multi(args[0], **kwargs) + elif len(args) == 2: + return spike_distance_bi(args[0], args[1], **kwargs) + else: + return spike_distance_multi(args, **kwargs) + + +############################################################ +# spike_distance_bi +############################################################ +def spike_distance_bi(spike_train1, spike_train2, interval=None): """ Computes the spike-distance :math:`D_S` of the given spike trains. The spike-distance is the integral over the spike distance profile :math:`S(t)`: @@ -86,35 +173,10 @@ def spike_distance(spike_train1, spike_train2, interval=None): spike_train1.t_end) except ImportError: # Cython backend not available: fall back to average profile - return spike_profile(spike_train1, spike_train2).avrg(interval) + return spike_profile_bi(spike_train1, spike_train2).avrg(interval) else: # some specific interval is provided: compute the whole profile - return spike_profile(spike_train1, spike_train2).avrg(interval) - - -############################################################ -# spike_profile_multi -############################################################ -def spike_profile_multi(spike_trains, indices=None): - """ Computes the multi-variate spike distance profile for a set of spike - trains. That is the average spike-distance of all pairs of spike-trains: - - .. math:: = \\frac{2}{N(N-1)} \\sum_{} S^{i, j}`, - - where the sum goes over all pairs - - :param spike_trains: list of :class:`.SpikeTrain` - :param indices: list of indices defining which spike trains to use, - if None all given spike trains are used (default=None) - :type indices: list or None - :returns: The averaged spike profile :math:`(t)` - :rtype: :class:`.PieceWiseLinFunc` - - """ - average_dist, M = _generic_profile_multi(spike_trains, spike_profile, - indices) - average_dist.mul_scalar(1.0/M) # normalize - return average_dist + return spike_profile_bi(spike_train1, spike_train2).avrg(interval) ############################################################ @@ -139,7 +201,7 @@ def spike_distance_multi(spike_trains, indices=None, interval=None): :returns: The averaged multi-variate spike distance :math:`D_S`. :rtype: double """ - return _generic_distance_multi(spike_trains, spike_distance, indices, + return _generic_distance_multi(spike_trains, spike_distance_bi, indices, interval) @@ -160,5 +222,5 @@ def spike_distance_matrix(spike_trains, indices=None, interval=None): :math:`D_S^{ij}` :rtype: np.array """ - return _generic_distance_matrix(spike_trains, spike_distance, + return _generic_distance_matrix(spike_trains, spike_distance_bi, indices, interval) diff --git a/test/test_generic_interfaces.py b/test/test_generic_interfaces.py index caa9ee4..ee87be4 100644 --- a/test/test_generic_interfaces.py +++ b/test/test_generic_interfaces.py @@ -23,7 +23,12 @@ class dist_from_prof: self.prof_func = prof_func def __call__(self, *args, **kwargs): - return self.prof_func(*args, **kwargs).avrg() + if "interval" in kwargs: + # forward interval arg into avrg function + interval = kwargs.pop("interval") + return self.prof_func(*args, **kwargs).avrg(interval=interval) + else: + return self.prof_func(*args, **kwargs).avrg() def check_func(dist_func): @@ -50,6 +55,22 @@ def check_func(dist_func): isi123_ = dist_func(spike_trains, indices=[0, 1, 2]) assert_equal(isi123, isi123_) + # run the same test with an additional interval parameter + + isi12 = dist_func(t1, t2, interval=[0.0, 0.5]) + isi12_ = dist_func([t1, t2], interval=[0.0, 0.5]) + assert_equal(isi12, isi12_) + + isi12_ = dist_func(spike_trains, indices=[0, 1], interval=[0.0, 0.5]) + assert_equal(isi12, isi12_) + + isi123 = dist_func(t1, t2, t3, interval=[0.0, 0.5]) + isi123_ = dist_func([t1, t2, t3], interval=[0.0, 0.5]) + assert_equal(isi123, isi123_) + + isi123_ = dist_func(spike_trains, indices=[0, 1, 2], interval=[0.0, 0.5]) + assert_equal(isi123, isi123_) + def test_isi_profile(): check_func(dist_from_prof(spk.isi_profile)) @@ -59,6 +80,14 @@ def test_isi_distance(): check_func(spk.isi_distance) +def test_spike_profile(): + check_func(dist_from_prof(spk.spike_profile)) + + +def test_spike_distance(): + check_func(spk.spike_distance) + + if __name__ == "__main__": test_isi_profile() test_isi_distance() diff --git a/test/test_regression/test_regression_15.py b/test/test_regression/test_regression_15.py index dcacae2..54adf23 100644 --- a/test/test_regression/test_regression_15.py +++ b/test/test_regression/test_regression_15.py @@ -20,6 +20,7 @@ import os TEST_PATH = os.path.dirname(os.path.realpath(__file__)) TEST_DATA = os.path.join(TEST_PATH, "..", "SPIKE_Sync_Test.txt") + def test_regression_15_isi(): # load spike trains spike_trains = spk.load_spike_trains_from_txt(TEST_DATA, edges=[0, 4000]) -- cgit v1.2.3