1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
|
# -*- coding: utf-8 -*-
"""
Created on Fri Oct 21 09:40:21 2016
@author: rflamary
"""
import numpy as np
def sinkhorn(a,b, M, reg,numItermax = 1000,stopThr=1e-9):
"""
Solve the optimal transport problem (OT)
.. math::
\gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)
s.t. \gamma 1 = a
\gamma^T 1= b
\gamma\geq 0
where :
- M is the metric cost matrix
- Omega is the entropic regularization term
- a and b are the sample weights
Parameters
----------
a : (ns,) ndarray
samples in the source domain
b : (nt,) ndarray
samples in the target domain
M : (ns,nt) ndarray
loss matrix
reg: float()
Regularization term >0
Returns
-------
gamma: (ns x nt) ndarray
Optimal transportation matrix for the given parameters
"""
# init data
Nini = len(a)
Nfin = len(b)
cpt = 0
# we assume that no distances are null except those of the diagonal of distances
u = np.ones(Nini)/Nini
v = np.ones(Nfin)/Nfin
uprev=np.zeros(Nini)
vprev=np.zeros(Nini)
#print reg
K = np.exp(-M/reg)
#print np.min(K)
Kp = np.dot(np.diag(1/a),K)
transp = K
cpt = 0
err=1
while (err>stopThr and cpt<numItermax):
if np.any(np.dot(K.T,u)==0) or np.any(np.isnan(u)) or np.any(np.isnan(v)):
# we have reached the machine precision
# come back to previous solution and quit loop
print('Warning: numerical errrors')
if cpt!=0:
u = uprev
v = vprev
break
uprev = u
vprev = v
v = np.divide(b,np.dot(K.T,u))
u = 1./np.dot(Kp,v)
if cpt%10==0:
# we can speed up the process by checking for the error only all the 10th iterations
transp = np.dot(np.diag(u),np.dot(K,np.diag(v)))
err = np.linalg.norm((np.sum(transp,axis=0)-b))**2
cpt = cpt +1
#print 'err=',err,' cpt=',cpt
return np.dot(np.diag(u),np.dot(K,np.diag(v)))
|