summaryrefslogtreecommitdiff
path: root/test/test_regression/test_regression_random_spikes.py
blob: dfd5488ec100b7b5ecc29e3b323b2b49f96a7765 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
""" regression benchmark

Copyright 2015, Mario Mulansky <mario.mulansky@gmx.net>

Distributed under the BSD License
"""
from __future__ import print_function

import numpy as np
from scipy.io import loadmat
import pyspike as spk

from numpy.testing import assert_almost_equal

spk.disable_backend_warning = True


def test_regression_random():

    spike_file = "test/test_regression/regression_random_spikes.mat"
    spikes_name = "spikes"
    result_name = "Distances"
    result_file = "test/test_regression/regression_random_results_cSPIKY.mat"

    spike_train_sets = loadmat(spike_file)[spikes_name][0]
    results_cSPIKY = loadmat(result_file)[result_name]

    for i, spike_train_data in enumerate(spike_train_sets):
        spike_trains = []
        for spikes in spike_train_data[0]:
            spike_trains.append(spk.SpikeTrain(spikes.flatten(), 100.0))

        isi = spk.isi_distance_multi(spike_trains)
        spike = spk.spike_distance_multi(spike_trains)
        # spike_sync = spk.spike_sync_multi(spike_trains)

        assert_almost_equal(isi, results_cSPIKY[i][0], decimal=14,
                            err_msg="Index: %d, ISI" % i)
        assert_almost_equal(spike, results_cSPIKY[i][1], decimal=14,
                            err_msg="Index: %d, SPIKE" % i)


def check_regression_dataset(spike_file="benchmark.mat",
                             spikes_name="spikes",
                             result_file="results_cSPIKY.mat",
                             result_name="Distances"):
    """ Debuging function """
    np.set_printoptions(precision=15)

    spike_train_sets = loadmat(spike_file)[spikes_name][0]

    results_cSPIKY = loadmat(result_file)[result_name]

    err_max = 0.0
    err_max_ind = -1
    err_count = 0

    for i, spike_train_data in enumerate(spike_train_sets):
        spike_trains = []
        for spikes in spike_train_data[0]:
            spike_trains.append(spk.SpikeTrain(spikes.flatten(), 100.0))

        isi = spk.isi_distance_multi(spike_trains)
        spike = spk.spike_distance_multi(spike_trains)
        # spike_sync = spk.spike_sync_multi(spike_trains)

        if abs(isi - results_cSPIKY[i][0]) > 1E-14:
            print("Error in ISI:", i, isi, results_cSPIKY[i][0])
            print("Spike trains:")
            for st in spike_trains:
                print(st.spikes)

        err = abs(spike - results_cSPIKY[i][1])
        if err > 1E-14:
            err_count += 1
        if err > err_max:
            err_max = err
            err_max_ind = i

    print("Total Errors:", err_count)

    if err_max_ind > -1:
        print("Max SPIKE distance error:", err_max, "at index:", err_max_ind)
        spike_train_data = spike_train_sets[err_max_ind]
        for spikes in spike_train_data[0]:
            print(spikes.flatten())


def check_single_spike_train_set(index):
    """ Debuging function """
    np.set_printoptions(precision=15)
    spike_file = "regression_random_spikes.mat"
    spikes_name = "spikes"
    result_name = "Distances"
    result_file = "regression_random_results_cSPIKY.mat"

    spike_train_sets = loadmat(spike_file)[spikes_name][0]

    results_cSPIKY = loadmat(result_file)[result_name]

    spike_train_data = spike_train_sets[index]

    spike_trains = []
    for spikes in spike_train_data[0]:
        print("Spikes:", spikes.flatten())
        spike_trains.append(spk.SpikeTrain(spikes.flatten(), 100.0))

    print(spk.spike_distance_multi(spike_trains))

    print(results_cSPIKY[index][1])

    print(spike_trains[1].spikes)


if __name__ == "__main__":

    test_regression_random()
    # check_regression_dataset()
    # check_single_spike_train_set(7633)