summaryrefslogtreecommitdiff
path: root/pyspike/distances.py
blob: 9077871e31326b44bb08fbae23493149288df7a8 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
""" distances.py

Module containing several functions to compute spike distances

Copyright 2014, Mario Mulansky <mario.mulansky@gmx.net>

Distributed under the BSD License
"""

import numpy as np
import threading
from functools import partial

from pyspike import PieceWiseConstFunc, PieceWiseLinFunc, DiscreteFunc


############################################################
# isi_profile
############################################################
def isi_profile(spikes1, spikes2):
    """ Computes the isi-distance profile :math:`S_{isi}(t)` of the two given
    spike trains. Retruns the profile as a PieceWiseConstFunc object. The S_isi
    values are defined positive S_isi(t)>=0.  The spike trains are expected
    to have auxiliary spikes at the beginning and end of the interval. Use the
    function add_auxiliary_spikes to add those spikes to the spike train.

    :param spikes1: ordered array of spike times with auxiliary spikes.
    :param spikes2: ordered array of spike times with auxiliary spikes.
    :returns: The isi-distance profile :math:`S_{isi}(t)`
    :rtype: :class:`pyspike.function.PieceWiseConstFunc`

    """
    # check for auxiliary spikes - first and last spikes should be identical
    assert spikes1[0] == spikes2[0], \
        "Given spike trains seems not to have auxiliary spikes!"
    assert spikes1[-1] == spikes2[-1], \
        "Given spike trains seems not to have auxiliary spikes!"

    # load cython implementation
    try:
        from cython_distance import isi_distance_cython as isi_distance_impl
    except ImportError:
        print("Warning: isi_distance_cython not found. Make sure that PySpike \
is installed by running\n 'python setup.py build_ext --inplace'!\n \
Falling back to slow python backend.")
        # use python backend
        from python_backend import isi_distance_python as isi_distance_impl

    times, values = isi_distance_impl(spikes1, spikes2)
    return PieceWiseConstFunc(times, values)


############################################################
# isi_distance
############################################################
def isi_distance(spikes1, spikes2, interval=None):
    """ Computes the isi-distance I of the given spike trains. The
    isi-distance is the integral over the isi distance profile
    :math:`S_{isi}(t)`:

    .. math:: I = \int_{T_0}^{T_1} S_{isi}(t) dt.

    :param spikes1: ordered array of spike times with auxiliary spikes.
    :param spikes2: ordered array of spike times with auxiliary spikes.
    :param interval: averaging interval given as a pair of floats (T0, T1),
                     if None the average over the whole function is computed.
    :type interval: Pair of floats or None.
    :returns: The isi-distance I.
    :rtype: double
    """
    return isi_profile(spikes1, spikes2).avrg(interval)


############################################################
# spike_profile
############################################################
def spike_profile(spikes1, spikes2):
    """ Computes the spike-distance profile S_spike(t) of the two given spike
    trains. Returns the profile as a PieceWiseLinFunc object. The S_spike
    values are defined positive S_spike(t)>=0. The spike trains are expected to
    have auxiliary spikes at the beginning and end of the interval. Use the
    function add_auxiliary_spikes to add those spikes to the spike train.

    :param spikes1: ordered array of spike times with auxiliary spikes.
    :param spikes2: ordered array of spike times with auxiliary spikes.
    :returns: The spike-distance profile :math:`S_{spike}(t)`.
    :rtype: :class:`pyspike.function.PieceWiseLinFunc`

    """
    # check for auxiliary spikes - first and last spikes should be identical
    assert spikes1[0] == spikes2[0], \
        "Given spike trains seems not to have auxiliary spikes!"
    assert spikes1[-1] == spikes2[-1], \
        "Given spike trains seems not to have auxiliary spikes!"

    # cython implementation
    try:
        from cython_distance import spike_distance_cython \
            as spike_distance_impl
    except ImportError:
        print("Warning: spike_distance_cython not found. Make sure that \
PySpike is installed by running\n 'python setup.py build_ext --inplace'!\n \
Falling back to slow python backend.")
        # use python backend
        from python_backend import spike_distance_python as spike_distance_impl

    times, y_starts, y_ends = spike_distance_impl(spikes1, spikes2)
    return PieceWiseLinFunc(times, y_starts, y_ends)


############################################################
# spike_distance
############################################################
def spike_distance(spikes1, spikes2, interval=None):
    """ Computes the spike-distance S of the given spike trains. The
    spike-distance is the integral over the isi distance profile S_spike(t):

    .. math:: S = \int_{T_0}^{T_1} S_{spike}(t) dt.

    :param spikes1: ordered array of spike times with auxiliary spikes.
    :param spikes2: ordered array of spike times with auxiliary spikes.
    :param interval: averaging interval given as a pair of floats (T0, T1),
                     if None the average over the whole function is computed.
    :type interval: Pair of floats or None.
    :returns: The spike-distance.
    :rtype: double

    """
    return spike_profile(spikes1, spikes2).avrg(interval)


