summaryrefslogtreecommitdiff
path: root/ot/gromov.py
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2018-02-15 16:48:29 +0100
committerRémi Flamary <remi.flamary@gmail.com>2018-02-15 16:48:29 +0100
commit71b17f4e956de4503cbdffc2a822bcea18ed85d1 (patch)
treef59c07b7218a3ca81461e892ad2dc9bdb312d111 /ot/gromov.py
parent42d0aa94e7cb49711a646fe9b263a86cdb817161 (diff)
large update gromov
Diffstat (limited to 'ot/gromov.py')
-rw-r--r--ot/gromov.py258
1 files changed, 148 insertions, 110 deletions
diff --git a/ot/gromov.py b/ot/gromov.py
index 20bf7ee..a23303a 100644
--- a/ot/gromov.py
+++ b/ot/gromov.py
@@ -9,6 +9,7 @@ Gromov-Wasserstein transport method
# Author: Erwan Vautier <erwan.vautier@gmail.com>
# Nicolas Courty <ncourty@irisa.fr>
+# Rémi Flamary <remi.flamary@unice.fr>
#
# License: MIT License
@@ -16,40 +17,35 @@ import numpy as np
from .bregman import sinkhorn
from .utils import dist
+from .optim import cg
-def square_loss(a, b):
- """
- Returns the value of L(a,b)=(1/2)*|a-b|^2
- """
-
- return 0.5 * (a - b)**2
-
-
-def kl_loss(a, b):
- """
- Returns the value of L(a,b)=a*log(a/b)-a+b
- """
-
- return a * np.log(a / b) - a + b
+def init_matrix(C1,C2,T,p,q,loss_fun='square_loss'):
+ """ Return loss matrices and tensors for Gromov-Wasserstein fast computation
-
-def tensor_square_loss(C1, C2, T):
- """
- Returns the value of \mathcal{L}(C1,C2) \otimes T with the square loss
+ Returns the value of \mathcal{L}(C1,C2) \otimes T with the selected loss
function as the loss function of Gromow-Wasserstein discrepancy.
+
+ The matrices are computed as described in Proposition 1 in [12]
Where :
- C1 : Metric cost matrix in the source space
- C2 : Metric cost matrix in the target space
- T : A coupling between those two spaces
+ * C1 : Metric cost matrix in the source space
+ * C2 : Metric cost matrix in the target space
+ * T : A coupling between those two spaces
The square-loss function L(a,b)=(1/2)*|a-b|^2 is read as :
L(a,b) = f1(a)+f2(b)-h1(a)*h2(b) with :
- f1(a)=(a^2)/2
- f2(b)=(b^2)/2
- h1(a)=a
- h2(b)=b
+ * f1(a)=(a^2)/2
+ * f2(b)=(b^2)/2
+ * h1(a)=a
+ * h2(b)=b
+
+ The kl-loss function L(a,b)=(1/2)*|a-b|^2 is read as :
+ L(a,b) = f1(a)+f2(b)-h1(a)*h2(b) with :
+ * f1(a)=a*log(a)-a
+ * f2(b)=b
+ * h1(a)=a
+ * h2(b)=log(b)
Parameters
----------
@@ -57,66 +53,77 @@ def tensor_square_loss(C1, C2, T):
Metric cost matrix in the source space
C2 : ndarray, shape (nt, nt)
Metric costfr matrix in the target space
- T : ndarray, shape (ns, nt)
+ T : ndarray, shape (ns, nt)
Coupling between source and target spaces
+ p : ndarray, shape (ns,)
+
Returns
-------
- tens : ndarray, shape (ns, nt)
- \mathcal{L}(C1,C2) \otimes T tensor-matrix multiplication result
- """
-
- C1 = np.asarray(C1, dtype=np.float64)
- C2 = np.asarray(C2, dtype=np.float64)
- T = np.asarray(T, dtype=np.float64)
-
- def f1(a):
- return (a**2) / 2
-
- def f2(b):
- return (b**2) / 2
-
- def h1(a):
- return a
-
- def h2(b):
- return b
-
- tens = -np.dot(h1(C1), T).dot(h2(C2).T)
- tens -= tens.min()
-
- return tens
-
+
+ constC : ndarray, shape (ns, nt)
+ Constant C matrix in Eq. (6)
+ hC1 : ndarray, shape (ns, ns)
+ h1(C1) matrix in Eq. (6)
+ hC2 : ndarray, shape (nt, nt)
+ h2(C) matrix in Eq. (6)
+
+ References
+ ----------
+ .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon,
+ "Gromov-Wasserstein averaging of kernel and distance matrices."
+ International Conference on Machine Learning (ICML). 2016.
-def tensor_kl_loss(C1, C2, T):
"""
- Returns the value of \mathcal{L}(C1,C2) \otimes T with the square loss
- function as the loss function of Gromow-Wasserstein discrepancy.
-
- Where :
- C1 : Metric cost matrix in the source space
- C2 : Metric cost matrix in the target space
- T : A coupling between those two spaces
- The square-loss function L(a,b)=(1/2)*|a-b|^2 is read as :
- L(a,b) = f1(a)+f2(b)-h1(a)*h2(b) with :
- f1(a)=a*log(a)-a
- f2(b)=b
- h1(a)=a
- h2(b)=log(b)
+ if loss_fun == 'square_loss':
+ def f1(a):
+ return (a**2)/2
+ def f2(b):
+ return (b**2)/2
+ def h1(a):
+ return a
+ def h2(b):
+ return b
+ elif loss_fun == 'kl_loss':
+ def f1(a):
+ return a * np.log(a + 1e-15) - a
+ def f2(b):
+ return b
+ def h1(a):
+ return a
+ def h2(b):
+ return np.log(b + 1e-15)
+
+ constC1=np.dot(np.dot(f1(C1),p.reshape(-1,1)),
+ np.ones(len(q)).reshape(1,-1))
+ constC2=np.dot(np.ones(len(p)).reshape(-1,1),
+ np.dot(q.reshape(1,-1),f2(C2).T))
+ constC=constC1+constC2
+ hC1 = h1(C1)
+ hC2 = h2(C2)
+
+ return constC,hC1,hC2
+
+def tensor_product(constC,hC1,hC2,T):
+ """ Return the tensor for Gromov-Wasserstein fast computation
+
+ The tensor is computed as described in Proposition 1 Eq. (6) in [12].
Parameters
----------
- C1 : ndarray, shape (ns, ns)
- Metric cost matrix in the source space
- C2 : ndarray, shape (nt, nt)
- Metric costfr matrix in the target space
- T : ndarray, shape (ns, nt)
- Coupling between source and target spaces
+ constC : ndarray, shape (ns, nt)
+ Constant C matrix in Eq. (6)
+ hC1 : ndarray, shape (ns, ns)
+ h1(C1) matrix in Eq. (6)
+ hC2 : ndarray, shape (nt, nt)
+ h2(C) matrix in Eq. (6)
+
Returns
-------
+
tens : ndarray, shape (ns, nt)
\mathcal{L}(C1,C2) \otimes T tensor-matrix multiplication result
@@ -126,28 +133,43 @@ def tensor_kl_loss(C1, C2, T):
"Gromov-Wasserstein averaging of kernel and distance matrices."
International Conference on Machine Learning (ICML). 2016.
- """
+ """
+ A=-np.dot(hC1, T).dot(hC2.T)
+ tens = constC+A
+ #tens -= tens.min()
+ return tens
- C1 = np.asarray(C1, dtype=np.float64)
- C2 = np.asarray(C2, dtype=np.float64)
- T = np.asarray(T, dtype=np.float64)
+def gwloss(constC,hC1,hC2,T):
- def f1(a):
- return a * np.log(a + 1e-15) - a
+ tens=tensor_product(constC,hC1,hC2,T)
+
+ return np.sum(tens*T)
- def f2(b):
- return b
+def gwggrad(constC,hC1,hC2,T):
+
+ return 2*tensor_product(constC,hC1,hC2,T) # [12] Prop. 2 misses a 2 factor
- def h1(a):
- return a
+def gromov_wasserstein(C1,C2,p,q,loss_fun,alpha=1,log=False,**kwargs):
- def h2(b):
- return np.log(b + 1e-15)
- tens = -np.dot(h1(C1), T).dot(h2(C2).T)
- tens -= tens.min()
+ T = np.eye(len(p), len(q))
+
+ constC,hC1,hC2=init_matrix(C1,C2,T,p,q,loss_fun)
+
+ G0=p[:,None]*q[None,:]
+
+ def f(G):
+ return gwloss(constC,hC1,hC2,G)
+ def df(G):
+ return gwggrad(constC,hC1,hC2,G)
+
+ if log:
+ res,log=cg(p,q,0,alpha,f,df,G0,log=True,**kwargs)
+ log['gw_dist']=gwloss(constC,hC1,hC2,res)
+ return res,log
+ else:
+ return cg(p,q,0,alpha,f,df,G0,**kwargs)
- return tens
def update_square_loss(p, lambdas, T, Cs):
@@ -205,10 +227,10 @@ def update_kl_loss(p, lambdas, T, Cs):
return np.exp(np.divide(tmpsum, ppt))
-def gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon,
+def entropic_gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon,
max_iter=1000, tol=1e-9, verbose=False, log=False):
"""
- Returns the gromov-wasserstein coupling between the two measured similarity matrices
+ Returns the regularized gromov-wasserstein coupling between the two measured similarity matrices
(C1,p) and (C2,q)
@@ -259,25 +281,34 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon,
T : ndarray, shape (ns, nt)
coupling between the two spaces that minimizes :
\sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l}-\epsilon(H(T))
+
+ References
+ ----------
+ .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon,
+ "Gromov-Wasserstein averaging of kernel and distance matrices."
+ International Conference on Machine Learning (ICML). 2016.
+
"""
C1 = np.asarray(C1, dtype=np.float64)
C2 = np.asarray(C2, dtype=np.float64)
T = np.outer(p, q) # Initialization
+
+ constC,hC1,hC2=init_matrix(C1,C2,T,p,q,loss_fun)
cpt = 0
err = 1
+
+ if log:
+ log={'err':[]}
while (err > tol and cpt < max_iter):
Tprev = T
- if loss_fun == 'square_loss':
- tens = tensor_square_loss(C1, C2, T)
-
- elif loss_fun == 'kl_loss':
- tens = tensor_kl_loss(C1, C2, T)
+ # compute the gradient
+ tens=gwggrad(constC,hC1,hC2,T)
T = sinkhorn(p, q, tens, epsilon)
@@ -298,15 +329,15 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon,
cpt += 1
if log:
+ log['gw_dist']=gwloss(constC,hC1,hC2,T)
return T, log
else:
return T
-
-def gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon,
+def entropic_gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon,
max_iter=1000, tol=1e-9, verbose=False, log=False):
"""
- Returns the gromov-wasserstein discrepancy between the two measured similarity matrices
+ Returns the entropic gromov-wasserstein discrepancy between the two measured similarity matrices
(C1,p) and (C2,q)
@@ -350,25 +381,25 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon,
-------
gw_dist : float
Gromov-Wasserstein distance
+
+ References
+ ----------
+ .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon,
+ "Gromov-Wasserstein averaging of kernel and distance matrices."
+ International Conference on Machine Learning (ICML). 2016.
+
"""
- if log:
- gw, logv = gromov_wasserstein(
- C1, C2, p, q, loss_fun, epsilon, max_iter, tol, verbose, log)
- else:
- gw = gromov_wasserstein(C1, C2, p, q, loss_fun,
- epsilon, max_iter, tol, verbose, log)
-
- if loss_fun == 'square_loss':
- gw_dist = np.sum(gw * tensor_square_loss(C1, C2, gw))
- elif loss_fun == 'kl_loss':
- gw_dist = np.sum(gw * tensor_kl_loss(C1, C2, gw))
+ gw, logv = entropic_gromov_wasserstein(
+ C1, C2, p, q, loss_fun, epsilon, max_iter, tol, verbose, log=True)
+
+ log['T']=gw
if log:
- return gw_dist, logv
+ return logv['gw_dist'], logv
else:
- return gw_dist
+ return logv['gw_dist']
def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon,
@@ -421,6 +452,13 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon,
-------
C : ndarray, shape (N, N)
Similarity matrix in the barycenter space (permutated arbitrarily)
+
+ References
+ ----------
+ .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon,
+ "Gromov-Wasserstein averaging of kernel and distance matrices."
+ International Conference on Machine Learning (ICML). 2016.
+
"""
S = len(Cs)
@@ -444,7 +482,7 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon,
while(err > tol and cpt < max_iter):
Cprev = C
- T = [gromov_wasserstein(Cs[s], C, ps[s], p, loss_fun, epsilon,
+ T = [entropic_gromov_wasserstein(Cs[s], C, ps[s], p, loss_fun, epsilon,
max_iter, 1e-5, verbose, log) for s in range(S)]
if loss_fun == 'square_loss':
C = update_square_loss(p, lambdas, T, Cs)