summaryrefslogtreecommitdiff
path: root/ot/bregman/sink.py
blob: 8b97e1eef75ba2bbd74559c17c8c3eeb075174e5 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
# -*- coding: utf-8 -*-
"""
Created on Fri Oct 21 09:40:21 2016

@author: rflamary
"""

import numpy as np


def sinkhorn(a,b, M, reg,numItermax = 1000,stopThr=1e-9):
    """
    Solve the optimal transport problem (OT)
    
    .. math::
        \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)
        
        s.t. \gamma 1 = a
        
             \gamma^T 1= b 
             
             \gamma\geq 0
    where :
    
    - M is the metric cost matrix
    - Omega is the entropic regularization term
    - a and b are the sample weights
             
    Parameters
    ----------
    a : (ns,) ndarray
        samples in the source domain
    b : (nt,) ndarray
        samples in the target domain
    M : (ns,nt) ndarray
        loss matrix        
    reg: float()
        Regularization term >0
  
    
    Returns
    -------
    gamma: (ns x nt) ndarray
        Optimal transportation matrix for the given parameters
        
    """    
    # init data
    Nini = len(a)
    Nfin = len(b)
    
    
    cpt = 0
    
    # we assume that no distances are null except those of the diagonal of distances
    u = np.ones(Nini)/Nini
    v = np.ones(Nfin)/Nfin 
    uprev=np.zeros(Nini)
    vprev=np.zeros(Nini)

    #print reg
 
    K = np.exp(-M/reg)
    #print np.min(K)
      
    Kp = np.dot(np.diag(1/a),K)
    transp = K
    cpt = 0
    err=1
    while (err>stopThr and cpt<numItermax):
        if np.any(np.dot(K.T,u)==0) or np.any(np.isnan(u)) or np.any(np.isnan(v)):
            # we have reached the machine precision
            # come back to previous solution and quit loop
            print('Warning: numerical errrors')
            if cpt!=0:
                u = uprev
                v = vprev     
            break
        uprev = u
        vprev = v  
        v = np.divide(b,np.dot(K.T,u))
        u = 1./np.dot(Kp,v)
        if cpt%10==0:
            # we can speed up the process by checking for the error only all the 10th iterations
            transp = np.dot(np.diag(u),np.dot(K,np.diag(v)))
            err = np.linalg.norm((np.sum(transp,axis=0)-b))**2
        cpt = cpt +1
    #print 'err=',err,' cpt=',cpt  

    return np.dot(np.diag(u),np.dot(K,np.diag(v)))