1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
|
"""Tests for ot solvers"""
# Author: Remi Flamary <remi.flamary@polytechnique.edu>
#
# License: MIT License
import itertools
import numpy as np
import pytest
import ot
lst_reg = [None, 1.0]
lst_reg_type = ['KL', 'entropy', 'L2']
lst_unbalanced = [None, 0.9]
lst_unbalanced_type = ['KL', 'L2', 'TV']
def assert_allclose_sol(sol1, sol2):
lst_attr = ['value', 'value_linear', 'plan',
'potential_a', 'potential_b', 'marginal_a', 'marginal_b']
nx1 = sol1._backend if sol1._backend is not None else ot.backend.NumpyBackend()
nx2 = sol2._backend if sol2._backend is not None else ot.backend.NumpyBackend()
for attr in lst_attr:
try:
np.allclose(nx1.to_numpy(getattr(sol1, attr)), nx2.to_numpy(getattr(sol2, attr)))
except NotImplementedError:
pass
def test_solve(nx):
n_samples_s = 10
n_samples_t = 7
n_features = 2
rng = np.random.RandomState(0)
x = rng.randn(n_samples_s, n_features)
y = rng.randn(n_samples_t, n_features)
a = ot.utils.unif(n_samples_s)
b = ot.utils.unif(n_samples_t)
M = ot.dist(x, y)
# solve unif weights
sol0 = ot.solve(M)
print(sol0)
# solve signe weights
sol = ot.solve(M, a, b)
# check some attributes
sol.potentials
sol.sparse_plan
sol.marginals
sol.status
assert_allclose_sol(sol0, sol)
# solve in backend
ab, bb, Mb = nx.from_numpy(a, b, M)
solb = ot.solve(M, a, b)
assert_allclose_sol(sol, solb)
# test not implemented unbalanced and check raise
with pytest.raises(NotImplementedError):
sol0 = ot.solve(M, unbalanced=1, unbalanced_type='cryptic divergence')
# test not implemented reg_type and check raise
with pytest.raises(NotImplementedError):
sol0 = ot.solve(M, reg=1, reg_type='cryptic divergence')
@pytest.mark.parametrize("reg,reg_type,unbalanced,unbalanced_type", itertools.product(lst_reg, lst_reg_type, lst_unbalanced, lst_unbalanced_type))
def test_solve_grid(nx, reg, reg_type, unbalanced, unbalanced_type):
n_samples_s = 10
n_samples_t = 7
n_features = 2
rng = np.random.RandomState(0)
x = rng.randn(n_samples_s, n_features)
y = rng.randn(n_samples_t, n_features)
a = ot.utils.unif(n_samples_s)
b = ot.utils.unif(n_samples_t)
M = ot.dist(x, y)
try:
# solve unif weights
sol0 = ot.solve(M, reg=reg, reg_type=reg_type, unbalanced=unbalanced, unbalanced_type=unbalanced_type)
# solve signe weights
sol = ot.solve(M, a, b, reg=reg, reg_type=reg_type, unbalanced=unbalanced, unbalanced_type=unbalanced_type)
assert_allclose_sol(sol0, sol)
# solve in backend
ab, bb, Mb = nx.from_numpy(a, b, M)
solb = ot.solve(M, a, b, reg=reg, reg_type=reg_type, unbalanced=unbalanced, unbalanced_type=unbalanced_type)
assert_allclose_sol(sol, solb)
except NotImplementedError:
pass
def test_solve_not_implemented(nx):
n_samples_s = 10
n_samples_t = 7
n_features = 2
rng = np.random.RandomState(0)
x = rng.randn(n_samples_s, n_features)
y = rng.randn(n_samples_t, n_features)
M = ot.dist(x, y)
# test not implemented and check raise
with pytest.raises(NotImplementedError):
ot.solve(M, reg=1.0, reg_type='cryptic divergence')
with pytest.raises(NotImplementedError):
ot.solve(M, unbalanced=1.0, unbalanced_type='cryptic divergence')
# pairs of incompatible divergences
with pytest.raises(NotImplementedError):
ot.solve(M, reg=1.0, reg_type='kl', unbalanced=1.0, unbalanced_type='tv')
|