From a5e0f0d40d5046a6639924347ef97e2ac80ad0c9 Mon Sep 17 00:00:00 2001 From: RĂ©mi Flamary Date: Wed, 2 Feb 2022 11:53:12 +0100 Subject: [MRG] Add weak OT solver (#341) * add info in release file * update tests * pep8 * add weak OT example * update plot in doc * correction ewample with empirical sinkhorn * better thumbnail * comment from review * update documenation --- test/test_weak.py | 54 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 54 insertions(+) create mode 100644 test/test_weak.py (limited to 'test/test_weak.py') diff --git a/test/test_weak.py b/test/test_weak.py new file mode 100644 index 0000000..c4c3278 --- /dev/null +++ b/test/test_weak.py @@ -0,0 +1,54 @@ +"""Tests for main module ot.weak """ + +# Author: Remi Flamary +# +# License: MIT License + +import ot +import numpy as np + + +def test_weak_ot(): + # test weak ot solver and identity stationary point + n = 50 + rng = np.random.RandomState(0) + + xs = rng.randn(n, 2) + xt = rng.randn(n, 2) + u = ot.utils.unif(n) + + G, log = ot.weak_optimal_transport(xs, xt, u, u, log=True) + + # check constraints + np.testing.assert_allclose(u, G.sum(1)) + np.testing.assert_allclose(u, G.sum(0)) + + # chaeck that identity is recovered + G = ot.weak_optimal_transport(xs, xs, G0=np.eye(n) / n) + + # check G is identity + np.testing.assert_allclose(G, np.eye(n) / n) + + # check constraints + np.testing.assert_allclose(u, G.sum(1)) + np.testing.assert_allclose(u, G.sum(0)) + + +def test_weak_ot_bakends(nx): + # test weak ot solver for different backends + n = 50 + rng = np.random.RandomState(0) + + xs = rng.randn(n, 2) + xt = rng.randn(n, 2) + u = ot.utils.unif(n) + + G = ot.weak_optimal_transport(xs, xt, u, u) + + xs2 = nx.from_numpy(xs) + xt2 = nx.from_numpy(xt) + u2 = nx.from_numpy(u) + + G2 = ot.weak_optimal_transport(xs2, xt2, u2, u2) + + np.testing.assert_allclose(nx.to_numpy(G2), G) -- cgit v1.2.3