diff options
author | ncourty <ncourty@irisa.fr> | 2017-04-27 11:17:59 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2017-04-27 11:17:59 +0200 |
commit | 55ff888aab4e3a5f0a4a3180887d40e96fbea3d8 (patch) | |
tree | de0ea03345a1e4192cd404d52fecce5b38e53176 /ot/gpu/bregman.py | |
parent | bd325a391aa5156faec0446c455be1f8f94eb80b (diff) | |
parent | 9df60d467ddd3316334578adac8a80667cfa8759 (diff) |
Merge pull request #10 from aje/master
performance improvement sinkhorn lpl1
Diffstat (limited to 'ot/gpu/bregman.py')
-rw-r--r-- | ot/gpu/bregman.py | 12 |
1 files changed, 9 insertions, 3 deletions
diff --git a/ot/gpu/bregman.py b/ot/gpu/bregman.py index f91f15f..cc610b7 100644 --- a/ot/gpu/bregman.py +++ b/ot/gpu/bregman.py @@ -8,7 +8,7 @@ import cudamat def sinkhorn(a, b, M_GPU, reg, numItermax=1000, stopThr=1e-9, verbose=False, - log=False): + log=False, returnAsGPU=False): # init data Nini = len(a) Nfin = len(b) @@ -77,7 +77,13 @@ def sinkhorn(a, b, M_GPU, reg, numItermax=1000, stopThr=1e-9, verbose=False, K_GPU.mult_by_col(u_GPU, target=K_GPU) K_GPU.mult_by_row(v_GPU.transpose(), target=K_GPU) + + if returnAsGPU: + res = K_GPU + else: + res = K_GPU.asarray() + if log: - return K_GPU.asarray(), log + return res, log else: - return K_GPU.asarray() + return res |