From 5a556a11fbf8434bd38fa73e05054d581018a4da Mon Sep 17 00:00:00 2001 From: Mario Mulansky Date: Sun, 31 Jan 2016 15:05:21 +0100 Subject: generalized interface to isi profile and distance isi profile and distance functionc an now compute bi-variate and multi-variate results. Therefore, it can be called with different "overloads". --- test/test_generic_interfaces.py | 64 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 64 insertions(+) create mode 100644 test/test_generic_interfaces.py (limited to 'test/test_generic_interfaces.py') diff --git a/test/test_generic_interfaces.py b/test/test_generic_interfaces.py new file mode 100644 index 0000000..caa9ee4 --- /dev/null +++ b/test/test_generic_interfaces.py @@ -0,0 +1,64 @@ +""" test_isi_interface.py + +Tests the generic interfaces of the profile and distance functions + +Copyright 2016, Mario Mulansky + +Distributed under the BSD License + +""" + +from __future__ import print_function +from numpy.testing import assert_equal + +import pyspike as spk +from pyspike import SpikeTrain + + +class dist_from_prof: + """ Simple functor that turns profile function into distance function by + calling profile.avrg(). + """ + def __init__(self, prof_func): + self.prof_func = prof_func + + def __call__(self, *args, **kwargs): + return self.prof_func(*args, **kwargs).avrg() + + +def check_func(dist_func): + """ generic checker that tests the given distance function. + """ + # generate spike trains: + t1 = SpikeTrain([0.2, 0.4, 0.6, 0.7], 1.0) + t2 = SpikeTrain([0.3, 0.45, 0.8, 0.9, 0.95], 1.0) + t3 = SpikeTrain([0.2, 0.4, 0.6], 1.0) + t4 = SpikeTrain([0.1, 0.4, 0.5, 0.6], 1.0) + spike_trains = [t1, t2, t3, t4] + + isi12 = dist_func(t1, t2) + isi12_ = dist_func([t1, t2]) + assert_equal(isi12, isi12_) + + isi12_ = dist_func(spike_trains, indices=[0, 1]) + assert_equal(isi12, isi12_) + + isi123 = dist_func(t1, t2, t3) + isi123_ = dist_func([t1, t2, t3]) + assert_equal(isi123, isi123_) + + isi123_ = dist_func(spike_trains, indices=[0, 1, 2]) + assert_equal(isi123, isi123_) + + +def test_isi_profile(): + check_func(dist_from_prof(spk.isi_profile)) + + +def test_isi_distance(): + check_func(spk.isi_distance) + + +if __name__ == "__main__": + test_isi_profile() + test_isi_distance() -- cgit v1.2.3 From ea3709e2f4367cb539acc26ec8e05b686d6bf836 Mon Sep 17 00:00:00 2001 From: Mario Mulansky Date: Sun, 31 Jan 2016 16:40:47 +0100 Subject: generic interface for spike distance/profile spike_profile and spike_distance now have a generic interface that allows to compute bi-variate and multi-variate results with the same function. --- pyspike/isi_distance.py | 17 +++- pyspike/spike_distance.py | 124 +++++++++++++++++++++-------- test/test_generic_interfaces.py | 31 +++++++- test/test_regression/test_regression_15.py | 1 + 4 files changed, 138 insertions(+), 35 deletions(-) (limited to 'test/test_generic_interfaces.py') diff --git a/pyspike/isi_distance.py b/pyspike/isi_distance.py index a85028f..122e11d 100644 --- a/pyspike/isi_distance.py +++ b/pyspike/isi_distance.py @@ -14,7 +14,7 @@ from pyspike.generic import _generic_profile_multi, _generic_distance_multi, \ # isi_profile ############################################################ def isi_profile(*args, **kwargs): - """ Computes the isi-distance profile :math:`I(t)` of the two given + """ Computes the isi-distance profile :math:`I(t)` of the given spike trains. Returns the profile as a PieceWiseConstFunc object. The ISI-values are defined positive :math:`I(t)>=0`. @@ -22,8 +22,9 @@ def isi_profile(*args, **kwargs): isi_profile(st1, st2) # returns the bi-variate profile isi_profile(st1, st2, st3) # multi-variate profile of 3 spike trains + spike_trains = [st1, st2, st3, st4] # list of spike trains - isi_profile(spike_trains) # return the profile the list of spike trains + isi_profile(spike_trains) # profile of the list of spike trains isi_profile(spike_trains, indices=[0, 1]) # use only the spike trains # given by the indices @@ -108,13 +109,23 @@ def isi_profile_multi(spike_trains, indices=None): # isi_distance ############################################################ def isi_distance(*args, **kwargs): - # spike_trains, spike_train2, interval=None): """ Computes the ISI-distance :math:`D_I` of the given spike trains. The isi-distance is the integral over the isi distance profile :math:`I(t)`: .. math:: D_I = \\int_{T_0}^{T_1} I(t) dt. + + Valid call structures:: + + isi_distance(st1, st2) # returns the bi-variate distance + isi_distance(st1, st2, st3) # multi-variate distance of 3 spike trains + + spike_trains = [st1, st2, st3, st4] # list of spike trains + isi_distance(spike_trains) # distance of the list of spike trains + isi_distance(spike_trains, indices=[0, 1]) # use only the spike trains + # given by the indices + :returns: The isi-distance :math:`D_I`. :rtype: double """ diff --git a/pyspike/spike_distance.py b/pyspike/spike_distance.py index e418283..7acb959 100644 --- a/pyspike/spike_distance.py +++ b/pyspike/spike_distance.py @@ -13,7 +13,36 @@ from pyspike.generic import _generic_profile_multi, _generic_distance_multi, \ ############################################################ # spike_profile ############################################################ -def spike_profile(spike_train1, spike_train2): +def spike_profile(*args, **kwargs): + """ Computes the spike-distance profile :math:`S(t)` of the given + spike trains. Returns the profile as a PieceWiseConstLin object. The + SPIKE-values are defined positive :math:`S(t)>=0`. + + Valid call structures:: + + spike_profile(st1, st2) # returns the bi-variate profile + spike_profile(st1, st2, st3) # multi-variate profile of 3 spike trains + + spike_trains = [st1, st2, st3, st4] # list of spike trains + spike_profile(spike_trains) # profile of the list of spike trains + spike_profile(spike_trains, indices=[0, 1]) # use only the spike trains + # given by the indices + + :returns: The spike-distance profile :math:`S(t)` + :rtype: :class:`.PieceWiseConstLin` + """ + if len(args) == 1: + return spike_profile_multi(args[0], **kwargs) + elif len(args) == 2: + return spike_profile_bi(args[0], args[1]) + else: + return spike_profile_multi(args) + + +############################################################ +# spike_profile_bi +############################################################ +def spike_profile_bi(spike_train1, spike_train2): """ Computes the spike-distance profile :math:`S(t)` of the two given spike trains. Returns the profile as a PieceWiseLinFunc object. The SPIKE-values are defined positive :math:`S(t)>=0`. @@ -53,10 +82,68 @@ Falling back to slow python backend.") return PieceWiseLinFunc(times, y_starts, y_ends) +############################################################ +# spike_profile_multi +############################################################ +def spike_profile_multi(spike_trains, indices=None): + """ Computes the multi-variate spike distance profile for a set of spike + trains. That is the average spike-distance of all pairs of spike-trains: + + .. math:: = \\frac{2}{N(N-1)} \\sum_{} S^{i, j}`, + + where the sum goes over all pairs + + :param spike_trains: list of :class:`.SpikeTrain` + :param indices: list of indices defining which spike trains to use, + if None all given spike trains are used (default=None) + :type indices: list or None + :returns: The averaged spike profile :math:`(t)` + :rtype: :class:`.PieceWiseLinFunc` + + """ + average_dist, M = _generic_profile_multi(spike_trains, spike_profile_bi, + indices) + average_dist.mul_scalar(1.0/M) # normalize + return average_dist + + ############################################################ # spike_distance ############################################################ -def spike_distance(spike_train1, spike_train2, interval=None): +def spike_distance(*args, **kwargs): + """ Computes the SPIKE-distance :math:`D_S` of the given spike trains. The + spike-distance is the integral over the spike distance profile + :math:`D(t)`: + + .. math:: D_S = \\int_{T_0}^{T_1} S(t) dt. + + + Valid call structures:: + + spike_distance(st1, st2) # returns the bi-variate distance + spike_distance(st1, st2, st3) # multi-variate distance of 3 spike trains + + spike_trains = [st1, st2, st3, st4] # list of spike trains + spike_distance(spike_trains) # distance of the list of spike trains + spike_distance(spike_trains, indices=[0, 1]) # use only the spike trains + # given by the indices + + :returns: The spike-distance :math:`D_S`. + :rtype: double + """ + + if len(args) == 1: + return spike_distance_multi(args[0], **kwargs) + elif len(args) == 2: + return spike_distance_bi(args[0], args[1], **kwargs) + else: + return spike_distance_multi(args, **kwargs) + + +############################################################ +# spike_distance_bi +############################################################ +def spike_distance_bi(spike_train1, spike_train2, interval=None): """ Computes the spike-distance :math:`D_S` of the given spike trains. The spike-distance is the integral over the spike distance profile :math:`S(t)`: @@ -86,35 +173,10 @@ def spike_distance(spike_train1, spike_train2, interval=None): spike_train1.t_end) except ImportError: # Cython backend not available: fall back to average profile - return spike_profile(spike_train1, spike_train2).avrg(interval) + return spike_profile_bi(spike_train1, spike_train2).avrg(interval) else: # some specific interval is provided: compute the whole profile - return spike_profile(spike_train1, spike_train2).avrg(interval) - - -############################################################ -# spike_profile_multi -############################################################ -def spike_profile_multi(spike_trains, indices=None): - """ Computes the multi-variate spike distance profile for a set of spike - trains. That is the average spike-distance of all pairs of spike-trains: - - .. math:: = \\frac{2}{N(N-1)} \\sum_{} S^{i, j}`, - - where the sum goes over all pairs - - :param spike_trains: list of :class:`.SpikeTrain` - :param indices: list of indices defining which spike trains to use, - if None all given spike trains are used (default=None) - :type indices: list or None - :returns: The averaged spike profile :math:`(t)` - :rtype: :class:`.PieceWiseLinFunc` - - """ - average_dist, M = _generic_profile_multi(spike_trains, spike_profile, - indices) - average_dist.mul_scalar(1.0/M) # normalize - return average_dist + return spike_profile_bi(spike_train1, spike_train2).avrg(interval) ############################################################ @@ -139,7 +201,7 @@ def spike_distance_multi(spike_trains, indices=None, interval=None): :returns: The averaged multi-variate spike distance :math:`D_S`. :rtype: double """ - return _generic_distance_multi(spike_trains, spike_distance, indices, + return _generic_distance_multi(spike_trains, spike_distance_bi, indices, interval) @@ -160,5 +222,5 @@ def spike_distance_matrix(spike_trains, indices=None, interval=None): :math:`D_S^{ij}` :rtype: np.array """ - return _generic_distance_matrix(spike_trains, spike_distance, + return _generic_distance_matrix(spike_trains, spike_distance_bi, indices, interval) diff --git a/test/test_generic_interfaces.py b/test/test_generic_interfaces.py index caa9ee4..ee87be4 100644 --- a/test/test_generic_interfaces.py +++ b/test/test_generic_interfaces.py @@ -23,7 +23,12 @@ class dist_from_prof: self.prof_func = prof_func def __call__(self, *args, **kwargs): - return self.prof_func(*args, **kwargs).avrg() + if "interval" in kwargs: + # forward interval arg into avrg function + interval = kwargs.pop("interval") + return self.prof_func(*args, **kwargs).avrg(interval=interval) + else: + return self.prof_func(*args, **kwargs).avrg() def check_func(dist_func): @@ -50,6 +55,22 @@ def check_func(dist_func): isi123_ = dist_func(spike_trains, indices=[0, 1, 2]) assert_equal(isi123, isi123_) + # run the same test with an additional interval parameter + + isi12 = dist_func(t1, t2, interval=[0.0, 0.5]) + isi12_ = dist_func([t1, t2], interval=[0.0, 0.5]) + assert_equal(isi12, isi12_) + + isi12_ = dist_func(spike_trains, indices=[0, 1], interval=[0.0, 0.5]) + assert_equal(isi12, isi12_) + + isi123 = dist_func(t1, t2, t3, interval=[0.0, 0.5]) + isi123_ = dist_func([t1, t2, t3], interval=[0.0, 0.5]) + assert_equal(isi123, isi123_) + + isi123_ = dist_func(spike_trains, indices=[0, 1, 2], interval=[0.0, 0.5]) + assert_equal(isi123, isi123_) + def test_isi_profile(): check_func(dist_from_prof(spk.isi_profile)) @@ -59,6 +80,14 @@ def test_isi_distance(): check_func(spk.isi_distance) +def test_spike_profile(): + check_func(dist_from_prof(spk.spike_profile)) + + +def test_spike_distance(): + check_func(spk.spike_distance) + + if __name__ == "__main__": test_isi_profile() test_isi_distance() diff --git a/test/test_regression/test_regression_15.py b/test/test_regression/test_regression_15.py index dcacae2..54adf23 100644 --- a/test/test_regression/test_regression_15.py +++ b/test/test_regression/test_regression_15.py @@ -20,6 +20,7 @@ import os TEST_PATH = os.path.dirname(os.path.realpath(__file__)) TEST_DATA = os.path.join(TEST_PATH, "..", "SPIKE_Sync_Test.txt") + def test_regression_15_isi(): # load spike trains spike_trains = spk.load_spike_trains_from_txt(TEST_DATA, edges=[0, 4000]) -- cgit v1.2.3 From a57f3d51473b10d81752ad66e4c392563ca1c6f8 Mon Sep 17 00:00:00 2001 From: Mario Mulansky Date: Tue, 2 Feb 2016 17:11:12 +0100 Subject: new generic interface for spike_sync functions Similar to the isi and spike distance functions, also the spike sync functions now support the new generic interface. --- pyspike/spike_sync.py | 136 +++++++++++++++++++++++++++++----------- test/test_generic_interfaces.py | 14 ++++- 2 files changed, 114 insertions(+), 36 deletions(-) (limited to 'test/test_generic_interfaces.py') diff --git a/pyspike/spike_sync.py b/pyspike/spike_sync.py index 3dc29ff..ccb09d9 100644 --- a/pyspike/spike_sync.py +++ b/pyspike/spike_sync.py @@ -15,7 +15,40 @@ from pyspike.generic import _generic_profile_multi, _generic_distance_matrix ############################################################ # spike_sync_profile ############################################################ -def spike_sync_profile(spike_train1, spike_train2, max_tau=None): +def spike_sync_profile(*args, **kwargs): + """ Computes the spike-synchronization profile S_sync(t) of the given + spike trains. Returns the profile as a DiscreteFunction object. In the + bivariate case, he S_sync values are either 1 or 0, indicating the presence + or absence of a coincidence. For multi-variate cases, each spike in the set + of spike trains, the profile is defined as the number of coincidences + divided by the number of spike trains pairs involving the spike train of + containing this spike, which is the number of spike trains minus one (N-1). + + Valid call structures:: + + spike_sync_profile(st1, st2) # returns the bi-variate profile + spike_sync_profile(st1, st2, st3) # multi-variate profile of 3 sts + + sts = [st1, st2, st3, st4] # list of spike trains + spike_sync_profile(sts) # profile of the list of spike trains + spike_sync_profile(sts, indices=[0, 1]) # use only the spike trains + # given by the indices + + :returns: The spike-sync profile :math:`S_{sync}(t)`. + :rtype: :class:`pyspike.function.DiscreteFunction` + """ + if len(args) == 1: + return spike_sync_profile_multi(args[0], **kwargs) + elif len(args) == 2: + return spike_sync_profile_bi(args[0], args[1]) + else: + return spike_sync_profile_multi(args) + + +############################################################ +# spike_sync_profile_bi +############################################################ +def spike_sync_profile_bi(spike_train1, spike_train2, max_tau=None): """ Computes the spike-synchronization profile S_sync(t) of the two given spike trains. Returns the profile as a DiscreteFunction object. The S_sync values are either 1 or 0, indicating the presence or absence of a @@ -27,7 +60,7 @@ def spike_sync_profile(spike_train1, spike_train2, max_tau=None): :type spike_train2: :class:`pyspike.SpikeTrain` :param max_tau: Maximum coincidence window size. If 0 or `None`, the coincidence window has no upper bound. - :returns: The spike-distance profile :math:`S_{sync}(t)`. + :returns: The spike-sync profile :math:`S_{sync}(t)`. :rtype: :class:`pyspike.function.DiscreteFunction` """ @@ -61,6 +94,33 @@ Falling back to slow python backend.") return DiscreteFunc(times, coincidences, multiplicity) +############################################################ +# spike_sync_profile_multi +############################################################ +def spike_sync_profile_multi(spike_trains, indices=None, max_tau=None): + """ Computes the multi-variate spike synchronization profile for a set of + spike trains. For each spike in the set of spike trains, the multi-variate + profile is defined as the number of coincidences divided by the number of + spike trains pairs involving the spike train of containing this spike, + which is the number of spike trains minus one (N-1). + + :param spike_trains: list of :class:`pyspike.SpikeTrain` + :param indices: list of indices defining which spike trains to use, + if None all given spike trains are used (default=None) + :type indices: list or None + :param max_tau: Maximum coincidence window size. If 0 or `None`, the + coincidence window has no upper bound. + :returns: The multi-variate spike sync profile :math:`(t)` + :rtype: :class:`pyspike.function.DiscreteFunction` + + """ + prof_func = partial(spike_sync_profile_bi, max_tau=max_tau) + average_prof, M = _generic_profile_multi(spike_trains, prof_func, + indices) + # average_dist.mul_scalar(1.0/M) # no normalization here! + return average_prof + + ############################################################ # _spike_sync_values ############################################################ @@ -87,18 +147,51 @@ def _spike_sync_values(spike_train1, spike_train2, interval, max_tau): return c, mp except ImportError: # Cython backend not available: fall back to profile averaging - return spike_sync_profile(spike_train1, spike_train2, - max_tau).integral(interval) + return spike_sync_profile_bi(spike_train1, spike_train2, + max_tau).integral(interval) else: # some specific interval is provided: use profile - return spike_sync_profile(spike_train1, spike_train2, - max_tau).integral(interval) + return spike_sync_profile_bi(spike_train1, spike_train2, + max_tau).integral(interval) ############################################################ # spike_sync ############################################################ -def spike_sync(spike_train1, spike_train2, interval=None, max_tau=None): +def spike_sync(*args, **kwargs): + """ Computes the spike synchronization value SYNC of the given spike + trains. The spike synchronization value is the computed as the total number + of coincidences divided by the total number of spikes: + + .. math:: SYNC = \sum_n C_n / N. + + + Valid call structures:: + + spike_sync(st1, st2) # returns the bi-variate spike synchronization + spike_sync(st1, st2, st3) # multi-variate result for 3 spike trains + + spike_trains = [st1, st2, st3, st4] # list of spike trains + spike_sync(spike_trains) # spike-sync of the list of spike trains + spike_sync(spike_trains, indices=[0, 1]) # use only the spike trains + # given by the indices + + :returns: The spike synchronization value. + :rtype: `double` + """ + + if len(args) == 1: + return spike_sync_multi(args[0], **kwargs) + elif len(args) == 2: + return spike_sync_bi(args[0], args[1], **kwargs) + else: + return spike_sync_multi(args, **kwargs) + + +############################################################ +# spike_sync_bi +############################################################ +def spike_sync_bi(spike_train1, spike_train2, interval=None, max_tau=None): """ Computes the spike synchronization value SYNC of the given spike trains. The spike synchronization value is the computed as the total number of coincidences divided by the total number of spikes: @@ -122,33 +215,6 @@ def spike_sync(spike_train1, spike_train2, interval=None, max_tau=None): return 1.0*c/mp -############################################################ -# spike_sync_profile_multi -############################################################ -def spike_sync_profile_multi(spike_trains, indices=None, max_tau=None): - """ Computes the multi-variate spike synchronization profile for a set of - spike trains. For each spike in the set of spike trains, the multi-variate - profile is defined as the number of coincidences divided by the number of - spike trains pairs involving the spike train of containing this spike, - which is the number of spike trains minus one (N-1). - - :param spike_trains: list of :class:`pyspike.SpikeTrain` - :param indices: list of indices defining which spike trains to use, - if None all given spike trains are used (default=None) - :type indices: list or None - :param max_tau: Maximum coincidence window size. If 0 or `None`, the - coincidence window has no upper bound. - :returns: The multi-variate spike sync profile :math:`(t)` - :rtype: :class:`pyspike.function.DiscreteFunction` - - """ - prof_func = partial(spike_sync_profile, max_tau=max_tau) - average_prof, M = _generic_profile_multi(spike_trains, prof_func, - indices) - # average_dist.mul_scalar(1.0/M) # no normalization here! - return average_prof - - ############################################################ # spike_sync_multi ############################################################ @@ -211,6 +277,6 @@ def spike_sync_matrix(spike_trains, indices=None, interval=None, max_tau=None): :rtype: np.array """ - dist_func = partial(spike_sync, max_tau=max_tau) + dist_func = partial(spike_sync_bi, max_tau=max_tau) return _generic_distance_matrix(spike_trains, dist_func, indices, interval) diff --git a/test/test_generic_interfaces.py b/test/test_generic_interfaces.py index ee87be4..7f08067 100644 --- a/test/test_generic_interfaces.py +++ b/test/test_generic_interfaces.py @@ -1,4 +1,4 @@ -""" test_isi_interface.py +""" test_generic_interface.py Tests the generic interfaces of the profile and distance functions @@ -88,6 +88,18 @@ def test_spike_distance(): check_func(spk.spike_distance) +def test_spike_sync_profile(): + check_func(dist_from_prof(spk.spike_sync_profile)) + + +def test_spike_sync(): + check_func(spk.spike_sync) + + if __name__ == "__main__": test_isi_profile() test_isi_distance() + test_spike_profile() + test_spike_distance() + test_spike_sync_profile() + test_spike_sync() -- cgit v1.2.3