summaryrefslogtreecommitdiff
path: root/test/test_ot.py
diff options
context:
space:
mode:
authoreloitanguy <69361683+eloitanguy@users.noreply.github.com>2022-05-06 08:43:21 +0200
committerGitHub <noreply@github.com>2022-05-06 08:43:21 +0200
commitccc076e0fc535b2c734214c0ac1936e9e2cbeb62 (patch)
treeb5a20af6fabcaefa0de4bc27afd9049bd15612f6 /test/test_ot.py
parenteccb1386eea52b94b82456d126bd20cbe3198e05 (diff)
[WIP] Generalized Wasserstein Barycenters (#372)
* GWB first solver version * tests + example for gwb (untested) + free_bar doc fix * improved doc, fixed minor bugs, better example visu * minor doc + visu fixes * plot GWB pep8 fix * fixed partial gromov test reproductibility * added an animation for the GWB visu * added PR num * minor doc fixes + better gwb logo
Diffstat (limited to 'test/test_ot.py')
-rw-r--r--test/test_ot.py40
1 files changed, 40 insertions, 0 deletions
diff --git a/test/test_ot.py b/test/test_ot.py
index bf832f6..ba3ef6a 100644
--- a/test/test_ot.py
+++ b/test/test_ot.py
@@ -320,6 +320,46 @@ def test_free_support_barycenter_backends(nx):
np.testing.assert_allclose(X, nx.to_numpy(X2))
+def test_generalised_free_support_barycenter():
+ np.random.seed(42) # random inits
+ X = [np.array([-1., -1.]).reshape((1, 2)), np.array([1., 1.]).reshape((1, 2))] # two 2D points bar is obviously 0
+ a = [np.array([1.]), np.array([1.])]
+
+ P = [np.eye(2), np.eye(2)]
+
+ Y_init = np.array([-12., 7.]).reshape((1, 2))
+
+ # obvious barycenter location between two 2D diracs
+ Y_true = np.array([0., .0]).reshape((1, 2))
+
+ # test without log and no init
+ Y = ot.lp.generalized_free_support_barycenter(X, a, P, 1)
+ np.testing.assert_allclose(Y, Y_true, rtol=1e-5, atol=1e-7)
+
+ # test with log and init
+ Y, _ = ot.lp.generalized_free_support_barycenter(X, a, P, 1, Y_init=Y_init, b=np.array([1.]), log=True)
+ np.testing.assert_allclose(Y, Y_true, rtol=1e-5, atol=1e-7)
+
+
+def test_generalised_free_support_barycenter_backends(nx):
+ np.random.seed(42)
+ X = [np.array([-1.]).reshape((1, 1)), np.array([1.]).reshape((1, 1))]
+ a = [np.array([1.]), np.array([1.])]
+ P = [np.array([1.]).reshape((1, 1)), np.array([1.]).reshape((1, 1))]
+ Y_init = np.array([-12.]).reshape((1, 1))
+
+ Y = ot.lp.generalized_free_support_barycenter(X, a, P, 1, Y_init=Y_init)
+
+ X2 = nx.from_numpy(*X)
+ a2 = nx.from_numpy(*a)
+ P2 = nx.from_numpy(*P)
+ Y_init2 = nx.from_numpy(Y_init)
+
+ Y2 = ot.lp.generalized_free_support_barycenter(X2, a2, P2, 1, Y_init=Y_init2)
+
+ np.testing.assert_allclose(Y, nx.to_numpy(Y2))
+
+
@pytest.mark.skipif(not ot.lp.cvx.cvxopt, reason="No cvxopt available")
def test_lp_barycenter_cvxopt():
a1 = np.array([1.0, 0, 0])[:, None]