summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMario Mulansky <mario.mulansky@gmx.net>2015-04-24 15:58:35 +0200
committerMario Mulansky <mario.mulansky@gmx.net>2015-04-24 15:58:35 +0200
commit36d80c9ec1d28488f9b5c97cd202c196efff694e (patch)
treed1eadedf101057b1628353e45a92c78d276c8556
parent3bf9e12e6b5667fb1ea72c969848dacaff3cb470 (diff)
distance tests now pass with new spike trains
-rw-r--r--pyspike/cython/cython_distance.pyx8
-rw-r--r--pyspike/cython/python_backend.py41
-rw-r--r--test/test_distance.py40
3 files changed, 57 insertions, 32 deletions
diff --git a/pyspike/cython/cython_distance.pyx b/pyspike/cython/cython_distance.pyx
index 2841da8..dc2557f 100644
--- a/pyspike/cython/cython_distance.pyx
+++ b/pyspike/cython/cython_distance.pyx
@@ -345,9 +345,9 @@ cdef inline double get_tau(double[:] spikes1, double[:] spikes2,
m = fmin(m, spikes1[i+1]-spikes1[i])
if j < N2 and j > -1:
m = fmin(m, spikes2[j+1]-spikes2[j])
- if i > 1:
+ if i > 0:
m = fmin(m, spikes1[i]-spikes1[i-1])
- if j > 1:
+ if j > 0:
m = fmin(m, spikes2[j]-spikes2[j-1])
m *= 0.5
if max_tau > 0.0:
@@ -371,7 +371,7 @@ def coincidence_cython(double[:] spikes1, double[:] spikes2,
cdef double[:] mp = np.ones(N1 + N2 + 2) # multiplicity
cdef double tau
while i + j < N1 + N2 - 2:
- if (i < N1-1) and (spikes1[i+1] < spikes2[j+1] or j == N2-1):
+ if (i < N1-1) and (j == N2-1 or spikes1[i+1] < spikes2[j+1]):
i += 1
n += 1
tau = get_tau(spikes1, spikes2, i, j, max_tau)
@@ -381,7 +381,7 @@ def coincidence_cython(double[:] spikes1, double[:] spikes2,
# both get marked with 1
c[n] = 1
c[n-1] = 1
- elif (j < N2-1) and (spikes1[i+1] > spikes2[j+1] or i == N1-1):
+ elif (j < N2-1) and (i == N1-1 or spikes1[i+1] > spikes2[j+1]):
j += 1
n += 1
tau = get_tau(spikes1, spikes2, i, j, max_tau)
diff --git a/pyspike/cython/python_backend.py b/pyspike/cython/python_backend.py
index bcf9c30..c65bfb0 100644
--- a/pyspike/cython/python_backend.py
+++ b/pyspike/cython/python_backend.py
@@ -330,47 +330,48 @@ def cumulative_sync_python(spikes1, spikes2):
############################################################
# coincidence_python
############################################################
-def coincidence_python(spikes1, spikes2, max_tau):
+def coincidence_python(spikes1, spikes2, t_start, t_end, max_tau):
def get_tau(spikes1, spikes2, i, j, max_tau):
m = 1E100 # some huge number
- if i < len(spikes1)-2:
+ if i < len(spikes1)-1 and i > -1:
m = min(m, spikes1[i+1]-spikes1[i])
- if j < len(spikes2)-2:
+ if j < len(spikes2)-1 and j > -1:
m = min(m, spikes2[j+1]-spikes2[j])
- if i > 1:
+ if i > 0:
m = min(m, spikes1[i]-spikes1[i-1])
- if j > 1:
+ if j > 0:
m = min(m, spikes2[j]-spikes2[j-1])
m *= 0.5
if max_tau > 0.0:
m = min(m, max_tau)
return m
+
N1 = len(spikes1)
N2 = len(spikes2)
- i = 0
- j = 0
+ i = -1
+ j = -1
n = 0
- st = np.zeros(N1 + N2 - 2) # spike times
- c = np.zeros(N1 + N2 - 2) # coincidences
- mp = np.ones(N1 + N2 - 2) # multiplicity
- while n < N1 + N2 - 2:
- if spikes1[i+1] < spikes2[j+1]:
+ st = np.zeros(N1 + N2 + 2) # spike times
+ c = np.zeros(N1 + N2 + 2) # coincidences
+ mp = np.ones(N1 + N2 + 2) # multiplicity
+ while i + j < N1 + N2 - 2:
+ if (i < N1-1) and (j == N2-1 or spikes1[i+1] < spikes2[j+1]):
i += 1
n += 1
tau = get_tau(spikes1, spikes2, i, j, max_tau)
st[n] = spikes1[i]
- if j > 0 and spikes1[i]-spikes2[j] < tau:
+ if j > -1 and spikes1[i]-spikes2[j] < tau:
# coincidence between the current spike and the previous spike
# both get marked with 1
c[n] = 1
c[n-1] = 1
- elif spikes1[i+1] > spikes2[j+1]:
+ elif (j < N2-1) and (i == N1-1 or spikes1[i+1] > spikes2[j+1]):
j += 1
n += 1
tau = get_tau(spikes1, spikes2, i, j, max_tau)
st[n] = spikes2[j]
- if i > 0 and spikes2[j]-spikes1[i] < tau:
+ if i > -1 and spikes2[j]-spikes1[i] < tau:
# coincidence between the current spike and the previous spike
# both get marked with 1
c[n] = 1
@@ -379,8 +380,6 @@ def coincidence_python(spikes1, spikes2, max_tau):
# advance in both spike trains
j += 1
i += 1
- if i == N1-1 or j == N2-1:
- break
n += 1
# add only one event, but with coincidence 2 and multiplicity 2
st[n] = spikes1[i]
@@ -391,12 +390,12 @@ def coincidence_python(spikes1, spikes2, max_tau):
c = c[:n+2]
mp = mp[:n+2]
- st[0] = spikes1[0]
- st[-1] = spikes1[-1]
+ st[0] = t_start
+ st[len(st)-1] = t_end
c[0] = c[1]
- c[-1] = c[-2]
+ c[len(c)-1] = c[len(c)-2]
mp[0] = mp[1]
- mp[-1] = mp[-2]
+ mp[len(mp)-1] = mp[len(mp)-2]
return st, c, mp
diff --git a/test/test_distance.py b/test/test_distance.py
index 0fff840..88cf40e 100644
--- a/test/test_distance.py
+++ b/test/test_distance.py
@@ -177,6 +177,18 @@ def test_spike_sync():
assert_almost_equal(spk.spike_sync(spikes1, spikes2),
0.5, decimal=16)
+ spikes2 = SpikeTrain([3.0], 4.0)
+ assert_almost_equal(spk.spike_sync(spikes1, spikes2),
+ 0.5, decimal=16)
+
+ spikes2 = SpikeTrain([1.0], 4.0)
+ assert_almost_equal(spk.spike_sync(spikes1, spikes2),
+ 0.5, decimal=16)
+
+ spikes2 = SpikeTrain([1.5, 3.0], 4.0)
+ assert_almost_equal(spk.spike_sync(spikes1, spikes2),
+ 0.4, decimal=16)
+
def check_multi_profile(profile_func, profile_func_multi):
# generate spike trains:
@@ -250,19 +262,28 @@ def test_multi_spike_sync():
# multivariate regression test
spike_trains = spk.load_spike_trains_from_txt("test/SPIKE_Sync_Test.txt",
- interval=(0, 4000))
- print(spike_trains[0].spikes)
+ interval=[0, 4000])
+ # extract all spike times
+ spike_times = np.array([])
+ for st in spike_trains:
+ spike_times = np.append(spike_times, st.spikes)
+ spike_times = np.unique(np.sort(spike_times))
+
f = spk.spike_sync_profile_multi(spike_trains)
+
+ assert_equal(spike_times, f.x[1:-1])
+ assert_equal(len(f.x), len(f.y))
+
assert_equal(np.sum(f.y[1:-1]), 39932)
assert_equal(np.sum(f.mp[1:-1]), 85554)
def check_dist_matrix(dist_func, dist_matrix_func):
# generate spike trains:
- t1 = spk.add_auxiliary_spikes(np.array([0.2, 0.4, 0.6, 0.7]), 1.0)
- t2 = spk.add_auxiliary_spikes(np.array([0.3, 0.45, 0.8, 0.9, 0.95]), 1.0)
- t3 = spk.add_auxiliary_spikes(np.array([0.2, 0.4, 0.6]), 1.0)
- t4 = spk.add_auxiliary_spikes(np.array([0.1, 0.4, 0.5, 0.6]), 1.0)
+ t1 = SpikeTrain([0.2, 0.4, 0.6, 0.7], 1.0)
+ t2 = SpikeTrain([0.3, 0.45, 0.8, 0.9, 0.95], 1.0)
+ t3 = SpikeTrain([0.2, 0.4, 0.6], 1.0)
+ t4 = SpikeTrain([0.1, 0.4, 0.5, 0.6], 1.0)
spike_trains = [t1, t2, t3, t4]
f12 = dist_func(t1, t2)
@@ -340,4 +361,9 @@ if __name__ == "__main__":
test_spike_sync()
test_multi_isi()
test_multi_spike()
- # test_multi_spike_sync()
+ test_multi_spike_sync()
+ test_isi_matrix()
+ test_spike_matrix()
+ test_spike_sync_matrix()
+ test_regression_spiky()
+ test_multi_variate_subsets()