summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ot/bregman.py3
-rw-r--r--test/test_bregman.py8
2 files changed, 6 insertions, 5 deletions
diff --git a/ot/bregman.py b/ot/bregman.py
index aff9f8c..c304b5d 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -2117,10 +2117,11 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res
log['v'] = vsc_full
log['Isel'] = Isel
log['Jsel'] = Jsel
+
gamma = usc_full[:, None] * K * vsc_full[None, :]
gamma = gamma / gamma.sum()
if log:
return gamma, log
else:
- return gamma \ No newline at end of file
+ return gamma
diff --git a/test/test_bregman.py b/test/test_bregman.py
index 2398d45..e376715 100644
--- a/test/test_bregman.py
+++ b/test/test_bregman.py
@@ -348,7 +348,7 @@ def test_screenkhorn():
x = rng.randn(n, 2)
M = ot.dist(x, x)
- G_sink = ot.sinkhorn(a, b, M, 1e-03)
- G_screen = ot.bregman.screenkhorn(a, b, M, 1e-03, uniform=True, verbose=True)
- np.testing.assert_allclose(G_sink.sum(0), G_screen.sum(0), atol=1e-02)
- np.testing.assert_allclose(G_sink.sum(1), G_screen.sum(1), atol=1e-02) \ No newline at end of file
+ G_s = ot.sinkhorn(a, b, M, 1e-03)
+ G_sc = ot.bregman.screenkhorn(a, b, M, 1e-03, uniform=True, verbose=True)
+ np.testing.assert_allclose(G_s.sum(0), G_sc.sum(0), atol=1e-02)
+ np.testing.assert_allclose(G_s.sum(1), G_sc.sum(1), atol=1e-02) \ No newline at end of file