summaryrefslogtreecommitdiff
path: root/pyspike/distances.py
diff options
context:
space:
mode:
Diffstat (limited to 'pyspike/distances.py')
-rw-r--r--pyspike/distances.py76
1 files changed, 76 insertions, 0 deletions
diff --git a/pyspike/distances.py b/pyspike/distances.py
index f4be625..52c6640 100644
--- a/pyspike/distances.py
+++ b/pyspike/distances.py
@@ -224,3 +224,79 @@ def spike_distance(spikes1, spikes2):
# could be less than original length due to equal spike times
return PieceWiseLinFunc(spike_events[:index+1],
y_starts[:index], y_ends[:index])
+
+
+
+
+############################################################
+# multi_distance
+############################################################
+def multi_distance(spike_trains, pair_distance_func, indices=None):
+ """ Internal implementation detail, use isi_distance_multi or
+ spike_distance_multi.
+
+ Computes the multi-variate distance for a set of spike-trains using the
+ pair_dist_func to compute pair-wise distances. That is it computes the
+ average distance of all pairs of spike-trains:
+ S(t) = 2/((N(N-1)) sum_{<i,j>} S_{i,j},
+ where the sum goes over all pairs <i,j>.
+ Args:
+ - spike_trains: list of spike trains
+ - pair_distance_func: function computing the distance of two spike trains
+ - indices: list of indices defining which spike trains to use,
+ if None all given spike trains are used (default=None)
+ Returns:
+ - The averaged multi-variate distance of all pairs
+ """
+ if indices==None:
+ indices = np.arange(len(spike_trains))
+ indices = np.array(indices)
+ # check validity of indices
+ assert (indices < len(spike_trains)).all() and (indices >= 0).all(), \
+ "Invalid index list."
+ # generate a list of possible index pairs
+ pairs = [(i,j) for i in indices for j in indices[i+1:]]
+ # start with first pair
+ (i,j) = pairs[0]
+ average_dist = pair_distance_func(spike_trains[i], spike_trains[j])
+ for (i,j) in pairs[1:]:
+ current_dist = pair_distance_func(spike_trains[i], spike_trains[j])
+ average_dist.add(current_dist) # add to the average
+ average_dist.mul_scalar(1.0/len(pairs)) # normalize
+ return average_dist
+
+
+############################################################
+# isi_distance_multi
+############################################################
+def isi_distance_multi(spike_trains, indices=None):
+ """ computes the multi-variate isi-distance for a set of spike-trains. That
+ is the average isi-distance of all pairs of spike-trains:
+ S(t) = 2/((N(N-1)) sum_{<i,j>} S_{i,j},
+ where the sum goes over all pairs <i,j>
+ Args:
+ - spike_trains: list of spike trains
+ - indices: list of indices defining which spike trains to use,
+ if None all given spike trains are used (default=None)
+ Returns:
+ - A PieceWiseConstFunc representing the averaged isi distance S
+ """
+ return multi_distance(spike_trains, isi_distance, indices)
+
+
+############################################################
+# spike_distance_multi
+############################################################
+def spike_distance_multi(spike_trains, indices=None):
+ """ computes the multi-variate spike-distance for a set of spike-trains.
+ That is the average spike-distance of all pairs of spike-trains:
+ S(t) = 2/((N(N-1)) sum_{<i,j>} S_{i,j},
+ where the sum goes over all pairs <i,j>
+ Args:
+ - spike_trains: list of spike trains
+ - indices: list of indices defining which spike trains to use,
+ if None all given spike trains are used (default=None)
+ Returns:
+ - A PieceWiseLinFunc representing the averaged spike distance S
+ """
+ return multi_distance(spike_trains, spike_distance, indices)