From 99730806c22f79089d4cdaf2a1ce713712ad557b Mon Sep 17 00:00:00 2001 From: Mario Mulansky Date: Wed, 1 Oct 2014 18:21:36 +0200 Subject: added multithreaded version of multi_distance (slow) --- pyspike/cython_distance.pyx | 215 ++++++++++++++++++++-------------------- pyspike/distances.py | 233 ++++++++++---------------------------------- 2 files changed, 162 insertions(+), 286 deletions(-) (limited to 'pyspike') diff --git a/pyspike/cython_distance.pyx b/pyspike/cython_distance.pyx index 6edcc01..23ffc37 100644 --- a/pyspike/cython_distance.pyx +++ b/pyspike/cython_distance.pyx @@ -54,38 +54,41 @@ def isi_distance_cython(double[:] s1, spike_events[0] = s1[0] # the values have one entry less - the number of intervals between events isi_values = np.empty(N1+N2-1) - isi_values[0] = (nu1-nu2)/max(nu1,nu2) - index1 = 0 - index2 = 0 - index = 1 - while True: - # check which spike is next - from s1 or s2 - if s1[index1+1] < s2[index2+1]: - index1 += 1 - # break condition relies on existence of spikes at T_end - if index1 >= N1: - break - spike_events[index] = s1[index1] - nu1 = s1[index1+1]-s1[index1] - elif s1[index1+1] > s2[index2+1]: - index2 += 1 - if index2 >= N2: - break - spike_events[index] = s2[index2] - nu2 = s2[index2+1]-s2[index2] - else: # s1[index1+1] == s2[index2+1] - index1 += 1 - index2 += 1 - if (index1 >= N1) or (index2 >= N2): - break - spike_events[index] = s1[index1] - nu1 = s1[index1+1]-s1[index1] - nu2 = s2[index2+1]-s2[index2] - # compute the corresponding isi-distance - isi_values[index] = (nu1 - nu2) / max(nu1, nu2) - index += 1 - # the last event is the interval end - spike_events[index] = s1[N1] + + with nogil: # release the interpreter to allow multithreading + isi_values[0] = (nu1-nu2)/max(nu1,nu2) + index1 = 0 + index2 = 0 + index = 1 + while True: + # check which spike is next - from s1 or s2 + if s1[index1+1] < s2[index2+1]: + index1 += 1 + # break condition relies on existence of spikes at T_end + if index1 >= N1: + break + spike_events[index] = s1[index1] + nu1 = s1[index1+1]-s1[index1] + elif s1[index1+1] > s2[index2+1]: + index2 += 1 + if index2 >= N2: + break + spike_events[index] = s2[index2] + nu2 = s2[index2+1]-s2[index2] + else: # s1[index1+1] == s2[index2+1] + index1 += 1 + index2 += 1 + if (index1 >= N1) or (index2 >= N2): + break + spike_events[index] = s1[index1] + nu1 = s1[index1+1]-s1[index1] + nu2 = s2[index2+1]-s2[index2] + # compute the corresponding isi-distance + isi_values[index] = (nu1 - nu2) / max(nu1, nu2) + index += 1 + # the last event is the interval end + spike_events[index] = s1[N1] + # end nogil return spike_events[:index+1], isi_values[:index] @@ -98,7 +101,7 @@ cdef inline double get_min_dist_cython(double spike_time, # use memory view to ensure inlining # np.ndarray[DTYPE_t,ndim=1] spike_train, int N, - int start_index=0): + int start_index=0) nogil: """ Returns the minimal distance |spike_time - spike_train[i]| with i>=start_index. """ @@ -136,78 +139,80 @@ def spike_distance_cython(double[:] t1, y_starts = np.empty(len(spike_events)-1) y_ends = np.empty(len(spike_events)-1) - index1 = 0 - index2 = 0 - index = 1 - dt_p1 = 0.0 - dt_f1 = get_min_dist_cython(t1[1], t2, N2, 0) - dt_p2 = 0.0 - dt_f2 = get_min_dist_cython(t2[1], t1, N1, 0) - isi1 = max(t1[1]-t1[0], t1[2]-t1[1]) - isi2 = max(t2[1]-t2[0], t2[2]-t2[1]) - s1 = dt_f1*(t1[1]-t1[0])/isi1 - s2 = dt_f2*(t2[1]-t2[0])/isi2 - y_starts[0] = (s1*isi2 + s2*isi1) / ((isi1+isi2)**2/2) - while True: - # print(index, index1, index2) - if t1[index1+1] < t2[index2+1]: - index1 += 1 - # break condition relies on existence of spikes at T_end - if index1+1 >= N1: - break - spike_events[index] = t1[index1] - # first calculate the previous interval end value - dt_p1 = dt_f1 # the previous time now was the following time before - s1 = dt_p1 - s2 = (dt_p2*(t2[index2+1]-t1[index1]) + - dt_f2*(t1[index1]-t2[index2])) / isi2 - y_ends[index-1] = (s1*isi2 + s2*isi1)/(0.5*(isi1+isi2)*(isi1+isi2)) - # now the next interval start value - dt_f1 = get_min_dist_cython(t1[index1+1], t2, N2, index2) - isi1 = t1[index1+1]-t1[index1] - # s2 is the same as above, thus we can compute y2 immediately - y_starts[index] = (s1*isi2 + s2*isi1)/(0.5*(isi1+isi2)*(isi1+isi2)) - elif t1[index1+1] > t2[index2+1]: - index2 += 1 - if index2+1 >= N2: - break - spike_events[index] = t2[index2] - # first calculate the previous interval end value - dt_p2 = dt_f2 # the previous time now was the following time before - s1 = (dt_p1*(t1[index1+1]-t2[index2]) + - dt_f1*(t2[index2]-t1[index1])) / isi1 - s2 = dt_p2 - y_ends[index-1] = (s1*isi2 + s2*isi1) / (0.5*(isi1+isi2)*(isi1+isi2)) - # now the next interval start value - dt_f2 = get_min_dist_cython(t2[index2+1], t1, N1, index1) - #s2 = dt_f2 - isi2 = t2[index2+1]-t2[index2] - # s2 is the same as above, thus we can compute y2 immediately - y_starts[index] = (s1*isi2 + s2*isi1)/(0.5*(isi1+isi2)*(isi1+isi2)) - else: # t1[index1+1] == t2[index2+1] - generate only one event - index1 += 1 - index2 += 1 - if (index1+1 >= N1) or (index2+1 >= N2): - break - spike_events[index] = t1[index1] - y_ends[index-1] = 0.0 - y_starts[index] = 0.0 - dt_p1 = 0.0 - dt_p2 = 0.0 - dt_f1 = get_min_dist_cython(t1[index1+1], t2, N2, index2) - dt_f2 = get_min_dist_cython(t2[index2+1], t1, N1, index1) - isi1 = t1[index1+1]-t1[index1] - isi2 = t2[index2+1]-t2[index2] - index += 1 - # the last event is the interval end - spike_events[index] = t1[N1-1] - # the ending value of the last interval - isi1 = max(t1[N1-1]-t1[N1-2], t1[N1-2]-t1[N1-3]) - isi2 = max(t2[N2-1]-t2[N2-2], t2[N2-2]-t2[N2-3]) - s1 = dt_p1*(t1[N1-1]-t1[N1-2])/isi1 - s2 = dt_p2*(t2[N2-1]-t2[N2-2])/isi2 - y_ends[index-1] = (s1*isi2 + s2*isi1) / (0.5*(isi1+isi2)*(isi1+isi2)) + with nogil: # release the interpreter to allow multithreading + index1 = 0 + index2 = 0 + index = 1 + dt_p1 = 0.0 + dt_f1 = get_min_dist_cython(t1[1], t2, N2, 0) + dt_p2 = 0.0 + dt_f2 = get_min_dist_cython(t2[1], t1, N1, 0) + isi1 = max(t1[1]-t1[0], t1[2]-t1[1]) + isi2 = max(t2[1]-t2[0], t2[2]-t2[1]) + s1 = dt_f1*(t1[1]-t1[0])/isi1 + s2 = dt_f2*(t2[1]-t2[0])/isi2 + y_starts[0] = (s1*isi2 + s2*isi1) / ((isi1+isi2)**2/2) + while True: + # print(index, index1, index2) + if t1[index1+1] < t2[index2+1]: + index1 += 1 + # break condition relies on existence of spikes at T_end + if index1+1 >= N1: + break + spike_events[index] = t1[index1] + # first calculate the previous interval end value + dt_p1 = dt_f1 # the previous time now was the following time before + s1 = dt_p1 + s2 = (dt_p2*(t2[index2+1]-t1[index1]) + + dt_f2*(t1[index1]-t2[index2])) / isi2 + y_ends[index-1] = (s1*isi2 + s2*isi1)/(0.5*(isi1+isi2)*(isi1+isi2)) + # now the next interval start value + dt_f1 = get_min_dist_cython(t1[index1+1], t2, N2, index2) + isi1 = t1[index1+1]-t1[index1] + # s2 is the same as above, thus we can compute y2 immediately + y_starts[index] = (s1*isi2 + s2*isi1)/(0.5*(isi1+isi2)*(isi1+isi2)) + elif t1[index1+1] > t2[index2+1]: + index2 += 1 + if index2+1 >= N2: + break + spike_events[index] = t2[index2] + # first calculate the previous interval end value + dt_p2 = dt_f2 # the previous time now was the following time before + s1 = (dt_p1*(t1[index1+1]-t2[index2]) + + dt_f1*(t2[index2]-t1[index1])) / isi1 + s2 = dt_p2 + y_ends[index-1] = (s1*isi2 + s2*isi1) / (0.5*(isi1+isi2)*(isi1+isi2)) + # now the next interval start value + dt_f2 = get_min_dist_cython(t2[index2+1], t1, N1, index1) + #s2 = dt_f2 + isi2 = t2[index2+1]-t2[index2] + # s2 is the same as above, thus we can compute y2 immediately + y_starts[index] = (s1*isi2 + s2*isi1)/(0.5*(isi1+isi2)*(isi1+isi2)) + else: # t1[index1+1] == t2[index2+1] - generate only one event + index1 += 1 + index2 += 1 + if (index1+1 >= N1) or (index2+1 >= N2): + break + spike_events[index] = t1[index1] + y_ends[index-1] = 0.0 + y_starts[index] = 0.0 + dt_p1 = 0.0 + dt_p2 = 0.0 + dt_f1 = get_min_dist_cython(t1[index1+1], t2, N2, index2) + dt_f2 = get_min_dist_cython(t2[index2+1], t1, N1, index1) + isi1 = t1[index1+1]-t1[index1] + isi2 = t2[index2+1]-t2[index2] + index += 1 + # the last event is the interval end + spike_events[index] = t1[N1-1] + # the ending value of the last interval + isi1 = max(t1[N1-1]-t1[N1-2], t1[N1-2]-t1[N1-3]) + isi2 = max(t2[N2-1]-t2[N2-2], t2[N2-2]-t2[N2-3]) + s1 = dt_p1*(t1[N1-1]-t1[N1-2])/isi1 + s2 = dt_p2*(t2[N2-1]-t2[N2-2])/isi2 + y_ends[index-1] = (s1*isi2 + s2*isi1) / (0.5*(isi1+isi2)*(isi1+isi2)) + # end nogil + # use only the data added above # could be less than original length due to equal spike times - return spike_events[:index+1], y_starts[:index], y_ends[:index] diff --git a/pyspike/distances.py b/pyspike/distances.py index 79278b4..35650f7 100644 --- a/pyspike/distances.py +++ b/pyspike/distances.py @@ -6,6 +6,7 @@ Copyright 2014, Mario Mulansky """ import numpy as np +import threading from pyspike import PieceWiseConstFunc, PieceWiseLinFunc @@ -126,6 +127,57 @@ def multi_distance(spike_trains, pair_distance_func, indices=None): return average_dist +############################################################ +# multi_distance_par +############################################################ +def multi_distance_par(spike_trains, pair_distance_func, indices=None): + """ parallel implementation of the multi-distance. Not currently used as + it does not improve the performance. + """ + + num_threads = 2 + + lock = threading.Lock() + def run(spike_trains, index_pairs, average_dist): + (i,j) = index_pairs[0] + # print(i,j) + this_avrg = pair_distance_func(spike_trains[i], spike_trains[j]) + for (i,j) in index_pairs[1:]: + # print(i,j) + current_dist = pair_distance_func(spike_trains[i], spike_trains[j]) + this_avrg.add(current_dist) + with lock: + average_dist.add(this_avrg) + + if indices==None: + indices = np.arange(len(spike_trains)) + indices = np.array(indices) + # check validity of indices + assert (indices < len(spike_trains)).all() and (indices >= 0).all(), \ + "Invalid index list." + # generate a list of possible index pairs + pairs = [(i,j) for i in indices for j in indices[i+1:]] + num_pairs = len(pairs) + + # start with first pair + (i,j) = pairs[0] + average_dist = pair_distance_func(spike_trains[i], spike_trains[j]) + # remove the one we already computed + pairs = pairs[1:] + # distribute the rest into num_threads pieces + clustered_pairs = [ pairs[i::num_threads] for i in xrange(num_threads) ] + + threads = [] + for pairs in clustered_pairs: + t = threading.Thread(target=run, args=(spike_trains, pairs, average_dist)) + threads.append(t) + t.start() + for t in threads: + t.join() + average_dist.mul_scalar(1.0/num_pairs) # normalize + return average_dist + + ############################################################ # isi_distance_multi ############################################################ @@ -161,184 +213,3 @@ def spike_distance_multi(spike_trains, indices=None): """ return multi_distance(spike_trains, spike_distance, indices) - -############################################################ -############################################################ -# VANILLA PYTHON IMPLEMENTATIONS OF ISI AND SPIKE DISTANCE -############################################################ -############################################################ - - -############################################################ -# isi_distance_python -############################################################ -def isi_distance_python(s1, s2): - """ Plain Python implementation of the isi distance. - """ - # compute the interspike interval - nu1 = s1[1:]-s1[:-1] - nu2 = s2[1:]-s2[:-1] - - # compute the isi-distance - spike_events = np.empty(len(nu1)+len(nu2)) - spike_events[0] = s1[0] - # the values have one entry less - the number of intervals between events - isi_values = np.empty(len(spike_events)-1) - # add the distance of the first events - # isi_values[0] = nu1[0]/nu2[0] - 1.0 if nu1[0] <= nu2[0] \ - # else 1.0 - nu2[0]/nu1[0] - isi_values[0] = (nu1[0]-nu2[0])/max(nu1[0],nu2[0]) - index1 = 0 - index2 = 0 - index = 1 - while True: - # check which spike is next - from s1 or s2 - if s1[index1+1] < s2[index2+1]: - index1 += 1 - # break condition relies on existence of spikes at T_end - if index1 >= len(nu1): - break - spike_events[index] = s1[index1] - elif s1[index1+1] > s2[index2+1]: - index2 += 1 - if index2 >= len(nu2): - break - spike_events[index] = s2[index2] - else: # s1[index1+1] == s2[index2+1] - index1 += 1 - index2 += 1 - if (index1 >= len(nu1)) or (index2 >= len(nu2)): - break - spike_events[index] = s1[index1] - # compute the corresponding isi-distance - isi_values[index] = (nu1[index1]-nu2[index2]) / \ - max(nu1[index1], nu2[index2]) - index += 1 - # the last event is the interval end - spike_events[index] = s1[-1] - # use only the data added above - # could be less than original length due to equal spike times - return PieceWiseConstFunc(spike_events[:index+1], isi_values[:index]) - - -############################################################ -# get_min_dist -############################################################ -def get_min_dist(spike_time, spike_train, start_index=0): - """ Returns the minimal distance |spike_time - spike_train[i]| - with i>=start_index. - """ - d = abs(spike_time - spike_train[start_index]) - start_index += 1 - while start_index < len(spike_train): - d_temp = abs(spike_time - spike_train[start_index]) - if d_temp > d: - break - else: - d = d_temp - start_index += 1 - return d - - -############################################################ -# spike_distance_python -############################################################ -def spike_distance_python(spikes1, spikes2): - """ Computes the instantaneous spike-distance S_spike (t) of the two given - spike trains. The spike trains are expected to have auxiliary spikes at the - beginning and end of the interval. Use the function add_auxiliary_spikes to - add those spikes to the spike train. - Args: - - spikes1, spikes2: ordered arrays of spike times with auxiliary spikes. - Returns: - - PieceWiseLinFunc describing the spike-distance. - """ - # check for auxiliary spikes - first and last spikes should be identical - assert spikes1[0]==spikes2[0], \ - "Given spike trains seems not to have auxiliary spikes!" - assert spikes1[-1]==spikes2[-1], \ - "Given spike trains seems not to have auxiliary spikes!" - # shorter variables - t1 = spikes1 - t2 = spikes2 - - spike_events = np.empty(len(t1)+len(t2)-2) - spike_events[0] = t1[0] - y_starts = np.empty(len(spike_events)-1) - y_ends = np.empty(len(spike_events)-1) - - index1 = 0 - index2 = 0 - index = 1 - dt_p1 = 0.0 - dt_f1 = get_min_dist(t1[1], t2, 0) - dt_p2 = 0.0 - dt_f2 = get_min_dist(t2[1], t1, 0) - isi1 = max(t1[1]-t1[0], t1[2]-t1[1]) - isi2 = max(t2[1]-t2[0], t2[2]-t2[1]) - s1 = dt_f1*(t1[1]-t1[0])/isi1 - s2 = dt_f2*(t2[1]-t2[0])/isi2 - y_starts[0] = (s1*isi2 + s2*isi1) / ((isi1+isi2)**2/2) - while True: - # print(index, index1, index2) - if t1[index1+1] < t2[index2+1]: - index1 += 1 - # break condition relies on existence of spikes at T_end - if index1+1 >= len(t1): - break - spike_events[index] = t1[index1] - # first calculate the previous interval end value - dt_p1 = dt_f1 # the previous time now was the following time before - s1 = dt_p1 - s2 = (dt_p2*(t2[index2+1]-t1[index1]) + dt_f2*(t1[index1]-t2[index2])) / isi2 - y_ends[index-1] = (s1*isi2 + s2*isi1) / ((isi1+isi2)**2/2) - # now the next interval start value - dt_f1 = get_min_dist(t1[index1+1], t2, index2) - isi1 = t1[index1+1]-t1[index1] - # s2 is the same as above, thus we can compute y2 immediately - y_starts[index] = (s1*isi2 + s2*isi1) / ((isi1+isi2)**2/2) - elif t1[index1+1] > t2[index2+1]: - index2 += 1 - if index2+1 >= len(t2): - break - spike_events[index] = t2[index2] - # first calculate the previous interval end value - dt_p2 = dt_f2 # the previous time now was the following time before - s1 = (dt_p1*(t1[index1+1]-t2[index2]) + dt_f1*(t2[index2]-t1[index1])) / isi1 - s2 = dt_p2 - y_ends[index-1] = (s1*isi2 + s2*isi1) / ((isi1+isi2)**2/2) - # now the next interval start value - dt_f2 = get_min_dist(t2[index2+1], t1, index1) - #s2 = dt_f2 - isi2 = t2[index2+1]-t2[index2] - # s2 is the same as above, thus we can compute y2 immediately - y_starts[index] = (s1*isi2 + s2*isi1) / ((isi1+isi2)**2/2) - else: # t1[index1+1] == t2[index2+1] - generate only one event - index1 += 1 - index2 += 1 - if (index1+1 >= len(t1)) or (index2+1 >= len(t2)): - break - assert dt_f2 == 0.0 - assert dt_f1 == 0.0 - spike_events[index] = t1[index1] - y_ends[index-1] = 0.0 - y_starts[index] = 0.0 - dt_p1 = 0.0 - dt_p2 = 0.0 - dt_f1 = get_min_dist(t1[index1+1], t2, index2) - dt_f2 = get_min_dist(t2[index2+1], t1, index1) - isi1 = t1[index1+1]-t1[index1] - isi2 = t2[index2+1]-t2[index2] - index += 1 - # the last event is the interval end - spike_events[index] = t1[-1] - # the ending value of the last interval - isi1 = max(t1[-1]-t1[-2], t1[-2]-t1[-3]) - isi2 = max(t2[-1]-t2[-2], t2[-2]-t2[-3]) - s1 = dt_p1*(t1[-1]-t1[-2])/isi1 - s2 = dt_p2*(t2[-1]-t2[-2])/isi2 - y_ends[index-1] = (s1*isi2 + s2*isi1) / ((isi1+isi2)**2/2) - # use only the data added above - # could be less than original length due to equal spike times - return PieceWiseLinFunc(spike_events[:index+1], - y_starts[:index], y_ends[:index]) -- cgit v1.2.3