summaryrefslogtreecommitdiff
path: root/ot/gpu/bregman.py
diff options
context:
space:
mode:
authorLeo gautheron <gautheron@iv-cm-359.creatis.insa-lyon.fr>2017-04-24 10:43:44 +0200
committerLeo gautheron <gautheron@iv-cm-359.creatis.insa-lyon.fr>2017-04-24 10:43:44 +0200
commite2cf8538b05f026d73c6033777699af77e7508b5 (patch)
tree65e3537011f45cb4200aea2e9fdc5d68ffe1de90 /ot/gpu/bregman.py
parent8fc74ed792cda0fc0073dab218f7fdc08bf3c1ba (diff)
add GPU implementation sinkhorn lpl1
Diffstat (limited to 'ot/gpu/bregman.py')
-rw-r--r--ot/gpu/bregman.py12
1 files changed, 9 insertions, 3 deletions
diff --git a/ot/gpu/bregman.py b/ot/gpu/bregman.py
index f91f15f..cc610b7 100644
--- a/ot/gpu/bregman.py
+++ b/ot/gpu/bregman.py
@@ -8,7 +8,7 @@ import cudamat
def sinkhorn(a, b, M_GPU, reg, numItermax=1000, stopThr=1e-9, verbose=False,
- log=False):
+ log=False, returnAsGPU=False):
# init data
Nini = len(a)
Nfin = len(b)
@@ -77,7 +77,13 @@ def sinkhorn(a, b, M_GPU, reg, numItermax=1000, stopThr=1e-9, verbose=False,
K_GPU.mult_by_col(u_GPU, target=K_GPU)
K_GPU.mult_by_row(v_GPU.transpose(), target=K_GPU)
+
+ if returnAsGPU:
+ res = K_GPU
+ else:
+ res = K_GPU.asarray()
+
if log:
- return K_GPU.asarray(), log
+ return res, log
else:
- return K_GPU.asarray()
+ return res