summaryrefslogtreecommitdiff
path: root/ot/gpu/da.py
diff options
context:
space:
mode:
Diffstat (limited to 'ot/gpu/da.py')
-rw-r--r--ot/gpu/da.py2
1 files changed, 1 insertions, 1 deletions
diff --git a/ot/gpu/da.py b/ot/gpu/da.py
index 4a98038..7adb830 100644
--- a/ot/gpu/da.py
+++ b/ot/gpu/da.py
@@ -120,7 +120,7 @@ def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10,
labels_a2 = cp.asnumpy(labels_a)
classes = npp.unique(labels_a2)
for c in classes:
- idxc, = utils.to_gpu(npp.where(labels_a2 == c))
+ idxc = utils.to_gpu(*npp.where(labels_a2 == c))
indices_labels.append(idxc)
W = np.zeros(M.shape)