summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMario Mulansky <mario.mulansky@gmx.net>2018-07-15 15:33:22 -0700
committerMario Mulansky <mario.mulansky@gmx.net>2018-07-15 15:33:22 -0700
commitaed1f8185cf2a3b3f9330e93681fa12b365d287b (patch)
treeb2468d59187ce753cbb670bfdba8c7f6e30476b6
parent4fe9f4373a83902cedd297bd44bf8a97b7357df0 (diff)
Clean up directionality module, add doxy.
-rw-r--r--doc/pyspike.rst1
-rw-r--r--pyspike/__init__.py17
-rw-r--r--pyspike/spike_directionality.py446
-rw-r--r--pyspike/spike_sync.py4
-rw-r--r--test/test_directionality.py14
5 files changed, 349 insertions, 133 deletions
diff --git a/doc/pyspike.rst b/doc/pyspike.rst
index 9552fa6..3b10d2a 100644
--- a/doc/pyspike.rst
+++ b/doc/pyspike.rst
@@ -65,6 +65,7 @@ PSTH
:show-inheritance:
Directionality
+........................................
.. automodule:: pyspike.spike_directionality
:members:
:undoc-members:
diff --git a/pyspike/__init__.py b/pyspike/__init__.py
index 3cf416f..3897d18 100644
--- a/pyspike/__init__.py
+++ b/pyspike/__init__.py
@@ -1,5 +1,5 @@
"""
-Copyright 2014-2015, Mario Mulansky <mario.mulansky@gmx.net>
+Copyright 2014-2018, Mario Mulansky <mario.mulansky@gmx.net>
Distributed under the BSD License
"""
@@ -29,16 +29,11 @@ from .spikes import load_spike_trains_from_txt, save_spike_trains_to_txt, \
merge_spike_trains, generate_poisson_spikes
from .spike_directionality import spike_directionality, \
- spike_directionality_profiles, spike_directionality_matrix, \
- spike_train_order_profile, spike_train_order, \
- spike_train_order_profile_multi, optimal_spike_train_order_from_matrix, \
- optimal_spike_train_order, permutate_matrix
-
-from .spike_directionality import spike_directionality, \
- spike_directionality_profiles, spike_directionality_matrix, \
- spike_train_order_profile, spike_train_order, \
- spike_train_order_profile_multi, optimal_spike_train_order_from_matrix, \
- optimal_spike_train_order, permutate_matrix
+ spike_directionality_values, spike_directionality_matrix, \
+ spike_train_order_profile, spike_train_order_profile_bi, \
+ spike_train_order_profile_multi, spike_train_order, \
+ spike_train_order_bi, spike_train_order_multi, \
+ optimal_spike_train_sorting, permutate_matrix
# define the __version__ following
# http://stackoverflow.com/questions/17583443
diff --git a/pyspike/spike_directionality.py b/pyspike/spike_directionality.py
index d1a525e..3a71f0b 100644
--- a/pyspike/spike_directionality.py
+++ b/pyspike/spike_directionality.py
@@ -13,11 +13,108 @@ from pyspike.generic import _generic_profile_multi
############################################################
+# spike_directionality_values
+############################################################
+def spike_directionality_values(*args, **kwargs):
+ """ Computes the spike directionality value for each spike in
+ each spike train. Returns a list containing an array of spike directionality
+ values for every given spike train.
+
+ Valid call structures::
+
+ spike_directionality_values(st1, st2) # returns the bi-variate profile
+ spike_directionality_values(st1, st2, st3) # multi-variate profile of 3
+ # spike trains
+
+ spike_trains = [st1, st2, st3, st4] # list of spike trains
+ spike_directionality_values(spike_trains) # profile of the list of spike trains
+ spike_directionality_values(spike_trains, indices=[0, 1]) # use only the spike trains
+ # given by the indices
+
+ Additonal arguments:
+ :param max_tau: Upper bound for coincidence window (default=None).
+ :param indices: list of indices defining which spike trains to use,
+ if None all given spike trains are used (default=None)
+
+ :returns: The spike directionality values :math:`D^n_i` as a list of arrays.
+ """
+ if len(args) == 1:
+ return _spike_directionality_values_impl(args[0], **kwargs)
+ else:
+ return _spike_directionality_values_impl(args, **kwargs)
+
+
+def _spike_directionality_values_impl(spike_trains, indices=None,
+ interval=None, max_tau=None):
+ """ Computes the multi-variate spike directionality profile
+ of the given spike trains.
+
+ :param spike_trains: List of spike trains.
+ :type spike_trains: List of :class:`pyspike.SpikeTrain`
+ :param indices: list of indices defining which spike trains to use,
+ if None all given spike trains are used (default=None)
+ :type indices: list or None
+ :param max_tau: Maximum coincidence window size. If 0 or `None`, the
+ coincidence window has no upper bound.
+ :returns: The spike-directionality values.
+ """
+ if interval is not None:
+ raise NotImplementedError("Parameter `interval` not supported.")
+ if indices is None:
+ indices = np.arange(len(spike_trains))
+ indices = np.array(indices)
+ # check validity of indices
+ assert (indices < len(spike_trains)).all() and (indices >= 0).all(), \
+ "Invalid index list."
+ # list of arrays for resulting asymmetry values
+ asymmetry_list = [np.zeros_like(spike_trains[n].spikes) for n in indices]
+ # generate a list of possible index pairs
+ pairs = [(indices[i], j) for i in range(len(indices))
+ for j in indices[i+1:]]
+
+ # cython implementation
+ try:
+ from .cython.cython_directionality import \
+ spike_directionality_profiles_cython as profile_impl
+ except ImportError:
+ if not(pyspike.disable_backend_warning):
+ print("Warning: spike_distance_cython not found. Make sure that \
+PySpike is installed by running\n 'python setup.py build_ext --inplace'!\n \
+Falling back to slow python backend.")
+ # use python backend
+ from .cython.directionality_python_backend import \
+ spike_directionality_profile_python as profile_impl
+
+ if max_tau is None:
+ max_tau = 0.0
+
+ for i, j in pairs:
+ d1, d2 = profile_impl(spike_trains[i].spikes, spike_trains[j].spikes,
+ spike_trains[i].t_start, spike_trains[i].t_end,
+ max_tau)
+ asymmetry_list[i] += d1
+ asymmetry_list[j] += d2
+ for a in asymmetry_list:
+ a /= len(spike_trains)-1
+ return asymmetry_list
+
+
+############################################################
# spike_directionality
############################################################
def spike_directionality(spike_train1, spike_train2, normalize=True,
interval=None, max_tau=None):
- """ Computes the overall spike directionality for two spike trains.
+ """ Computes the overall spike directionality of the first spike train with
+ respect to the second spike train.
+
+ :param spike_train1: First spike train.
+ :type spike_train1: :class:`pyspike.SpikeTrain`
+ :param spike_train2: Second spike train.
+ :type spike_train2: :class:`pyspike.SpikeTrain`
+ :param normalize: Normalize by the number of spikes (multiplicity).
+ :param max_tau: Maximum coincidence window size. If 0 or `None`, the
+ coincidence window has no upper bound.
+ :returns: The spike train order profile :math:`E(t)`.
"""
if interval is None:
# distance over the whole interval is requested: use specific function
@@ -34,9 +131,14 @@ def spike_directionality(spike_train1, spike_train2, normalize=True,
max_tau)
c = len(spike_train1.spikes)
except ImportError:
- d1, x = spike_directionality_profiles([spike_train1, spike_train2],
- interval=interval,
- max_tau=max_tau)
+ if not(pyspike.disable_backend_warning):
+ print("Warning: spike_distance_cython not found. Make sure that \
+PySpike is installed by running\n 'python setup.py build_ext --inplace'!\n \
+Falling back to slow python backend.")
+ # use profile.
+ d1, x = spike_directionality_values([spike_train1, spike_train2],
+ interval=interval,
+ max_tau=max_tau)
d = np.sum(d1)
c = len(spike_train1.spikes)
if normalize:
@@ -45,7 +147,7 @@ def spike_directionality(spike_train1, spike_train2, normalize=True,
return d
else:
# some specific interval is provided: not yet implemented
- raise NotImplementedError()
+ raise NotImplementedError("Parameter `interval` not supported.")
############################################################
@@ -53,7 +155,17 @@ def spike_directionality(spike_train1, spike_train2, normalize=True,
############################################################
def spike_directionality_matrix(spike_trains, normalize=True, indices=None,
interval=None, max_tau=None):
- """ Computes the spike directionaity matrix for the given spike trains.
+ """ Computes the spike directionality matrix for the given spike trains.
+
+ :param spike_trains: List of spike trains.
+ :type spike_trains: List of :class:`pyspike.SpikeTrain`
+ :param normalize: Normalize by the number of spikes (multiplicity).
+ :param indices: list of indices defining which spike trains to use,
+ if None all given spike trains are used (default=None)
+ :type indices: list or None
+ :param max_tau: Maximum coincidence window size. If 0 or `None`, the
+ coincidence window has no upper bound.
+ :returns: The spike-directionality values.
"""
if indices is None:
indices = np.arange(len(spike_trains))
@@ -75,65 +187,53 @@ def spike_directionality_matrix(spike_trains, normalize=True, indices=None,
############################################################
-# spike_directionality_profiles
+# spike_train_order_profile
############################################################
-def spike_directionality_profiles(spike_trains, indices=None,
- interval=None, max_tau=None):
- """ Computes the spike directionality value for each spike in each spike
- train.
- """
- if indices is None:
- indices = np.arange(len(spike_trains))
- indices = np.array(indices)
- # check validity of indices
- assert (indices < len(spike_trains)).all() and (indices >= 0).all(), \
- "Invalid index list."
- # list of arrays for reulting asymmetry values
- asymmetry_list = [np.zeros_like(st.spikes) for st in spike_trains]
- # generate a list of possible index pairs
- pairs = [(indices[i], j) for i in range(len(indices))
- for j in indices[i+1:]]
+def spike_train_order_profile(*args, **kwargs):
+ """ Computes the spike train order profile :math:`E(t)` of the given
+ spike trains. Returns the profile as a DiscreteFunction object.
- # cython implementation
- try:
- from .cython.cython_directionality import \
- spike_directionality_profiles_cython as profile_impl
- except ImportError:
- if not(pyspike.disable_backend_warning):
- print("Warning: spike_distance_cython not found. Make sure that \
-PySpike is installed by running\n 'python setup.py build_ext --inplace'!\n \
-Falling back to slow python backend.")
- # use python backend
- from .cython.directionality_python_backend import \
- spike_directionality_profile_python as profile_impl
+ Valid call structures::
- if max_tau is None:
- max_tau = 0.0
+ spike_train_order_profile(st1, st2) # returns the bi-variate profile
+ spike_train_order_profile(st1, st2, st3) # multi-variate profile of 3
+ # spike trains
- for i, j in pairs:
- d1, d2 = profile_impl(spike_trains[i].spikes, spike_trains[j].spikes,
- spike_trains[i].t_start, spike_trains[i].t_end,
- max_tau)
- asymmetry_list[i] += d1
- asymmetry_list[j] += d2
- for a in asymmetry_list:
- a /= len(spike_trains)-1
- return asymmetry_list
+ spike_trains = [st1, st2, st3, st4] # list of spike trains
+ spike_train_order_profile(spike_trains) # profile of the list of spike trains
+ spike_train_order_profile(spike_trains, indices=[0, 1]) # use only the spike trains
+ # given by the indices
+
+ Additonal arguments:
+ :param max_tau: Upper bound for coincidence window, `default=None`.
+ :param indices: list of indices defining which spike trains to use,
+ if None all given spike trains are used (default=None)
+
+ :returns: The spike train order profile :math:`E(t)`
+ :rtype: :class:`.DiscreteFunction`
+ """
+ if len(args) == 1:
+ return spike_train_order_profile_multi(args[0], **kwargs)
+ elif len(args) == 2:
+ return spike_train_order_profile_bi(args[0], args[1], **kwargs)
+ else:
+ return spike_train_order_profile_multi(args, **kwargs)
############################################################
-# spike_train_order_profile
+# spike_train_order_profile_bi
############################################################
-def spike_train_order_profile(spike_train1, spike_train2, max_tau=None):
+def spike_train_order_profile_bi(spike_train1, spike_train2, max_tau=None):
""" Computes the spike train order profile P(t) of the two given
spike trains. Returns the profile as a DiscreteFunction object.
+
:param spike_train1: First spike train.
:type spike_train1: :class:`pyspike.SpikeTrain`
:param spike_train2: Second spike train.
:type spike_train2: :class:`pyspike.SpikeTrain`
:param max_tau: Maximum coincidence window size. If 0 or `None`, the
coincidence window has no upper bound.
- :returns: The spike-distance profile :math:`S_{sync}(t)`.
+ :returns: The spike train order profile :math:`E(t)`.
:rtype: :class:`pyspike.function.DiscreteFunction`
"""
# check whether the spike trains are defined for the same interval
@@ -171,21 +271,56 @@ Falling back to slow python backend.")
############################################################
-# spike_train_order
+# spike_train_order_profile_multi
+############################################################
+def spike_train_order_profile_multi(spike_trains, indices=None,
+ max_tau=None):
+ """ Computes the multi-variate spike train order profile for a set of
+ spike trains. For each spike in the set of spike trains, the multi-variate
+ profile is defined as the sum of asymmetry values divided by the number of
+ spike trains pairs involving the spike train of containing this spike,
+ which is the number of spike trains minus one (N-1).
+
+ :param spike_trains: list of :class:`pyspike.SpikeTrain`
+ :param indices: list of indices defining which spike trains to use,
+ if None all given spike trains are used (default=None)
+ :type indices: list or None
+ :param max_tau: Maximum coincidence window size. If 0 or `None`, the
+ coincidence window has no upper bound.
+ :returns: The multi-variate spike sync profile :math:`<S_{sync}>(t)`
+ :rtype: :class:`pyspike.function.DiscreteFunction`
+ """
+ prof_func = partial(spike_train_order_profile_bi, max_tau=max_tau)
+ average_prof, M = _generic_profile_multi(spike_trains, prof_func,
+ indices)
+ return average_prof
+
+
+
+############################################################
+# _spike_train_order_impl
############################################################
-def spike_train_order(spike_train1, spike_train2, normalize=True,
- interval=None, max_tau=None):
- """ Computes the overall spike delay asymmetry value for two spike trains.
+def _spike_train_order_impl(spike_train1, spike_train2,
+ interval=None, max_tau=None):
+ """ Implementation of bi-variatae spike train order value (Synfire Indicator).
+
+ :param spike_train1: First spike train.
+ :type spike_train1: :class:`pyspike.SpikeTrain`
+ :param spike_train2: Second spike train.
+ :type spike_train2: :class:`pyspike.SpikeTrain`
+ :param max_tau: Maximum coincidence window size. If 0 or `None`, the
+ coincidence window has no upper bound.
+ :returns: The spike train order value (Synfire Indicator)
"""
if interval is None:
# distance over the whole interval is requested: use specific function
# for optimal performance
try:
from .cython.cython_directionality import \
- spike_train_order_cython as spike_train_order_impl
+ spike_train_order_cython as spike_train_order_func
if max_tau is None:
max_tau = 0.0
- c, mp = spike_train_order_impl(spike_train1.spikes,
+ c, mp = spike_train_order_func(spike_train1.spikes,
spike_train2.spikes,
spike_train1.t_start,
spike_train1.t_end,
@@ -194,49 +329,127 @@ def spike_train_order(spike_train1, spike_train2, normalize=True,
# Cython backend not available: fall back to profile averaging
c, mp = spike_train_order_profile(spike_train1, spike_train2,
max_tau).integral(interval)
- if normalize:
- return 1.0*c/mp
- else:
- return c
+ return c, mp
else:
# some specific interval is provided: not yet implemented
- raise NotImplementedError()
+ raise NotImplementedError("Parameter `interval` not supported.")
############################################################
-# spike_train_order_profile_multi
+# spike_train_order
############################################################
-def spike_train_order_profile_multi(spike_trains, indices=None,
- max_tau=None):
- """ Computes the multi-variate spike delay asymmetry profile for a set of
- spike trains. For each spike in the set of spike trains, the multi-variate
- profile is defined as the sum of asymmetry values divided by the number of
- spike trains pairs involving the spike train of containing this spike,
- which is the number of spike trains minus one (N-1).
- :param spike_trains: list of :class:`pyspike.SpikeTrain`
+def spike_train_order(*args, **kwargs):
+ """ Computes the spike train order (Synfire Indicator) of the given
+ spike trains.
+
+ Valid call structures::
+
+ spike_train_order(st1, st2, normalize=True) # normalized bi-variate
+ # spike train order
+ spike_train_order(st1, st2, st3) # multi-variate result of 3 spike trains
+
+ spike_trains = [st1, st2, st3, st4] # list of spike trains
+ spike_train_order(spike_trains) # result for the list of spike trains
+ spike_train_order(spike_trains, indices=[0, 1]) # use only the spike trains
+ # given by the indices
+
+ Additonal arguments:
+ - `max_tau` Upper bound for coincidence window, `default=None`.
+ - `normalize` Flag indicating if the reslut should be normalized by the
+ number of spikes , default=`False`
+
+
+ :returns: The spike train order value (Synfire Indicator)
+ """
+ if len(args) == 1:
+ return spike_train_order_multi(args[0], **kwargs)
+ elif len(args) == 2:
+ return spike_train_order_bi(args[0], args[1], **kwargs)
+ else:
+ return spike_train_order_multi(args, **kwargs)
+
+
+############################################################
+# spike_train_order_bi
+############################################################
+def spike_train_order_bi(spike_train1, spike_train2, normalize=True,
+ interval=None, max_tau=None):
+ """ Computes the overall spike train order value (Synfire Indicator)
+ for two spike trains.
+
+ :param spike_train1: First spike train.
+ :type spike_train1: :class:`pyspike.SpikeTrain`
+ :param spike_train2: Second spike train.
+ :type spike_train2: :class:`pyspike.SpikeTrain`
+ :param normalize: Normalize by the number of spikes (multiplicity).
+ :param max_tau: Maximum coincidence window size. If 0 or `None`, the
+ coincidence window has no upper bound.
+ :returns: The spike train order value (Synfire Indicator)
+ """
+ c, mp = _spike_train_order_impl(spike_train1, spike_train2, interval, max_tau)
+ if normalize:
+ return 1.0*c/mp
+ else:
+ return c
+
+############################################################
+# spike_train_order_multi
+############################################################
+def spike_train_order_multi(spike_trains, indices=None, normalize=True,
+ interval=None, max_tau=None):
+ """ Computes the overall spike train order value (Synfire Indicator)
+ for many spike trains.
+
+ :param spike_trains: list of :class:`.SpikeTrain`
:param indices: list of indices defining which spike trains to use,
if None all given spike trains are used (default=None)
- :type indices: list or None
+ :param normalize: Normalize by the number of spike (multiplicity).
+ :param interval: averaging interval given as a pair of floats, if None
+ the average over the whole function is computed.
+ :type interval: Pair of floats or None.
:param max_tau: Maximum coincidence window size. If 0 or `None`, the
coincidence window has no upper bound.
- :returns: The multi-variate spike sync profile :math:`<S_{sync}>(t)`
- :rtype: :class:`pyspike.function.DiscreteFunction`
+ :returns: Spike train order values (Synfire Indicator) F for the given spike trains.
+ :rtype: double
"""
- prof_func = partial(spike_train_order_profile, max_tau=max_tau)
- average_prof, M = _generic_profile_multi(spike_trains, prof_func,
- indices)
- # average_dist.mul_scalar(1.0/M) # no normalization here!
- return average_prof
+ if indices is None:
+ indices = np.arange(len(spike_trains))
+ indices = np.array(indices)
+ # check validity of indices
+ assert (indices < len(spike_trains)).all() and (indices >= 0).all(), \
+ "Invalid index list."
+ # generate a list of possible index pairs
+ pairs = [(indices[i], j) for i in range(len(indices))
+ for j in indices[i+1:]]
+
+ e_total = 0.0
+ m_total = 0.0
+ for (i, j) in pairs:
+ e, m = _spike_train_order_impl(spike_trains[i], spike_trains[j],
+ interval, max_tau)
+ e_total += e
+ m_total += m
+
+ if m == 0.0:
+ return 1.0
+ else:
+ return e_total/m_total
+
############################################################
-# optimal_spike_train_order_from_matrix
+# optimal_spike_train_sorting_from_matrix
############################################################
-def optimal_spike_train_order_from_matrix(D, full_output=False):
- """ finds the best sorting via simulated annealing.
+def _optimal_spike_train_sorting_from_matrix(D, full_output=False):
+ """ Finds the best sorting via simulated annealing.
Returns the optimal permutation p and A value.
- Internal function, don't call directly! Use optimal_asymmetry_order
- instead.
+ Not for direct use, call :func:`.optimal_spike_train_sorting` instead.
+
+ :param D: The directionality (Spike-ORDER) matrix.
+ :param full_output: If true, then function will additionally return the
+ number of performed iterations (default=False)
+ :return: (p, F) - tuple with the optimal permutation and synfire indicator.
+ if `full_output=True` , (p, F, iter) is returned.
"""
N = len(D)
A = np.sum(np.triu(D, 0))
@@ -247,35 +460,14 @@ def optimal_spike_train_order_from_matrix(D, full_output=False):
T_end = 1E-5 * T_start # final temperature
alpha = 0.9 # cooling factor
- from .cython.cython_simulated_annealing import sim_ann_cython as sim_ann
+ try:
+ from .cython.cython_simulated_annealing import sim_ann_cython as sim_ann
+ except ImportError:
+ raise NotImplementedError("PySpike with Cython required for computing spike train"
+ " sorting!")
p, A, total_iter = sim_ann(D, T_start, T_end, alpha)
- # T = T_start
- # total_iter = 0
- # while T > T_end:
- # iterations = 0
- # succ_iter = 0
- # # equilibrate for 100*N steps or 10*N successful steps
- # while iterations < 100*N and succ_iter < 10*N:
- # # exchange two rows and cols
- # ind1 = np.random.randint(N-1)
- # if ind1 < N-1:
- # ind2 = ind1+1
- # else: # this can never happend
- # ind2 = 0
- # delta_A = -2*D[p[ind1], p[ind2]]
- # if delta_A > 0.0 or exp(delta_A/T) > np.random.random():
- # # swap indices
- # p[ind1], p[ind2] = p[ind2], p[ind1]
- # A += delta_A
- # succ_iter += 1
- # iterations += 1
- # total_iter += iterations
- # T *= alpha # cool down
- # if succ_iter == 0:
- # break
-
if full_output:
return p, A, total_iter
else:
@@ -283,26 +475,44 @@ def optimal_spike_train_order_from_matrix(D, full_output=False):
############################################################
-# optimal_spike_train_order
+# optimal_spike_train_sorting
############################################################
-def optimal_spike_train_order(spike_trains, indices=None, interval=None,
- max_tau=None, full_output=False):
- """ finds the best sorting of the given spike trains via simulated
- annealing.
- Returns the optimal permutation p and A value.
+def optimal_spike_train_sorting(spike_trains, indices=None, interval=None,
+ max_tau=None, full_output=False):
+ """ Finds the best sorting of the given spike trains by computing the spike
+ directionality matrix and optimize the order using simulated annealing.
+ For a detailed description of the algorithm see:
+ `http://iopscience.iop.org/article/10.1088/1367-2630/aa68c3/meta`
+
+ :param spike_trains: list of :class:`.SpikeTrain`
+ :param indices: list of indices defining which spike trains to use,
+ if None all given spike trains are used (default=None)
+ :type indices: list or None
+ :param interval: time interval filter given as a pair of floats, if None
+ the full spike trains are used (default=None).
+ :type interval: Pair of floats or None.
+ :param max_tau: Maximum coincidence window size. If 0 or `None`, the
+ coincidence window has no upper bound (default=None).
+ :param full_output: If true, then function will additionally return the
+ number of performed iterations (default=False)
+ :return: (p, F) - tuple with the optimal permutation and synfire indicator.
+ if `full_output=True` , (p, F, iter) is returned.
"""
D = spike_directionality_matrix(spike_trains, normalize=False,
indices=indices, interval=interval,
max_tau=max_tau)
- return optimal_spike_train_order_from_matrix(D, full_output)
-
+ return _optimal_spike_train_order_from_matrix(D, full_output)
############################################################
# permutate_matrix
############################################################
def permutate_matrix(D, p):
- """ Applies the permutation p to the columns and rows of matrix D.
- Return the new permutated matrix.
+ """ Helper function that applies the permutation p to the columns and rows
+ of matrix D. Return the permutated matrix :math:`D'[n,m] = D[p[n], p[m]]`.
+
+ :param D: The matrix.
+ :param d: The permutation.
+ :return: The permuated matrix D', ie :math:`D'[n,m] = D[p[n], p[m]]`
"""
N = len(D)
D_p = np.empty_like(D)
diff --git a/pyspike/spike_sync.py b/pyspike/spike_sync.py
index 32e6bf4..95ef454 100644
--- a/pyspike/spike_sync.py
+++ b/pyspike/spike_sync.py
@@ -45,9 +45,9 @@ def spike_sync_profile(*args, **kwargs):
if len(args) == 1:
return spike_sync_profile_multi(args[0], **kwargs)
elif len(args) == 2:
- return spike_sync_profile_bi(args[0], args[1])
+ return spike_sync_profile_bi(args[0], args[1], **kwargs)
else:
- return spike_sync_profile_multi(args)
+ return spike_sync_profile_multi(args, **kwargs)
############################################################
diff --git a/test/test_directionality.py b/test/test_directionality.py
index 63865cc..5c7e917 100644
--- a/test/test_directionality.py
+++ b/test/test_directionality.py
@@ -14,7 +14,6 @@ from numpy.testing import assert_equal, assert_almost_equal, \
import pyspike as spk
from pyspike import SpikeTrain, DiscreteFunc
-# from pyspike.spike_directionality import _spike_directionality_profile
def test_spike_directionality():
@@ -39,7 +38,7 @@ def test_spike_directionality():
D_expected = np.array([[0, 2.0, 0.0], [-2.0, 0.0, -1.0], [0.0, 1.0, 0.0]])
assert_array_equal(D, D_expected)
- dir_profs = spk.spike_directionality_profiles([st1, st2, st3])
+ dir_profs = spk.spike_directionality_values([st1, st2, st3])
assert_array_equal(dir_profs[0], [1.0, 0.0, 0.0])
assert_array_equal(dir_profs[1], [-0.5, -1.0, 0.0])
@@ -87,3 +86,14 @@ def test_spike_train_order():
assert_array_equal(f.x, expected_x)
assert_array_equal(f.y, expected_y)
assert_array_equal(f.mp, expected_mp)
+
+ # Averaging the profile should be the same as computing the synfire indicator directly.
+ assert_almost_equal(f.avrg(), spk.spike_train_order([st1, st2, st3]))
+
+ # We can also compute the synfire indicator from the Directionality Matrix:
+ D_matrix = spk.spike_directionality_matrix([st1, st2, st3], normalize=False)
+ print D_matrix
+ num_spikes = np.sum(len(st) for st in [st1, st2, st3])
+ print f.avrg(), num_spikes
+ syn_fire = np.sum(np.triu(D_matrix)) / num_spikes
+ assert_almost_equal(f.avrg(), syn_fire)