summaryrefslogtreecommitdiff
path: root/pyspike/generic.py
diff options
context:
space:
mode:
Diffstat (limited to 'pyspike/generic.py')
-rw-r--r--pyspike/generic.py77
1 files changed, 70 insertions, 7 deletions
diff --git a/pyspike/generic.py b/pyspike/generic.py
index 4f278d2..41affcb 100644
--- a/pyspike/generic.py
+++ b/pyspike/generic.py
@@ -31,6 +31,69 @@ def _generic_profile_multi(spike_trains, pair_distance_func, indices=None):
Returns:
- The averaged multi-variate distance of all pairs
"""
+
+ def divide_and_conquer(pairs1, pairs2):
+ """ recursive calls by splitting the two lists in half.
+ """
+ L1 = len(pairs1)
+ if L1 > 1:
+ dist_prof1 = divide_and_conquer(pairs1[:L1/2], pairs1[L1/2:])
+ else:
+ dist_prof1 = pair_distance_func(spike_trains[pairs1[0][0]],
+ spike_trains[pairs1[0][1]])
+ L2 = len(pairs2)
+ if L2 > 1:
+ dist_prof2 = divide_and_conquer(pairs2[:L2/2], pairs2[L2/2:])
+ else:
+ dist_prof2 = pair_distance_func(spike_trains[pairs2[0][0]],
+ spike_trains[pairs2[0][1]])
+ dist_prof1.add(dist_prof2)
+ return dist_prof1
+
+ if indices is None:
+ indices = np.arange(len(spike_trains))
+ indices = np.array(indices)
+ # check validity of indices
+ assert (indices < len(spike_trains)).all() and (indices >= 0).all(), \
+ "Invalid index list."
+ # generate a list of possible index pairs
+ pairs = [(indices[i], j) for i in range(len(indices))
+ for j in indices[i+1:]]
+
+ L = len(pairs)
+ if L > 1:
+ # recursive iteration through the list of pairs to get average profile
+ avrg_dist = divide_and_conquer(pairs[:len(pairs)/2],
+ pairs[len(pairs)/2:])
+ else:
+ avrg_dist = pair_distance_func(spike_trains[pairs[0][0]],
+ spike_trains[pairs[0][1]])
+
+ return avrg_dist, L
+
+
+############################################################
+# _generic_distance_multi
+############################################################
+def _generic_distance_multi(spike_trains, pair_distance_func,
+ indices=None, interval=None):
+ """ Internal implementation detail, don't call this function directly,
+ use isi_distance_multi or spike_distance_multi instead.
+
+ Computes the multi-variate distance for a set of spike-trains using the
+ pair_dist_func to compute pair-wise distances. That is it computes the
+ average distance of all pairs of spike-trains:
+ :math:`S(t) = 2/((N(N-1)) sum_{<i,j>} D_{i,j}`,
+ where the sum goes over all pairs <i,j>.
+ Args:
+ - spike_trains: list of spike trains
+ - pair_distance_func: function computing the distance of two spike trains
+ - indices: list of indices defining which spike trains to use,
+ if None all given spike trains are used (default=None)
+ Returns:
+ - The averaged multi-variate distance of all pairs
+ """
+
if indices is None:
indices = np.arange(len(spike_trains))
indices = np.array(indices)
@@ -40,13 +103,13 @@ def _generic_profile_multi(spike_trains, pair_distance_func, indices=None):
# generate a list of possible index pairs
pairs = [(indices[i], j) for i in range(len(indices))
for j in indices[i+1:]]
- # start with first pair
- (i, j) = pairs[0]
- average_dist = pair_distance_func(spike_trains[i], spike_trains[j])
- for (i, j) in pairs[1:]:
- current_dist = pair_distance_func(spike_trains[i], spike_trains[j])
- average_dist.add(current_dist) # add to the average
- return average_dist, len(pairs)
+
+ avrg_dist = 0.0
+ for (i, j) in pairs:
+ avrg_dist += pair_distance_func(spike_trains[i], spike_trains[j],
+ interval)
+
+ return avrg_dist/len(pairs)
############################################################