summaryrefslogtreecommitdiff
path: root/pyspike
diff options
context:
space:
mode:
authorMario Mulansky <mario.mulansky@gmx.net>2014-09-30 12:50:56 +0200
committerMario Mulansky <mario.mulansky@gmx.net>2014-09-30 12:50:56 +0200
commit68f2f8e6297be829d29ef428784ac0002348877b (patch)
tree2fed86db66edc2373521e279e2a1c2226a49322a /pyspike
parentaeec6cfafed8df110e60743073cff6d778f65af0 (diff)
cython version of spike dist, using memory views
Diffstat (limited to 'pyspike')
-rw-r--r--pyspike/cython_distance.pyx148
-rw-r--r--pyspike/distances.py190
2 files changed, 254 insertions, 84 deletions
diff --git a/pyspike/cython_distance.pyx b/pyspike/cython_distance.pyx
index 330eea4..1a6d24a 100644
--- a/pyspike/cython_distance.pyx
+++ b/pyspike/cython_distance.pyx
@@ -3,8 +3,17 @@
#cython: cdivision=True
"""
-Doc
+cython_distances.py
+cython implementation of the isi- and spike-distance
+
+Note: using cython memoryviews (e.g. double[:]) instead of ndarray objects
+improves the performance of spike_distance by a factor of 10!
+
+Copyright 2014, Mario Mulansky <mario.mulansky@gmx.net>
+"""
+
+"""
To test whether things can be optimized: remove all yellow stuff
in the html output::
@@ -14,20 +23,25 @@ which gives::
cython_distance.html
-
"""
import numpy as np
cimport numpy as np
+from libc.math cimport fabs
+
DTYPE = np.float
ctypedef np.float_t DTYPE_t
-def isi_distance_cython(np.ndarray[DTYPE_t, ndim=1] s1, np.ndarray[DTYPE_t, ndim=1] s2):
- cdef np.ndarray[DTYPE_t, ndim=1] spike_events
- # the values have one entry less - the number of intervals between events
- cdef np.ndarray[DTYPE_t, ndim=1] isi_values
+############################################################
+# isi_distance_cython
+############################################################
+def isi_distance_cython(double[:] s1,
+ double[:] s2):
+
+ cdef double[:] spike_events
+ cdef double[:] isi_values
cdef int index1, index2, index
cdef int N1, N2
cdef double nu1, nu2
@@ -38,6 +52,7 @@ def isi_distance_cython(np.ndarray[DTYPE_t, ndim=1] s1, np.ndarray[DTYPE_t, ndim
nu2 = s2[1]-s2[0]
spike_events = np.empty(N1+N2)
spike_events[0] = s1[0]
+ # the values have one entry less - the number of intervals between events
isi_values = np.empty(N1+N2-1)
isi_values[0] = (nu1-nu2)/max(nu1,nu2)
index1 = 0
@@ -73,3 +88,124 @@ def isi_distance_cython(np.ndarray[DTYPE_t, ndim=1] s1, np.ndarray[DTYPE_t, ndim
spike_events[index] = s1[N1]
return spike_events[:index+1], isi_values[:index]
+
+
+############################################################
+# get_min_dist_cython
+############################################################
+cdef inline double get_min_dist_cython(double spike_time,
+ double[:] spike_train,
+ # use memory view to ensure inlining
+ # np.ndarray[DTYPE_t,ndim=1] spike_train,
+ int N,
+ int start_index=0):
+ """ Returns the minimal distance |spike_time - spike_train[i]|
+ with i>=start_index.
+ """
+ cdef double d, d_temp
+ d = fabs(spike_time - spike_train[start_index])
+ start_index += 1
+ while start_index < N:
+ d_temp = fabs(spike_time - spike_train[start_index])
+ if d_temp > d:
+ break
+ else:
+ d = d_temp
+ start_index += 1
+ return d
+
+
+############################################################
+# spike_distance_cython
+############################################################
+def spike_distance_cython(double[:] t1,
+ double[:] t2):
+
+ cdef double[:] spike_events
+ cdef double[:] y_starts
+ cdef double[:] y_ends
+
+ cdef int N1, N2, index1, index2, index
+ cdef double dt_p1, dt_p2, dt_f1, dt_f2, isi1, isi2, s1, s2
+
+ N1 = len(t1)
+ N2 = len(t2)
+
+ spike_events = np.empty(N1+N2-2)
+ spike_events[0] = t1[0]
+ y_starts = np.empty(len(spike_events)-1)
+ y_ends = np.empty(len(spike_events)-1)
+
+ index1 = 0
+ index2 = 0
+ index = 1
+ dt_p1 = 0.0
+ dt_f1 = get_min_dist_cython(t1[1], t2, N2, 0)
+ dt_p2 = 0.0
+ dt_f2 = get_min_dist_cython(t2[1], t1, N1, 0)
+ isi1 = max(t1[1]-t1[0], t1[2]-t1[1])
+ isi2 = max(t2[1]-t2[0], t2[2]-t2[1])
+ s1 = dt_f1*(t1[1]-t1[0])/isi1
+ s2 = dt_f2*(t2[1]-t2[0])/isi2
+ y_starts[0] = (s1*isi2 + s2*isi1) / ((isi1+isi2)**2/2)
+ while True:
+ # print(index, index1, index2)
+ if t1[index1+1] < t2[index2+1]:
+ index1 += 1
+ # break condition relies on existence of spikes at T_end
+ if index1+1 >= N1:
+ break
+ spike_events[index] = t1[index1]
+ # first calculate the previous interval end value
+ dt_p1 = dt_f1 # the previous time now was the following time before
+ s1 = dt_p1
+ s2 = (dt_p2*(t2[index2+1]-t1[index1]) + dt_f2*(t1[index1]-t2[index2])) / isi2
+ y_ends[index-1] = (s1*isi2 + s2*isi1) / ((isi1+isi2)**2/2)
+ # now the next interval start value
+ dt_f1 = get_min_dist_cython(t1[index1+1], t2, N2, index2)
+ isi1 = t1[index1+1]-t1[index1]
+ # s2 is the same as above, thus we can compute y2 immediately
+ y_starts[index] = (s1*isi2 + s2*isi1) / ((isi1+isi2)**2/2)
+ elif t1[index1+1] > t2[index2+1]:
+ index2 += 1
+ if index2+1 >= N2:
+ break
+ spike_events[index] = t2[index2]
+ # first calculate the previous interval end value
+ dt_p2 = dt_f2 # the previous time now was the following time before
+ s1 = (dt_p1*(t1[index1+1]-t2[index2]) + dt_f1*(t2[index2]-t1[index1])) / isi1
+ s2 = dt_p2
+ y_ends[index-1] = (s1*isi2 + s2*isi1) / ((isi1+isi2)**2/2)
+ # now the next interval start value
+ dt_f2 = get_min_dist_cython(t2[index2+1], t1, N1, index1)
+ #s2 = dt_f2
+ isi2 = t2[index2+1]-t2[index2]
+ # s2 is the same as above, thus we can compute y2 immediately
+ y_starts[index] = (s1*isi2 + s2*isi1) / ((isi1+isi2)**2/2)
+ else: # t1[index1+1] == t2[index2+1] - generate only one event
+ index1 += 1
+ index2 += 1
+ if (index1+1 >= N1) or (index2+1 >= N2):
+ break
+ spike_events[index] = t1[index1]
+ y_ends[index-1] = 0.0
+ y_starts[index] = 0.0
+ dt_p1 = 0.0
+ dt_p2 = 0.0
+ dt_f1 = get_min_dist_cython(t1[index1+1], t2, N2, index2)
+ dt_f2 = get_min_dist_cython(t2[index2+1], t1, N1, index1)
+ isi1 = t1[index1+1]-t1[index1]
+ isi2 = t2[index2+1]-t2[index2]
+ index += 1
+ # the last event is the interval end
+ spike_events[index] = t1[N1-1]
+ # the ending value of the last interval
+ isi1 = max(t1[N1-1]-t1[N1-2], t1[N1-2]-t1[N1-3])
+ isi2 = max(t2[N2-1]-t2[N2-2], t2[N2-2]-t2[N2-3])
+ s1 = dt_p1*(t1[N1-1]-t1[N1-2])/isi1
+ s2 = dt_p2*(t2[N2-1]-t2[N2-2])/isi2
+ y_ends[index-1] = (s1*isi2 + s2*isi1) / ((isi1+isi2)**2/2)
+ # use only the data added above
+ # could be less than original length due to equal spike times
+
+ return spike_events[:index+1], y_starts[:index], y_ends[:index]
diff --git a/pyspike/distances.py b/pyspike/distances.py
index 76dcd83..7a85471 100644
--- a/pyspike/distances.py
+++ b/pyspike/distances.py
@@ -63,6 +63,116 @@ def isi_distance(spikes1, spikes2):
############################################################
+# spike_distance
+############################################################
+def spike_distance(spikes1, spikes2):
+ """ Computes the instantaneous spike-distance S_spike (t) of the two given
+ spike trains. The spike trains are expected to have auxiliary spikes at the
+ beginning and end of the interval. Use the function add_auxiliary_spikes to
+ add those spikes to the spike train.
+ Args:
+ - spikes1, spikes2: ordered arrays of spike times with auxiliary spikes.
+ Returns:
+ - PieceWiseLinFunc describing the spike-distance.
+ """
+ # check for auxiliary spikes - first and last spikes should be identical
+ assert spikes1[0]==spikes2[0], \
+ "Given spike trains seems not to have auxiliary spikes!"
+ assert spikes1[-1]==spikes2[-1], \
+ "Given spike trains seems not to have auxiliary spikes!"
+
+ # compile and load cython implementation
+ import pyximport
+ pyximport.install(setup_args={'include_dirs': [np.get_include()]})
+ from cython_distance import spike_distance_cython
+
+ times, y_starts, y_ends = spike_distance_cython(spikes1, spikes2)
+
+ return PieceWiseLinFunc(times, y_starts, y_ends)
+
+
+############################################################
+# multi_distance
+############################################################
+def multi_distance(spike_trains, pair_distance_func, indices=None):
+ """ Internal implementation detail, use isi_distance_multi or
+ spike_distance_multi.
+
+ Computes the multi-variate distance for a set of spike-trains using the
+ pair_dist_func to compute pair-wise distances. That is it computes the
+ average distance of all pairs of spike-trains:
+ S(t) = 2/((N(N-1)) sum_{<i,j>} S_{i,j},
+ where the sum goes over all pairs <i,j>.
+ Args:
+ - spike_trains: list of spike trains
+ - pair_distance_func: function computing the distance of two spike trains
+ - indices: list of indices defining which spike trains to use,
+ if None all given spike trains are used (default=None)
+ Returns:
+ - The averaged multi-variate distance of all pairs
+ """
+ if indices==None:
+ indices = np.arange(len(spike_trains))
+ indices = np.array(indices)
+ # check validity of indices
+ assert (indices < len(spike_trains)).all() and (indices >= 0).all(), \
+ "Invalid index list."
+ # generate a list of possible index pairs
+ pairs = [(i,j) for i in indices for j in indices[i+1:]]
+ # start with first pair
+ (i,j) = pairs[0]
+ average_dist = pair_distance_func(spike_trains[i], spike_trains[j])
+ for (i,j) in pairs[1:]:
+ current_dist = pair_distance_func(spike_trains[i], spike_trains[j])
+ average_dist.add(current_dist) # add to the average
+ average_dist.mul_scalar(1.0/len(pairs)) # normalize
+ return average_dist
+
+
+############################################################
+# isi_distance_multi
+############################################################
+def isi_distance_multi(spike_trains, indices=None):
+ """ computes the multi-variate isi-distance for a set of spike-trains. That
+ is the average isi-distance of all pairs of spike-trains:
+ S(t) = 2/((N(N-1)) sum_{<i,j>} S_{i,j},
+ where the sum goes over all pairs <i,j>
+ Args:
+ - spike_trains: list of spike trains
+ - indices: list of indices defining which spike trains to use,
+ if None all given spike trains are used (default=None)
+ Returns:
+ - A PieceWiseConstFunc representing the averaged isi distance S
+ """
+ return multi_distance(spike_trains, isi_distance, indices)
+
+
+############################################################
+# spike_distance_multi
+############################################################
+def spike_distance_multi(spike_trains, indices=None):
+ """ computes the multi-variate spike-distance for a set of spike-trains.
+ That is the average spike-distance of all pairs of spike-trains:
+ S(t) = 2/((N(N-1)) sum_{<i,j>} S_{i,j},
+ where the sum goes over all pairs <i,j>
+ Args:
+ - spike_trains: list of spike trains
+ - indices: list of indices defining which spike trains to use,
+ if None all given spike trains are used (default=None)
+ Returns:
+ - A PieceWiseLinFunc representing the averaged spike distance S
+ """
+ return multi_distance(spike_trains, spike_distance, indices)
+
+
+############################################################
+############################################################
+# VANILLA PYTHON IMPLEMENTATIONS OF ISI AND SPIKE DISTANCE
+############################################################
+############################################################
+
+
+############################################################
# isi_distance_python
############################################################
def isi_distance_python(s1, s2):
@@ -134,9 +244,9 @@ def get_min_dist(spike_time, spike_train, start_index=0):
############################################################
-# spike_distance
+# spike_distance_python
############################################################
-def spike_distance(spikes1, spikes2):
+def spike_distance_python(spikes1, spikes2):
""" Computes the instantaneous spike-distance S_spike (t) of the two given
spike trains. The spike trains are expected to have auxiliary spikes at the
beginning and end of the interval. Use the function add_auxiliary_spikes to
@@ -235,79 +345,3 @@ def spike_distance(spikes1, spikes2):
# could be less than original length due to equal spike times
return PieceWiseLinFunc(spike_events[:index+1],
y_starts[:index], y_ends[:index])
-
-
-
-
-############################################################
-# multi_distance
-############################################################
-def multi_distance(spike_trains, pair_distance_func, indices=None):
- """ Internal implementation detail, use isi_distance_multi or
- spike_distance_multi.
-
- Computes the multi-variate distance for a set of spike-trains using the
- pair_dist_func to compute pair-wise distances. That is it computes the
- average distance of all pairs of spike-trains:
- S(t) = 2/((N(N-1)) sum_{<i,j>} S_{i,j},
- where the sum goes over all pairs <i,j>.
- Args:
- - spike_trains: list of spike trains
- - pair_distance_func: function computing the distance of two spike trains
- - indices: list of indices defining which spike trains to use,
- if None all given spike trains are used (default=None)
- Returns:
- - The averaged multi-variate distance of all pairs
- """
- if indices==None:
- indices = np.arange(len(spike_trains))
- indices = np.array(indices)
- # check validity of indices
- assert (indices < len(spike_trains)).all() and (indices >= 0).all(), \
- "Invalid index list."
- # generate a list of possible index pairs
- pairs = [(i,j) for i in indices for j in indices[i+1:]]
- # start with first pair
- (i,j) = pairs[0]
- average_dist = pair_distance_func(spike_trains[i], spike_trains[j])
- for (i,j) in pairs[1:]:
- current_dist = pair_distance_func(spike_trains[i], spike_trains[j])
- average_dist.add(current_dist) # add to the average
- average_dist.mul_scalar(1.0/len(pairs)) # normalize
- return average_dist
-
-
-############################################################
-# isi_distance_multi
-############################################################
-def isi_distance_multi(spike_trains, indices=None):
- """ computes the multi-variate isi-distance for a set of spike-trains. That
- is the average isi-distance of all pairs of spike-trains:
- S(t) = 2/((N(N-1)) sum_{<i,j>} S_{i,j},
- where the sum goes over all pairs <i,j>
- Args:
- - spike_trains: list of spike trains
- - indices: list of indices defining which spike trains to use,
- if None all given spike trains are used (default=None)
- Returns:
- - A PieceWiseConstFunc representing the averaged isi distance S
- """
- return multi_distance(spike_trains, isi_distance, indices)
-
-
-############################################################
-# spike_distance_multi
-############################################################
-def spike_distance_multi(spike_trains, indices=None):
- """ computes the multi-variate spike-distance for a set of spike-trains.
- That is the average spike-distance of all pairs of spike-trains:
- S(t) = 2/((N(N-1)) sum_{<i,j>} S_{i,j},
- where the sum goes over all pairs <i,j>
- Args:
- - spike_trains: list of spike trains
- - indices: list of indices defining which spike trains to use,
- if None all given spike trains are used (default=None)
- Returns:
- - A PieceWiseLinFunc representing the averaged spike distance S
- """
- return multi_distance(spike_trains, spike_distance, indices)