summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2021-11-04 11:00:09 +0100
committerGitHub <noreply@github.com>2021-11-04 11:00:09 +0100
commit2fe69eb130827560ada704bc25998397c4357821 (patch)
tree82973444cc4afc4c42cc7cdaf43a2ebd4b1a6a91
parent9c6ac880d426b7577918b0c77bd74b3b01930ef6 (diff)
[MRG] Make gromov loss differentiable wrt matrices and weights (#302)
* grmov differentable * new stuff * test gromov gradients * fgwdifferentiable * fgw tested * correc name test * add awesome example with gromov optimizatrion * pep8+ typos * damn pep8 * thunbnail * remove prints
-rw-r--r--README.md9
-rw-r--r--examples/backends/plot_optim_gromov_pytorch.py260
-rw-r--r--ot/__init__.py2
-rw-r--r--ot/gromov.py141
-rw-r--r--ot/optim.py3
-rw-r--r--test/test_gromov.py76
6 files changed, 460 insertions, 31 deletions
diff --git a/README.md b/README.md
index ff32c53..08db003 100644
--- a/README.md
+++ b/README.md
@@ -26,7 +26,7 @@ POT provides the following generic OT solvers (links to examples):
* Debiased Sinkhorn barycenters [Sinkhorn divergence barycenter](https://pythonot.github.io/auto_examples/barycenters/plot_debiased_barycenter.html) [37]
* [Smooth optimal transport solvers](https://pythonot.github.io/auto_examples/plot_OT_1D_smooth.html) (dual and semi-dual) for KL and squared L2 regularizations [17].
* Non regularized [Wasserstein barycenters [16] ](https://pythonot.github.io/auto_examples/barycenters/plot_barycenter_lp_vs_entropic.html)) with LP solver (only small scale).
-* [Gromov-Wasserstein distances](https://pythonot.github.io/auto_examples/gromov/plot_gromov.html) and [GW barycenters](https://pythonot.github.io/auto_examples/gromov/plot_gromov_barycenter.html) (exact [13] and regularized [12])
+* [Gromov-Wasserstein distances](https://pythonot.github.io/auto_examples/gromov/plot_gromov.html) and [GW barycenters](https://pythonot.github.io/auto_examples/gromov/plot_gromov_barycenter.html) (exact [13] and regularized [12]), differentiable using gradients from
* [Fused-Gromov-Wasserstein distances solver](https://pythonot.github.io/auto_examples/gromov/plot_fgw.html#sphx-glr-auto-examples-plot-fgw-py) and [FGW barycenters](https://pythonot.github.io/auto_examples/gromov/plot_barycenter_fgw.html) [24]
* [Stochastic solver](https://pythonot.github.io/auto_examples/plot_stochastic.html) for Large-scale Optimal Transport (semi-dual problem [18] and dual problem [19])
* [Stochastic solver of Gromov Wasserstein](https://pythonot.github.io/auto_examples/gromov/plot_gromov.html) for large-scale problem with any loss functions [33]
@@ -295,5 +295,8 @@ You can also post bug reports and feature requests in Github issues. Make sure t
via optimal transport and diffusions](http://proceedings.mlr.press/v97/liutkus19a/liutkus19a.pdf). In International Conference on
Machine Learning (pp. 4104-4113). PMLR.
-[37] Janati, H., Cuturi, M., Gramfort, A. Proceedings of the 37th International
-Conference on Machine Learning, PMLR 119:4692-4701, 2020 \ No newline at end of file
+[37] Janati, H., Cuturi, M., Gramfort, A. [Debiased sinkhorn barycenters](http://proceedings.mlr.press/v119/janati20a/janati20a.pdf) Proceedings of the 37th International
+Conference on Machine Learning, PMLR 119:4692-4701, 2020
+
+[38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, [Online Graph
+Dictionary Learning](https://arxiv.org/pdf/2102.06555.pdf), International Conference on Machine Learning (ICML), 2021. \ No newline at end of file
diff --git a/examples/backends/plot_optim_gromov_pytorch.py b/examples/backends/plot_optim_gromov_pytorch.py
new file mode 100644
index 0000000..465f612
--- /dev/null
+++ b/examples/backends/plot_optim_gromov_pytorch.py
@@ -0,0 +1,260 @@
+r"""
+=================================
+Optimizing the Gromov-Wasserstein distance with PyTorch
+=================================
+
+In this exemple we use the pytorch backend to optimize the Gromov-Wasserstein
+(GW) loss between two graphs expressed as empirical distribution.
+
+In the first example we optimize the weights on the node of a simple template
+graph so that it minimizes the GW with a given Stochastic Block Model graph.
+We can see that this actually recovers the proportion of classes in the SBM
+and allows for an accurate clustering of the nodes using the GW optimal plan.
+
+In a second example we optimize simultaneously the weights and the sructure of
+the template graph which allows us to perform graph compression and to recover
+other properties of the SBM.
+
+The backend actually uses the gradients expressed in [38] to optimize the
+weights.
+
+[38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, Online Graph
+Dictionary Learning, International Conference on Machine Learning (ICML), 2021.
+
+"""
+# Author: Rémi Flamary <remi.flamary@polytechnique.edu>
+#
+# License: MIT License
+
+# sphinx_gallery_thumbnail_number = 3
+
+from sklearn.manifold import MDS
+import numpy as np
+import matplotlib.pylab as pl
+import torch
+
+import ot
+from ot.gromov import gromov_wasserstein2
+
+# %%
+# Graph generation
+# ---------------
+
+rng = np.random.RandomState(42)
+
+
+def get_sbm(n, nc, ratio, P):
+ nbpc = np.round(n * ratio).astype(int)
+ n = np.sum(nbpc)
+ C = np.zeros((n, n))
+ for c1 in range(nc):
+ for c2 in range(c1 + 1):
+ if c1 == c2:
+ for i in range(np.sum(nbpc[:c1]), np.sum(nbpc[:c1 + 1])):
+ for j in range(np.sum(nbpc[:c2]), i):
+ if rng.rand() <= P[c1, c2]:
+ C[i, j] = 1
+ else:
+ for i in range(np.sum(nbpc[:c1]), np.sum(nbpc[:c1 + 1])):
+ for j in range(np.sum(nbpc[:c2]), np.sum(nbpc[:c2 + 1])):
+ if rng.rand() <= P[c1, c2]:
+ C[i, j] = 1
+
+ return C + C.T
+
+
+n = 100
+nc = 3
+ratio = np.array([.5, .3, .2])
+P = np.array(0.6 * np.eye(3) + 0.05 * np.ones((3, 3)))
+C1 = get_sbm(n, nc, ratio, P)
+
+# get 2d position for nodes
+x1 = MDS(dissimilarity='precomputed', random_state=0).fit_transform(1 - C1)
+
+
+def plot_graph(x, C, color='C0', s=None):
+ for j in range(C.shape[0]):
+ for i in range(j):
+ if C[i, j] > 0:
+ pl.plot([x[i, 0], x[j, 0]], [x[i, 1], x[j, 1]], alpha=0.2, color='k')
+ pl.scatter(x[:, 0], x[:, 1], c=color, s=s, zorder=10, edgecolors='k', cmap='tab10', vmax=9)
+
+
+pl.figure(1, (10, 5))
+pl.clf()
+pl.subplot(1, 2, 1)
+plot_graph(x1, C1, color='C0')
+pl.title("SBM Graph")
+pl.axis("off")
+pl.subplot(1, 2, 2)
+pl.imshow(C1, interpolation='nearest')
+pl.title("Adjacency matrix")
+pl.axis("off")
+
+
+# %%
+# Optimizing the weights of a simple template C0=eye(3) to fit Graph 1
+# ------------------------------------------------
+# The adajacency matrix C1 is block diagonal with 3 blocks. We want to
+# optimize the weights of a simple template C0=eye(3) and see if we can
+# recover the proportion of classes from the SBM (up to a permutation).
+
+C0 = np.eye(3)
+
+
+def min_weight_gw(C1, C2, a2, nb_iter_max=100, lr=1e-2):
+ """ solve min_a GW(C1,C2,a, a2) by gradient descent"""
+
+ # use pyTorch for our data
+ C1_torch = torch.tensor(C1)
+ C2_torch = torch.tensor(C2)
+
+ a0 = rng.rand(C1.shape[0]) # random_init
+ a0 /= a0.sum() # on simplex
+ a1_torch = torch.tensor(a0).requires_grad_(True)
+ a2_torch = torch.tensor(a2)
+
+ loss_iter = []
+
+ for i in range(nb_iter_max):
+
+ loss = gromov_wasserstein2(C1_torch, C2_torch, a1_torch, a2_torch)
+
+ loss_iter.append(loss.clone().detach().cpu().numpy())
+ loss.backward()
+
+ #print("{:03d} | {}".format(i, loss_iter[-1]))
+
+ # performs a step of projected gradient descent
+ with torch.no_grad():
+ grad = a1_torch.grad
+ a1_torch -= grad * lr # step
+ a1_torch.grad.zero_()
+ a1_torch.data = ot.utils.proj_simplex(a1_torch)
+
+ a1 = a1_torch.clone().detach().cpu().numpy()
+
+ return a1, loss_iter
+
+
+a0_est, loss_iter0 = min_weight_gw(C0, C1, ot.unif(n), nb_iter_max=100, lr=1e-2)
+
+pl.figure(2)
+pl.plot(loss_iter0)
+pl.title("Loss along iterations")
+
+print("Estimated weights : ", a0_est)
+print("True proportions : ", ratio)
+
+
+# %%
+# It is clear that the optimization has converged and that we recover the
+# ratio of the different classes in the SBM graph up to a permutation.
+
+
+# %%
+# Community clustering with uniform and estimated weights
+# --------------------------------------------
+# The GW OT plan can be used to perform a clustering of the nodes of a graph
+# when computing the GW with a simple template like C0 by labeling nodes in
+# the original graph using by the index of the noe in the template receiving
+# the most mass.
+#
+# We show here the result of such a clustering when using uniform weights on
+# the template C0 and when using the optimal weights previously estimated.
+
+
+T_unif = ot.gromov_wasserstein(C1, C0, ot.unif(n), ot.unif(3))
+label_unif = T_unif.argmax(1)
+
+T_est = ot.gromov_wasserstein(C1, C0, ot.unif(n), a0_est)
+label_est = T_est.argmax(1)
+
+pl.figure(3, (10, 5))
+pl.clf()
+pl.subplot(1, 2, 1)
+plot_graph(x1, C1, color=label_unif)
+pl.title("Graph clustering unif. weights")
+pl.axis("off")
+pl.subplot(1, 2, 2)
+plot_graph(x1, C1, color=label_est)
+pl.title("Graph clustering est. weights")
+pl.axis("off")
+
+
+# %%
+# Graph compression with GW
+# -------------------------
+
+# Now we optimize both the weights and structure of a small graph that
+# minimize the GW distance wrt our data graph. This can be seen as graph
+# compression but can also recover important properties of an SBM such
+# as its class proportion but also its matrix of probability of links between
+# classes
+
+
+def graph_compession_gw(nb_nodes, C2, a2, nb_iter_max=100, lr=1e-2):
+ """ solve min_a GW(C1,C2,a, a2) by gradient descent"""
+
+ # use pyTorch for our data
+
+ C2_torch = torch.tensor(C2)
+ a2_torch = torch.tensor(a2)
+
+ a0 = rng.rand(nb_nodes) # random_init
+ a0 /= a0.sum() # on simplex
+ a1_torch = torch.tensor(a0).requires_grad_(True)
+ C0 = np.eye(nb_nodes)
+ C1_torch = torch.tensor(C0).requires_grad_(True)
+
+ loss_iter = []
+
+ for i in range(nb_iter_max):
+
+ loss = gromov_wasserstein2(C1_torch, C2_torch, a1_torch, a2_torch)
+
+ loss_iter.append(loss.clone().detach().cpu().numpy())
+ loss.backward()
+
+ #print("{:03d} | {}".format(i, loss_iter[-1]))
+
+ # performs a step of projected gradient descent
+ with torch.no_grad():
+ grad = a1_torch.grad
+ a1_torch -= grad * lr # step
+ a1_torch.grad.zero_()
+ a1_torch.data = ot.utils.proj_simplex(a1_torch)
+
+ grad = C1_torch.grad
+ C1_torch -= grad * lr # step
+ C1_torch.grad.zero_()
+ C1_torch.data = torch.clamp(C1_torch, 0, 1)
+
+ a1 = a1_torch.clone().detach().cpu().numpy()
+ C1 = C1_torch.clone().detach().cpu().numpy()
+
+ return a1, C1, loss_iter
+
+
+nb_nodes = 3
+a0_est2, C0_est2, loss_iter2 = graph_compession_gw(nb_nodes, C1, ot.unif(n),
+ nb_iter_max=100, lr=5e-2)
+
+pl.figure(4)
+pl.plot(loss_iter2)
+pl.title("Loss along iterations")
+
+
+print("Estimated weights : ", a0_est2)
+print("True proportions : ", ratio)
+
+pl.figure(6, (10, 3.5))
+pl.clf()
+pl.subplot(1, 2, 1)
+pl.imshow(P, vmin=0, vmax=1)
+pl.title('True SBM P matrix')
+pl.subplot(1, 2, 2)
+pl.imshow(C0_est2, vmin=0, vmax=1)
+pl.title('Estimated C0 matrix')
+pl.colorbar()
diff --git a/ot/__init__.py b/ot/__init__.py
index f20332c..4292b41 100644
--- a/ot/__init__.py
+++ b/ot/__init__.py
@@ -43,6 +43,8 @@ from .unbalanced import (sinkhorn_unbalanced, barycenter_unbalanced,
sinkhorn_unbalanced2)
from .da import sinkhorn_lpl1_mm
from .sliced import sliced_wasserstein_distance, max_sliced_wasserstein_distance
+from .gromov import (gromov_wasserstein, gromov_wasserstein2,
+ gromov_barycenters, fused_gromov_wasserstein, fused_gromov_wasserstein2)
# utils functions
from .utils import dist, unif, tic, toc, toq
diff --git a/ot/gromov.py b/ot/gromov.py
index 465693d..ea667e4 100644
--- a/ot/gromov.py
+++ b/ot/gromov.py
@@ -174,7 +174,7 @@ def tensor_product(constC, hC1, hC2, T):
def gwloss(constC, hC1, hC2, T):
- """Return the Loss for Gromov-Wasserstein
+ r"""Return the Loss for Gromov-Wasserstein
The loss is computed as described in Proposition 1 Eq. (6) in :ref:`[12] <references-gwloss>`
@@ -213,7 +213,7 @@ def gwloss(constC, hC1, hC2, T):
def gwggrad(constC, hC1, hC2, T):
- """Return the gradient for Gromov-Wasserstein
+ r"""Return the gradient for Gromov-Wasserstein
The gradient is computed as described in Proposition 2 in :ref:`[12] <references-gwggrad>`
@@ -247,7 +247,7 @@ def gwggrad(constC, hC1, hC2, T):
def update_square_loss(p, lambdas, T, Cs):
- """
+ r"""
Updates :math:`\mathbf{C}` according to the L2 Loss kernel with the `S` :math:`\mathbf{T}_s`
couplings calculated at each iteration
@@ -284,7 +284,7 @@ def update_square_loss(p, lambdas, T, Cs):
def update_kl_loss(p, lambdas, T, Cs):
- """
+ r"""
Updates :math:`\mathbf{C}` according to the KL Loss kernel with the `S` :math:`\mathbf{T}_s` couplings calculated at each iteration
@@ -320,7 +320,7 @@ def update_kl_loss(p, lambdas, T, Cs):
return nx.exp(tmpsum / ppt)
-def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs):
+def gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=False, **kwargs):
r"""
Returns the gromov-wasserstein transport between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})`
@@ -386,6 +386,14 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs
"""
p, q = list_to_array(p, q)
+ p0, q0, C10, C20 = p, q, C1, C2
+ nx = get_backend(p0, q0, C10, C20)
+
+ p = nx.to_numpy(p)
+ q = nx.to_numpy(q)
+ C1 = nx.to_numpy(C10)
+ C2 = nx.to_numpy(C20)
+
constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun)
G0 = p[:, None] * q[None, :]
@@ -398,13 +406,15 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs
if log:
res, log = cg(p, q, 0, 1, f, df, G0, log=True, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs)
- log['gw_dist'] = gwloss(constC, hC1, hC2, res)
- return res, log
+ log['gw_dist'] = nx.from_numpy(gwloss(constC, hC1, hC2, res), type_as=C10)
+ log['u'] = nx.from_numpy(log['u'], type_as=C10)
+ log['v'] = nx.from_numpy(log['v'], type_as=C10)
+ return nx.from_numpy(res, type_as=C10), log
else:
- return cg(p, q, 0, 1, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs)
+ return nx.from_numpy(cg(p, q, 0, 1, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=False, **kwargs), type_as=C10)
-def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs):
+def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=False, **kwargs):
r"""
Returns the gromov-wasserstein discrepancy between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})`
@@ -420,7 +430,11 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwarg
- :math:`\mathbf{C_2}`: Metric cost matrix in the target space
- :math:`\mathbf{p}`: distribution in the source space
- :math:`\mathbf{q}`: distribution in the target space
- - `L`: loss function to account for the misfit between the similarity matrices
+ - `L`: loss function to account for the misfit between the similarity
+ matrices
+
+ Note that when using backends, this loss function is differentiable wrt the
+ marices and weights for quadratic loss using the gradients from [38]_.
Parameters
----------
@@ -463,9 +477,21 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwarg
metric approach to object matching. Foundations of computational
mathematics 11.4 (2011): 417-487.
+ .. [38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, Online
+ Graph Dictionary Learning, International Conference on Machine Learning
+ (ICML), 2021.
+
"""
p, q = list_to_array(p, q)
+ p0, q0, C10, C20 = p, q, C1, C2
+ nx = get_backend(p0, q0, C10, C20)
+
+ p = nx.to_numpy(p)
+ q = nx.to_numpy(q)
+ C1 = nx.to_numpy(C10)
+ C2 = nx.to_numpy(C20)
+
constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun)
G0 = p[:, None] * q[None, :]
@@ -475,13 +501,28 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwarg
def df(G):
return gwggrad(constC, hC1, hC2, G)
- res, log_gw = cg(p, q, 0, 1, f, df, G0, log=True, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs)
- log_gw['gw_dist'] = gwloss(constC, hC1, hC2, res)
- log_gw['T'] = res
+
+ T, log_gw = cg(p, q, 0, 1, f, df, G0, log=True, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs)
+
+ T0 = nx.from_numpy(T, type_as=C10)
+
+ log_gw['gw_dist'] = nx.from_numpy(gwloss(constC, hC1, hC2, T), type_as=C10)
+ log_gw['u'] = nx.from_numpy(log_gw['u'], type_as=C10)
+ log_gw['v'] = nx.from_numpy(log_gw['v'], type_as=C10)
+ log_gw['T'] = T0
+
+ gw = log_gw['gw_dist']
+
+ if loss_fun == 'square_loss':
+ gC1 = nx.from_numpy(2 * C1 * (p[:, None] * p[None, :]) - 2 * T.dot(C2).dot(T.T))
+ gC2 = nx.from_numpy(2 * C2 * (q[:, None] * q[None, :]) - 2 * T.T.dot(C1).dot(T))
+ gw = nx.set_gradients(gw, (p0, q0, C10, C20),
+ (log_gw['u'], log_gw['v'], gC1, gC2))
+
if log:
- return log_gw['gw_dist'], log_gw
+ return gw, log_gw
else:
- return log_gw['gw_dist']
+ return gw
def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, **kwargs):
@@ -548,6 +589,15 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5,
"""
p, q = list_to_array(p, q)
+ p0, q0, C10, C20, M0 = p, q, C1, C2, M
+ nx = get_backend(p0, q0, C10, C20, M0)
+
+ p = nx.to_numpy(p)
+ q = nx.to_numpy(q)
+ C1 = nx.to_numpy(C10)
+ C2 = nx.to_numpy(C20)
+ M = nx.to_numpy(M0)
+
constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun)
G0 = p[:, None] * q[None, :]
@@ -560,10 +610,16 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5,
if log:
res, log = cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs)
- log['fgw_dist'] = log['loss'][::-1][0]
- return res, log
+
+ fgw_dist = nx.from_numpy(log['loss'][-1], type_as=C10)
+
+ log['fgw_dist'] = fgw_dist
+ log['u'] = nx.from_numpy(log['u'], type_as=C10)
+ log['v'] = nx.from_numpy(log['v'], type_as=C10)
+ return nx.from_numpy(res, type_as=C10), log
+
else:
- return cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs)
+ return nx.from_numpy(cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs), type_as=C10)
def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, **kwargs):
@@ -586,7 +642,11 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5
- :math:`\mathbf{p}` and :math:`\mathbf{q}` are source and target weights (sum to 1)
- `L` is a loss function to account for the misfit between the similarity matrices
- The algorithm used for solving the problem is conditional gradient as discussed in :ref:`[24] <references-fused-gromov-wasserstein2>`
+ The algorithm used for solving the problem is conditional gradient as
+ discussed in :ref:`[24] <references-fused-gromov-wasserstein2>`
+
+ Note that when using backends, this loss function is differentiable wrt the
+ marices and weights for quadratic loss using the gradients from [38]_.
Parameters
----------
@@ -627,9 +687,22 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5
and Courty Nicolas
"Optimal Transport for structured data with application on graphs"
International Conference on Machine Learning (ICML). 2019.
+
+ .. [38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, Online
+ Graph Dictionary Learning, International Conference on Machine Learning
+ (ICML), 2021.
"""
p, q = list_to_array(p, q)
+ p0, q0, C10, C20, M0 = p, q, C1, C2, M
+ nx = get_backend(p0, q0, C10, C20, M0)
+
+ p = nx.to_numpy(p)
+ q = nx.to_numpy(q)
+ C1 = nx.to_numpy(C10)
+ C2 = nx.to_numpy(C20)
+ M = nx.to_numpy(M0)
+
constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun)
G0 = p[:, None] * q[None, :]
@@ -640,13 +713,27 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5
def df(G):
return gwggrad(constC, hC1, hC2, G)
- res, log = cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs)
+ T, log_fgw = cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs)
+
+ fgw_dist = nx.from_numpy(log_fgw['loss'][-1], type_as=C10)
+
+ T0 = nx.from_numpy(T, type_as=C10)
+
+ log_fgw['fgw_dist'] = fgw_dist
+ log_fgw['u'] = nx.from_numpy(log_fgw['u'], type_as=C10)
+ log_fgw['v'] = nx.from_numpy(log_fgw['v'], type_as=C10)
+ log_fgw['T'] = T0
+
+ if loss_fun == 'square_loss':
+ gC1 = nx.from_numpy(2 * C1 * (p[:, None] * p[None, :]) - 2 * T.dot(C2).dot(T.T))
+ gC2 = nx.from_numpy(2 * C2 * (q[:, None] * q[None, :]) - 2 * T.T.dot(C1).dot(T))
+ fgw_dist = nx.set_gradients(fgw_dist, (p0, q0, C10, C20, M0),
+ (log_fgw['u'], log_fgw['v'], alpha * gC1, alpha * gC2, (1 - alpha) * T0))
+
if log:
- log['fgw_dist'] = log['loss'][::-1][0]
- log['T'] = res
- return log['fgw_dist'], log
+ return fgw_dist, log_fgw
else:
- return log['fgw_dist']
+ return fgw_dist
def GW_distance_estimation(C1, C2, p, q, loss_fun, T,
@@ -1447,7 +1534,7 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun,
def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_features=False,
p=None, loss_fun='square_loss', max_iter=100, tol=1e-9,
verbose=False, log=False, init_C=None, init_X=None, random_state=None):
- """Compute the fgw barycenter as presented eq (5) in :ref:`[24] <references-fgw-barycenters>`
+ r"""Compute the fgw barycenter as presented eq (5) in :ref:`[24] <references-fgw-barycenters>`
Parameters
----------
@@ -1604,7 +1691,7 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_
def update_structure_matrix(p, lambdas, T, Cs):
- """Updates :math:`\mathbf{C}` according to the L2 Loss kernel with the `S` :math:`\mathbf{T}_s` couplings.
+ r"""Updates :math:`\mathbf{C}` according to the L2 Loss kernel with the `S` :math:`\mathbf{T}_s` couplings.
It is calculated at each iteration
@@ -1640,7 +1727,7 @@ def update_structure_matrix(p, lambdas, T, Cs):
def update_feature_matrix(lambdas, Ys, Ts, p):
- """Updates the feature with respect to the `S` :math:`\mathbf{T}_s` couplings.
+ r"""Updates the feature with respect to the `S` :math:`\mathbf{T}_s` couplings.
See "Solving the barycenter problem with Block Coordinate Descent (BCD)"
diff --git a/ot/optim.py b/ot/optim.py
index cc286b6..bd8ca26 100644
--- a/ot/optim.py
+++ b/ot/optim.py
@@ -267,7 +267,7 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000,
Mi += nx.min(Mi)
# solve linear program
- Gc = emd(a, b, Mi, numItermax=numItermaxEmd)
+ Gc, logemd = emd(a, b, Mi, numItermax=numItermaxEmd, log=True)
deltaG = Gc - G
@@ -297,6 +297,7 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000,
print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, f_val, relative_delta_fval, abs_delta_fval))
if log:
+ log.update(logemd)
return G, log
else:
return G
diff --git a/test/test_gromov.py b/test/test_gromov.py
index 509c54d..bcbcc3a 100644
--- a/test/test_gromov.py
+++ b/test/test_gromov.py
@@ -9,6 +9,7 @@
import numpy as np
import ot
from ot.backend import NumpyBackend
+from ot.backend import torch
import pytest
@@ -74,6 +75,42 @@ def test_gromov(nx):
q, Gb.sum(0), atol=1e-04) # cf convergence gromov
+def test_gromov2_gradients():
+ n_samples = 50 # nb samples
+
+ mu_s = np.array([0, 0])
+ cov_s = np.array([[1, 0], [0, 1]])
+
+ xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=4)
+
+ xt = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=5)
+
+ p = ot.unif(n_samples)
+ q = ot.unif(n_samples)
+
+ C1 = ot.dist(xs, xs)
+ C2 = ot.dist(xt, xt)
+
+ C1 /= C1.max()
+ C2 /= C2.max()
+
+ if torch:
+
+ p1 = torch.tensor(p, requires_grad=True)
+ q1 = torch.tensor(q, requires_grad=True)
+ C11 = torch.tensor(C1, requires_grad=True)
+ C12 = torch.tensor(C2, requires_grad=True)
+
+ val = ot.gromov_wasserstein2(C11, C12, p1, q1)
+
+ val.backward()
+
+ assert q1.shape == q1.grad.shape
+ assert p1.shape == p1.grad.shape
+ assert C11.shape == C11.grad.shape
+ assert C12.shape == C12.grad.shape
+
+
@pytest.skip_backend("jax", reason="test very slow with jax backend")
def test_entropic_gromov(nx):
n_samples = 50 # nb samples
@@ -389,6 +426,45 @@ def test_fgw(nx):
q, Gb.sum(0), atol=1e-04) # cf convergence gromov
+def test_fgw2_gradients():
+ n_samples = 50 # nb samples
+
+ mu_s = np.array([0, 0])
+ cov_s = np.array([[1, 0], [0, 1]])
+
+ xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=4)
+
+ xt = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=5)
+
+ p = ot.unif(n_samples)
+ q = ot.unif(n_samples)
+
+ C1 = ot.dist(xs, xs)
+ C2 = ot.dist(xt, xt)
+ M = ot.dist(xs, xt)
+
+ C1 /= C1.max()
+ C2 /= C2.max()
+
+ if torch:
+
+ p1 = torch.tensor(p, requires_grad=True)
+ q1 = torch.tensor(q, requires_grad=True)
+ C11 = torch.tensor(C1, requires_grad=True)
+ C12 = torch.tensor(C2, requires_grad=True)
+ M1 = torch.tensor(M, requires_grad=True)
+
+ val = ot.fused_gromov_wasserstein2(M1, C11, C12, p1, q1)
+
+ val.backward()
+
+ assert q1.shape == q1.grad.shape
+ assert p1.shape == p1.grad.shape
+ assert C11.shape == C11.grad.shape
+ assert C12.shape == C12.grad.shape
+ assert M1.shape == M1.grad.shape
+
+
def test_fgw_barycenter(nx):
np.random.seed(42)