diff options
author | Mario Mulansky <mario.mulansky@gmx.net> | 2014-09-29 16:08:45 +0200 |
---|---|---|
committer | Mario Mulansky <mario.mulansky@gmx.net> | 2014-09-29 16:08:45 +0200 |
commit | b726773a29f85d465ff71867fab4fa5b8e5bcfe1 (patch) | |
tree | e9a0add62dfc3f1a27beeaa7de000b8d7614aa72 | |
parent | e4f1c09672068e4778f7b5f3e27b47ff8986863c (diff) |
+ multivariate distances
-rw-r--r-- | pyspike/__init__.py | 3 | ||||
-rw-r--r-- | pyspike/distances.py | 76 | ||||
-rw-r--r-- | pyspike/function.py | 40 | ||||
-rw-r--r-- | test/test_distance.py | 47 |
4 files changed, 158 insertions, 8 deletions
diff --git a/pyspike/__init__.py b/pyspike/__init__.py index 1784037..2143bdc 100644 --- a/pyspike/__init__.py +++ b/pyspike/__init__.py @@ -1,5 +1,6 @@ __all__ = ["function", "distances", "spikes"] from function import PieceWiseConstFunc, PieceWiseLinFunc -from distances import add_auxiliary_spikes, isi_distance, spike_distance +from distances import add_auxiliary_spikes, isi_distance, spike_distance, \ + isi_distance_multi, spike_distance_multi from spikes import spike_train_from_string, merge_spike_trains diff --git a/pyspike/distances.py b/pyspike/distances.py index f4be625..52c6640 100644 --- a/pyspike/distances.py +++ b/pyspike/distances.py @@ -224,3 +224,79 @@ def spike_distance(spikes1, spikes2): # could be less than original length due to equal spike times return PieceWiseLinFunc(spike_events[:index+1], y_starts[:index], y_ends[:index]) + + + + +############################################################ +# multi_distance +############################################################ +def multi_distance(spike_trains, pair_distance_func, indices=None): + """ Internal implementation detail, use isi_distance_multi or + spike_distance_multi. + + Computes the multi-variate distance for a set of spike-trains using the + pair_dist_func to compute pair-wise distances. That is it computes the + average distance of all pairs of spike-trains: + S(t) = 2/((N(N-1)) sum_{<i,j>} S_{i,j}, + where the sum goes over all pairs <i,j>. + Args: + - spike_trains: list of spike trains + - pair_distance_func: function computing the distance of two spike trains + - indices: list of indices defining which spike trains to use, + if None all given spike trains are used (default=None) + Returns: + - The averaged multi-variate distance of all pairs + """ + if indices==None: + indices = np.arange(len(spike_trains)) + indices = np.array(indices) + # check validity of indices + assert (indices < len(spike_trains)).all() and (indices >= 0).all(), \ + "Invalid index list." + # generate a list of possible index pairs + pairs = [(i,j) for i in indices for j in indices[i+1:]] + # start with first pair + (i,j) = pairs[0] + average_dist = pair_distance_func(spike_trains[i], spike_trains[j]) + for (i,j) in pairs[1:]: + current_dist = pair_distance_func(spike_trains[i], spike_trains[j]) + average_dist.add(current_dist) # add to the average + average_dist.mul_scalar(1.0/len(pairs)) # normalize + return average_dist + + +############################################################ +# isi_distance_multi +############################################################ +def isi_distance_multi(spike_trains, indices=None): + """ computes the multi-variate isi-distance for a set of spike-trains. That + is the average isi-distance of all pairs of spike-trains: + S(t) = 2/((N(N-1)) sum_{<i,j>} S_{i,j}, + where the sum goes over all pairs <i,j> + Args: + - spike_trains: list of spike trains + - indices: list of indices defining which spike trains to use, + if None all given spike trains are used (default=None) + Returns: + - A PieceWiseConstFunc representing the averaged isi distance S + """ + return multi_distance(spike_trains, isi_distance, indices) + + +############################################################ +# spike_distance_multi +############################################################ +def spike_distance_multi(spike_trains, indices=None): + """ computes the multi-variate spike-distance for a set of spike-trains. + That is the average spike-distance of all pairs of spike-trains: + S(t) = 2/((N(N-1)) sum_{<i,j>} S_{i,j}, + where the sum goes over all pairs <i,j> + Args: + - spike_trains: list of spike trains + - indices: list of indices defining which spike trains to use, + if None all given spike trains are used (default=None) + Returns: + - A PieceWiseLinFunc representing the averaged spike distance S + """ + return multi_distance(spike_trains, spike_distance, indices) diff --git a/pyspike/function.py b/pyspike/function.py index 3a5a01c..26ca4b2 100644 --- a/pyspike/function.py +++ b/pyspike/function.py @@ -10,6 +10,7 @@ from __future__ import print_function import numpy as np + ############################################################## # PieceWiseConstFunc ############################################################## @@ -18,7 +19,7 @@ class PieceWiseConstFunc: def __init__(self, x, y): """ Constructs the piece-wise const function. - Params: + Args: - x: array of length N+1 defining the edges of the intervals of the pwc function. - y: array of length N defining the function values at the intervals. @@ -26,6 +27,19 @@ class PieceWiseConstFunc: self.x = np.array(x) self.y = np.array(y) + def almost_equal(self, other, decimal=14): + """ Checks if the function is equal to another function up to `decimal` + precision. + Args: + - other: another PieceWiseConstFunc object + Returns: + True if the two functions are equal up to `decimal` decimals, + False otherwise + """ + eps = 10.0**(-decimal) + return np.allclose(self.x, other.x, atol=eps, rtol=0.0) and \ + np.allclose(self.y, other.y, atol=eps, rtol=0.0) + def get_plottable_data(self): """ Returns two arrays containing x- and y-coordinates for immeditate plotting of the piece-wise function. @@ -63,7 +77,7 @@ class PieceWiseConstFunc: def add(self, f): """ Adds another PieceWiseConst function to this function. Note: only functions defined on the same interval can be summed. - Params: + Args: - f: PieceWiseConst function to be added. """ assert self.x[0] == f.x[0], "The functions have different intervals" @@ -111,7 +125,7 @@ class PieceWiseConstFunc: def mul_scalar(self, fac): """ Multiplies the function with a scalar value - Params: + Args: - fac: Value to multiply """ self.y *= fac @@ -125,7 +139,7 @@ class PieceWiseLinFunc: def __init__(self, x, y1, y2): """ Constructs the piece-wise linear function. - Params: + Args: - x: array of length N+1 defining the edges of the intervals of the pwc function. - y1: array of length N defining the function values at the left of the @@ -137,6 +151,20 @@ class PieceWiseLinFunc: self.y1 = np.array(y1) self.y2 = np.array(y2) + def almost_equal(self, other, decimal=14): + """ Checks if the function is equal to another function up to `decimal` + precision. + Args: + - other: another PieceWiseLinFunc object + Returns: + True if the two functions are equal up to `decimal` decimals, + False otherwise + """ + eps = 10.0**(-decimal) + return np.allclose(self.x, other.x, atol=eps, rtol=0.0) and \ + np.allclose(self.y1, other.y1, atol=eps, rtol=0.0) and \ + np.allclose(self.y2, other.y2, atol=eps, rtol=0.0) + def get_plottable_data(self): """ Returns two arrays containing x- and y-coordinates for immeditate plotting of the piece-wise function. @@ -171,7 +199,7 @@ class PieceWiseLinFunc: def add(self, f): """ Adds another PieceWiseLin function to this function. Note: only functions defined on the same interval can be summed. - Params: + Args: - f: PieceWiseLin function to be added. """ assert self.x[0] == f.x[0], "The functions have different intervals" @@ -246,7 +274,7 @@ class PieceWiseLinFunc: def mul_scalar(self, fac): """ Multiplies the function with a scalar value - Params: + Args: - fac: Value to multiply """ self.y1 *= fac diff --git a/test/test_distance.py b/test/test_distance.py index 35bdf85..c43f0b3 100644 --- a/test/test_distance.py +++ b/test/test_distance.py @@ -7,10 +7,12 @@ Copyright 2014, Mario Mulansky <mario.mulansky@gmx.net> from __future__ import print_function import numpy as np +from copy import copy from numpy.testing import assert_equal, assert_array_almost_equal import pyspike as spk + def test_auxiliary_spikes(): t = np.array([0.2, 0.4, 0.6, 0.7]) t_aux = spk.add_auxiliary_spikes(t, T_end=1.0, T_start=0.1) @@ -18,6 +20,7 @@ def test_auxiliary_spikes(): t_aux = spk.add_auxiliary_spikes(t_aux, 1.0) assert_equal(t_aux, [0.0, 0.1, 0.2, 0.4, 0.6, 0.7, 1.0]) + def test_isi(): # generate two spike trains: t1 = np.array([0.2, 0.4, 0.6, 0.7]) @@ -32,7 +35,7 @@ def test_isi(): t2 = spk.add_auxiliary_spikes(t2, 1.0) f = spk.isi_distance(t1, t2) - print("ISI: ", f.y) + # print("ISI: ", f.y) assert_equal(f.x, expected_times) assert_array_almost_equal(f.y, expected_isi, decimal=14) @@ -98,6 +101,48 @@ def test_spike(): assert_array_almost_equal(f.y2, expected_y2, decimal=14) +def check_multi_distance(dist_func, dist_func_multi): + # generate spike trains: + t1 = spk.add_auxiliary_spikes(np.array([0.2, 0.4, 0.6, 0.7]), 1.0) + t2 = spk.add_auxiliary_spikes(np.array([0.3, 0.45, 0.8, 0.9, 0.95]), 1.0) + t3 = spk.add_auxiliary_spikes(np.array([0.2,0.4,0.6]), 1.0) + t4 = spk.add_auxiliary_spikes(np.array([0.1,0.4,0.5,0.6]), 1.0) + spike_trains = [t1, t2, t3, t4] + + f12 = dist_func(t1, t2) + f13 = dist_func(t1, t3) + f14 = dist_func(t1, t4) + f23 = dist_func(t2, t3) + f24 = dist_func(t2, t4) + f34 = dist_func(t3, t4) + + f_multi = dist_func_multi(spike_trains, [0,1]) + assert f_multi.almost_equal(f12, decimal=14) + + f = copy(f12) + f.add(f13) + f.add(f23) + f.mul_scalar(1.0/3) + f_multi = dist_func_multi(spike_trains, [0,1,2]) + assert f_multi.almost_equal(f, decimal=14) + + f.mul_scalar(3) # revert above normalization + f.add(f14) + f.add(f24) + f.add(f34) + f.mul_scalar(1.0/6) + f_multi = dist_func_multi(spike_trains) + assert f_multi.almost_equal(f, decimal=14) + + +def test_multi_isi(): + check_multi_distance(spk.isi_distance, spk.isi_distance_multi) + + +def test_multi_spike(): + check_multi_distance(spk.spike_distance, spk.spike_distance_multi) + + if __name__ == "__main__": test_auxiliary_spikes() test_isi() |