summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorRĂ©mi Flamary <remi.flamary@gmail.com>2021-11-04 11:00:09 +0100
committerGitHub <noreply@github.com>2021-11-04 11:00:09 +0100
commit2fe69eb130827560ada704bc25998397c4357821 (patch)
tree82973444cc4afc4c42cc7cdaf43a2ebd4b1a6a91 /test
parent9c6ac880d426b7577918b0c77bd74b3b01930ef6 (diff)
[MRG] Make gromov loss differentiable wrt matrices and weights (#302)
* grmov differentable * new stuff * test gromov gradients * fgwdifferentiable * fgw tested * correc name test * add awesome example with gromov optimizatrion * pep8+ typos * damn pep8 * thunbnail * remove prints
Diffstat (limited to 'test')
-rw-r--r--test/test_gromov.py76
1 files changed, 76 insertions, 0 deletions
diff --git a/test/test_gromov.py b/test/test_gromov.py
index 509c54d..bcbcc3a 100644
--- a/test/test_gromov.py
+++ b/test/test_gromov.py
@@ -9,6 +9,7 @@
import numpy as np
import ot
from ot.backend import NumpyBackend
+from ot.backend import torch
import pytest
@@ -74,6 +75,42 @@ def test_gromov(nx):
q, Gb.sum(0), atol=1e-04) # cf convergence gromov
+def test_gromov2_gradients():
+ n_samples = 50 # nb samples
+
+ mu_s = np.array([0, 0])
+ cov_s = np.array([[1, 0], [0, 1]])
+
+ xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=4)
+
+ xt = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=5)
+
+ p = ot.unif(n_samples)
+ q = ot.unif(n_samples)
+
+ C1 = ot.dist(xs, xs)
+ C2 = ot.dist(xt, xt)
+
+ C1 /= C1.max()
+ C2 /= C2.max()
+
+ if torch:
+
+ p1 = torch.tensor(p, requires_grad=True)
+ q1 = torch.tensor(q, requires_grad=True)
+ C11 = torch.tensor(C1, requires_grad=True)
+ C12 = torch.tensor(C2, requires_grad=True)
+
+ val = ot.gromov_wasserstein2(C11, C12, p1, q1)
+
+ val.backward()
+
+ assert q1.shape == q1.grad.shape
+ assert p1.shape == p1.grad.shape
+ assert C11.shape == C11.grad.shape
+ assert C12.shape == C12.grad.shape
+
+
@pytest.skip_backend("jax", reason="test very slow with jax backend")
def test_entropic_gromov(nx):
n_samples = 50 # nb samples
@@ -389,6 +426,45 @@ def test_fgw(nx):
q, Gb.sum(0), atol=1e-04) # cf convergence gromov
+def test_fgw2_gradients():
+ n_samples = 50 # nb samples
+
+ mu_s = np.array([0, 0])
+ cov_s = np.array([[1, 0], [0, 1]])
+
+ xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=4)
+
+ xt = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=5)
+
+ p = ot.unif(n_samples)
+ q = ot.unif(n_samples)
+
+ C1 = ot.dist(xs, xs)
+ C2 = ot.dist(xt, xt)
+ M = ot.dist(xs, xt)
+
+ C1 /= C1.max()
+ C2 /= C2.max()
+
+ if torch:
+
+ p1 = torch.tensor(p, requires_grad=True)
+ q1 = torch.tensor(q, requires_grad=True)
+ C11 = torch.tensor(C1, requires_grad=True)
+ C12 = torch.tensor(C2, requires_grad=True)
+ M1 = torch.tensor(M, requires_grad=True)
+
+ val = ot.fused_gromov_wasserstein2(M1, C11, C12, p1, q1)
+
+ val.backward()
+
+ assert q1.shape == q1.grad.shape
+ assert p1.shape == p1.grad.shape
+ assert C11.shape == C11.grad.shape
+ assert C12.shape == C12.grad.shape
+ assert M1.shape == M1.grad.shape
+
+
def test_fgw_barycenter(nx):
np.random.seed(42)