1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
|
"""
Factored OT solvers (low rank, cost or OT plan)
"""
# Author: Remi Flamary <remi.flamary@polytehnique.edu>
#
# License: MIT License
from .backend import get_backend
from .utils import dist
from .lp import emd
from .bregman import sinkhorn
__all__ = ['factored_optimal_transport']
def factored_optimal_transport(Xa, Xb, a=None, b=None, reg=0.0, r=100, X0=None, stopThr=1e-7, numItermax=100, verbose=False, log=False, **kwargs):
r"""Solves factored OT problem and return OT plans and intermediate distribution
This function solve the following OT problem [40]_
.. math::
\mathop{\arg \min}_\mu \quad W_2^2(\mu_a,\mu)+ W_2^2(\mu,\mu_b)
where :
- :math:`\mu_a` and :math:`\mu_b` are empirical distributions.
- :math:`\mu` is an empirical distribution with r samples
And returns the two OT plans between
.. note:: This function is backend-compatible and will work on arrays
from all compatible backends. But the algorithm uses the C++ CPU backend
which can lead to copy overhead on GPU arrays.
Uses the conditional gradient algorithm to solve the problem proposed in
:ref:`[39] <references-weak>`.
Parameters
----------
Xa : (ns,d) array-like, float
Source samples
Xb : (nt,d) array-like, float
Target samples
a : (ns,) array-like, float
Source histogram (uniform weight if empty list)
b : (nt,) array-like, float
Target histogram (uniform weight if empty list))
numItermax : int, optional
Max number of iterations
stopThr : float, optional
Stop threshold on the relative variation (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
record log if True
Returns
-------
Ga: array-like, shape (ns, r)
Optimal transportation matrix between source and the intermediate
distribution
Gb: array-like, shape (r, nt)
Optimal transportation matrix between the intermediate and target
distribution
X: array-like, shape (r, d)
Support of the intermediate distribution
log: dict, optional
If input log is true, a dictionary containing the cost and dual
variables and exit status
.. _references-factored:
References
----------
.. [40] Forrow, A., Hütter, J. C., Nitzan, M., Rigollet, P., Schiebinger,
G., & Weed, J. (2019, April). Statistical optimal transport via factored
couplings. In The 22nd International Conference on Artificial
Intelligence and Statistics (pp. 2454-2465). PMLR.
See Also
--------
ot.bregman.sinkhorn : Entropic regularized OT
ot.optim.cg : General regularized OT
"""
nx = get_backend(Xa, Xb)
n_a = Xa.shape[0]
n_b = Xb.shape[0]
d = Xa.shape[1]
if a is None:
a = nx.ones((n_a), type_as=Xa) / n_a
if b is None:
b = nx.ones((n_b), type_as=Xb) / n_b
if X0 is None:
X = nx.randn(r, d, type_as=Xa)
else:
X = X0
w = nx.ones(r, type_as=Xa) / r
def solve_ot(X1, X2, w1, w2):
M = dist(X1, X2)
if reg > 0:
G, log = sinkhorn(w1, w2, M, reg, log=True, **kwargs)
log['cost'] = nx.sum(G * M)
return G, log
else:
return emd(w1, w2, M, log=True, **kwargs)
norm_delta = []
# solve the barycenter
for i in range(numItermax):
old_X = X
# solve OT with template
Ga, loga = solve_ot(Xa, X, a, w)
Gb, logb = solve_ot(X, Xb, w, b)
X = 0.5 * (nx.dot(Ga.T, Xa) + nx.dot(Gb, Xb)) * r
delta = nx.norm(X - old_X)
if delta < stopThr:
break
if log:
norm_delta.append(delta)
if log:
log_dic = {'delta_iter': norm_delta,
'ua': loga['u'],
'va': loga['v'],
'ub': logb['u'],
'vb': logb['v'],
'costa': loga['cost'],
'costb': logb['cost'],
}
return Ga, Gb, X, log_dic
return Ga, Gb, X
|