From 5970a9cfdbecc1af232b7ffe485bdc057591a2b8 Mon Sep 17 00:00:00 2001 From: Mario Mulansky Date: Thu, 16 Oct 2014 14:50:26 +0200 Subject: added spike_matrix, refactoring dist matrix functs --- pyspike/__init__.py | 2 +- pyspike/distances.py | 52 ++++++++++++++++++++++++++++++++++-------- test/test_distance.py | 62 +++++++++++++++++++++++++++++++++++++++++---------- test/test_spikes.py | 1 + 4 files changed, 95 insertions(+), 22 deletions(-) diff --git a/pyspike/__init__.py b/pyspike/__init__.py index 5146507..d2d5b57 100644 --- a/pyspike/__init__.py +++ b/pyspike/__init__.py @@ -10,6 +10,6 @@ from function import PieceWiseConstFunc, PieceWiseLinFunc, average_profile from distances import isi_profile, isi_distance, \ spike_profile, spike_distance, \ isi_profile_multi, isi_distance_multi, isi_distance_matrix, \ - spike_profile_multi, spike_distance_multi + spike_profile_multi, spike_distance_multi, spike_distance_matrix from spikes import add_auxiliary_spikes, load_spike_trains_from_txt, \ spike_train_from_string, merge_spike_trains diff --git a/pyspike/distances.py b/pyspike/distances.py index 9056863..7d7044b 100644 --- a/pyspike/distances.py +++ b/pyspike/distances.py @@ -99,9 +99,9 @@ def spike_distance(spikes1, spikes2): ############################################################ -# multi_profile +# generic_profile_multi ############################################################ -def multi_profile(spike_trains, pair_distance_func, indices=None): +def generic_profile_multi(spike_trains, pair_distance_func, indices=None): """ Internal implementation detail, don't call this function directly, use isi_profile_multi or spike_profile_multi instead. @@ -203,7 +203,7 @@ def isi_profile_multi(spike_trains, indices=None): Returns: - A PieceWiseConstFunc representing the averaged isi distance S_isi(t) """ - return multi_profile(spike_trains, isi_profile, indices) + return generic_profile_multi(spike_trains, isi_profile, indices) ############################################################ @@ -239,7 +239,7 @@ def spike_profile_multi(spike_trains, indices=None): Returns: - A PieceWiseLinFunc representing the averaged spike distance S(t) """ - return multi_profile(spike_trains, spike_profile, indices) + return generic_profile_multi(spike_trains, spike_profile, indices) ############################################################ @@ -261,17 +261,19 @@ def spike_distance_multi(spike_trains, indices=None): ############################################################ -# isi_distance_matrix +# generic_distance_matrix ############################################################ -def isi_distance_matrix(spike_trains, indices=None): - """ Computes the average isi-distance of all pairs of spike-trains. +def generic_distance_matrix(spike_trains, dist_function, indices=None): + """ Internal implementation detail. Don't use this function directly. + Instead use isi_distance_matrix or spike_distance_matrix. + Computes the time averaged distance of all pairs of spike-trains. Args: - spike_trains: list of spike trains - indices: list of indices defining which spike-trains to use if None all given spike-trains are used (default=None) Return: - a 2D array of size len(indices)*len(indices) containing the average - pair-wise isi-distance + pair-wise distance """ if indices is None: indices = np.arange(len(spike_trains)) @@ -284,7 +286,39 @@ def isi_distance_matrix(spike_trains, indices=None): distance_matrix = np.zeros((len(indices), len(indices))) for i, j in pairs: - d = isi_distance(spike_trains[i], spike_trains[j]) + d = dist_function(spike_trains[i], spike_trains[j]) distance_matrix[i, j] = d distance_matrix[j, i] = d return distance_matrix + + +############################################################ +# isi_distance_matrix +############################################################ +def isi_distance_matrix(spike_trains, indices=None): + """ Computes the time averaged isi-distance of all pairs of spike-trains. + Args: + - spike_trains: list of spike trains + - indices: list of indices defining which spike-trains to use + if None all given spike-trains are used (default=None) + Return: + - a 2D array of size len(indices)*len(indices) containing the average + pair-wise isi-distance + """ + return generic_distance_matrix(spike_trains, isi_distance, indices) + + +############################################################ +# spike_distance_matrix +############################################################ +def spike_distance_matrix(spike_trains, indices=None): + """ Computes the time averaged spike-distance of all pairs of spike-trains. + Args: + - spike_trains: list of spike trains + - indices: list of indices defining which spike-trains to use + if None all given spike-trains are used (default=None) + Return: + - a 2D array of size len(indices)*len(indices) containing the average + pair-wise spike-distance + """ + return generic_distance_matrix(spike_trains, spike_distance, indices) diff --git a/test/test_distance.py b/test/test_distance.py index 2a6bf4e..7be0d9b 100644 --- a/test/test_distance.py +++ b/test/test_distance.py @@ -130,7 +130,7 @@ def test_spike(): decimal=16) -def check_multi_distance(dist_func, dist_func_multi): +def check_multi_profile(profile_func, profile_func_multi): # generate spike trains: t1 = spk.add_auxiliary_spikes(np.array([0.2, 0.4, 0.6, 0.7]), 1.0) t2 = spk.add_auxiliary_spikes(np.array([0.3, 0.45, 0.8, 0.9, 0.95]), 1.0) @@ -138,21 +138,21 @@ def check_multi_distance(dist_func, dist_func_multi): t4 = spk.add_auxiliary_spikes(np.array([0.1, 0.4, 0.5, 0.6]), 1.0) spike_trains = [t1, t2, t3, t4] - f12 = dist_func(t1, t2) - f13 = dist_func(t1, t3) - f14 = dist_func(t1, t4) - f23 = dist_func(t2, t3) - f24 = dist_func(t2, t4) - f34 = dist_func(t3, t4) + f12 = profile_func(t1, t2) + f13 = profile_func(t1, t3) + f14 = profile_func(t1, t4) + f23 = profile_func(t2, t3) + f24 = profile_func(t2, t4) + f34 = profile_func(t3, t4) - f_multi = dist_func_multi(spike_trains, [0, 1]) + f_multi = profile_func_multi(spike_trains, [0, 1]) assert f_multi.almost_equal(f12, decimal=14) f = copy(f12) f.add(f13) f.add(f23) f.mul_scalar(1.0/3) - f_multi = dist_func_multi(spike_trains, [0, 1, 2]) + f_multi = profile_func_multi(spike_trains, [0, 1, 2]) assert f_multi.almost_equal(f, decimal=14) f.mul_scalar(3) # revert above normalization @@ -160,16 +160,54 @@ def check_multi_distance(dist_func, dist_func_multi): f.add(f24) f.add(f34) f.mul_scalar(1.0/6) - f_multi = dist_func_multi(spike_trains) + f_multi = profile_func_multi(spike_trains) assert f_multi.almost_equal(f, decimal=14) def test_multi_isi(): - check_multi_distance(spk.isi_profile, spk.isi_profile_multi) + check_multi_profile(spk.isi_profile, spk.isi_profile_multi) def test_multi_spike(): - check_multi_distance(spk.spike_profile, spk.spike_profile_multi) + check_multi_profile(spk.spike_profile, spk.spike_profile_multi) + + +def check_dist_matrix(dist_func, dist_matrix_func): + # generate spike trains: + t1 = spk.add_auxiliary_spikes(np.array([0.2, 0.4, 0.6, 0.7]), 1.0) + t2 = spk.add_auxiliary_spikes(np.array([0.3, 0.45, 0.8, 0.9, 0.95]), 1.0) + t3 = spk.add_auxiliary_spikes(np.array([0.2, 0.4, 0.6]), 1.0) + t4 = spk.add_auxiliary_spikes(np.array([0.1, 0.4, 0.5, 0.6]), 1.0) + spike_trains = [t1, t2, t3, t4] + + f12 = dist_func(t1, t2) + f13 = dist_func(t1, t3) + f14 = dist_func(t1, t4) + f23 = dist_func(t2, t3) + f24 = dist_func(t2, t4) + f34 = dist_func(t3, t4) + + f_matrix = dist_matrix_func(spike_trains) + # check zero diagonal + for i in xrange(4): + assert_equal(0.0, f_matrix[i, i]) + for i in xrange(4): + for j in xrange(i+1, 4): + assert_equal(f_matrix[i, j], f_matrix[j, i]) + assert_equal(f12, f_matrix[1, 0]) + assert_equal(f13, f_matrix[2, 0]) + assert_equal(f14, f_matrix[3, 0]) + assert_equal(f23, f_matrix[2, 1]) + assert_equal(f24, f_matrix[3, 1]) + assert_equal(f34, f_matrix[3, 2]) + + +def test_isi_matrix(): + check_dist_matrix(spk.isi_distance, spk.isi_distance_matrix) + + +def test_spike_matrix(): + check_dist_matrix(spk.spike_distance, spk.spike_distance_matrix) def test_regression_spiky(): diff --git a/test/test_spikes.py b/test/test_spikes.py index d650d5d..b12099e 100644 --- a/test/test_spikes.py +++ b/test/test_spikes.py @@ -66,6 +66,7 @@ def test_merge_spike_trains(): # first load the data spike_trains = spk.load_spike_trains_from_txt("test/PySpike_testdata.txt", time_interval=(0, 4000)) + spikes = spk.merge_spike_trains([spike_trains[0], spike_trains[1]]) # test if result is sorted assert((spikes == np.sort(spikes)).all()) -- cgit v1.2.3