diff options
-rw-r--r-- | ot/bregman.py | 12 |
1 files changed, 10 insertions, 2 deletions
diff --git a/ot/bregman.py b/ot/bregman.py index 8a20307..8cfea7e 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -2107,6 +2107,14 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=True, rest vsc_full = np.full(nt, epsilon * kappa) usc_full[I] = usc vsc_full[J] = vsc + + if log: + log['u'] = usc_full + log['v'] = vsc_full + + gamma = usc_full[:, None] * K * vsc_full[None, :] - Gsc = usc_full.reshape((-1, 1)) * K * vsc_full.reshape((1, -1)) - return Gsc + if log: + return gamma, log + else: + return gamma |