summaryrefslogtreecommitdiff
path: root/ot/gpu/bregman.py
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2018-09-24 13:57:42 +0200
committerRémi Flamary <remi.flamary@gmail.com>2018-09-24 13:57:42 +0200
commitd258c7d6936410cd78189445a0260d983f7684d6 (patch)
tree415b2fc393c68204055ff334882b08349e265507 /ot/gpu/bregman.py
parentc9b99df8fffec1dcc6802ef43b6192774817c5fb (diff)
convert ot.gpu to cupy
Diffstat (limited to 'ot/gpu/bregman.py')
-rw-r--r--ot/gpu/bregman.py143
1 files changed, 87 insertions, 56 deletions
diff --git a/ot/gpu/bregman.py b/ot/gpu/bregman.py
index 47939c4..912104c 100644
--- a/ot/gpu/bregman.py
+++ b/ot/gpu/bregman.py
@@ -8,14 +8,16 @@ Bregman projections for regularized OT with GPU
#
# License: MIT License
-import numpy as np
-import cudamat
+import cupy as np # np used for matrix computation
+import cupy as cp # cp used for cupy specific operations
+from . import utils
-def sinkhorn(a, b, M_GPU, reg, numItermax=1000, stopThr=1e-9, verbose=False,
- log=False, returnAsGPU=False):
- r"""
- Solve the entropic regularization optimal transport problem on GPU
+
+def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9,
+ verbose=False, log=False, to_numpy=True, **kwargs):
+ """
+ Solve the entropic regularization optimal transport problem and return the OT matrix
The function solves the following optimization problem:
@@ -40,9 +42,10 @@ def sinkhorn(a, b, M_GPU, reg, numItermax=1000, stopThr=1e-9, verbose=False,
----------
a : np.ndarray (ns,)
samples weights in the source domain
- b : np.ndarray (nt,)
- samples in the target domain
- M_GPU : cudamat.CUDAMatrix (ns,nt)
+ b : np.ndarray (nt,) or np.ndarray (nt,nbb)
+ samples in the target domain, compute sinkhorn with multiple targets
+ and fixed M if b is a matrix (return OT loss + dual variables in log)
+ M : np.ndarray (ns,nt)
loss matrix
reg : float
Regularization term >0
@@ -54,8 +57,7 @@ def sinkhorn(a, b, M_GPU, reg, numItermax=1000, stopThr=1e-9, verbose=False,
Print information along iterations
log : bool, optional
record log if True
- returnAsGPU : bool, optional
- return the OT matrix as a cudamat.CUDAMatrix
+
Returns
-------
@@ -88,60 +90,78 @@ def sinkhorn(a, b, M_GPU, reg, numItermax=1000, stopThr=1e-9, verbose=False,
ot.optim.cg : General regularized OT
"""
+
+ a = cp.asarray(a, dtype=np.float64)
+ b = cp.asarray(b, dtype=np.float64)
+ M = cp.asarray(M, dtype=np.float64)
+
+ if len(a) == 0:
+ a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0]
+ if len(b) == 0:
+ b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1]
+
# init data
Nini = len(a)
Nfin = len(b)
+ if len(b.shape) > 1:
+ nbb = b.shape[1]
+ else:
+ nbb = 0
+
if log:
log = {'err': []}
# we assume that no distances are null except those of the diagonal of
# distances
- u = (np.ones(Nini) / Nini).reshape((Nini, 1))
- u_GPU = cudamat.CUDAMatrix(u)
- a_GPU = cudamat.CUDAMatrix(a.reshape((Nini, 1)))
- ones_GPU = cudamat.empty(u_GPU.shape).assign(1)
- v = (np.ones(Nfin) / Nfin).reshape((Nfin, 1))
- v_GPU = cudamat.CUDAMatrix(v)
- b_GPU = cudamat.CUDAMatrix(b.reshape((Nfin, 1)))
-
- M_GPU.divide(-reg)
+ if nbb:
+ u = np.ones((Nini, nbb)) / Nini
+ v = np.ones((Nfin, nbb)) / Nfin
+ else:
+ u = np.ones(Nini) / Nini
+ v = np.ones(Nfin) / Nfin
- K_GPU = cudamat.exp(M_GPU)
+ # print(reg)
- ones_GPU.divide(a_GPU, target=a_GPU)
- Kp_GPU = cudamat.empty(K_GPU.shape)
- K_GPU.mult_by_col(a_GPU, target=Kp_GPU)
+ # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute
+ K = np.empty(M.shape, dtype=M.dtype)
+ np.divide(M, -reg, out=K)
+ np.exp(K, out=K)
- tmp_GPU = cudamat.empty(K_GPU.shape)
+ # print(np.min(K))
+ tmp2 = np.empty(b.shape, dtype=M.dtype)
+ Kp = (1 / a).reshape(-1, 1) * K
cpt = 0
err = 1
while (err > stopThr and cpt < numItermax):
- uprev_GPU = u_GPU.copy()
- vprev_GPU = v_GPU.copy()
+ uprev = u
+ vprev = v
- KtransposeU_GPU = K_GPU.transpose().dot(u_GPU)
- b_GPU.divide(KtransposeU_GPU, target=v_GPU)
- ones_GPU.divide(Kp_GPU.dot(v_GPU), target=u_GPU)
+ KtransposeU = np.dot(K.T, u)
+ v = np.divide(b, KtransposeU)
+ u = 1. / np.dot(Kp, v)
- if (np.any(KtransposeU_GPU.asarray() == 0) or
- not u_GPU.allfinite() or not v_GPU.allfinite()):
+ if (np.any(KtransposeU == 0) or
+ np.any(np.isnan(u)) or np.any(np.isnan(v)) or
+ np.any(np.isinf(u)) or np.any(np.isinf(v))):
# we have reached the machine precision
# come back to previous solution and quit loop
print('Warning: numerical errors at iteration', cpt)
- u_GPU = uprev_GPU.copy()
- v_GPU = vprev_GPU.copy()
+ u = uprev
+ v = vprev
break
if cpt % 10 == 0:
# we can speed up the process by checking for the error only all
# the 10th iterations
- K_GPU.mult_by_col(u_GPU, target=tmp_GPU)
- tmp_GPU.mult_by_row(v_GPU.transpose(), target=tmp_GPU)
-
- bcopy_GPU = b_GPU.copy().transpose()
- bcopy_GPU.add_sums(tmp_GPU, axis=0, beta=-1)
- err = bcopy_GPU.euclid_norm()**2
+ if nbb:
+ err = np.sum((u - uprev)**2) / np.sum((u)**2) + \
+ np.sum((v - vprev)**2) / np.sum((v)**2)
+ else:
+ # compute right marginal tmp2= (diag(u)Kdiag(v))^T1
+ tmp2=np.sum(u[:,None]*K*v[None,:],0)
+ #tmp2=np.einsum('i,ij,j->j', u, K, v)
+ err = np.linalg.norm(tmp2 - b)**2 # violation of marginal
if log:
log['err'].append(err)
@@ -150,20 +170,31 @@ def sinkhorn(a, b, M_GPU, reg, numItermax=1000, stopThr=1e-9, verbose=False,
print(
'{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
print('{:5d}|{:8e}|'.format(cpt, err))
- cpt += 1
- if log:
- log['u'] = u_GPU.asarray()
- log['v'] = v_GPU.asarray()
-
- K_GPU.mult_by_col(u_GPU, target=K_GPU)
- K_GPU.mult_by_row(v_GPU.transpose(), target=K_GPU)
-
- if returnAsGPU:
- res = K_GPU
- else:
- res = K_GPU.asarray()
-
+ cpt = cpt + 1
if log:
- return res, log
- else:
- return res
+ log['u'] = u
+ log['v'] = v
+
+ if nbb: # return only loss
+ #res = np.einsum('ik,ij,jk,ij->k', u, K, v, M) (explodes cupy memory)
+ res=np.empty(nbb)
+ for i in range(nbb):
+ res[i]=np.sum(u[:,None,i]*(K*M)*v[None,:,i])
+ if to_numpy:
+ res=utils.to_np(res)
+ if log:
+ return res, log
+ else:
+ return res
+
+ else: # return OT matrix
+ res=u.reshape((-1, 1)) * K * v.reshape((1, -1))
+ if to_numpy:
+ res=utils.to_np(res)
+ if log:
+ return res, log
+ else:
+ return res
+
+# define sinkhorn as sinkhorn_knopp
+sinkhorn=sinkhorn_knopp \ No newline at end of file