diff options
author | RĂ©mi Flamary <remi.flamary@gmail.com> | 2022-12-06 18:02:44 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-12-06 18:02:44 +0100 |
commit | 8490196dcc982c492b7565e1ec4de5f75f006acf (patch) | |
tree | 7f330e9c15dd4dff8d36c2308fa8d416b552ce23 /ot/da.py | |
parent | ac830dd2b85cfd39f4fadd879a721b36ded033ea (diff) |
[MRG] Fix bug in regularized OTDA l1lp with log (#413)
* correct bug in DA l1lp with log
* better tests and speedup with smaller dataset size
* remove jax for log test
* remove trndorflow for log test
* pep8!
Diffstat (limited to 'ot/da.py')
-rw-r--r-- | ot/da.py | 13 |
1 files changed, 10 insertions, 3 deletions
@@ -126,8 +126,12 @@ def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10, W = nx.zeros(M.shape, type_as=M) for cpt in range(numItermax): Mreg = M + eta * W - transp = sinkhorn(a, b, Mreg, reg, numItermax=numInnerItermax, - stopThr=stopInnerThr) + if log: + transp, log = sinkhorn(a, b, Mreg, reg, numItermax=numInnerItermax, + stopThr=stopInnerThr, log=True) + else: + transp = sinkhorn(a, b, Mreg, reg, numItermax=numInnerItermax, + stopThr=stopInnerThr) # the transport has been computed. Check if classes are really # separated W = nx.ones(M.shape, type_as=M) @@ -136,7 +140,10 @@ def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10, majs = p * ((majs + epsilon) ** (p - 1)) W[indices_labels[i]] = majs - return transp + if log: + return transp, log + else: + return transp def sinkhorn_l1l2_gl(a, labels_a, b, M, reg, eta=0.1, numItermax=10, |