############################################################
# spike_sync_profile
############################################################
def spike_sync_profile(spikes1, spikes2):
    """ Computes the spike-synchronization profile S_sync(t) of the two given
    spike trains. Returns the profile as a DiscreteFunction object. The S_sync
    values are either 1 or 0, indicating the presence or absence of a
    coincidence. The spike trains are expected to have auxiliary spikes at the
    beginning and end of the interval. Use the function add_auxiliary_spikes to
    add those spikes to the spike train.

    :param spikes1: ordered array of spike times with auxiliary spikes.
    :param spikes2: ordered array of spike times with auxiliary spikes.
    :returns: The spike-distance profile :math:`S_{sync}(t)`.
    :rtype: :class:`pyspike.function.DiscreteFunction`

    """

    # cython implementation
    try:
        from cython_distance import coincidence_cython \
            as coincidence_impl
    except ImportError:
        print("Warning: spike_distance_cython not found. Make sure that \
PySpike is installed by running\n 'python setup.py build_ext --inplace'!\n \
Falling back to slow python backend.")
        # use python backend
        from python_backend import coincidence_python \
            as coincidence_impl

    times, coincidences, multiplicity = coincidence_impl(spikes1, spikes2)

    return DiscreteFunc(times, coincidences, multiplicity)


############################################################
# spike_sync
############################################################
def spike_sync(spikes1, spikes2, interval=None):
    """ Computes the spike synchronization value SYNC of the given spike
    trains. The spike synchronization value is the computed as the total number
    of coincidences divided by the total number of spikes:

    .. math:: SYNC = \sum_n C_n / N.

    :param spikes1: ordered array of spike times with auxiliary spikes.
    :param spikes2: ordered array of spike times with auxiliary spikes.
    :param interval: averaging interval given as a pair of floats (T0, T1),
                     if None the average over the whole function is computed.
    :type interval: Pair of floats or None.
    :returns: The spike synchronization value.
    :rtype: double
    """
    return spike_sync_profile(spikes1, spikes2).avrg(interval)


############################################################
# _generic_profile_multi
############################################################
def _generic_profile_multi(spike_trains, pair_distance_func, indices=None):
    """ Internal implementation detail, don't call this function directly,
    use isi_profile_multi or spike_profile_multi instead.

    Computes the multi-variate distance for a set of spike-trains using the
    pair_dist_func to compute pair-wise distances. That is it computes the
    average distance of all pairs of spike-trains:
    :math:`S(t) = 2/((N(N-1)) sum_{<i,j>} S_{i,j}`,
    where the sum goes over all pairs <i,j>.
    Args:
    - spike_trains: list of spike trains
    - pair_distance_func: function computing the distance of two spike trains
    - indices: list of indices defining which spike trains to use,
    if None all given spike trains are used (default=None)
    Returns:
    - The averaged multi-variate distance of all pairs
    """
    if indices is None:
        indices = np.arange(len(spike_trains))
    indices = np.array(indices)
    # check validity of indices
    assert (indices < len(spike_trains)).all() and (indices >= 0).all(), \
        "Invalid index list."
    # generate a list of possible index pairs
    pairs = [(indices[i], j) for i in range(len(indices))
             for j in indices[i+1:]]
    # start with first pair
    (i, j) = pairs[0]
    average_dist = pair_distance_func(spike_trains[i], spike_trains[j])
    for (i, j) in pairs[1:]:
        current_dist = pair_distance_func(spike_trains[i], spike_trains[j])
        average_dist.add(current_dist)       # add to the average
    return average_dist, len(pairs)


