diff options
author | Kilian <kilian.fatras@gmail.com> | 2018-06-26 11:10:40 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-06-26 11:10:40 -0700 |
commit | b4bc86176a5712fdd2f930fbf5d1968edd5efa5e (patch) | |
tree | 075c694496216f0a6db61e879ece5eb2e799fc07 /test/test_smooth.py | |
parent | 208ff4627ba28aa3f35328c5928324019c23ddac (diff) | |
parent | 327b0c6e0ccb0c9453179eb316021c34bcdffec4 (diff) |
Merge branch 'master' into stochastic_OT
Diffstat (limited to 'test/test_smooth.py')
-rw-r--r-- | test/test_smooth.py | 79 |
1 files changed, 79 insertions, 0 deletions
diff --git a/test/test_smooth.py b/test/test_smooth.py new file mode 100644 index 0000000..2afa4f8 --- /dev/null +++ b/test/test_smooth.py @@ -0,0 +1,79 @@ +"""Tests for ot.smooth model """ + +# Author: Remi Flamary <remi.flamary@unice.fr> +# +# License: MIT License + +import numpy as np +import ot +import pytest + + +def test_smooth_ot_dual(): + + # get data + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + u = ot.utils.unif(n) + + M = ot.dist(x, x) + + with pytest.raises(NotImplementedError): + Gl2, log = ot.smooth.smooth_ot_dual(u, u, M, 1, reg_type='none') + + Gl2, log = ot.smooth.smooth_ot_dual(u, u, M, 1, reg_type='l2', log=True, stopThr=1e-10) + + # check constratints + np.testing.assert_allclose( + u, Gl2.sum(1), atol=1e-05) # cf convergence sinkhorn + np.testing.assert_allclose( + u, Gl2.sum(0), atol=1e-05) # cf convergence sinkhorn + + # kl regyularisation + G = ot.smooth.smooth_ot_dual(u, u, M, 1, reg_type='kl', stopThr=1e-10) + + # check constratints + np.testing.assert_allclose( + u, G.sum(1), atol=1e-05) # cf convergence sinkhorn + np.testing.assert_allclose( + u, G.sum(0), atol=1e-05) # cf convergence sinkhorn + + G2 = ot.sinkhorn(u, u, M, 1, stopThr=1e-10) + np.testing.assert_allclose(G, G2, atol=1e-05) + + +def test_smooth_ot_semi_dual(): + + # get data + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + u = ot.utils.unif(n) + + M = ot.dist(x, x) + + with pytest.raises(NotImplementedError): + Gl2, log = ot.smooth.smooth_ot_semi_dual(u, u, M, 1, reg_type='none') + + Gl2, log = ot.smooth.smooth_ot_semi_dual(u, u, M, 1, reg_type='l2', log=True, stopThr=1e-10) + + # check constratints + np.testing.assert_allclose( + u, Gl2.sum(1), atol=1e-05) # cf convergence sinkhorn + np.testing.assert_allclose( + u, Gl2.sum(0), atol=1e-05) # cf convergence sinkhorn + + # kl regyularisation + G = ot.smooth.smooth_ot_semi_dual(u, u, M, 1, reg_type='kl', stopThr=1e-10) + + # check constratints + np.testing.assert_allclose( + u, G.sum(1), atol=1e-05) # cf convergence sinkhorn + np.testing.assert_allclose( + u, G.sum(0), atol=1e-05) # cf convergence sinkhorn + + G2 = ot.sinkhorn(u, u, M, 1, stopThr=1e-10) + np.testing.assert_allclose(G, G2, atol=1e-05) |