summaryrefslogtreecommitdiff
path: root/ot/gpu/da.py
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2018-09-24 14:30:44 +0200
committerRémi Flamary <remi.flamary@gmail.com>2018-09-24 14:30:44 +0200
commitf45f7a68b221ec5b619b8fd8de797815a1eecf43 (patch)
treee3cc97cdf0c38e457303ceba32f7dadc20a12139 /ot/gpu/da.py
parentd258c7d6936410cd78189445a0260d983f7684d6 (diff)
pep8
Diffstat (limited to 'ot/gpu/da.py')
-rw-r--r--ot/gpu/da.py16
1 files changed, 8 insertions, 8 deletions
diff --git a/ot/gpu/da.py b/ot/gpu/da.py
index 8bcc2aa..8c63870 100644
--- a/ot/gpu/da.py
+++ b/ot/gpu/da.py
@@ -10,15 +10,16 @@ Domain adaptation with optimal transport with GPU implementation
#
# License: MIT License
-import cupy as np # np used for matrix computation
-import cupy as cp # cp used for cupy specific operations
+import cupy as np # np used for matrix computation
+import cupy as cp # cp used for cupy specific operations
import numpy as npp
from . import utils
from .bregman import sinkhorn
+
def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10,
numInnerItermax=200, stopInnerThr=1e-9, verbose=False,
- log=False,to_numpy=True):
+ log=False, to_numpy=True):
"""
Solve the entropic regularization optimal transport problem with nonconvex
group lasso regularization
@@ -101,15 +102,14 @@ def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10,
ot.optim.cg : General regularized OT
"""
-
+
a, labels_a, b, M = utils.to_gpu(a, labels_a, b, M)
-
-
+
p = 0.5
epsilon = 1e-3
indices_labels = []
- labels_a2=cp.asnumpy(labels_a)
+ labels_a2 = cp.asnumpy(labels_a)
classes = npp.unique(labels_a2)
for c in classes:
idxc, = utils.to_gpu(npp.where(labels_a2 == c))
@@ -120,7 +120,7 @@ def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10,
for cpt in range(numItermax):
Mreg = M + eta * W
transp = sinkhorn(a, b, Mreg, reg, numItermax=numInnerItermax,
- stopThr=stopInnerThr,to_numpy=False)
+ stopThr=stopInnerThr, to_numpy=False)
# the transport has been computed. Check if classes are really
# separated
W = np.ones(M.shape)