summaryrefslogtreecommitdiff
path: root/test/test_smooth.py
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2018-05-31 11:08:01 +0200
committerRémi Flamary <remi.flamary@gmail.com>2018-05-31 11:08:01 +0200
commit8d288976d398749c5261daca22ee4d772bd5b489 (patch)
tree6548d89d548ceea6a86f54b1cdc6c8de5c2cf5b0 /test/test_smooth.py
parent90efa5a8b189214d1aeb81920b2bb04ce0c261ca (diff)
add smooth.py + first tests
Diffstat (limited to 'test/test_smooth.py')
-rw-r--r--test/test_smooth.py33
1 files changed, 33 insertions, 0 deletions
diff --git a/test/test_smooth.py b/test/test_smooth.py
new file mode 100644
index 0000000..f951bf9
--- /dev/null
+++ b/test/test_smooth.py
@@ -0,0 +1,33 @@
+"""Tests for ot.smooth model """
+
+# Author: Remi Flamary <remi.flamary@unice.fr>
+#
+# License: MIT License
+
+import warnings
+
+import numpy as np
+
+import ot
+from ot.datasets import get_1D_gauss as gauss
+import pytest
+
+
+def test_smooth_ot_dual():
+ # test sinkhorn
+ n = 100
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 2)
+ u = ot.utils.unif(n)
+
+ M = ot.dist(x, x)
+
+ G = ot.smooth.smooth_ot_dual(u, u, M, 1, stopThr=1e-10)
+
+ # check constratints
+ np.testing.assert_allclose(
+ u, G.sum(1), atol=1e-05) # cf convergence sinkhorn
+ np.testing.assert_allclose(
+ u, G.sum(0), atol=1e-05) # cf convergence sinkhorn
+ \ No newline at end of file