diff options
Diffstat (limited to 'ot/gpu/da.py')
-rw-r--r-- | ot/gpu/da.py | 2 |
1 files changed, 1 insertions, 1 deletions
diff --git a/ot/gpu/da.py b/ot/gpu/da.py index 4a98038..7adb830 100644 --- a/ot/gpu/da.py +++ b/ot/gpu/da.py @@ -120,7 +120,7 @@ def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10, labels_a2 = cp.asnumpy(labels_a) classes = npp.unique(labels_a2) for c in classes: - idxc, = utils.to_gpu(npp.where(labels_a2 == c)) + idxc = utils.to_gpu(*npp.where(labels_a2 == c)) indices_labels.append(idxc) W = np.zeros(M.shape) |