############################################################
# multi_distance_par
############################################################
def _multi_distance_par(spike_trains, pair_distance_func, indices=None):
    """ parallel implementation of the multi-distance. Not currently used as
    it does not improve the performance.
    """

    num_threads = 2
    lock = threading.Lock()

    def run(spike_trains, index_pairs, average_dist):
        (i, j) = index_pairs[0]
        # print(i,j)
        this_avrg = pair_distance_func(spike_trains[i], spike_trains[j])
        for (i, j) in index_pairs[1:]:
            # print(i,j)
            current_dist = pair_distance_func(spike_trains[i], spike_trains[j])
            this_avrg.add(current_dist)
        with lock:
            average_dist.add(this_avrg)

    if indices is None:
        indices = np.arange(len(spike_trains))
    indices = np.array(indices)
    # check validity of indices
    assert (indices < len(spike_trains)).all() and (indices >= 0).all(), \
        "Invalid index list."
    # generate a list of possible index pairs
    pairs = [(indices[i], j) for i in range(len(indices))
             for j in indices[i+1:]]
    num_pairs = len(pairs)

    # start with first pair
    (i, j) = pairs[0]
    average_dist = pair_distance_func(spike_trains[i], spike_trains[j])
    # remove the one we already computed
    pairs = pairs[1:]
    # distribute the rest into num_threads pieces
    clustered_pairs = [pairs[n::num_threads] for n in xrange(num_threads)]

    threads = []
    for pairs in clustered_pairs:
        t = threading.Thread(target=run, args=(spike_trains, pairs,
                                               average_dist))
        threads.append(t)
        t.start()
    for t in threads:
        t.join()
    average_dist.mul_scalar(1.0/num_pairs)  # normalize
    return average_dist


############################################################
# isi_profile_multi
############################################################
def isi_profile_multi(spike_trains, indices=None):
    """ computes the multi-variate isi distance profile for a set of spike
    trains. That is the average isi-distance of all pairs of spike-trains:
    S_isi(t) = 2/((N(N-1)) sum_{<i,j>} S_{isi}^{i,j},
    where the sum goes over all pairs <i,j>

    :param spike_trains: list of spike trains
    :param indices: list of indices defining which spike trains to use,
                    if None all given spike trains are used (default=None)
    :type state: list or None
    :returns: The averaged isi profile :math:`<S_{isi}>(t)`
    :rtype: :class:`pyspike.function.PieceWiseConstFunc`
    """
    average_dist, M = _generic_profile_multi(spike_trains, isi_profile,
                                             indices)
    average_dist.mul_scalar(1.0/M)  # normalize
    return average_dist


############################################################
# isi_distance_multi
############################################################
def isi_distance_multi(spike_trains, indices=None, interval=None):
    """ computes the multi-variate isi-distance for a set of spike-trains.
    That is the time average of the multi-variate spike profile:
    I = \int_0^T 2/((N(N-1)) sum_{<i,j>} S_{isi}^{i,j},
    where the sum goes over all pairs <i,j>

    :param spike_trains: list of spike trains
    :param indices: list of indices defining which spike trains to use,
                    if None all given spike trains are used (default=None)
    :param interval: averaging interval given as a pair of floats, if None
                     the average over the whole function is computed.
    :type interval: Pair of floats or None.
    :returns: The time-averaged isi distance :math:`I`
    :rtype: double
    """
    return isi_profile_multi(spike_trains, indices).avrg(interval)


############################################################
# spike_profile_multi
############################################################
def spike_profile_multi(spike_trains, indices=None):
    """ Computes the multi-variate spike distance profile for a set of spike
    trains. That is the average spike-distance of all pairs of spike-trains:
    :math:`S_spike(t) = 2/((N(N-1)) sum_{<i,j>} S_{spike}^{i, j}`,
    where the sum goes over all pairs <i,j>

    :param spike_trains: list of spike trains
    :param indices: list of indices defining which spike trains to use,
                    if None all given spike trains are used (default=None)
    :type indices: list or None
    :returns: The averaged spike profile :math:`<S_{spike}>(t)`
    :rtype: :class:`pyspike.function.PieceWiseLinFunc`

    """
    average_dist, M = _generic_profile_multi(spike_trains, spike_profile,
                                             indices)
    average_dist.mul_scalar(1.0/M)  # normalize
    return average_dist


############################################################
# spike_distance_multi
############################################################
def spike_distance_multi(spike_trains, indices=None, interval=None):
    """ Computes the multi-variate spike distance for a set of spike trains.
    That is the time average of the multi-variate spike profile:
    S_{spike} = \int_0^T 2/((N(N-1)) sum_{<i,j>} S_{spike}^{i, j} dt
    where the sum goes over all pairs <i,j>

    :param spike_trains: list of spike trains
    :param indices: list of indices defining which spike trains to use,
                    if None all given spike trains are used (default=None)
    :type indices: list or None
    :param interval: averaging interval given as a pair of floats, if None
                     the average over the whole function is computed.
    :type interval: Pair of floats or None.
    :returns: The averaged spike distance S.
    :rtype: double
    """
    return spike_profile_multi(spike_trains, indices).avrg(interval)


