summaryrefslogtreecommitdiff
path: root/test/test_1d_solver.py
diff options
context:
space:
mode:
authorClément Bonet <32179275+clbonet@users.noreply.github.com>2023-05-05 10:53:48 +0200
committerGitHub <noreply@github.com>2023-05-05 10:53:48 +0200
commit7e0ea27ad9cad31cfc2181430d837c0a77a61568 (patch)
tree0a41128a975500bfef52a4c21b5af634adecc71a /test/test_1d_solver.py
parent83dc498b496087aea293df1445442d8728435211 (diff)
[MRG] Fix bug SSW backend (#471)
* fix bug np vs torch matmul * typo error * einsum projections ssw * Test broadcast matmul * einsum projections ssw * Test broadcast matmul * projections SSW with einsum * reduce number of samples in test wasserstein_circle_unif * Update releases.md --------- Co-authored-by: Rémi Flamary <remi.flamary@gmail.com>
Diffstat (limited to 'test/test_1d_solver.py')
-rw-r--r--test/test_1d_solver.py6
1 files changed, 3 insertions, 3 deletions
diff --git a/test/test_1d_solver.py b/test/test_1d_solver.py
index 21abd1d..075a415 100644
--- a/test/test_1d_solver.py
+++ b/test/test_1d_solver.py
@@ -279,7 +279,7 @@ def test_wasserstein1d_circle_devices(nx):
def test_wasserstein_1d_unif_circle():
# test semidiscrete_wasserstein2_unif_circle versus wasserstein_circle
n = 20
- m = 50000
+ m = 1000
rng = np.random.RandomState(0)
u = rng.rand(n,)
@@ -298,8 +298,8 @@ def test_wasserstein_1d_unif_circle():
wass2_unif_circle = ot.semidiscrete_wasserstein2_unif_circle(u, w_u)
# check loss is similar
- np.testing.assert_allclose(wass2, wass2_unif_circle, atol=1e-3)
- np.testing.assert_allclose(wass2_circle, wass2_unif_circle, atol=1e-3)
+ np.testing.assert_allclose(wass2, wass2_unif_circle, atol=1e-2)
+ np.testing.assert_allclose(wass2_circle, wass2_unif_circle, atol=1e-2)
def test_wasserstein1d_unif_circle_devices(nx):