From 827be39a1646f1e518f7210b8943006f5741144d Mon Sep 17 00:00:00 2001 From: Mario Mulansky Date: Wed, 26 Aug 2015 17:40:09 +0200 Subject: further refactoring of directionality --- pyspike/spike_directionality.py | 232 +++++++++++++++++++++++++++++----------- 1 file changed, 169 insertions(+), 63 deletions(-) (limited to 'pyspike/spike_directionality.py') diff --git a/pyspike/spike_directionality.py b/pyspike/spike_directionality.py index 0e69cb5..f608ecc 100644 --- a/pyspike/spike_directionality.py +++ b/pyspike/spike_directionality.py @@ -7,6 +7,8 @@ import numpy as np from math import exp import pyspike from pyspike import DiscreteFunc +from functools import partial +from pyspike.generic import _generic_profile_multi ############################################################ @@ -21,23 +23,21 @@ def spike_directionality(spike_train1, spike_train2, normalize=True, # for optimal performance try: from cython.cython_directionality import \ - spike_train_order_cython as spike_train_order_impl + spike_directionality_cython as spike_directionality_impl if max_tau is None: max_tau = 0.0 - c, mp = spike_train_order_impl(spike_train1.spikes, - spike_train2.spikes, - spike_train1.t_start, - spike_train1.t_end, - max_tau) + d = spike_directionality_impl(spike_train1.spikes, + spike_train2.spikes, + spike_train1.t_start, + spike_train1.t_end, + max_tau) + c = len(spike_train1.spikes) except ImportError: - # Cython backend not available: fall back to profile averaging - c, mp = _spike_directionality_profile(spike_train1, - spike_train2, - max_tau).integral(interval) + raise NotImplementedError() if normalize: - return 1.0*c/mp + return 1.0*d/c else: - return c + return d else: # some specific interval is provided: not yet implemented raise NotImplementedError() @@ -70,11 +70,11 @@ def spike_directionality_matrix(spike_trains, normalize=True, indices=None, ############################################################ -# spike_train_order_profile +# spike_directionality_profiles ############################################################ -def spike_train_order_profile(spike_trains, indices=None, - interval=None, max_tau=None): - """ Computes the spike train symmetry value for each spike in each spike +def spike_directionality_profiles(spike_trains, indices=None, + interval=None, max_tau=None): + """ Computes the spike directionality value for each spike in each spike train. """ if indices is None: @@ -92,7 +92,7 @@ def spike_train_order_profile(spike_trains, indices=None, # cython implementation try: from cython.cython_directionality import \ - spike_order_values_cython as spike_order_values_impl + spike_directionality_profiles_cython as profile_impl except ImportError: raise NotImplementedError() # if not(pyspike.disable_backend_warning): @@ -107,11 +107,9 @@ def spike_train_order_profile(spike_trains, indices=None, max_tau = 0.0 for i, j in pairs: - a1, a2 = spike_order_values_impl(spike_trains[i].spikes, - spike_trains[j].spikes, - spike_trains[i].t_start, - spike_trains[i].t_end, - max_tau) + a1, a2 = profile_impl(spike_trains[i].spikes, spike_trains[j].spikes, + spike_trains[i].t_start, spike_trains[i].t_end, + max_tau) asymmetry_list[i] += a1 asymmetry_list[j] += a2 for a in asymmetry_list: @@ -119,6 +117,114 @@ def spike_train_order_profile(spike_trains, indices=None, return asymmetry_list +############################################################ +# spike_train_order_profile +############################################################ +def spike_train_order_profile(spike_train1, spike_train2, max_tau=None): + """ Computes the spike train order profile P(t) of the two given + spike trains. Returns the profile as a DiscreteFunction object. + :param spike_train1: First spike train. + :type spike_train1: :class:`pyspike.SpikeTrain` + :param spike_train2: Second spike train. + :type spike_train2: :class:`pyspike.SpikeTrain` + :param max_tau: Maximum coincidence window size. If 0 or `None`, the + coincidence window has no upper bound. + :returns: The spike-distance profile :math:`S_{sync}(t)`. + :rtype: :class:`pyspike.function.DiscreteFunction` + """ + # check whether the spike trains are defined for the same interval + assert spike_train1.t_start == spike_train2.t_start, \ + "Given spike trains are not defined on the same interval!" + assert spike_train1.t_end == spike_train2.t_end, \ + "Given spike trains are not defined on the same interval!" + + # cython implementation + try: + from cython.cython_directionality import \ + spike_train_order_profile_cython as \ + spike_train_order_profile_impl + except ImportError: + # raise NotImplementedError() + if not(pyspike.disable_backend_warning): + print("Warning: spike_distance_cython not found. Make sure that \ +PySpike is installed by running\n 'python setup.py build_ext --inplace'!\n \ +Falling back to slow python backend.") + # use python backend + from cython.directionality_python_backend import \ + spike_train_order_python as spike_train_order_profile_impl + + if max_tau is None: + max_tau = 0.0 + + times, coincidences, multiplicity \ + = spike_train_order_profile_impl(spike_train1.spikes, + spike_train2.spikes, + spike_train1.t_start, + spike_train1.t_end, + max_tau) + + return DiscreteFunc(times, coincidences, multiplicity) + + +############################################################ +# spike_train_order +############################################################ +def spike_train_order(spike_train1, spike_train2, normalize=True, + interval=None, max_tau=None): + """ Computes the overall spike delay asymmetry value for two spike trains. + """ + if interval is None: + # distance over the whole interval is requested: use specific function + # for optimal performance + try: + from cython.cython_directionality import \ + spike_train_order_cython as spike_train_order_impl + if max_tau is None: + max_tau = 0.0 + c, mp = spike_train_order_impl(spike_train1.spikes, + spike_train2.spikes, + spike_train1.t_start, + spike_train1.t_end, + max_tau) + except ImportError: + # Cython backend not available: fall back to profile averaging + c, mp = spike_train_order_profile(spike_train1, spike_train2, + max_tau).integral(interval) + if normalize: + return 1.0*c/mp + else: + return c + else: + # some specific interval is provided: not yet implemented + raise NotImplementedError() + + +############################################################ +# spike_train_order_profile_multi +############################################################ +def spike_train_order_profile_multi(spike_trains, indices=None, + max_tau=None): + """ Computes the multi-variate spike delay asymmetry profile for a set of + spike trains. For each spike in the set of spike trains, the multi-variate + profile is defined as the sum of asymmetry values divided by the number of + spike trains pairs involving the spike train of containing this spike, + which is the number of spike trains minus one (N-1). + :param spike_trains: list of :class:`pyspike.SpikeTrain` + :param indices: list of indices defining which spike trains to use, + if None all given spike trains are used (default=None) + :type indices: list or None + :param max_tau: Maximum coincidence window size. If 0 or `None`, the + coincidence window has no upper bound. + :returns: The multi-variate spike sync profile :math:`(t)` + :rtype: :class:`pyspike.function.DiscreteFunction` + """ + prof_func = partial(spike_train_order_profile, max_tau=max_tau) + average_prof, M = _generic_profile_multi(spike_trains, prof_func, + indices) + # average_dist.mul_scalar(1.0/M) # no normalization here! + return average_prof + + ############################################################ # optimal_spike_train_order_from_matrix ############################################################ @@ -195,50 +301,50 @@ def permutate_matrix(D, p): ############################################################ # _spike_directionality_profile ############################################################ -def _spike_directionality_profile(spike_train1, spike_train2, - max_tau=None): - """ Computes the spike delay asymmetry profile A(t) of the two given - spike trains. Returns the profile as a DiscreteFunction object. +# def _spike_directionality_profile(spike_train1, spike_train2, +# max_tau=None): +# """ Computes the spike delay asymmetry profile A(t) of the two given +# spike trains. Returns the profile as a DiscreteFunction object. - :param spike_train1: First spike train. - :type spike_train1: :class:`pyspike.SpikeTrain` - :param spike_train2: Second spike train. - :type spike_train2: :class:`pyspike.SpikeTrain` - :param max_tau: Maximum coincidence window size. If 0 or `None`, the - coincidence window has no upper bound. - :returns: The spike-distance profile :math:`S_{sync}(t)`. - :rtype: :class:`pyspike.function.DiscreteFunction` +# :param spike_train1: First spike train. +# :type spike_train1: :class:`pyspike.SpikeTrain` +# :param spike_train2: Second spike train. +# :type spike_train2: :class:`pyspike.SpikeTrain` +# :param max_tau: Maximum coincidence window size. If 0 or `None`, the +# coincidence window has no upper bound. +# :returns: The spike-distance profile :math:`S_{sync}(t)`. +# :rtype: :class:`pyspike.function.DiscreteFunction` - """ - # check whether the spike trains are defined for the same interval - assert spike_train1.t_start == spike_train2.t_start, \ - "Given spike trains are not defined on the same interval!" - assert spike_train1.t_end == spike_train2.t_end, \ - "Given spike trains are not defined on the same interval!" +# """ +# # check whether the spike trains are defined for the same interval +# assert spike_train1.t_start == spike_train2.t_start, \ +# "Given spike trains are not defined on the same interval!" +# assert spike_train1.t_end == spike_train2.t_end, \ +# "Given spike trains are not defined on the same interval!" - # cython implementation - try: - from cython.cython_directionality import \ - spike_train_order_profile_cython as \ - spike_train_order_profile_impl - except ImportError: - # raise NotImplementedError() - if not(pyspike.disable_backend_warning): - print("Warning: spike_distance_cython not found. Make sure that \ -PySpike is installed by running\n 'python setup.py build_ext --inplace'!\n \ -Falling back to slow python backend.") - # use python backend - from cython.directionality_python_backend import \ - spike_train_order_python as spike_train_order_profile_impl +# # cython implementation +# try: +# from cython.cython_directionality import \ +# spike_train_order_profile_cython as \ +# spike_train_order_profile_impl +# except ImportError: +# # raise NotImplementedError() +# if not(pyspike.disable_backend_warning): +# print("Warning: spike_distance_cython not found. Make sure that \ +# PySpike is installed by running\n 'python setup.py build_ext --inplace'!\n \ +# Falling back to slow python backend.") +# # use python backend +# from cython.directionality_python_backend import \ +# spike_train_order_python as spike_train_order_profile_impl - if max_tau is None: - max_tau = 0.0 +# if max_tau is None: +# max_tau = 0.0 - times, coincidences, multiplicity \ - = spike_train_order_profile_impl(spike_train1.spikes, - spike_train2.spikes, - spike_train1.t_start, - spike_train1.t_end, - max_tau) +# times, coincidences, multiplicity \ +# = spike_train_order_profile_impl(spike_train1.spikes, +# spike_train2.spikes, +# spike_train1.t_start, +# spike_train1.t_end, +# max_tau) - return DiscreteFunc(times, coincidences, multiplicity) +# return DiscreteFunc(times, coincidences, multiplicity) -- cgit v1.2.3