summaryrefslogtreecommitdiff
path: root/ot/bregman.py
diff options
context:
space:
mode:
authorMokhtar Z. Alaya <mokhtarzahdi.alaya@gmail.com>2020-01-07 16:21:40 +0100
committerGitHub <noreply@github.com>2020-01-07 16:21:40 +0100
commit92b7075568207a468cc821cd6a21e130b9d89f96 (patch)
treed00c7312ef96e2f1bca5f29936fd8f4846afe35f /ot/bregman.py
parent05a97b44a7137ef6cb0397cca3bb2ea1f8736ac5 (diff)
replace reshape by numpy slicing in return
Diffstat (limited to 'ot/bregman.py')
-rw-r--r--ot/bregman.py12
1 files changed, 10 insertions, 2 deletions
diff --git a/ot/bregman.py b/ot/bregman.py
index 8a20307..8cfea7e 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -2107,6 +2107,14 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=True, rest
vsc_full = np.full(nt, epsilon * kappa)
usc_full[I] = usc
vsc_full[J] = vsc
+
+ if log:
+ log['u'] = usc_full
+ log['v'] = vsc_full
+
+ gamma = usc_full[:, None] * K * vsc_full[None, :]
- Gsc = usc_full.reshape((-1, 1)) * K * vsc_full.reshape((1, -1))
- return Gsc
+ if log:
+ return gamma, log
+ else:
+ return gamma