From 36d80c9ec1d28488f9b5c97cd202c196efff694e Mon Sep 17 00:00:00 2001 From: Mario Mulansky Date: Fri, 24 Apr 2015 15:58:35 +0200 Subject: distance tests now pass with new spike trains --- pyspike/cython/cython_distance.pyx | 8 ++++---- pyspike/cython/python_backend.py | 41 +++++++++++++++++++------------------- test/test_distance.py | 40 ++++++++++++++++++++++++++++++------- 3 files changed, 57 insertions(+), 32 deletions(-) diff --git a/pyspike/cython/cython_distance.pyx b/pyspike/cython/cython_distance.pyx index 2841da8..dc2557f 100644 --- a/pyspike/cython/cython_distance.pyx +++ b/pyspike/cython/cython_distance.pyx @@ -345,9 +345,9 @@ cdef inline double get_tau(double[:] spikes1, double[:] spikes2, m = fmin(m, spikes1[i+1]-spikes1[i]) if j < N2 and j > -1: m = fmin(m, spikes2[j+1]-spikes2[j]) - if i > 1: + if i > 0: m = fmin(m, spikes1[i]-spikes1[i-1]) - if j > 1: + if j > 0: m = fmin(m, spikes2[j]-spikes2[j-1]) m *= 0.5 if max_tau > 0.0: @@ -371,7 +371,7 @@ def coincidence_cython(double[:] spikes1, double[:] spikes2, cdef double[:] mp = np.ones(N1 + N2 + 2) # multiplicity cdef double tau while i + j < N1 + N2 - 2: - if (i < N1-1) and (spikes1[i+1] < spikes2[j+1] or j == N2-1): + if (i < N1-1) and (j == N2-1 or spikes1[i+1] < spikes2[j+1]): i += 1 n += 1 tau = get_tau(spikes1, spikes2, i, j, max_tau) @@ -381,7 +381,7 @@ def coincidence_cython(double[:] spikes1, double[:] spikes2, # both get marked with 1 c[n] = 1 c[n-1] = 1 - elif (j < N2-1) and (spikes1[i+1] > spikes2[j+1] or i == N1-1): + elif (j < N2-1) and (i == N1-1 or spikes1[i+1] > spikes2[j+1]): j += 1 n += 1 tau = get_tau(spikes1, spikes2, i, j, max_tau) diff --git a/pyspike/cython/python_backend.py b/pyspike/cython/python_backend.py index bcf9c30..c65bfb0 100644 --- a/pyspike/cython/python_backend.py +++ b/pyspike/cython/python_backend.py @@ -330,47 +330,48 @@ def cumulative_sync_python(spikes1, spikes2): ############################################################ # coincidence_python ############################################################ -def coincidence_python(spikes1, spikes2, max_tau): +def coincidence_python(spikes1, spikes2, t_start, t_end, max_tau): def get_tau(spikes1, spikes2, i, j, max_tau): m = 1E100 # some huge number - if i < len(spikes1)-2: + if i < len(spikes1)-1 and i > -1: m = min(m, spikes1[i+1]-spikes1[i]) - if j < len(spikes2)-2: + if j < len(spikes2)-1 and j > -1: m = min(m, spikes2[j+1]-spikes2[j]) - if i > 1: + if i > 0: m = min(m, spikes1[i]-spikes1[i-1]) - if j > 1: + if j > 0: m = min(m, spikes2[j]-spikes2[j-1]) m *= 0.5 if max_tau > 0.0: m = min(m, max_tau) return m + N1 = len(spikes1) N2 = len(spikes2) - i = 0 - j = 0 + i = -1 + j = -1 n = 0 - st = np.zeros(N1 + N2 - 2) # spike times - c = np.zeros(N1 + N2 - 2) # coincidences - mp = np.ones(N1 + N2 - 2) # multiplicity - while n < N1 + N2 - 2: - if spikes1[i+1] < spikes2[j+1]: + st = np.zeros(N1 + N2 + 2) # spike times + c = np.zeros(N1 + N2 + 2) # coincidences + mp = np.ones(N1 + N2 + 2) # multiplicity + while i + j < N1 + N2 - 2: + if (i < N1-1) and (j == N2-1 or spikes1[i+1] < spikes2[j+1]): i += 1 n += 1 tau = get_tau(spikes1, spikes2, i, j, max_tau) st[n] = spikes1[i] - if j > 0 and spikes1[i]-spikes2[j] < tau: + if j > -1 and spikes1[i]-spikes2[j] < tau: # coincidence between the current spike and the previous spike # both get marked with 1 c[n] = 1 c[n-1] = 1 - elif spikes1[i+1] > spikes2[j+1]: + elif (j < N2-1) and (i == N1-1 or spikes1[i+1] > spikes2[j+1]): j += 1 n += 1 tau = get_tau(spikes1, spikes2, i, j, max_tau) st[n] = spikes2[j] - if i > 0 and spikes2[j]-spikes1[i] < tau: + if i > -1 and spikes2[j]-spikes1[i] < tau: # coincidence between the current spike and the previous spike # both get marked with 1 c[n] = 1 @@ -379,8 +380,6 @@ def coincidence_python(spikes1, spikes2, max_tau): # advance in both spike trains j += 1 i += 1 - if i == N1-1 or j == N2-1: - break n += 1 # add only one event, but with coincidence 2 and multiplicity 2 st[n] = spikes1[i] @@ -391,12 +390,12 @@ def coincidence_python(spikes1, spikes2, max_tau): c = c[:n+2] mp = mp[:n+2] - st[0] = spikes1[0] - st[-1] = spikes1[-1] + st[0] = t_start + st[len(st)-1] = t_end c[0] = c[1] - c[-1] = c[-2] + c[len(c)-1] = c[len(c)-2] mp[0] = mp[1] - mp[-1] = mp[-2] + mp[len(mp)-1] = mp[len(mp)-2] return st, c, mp diff --git a/test/test_distance.py b/test/test_distance.py index 0fff840..88cf40e 100644 --- a/test/test_distance.py +++ b/test/test_distance.py @@ -177,6 +177,18 @@ def test_spike_sync(): assert_almost_equal(spk.spike_sync(spikes1, spikes2), 0.5, decimal=16) + spikes2 = SpikeTrain([3.0], 4.0) + assert_almost_equal(spk.spike_sync(spikes1, spikes2), + 0.5, decimal=16) + + spikes2 = SpikeTrain([1.0], 4.0) + assert_almost_equal(spk.spike_sync(spikes1, spikes2), + 0.5, decimal=16) + + spikes2 = SpikeTrain([1.5, 3.0], 4.0) + assert_almost_equal(spk.spike_sync(spikes1, spikes2), + 0.4, decimal=16) + def check_multi_profile(profile_func, profile_func_multi): # generate spike trains: @@ -250,19 +262,28 @@ def test_multi_spike_sync(): # multivariate regression test spike_trains = spk.load_spike_trains_from_txt("test/SPIKE_Sync_Test.txt", - interval=(0, 4000)) - print(spike_trains[0].spikes) + interval=[0, 4000]) + # extract all spike times + spike_times = np.array([]) + for st in spike_trains: + spike_times = np.append(spike_times, st.spikes) + spike_times = np.unique(np.sort(spike_times)) + f = spk.spike_sync_profile_multi(spike_trains) + + assert_equal(spike_times, f.x[1:-1]) + assert_equal(len(f.x), len(f.y)) + assert_equal(np.sum(f.y[1:-1]), 39932) assert_equal(np.sum(f.mp[1:-1]), 85554) def check_dist_matrix(dist_func, dist_matrix_func): # generate spike trains: - t1 = spk.add_auxiliary_spikes(np.array([0.2, 0.4, 0.6, 0.7]), 1.0) - t2 = spk.add_auxiliary_spikes(np.array([0.3, 0.45, 0.8, 0.9, 0.95]), 1.0) - t3 = spk.add_auxiliary_spikes(np.array([0.2, 0.4, 0.6]), 1.0) - t4 = spk.add_auxiliary_spikes(np.array([0.1, 0.4, 0.5, 0.6]), 1.0) + t1 = SpikeTrain([0.2, 0.4, 0.6, 0.7], 1.0) + t2 = SpikeTrain([0.3, 0.45, 0.8, 0.9, 0.95], 1.0) + t3 = SpikeTrain([0.2, 0.4, 0.6], 1.0) + t4 = SpikeTrain([0.1, 0.4, 0.5, 0.6], 1.0) spike_trains = [t1, t2, t3, t4] f12 = dist_func(t1, t2) @@ -340,4 +361,9 @@ if __name__ == "__main__": test_spike_sync() test_multi_isi() test_multi_spike() - # test_multi_spike_sync() + test_multi_spike_sync() + test_isi_matrix() + test_spike_matrix() + test_spike_sync_matrix() + test_regression_spiky() + test_multi_variate_subsets() -- cgit v1.2.3