From 68f2f8e6297be829d29ef428784ac0002348877b Mon Sep 17 00:00:00 2001 From: Mario Mulansky Date: Tue, 30 Sep 2014 12:50:56 +0200 Subject: cython version of spike dist, using memory views --- pyspike/cython_distance.pyx | 148 ++++++++++++++++++++++++++++++++-- pyspike/distances.py | 190 ++++++++++++++++++++++++++------------------ 2 files changed, 254 insertions(+), 84 deletions(-) (limited to 'pyspike') diff --git a/pyspike/cython_distance.pyx b/pyspike/cython_distance.pyx index 330eea4..1a6d24a 100644 --- a/pyspike/cython_distance.pyx +++ b/pyspike/cython_distance.pyx @@ -3,8 +3,17 @@ #cython: cdivision=True """ -Doc +cython_distances.py +cython implementation of the isi- and spike-distance + +Note: using cython memoryviews (e.g. double[:]) instead of ndarray objects +improves the performance of spike_distance by a factor of 10! + +Copyright 2014, Mario Mulansky +""" + +""" To test whether things can be optimized: remove all yellow stuff in the html output:: @@ -14,20 +23,25 @@ which gives:: cython_distance.html - """ import numpy as np cimport numpy as np +from libc.math cimport fabs + DTYPE = np.float ctypedef np.float_t DTYPE_t -def isi_distance_cython(np.ndarray[DTYPE_t, ndim=1] s1, np.ndarray[DTYPE_t, ndim=1] s2): - cdef np.ndarray[DTYPE_t, ndim=1] spike_events - # the values have one entry less - the number of intervals between events - cdef np.ndarray[DTYPE_t, ndim=1] isi_values +############################################################ +# isi_distance_cython +############################################################ +def isi_distance_cython(double[:] s1, + double[:] s2): + + cdef double[:] spike_events + cdef double[:] isi_values cdef int index1, index2, index cdef int N1, N2 cdef double nu1, nu2 @@ -38,6 +52,7 @@ def isi_distance_cython(np.ndarray[DTYPE_t, ndim=1] s1, np.ndarray[DTYPE_t, ndim nu2 = s2[1]-s2[0] spike_events = np.empty(N1+N2) spike_events[0] = s1[0] + # the values have one entry less - the number of intervals between events isi_values = np.empty(N1+N2-1) isi_values[0] = (nu1-nu2)/max(nu1,nu2) index1 = 0 @@ -73,3 +88,124 @@ def isi_distance_cython(np.ndarray[DTYPE_t, ndim=1] s1, np.ndarray[DTYPE_t, ndim spike_events[index] = s1[N1] return spike_events[:index+1], isi_values[:index] + + +############################################################ +# get_min_dist_cython +############################################################ +cdef inline double get_min_dist_cython(double spike_time, + double[:] spike_train, + # use memory view to ensure inlining + # np.ndarray[DTYPE_t,ndim=1] spike_train, + int N, + int start_index=0): + """ Returns the minimal distance |spike_time - spike_train[i]| + with i>=start_index. + """ + cdef double d, d_temp + d = fabs(spike_time - spike_train[start_index]) + start_index += 1 + while start_index < N: + d_temp = fabs(spike_time - spike_train[start_index]) + if d_temp > d: + break + else: + d = d_temp + start_index += 1 + return d + + +############################################################ +# spike_distance_cython +############################################################ +def spike_distance_cython(double[:] t1, + double[:] t2): + + cdef double[:] spike_events + cdef double[:] y_starts + cdef double[:] y_ends + + cdef int N1, N2, index1, index2, index + cdef double dt_p1, dt_p2, dt_f1, dt_f2, isi1, isi2, s1, s2 + + N1 = len(t1) + N2 = len(t2) + + spike_events = np.empty(N1+N2-2) + spike_events[0] = t1[0] + y_starts = np.empty(len(spike_events)-1) + y_ends = np.empty(len(spike_events)-1) + + index1 = 0 + index2 = 0 + index = 1 + dt_p1 = 0.0 + dt_f1 = get_min_dist_cython(t1[1], t2, N2, 0) + dt_p2 = 0.0 + dt_f2 = get_min_dist_cython(t2[1], t1, N1, 0) + isi1 = max(t1[1]-t1[0], t1[2]-t1[1]) + isi2 = max(t2[1]-t2[0], t2[2]-t2[1]) + s1 = dt_f1*(t1[1]-t1[0])/isi1 + s2 = dt_f2*(t2[1]-t2[0])/isi2 + y_starts[0] = (s1*isi2 + s2*isi1) / ((isi1+isi2)**2/2) + while True: + # print(index, index1, index2) + if t1[index1+1] < t2[index2+1]: + index1 += 1 + # break condition relies on existence of spikes at T_end + if index1+1 >= N1: + break + spike_events[index] = t1[index1] + # first calculate the previous interval end value + dt_p1 = dt_f1 # the previous time now was the following time before + s1 = dt_p1 + s2 = (dt_p2*(t2[index2+1]-t1[index1]) + dt_f2*(t1[index1]-t2[index2])) / isi2 + y_ends[index-1] = (s1*isi2 + s2*isi1) / ((isi1+isi2)**2/2) + # now the next interval start value + dt_f1 = get_min_dist_cython(t1[index1+1], t2, N2, index2) + isi1 = t1[index1+1]-t1[index1] + # s2 is the same as above, thus we can compute y2 immediately + y_starts[index] = (s1*isi2 + s2*isi1) / ((isi1+isi2)**2/2) + elif t1[index1+1] > t2[index2+1]: + index2 += 1 + if index2+1 >= N2: + break + spike_events[index] = t2[index2] + # first calculate the previous interval end value + dt_p2 = dt_f2 # the previous time now was the following time before + s1 = (dt_p1*(t1[index1+1]-t2[index2]) + dt_f1*(t2[index2]-t1[index1])) / isi1 + s2 = dt_p2 + y_ends[index-1] = (s1*isi2 + s2*isi1) / ((isi1+isi2)**2/2) + # now the next interval start value + dt_f2 = get_min_dist_cython(t2[index2+1], t1, N1, index1) + #s2 = dt_f2 + isi2 = t2[index2+1]-t2[index2] + # s2 is the same as above, thus we can compute y2 immediately + y_starts[index] = (s1*isi2 + s2*isi1) / ((isi1+isi2)**2/2) + else: # t1[index1+1] == t2[index2+1] - generate only one event + index1 += 1 + index2 += 1 + if (index1+1 >= N1) or (index2+1 >= N2): + break + spike_events[index] = t1[index1] + y_ends[index-1] = 0.0 + y_starts[index] = 0.0 + dt_p1 = 0.0 + dt_p2 = 0.0 + dt_f1 = get_min_dist_cython(t1[index1+1], t2, N2, index2) + dt_f2 = get_min_dist_cython(t2[index2+1], t1, N1, index1) + isi1 = t1[index1+1]-t1[index1] + isi2 = t2[index2+1]-t2[index2] + index += 1 + # the last event is the interval end + spike_events[index] = t1[N1-1] + # the ending value of the last interval + isi1 = max(t1[N1-1]-t1[N1-2], t1[N1-2]-t1[N1-3]) + isi2 = max(t2[N2-1]-t2[N2-2], t2[N2-2]-t2[N2-3]) + s1 = dt_p1*(t1[N1-1]-t1[N1-2])/isi1 + s2 = dt_p2*(t2[N2-1]-t2[N2-2])/isi2 + y_ends[index-1] = (s1*isi2 + s2*isi1) / ((isi1+isi2)**2/2) + # use only the data added above + # could be less than original length due to equal spike times + + return spike_events[:index+1], y_starts[:index], y_ends[:index] diff --git a/pyspike/distances.py b/pyspike/distances.py index 76dcd83..7a85471 100644 --- a/pyspike/distances.py +++ b/pyspike/distances.py @@ -62,6 +62,116 @@ def isi_distance(spikes1, spikes2): return PieceWiseConstFunc(times, values) +############################################################ +# spike_distance +############################################################ +def spike_distance(spikes1, spikes2): + """ Computes the instantaneous spike-distance S_spike (t) of the two given + spike trains. The spike trains are expected to have auxiliary spikes at the + beginning and end of the interval. Use the function add_auxiliary_spikes to + add those spikes to the spike train. + Args: + - spikes1, spikes2: ordered arrays of spike times with auxiliary spikes. + Returns: + - PieceWiseLinFunc describing the spike-distance. + """ + # check for auxiliary spikes - first and last spikes should be identical + assert spikes1[0]==spikes2[0], \ + "Given spike trains seems not to have auxiliary spikes!" + assert spikes1[-1]==spikes2[-1], \ + "Given spike trains seems not to have auxiliary spikes!" + + # compile and load cython implementation + import pyximport + pyximport.install(setup_args={'include_dirs': [np.get_include()]}) + from cython_distance import spike_distance_cython + + times, y_starts, y_ends = spike_distance_cython(spikes1, spikes2) + + return PieceWiseLinFunc(times, y_starts, y_ends) + + +############################################################ +# multi_distance +############################################################ +def multi_distance(spike_trains, pair_distance_func, indices=None): + """ Internal implementation detail, use isi_distance_multi or + spike_distance_multi. + + Computes the multi-variate distance for a set of spike-trains using the + pair_dist_func to compute pair-wise distances. That is it computes the + average distance of all pairs of spike-trains: + S(t) = 2/((N(N-1)) sum_{} S_{i,j}, + where the sum goes over all pairs . + Args: + - spike_trains: list of spike trains + - pair_distance_func: function computing the distance of two spike trains + - indices: list of indices defining which spike trains to use, + if None all given spike trains are used (default=None) + Returns: + - The averaged multi-variate distance of all pairs + """ + if indices==None: + indices = np.arange(len(spike_trains)) + indices = np.array(indices) + # check validity of indices + assert (indices < len(spike_trains)).all() and (indices >= 0).all(), \ + "Invalid index list." + # generate a list of possible index pairs + pairs = [(i,j) for i in indices for j in indices[i+1:]] + # start with first pair + (i,j) = pairs[0] + average_dist = pair_distance_func(spike_trains[i], spike_trains[j]) + for (i,j) in pairs[1:]: + current_dist = pair_distance_func(spike_trains[i], spike_trains[j]) + average_dist.add(current_dist) # add to the average + average_dist.mul_scalar(1.0/len(pairs)) # normalize + return average_dist + + +############################################################ +# isi_distance_multi +############################################################ +def isi_distance_multi(spike_trains, indices=None): + """ computes the multi-variate isi-distance for a set of spike-trains. That + is the average isi-distance of all pairs of spike-trains: + S(t) = 2/((N(N-1)) sum_{} S_{i,j}, + where the sum goes over all pairs + Args: + - spike_trains: list of spike trains + - indices: list of indices defining which spike trains to use, + if None all given spike trains are used (default=None) + Returns: + - A PieceWiseConstFunc representing the averaged isi distance S + """ + return multi_distance(spike_trains, isi_distance, indices) + + +############################################################ +# spike_distance_multi +############################################################ +def spike_distance_multi(spike_trains, indices=None): + """ computes the multi-variate spike-distance for a set of spike-trains. + That is the average spike-distance of all pairs of spike-trains: + S(t) = 2/((N(N-1)) sum_{} S_{i,j}, + where the sum goes over all pairs + Args: + - spike_trains: list of spike trains + - indices: list of indices defining which spike trains to use, + if None all given spike trains are used (default=None) + Returns: + - A PieceWiseLinFunc representing the averaged spike distance S + """ + return multi_distance(spike_trains, spike_distance, indices) + + +############################################################ +############################################################ +# VANILLA PYTHON IMPLEMENTATIONS OF ISI AND SPIKE DISTANCE +############################################################ +############################################################ + + ############################################################ # isi_distance_python ############################################################ @@ -134,9 +244,9 @@ def get_min_dist(spike_time, spike_train, start_index=0): ############################################################ -# spike_distance +# spike_distance_python ############################################################ -def spike_distance(spikes1, spikes2): +def spike_distance_python(spikes1, spikes2): """ Computes the instantaneous spike-distance S_spike (t) of the two given spike trains. The spike trains are expected to have auxiliary spikes at the beginning and end of the interval. Use the function add_auxiliary_spikes to @@ -235,79 +345,3 @@ def spike_distance(spikes1, spikes2): # could be less than original length due to equal spike times return PieceWiseLinFunc(spike_events[:index+1], y_starts[:index], y_ends[:index]) - - - - -############################################################ -# multi_distance -############################################################ -def multi_distance(spike_trains, pair_distance_func, indices=None): - """ Internal implementation detail, use isi_distance_multi or - spike_distance_multi. - - Computes the multi-variate distance for a set of spike-trains using the - pair_dist_func to compute pair-wise distances. That is it computes the - average distance of all pairs of spike-trains: - S(t) = 2/((N(N-1)) sum_{} S_{i,j}, - where the sum goes over all pairs . - Args: - - spike_trains: list of spike trains - - pair_distance_func: function computing the distance of two spike trains - - indices: list of indices defining which spike trains to use, - if None all given spike trains are used (default=None) - Returns: - - The averaged multi-variate distance of all pairs - """ - if indices==None: - indices = np.arange(len(spike_trains)) - indices = np.array(indices) - # check validity of indices - assert (indices < len(spike_trains)).all() and (indices >= 0).all(), \ - "Invalid index list." - # generate a list of possible index pairs - pairs = [(i,j) for i in indices for j in indices[i+1:]] - # start with first pair - (i,j) = pairs[0] - average_dist = pair_distance_func(spike_trains[i], spike_trains[j]) - for (i,j) in pairs[1:]: - current_dist = pair_distance_func(spike_trains[i], spike_trains[j]) - average_dist.add(current_dist) # add to the average - average_dist.mul_scalar(1.0/len(pairs)) # normalize - return average_dist - - -############################################################ -# isi_distance_multi -############################################################ -def isi_distance_multi(spike_trains, indices=None): - """ computes the multi-variate isi-distance for a set of spike-trains. That - is the average isi-distance of all pairs of spike-trains: - S(t) = 2/((N(N-1)) sum_{} S_{i,j}, - where the sum goes over all pairs - Args: - - spike_trains: list of spike trains - - indices: list of indices defining which spike trains to use, - if None all given spike trains are used (default=None) - Returns: - - A PieceWiseConstFunc representing the averaged isi distance S - """ - return multi_distance(spike_trains, isi_distance, indices) - - -############################################################ -# spike_distance_multi -############################################################ -def spike_distance_multi(spike_trains, indices=None): - """ computes the multi-variate spike-distance for a set of spike-trains. - That is the average spike-distance of all pairs of spike-trains: - S(t) = 2/((N(N-1)) sum_{} S_{i,j}, - where the sum goes over all pairs - Args: - - spike_trains: list of spike trains - - indices: list of indices defining which spike trains to use, - if None all given spike trains are used (default=None) - Returns: - - A PieceWiseLinFunc representing the averaged spike distance S - """ - return multi_distance(spike_trains, spike_distance, indices) -- cgit v1.2.3