summaryrefslogtreecommitdiff
path: root/ot/gpu/da.py
diff options
context:
space:
mode:
Diffstat (limited to 'ot/gpu/da.py')
-rw-r--r--ot/gpu/da.py8
1 files changed, 7 insertions, 1 deletions
diff --git a/ot/gpu/da.py b/ot/gpu/da.py
index 8c63870..6aba29c 100644
--- a/ot/gpu/da.py
+++ b/ot/gpu/da.py
@@ -22,7 +22,11 @@ def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10,
log=False, to_numpy=True):
"""
Solve the entropic regularization optimal transport problem with nonconvex
- group lasso regularization
+ group lasso regularization on GPU
+
+ If the input matrix are in numpy format, they will be uploaded to the
+ GPU first which can incur significant time overhead.
+
The function solves the following optimization problem:
@@ -74,6 +78,8 @@ def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10,
Print information along iterations
log : bool, optional
record log if True
+ to_numpy : boolean, optional (default True)
+ If true convert back the GPU array result to numpy format.
Returns