############################################################
# spike_profile_multi
############################################################
def spike_sync_profile_multi(spike_trains, indices=None):
    """ Computes the multi-variate spike synchronization profile for a set of
    spike trains. For each spike in the set of spike trains, the multi-variate
    profile is defined as the number of coincidences divided by the number of
    spike trains pairs involving the spike train of containing this spike,
    which is the number of spike trains minus one (N-1).

    :param spike_trains: list of spike trains
    :param indices: list of indices defining which spike trains to use,
                    if None all given spike trains are used (default=None)
    :type indices: list or None
    :returns: The multi-variate spike sync profile :math:`<S_{sync}>(t)`
    :rtype: :class:`pyspike.function.DiscreteFunction`

    """
    prof_func = partial(spike_sync_profile)
    average_dist, M = _generic_profile_multi(spike_trains, prof_func,
                                             indices)
    # average_dist.mul_scalar(1.0/M)  # no normalization here!
    return average_dist


############################################################
# spike_distance_multi
############################################################
def spike_sync_multi(spike_trains, indices=None, interval=None):
    """ Computes the multi-variate spike synchronization value for a set of
    spike trains.

    :param spike_trains: list of spike trains
    :param indices: list of indices defining which spike trains to use,
                    if None all given spike trains are used (default=None)
    :type indices: list or None
    :param interval: averaging interval given as a pair of floats, if None
                     the average over the whole function is computed.
    :type interval: Pair of floats or None.
    :returns: The multi-variate spike synchronization value SYNC.
    :rtype: double
    """
    return spike_sync_profile_multi(spike_trains, indices).avrg(interval)


############################################################
# generic_distance_matrix
############################################################
def _generic_distance_matrix(spike_trains, dist_function,
                             indices=None, interval=None):
    """ Internal implementation detail. Don't use this function directly.
    Instead use isi_distance_matrix or spike_distance_matrix.
    Computes the time averaged distance of all pairs of spike-trains.
    Args:
    - spike_trains: list of spike trains
    - indices: list of indices defining which spike-trains to use
    if None all given spike-trains are used (default=None)
    Return:
    - a 2D array of size len(indices)*len(indices) containing the average
    pair-wise distance
    """
    if indices is None:
        indices = np.arange(len(spike_trains))
    indices = np.array(indices)
    # check validity of indices
    assert (indices < len(spike_trains)).all() and (indices >= 0).all(), \
        "Invalid index list."
    # generate a list of possible index pairs
    pairs = [(indices[i], j) for i in range(len(indices))
             for j in indices[i+1:]]

    distance_matrix = np.zeros((len(indices), len(indices)))
    for i, j in pairs:
        d = dist_function(spike_trains[i], spike_trains[j], interval)
        distance_matrix[i, j] = d
        distance_matrix[j, i] = d
    return distance_matrix


############################################################
# isi_distance_matrix
############################################################
def isi_distance_matrix(spike_trains, indices=None, interval=None):
    """ Computes the time averaged isi-distance of all pairs of spike-trains.

    :param spike_trains: list of spike trains
    :param indices: list of indices defining which spike trains to use,
                    if None all given spike trains are used (default=None)
    :type indices: list or None
    :param interval: averaging interval given as a pair of floats, if None
                     the average over the whole function is computed.
    :type interval: Pair of floats or None.
    :returns: 2D array with the pair wise time average isi distances
              :math:`I_{ij}`
    :rtype: np.array
    """
    return _generic_distance_matrix(spike_trains, isi_distance,
                                    indices, interval)


############################################################
# spike_distance_matrix
############################################################
def spike_distance_matrix(spike_trains, indices=None, interval=None):
    """ Computes the time averaged spike-distance of all pairs of spike-trains.

    :param spike_trains: list of spike trains
    :param indices: list of indices defining which spike trains to use,
                    if None all given spike trains are used (default=None)
    :type indices: list or None
    :param interval: averaging interval given as a pair of floats, if None
                     the average over the whole function is computed.
    :type interval: Pair of floats or None.
    :returns: 2D array with the pair wise time average spike distances
              :math:`S_{ij}`
    :rtype: np.array
    """
    return _generic_distance_matrix(spike_trains, spike_distance,
                                    indices, interval)


############################################################
# spike_sync_matrix
############################################################
def spike_sync_matrix(spike_trains, indices=None, interval=None):
    """ Computes the overall spike-synchronization value of all pairs of
    spike-trains.

    :param spike_trains: list of spike trains
    :param indices: list of indices defining which spike trains to use,
                    if None all given spike trains are used (default=None)
    :type indices: list or None
    :param interval: averaging interval given as a pair of floats, if None
                     the average over the whole function is computed.
    :type interval: Pair of floats or None.
    :returns: 2D array with the pair wise time spike synchronization values
              :math:`SYNC_{ij}`
    :rtype: np.array
    """
    return _generic_distance_matrix(spike_trains, spike_sync,
                                    indices, interval)