From 2fe69eb130827560ada704bc25998397c4357821 Mon Sep 17 00:00:00 2001 From: RĂ©mi Flamary Date: Thu, 4 Nov 2021 11:00:09 +0100 Subject: [MRG] Make gromov loss differentiable wrt matrices and weights (#302) * grmov differentable * new stuff * test gromov gradients * fgwdifferentiable * fgw tested * correc name test * add awesome example with gromov optimizatrion * pep8+ typos * damn pep8 * thunbnail * remove prints --- test/test_gromov.py | 76 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 76 insertions(+) (limited to 'test') diff --git a/test/test_gromov.py b/test/test_gromov.py index 509c54d..bcbcc3a 100644 --- a/test/test_gromov.py +++ b/test/test_gromov.py @@ -9,6 +9,7 @@ import numpy as np import ot from ot.backend import NumpyBackend +from ot.backend import torch import pytest @@ -74,6 +75,42 @@ def test_gromov(nx): q, Gb.sum(0), atol=1e-04) # cf convergence gromov +def test_gromov2_gradients(): + n_samples = 50 # nb samples + + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=4) + + xt = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=5) + + p = ot.unif(n_samples) + q = ot.unif(n_samples) + + C1 = ot.dist(xs, xs) + C2 = ot.dist(xt, xt) + + C1 /= C1.max() + C2 /= C2.max() + + if torch: + + p1 = torch.tensor(p, requires_grad=True) + q1 = torch.tensor(q, requires_grad=True) + C11 = torch.tensor(C1, requires_grad=True) + C12 = torch.tensor(C2, requires_grad=True) + + val = ot.gromov_wasserstein2(C11, C12, p1, q1) + + val.backward() + + assert q1.shape == q1.grad.shape + assert p1.shape == p1.grad.shape + assert C11.shape == C11.grad.shape + assert C12.shape == C12.grad.shape + + @pytest.skip_backend("jax", reason="test very slow with jax backend") def test_entropic_gromov(nx): n_samples = 50 # nb samples @@ -389,6 +426,45 @@ def test_fgw(nx): q, Gb.sum(0), atol=1e-04) # cf convergence gromov +def test_fgw2_gradients(): + n_samples = 50 # nb samples + + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=4) + + xt = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=5) + + p = ot.unif(n_samples) + q = ot.unif(n_samples) + + C1 = ot.dist(xs, xs) + C2 = ot.dist(xt, xt) + M = ot.dist(xs, xt) + + C1 /= C1.max() + C2 /= C2.max() + + if torch: + + p1 = torch.tensor(p, requires_grad=True) + q1 = torch.tensor(q, requires_grad=True) + C11 = torch.tensor(C1, requires_grad=True) + C12 = torch.tensor(C2, requires_grad=True) + M1 = torch.tensor(M, requires_grad=True) + + val = ot.fused_gromov_wasserstein2(M1, C11, C12, p1, q1) + + val.backward() + + assert q1.shape == q1.grad.shape + assert p1.shape == p1.grad.shape + assert C11.shape == C11.grad.shape + assert C12.shape == C12.grad.shape + assert M1.shape == M1.grad.shape + + def test_fgw_barycenter(nx): np.random.seed(42) -- cgit v1.2.3