summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRĂ©mi Flamary <remi.flamary@gmail.com>2023-04-24 17:54:03 +0200
committerGitHub <noreply@github.com>2023-04-24 17:54:03 +0200
commit03ca4ef659a037e400975e3b2116b637a2d94265 (patch)
tree2fff6add4b430a9bb97cf594786777c7e48ea5a5
parent25d72db09ed281c13b97aa8a68d82a4ed5ba7bf0 (diff)
[MRG] make alpha parameter in FGW diferentiable (#463)
* make alpha diferentiable * update release file * debug tensorflow to_numpy
-rw-r--r--RELEASES.md2
-rw-r--r--ot/backend.py6
-rw-r--r--ot/gromov/_gw.py20
-rw-r--r--test/test_gromov.py27
4 files changed, 49 insertions, 6 deletions
diff --git a/RELEASES.md b/RELEASES.md
index 214cc2a..d912215 100644
--- a/RELEASES.md
+++ b/RELEASES.md
@@ -4,6 +4,8 @@
#### New features
+- Make alpha parameter in Fused Gromov Wasserstein differentiable (PR #463)
+
#### Closed issues
- Fix circleci-redirector action and codecov (PR #460)
diff --git a/ot/backend.py b/ot/backend.py
index 74f8366..0dd6fb8 100644
--- a/ot/backend.py
+++ b/ot/backend.py
@@ -1694,10 +1694,12 @@ class TorchBackend(Backend):
self.ValFunction = ValFunction
def _to_numpy(self, a):
+ if isinstance(a, float) or isinstance(a, int) or isinstance(a, np.ndarray):
+ return np.array(a)
return a.cpu().detach().numpy()
def _from_numpy(self, a, type_as=None):
- if isinstance(a, float):
+ if isinstance(a, float) or isinstance(a, int):
a = np.array(a)
if type_as is None:
return torch.from_numpy(a)
@@ -2501,6 +2503,8 @@ class TensorflowBackend(Backend):
)
def _to_numpy(self, a):
+ if isinstance(a, float) or isinstance(a, int) or isinstance(a, np.ndarray):
+ return np.array(a)
return a.numpy()
def _from_numpy(self, a, type_as=None):
diff --git a/ot/gromov/_gw.py b/ot/gromov/_gw.py
index c6e4076..bc4719d 100644
--- a/ot/gromov/_gw.py
+++ b/ot/gromov/_gw.py
@@ -370,7 +370,7 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', symmetric=
Information and Inference: A Journal of the IMA, 8(4), 757-787.
"""
p, q = list_to_array(p, q)
- p0, q0, C10, C20, M0 = p, q, C1, C2, M
+ p0, q0, C10, C20, M0, alpha0 = p, q, C1, C2, M, alpha
if G0 is None:
nx = get_backend(p0, q0, C10, C20, M0)
else:
@@ -382,6 +382,7 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', symmetric=
C1 = nx.to_numpy(C10)
C2 = nx.to_numpy(C20)
M = nx.to_numpy(M0)
+ alpha = nx.to_numpy(alpha0)
if symmetric is None:
symmetric = np.allclose(C1, C1.T, atol=1e-10) and np.allclose(C2, C2.T, atol=1e-10)
@@ -535,10 +536,19 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', symmetric
if loss_fun == 'square_loss':
gC1 = 2 * C1 * nx.outer(p, p) - 2 * nx.dot(T, nx.dot(C2, T.T))
gC2 = 2 * C2 * nx.outer(q, q) - 2 * nx.dot(T.T, nx.dot(C1, T))
- fgw_dist = nx.set_gradients(fgw_dist, (p, q, C1, C2, M),
- (log_fgw['u'] - nx.mean(log_fgw['u']),
- log_fgw['v'] - nx.mean(log_fgw['v']),
- alpha * gC1, alpha * gC2, (1 - alpha) * T))
+ if isinstance(alpha, int) or isinstance(alpha, float):
+ fgw_dist = nx.set_gradients(fgw_dist, (p, q, C1, C2, M),
+ (log_fgw['u'] - nx.mean(log_fgw['u']),
+ log_fgw['v'] - nx.mean(log_fgw['v']),
+ alpha * gC1, alpha * gC2, (1 - alpha) * T))
+ else:
+ lin_term = nx.sum(T * M)
+ gw_term = (fgw_dist - (1 - alpha) * lin_term) / alpha
+ fgw_dist = nx.set_gradients(fgw_dist, (p, q, C1, C2, M, alpha),
+ (log_fgw['u'] - nx.mean(log_fgw['u']),
+ log_fgw['v'] - nx.mean(log_fgw['v']),
+ alpha * gC1, alpha * gC2, (1 - alpha) * T,
+ gw_term - lin_term))
if log:
return fgw_dist, log_fgw
diff --git a/test/test_gromov.py b/test/test_gromov.py
index 80b6df4..f70f410 100644
--- a/test/test_gromov.py
+++ b/test/test_gromov.py
@@ -209,6 +209,8 @@ def test_gromov2_gradients():
if torch.cuda.is_available():
devices.append(torch.device("cuda"))
for device in devices:
+
+ # classical gradients
p1 = torch.tensor(p, requires_grad=True, device=device)
q1 = torch.tensor(q, requires_grad=True, device=device)
C11 = torch.tensor(C1, requires_grad=True, device=device)
@@ -226,6 +228,12 @@ def test_gromov2_gradients():
assert C12.shape == C12.grad.shape
# Test with armijo line-search
+ # classical gradients
+ p1 = torch.tensor(p, requires_grad=True, device=device)
+ q1 = torch.tensor(q, requires_grad=True, device=device)
+ C11 = torch.tensor(C1, requires_grad=True, device=device)
+ C12 = torch.tensor(C2, requires_grad=True, device=device)
+
q1.grad = None
p1.grad = None
C11.grad = None
@@ -830,6 +838,25 @@ def test_fgw2_gradients():
assert C12.shape == C12.grad.shape
assert M1.shape == M1.grad.shape
+ # full gradients with alpha
+ p1 = torch.tensor(p, requires_grad=True, device=device)
+ q1 = torch.tensor(q, requires_grad=True, device=device)
+ C11 = torch.tensor(C1, requires_grad=True, device=device)
+ C12 = torch.tensor(C2, requires_grad=True, device=device)
+ M1 = torch.tensor(M, requires_grad=True, device=device)
+ alpha = torch.tensor(0.5, requires_grad=True, device=device)
+
+ val = ot.fused_gromov_wasserstein2(M1, C11, C12, p1, q1, alpha=alpha)
+
+ val.backward()
+
+ assert val.device == p1.device
+ assert q1.shape == q1.grad.shape
+ assert p1.shape == p1.grad.shape
+ assert C11.shape == C11.grad.shape
+ assert C12.shape == C12.grad.shape
+ assert alpha.shape == alpha.grad.shape
+
def test_fgw_helper_backend(nx):
n_samples = 20 # nb samples