From ed85a9b72edcb7bba6ae1105e213b3b0a2f78d3a Mon Sep 17 00:00:00 2001 From: Mario Mulansky Date: Fri, 24 Apr 2015 00:49:16 +0200 Subject: changed spike distance to use new SpikeTrain class --- pyspike/cython/cython_distance.pyx | 199 +++++++++++++++++++++++++----------- pyspike/cython/python_backend.py | 203 ++++++++++++++++++++++++------------- pyspike/spike_distance.py | 37 ++++--- test/test_distance.py | 34 ++++--- 4 files changed, 313 insertions(+), 160 deletions(-) diff --git a/pyspike/cython/cython_distance.pyx b/pyspike/cython/cython_distance.pyx index 1d652ee..7999e0a 100644 --- a/pyspike/cython/cython_distance.pyx +++ b/pyspike/cython/cython_distance.pyx @@ -131,21 +131,30 @@ cdef inline double get_min_dist_cython(double spike_time, # use memory view to ensure inlining # np.ndarray[DTYPE_t,ndim=1] spike_train, int N, - int start_index=0) nogil: + int start_index, + double t_start, double t_end) nogil: """ Returns the minimal distance |spike_time - spike_train[i]| with i>=start_index. """ cdef double d, d_temp - d = fabs(spike_time - spike_train[start_index]) - start_index += 1 + # start with the distance to the start time + d = fabs(spike_time - t_start) + if start_index < 0: + start_index = 0 while start_index < N: d_temp = fabs(spike_time - spike_train[start_index]) if d_temp > d: - break + return d else: d = d_temp start_index += 1 - return d + + # finally, check the distance to end time + d_temp = fabs(t_end - spike_time) + if d_temp > d: + return d + else: + return d_temp ############################################################ @@ -160,96 +169,162 @@ cdef inline double isi_avrg_cython(double isi1, double isi2) nogil: ############################################################ # spike_distance_cython ############################################################ -def spike_distance_cython(double[:] t1, - double[:] t2): +def spike_distance_cython(double[:] t1, double[:] t2, + double t_start, double t_end): cdef double[:] spike_events cdef double[:] y_starts cdef double[:] y_ends cdef int N1, N2, index1, index2, index - cdef double dt_p1, dt_p2, dt_f1, dt_f2, isi1, isi2, s1, s2 + cdef double t_p1, t_f1, t_p2, t_f2, dt_p1, dt_p2, dt_f1, dt_f2 + cdef double isi1, isi2, s1, s2 N1 = len(t1) N2 = len(t2) - spike_events = np.empty(N1+N2-2) - spike_events[0] = t1[0] + spike_events = np.empty(N1+N2+2) + y_starts = np.empty(len(spike_events)-1) y_ends = np.empty(len(spike_events)-1) with nogil: # release the interpreter to allow multithreading - index1 = 0 - index2 = 0 - index = 1 - dt_p1 = 0.0 - dt_f1 = get_min_dist_cython(t1[1], t2, N2, 0) - dt_p2 = 0.0 - dt_f2 = get_min_dist_cython(t2[1], t1, N1, 0) - isi1 = max(t1[1]-t1[0], t1[2]-t1[1]) - isi2 = max(t2[1]-t2[0], t2[2]-t2[1]) - s1 = dt_f1*(t1[1]-t1[0])/isi1 - s2 = dt_f2*(t2[1]-t2[0])/isi2 + spike_events[0] = t_start + t_p1 = t_start + t_p2 = t_start + if t1[0] > t_start: + # dt_p1 = t2[0]-t_start + dt_p1 = 0.0 + t_f1 = t1[0] + dt_f1 = get_min_dist_cython(t_f1, t2, N2, 0, t_start, t_end) + isi1 = fmax(t_f1-t_start, t1[1]-t1[0]) + s1 = dt_f1*(t_f1-t_start)/isi1 + index1 = -1 + else: + dt_p1 = 0.0 + t_f1 = t1[1] + dt_f1 = get_min_dist_cython(t_f1, t2, N2, 0, t_start, t_end) + isi1 = t1[1]-t1[0] + s1 = dt_p1 + index1 = 0 + if t2[0] > t_start: + # dt_p1 = t2[0]-t_start + dt_p2 = 0.0 + t_f2 = t2[0] + dt_f2 = get_min_dist_cython(t_f2, t1, N1, 0, t_start, t_end) + isi2 = fmax(t_f2-t_start, t2[1]-t2[0]) + s2 = dt_f2*(t_f2-t_start)/isi2 + index2 = -1 + else: + dt_p2 = 0.0 + t_f2 = t2[1] + dt_f2 = get_min_dist_cython(t_f2, t1, N1, 0, t_start, t_end) + isi2 = t2[1]-t2[0] + s2 = dt_p2 + index2 = 0 + y_starts[0] = (s1*isi2 + s2*isi1) / isi_avrg_cython(isi1, isi2) - while True: + index = 1 + + while index1+index2 < N1+N2-2: # print(index, index1, index2) - if t1[index1+1] < t2[index2+1]: + if (index1 < N1-1) and (t_f1 < t_f2 or index2 == N2-1): index1 += 1 - # break condition relies on existence of spikes at T_end - if index1+1 >= N1: - break - spike_events[index] = t1[index1] # first calculate the previous interval end value - dt_p1 = dt_f1 # the previous time now was the following time before + # the previous time now was the following time before: + dt_p1 = dt_f1 + t_p1 = t_f1 # t_p1 contains the current time point + # get the next time + if index1 < N1-1: + t_f1 = t1[index1+1] + else: + t_f1 = t_end + spike_events[index] = t_p1 s1 = dt_p1 - s2 = (dt_p2*(t2[index2+1]-t1[index1]) + - dt_f2*(t1[index1]-t2[index2])) / isi2 - y_ends[index-1] = (s1*isi2 + s2*isi1)/isi_avrg_cython(isi1, isi2) + s2 = (dt_p2*(t_f2-t_p1) + dt_f2*(t_p1-t_p2)) / isi2 + y_ends[index-1] = (s1*isi2 + s2*isi1)/isi_avrg_cython(isi1, + isi2) # now the next interval start value - dt_f1 = get_min_dist_cython(t1[index1+1], t2, N2, index2) - isi1 = t1[index1+1]-t1[index1] + if index1 < N1-1: + dt_f1 = get_min_dist_cython(t_f1, t2, N2, index2, + t_start, t_end) + isi1 = t_f1-t_p1 + else: + dt_f1 = dt_p1 + isi1 = fmax(t_end-t1[N1-1], t1[N1-1]-t1[N1-2]) + # s1 needs adjustment due to change of isi1 + s1 = dt_p1*(t_end-t1[N1-1])/isi1 # s2 is the same as above, thus we can compute y2 immediately - y_starts[index] = (s1*isi2 + s2*isi1)/isi_avrg_cython(isi1, isi2) - elif t1[index1+1] > t2[index2+1]: + y_starts[index] = (s1*isi2 + s2*isi1)/isi_avrg_cython(isi1, + isi2) + elif (index2 < N2-1) and (t_f1 > t_f2 or index1 == N1-1): index2 += 1 - if index2+1 >= N2: - break - spike_events[index] = t2[index2] # first calculate the previous interval end value - dt_p2 = dt_f2 # the previous time now was the following time before - s1 = (dt_p1*(t1[index1+1]-t2[index2]) + - dt_f1*(t2[index2]-t1[index1])) / isi1 + # the previous time now was the following time before: + dt_p2 = dt_f2 + t_p2 = t_f2 # t_p1 contains the current time point + # get the next time + if index2 < N2-1: + t_f2 = t2[index2+1] + else: + t_f2 = t_end + spike_events[index] = t_p2 + s1 = (dt_p1*(t_f1-t_p2) + dt_f1*(t_p2-t_p1)) / isi1 s2 = dt_p2 - y_ends[index-1] = (s1*isi2 + s2*isi1) / isi_avrg_cython(isi1, isi2) + y_ends[index-1] = (s1*isi2 + s2*isi1) / isi_avrg_cython(isi1, + isi2) # now the next interval start value - dt_f2 = get_min_dist_cython(t2[index2+1], t1, N1, index1) - #s2 = dt_f2 - isi2 = t2[index2+1]-t2[index2] + if index2 < N2-1: + dt_f2 = get_min_dist_cython(t_f2, t1, N1, index1, + t_start, t_end) + isi2 = t_f2-t_p2 + else: + dt_f2 = dt_p2 + isi2 = fmax(t_end-t2[N2-1], t2[N2-1]-t2[N2-2]) + # s2 needs adjustment due to change of isi2 + s2 = dt_p2*(t_end-t2[N2-1])/isi2 # s2 is the same as above, thus we can compute y2 immediately y_starts[index] = (s1*isi2 + s2*isi1)/isi_avrg_cython(isi1, isi2) - else: # t1[index1+1] == t2[index2+1] - generate only one event + else: # t_f1 == t_f2 - generate only one event index1 += 1 index2 += 1 - if (index1+1 >= N1) or (index2+1 >= N2): - break - spike_events[index] = t1[index1] - y_ends[index-1] = 0.0 - y_starts[index] = 0.0 + t_p1 = t_f1 + t_p2 = t_f2 dt_p1 = 0.0 dt_p2 = 0.0 - dt_f1 = get_min_dist_cython(t1[index1+1], t2, N2, index2) - dt_f2 = get_min_dist_cython(t2[index2+1], t1, N1, index1) - isi1 = t1[index1+1]-t1[index1] - isi2 = t2[index2+1]-t2[index2] + spike_events[index] = t_f1 + y_ends[index-1] = 0.0 + y_starts[index] = 0.0 + if index1 < N1-1: + t_f1 = t1[index1+1] + dt_f1 = get_min_dist_cython(t_f1, t2, N2, index2, + t_start, t_end) + isi1 = t_f1 - t_p1 + else: + t_f1 = t_end + dt_f1 = dt_p1 + isi1 = fmax(t_end-t1[N1-1], t1[N1-1]-t1[N1-2]) + if index2 < N2-1: + t_f2 = t2[index2+1] + dt_f2 = get_min_dist_cython(t_f2, t1, N1, index1, + t_start, t_end) + isi2 = t_f2 - t_p2 + else: + t_f2 = t_end + dt_f2 = dt_p2 + isi2 = fmax(t_end-t2[N2-1], t2[N2-1]-t2[N2-2]) index += 1 # the last event is the interval end - spike_events[index] = t1[N1-1] - # the ending value of the last interval - isi1 = max(t1[N1-1]-t1[N1-2], t1[N1-2]-t1[N1-3]) - isi2 = max(t2[N2-1]-t2[N2-2], t2[N2-2]-t2[N2-3]) - s1 = dt_p1*(t1[N1-1]-t1[N1-2])/isi1 - s2 = dt_p2*(t2[N2-1]-t2[N2-2])/isi2 - y_ends[index-1] = (s1*isi2 + s2*isi1) / isi_avrg_cython(isi1, isi2) + if spike_events[index-1] == t_end: + index -= 1 + else: + spike_events[index] = t_end + # the ending value of the last interval + isi1 = max(t_end-t1[N1-1], t1[N1-1]-t1[N1-2]) + isi2 = max(t_end-t2[N2-1], t2[N2-1]-t2[N2-2]) + s1 = dt_f1*(t_end-t1[N1-1])/isi1 + s2 = dt_f2*(t_end-t2[N2-1])/isi2 + y_ends[index-1] = (s1*isi2 + s2*isi1) / isi_avrg_cython(isi1, isi2) # end nogil # use only the data added above diff --git a/pyspike/cython/python_backend.py b/pyspike/cython/python_backend.py index 4c37236..bcf9c30 100644 --- a/pyspike/cython/python_backend.py +++ b/pyspike/cython/python_backend.py @@ -89,122 +89,185 @@ def isi_distance_python(s1, s2, t_start, t_end): ############################################################ # get_min_dist ############################################################ -def get_min_dist(spike_time, spike_train, start_index=0): +def get_min_dist(spike_time, spike_train, start_index, t_start, t_end): """ Returns the minimal distance |spike_time - spike_train[i]| with i>=start_index. """ - d = abs(spike_time - spike_train[start_index]) - start_index += 1 + d = abs(spike_time - t_start) + if start_index < 0: + start_index = 0 while start_index < len(spike_train): d_temp = abs(spike_time - spike_train[start_index]) if d_temp > d: - break + return d else: d = d_temp start_index += 1 - return d + # finally, check the distance to end time + d_temp = abs(t_end - spike_time) + if d_temp > d: + return d + else: + return d_temp ############################################################ # spike_distance_python ############################################################ -def spike_distance_python(spikes1, spikes2): +def spike_distance_python(spikes1, spikes2, t_start, t_end): """ Computes the instantaneous spike-distance S_spike (t) of the two given spike trains. The spike trains are expected to have auxiliary spikes at the beginning and end of the interval. Use the function add_auxiliary_spikes to add those spikes to the spike train. Args: - spikes1, spikes2: ordered arrays of spike times with auxiliary spikes. + - t_start, t_end: edges of the spike train Returns: - PieceWiseLinFunc describing the spike-distance. """ - # check for auxiliary spikes - first and last spikes should be identical - assert spikes1[0] == spikes2[0], \ - "Given spike trains seems not to have auxiliary spikes!" - assert spikes1[-1] == spikes2[-1], \ - "Given spike trains seems not to have auxiliary spikes!" + # shorter variables t1 = spikes1 t2 = spikes2 - spike_events = np.empty(len(t1) + len(t2) - 2) - spike_events[0] = t1[0] - y_starts = np.empty(len(spike_events) - 1) - y_ends = np.empty(len(spike_events) - 1) + N1 = len(t1) + N2 = len(t2) - index1 = 0 - index2 = 0 + spike_events = np.empty(N1+N2+2) + + y_starts = np.empty(len(spike_events)-1) + y_ends = np.empty(len(spike_events)-1) + + spike_events[0] = t_start + t_p1 = t_start + t_p2 = t_start + if t1[0] > t_start: + # dt_p1 = t2[0]-t_start + dt_p1 = 0.0 + t_f1 = t1[0] + dt_f1 = get_min_dist(t_f1, t2, 0, t_start, t_end) + isi1 = max(t_f1-t_start, t1[1]-t1[0]) + s1 = dt_f1*(t_f1-t_start)/isi1 + index1 = -1 + else: + dt_p1 = 0.0 + t_f1 = t1[1] + dt_f1 = get_min_dist(t_f1, t2, 0, t_start, t_end) + isi1 = t1[1]-t1[0] + s1 = dt_p1 + index1 = 0 + if t2[0] > t_start: + # dt_p1 = t2[0]-t_start + dt_p2 = 0.0 + t_f2 = t2[0] + dt_f2 = get_min_dist(t_f2, t1, 0, t_start, t_end) + isi2 = max(t_f2-t_start, t2[1]-t2[0]) + s2 = dt_f2*(t_f2-t_start)/isi2 + index2 = -1 + else: + dt_p2 = 0.0 + t_f2 = t2[1] + dt_f2 = get_min_dist(t_f2, t1, 0, t_start, t_end) + isi2 = t2[1]-t2[0] + s2 = dt_p2 + index2 = 0 + + y_starts[0] = (s1*isi2 + s2*isi1) / (0.5*(isi1+isi2)**2) index = 1 - dt_p1 = 0.0 - dt_f1 = get_min_dist(t1[1], t2, 0) - dt_p2 = 0.0 - dt_f2 = get_min_dist(t2[1], t1, 0) - isi1 = max(t1[1]-t1[0], t1[2]-t1[1]) - isi2 = max(t2[1]-t2[0], t2[2]-t2[1]) - s1 = dt_f1*(t1[1]-t1[0])/isi1 - s2 = dt_f2*(t2[1]-t2[0])/isi2 - y_starts[0] = (s1*isi2 + s2*isi1) / ((isi1+isi2)**2/2) - while True: + + while index1+index2 < N1+N2-2: # print(index, index1, index2) - if t1[index1+1] < t2[index2+1]: + if (index1 < N1-1) and (t_f1 < t_f2 or index2 == N2-1): index1 += 1 - # break condition relies on existence of spikes at T_end - if index1+1 >= len(t1): - break - spike_events[index] = t1[index1] # first calculate the previous interval end value - dt_p1 = dt_f1 # the previous time was the following time before + # the previous time now was the following time before: + dt_p1 = dt_f1 + t_p1 = t_f1 # t_p1 contains the current time point + # get the next time + if index1 < N1-1: + t_f1 = t1[index1+1] + else: + t_f1 = t_end + spike_events[index] = t_p1 s1 = dt_p1 - s2 = (dt_p2*(t2[index2+1]-t1[index1]) + - dt_f2*(t1[index1]-t2[index2])) / isi2 - y_ends[index-1] = (s1*isi2 + s2*isi1) / ((isi1+isi2)**2/2) + s2 = (dt_p2*(t_f2-t_p1) + dt_f2*(t_p1-t_p2)) / isi2 + y_ends[index-1] = (s1*isi2 + s2*isi1) / (0.5*(isi1+isi2)**2) # now the next interval start value - dt_f1 = get_min_dist(t1[index1+1], t2, index2) - isi1 = t1[index1+1]-t1[index1] + if index1 < N1-1: + dt_f1 = get_min_dist(t_f1, t2, index2, t_start, t_end) + isi1 = t_f1-t_p1 + else: + dt_f1 = dt_p1 + isi1 = max(t_end-t1[N1-1], t1[N1-1]-t1[N1-2]) + # s1 needs adjustment due to change of isi1 + s1 = dt_p1*(t_end-t1[N1-1])/isi1 # s2 is the same as above, thus we can compute y2 immediately - y_starts[index] = (s1*isi2 + s2*isi1) / ((isi1+isi2)**2/2) - elif t1[index1+1] > t2[index2+1]: + y_starts[index] = (s1*isi2 + s2*isi1) / (0.5*(isi1+isi2)**2) + elif (index2 < N2-1) and (t_f1 > t_f2 or index1 == N1-1): index2 += 1 - if index2+1 >= len(t2): - break - spike_events[index] = t2[index2] # first calculate the previous interval end value - dt_p2 = dt_f2 # the previous time was the following time before - s1 = (dt_p1*(t1[index1+1]-t2[index2]) + - dt_f1*(t2[index2]-t1[index1])) / isi1 + # the previous time now was the following time before: + dt_p2 = dt_f2 + t_p2 = t_f2 # t_p1 contains the current time point + # get the next time + if index2 < N2-1: + t_f2 = t2[index2+1] + else: + t_f2 = t_end + spike_events[index] = t_p2 + s1 = (dt_p1*(t_f1-t_p2) + dt_f1*(t_p2-t_p1)) / isi1 s2 = dt_p2 - y_ends[index-1] = (s1*isi2 + s2*isi1) / ((isi1+isi2)**2/2) + y_ends[index-1] = (s1*isi2 + s2*isi1) / (0.5*(isi1+isi2)**2) # now the next interval start value - dt_f2 = get_min_dist(t2[index2+1], t1, index1) - #s2 = dt_f2 - isi2 = t2[index2+1]-t2[index2] + if index2 < N2-1: + dt_f2 = get_min_dist(t_f2, t1, index1, t_start, t_end) + isi2 = t_f2-t_p2 + else: + dt_f2 = dt_p2 + isi2 = max(t_end-t2[N2-1], t2[N2-1]-t2[N2-2]) + # s2 needs adjustment due to change of isi2 + s2 = dt_p2*(t_end-t2[N2-1])/isi2 # s2 is the same as above, thus we can compute y2 immediately - y_starts[index] = (s1*isi2 + s2*isi1) / ((isi1+isi2)**2/2) - else: # t1[index1+1] == t2[index2+1] - generate only one event + y_starts[index] = (s1*isi2 + s2*isi1) / (0.5*(isi1+isi2)**2) + else: # t_f1 == t_f2 - generate only one event index1 += 1 index2 += 1 - if (index1+1 >= len(t1)) or (index2+1 >= len(t2)): - break - assert dt_f2 == 0.0 - assert dt_f1 == 0.0 - spike_events[index] = t1[index1] - y_ends[index-1] = 0.0 - y_starts[index] = 0.0 + t_p1 = t_f1 + t_p2 = t_f2 dt_p1 = 0.0 dt_p2 = 0.0 - dt_f1 = get_min_dist(t1[index1+1], t2, index2) - dt_f2 = get_min_dist(t2[index2+1], t1, index1) - isi1 = t1[index1+1]-t1[index1] - isi2 = t2[index2+1]-t2[index2] + spike_events[index] = t_f1 + y_ends[index-1] = 0.0 + y_starts[index] = 0.0 + if index1 < N1-1: + t_f1 = t1[index1+1] + dt_f1 = get_min_dist(t_f1, t2, index2, t_start, t_end) + isi1 = t_f1 - t_p1 + else: + t_f1 = t_end + dt_f1 = dt_p1 + isi1 = max(t_end-t1[N1-1], t1[N1-1]-t1[N1-2]) + if index2 < N2-1: + t_f2 = t2[index2+1] + dt_f2 = get_min_dist(t_f2, t1, index1, t_start, t_end) + isi2 = t_f2 - t_p2 + else: + t_f2 = t_end + dt_f2 = dt_p2 + isi2 = max(t_end-t2[N2-1], t2[N2-1]-t2[N2-2]) index += 1 # the last event is the interval end - spike_events[index] = t1[-1] - # the ending value of the last interval - isi1 = max(t1[-1]-t1[-2], t1[-2]-t1[-3]) - isi2 = max(t2[-1]-t2[-2], t2[-2]-t2[-3]) - s1 = dt_p1*(t1[-1]-t1[-2])/isi1 - s2 = dt_p2*(t2[-1]-t2[-2])/isi2 - y_ends[index-1] = (s1*isi2 + s2*isi1) / ((isi1+isi2)**2/2) + if spike_events[index-1] == t_end: + index -= 1 + else: + spike_events[index] = t_end + # the ending value of the last interval + isi1 = max(t_end-t1[N1-1], t1[N1-1]-t1[N1-2]) + isi2 = max(t_end-t2[N2-1], t2[N2-1]-t2[N2-2]) + s1 = dt_f1*(t_end-t1[N1-1])/isi1 + s2 = dt_f2*(t_end-t2[N2-1])/isi2 + y_ends[index-1] = (s1*isi2 + s2*isi1) / (0.5*(isi1+isi2)**2) + # use only the data added above # could be less than original length due to equal spike times return spike_events[:index+1], y_starts[:index], y_ends[:index] diff --git a/pyspike/spike_distance.py b/pyspike/spike_distance.py index f721c86..8d03d70 100644 --- a/pyspike/spike_distance.py +++ b/pyspike/spike_distance.py @@ -14,23 +14,23 @@ from pyspike.generic import _generic_profile_multi, _generic_distance_matrix ############################################################ # spike_profile ############################################################ -def spike_profile(spikes1, spikes2): +def spike_profile(spike_train1, spike_train2): """ Computes the spike-distance profile S_spike(t) of the two given spike trains. Returns the profile as a PieceWiseLinFunc object. The S_spike - values are defined positive S_spike(t)>=0. The spike trains are expected to - have auxiliary spikes at the beginning and end of the interval. Use the - function add_auxiliary_spikes to add those spikes to the spike train. + values are defined positive S_spike(t)>=0. - :param spikes1: ordered array of spike times with auxiliary spikes. - :param spikes2: ordered array of spike times with auxiliary spikes. + :param spike_train1: First spike train. + :type spike_train1: :class:`pyspike.SpikeTrain` + :param spike_train2: Second spike train. + :type spike_train2: :class:`pyspike.SpikeTrain` :returns: The spike-distance profile :math:`S_{spike}(t)`. :rtype: :class:`pyspike.function.PieceWiseLinFunc` """ - # check for auxiliary spikes - first and last spikes should be identical - assert spikes1[0] == spikes2[0], \ + # check whether the spike trains are defined for the same interval + assert spike_train1.t_start == spike_train2.t_start, \ "Given spike trains seems not to have auxiliary spikes!" - assert spikes1[-1] == spikes2[-1], \ + assert spike_train1.t_end == spike_train2.t_end, \ "Given spike trains seems not to have auxiliary spikes!" # cython implementation @@ -45,21 +45,26 @@ Falling back to slow python backend.") from cython.python_backend import spike_distance_python \ as spike_distance_impl - times, y_starts, y_ends = spike_distance_impl(spikes1, spikes2) + times, y_starts, y_ends = spike_distance_impl(spike_train1.spikes, + spike_train2.spikes, + spike_train1.t_start, + spike_train1.t_end) return PieceWiseLinFunc(times, y_starts, y_ends) ############################################################ # spike_distance ############################################################ -def spike_distance(spikes1, spikes2, interval=None): +def spike_distance(spike_train1, spike_train2, interval=None): """ Computes the spike-distance S of the given spike trains. The spike-distance is the integral over the isi distance profile S_spike(t): .. math:: S = \int_{T_0}^{T_1} S_{spike}(t) dt. - :param spikes1: ordered array of spike times with auxiliary spikes. - :param spikes2: ordered array of spike times with auxiliary spikes. + :param spike_train1: First spike train. + :type spike_train1: :class:`pyspike.SpikeTrain` + :param spike_train2: Second spike train. + :type spike_train2: :class:`pyspike.SpikeTrain` :param interval: averaging interval given as a pair of floats (T0, T1), if None the average over the whole function is computed. :type interval: Pair of floats or None. @@ -67,7 +72,7 @@ def spike_distance(spikes1, spikes2, interval=None): :rtype: double """ - return spike_profile(spikes1, spikes2).avrg(interval) + return spike_profile(spike_train1, spike_train2).avrg(interval) ############################################################ @@ -102,7 +107,7 @@ def spike_distance_multi(spike_trains, indices=None, interval=None): S_{spike} = \int_0^T 2/((N(N-1)) sum_{} S_{spike}^{i, j} dt where the sum goes over all pairs - :param spike_trains: list of spike trains + :param spike_trains: list of :class:`pyspike.SpikeTrain` :param indices: list of indices defining which spike trains to use, if None all given spike trains are used (default=None) :type indices: list or None @@ -121,7 +126,7 @@ def spike_distance_multi(spike_trains, indices=None, interval=None): def spike_distance_matrix(spike_trains, indices=None, interval=None): """ Computes the time averaged spike-distance of all pairs of spike-trains. - :param spike_trains: list of spike trains + :param spike_trains: list of :class:`pyspike.SpikeTrain` :param indices: list of indices defining which spike trains to use, if None all given spike trains are used (default=None) :type indices: list or None diff --git a/test/test_distance.py b/test/test_distance.py index b54e908..4af0e63 100644 --- a/test/test_distance.py +++ b/test/test_distance.py @@ -64,13 +64,26 @@ def test_isi(): def test_spike(): # generate two spike trains: - t1 = np.array([0.2, 0.4, 0.6, 0.7]) - t2 = np.array([0.3, 0.45, 0.8, 0.9, 0.95]) + t1 = SpikeTrain([0.0, 2.0, 5.0, 8.0], 10.0) + t2 = SpikeTrain([0.0, 1.0, 5.0, 9.0], 10.0) + + expected_times = np.array([0.0, 1.0, 2.0, 5.0, 8.0, 9.0, 10.0]) + + f = spk.spike_profile(t1, t2) + + assert_equal(f.x, expected_times) + + assert_almost_equal(f.avrg(), 0.1662415, decimal=6) + assert_almost_equal(f.y2[-1], 0.1394558, decimal=6) + + t1 = SpikeTrain([0.2, 0.4, 0.6, 0.7], 1.0) + t2 = SpikeTrain([0.3, 0.45, 0.8, 0.9, 0.95], 1.0) # pen&paper calculation of the spike distance expected_times = [0.0, 0.2, 0.3, 0.4, 0.45, 0.6, 0.7, 0.8, 0.9, 0.95, 1.0] s1 = np.array([0.1, 0.1, (0.1*0.1+0.05*0.1)/0.2, 0.05, (0.05*0.15 * 2)/0.2, - 0.15, 0.1, 0.1*0.2/0.3, 0.1**2/0.3, 0.1*0.05/0.3, 0.1]) + 0.15, 0.1, (0.1*0.1+0.1*0.2)/0.3, (0.1*0.2+0.1*0.1)/0.3, + (0.1*0.05+0.1*0.25)/0.3, 0.1]) s2 = np.array([0.1, 0.1*0.2/0.3, 0.1, (0.1*0.05 * 2)/.15, 0.05, (0.05*0.2+0.1*0.15)/0.35, (0.05*0.1+0.1*0.25)/0.35, 0.1, 0.1, 0.05, 0.05]) @@ -86,19 +99,18 @@ def test_spike(): (expected_y1+expected_y2)/2) expected_spike_val /= (expected_times[-1]-expected_times[0]) - t1 = spk.add_auxiliary_spikes(t1, 1.0) - t2 = spk.add_auxiliary_spikes(t2, 1.0) f = spk.spike_profile(t1, t2) assert_equal(f.x, expected_times) assert_array_almost_equal(f.y1, expected_y1, decimal=15) assert_array_almost_equal(f.y2, expected_y2, decimal=15) - assert_equal(f.avrg(), expected_spike_val) - assert_equal(spk.spike_distance(t1, t2), expected_spike_val) + assert_almost_equal(f.avrg(), expected_spike_val, decimal=15) + assert_almost_equal(spk.spike_distance(t1, t2), expected_spike_val, + decimal=15) # check with some equal spike times - t1 = np.array([0.2, 0.4, 0.6]) - t2 = np.array([0.1, 0.4, 0.5, 0.6]) + t1 = SpikeTrain([0.2, 0.4, 0.6], [0.0, 1.0]) + t2 = SpikeTrain([0.1, 0.4, 0.5, 0.6], [0.0, 1.0]) expected_times = [0.0, 0.1, 0.2, 0.4, 0.5, 0.6, 1.0] s1 = np.array([0.1, 0.1*0.1/0.2, 0.1, 0.0, 0.0, 0.0, 0.0]) @@ -115,8 +127,6 @@ def test_spike(): (expected_y1+expected_y2)/2) expected_spike_val /= (expected_times[-1]-expected_times[0]) - t1 = spk.add_auxiliary_spikes(t1, 1.0) - t2 = spk.add_auxiliary_spikes(t2, 1.0) f = spk.spike_profile(t1, t2) assert_equal(f.x, expected_times) @@ -315,6 +325,6 @@ def test_multi_variate_subsets(): if __name__ == "__main__": test_isi() - # test_spike() + test_spike() # test_multi_isi() # test_multi_spike() -- cgit v1.2.3