From ccc076e0fc535b2c734214c0ac1936e9e2cbeb62 Mon Sep 17 00:00:00 2001 From: eloitanguy <69361683+eloitanguy@users.noreply.github.com> Date: Fri, 6 May 2022 08:43:21 +0200 Subject: [WIP] Generalized Wasserstein Barycenters (#372) * GWB first solver version * tests + example for gwb (untested) + free_bar doc fix * improved doc, fixed minor bugs, better example visu * minor doc + visu fixes * plot GWB pep8 fix * fixed partial gromov test reproductibility * added an animation for the GWB visu * added PR num * minor doc fixes + better gwb logo --- test/test_ot.py | 40 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) (limited to 'test/test_ot.py') diff --git a/test/test_ot.py b/test/test_ot.py index bf832f6..ba3ef6a 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -320,6 +320,46 @@ def test_free_support_barycenter_backends(nx): np.testing.assert_allclose(X, nx.to_numpy(X2)) +def test_generalised_free_support_barycenter(): + np.random.seed(42) # random inits + X = [np.array([-1., -1.]).reshape((1, 2)), np.array([1., 1.]).reshape((1, 2))] # two 2D points bar is obviously 0 + a = [np.array([1.]), np.array([1.])] + + P = [np.eye(2), np.eye(2)] + + Y_init = np.array([-12., 7.]).reshape((1, 2)) + + # obvious barycenter location between two 2D diracs + Y_true = np.array([0., .0]).reshape((1, 2)) + + # test without log and no init + Y = ot.lp.generalized_free_support_barycenter(X, a, P, 1) + np.testing.assert_allclose(Y, Y_true, rtol=1e-5, atol=1e-7) + + # test with log and init + Y, _ = ot.lp.generalized_free_support_barycenter(X, a, P, 1, Y_init=Y_init, b=np.array([1.]), log=True) + np.testing.assert_allclose(Y, Y_true, rtol=1e-5, atol=1e-7) + + +def test_generalised_free_support_barycenter_backends(nx): + np.random.seed(42) + X = [np.array([-1.]).reshape((1, 1)), np.array([1.]).reshape((1, 1))] + a = [np.array([1.]), np.array([1.])] + P = [np.array([1.]).reshape((1, 1)), np.array([1.]).reshape((1, 1))] + Y_init = np.array([-12.]).reshape((1, 1)) + + Y = ot.lp.generalized_free_support_barycenter(X, a, P, 1, Y_init=Y_init) + + X2 = nx.from_numpy(*X) + a2 = nx.from_numpy(*a) + P2 = nx.from_numpy(*P) + Y_init2 = nx.from_numpy(Y_init) + + Y2 = ot.lp.generalized_free_support_barycenter(X2, a2, P2, 1, Y_init=Y_init2) + + np.testing.assert_allclose(Y, nx.to_numpy(Y2)) + + @pytest.mark.skipif(not ot.lp.cvx.cvxopt, reason="No cvxopt available") def test_lp_barycenter_cvxopt(): a1 = np.array([1.0, 0, 0])[:, None] -- cgit v1.2.3