summaryrefslogtreecommitdiff
path: root/ot
diff options
context:
space:
mode:
authorRĂ©mi Flamary <remi.flamary@gmail.com>2021-11-04 11:00:09 +0100
committerGitHub <noreply@github.com>2021-11-04 11:00:09 +0100
commit2fe69eb130827560ada704bc25998397c4357821 (patch)
tree82973444cc4afc4c42cc7cdaf43a2ebd4b1a6a91 /ot
parent9c6ac880d426b7577918b0c77bd74b3b01930ef6 (diff)
[MRG] Make gromov loss differentiable wrt matrices and weights (#302)
* grmov differentable * new stuff * test gromov gradients * fgwdifferentiable * fgw tested * correc name test * add awesome example with gromov optimizatrion * pep8+ typos * damn pep8 * thunbnail * remove prints
Diffstat (limited to 'ot')
-rw-r--r--ot/__init__.py2
-rw-r--r--ot/gromov.py141
-rw-r--r--ot/optim.py3
3 files changed, 118 insertions, 28 deletions
diff --git a/ot/__init__.py b/ot/__init__.py
index f20332c..4292b41 100644
--- a/ot/__init__.py
+++ b/ot/__init__.py
@@ -43,6 +43,8 @@ from .unbalanced import (sinkhorn_unbalanced, barycenter_unbalanced,
sinkhorn_unbalanced2)
from .da import sinkhorn_lpl1_mm
from .sliced import sliced_wasserstein_distance, max_sliced_wasserstein_distance
+from .gromov import (gromov_wasserstein, gromov_wasserstein2,
+ gromov_barycenters, fused_gromov_wasserstein, fused_gromov_wasserstein2)
# utils functions
from .utils import dist, unif, tic, toc, toq
diff --git a/ot/gromov.py b/ot/gromov.py
index 465693d..ea667e4 100644
--- a/ot/gromov.py
+++ b/ot/gromov.py
@@ -174,7 +174,7 @@ def tensor_product(constC, hC1, hC2, T):
def gwloss(constC, hC1, hC2, T):
- """Return the Loss for Gromov-Wasserstein
+ r"""Return the Loss for Gromov-Wasserstein
The loss is computed as described in Proposition 1 Eq. (6) in :ref:`[12] <references-gwloss>`
@@ -213,7 +213,7 @@ def gwloss(constC, hC1, hC2, T):
def gwggrad(constC, hC1, hC2, T):
- """Return the gradient for Gromov-Wasserstein
+ r"""Return the gradient for Gromov-Wasserstein
The gradient is computed as described in Proposition 2 in :ref:`[12] <references-gwggrad>`
@@ -247,7 +247,7 @@ def gwggrad(constC, hC1, hC2, T):
def update_square_loss(p, lambdas, T, Cs):
- """
+ r"""
Updates :math:`\mathbf{C}` according to the L2 Loss kernel with the `S` :math:`\mathbf{T}_s`
couplings calculated at each iteration
@@ -284,7 +284,7 @@ def update_square_loss(p, lambdas, T, Cs):
def update_kl_loss(p, lambdas, T, Cs):
- """
+ r"""
Updates :math:`\mathbf{C}` according to the KL Loss kernel with the `S` :math:`\mathbf{T}_s` couplings calculated at each iteration
@@ -320,7 +320,7 @@ def update_kl_loss(p, lambdas, T, Cs):
return nx.exp(tmpsum / ppt)
-def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs):
+def gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=False, **kwargs):
r"""
Returns the gromov-wasserstein transport between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})`
@@ -386,6 +386,14 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs
"""
p, q = list_to_array(p, q)
+ p0, q0, C10, C20 = p, q, C1, C2
+ nx = get_backend(p0, q0, C10, C20)
+
+ p = nx.to_numpy(p)
+ q = nx.to_numpy(q)
+ C1 = nx.to_numpy(C10)
+ C2 = nx.to_numpy(C20)
+
constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun)
G0 = p[:, None] * q[None, :]
@@ -398,13 +406,15 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs
if log:
res, log = cg(p, q, 0, 1, f, df, G0, log=True, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs)
- log['gw_dist'] = gwloss(constC, hC1, hC2, res)
- return res, log
+ log['gw_dist'] = nx.from_numpy(gwloss(constC, hC1, hC2, res), type_as=C10)
+ log['u'] = nx.from_numpy(log['u'], type_as=C10)
+ log['v'] = nx.from_numpy(log['v'], type_as=C10)
+ return nx.from_numpy(res, type_as=C10), log
else:
- return cg(p, q, 0, 1, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs)
+ return nx.from_numpy(cg(p, q, 0, 1, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=False, **kwargs), type_as=C10)
-def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs):
+def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=False, **kwargs):
r"""
Returns the gromov-wasserstein discrepancy between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})`
@@ -420,7 +430,11 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwarg
- :math:`\mathbf{C_2}`: Metric cost matrix in the target space
- :math:`\mathbf{p}`: distribution in the source space
- :math:`\mathbf{q}`: distribution in the target space
- - `L`: loss function to account for the misfit between the similarity matrices
+ - `L`: loss function to account for the misfit between the similarity
+ matrices
+
+ Note that when using backends, this loss function is differentiable wrt the
+ marices and weights for quadratic loss using the gradients from [38]_.
Parameters
----------
@@ -463,9 +477,21 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwarg
metric approach to object matching. Foundations of computational
mathematics 11.4 (2011): 417-487.
+ .. [38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, Online
+ Graph Dictionary Learning, International Conference on Machine Learning
+ (ICML), 2021.
+
"""
p, q = list_to_array(p, q)
+ p0, q0, C10, C20 = p, q, C1, C2
+ nx = get_backend(p0, q0, C10, C20)
+
+ p = nx.to_numpy(p)
+ q = nx.to_numpy(q)
+ C1 = nx.to_numpy(C10)
+ C2 = nx.to_numpy(C20)
+
constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun)
G0 = p[:, None] * q[None, :]
@@ -475,13 +501,28 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwarg
def df(G):
return gwggrad(constC, hC1, hC2, G)
- res, log_gw = cg(p, q, 0, 1, f, df, G0, log=True, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs)
- log_gw['gw_dist'] = gwloss(constC, hC1, hC2, res)
- log_gw['T'] = res
+
+ T, log_gw = cg(p, q, 0, 1, f, df, G0, log=True, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs)
+
+ T0 = nx.from_numpy(T, type_as=C10)
+
+ log_gw['gw_dist'] = nx.from_numpy(gwloss(constC, hC1, hC2, T), type_as=C10)
+ log_gw['u'] = nx.from_numpy(log_gw['u'], type_as=C10)
+ log_gw['v'] = nx.from_numpy(log_gw['v'], type_as=C10)
+ log_gw['T'] = T0
+
+ gw = log_gw['gw_dist']
+
+ if loss_fun == 'square_loss':
+ gC1 = nx.from_numpy(2 * C1 * (p[:, None] * p[None, :]) - 2 * T.dot(C2).dot(T.T))
+ gC2 = nx.from_numpy(2 * C2 * (q[:, None] * q[None, :]) - 2 * T.T.dot(C1).dot(T))
+ gw = nx.set_gradients(gw, (p0, q0, C10, C20),
+ (log_gw['u'], log_gw['v'], gC1, gC2))
+
if log:
- return log_gw['gw_dist'], log_gw
+ return gw, log_gw
else:
- return log_gw['gw_dist']
+ return gw
def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, **kwargs):
@@ -548,6 +589,15 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5,
"""
p, q = list_to_array(p, q)
+ p0, q0, C10, C20, M0 = p, q, C1, C2, M
+ nx = get_backend(p0, q0, C10, C20, M0)
+
+ p = nx.to_numpy(p)
+ q = nx.to_numpy(q)
+ C1 = nx.to_numpy(C10)
+ C2 = nx.to_numpy(C20)
+ M = nx.to_numpy(M0)
+
constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun)
G0 = p[:, None] * q[None, :]
@@ -560,10 +610,16 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5,
if log:
res, log = cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs)
- log['fgw_dist'] = log['loss'][::-1][0]
- return res, log
+
+ fgw_dist = nx.from_numpy(log['loss'][-1], type_as=C10)
+
+ log['fgw_dist'] = fgw_dist
+ log['u'] = nx.from_numpy(log['u'], type_as=C10)
+ log['v'] = nx.from_numpy(log['v'], type_as=C10)
+ return nx.from_numpy(res, type_as=C10), log
+
else:
- return cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs)
+ return nx.from_numpy(cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs), type_as=C10)
def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, **kwargs):
@@ -586,7 +642,11 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5
- :math:`\mathbf{p}` and :math:`\mathbf{q}` are source and target weights (sum to 1)
- `L` is a loss function to account for the misfit between the similarity matrices
- The algorithm used for solving the problem is conditional gradient as discussed in :ref:`[24] <references-fused-gromov-wasserstein2>`
+ The algorithm used for solving the problem is conditional gradient as
+ discussed in :ref:`[24] <references-fused-gromov-wasserstein2>`
+
+ Note that when using backends, this loss function is differentiable wrt the
+ marices and weights for quadratic loss using the gradients from [38]_.
Parameters
----------
@@ -627,9 +687,22 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5
and Courty Nicolas
"Optimal Transport for structured data with application on graphs"
International Conference on Machine Learning (ICML). 2019.
+
+ .. [38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, Online
+ Graph Dictionary Learning, International Conference on Machine Learning
+ (ICML), 2021.
"""
p, q = list_to_array(p, q)
+ p0, q0, C10, C20, M0 = p, q, C1, C2, M
+ nx = get_backend(p0, q0, C10, C20, M0)
+
+ p = nx.to_numpy(p)
+ q = nx.to_numpy(q)
+ C1 = nx.to_numpy(C10)
+ C2 = nx.to_numpy(C20)
+ M = nx.to_numpy(M0)
+
constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun)
G0 = p[:, None] * q[None, :]
@@ -640,13 +713,27 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5
def df(G):
return gwggrad(constC, hC1, hC2, G)
- res, log = cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs)
+ T, log_fgw = cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs)
+
+ fgw_dist = nx.from_numpy(log_fgw['loss'][-1], type_as=C10)
+
+ T0 = nx.from_numpy(T, type_as=C10)
+
+ log_fgw['fgw_dist'] = fgw_dist
+ log_fgw['u'] = nx.from_numpy(log_fgw['u'], type_as=C10)
+ log_fgw['v'] = nx.from_numpy(log_fgw['v'], type_as=C10)
+ log_fgw['T'] = T0
+
+ if loss_fun == 'square_loss':
+ gC1 = nx.from_numpy(2 * C1 * (p[:, None] * p[None, :]) - 2 * T.dot(C2).dot(T.T))
+ gC2 = nx.from_numpy(2 * C2 * (q[:, None] * q[None, :]) - 2 * T.T.dot(C1).dot(T))
+ fgw_dist = nx.set_gradients(fgw_dist, (p0, q0, C10, C20, M0),
+ (log_fgw['u'], log_fgw['v'], alpha * gC1, alpha * gC2, (1 - alpha) * T0))
+
if log:
- log['fgw_dist'] = log['loss'][::-1][0]
- log['T'] = res
- return log['fgw_dist'], log
+ return fgw_dist, log_fgw
else:
- return log['fgw_dist']
+ return fgw_dist
def GW_distance_estimation(C1, C2, p, q, loss_fun, T,
@@ -1447,7 +1534,7 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun,
def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_features=False,
p=None, loss_fun='square_loss', max_iter=100, tol=1e-9,
verbose=False, log=False, init_C=None, init_X=None, random_state=None):
- """Compute the fgw barycenter as presented eq (5) in :ref:`[24] <references-fgw-barycenters>`
+ r"""Compute the fgw barycenter as presented eq (5) in :ref:`[24] <references-fgw-barycenters>`
Parameters
----------
@@ -1604,7 +1691,7 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_
def update_structure_matrix(p, lambdas, T, Cs):
- """Updates :math:`\mathbf{C}` according to the L2 Loss kernel with the `S` :math:`\mathbf{T}_s` couplings.
+ r"""Updates :math:`\mathbf{C}` according to the L2 Loss kernel with the `S` :math:`\mathbf{T}_s` couplings.
It is calculated at each iteration
@@ -1640,7 +1727,7 @@ def update_structure_matrix(p, lambdas, T, Cs):
def update_feature_matrix(lambdas, Ys, Ts, p):
- """Updates the feature with respect to the `S` :math:`\mathbf{T}_s` couplings.
+ r"""Updates the feature with respect to the `S` :math:`\mathbf{T}_s` couplings.
See "Solving the barycenter problem with Block Coordinate Descent (BCD)"
diff --git a/ot/optim.py b/ot/optim.py
index cc286b6..bd8ca26 100644
--- a/ot/optim.py
+++ b/ot/optim.py
@@ -267,7 +267,7 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000,
Mi += nx.min(Mi)
# solve linear program
- Gc = emd(a, b, Mi, numItermax=numItermaxEmd)
+ Gc, logemd = emd(a, b, Mi, numItermax=numItermaxEmd, log=True)
deltaG = Gc - G
@@ -297,6 +297,7 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000,
print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, f_val, relative_delta_fval, abs_delta_fval))
if log:
+ log.update(logemd)
return G, log
else:
return G