summaryrefslogtreecommitdiff
path: root/src/python/test/test_wasserstein_distance.py
blob: a76b6ce768eb7bfcc7f38c1b5c1889b8d6e51b65 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
""" This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT.
    See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details.
    Author(s):       Theo Lacombe, Marc Glisse

    Copyright (C) 2019 Inria

    Modification(s):
      - 2020/07 Théo Lacombe: Added tests about handling essential parts in diagrams.
      - YYYY/MM Author: Description of the modification
"""

from gudhi.wasserstein.wasserstein import _proj_on_diag, _finite_part, _handle_essential_parts, _get_essential_parts
from gudhi.wasserstein.wasserstein import _warn_infty
from gudhi.wasserstein import wasserstein_distance as pot
from gudhi.hera import wasserstein_distance as hera
import numpy as np
import pytest


__author__ = "Theo Lacombe"
__copyright__ = "Copyright (C) 2019 Inria"
__license__ = "MIT"


def test_proj_on_diag():
    dgm = np.array([[1., 1.], [1., 2.], [3., 5.]])
    assert np.array_equal(_proj_on_diag(dgm), [[1., 1.], [1.5, 1.5], [4., 4.]])
    empty = np.empty((0, 2))
    assert np.array_equal(_proj_on_diag(empty), empty)


def test_finite_part():
    diag = np.array([[0, 1], [3, 5], [2, np.inf], [3, np.inf], [-np.inf, 8], [-np.inf, 12], [-np.inf, -np.inf],
                     [np.inf, np.inf], [-np.inf, np.inf], [-np.inf, np.inf]])
    assert np.array_equal(_finite_part(diag), [[0, 1], [3, 5]])


def test_handle_essential_parts():
    diag1 = np.array([[0, 1], [3, 5],
                      [2, np.inf], [3, np.inf],
                      [-np.inf, 8], [-np.inf, 12],
                      [-np.inf, -np.inf],
                      [np.inf, np.inf],
                      [-np.inf, np.inf], [-np.inf, np.inf]])

    diag2 = np.array([[0, 2], [3, 5],
                      [2, np.inf], [4, np.inf],
                      [-np.inf, 8], [-np.inf, 11],
                      [-np.inf, -np.inf],
                      [np.inf, np.inf],
                      [-np.inf, np.inf], [-np.inf, np.inf]])

    diag3 = np.array([[0, 2], [3, 5],
                      [2, np.inf], [4, np.inf], [6, np.inf],
                      [-np.inf, 8], [-np.inf, 11],
                      [-np.inf, -np.inf],
                      [np.inf, np.inf],
                      [-np.inf, np.inf], [-np.inf, np.inf]])

    c, m = _handle_essential_parts(diag1, diag2, order=1)
    assert c == pytest.approx(2, 0.0001)  # Note: here c is only the cost due to essential part (thus 2, not 3)
    # Similarly, the matching only corresponds to essential parts.
    # Note that (-inf,-inf) and (+inf,+inf) coordinates are matched to the diagonal.
    assert np.array_equal(m, [[4, 4], [5, 5], [2, 2], [3, 3], [8, 8], [9, 9], [6, -1], [7, -1], [-1, 6], [-1, 7]])

    c, m = _handle_essential_parts(diag1, diag3, order=1)
    assert c == np.inf
    assert (m is None)


def test_get_essential_parts():
    diag1 = np.array([[0, 1], [3, 5], [2, np.inf], [3, np.inf], [-np.inf, 8], [-np.inf, 12], [-np.inf, -np.inf],
                     [np.inf, np.inf], [-np.inf, np.inf], [-np.inf, np.inf]])

    diag2 = np.array([[0, 1], [3, 5], [2, np.inf], [3, np.inf]])

    res  = _get_essential_parts(diag1)
    res2 = _get_essential_parts(diag2)
    assert np.array_equal(res[0], [4, 5])
    assert np.array_equal(res[1], [2, 3])
    assert np.array_equal(res[2], [8, 9])
    assert np.array_equal(res[3], [6]   )
    assert np.array_equal(res[4], [7]   )

    assert np.array_equal(res2[0], []    )
    assert np.array_equal(res2[1], [2, 3])
    assert np.array_equal(res2[2], []    )
    assert np.array_equal(res2[3], []    )
    assert np.array_equal(res2[4], []    )


def test_warn_infty():
    with pytest.warns(UserWarning):
        assert _warn_infty(matching=False)==np.inf
        c, m = _warn_infty(matching=True)
        assert (c == np.inf)
        assert (m is None)


