diff options
Diffstat (limited to 'ot/gpu/da.py')
-rw-r--r-- | ot/gpu/da.py | 8 |
1 files changed, 7 insertions, 1 deletions
diff --git a/ot/gpu/da.py b/ot/gpu/da.py index 8c63870..6aba29c 100644 --- a/ot/gpu/da.py +++ b/ot/gpu/da.py @@ -22,7 +22,11 @@ def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10, log=False, to_numpy=True): """ Solve the entropic regularization optimal transport problem with nonconvex - group lasso regularization + group lasso regularization on GPU + + If the input matrix are in numpy format, they will be uploaded to the + GPU first which can incur significant time overhead. + The function solves the following optimization problem: @@ -74,6 +78,8 @@ def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10, Print information along iterations log : bool, optional record log if True + to_numpy : boolean, optional (default True) + If true convert back the GPU array result to numpy format. Returns |