summaryrefslogtreecommitdiff
path: root/ot/da.py
diff options
context:
space:
mode:
Diffstat (limited to 'ot/da.py')
-rw-r--r--ot/da.py13
1 files changed, 10 insertions, 3 deletions
diff --git a/ot/da.py b/ot/da.py
index 0b9737e..083663c 100644
--- a/ot/da.py
+++ b/ot/da.py
@@ -126,8 +126,12 @@ def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10,
W = nx.zeros(M.shape, type_as=M)
for cpt in range(numItermax):
Mreg = M + eta * W
- transp = sinkhorn(a, b, Mreg, reg, numItermax=numInnerItermax,
- stopThr=stopInnerThr)
+ if log:
+ transp, log = sinkhorn(a, b, Mreg, reg, numItermax=numInnerItermax,
+ stopThr=stopInnerThr, log=True)
+ else:
+ transp = sinkhorn(a, b, Mreg, reg, numItermax=numInnerItermax,
+ stopThr=stopInnerThr)
# the transport has been computed. Check if classes are really
# separated
W = nx.ones(M.shape, type_as=M)
@@ -136,7 +140,10 @@ def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10,
majs = p * ((majs + epsilon) ** (p - 1))
W[indices_labels[i]] = majs
- return transp
+ if log:
+ return transp, log
+ else:
+ return transp
def sinkhorn_l1l2_gl(a, labels_a, b, M, reg, eta=0.1, numItermax=10,