summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMario Mulansky <mario.mulansky@gmx.net>2015-08-26 12:10:01 +0200
committerMario Mulansky <mario.mulansky@gmx.net>2018-06-02 12:58:46 -0700
commit0c710a2450a055091fe6de2313d6240671d1c74d (patch)
tree71f6d604779c636c2aa448887de500af07a89673
parent44d23620d2faa78ca74437fbd3f1b95da722a853 (diff)
reorganized directionality module
-rw-r--r--pyspike/__init__.py9
-rw-r--r--pyspike/cython/cython_directionality.pyx223
-rw-r--r--pyspike/cython/directionality_python_backend.py89
-rw-r--r--pyspike/spike_directionality.py244
-rw-r--r--setup.py19
-rw-r--r--test/test_directionality.py41
6 files changed, 619 insertions, 6 deletions
diff --git a/pyspike/__init__.py b/pyspike/__init__.py
index 08253fb..7fa5265 100644
--- a/pyspike/__init__.py
+++ b/pyspike/__init__.py
@@ -7,8 +7,8 @@ Distributed under the BSD License
from __future__ import absolute_import
__all__ = ["isi_distance", "spike_distance", "spike_sync", "psth",
- "spikes", "SpikeTrain", "PieceWiseConstFunc", "PieceWiseLinFunc",
- "DiscreteFunc", "directionality"]
+ "spikes", "spike_directionality", "SpikeTrain",
+ "PieceWiseConstFunc", "PieceWiseLinFunc", "DiscreteFunc"]
from .PieceWiseConstFunc import PieceWiseConstFunc
from .PieceWiseLinFunc import PieceWiseLinFunc
@@ -27,6 +27,11 @@ from .spikes import load_spike_trains_from_txt, save_spike_trains_to_txt, \
spike_train_from_string, import_spike_trains_from_time_series, \
merge_spike_trains, generate_poisson_spikes
+from spike_directionality import spike_directionality, \
+ spike_directionality_matrix, spike_train_order_profile, \
+ optimal_spike_train_order_from_matrix, optimal_spike_train_order, \
+ permutate_matrix
+
# define the __version__ following
# http://stackoverflow.com/questions/17583443
from pkg_resources import get_distribution, DistributionNotFound
diff --git a/pyspike/cython/cython_directionality.pyx b/pyspike/cython/cython_directionality.pyx
new file mode 100644
index 0000000..e1f63c4
--- /dev/null
+++ b/pyspike/cython/cython_directionality.pyx
@@ -0,0 +1,223 @@
+#cython: boundscheck=False
+#cython: wraparound=False
+#cython: cdivision=True
+
+"""
+cython_directionality.pyx
+
+cython implementation of the spike delay asymmetry measures
+
+Copyright 2015, Mario Mulansky <mario.mulansky@gmx.net>
+
+Distributed under the BSD License
+
+"""
+
+"""
+To test whether things can be optimized: remove all yellow stuff
+in the html output::
+
+ cython -a cython_directionality.pyx
+
+which gives::
+
+ cython_directionality.html
+
+"""
+
+import numpy as np
+cimport numpy as np
+
+from libc.math cimport fabs
+from libc.math cimport fmax
+from libc.math cimport fmin
+
+# from pyspike.cython.cython_distances cimport get_tau
+
+DTYPE = np.float
+ctypedef np.float_t DTYPE_t
+
+
+############################################################
+# get_tau
+############################################################
+cdef inline double get_tau(double[:] spikes1, double[:] spikes2,
+ int i, int j, double interval, double max_tau):
+ cdef double m = interval # use interval length as initial tau
+ cdef int N1 = spikes1.shape[0]-1 # len(spikes1)-1
+ cdef int N2 = spikes2.shape[0]-1 # len(spikes2)-1
+ if i < N1 and i > -1:
+ m = fmin(m, spikes1[i+1]-spikes1[i])
+ if j < N2 and j > -1:
+ m = fmin(m, spikes2[j+1]-spikes2[j])
+ if i > 0:
+ m = fmin(m, spikes1[i]-spikes1[i-1])
+ if j > 0:
+ m = fmin(m, spikes2[j]-spikes2[j-1])
+ m *= 0.5
+ if max_tau > 0.0:
+ m = fmin(m, max_tau)
+ return m
+
+
+############################################################
+# spike_train_order_profile_cython
+############################################################
+def spike_train_order_profile_cython(double[:] spikes1, double[:] spikes2,
+ double t_start, double t_end,
+ double max_tau):
+
+ cdef int N1 = len(spikes1)
+ cdef int N2 = len(spikes2)
+ cdef int i = -1
+ cdef int j = -1
+ cdef int n = 0
+ cdef double[:] st = np.zeros(N1 + N2 + 2) # spike times
+ cdef double[:] a = np.zeros(N1 + N2 + 2) # asymmetry values
+ cdef double[:] mp = np.ones(N1 + N2 + 2) # multiplicity
+ cdef double interval = t_end - t_start
+ cdef double tau
+ while i + j < N1 + N2 - 2:
+ if (i < N1-1) and (j == N2-1 or spikes1[i+1] < spikes2[j+1]):
+ i += 1
+ n += 1
+ tau = get_tau(spikes1, spikes2, i, j, interval, max_tau)
+ st[n] = spikes1[i]
+ if j > -1 and spikes1[i]-spikes2[j] < tau:
+ # coincidence between the current spike and the previous spike
+ # spike from spike train 1 after spike train 2
+ # both get marked with -1
+ a[n] = -1
+ a[n-1] = -1
+ elif (j < N2-1) and (i == N1-1 or spikes1[i+1] > spikes2[j+1]):
+ j += 1
+ n += 1
+ tau = get_tau(spikes1, spikes2, i, j, interval, max_tau)
+ st[n] = spikes2[j]
+ if i > -1 and spikes2[j]-spikes1[i] < tau:
+ # coincidence between the current spike and the previous spike
+ # spike from spike train 1 before spike train 2
+ # both get marked with 1
+ a[n] = 1
+ a[n-1] = 1
+ else: # spikes1[i+1] = spikes2[j+1]
+ # advance in both spike trains
+ j += 1
+ i += 1
+ n += 1
+ # add only one event with zero asymmetry value and multiplicity 2
+ st[n] = spikes1[i]
+ a[n] = 0
+ mp[n] = 2
+
+ st = st[:n+2]
+ a = a[:n+2]
+ mp = mp[:n+2]
+
+ st[0] = t_start
+ st[len(st)-1] = t_end
+ if N1 + N2 > 0:
+ a[0] = a[1]
+ a[len(a)-1] = a[len(a)-2]
+ mp[0] = mp[1]
+ mp[len(mp)-1] = mp[len(mp)-2]
+ else:
+ a[0] = 1
+ a[1] = 1
+
+ return st, a, mp
+
+
+
+############################################################
+# spike_order_values_cython
+############################################################
+def spike_order_values_cython(double[:] spikes1,
+ double[:] spikes2,
+ double t_start, double t_end,
+ double max_tau):
+
+ cdef int N1 = len(spikes1)
+ cdef int N2 = len(spikes2)
+ cdef int i = -1
+ cdef int j = -1
+ cdef double[:] a1 = np.zeros(N1) # asymmetry values
+ cdef double[:] a2 = np.zeros(N2) # asymmetry values
+ cdef double interval = t_end - t_start
+ cdef double tau
+ while i + j < N1 + N2 - 2:
+ if (i < N1-1) and (j == N2-1 or spikes1[i+1] < spikes2[j+1]):
+ i += 1
+ tau = get_tau(spikes1, spikes2, i, j, interval, max_tau)
+ if j > -1 and spikes1[i]-spikes2[j] < tau:
+ # coincidence between the current spike and the previous spike
+ # spike from spike train 1 after spike train 2
+ # leading spike gets +1, following spike -1
+ a1[i] = -1
+ a2[j] = +1
+ elif (j < N2-1) and (i == N1-1 or spikes1[i+1] > spikes2[j+1]):
+ j += 1
+ tau = get_tau(spikes1, spikes2, i, j, interval, max_tau)
+ if i > -1 and spikes2[j]-spikes1[i] < tau:
+ # coincidence between the current spike and the previous spike
+ # spike from spike train 1 before spike train 2
+ # leading spike gets +1, following spike -1
+ a1[i] = +1
+ a2[j] = -1
+ else: # spikes1[i+1] = spikes2[j+1]
+ # advance in both spike trains
+ j += 1
+ i += 1
+ # equal spike times: zero asymmetry value
+ a1[i] = 0
+ a2[j] = 0
+
+ return a1, a2
+
+
+############################################################
+# spike_train_order_cython
+############################################################
+def spike_train_order_cython(double[:] spikes1, double[:] spikes2,
+ double t_start, double t_end, double max_tau):
+
+ cdef int N1 = len(spikes1)
+ cdef int N2 = len(spikes2)
+ cdef int i = -1
+ cdef int j = -1
+ cdef int asym = 0
+ cdef int mp = 0
+ cdef double interval = t_end - t_start
+ cdef double tau
+ while i + j < N1 + N2 - 2:
+ if (i < N1-1) and (j == N2-1 or spikes1[i+1] < spikes2[j+1]):
+ i += 1
+ mp += 1
+ tau = get_tau(spikes1, spikes2, i, j, interval, max_tau)
+ if j > -1 and spikes1[i]-spikes2[j] < tau:
+ # coincidence between the current spike and the previous spike
+ # spike in spike train 2 appeared before spike in spike train 1
+ # mark with -1
+ asym -= 2
+ elif (j < N2-1) and (i == N1-1 or spikes1[i+1] > spikes2[j+1]):
+ j += 1
+ mp += 1
+ tau = get_tau(spikes1, spikes2, i, j, interval, max_tau)
+ if i > -1 and spikes2[j]-spikes1[i] < tau:
+ # coincidence between the current spike and the previous spike
+ # spike in spike train 1 appeared before spike in spike train 2
+ # mark with +1
+ asym += 2
+ else: # spikes1[i+1] = spikes2[j+1]
+ # advance in both spike trains
+ j += 1
+ i += 1
+ # add only one event with multiplicity 2, but no asymmetry counting
+ mp += 2
+
+ if asym == 0 and mp == 0:
+ # empty spike trains -> spike sync = 1 by definition
+ asym = 1
+ mp = 1
+
+ return asym, mp
diff --git a/pyspike/cython/directionality_python_backend.py b/pyspike/cython/directionality_python_backend.py
new file mode 100644
index 0000000..e14238f
--- /dev/null
+++ b/pyspike/cython/directionality_python_backend.py
@@ -0,0 +1,89 @@
+""" directionality_python_backend.py
+
+Collection of python functions that can be used instead of the cython
+implementation.
+
+Copyright 2015, Mario Mulansky <mario.mulansky@gmx.net>
+
+Distributed under the BSD License
+
+"""
+
+import numpy as np
+
+
+############################################################
+# spike_train_order_python
+############################################################
+def spike_train_order_python(spikes1, spikes2, t_start, t_end, max_tau):
+
+ def get_tau(spikes1, spikes2, i, j, max_tau):
+ m = t_end - t_start # use interval as initial tau
+ if i < len(spikes1)-1 and i > -1:
+ m = min(m, spikes1[i+1]-spikes1[i])
+ if j < len(spikes2)-1 and j > -1:
+ m = min(m, spikes2[j+1]-spikes2[j])
+ if i > 0:
+ m = min(m, spikes1[i]-spikes1[i-1])
+ if j > 0:
+ m = min(m, spikes2[j]-spikes2[j-1])
+ m *= 0.5
+ if max_tau > 0.0:
+ m = min(m, max_tau)
+ return m
+
+ N1 = len(spikes1)
+ N2 = len(spikes2)
+ i = -1
+ j = -1
+ n = 0
+ st = np.zeros(N1 + N2 + 2) # spike times
+ a = np.zeros(N1 + N2 + 2) # coincidences
+ mp = np.ones(N1 + N2 + 2) # multiplicity
+ while i + j < N1 + N2 - 2:
+ if (i < N1-1) and (j == N2-1 or spikes1[i+1] < spikes2[j+1]):
+ i += 1
+ n += 1
+ tau = get_tau(spikes1, spikes2, i, j, max_tau)
+ st[n] = spikes1[i]
+ if j > -1 and spikes1[i]-spikes2[j] < tau:
+ # coincidence between the current spike and the previous spike
+ # both get marked with 1
+ a[n] = -1
+ a[n-1] = -1
+ elif (j < N2-1) and (i == N1-1 or spikes1[i+1] > spikes2[j+1]):
+ j += 1
+ n += 1
+ tau = get_tau(spikes1, spikes2, i, j, max_tau)
+ st[n] = spikes2[j]
+ if i > -1 and spikes2[j]-spikes1[i] < tau:
+ # coincidence between the current spike and the previous spike
+ # both get marked with 1
+ a[n] = 1
+ a[n-1] = 1
+ else: # spikes1[i+1] = spikes2[j+1]
+ # advance in both spike trains
+ j += 1
+ i += 1
+ n += 1
+ # add only one event with zero asymmetry value and multiplicity 2
+ st[n] = spikes1[i]
+ a[n] = 0
+ mp[n] = 2
+
+ st = st[:n+2]
+ a = a[:n+2]
+ mp = mp[:n+2]
+
+ st[0] = t_start
+ st[len(st)-1] = t_end
+ if N1 + N2 > 0:
+ a[0] = a[1]
+ a[len(a)-1] = a[len(a)-2]
+ mp[0] = mp[1]
+ mp[len(mp)-1] = mp[len(mp)-2]
+ else:
+ a[0] = 1
+ a[1] = 1
+
+ return st, a, mp
diff --git a/pyspike/spike_directionality.py b/pyspike/spike_directionality.py
new file mode 100644
index 0000000..0e69cb5
--- /dev/null
+++ b/pyspike/spike_directionality.py
@@ -0,0 +1,244 @@
+# Module containing functions to compute the SPIKE directionality and the
+# spike train order profile
+# Copyright 2015, Mario Mulansky <mario.mulansky@gmx.net>
+# Distributed under the BSD License
+
+import numpy as np
+from math import exp
+import pyspike
+from pyspike import DiscreteFunc
+
+
+############################################################
+# spike_directionality
+############################################################
+def spike_directionality(spike_train1, spike_train2, normalize=True,
+ interval=None, max_tau=None):
+ """ Computes the overall spike directionality for two spike trains.
+ """
+ if interval is None:
+ # distance over the whole interval is requested: use specific function
+ # for optimal performance
+ try:
+ from cython.cython_directionality import \
+ spike_train_order_cython as spike_train_order_impl
+ if max_tau is None:
+ max_tau = 0.0
+ c, mp = spike_train_order_impl(spike_train1.spikes,
+ spike_train2.spikes,
+ spike_train1.t_start,
+ spike_train1.t_end,
+ max_tau)
+ except ImportError:
+ # Cython backend not available: fall back to profile averaging
+ c, mp = _spike_directionality_profile(spike_train1,
+ spike_train2,
+ max_tau).integral(interval)
+ if normalize:
+ return 1.0*c/mp
+ else:
+ return c
+ else:
+ # some specific interval is provided: not yet implemented
+ raise NotImplementedError()
+
+
+############################################################
+# spike_directionality_matrix
+############################################################
+def spike_directionality_matrix(spike_trains, normalize=True, indices=None,
+ interval=None, max_tau=None):
+ """ Computes the spike directionaity matrix for the given spike trains.
+ """
+ if indices is None:
+ indices = np.arange(len(spike_trains))
+ indices = np.array(indices)
+ # check validity of indices
+ assert (indices < len(spike_trains)).all() and (indices >= 0).all(), \
+ "Invalid index list."
+ # generate a list of possible index pairs
+ pairs = [(indices[i], j) for i in range(len(indices))
+ for j in indices[i+1:]]
+
+ distance_matrix = np.zeros((len(indices), len(indices)))
+ for i, j in pairs:
+ d = spike_directionality(spike_trains[i], spike_trains[j], normalize,
+ interval, max_tau=max_tau)
+ distance_matrix[i, j] = d
+ distance_matrix[j, i] = -d
+ return distance_matrix
+
+
+############################################################
+# spike_train_order_profile
+############################################################
+def spike_train_order_profile(spike_trains, indices=None,
+ interval=None, max_tau=None):
+ """ Computes the spike train symmetry value for each spike in each spike
+ train.
+ """
+ if indices is None:
+ indices = np.arange(len(spike_trains))
+ indices = np.array(indices)
+ # check validity of indices
+ assert (indices < len(spike_trains)).all() and (indices >= 0).all(), \
+ "Invalid index list."
+ # list of arrays for reulting asymmetry values
+ asymmetry_list = [np.zeros_like(st.spikes) for st in spike_trains]
+ # generate a list of possible index pairs
+ pairs = [(indices[i], j) for i in range(len(indices))
+ for j in indices[i+1:]]
+
+ # cython implementation
+ try:
+ from cython.cython_directionality import \
+ spike_order_values_cython as spike_order_values_impl
+ except ImportError:
+ raise NotImplementedError()
+# if not(pyspike.disable_backend_warning):
+# print("Warning: spike_distance_cython not found. Make sure that \
+# PySpike is installed by running\n 'python setup.py build_ext --inplace'!\n \
+# Falling back to slow python backend.")
+# # use python backend
+# from cython.python_backend import coincidence_python \
+# as coincidence_profile_impl
+
+ if max_tau is None:
+ max_tau = 0.0
+
+ for i, j in pairs:
+ a1, a2 = spike_order_values_impl(spike_trains[i].spikes,
+ spike_trains[j].spikes,
+ spike_trains[i].t_start,
+ spike_trains[i].t_end,
+ max_tau)
+ asymmetry_list[i] += a1
+ asymmetry_list[j] += a2
+ for a in asymmetry_list:
+ a /= len(spike_trains)-1
+ return asymmetry_list
+
+
+############################################################
+# optimal_spike_train_order_from_matrix
+############################################################
+def optimal_spike_train_order_from_matrix(D, full_output=False):
+ """ finds the best sorting via simulated annealing.
+ Returns the optimal permutation p and A value.
+ Internal function, don't call directly! Use optimal_asymmetry_order
+ instead.
+ """
+ N = len(D)
+ A = np.sum(np.triu(D, 0))
+
+ p = np.arange(N)
+
+ T = 2*np.max(D) # starting temperature
+ T_end = 1E-5 * T # final temperature
+ alpha = 0.9 # cooling factor
+ total_iter = 0
+ while T > T_end:
+ iterations = 0
+ succ_iter = 0
+ while iterations < 100*N and succ_iter < 10*N:
+ # exchange two rows and cols
+ ind1 = np.random.randint(N-1)
+ delta_A = -2*D[p[ind1], p[ind1+1]]
+ if delta_A > 0.0 or exp(delta_A/T) > np.random.random():
+ # swap indices
+ p[ind1], p[ind1+1] = p[ind1+1], p[ind1]
+ A += delta_A
+ succ_iter += 1
+ iterations += 1
+ total_iter += iterations
+ T *= alpha # cool down
+ if succ_iter == 0:
+ break
+ if full_output:
+ return p, A, total_iter
+ else:
+ return p, A
+
+
+############################################################
+# optimal_spike_train_order
+############################################################
+def optimal_spike_train_order(spike_trains, indices=None, interval=None,
+ max_tau=None, full_output=False):
+ """ finds the best sorting of the given spike trains via simulated
+ annealing.
+ Returns the optimal permutation p and A value.
+ """
+ D = spike_directionality_matrix(spike_trains, normalize=False,
+ indices=indices, interval=interval,
+ max_tau=max_tau)
+ return optimal_spike_train_order_from_matrix(D, full_output)
+
+
+############################################################
+# permutate_matrix
+############################################################
+def permutate_matrix(D, p):
+ """ Applies the permutation p to the columns and rows of matrix D.
+ Return the new permutated matrix.
+ """
+ N = len(D)
+ D_p = np.empty_like(D)
+ for n in xrange(N):
+ for m in xrange(N):
+ D_p[n, m] = D[p[n], p[m]]
+ return D_p
+
+
+# internal helper functions
+
+############################################################
+# _spike_directionality_profile
+############################################################
+def _spike_directionality_profile(spike_train1, spike_train2,
+ max_tau=None):
+ """ Computes the spike delay asymmetry profile A(t) of the two given
+ spike trains. Returns the profile as a DiscreteFunction object.
+
+ :param spike_train1: First spike train.
+ :type spike_train1: :class:`pyspike.SpikeTrain`
+ :param spike_train2: Second spike train.
+ :type spike_train2: :class:`pyspike.SpikeTrain`
+ :param max_tau: Maximum coincidence window size. If 0 or `None`, the
+ coincidence window has no upper bound.
+ :returns: The spike-distance profile :math:`S_{sync}(t)`.
+ :rtype: :class:`pyspike.function.DiscreteFunction`
+
+ """
+ # check whether the spike trains are defined for the same interval
+ assert spike_train1.t_start == spike_train2.t_start, \
+ "Given spike trains are not defined on the same interval!"
+ assert spike_train1.t_end == spike_train2.t_end, \
+ "Given spike trains are not defined on the same interval!"
+
+ # cython implementation
+ try:
+ from cython.cython_directionality import \
+ spike_train_order_profile_cython as \
+ spike_train_order_profile_impl
+ except ImportError:
+ # raise NotImplementedError()
+ if not(pyspike.disable_backend_warning):
+ print("Warning: spike_distance_cython not found. Make sure that \
+PySpike is installed by running\n 'python setup.py build_ext --inplace'!\n \
+Falling back to slow python backend.")
+ # use python backend
+ from cython.directionality_python_backend import \
+ spike_train_order_python as spike_train_order_profile_impl
+
+ if max_tau is None:
+ max_tau = 0.0
+
+ times, coincidences, multiplicity \
+ = spike_train_order_profile_impl(spike_train1.spikes,
+ spike_train2.spikes,
+ spike_train1.t_start,
+ spike_train1.t_end,
+ max_tau)
+
+ return DiscreteFunc(times, coincidences, multiplicity)
diff --git a/setup.py b/setup.py
index 5b9e677..9ba1da6 100644
--- a/setup.py
+++ b/setup.py
@@ -30,7 +30,8 @@ class numpy_include(object):
if os.path.isfile("pyspike/cython/cython_add.c") and \
os.path.isfile("pyspike/cython/cython_profiles.c") and \
- os.path.isfile("pyspike/cython/cython_distances.c"):
+ os.path.isfile("pyspike/cython/cython_distances.c") and \
+ os.path.isfile("pyspike/cython/cython_directionality.c"):
use_c = True
else:
use_c = False
@@ -45,7 +46,9 @@ if use_cython: # Cython is available, compile .pyx -> .c
Extension("pyspike.cython.cython_profiles",
["pyspike/cython/cython_profiles.pyx"]),
Extension("pyspike.cython.cython_distances",
- ["pyspike/cython/cython_distances.pyx"])
+ ["pyspike/cython/cython_distances.pyx"]),
+ Extension("pyspike.cython.cython_directionality",
+ ["pyspike/cython/cython_directionality.pyx"])
]
cmdclass.update({'build_ext': build_ext})
elif use_c: # c files are there, compile to binaries
@@ -55,7 +58,9 @@ elif use_c: # c files are there, compile to binaries
Extension("pyspike.cython.cython_profiles",
["pyspike/cython/cython_profiles.c"]),
Extension("pyspike.cython.cython_distances",
- ["pyspike/cython/cython_distances.c"])
+ ["pyspike/cython/cython_distances.c"]),
+ Extension("pyspike.cython.cython_directionality",
+ ["pyspike/cython/cython_directionality.c"])
]
# neither cython nor c files available -> automatic fall-back to python backend
@@ -96,5 +101,11 @@ train similarity',
'Programming Language :: Python :: 3.4',
'Programming Language :: Python :: 3.5',
'Programming Language :: Python :: 3.6'
- ]
+ ],
+ package_data={
+ 'pyspike': ['cython/cython_add.c', 'cython/cython_profiles.c',
+ 'cython/cython_distances.c',
+ 'cython/cython_directionality.c'],
+ 'test': ['Spike_testdata.txt']
+ }
)
diff --git a/test/test_directionality.py b/test/test_directionality.py
new file mode 100644
index 0000000..5c3da00
--- /dev/null
+++ b/test/test_directionality.py
@@ -0,0 +1,41 @@
+""" test_spike_delay_asymmetry.py
+
+Tests the asymmetry functions
+
+Copyright 2015, Mario Mulansky <mario.mulansky@gmx.net>
+
+Distributed under the BSD License
+
+"""
+
+import numpy as np
+from numpy.testing import assert_equal, assert_almost_equal, \
+ assert_array_equal
+
+import pyspike as spk
+from pyspike import SpikeTrain, DiscreteFunc
+from pyspike.spike_directionality import _spike_directionality_profile
+
+
+def test_profile():
+ st1 = SpikeTrain([100, 200, 300], [0, 1000])
+ st2 = SpikeTrain([105, 205, 300], [0, 1000])
+ expected_x = np.array([0, 100, 105, 200, 205, 300, 1000])
+ expected_y = np.array([1, 1, 1, 1, 1, 0, 0])
+ expected_mp = np.array([1, 1, 1, 1, 1, 2, 2])
+
+ f = _spike_directionality_profile(st1, st2)
+
+ assert f.almost_equal(DiscreteFunc(expected_x, expected_y, expected_mp))
+ assert_almost_equal(f.avrg(), 2.0/3.0)
+ assert_almost_equal(spk.spike_directionality(st1, st2), 2.0/3.0)
+ assert_almost_equal(spk.spike_directionality(st1, st2, normalize=False),
+ 4.0)
+
+ st3 = SpikeTrain([105, 195, 500], [0, 1000])
+ expected_x = np.array([0, 100, 105, 195, 200, 300, 500, 1000])
+ expected_y = np.array([1, 1, 1, -1, -1, 0, 0, 0])
+ expected_mp = np.array([1, 1, 1, 1, 1, 1, 1, 1])
+
+ f = _spike_directionality_profile(st1, st3)
+ assert f.almost_equal(DiscreteFunc(expected_x, expected_y, expected_mp))