summaryrefslogtreecommitdiff
path: root/ot/gpu/bregman.py
blob: f91f15fd77c66ce3567d1c8289ca0c55fd373cee (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
# -*- coding: utf-8 -*-
"""
Bregman projections for regularized OT with GPU
"""

import numpy as np
import cudamat


def sinkhorn(a, b, M_GPU, reg, numItermax=1000, stopThr=1e-9, verbose=False,
                log=False):
    # init data
    Nini = len(a)
    Nfin = len(b)

    if log:
        log = {'err': []}

    # we assume that no distances are null except those of the diagonal of
    # distances
    u = (np.ones(Nini)/Nini).reshape((Nini, 1))
    u_GPU = cudamat.CUDAMatrix(u)
    a_GPU = cudamat.CUDAMatrix(a.reshape((Nini, 1)))
    ones_GPU = cudamat.empty(u_GPU.shape).assign(1)
    v = (np.ones(Nfin)/Nfin).reshape((Nfin, 1))
    v_GPU = cudamat.CUDAMatrix(v)
    b_GPU = cudamat.CUDAMatrix(b.reshape((Nfin, 1)))

    M_GPU.divide(-reg)

    K_GPU = cudamat.exp(M_GPU)

    ones_GPU.divide(a_GPU, target=a_GPU)
    Kp_GPU = cudamat.empty(K_GPU.shape)
    K_GPU.mult_by_col(a_GPU, target=Kp_GPU)

    tmp_GPU = cudamat.empty(K_GPU.shape)

    cpt = 0
    err = 1
    while (err > stopThr and cpt < numItermax):
        uprev_GPU = u_GPU.copy()
        vprev_GPU = v_GPU.copy()

        KtransposeU_GPU = K_GPU.transpose().dot(u_GPU)
        b_GPU.divide(KtransposeU_GPU, target=v_GPU)
        ones_GPU.divide(Kp_GPU.dot(v_GPU), target=u_GPU)

        if (np.any(KtransposeU_GPU.asarray() == 0) or
           not u_GPU.allfinite() or not v_GPU.allfinite()):
            # we have reached the machine precision
            # come back to previous solution and quit loop
            print('Warning: numerical errors at iteration', cpt)
            u_GPU = uprev_GPU.copy()
            v_GPU = vprev_GPU.copy()
            break
        if cpt % 10 == 0:
            # we can speed up the process by checking for the error only all
            # the 10th iterations
            K_GPU.mult_by_col(u_GPU, target=tmp_GPU)
            tmp_GPU.mult_by_row(v_GPU.transpose(), target=tmp_GPU)

            bcopy_GPU = b_GPU.copy().transpose()
            bcopy_GPU.add_sums(tmp_GPU, axis=0, beta=-1)
            err = bcopy_GPU.euclid_norm()**2
            if log:
                log['err'].append(err)

            if verbose:
                if cpt % 200 == 0:
                    print('{:5s}|{:12s}'.format('It.', 'Err')+'\n'+'-'*19)
                print('{:5d}|{:8e}|'.format(cpt, err))
        cpt += 1
    if log:
        log['u'] = u_GPU.asarray()
        log['v'] = v_GPU.asarray()

    K_GPU.mult_by_col(u_GPU, target=K_GPU)
    K_GPU.mult_by_row(v_GPU.transpose(), target=K_GPU)
    if log:
        return K_GPU.asarray(), log
    else:
        return K_GPU.asarray()