summaryrefslogtreecommitdiff
path: root/ot/gpu/da.py
diff options
context:
space:
mode:
Diffstat (limited to 'ot/gpu/da.py')
-rw-r--r--ot/gpu/da.py3
1 files changed, 3 insertions, 0 deletions
diff --git a/ot/gpu/da.py b/ot/gpu/da.py
index 6aba29c..4a98038 100644
--- a/ot/gpu/da.py
+++ b/ot/gpu/da.py
@@ -10,10 +10,12 @@ Domain adaptation with optimal transport with GPU implementation
#
# License: MIT License
+
import cupy as np # np used for matrix computation
import cupy as cp # cp used for cupy specific operations
import numpy as npp
from . import utils
+
from .bregman import sinkhorn
@@ -131,6 +133,7 @@ def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10,
# separated
W = np.ones(M.shape)
for (i, c) in enumerate(classes):
+
majs = np.sum(transp[indices_labels[i]], axis=0)
majs = p * ((majs + epsilon)**(p - 1))
W[indices_labels[i]] = majs