diff options
Diffstat (limited to 'pyspike/spike_distance.py')
-rw-r--r-- | pyspike/spike_distance.py | 155 |
1 files changed, 111 insertions, 44 deletions
diff --git a/pyspike/spike_distance.py b/pyspike/spike_distance.py index e418283..0fd86c1 100644 --- a/pyspike/spike_distance.py +++ b/pyspike/spike_distance.py @@ -13,10 +13,46 @@ from pyspike.generic import _generic_profile_multi, _generic_distance_multi, \ ############################################################ # spike_profile ############################################################ -def spike_profile(spike_train1, spike_train2): - """ Computes the spike-distance profile :math:`S(t)` of the two given spike - trains. Returns the profile as a PieceWiseLinFunc object. The SPIKE-values - are defined positive :math:`S(t)>=0`. +def spike_profile(*args, **kwargs): + """ Computes the spike-distance profile :math:`S(t)` of the given + spike trains. Returns the profile as a PieceWiseConstLin object. The + SPIKE-values are defined positive :math:`S(t)>=0`. + + Valid call structures:: + + spike_profile(st1, st2) # returns the bi-variate profile + spike_profile(st1, st2, st3) # multi-variate profile of 3 spike trains + + spike_trains = [st1, st2, st3, st4] # list of spike trains + spike_profile(spike_trains) # profile of the list of spike trains + spike_profile(spike_trains, indices=[0, 1]) # use only the spike trains + # given by the indices + + The multivariate spike-distance profile is defined as the average of all + pairs of spike-trains: + + .. math:: <S(t)> = \\frac{2}{N(N-1)} \\sum_{<i,j>} S^{i, j}`, + + where the sum goes over all pairs <i,j> + + :returns: The spike-distance profile :math:`S(t)` + :rtype: :class:`.PieceWiseConstLin` + """ + if len(args) == 1: + return spike_profile_multi(args[0], **kwargs) + elif len(args) == 2: + return spike_profile_bi(args[0], args[1]) + else: + return spike_profile_multi(args) + + +############################################################ +# spike_profile_bi +############################################################ +def spike_profile_bi(spike_train1, spike_train2): + """ Specific function to compute a bivariate SPIKE-profile. This is a + deprecated function and should not be called directly. Use + :func:`.spike_profile` to compute SPIKE-profiles. :param spike_train1: First spike train. :type spike_train1: :class:`.SpikeTrain` @@ -54,14 +90,74 @@ Falling back to slow python backend.") ############################################################ +# spike_profile_multi +############################################################ +def spike_profile_multi(spike_trains, indices=None): + """ Specific function to compute a multivariate SPIKE-profile. This is a + deprecated function and should not be called directly. Use + :func:`.spike_profile` to compute SPIKE-profiles. + + :param spike_trains: list of :class:`.SpikeTrain` + :param indices: list of indices defining which spike trains to use, + if None all given spike trains are used (default=None) + :type indices: list or None + :returns: The averaged spike profile :math:`<S>(t)` + :rtype: :class:`.PieceWiseLinFunc` + + """ + average_dist, M = _generic_profile_multi(spike_trains, spike_profile_bi, + indices) + average_dist.mul_scalar(1.0/M) # normalize + return average_dist + + +############################################################ # spike_distance ############################################################ -def spike_distance(spike_train1, spike_train2, interval=None): - """ Computes the spike-distance :math:`D_S` of the given spike trains. The +def spike_distance(*args, **kwargs): + """ Computes the SPIKE-distance :math:`D_S` of the given spike trains. The spike-distance is the integral over the spike distance profile - :math:`S(t)`: + :math:`D(t)`: + + .. math:: D_S = \\int_{T_0}^{T_1} S(t) dt. + - .. math:: D_S = \int_{T_0}^{T_1} S(t) dt. + Valid call structures:: + + spike_distance(st1, st2) # returns the bi-variate distance + spike_distance(st1, st2, st3) # multi-variate distance of 3 spike trains + + spike_trains = [st1, st2, st3, st4] # list of spike trains + spike_distance(spike_trains) # distance of the list of spike trains + spike_distance(spike_trains, indices=[0, 1]) # use only the spike trains + # given by the indices + + In the multivariate case, the spike distance is given as the integral over + the multivariate profile, that is the average profile of all spike train + pairs: + + .. math:: D_S = \\int_0^T \\frac{2}{N(N-1)} \\sum_{<i,j>} + S^{i, j} dt + + :returns: The spike-distance :math:`D_S`. + :rtype: double + """ + + if len(args) == 1: + return spike_distance_multi(args[0], **kwargs) + elif len(args) == 2: + return spike_distance_bi(args[0], args[1], **kwargs) + else: + return spike_distance_multi(args, **kwargs) + + +############################################################ +# spike_distance_bi +############################################################ +def spike_distance_bi(spike_train1, spike_train2, interval=None): + """ Specific function to compute a bivariate SPIKE-distance. This is a + deprecated function and should not be called directly. Use + :func:`.spike_distance` to compute SPIKE-distances. :param spike_train1: First spike train. :type spike_train1: :class:`.SpikeTrain` @@ -86,48 +182,19 @@ def spike_distance(spike_train1, spike_train2, interval=None): spike_train1.t_end) except ImportError: # Cython backend not available: fall back to average profile - return spike_profile(spike_train1, spike_train2).avrg(interval) + return spike_profile_bi(spike_train1, spike_train2).avrg(interval) else: # some specific interval is provided: compute the whole profile - return spike_profile(spike_train1, spike_train2).avrg(interval) - - -############################################################ -# spike_profile_multi -############################################################ -def spike_profile_multi(spike_trains, indices=None): - """ Computes the multi-variate spike distance profile for a set of spike - trains. That is the average spike-distance of all pairs of spike-trains: - - .. math:: <S(t)> = \\frac{2}{N(N-1)} \\sum_{<i,j>} S^{i, j}`, - - where the sum goes over all pairs <i,j> - - :param spike_trains: list of :class:`.SpikeTrain` - :param indices: list of indices defining which spike trains to use, - if None all given spike trains are used (default=None) - :type indices: list or None - :returns: The averaged spike profile :math:`<S>(t)` - :rtype: :class:`.PieceWiseLinFunc` - - """ - average_dist, M = _generic_profile_multi(spike_trains, spike_profile, - indices) - average_dist.mul_scalar(1.0/M) # normalize - return average_dist + return spike_profile_bi(spike_train1, spike_train2).avrg(interval) ############################################################ # spike_distance_multi ############################################################ def spike_distance_multi(spike_trains, indices=None, interval=None): - """ Computes the multi-variate spike distance for a set of spike trains. - That is the time average of the multi-variate spike profile: - - .. math:: D_S = \\int_0^T \\frac{2}{N(N-1)} \\sum_{<i,j>} - S^{i, j} dt - - where the sum goes over all pairs <i,j> + """ Specific function to compute a multivariate SPIKE-distance. This is a + deprecated function and should not be called directly. Use + :func:`.spike_distance` to compute SPIKE-distances. :param spike_trains: list of :class:`.SpikeTrain` :param indices: list of indices defining which spike trains to use, @@ -139,7 +206,7 @@ def spike_distance_multi(spike_trains, indices=None, interval=None): :returns: The averaged multi-variate spike distance :math:`D_S`. :rtype: double """ - return _generic_distance_multi(spike_trains, spike_distance, indices, + return _generic_distance_multi(spike_trains, spike_distance_bi, indices, interval) @@ -160,5 +227,5 @@ def spike_distance_matrix(spike_trains, indices=None, interval=None): :math:`D_S^{ij}` :rtype: np.array """ - return _generic_distance_matrix(spike_trains, spike_distance, + return _generic_distance_matrix(spike_trains, spike_distance_bi, indices, interval) |