summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMario Mulansky <mario.mulansky@gmx.net>2014-10-16 14:50:26 +0200
committerMario Mulansky <mario.mulansky@gmx.net>2014-10-16 14:50:26 +0200
commit5970a9cfdbecc1af232b7ffe485bdc057591a2b8 (patch)
tree4ec6c23cd624bb33b0e87821541689874e659983
parentd869d4d822c651ea3d094eaf17ba7732bf91136f (diff)
added spike_matrix, refactoring dist matrix functs
-rw-r--r--pyspike/__init__.py2
-rw-r--r--pyspike/distances.py52
-rw-r--r--test/test_distance.py62
-rw-r--r--test/test_spikes.py1
4 files changed, 95 insertions, 22 deletions
diff --git a/pyspike/__init__.py b/pyspike/__init__.py
index 5146507..d2d5b57 100644
--- a/pyspike/__init__.py
+++ b/pyspike/__init__.py
@@ -10,6 +10,6 @@ from function import PieceWiseConstFunc, PieceWiseLinFunc, average_profile
from distances import isi_profile, isi_distance, \
spike_profile, spike_distance, \
isi_profile_multi, isi_distance_multi, isi_distance_matrix, \
- spike_profile_multi, spike_distance_multi
+ spike_profile_multi, spike_distance_multi, spike_distance_matrix
from spikes import add_auxiliary_spikes, load_spike_trains_from_txt, \
spike_train_from_string, merge_spike_trains
diff --git a/pyspike/distances.py b/pyspike/distances.py
index 9056863..7d7044b 100644
--- a/pyspike/distances.py
+++ b/pyspike/distances.py
@@ -99,9 +99,9 @@ def spike_distance(spikes1, spikes2):
############################################################
-# multi_profile
+# generic_profile_multi
############################################################
-def multi_profile(spike_trains, pair_distance_func, indices=None):
+def generic_profile_multi(spike_trains, pair_distance_func, indices=None):
""" Internal implementation detail, don't call this function directly,
use isi_profile_multi or spike_profile_multi instead.
@@ -203,7 +203,7 @@ def isi_profile_multi(spike_trains, indices=None):
Returns:
- A PieceWiseConstFunc representing the averaged isi distance S_isi(t)
"""
- return multi_profile(spike_trains, isi_profile, indices)
+ return generic_profile_multi(spike_trains, isi_profile, indices)
############################################################
@@ -239,7 +239,7 @@ def spike_profile_multi(spike_trains, indices=None):
Returns:
- A PieceWiseLinFunc representing the averaged spike distance S(t)
"""
- return multi_profile(spike_trains, spike_profile, indices)
+ return generic_profile_multi(spike_trains, spike_profile, indices)
############################################################
@@ -261,17 +261,19 @@ def spike_distance_multi(spike_trains, indices=None):
############################################################
-# isi_distance_matrix
+# generic_distance_matrix
############################################################
-def isi_distance_matrix(spike_trains, indices=None):
- """ Computes the average isi-distance of all pairs of spike-trains.
+def generic_distance_matrix(spike_trains, dist_function, indices=None):
+ """ Internal implementation detail. Don't use this function directly.
+ Instead use isi_distance_matrix or spike_distance_matrix.
+ Computes the time averaged distance of all pairs of spike-trains.
Args:
- spike_trains: list of spike trains
- indices: list of indices defining which spike-trains to use
if None all given spike-trains are used (default=None)
Return:
- a 2D array of size len(indices)*len(indices) containing the average
- pair-wise isi-distance
+ pair-wise distance
"""
if indices is None:
indices = np.arange(len(spike_trains))
@@ -284,7 +286,39 @@ def isi_distance_matrix(spike_trains, indices=None):
distance_matrix = np.zeros((len(indices), len(indices)))
for i, j in pairs:
- d = isi_distance(spike_trains[i], spike_trains[j])
+ d = dist_function(spike_trains[i], spike_trains[j])
distance_matrix[i, j] = d
distance_matrix[j, i] = d
return distance_matrix
+
+
+############################################################
+# isi_distance_matrix
+############################################################
+def isi_distance_matrix(spike_trains, indices=None):
+ """ Computes the time averaged isi-distance of all pairs of spike-trains.
+ Args:
+ - spike_trains: list of spike trains
+ - indices: list of indices defining which spike-trains to use
+ if None all given spike-trains are used (default=None)
+ Return:
+ - a 2D array of size len(indices)*len(indices) containing the average
+ pair-wise isi-distance
+ """
+ return generic_distance_matrix(spike_trains, isi_distance, indices)
+
+
+############################################################
+# spike_distance_matrix
+############################################################
+def spike_distance_matrix(spike_trains, indices=None):
+ """ Computes the time averaged spike-distance of all pairs of spike-trains.
+ Args:
+ - spike_trains: list of spike trains
+ - indices: list of indices defining which spike-trains to use
+ if None all given spike-trains are used (default=None)
+ Return:
+ - a 2D array of size len(indices)*len(indices) containing the average
+ pair-wise spike-distance
+ """
+ return generic_distance_matrix(spike_trains, spike_distance, indices)
diff --git a/test/test_distance.py b/test/test_distance.py
index 2a6bf4e..7be0d9b 100644
--- a/test/test_distance.py
+++ b/test/test_distance.py
@@ -130,7 +130,7 @@ def test_spike():
decimal=16)
-def check_multi_distance(dist_func, dist_func_multi):
+def check_multi_profile(profile_func, profile_func_multi):
# generate spike trains:
t1 = spk.add_auxiliary_spikes(np.array([0.2, 0.4, 0.6, 0.7]), 1.0)
t2 = spk.add_auxiliary_spikes(np.array([0.3, 0.45, 0.8, 0.9, 0.95]), 1.0)
@@ -138,21 +138,21 @@ def check_multi_distance(dist_func, dist_func_multi):
t4 = spk.add_auxiliary_spikes(np.array([0.1, 0.4, 0.5, 0.6]), 1.0)
spike_trains = [t1, t2, t3, t4]
- f12 = dist_func(t1, t2)
- f13 = dist_func(t1, t3)
- f14 = dist_func(t1, t4)
- f23 = dist_func(t2, t3)
- f24 = dist_func(t2, t4)
- f34 = dist_func(t3, t4)
+ f12 = profile_func(t1, t2)
+ f13 = profile_func(t1, t3)
+ f14 = profile_func(t1, t4)
+ f23 = profile_func(t2, t3)
+ f24 = profile_func(t2, t4)
+ f34 = profile_func(t3, t4)
- f_multi = dist_func_multi(spike_trains, [0, 1])
+ f_multi = profile_func_multi(spike_trains, [0, 1])
assert f_multi.almost_equal(f12, decimal=14)
f = copy(f12)
f.add(f13)
f.add(f23)
f.mul_scalar(1.0/3)
- f_multi = dist_func_multi(spike_trains, [0, 1, 2])
+ f_multi = profile_func_multi(spike_trains, [0, 1, 2])
assert f_multi.almost_equal(f, decimal=14)
f.mul_scalar(3) # revert above normalization
@@ -160,16 +160,54 @@ def check_multi_distance(dist_func, dist_func_multi):
f.add(f24)
f.add(f34)
f.mul_scalar(1.0/6)
- f_multi = dist_func_multi(spike_trains)
+ f_multi = profile_func_multi(spike_trains)
assert f_multi.almost_equal(f, decimal=14)
def test_multi_isi():
- check_multi_distance(spk.isi_profile, spk.isi_profile_multi)
+ check_multi_profile(spk.isi_profile, spk.isi_profile_multi)
def test_multi_spike():
- check_multi_distance(spk.spike_profile, spk.spike_profile_multi)
+ check_multi_profile(spk.spike_profile, spk.spike_profile_multi)
+
+
+def check_dist_matrix(dist_func, dist_matrix_func):
+ # generate spike trains:
+ t1 = spk.add_auxiliary_spikes(np.array([0.2, 0.4, 0.6, 0.7]), 1.0)
+ t2 = spk.add_auxiliary_spikes(np.array([0.3, 0.45, 0.8, 0.9, 0.95]), 1.0)
+ t3 = spk.add_auxiliary_spikes(np.array([0.2, 0.4, 0.6]), 1.0)
+ t4 = spk.add_auxiliary_spikes(np.array([0.1, 0.4, 0.5, 0.6]), 1.0)
+ spike_trains = [t1, t2, t3, t4]
+
+ f12 = dist_func(t1, t2)
+ f13 = dist_func(t1, t3)
+ f14 = dist_func(t1, t4)
+ f23 = dist_func(t2, t3)
+ f24 = dist_func(t2, t4)
+ f34 = dist_func(t3, t4)
+
+ f_matrix = dist_matrix_func(spike_trains)
+ # check zero diagonal
+ for i in xrange(4):
+ assert_equal(0.0, f_matrix[i, i])
+ for i in xrange(4):
+ for j in xrange(i+1, 4):
+ assert_equal(f_matrix[i, j], f_matrix[j, i])
+ assert_equal(f12, f_matrix[1, 0])
+ assert_equal(f13, f_matrix[2, 0])
+ assert_equal(f14, f_matrix[3, 0])
+ assert_equal(f23, f_matrix[2, 1])
+ assert_equal(f24, f_matrix[3, 1])
+ assert_equal(f34, f_matrix[3, 2])
+
+
+def test_isi_matrix():
+ check_dist_matrix(spk.isi_distance, spk.isi_distance_matrix)
+
+
+def test_spike_matrix():
+ check_dist_matrix(spk.spike_distance, spk.spike_distance_matrix)
def test_regression_spiky():
diff --git a/test/test_spikes.py b/test/test_spikes.py
index d650d5d..b12099e 100644
--- a/test/test_spikes.py
+++ b/test/test_spikes.py
@@ -66,6 +66,7 @@ def test_merge_spike_trains():
# first load the data
spike_trains = spk.load_spike_trains_from_txt("test/PySpike_testdata.txt",
time_interval=(0, 4000))
+
spikes = spk.merge_spike_trains([spike_trains[0], spike_trains[1]])
# test if result is sorted
assert((spikes == np.sort(spikes)).all())