def _basic_wasserstein(wasserstein_distance, delta, test_infinity=True, test_matching=True):
    diag1 = np.array([[2.7, 3.7], [9.6, 14.0], [34.2, 34.974]])
    diag2 = np.array([[2.8, 4.45], [9.5, 14.1]])
    diag3 = np.array([[0, 2], [4, 6]])
    diag4 = np.array([[0, 3], [4, 8]])
    emptydiag = np.array([])

    # We just need to handle positive numbers here
    def approx(x):
        return pytest.approx(x, rel=delta)

    assert wasserstein_distance(emptydiag, emptydiag, internal_p=2.,     order=1.) == 0.
    assert wasserstein_distance(emptydiag, emptydiag, internal_p=np.inf, order=1.) == 0.
    assert wasserstein_distance(emptydiag, emptydiag, internal_p=np.inf, order=2.) == 0.
    assert wasserstein_distance(emptydiag, emptydiag, internal_p=2.,     order=2.) == 0.

    assert wasserstein_distance(diag3, emptydiag, internal_p=np.inf,     order=1.) == approx(2.)
    assert wasserstein_distance(diag3, emptydiag, internal_p=1.,         order=1.) == approx(4.)

    assert wasserstein_distance(diag4, emptydiag, internal_p=1.,     order=2.) == approx(5.)  # thank you Pythagorician triplets
    assert wasserstein_distance(diag4, emptydiag, internal_p=np.inf, order=2.) == approx(2.5)
    assert wasserstein_distance(diag4, emptydiag, internal_p=2.,     order=2.) == approx(3.5355339059327378)

    assert wasserstein_distance(diag1, diag2, internal_p=2.,   order=1.)   == approx(1.4453593023967701)
    assert wasserstein_distance(diag1, diag2, internal_p=2.35, order=1.74) == approx(0.9772734057168739)

    assert wasserstein_distance(diag1, emptydiag, internal_p=2.35, order=1.7863) == approx(3.141592214572228)

    assert wasserstein_distance(diag3, diag4, internal_p=1.,     order=1.) == approx(3.)
    assert wasserstein_distance(diag3, diag4, internal_p=np.inf, order=1.) == approx(3.)  # no diag matching here
    assert wasserstein_distance(diag3, diag4, internal_p=np.inf, order=2.) == approx(np.sqrt(5))
    assert wasserstein_distance(diag3, diag4, internal_p=1.,     order=2.) == approx(np.sqrt(5))
    assert wasserstein_distance(diag3, diag4, internal_p=4.5,    order=2.) == approx(np.sqrt(5))

    if test_infinity:
        diag5 = np.array([[0, 3], [4, np.inf]])
        diag6 = np.array([[7, 8], [4, 6], [3, np.inf]])

        assert wasserstein_distance(diag4, diag5) == np.inf
        assert wasserstein_distance(diag5, diag6, order=1, internal_p=np.inf) == approx(4.)
        assert wasserstein_distance(diag5, emptydiag) == np.inf

    if test_matching:
        match = wasserstein_distance(emptydiag, emptydiag, matching=True, internal_p=1., order=2)[1]
        assert np.array_equal(match, [])
        match = wasserstein_distance(emptydiag, emptydiag, matching=True, internal_p=np.inf, order=2.24)[1]
        assert np.array_equal(match, [])
        match = wasserstein_distance(emptydiag, diag2, matching=True, internal_p=np.inf, order=2.)[1]
        assert np.array_equal(match , [[-1, 0], [-1, 1]])
        match = wasserstein_distance(diag2, emptydiag, matching=True, internal_p=np.inf, order=2.24)[1]
        assert np.array_equal(match , [[0, -1], [1, -1]])
        match = wasserstein_distance(diag1, diag2, matching=True, internal_p=2., order=2.)[1]
        assert np.array_equal(match, [[0, 0], [1, 1], [2, -1]])

    if test_matching and test_infinity:
        diag7 = np.array([[0, 3], [4, np.inf], [5, np.inf]])
        diag8 = np.array([[0,1], [0, np.inf], [-np.inf, -np.inf], [np.inf, np.inf]])
        diag9 = np.array([[-np.inf, -np.inf], [np.inf, np.inf]])
        diag10 = np.array([[0,1], [-np.inf, -np.inf], [np.inf, np.inf]])

        match = wasserstein_distance(diag5, diag6, matching=True, internal_p=2., order=2.)[1]
        assert np.array_equal(match, [[0, -1], [-1,0], [-1, 1], [1, 2]])
        match = wasserstein_distance(diag5, diag7, matching=True, internal_p=2., order=2.)[1]
        assert (match is None)
        cost, match = wasserstein_distance(diag7, emptydiag, matching=True, internal_p=2., order=2.3)
        assert (cost == np.inf)
        assert (match is None)
        cost, match = wasserstein_distance(emptydiag, diag7, matching=True, internal_p=2.42, order=2.)
        assert (cost == np.inf)
        assert (match is None)
        cost, match = wasserstein_distance(diag8, diag9, matching=True, internal_p=2., order=2.)
        assert (cost == np.inf)
        assert (match is None)
        cost, match = wasserstein_distance(diag9, diag10, matching=True, internal_p=1., order=1.)
        assert (cost == 1)
        assert (match == [[0, -1],[1, -1],[-1, 0], [-1, 1], [-1, 2]]) # type 4 and 5 are match to the diag anyway.
        cost, match = wasserstein_distance(diag9, emptydiag, matching=True, internal_p=2., order=2.)
        assert (cost == 0.)
        assert (match == [[0, -1], [1, -1]])


def hera_wrap(**extra):
    def fun(*kargs,**kwargs):
        return hera(*kargs,**kwargs,**extra)
    return fun


def pot_wrap(**extra):
    def fun(*kargs,**kwargs):
        return pot(*kargs,**kwargs,**extra)
    return fun


def test_wasserstein_distance_pot():
    _basic_wasserstein(pot, 1e-15, test_infinity=False, test_matching=True)  # pot with its standard args
    _basic_wasserstein(pot_wrap(enable_autodiff=True, keep_essential_parts=False), 1e-15, test_infinity=False, test_matching=False)


def test_wasserstein_distance_hera():
    _basic_wasserstein(hera_wrap(delta=1e-12), 1e-12, test_matching=False)
    _basic_wasserstein(hera_wrap(delta=.1), .1, test_matching=False)