summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMario Mulansky <mario.mulansky@gmx.net>2015-01-22 16:19:37 +0100
committerMario Mulansky <mario.mulansky@gmx.net>2015-01-22 16:19:37 +0100
commit27121e76e92a244c484b970f61da990eff19dc1d (patch)
tree3407050f25ab7b2c0561208b6376db956693834f
parent7bec7f0fe1c40146e8b45757d4c156a4f9b49001 (diff)
parente30a16a76a78aad51c59972b6c5eae3dd74f0459 (diff)
Merge branch 'richardtomsett-master'
-rw-r--r--pyspike/distances.py6
-rw-r--r--test/test_distance.py23
2 files changed, 26 insertions, 3 deletions
diff --git a/pyspike/distances.py b/pyspike/distances.py
index 2cac4bc..5476b6f 100644
--- a/pyspike/distances.py
+++ b/pyspike/distances.py
@@ -212,7 +212,7 @@ def _generic_profile_multi(spike_trains, pair_distance_func, indices=None):
assert (indices < len(spike_trains)).all() and (indices >= 0).all(), \
"Invalid index list."
# generate a list of possible index pairs
- pairs = [(i, j) for i in indices for j in indices[i+1:]]
+ pairs = [(indices[i], j) for i in range(len(indices)) for j in indices[i+1:]]
# start with first pair
(i, j) = pairs[0]
average_dist = pair_distance_func(spike_trains[i], spike_trains[j])
@@ -251,7 +251,7 @@ def _multi_distance_par(spike_trains, pair_distance_func, indices=None):
assert (indices < len(spike_trains)).all() and (indices >= 0).all(), \
"Invalid index list."
# generate a list of possible index pairs
- pairs = [(i, j) for i in indices for j in indices[i+1:]]
+ pairs = [(indices[i], j) for i in range(len(indices)) for j in indices[i+1:]]
num_pairs = len(pairs)
# start with first pair
@@ -430,7 +430,7 @@ def _generic_distance_matrix(spike_trains, dist_function,
assert (indices < len(spike_trains)).all() and (indices >= 0).all(), \
"Invalid index list."
# generate a list of possible index pairs
- pairs = [(i, j) for i in indices for j in indices[i+1:]]
+ pairs = [(indices[i], j) for i in range(len(indices)) for j in indices[i+1:]]
distance_matrix = np.zeros((len(indices), len(indices)))
for i, j in pairs:
diff --git a/test/test_distance.py b/test/test_distance.py
index 2650313..41f625e 100644
--- a/test/test_distance.py
+++ b/test/test_distance.py
@@ -172,6 +172,10 @@ def check_multi_profile(profile_func, profile_func_multi):
f_multi = profile_func_multi(spike_trains, [0, 1])
assert f_multi.almost_equal(f12, decimal=14)
+ f_multi1 = profile_func_multi(spike_trains, [1, 2, 3])
+ f_multi2 = profile_func_multi(spike_trains[1:])
+ assert f_multi1.almost_equal(f_multi2, decimal=14)
+
f = copy(f12)
f.add(f13)
f.add(f23)
@@ -289,6 +293,25 @@ def test_regression_spiky():
# assert_equal(spike_dist, 0.2445)
+def test_multi_variate_subsets():
+ spike_trains = spk.load_spike_trains_from_txt("test/PySpike_testdata.txt",
+ (0.0, 4000.0))
+ sub_set = [1, 3, 5, 7]
+ spike_trains_sub_set = [spike_trains[i] for i in sub_set]
+
+ v1 = spk.isi_distance_multi(spike_trains_sub_set)
+ v2 = spk.isi_distance_multi(spike_trains, sub_set)
+ assert_equal(v1, v2)
+
+ v1 = spk.spike_distance_multi(spike_trains_sub_set)
+ v2 = spk.spike_distance_multi(spike_trains, sub_set)
+ assert_equal(v1, v2)
+
+ v1 = spk.spike_sync_multi(spike_trains_sub_set)
+ v2 = spk.spike_sync_multi(spike_trains, sub_set)
+ assert_equal(v1, v2)
+
+
if __name__ == "__main__":
test_isi()
test_spike()