summaryrefslogtreecommitdiff
path: root/ot/gromov.py
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2018-02-15 16:54:23 +0100
committerRémi Flamary <remi.flamary@gmail.com>2018-02-15 16:54:23 +0100
commit341a6b55cde2c0826e4c5a558b7507828d01ae08 (patch)
treea21be77f96750bcec1d11d76c2a57197b3a916cd /ot/gromov.py
parent71b17f4e956de4503cbdffc2a822bcea18ed85d1 (diff)
add documentation
Diffstat (limited to 'ot/gromov.py')
-rw-r--r--ot/gromov.py110
1 files changed, 108 insertions, 2 deletions
diff --git a/ot/gromov.py b/ot/gromov.py
index a23303a..c3ce415 100644
--- a/ot/gromov.py
+++ b/ot/gromov.py
@@ -140,17 +140,123 @@ def tensor_product(constC,hC1,hC2,T):
return tens
def gwloss(constC,hC1,hC2,T):
+ """ Return the Loss for Gromov-Wasserstein
+
+ The loss is computed as described in Proposition 1 Eq. (6) in [12].
+
+ Parameters
+ ----------
+ constC : ndarray, shape (ns, nt)
+ Constant C matrix in Eq. (6)
+ hC1 : ndarray, shape (ns, ns)
+ h1(C1) matrix in Eq. (6)
+ hC2 : ndarray, shape (nt, nt)
+ h2(C) matrix in Eq. (6)
+ T : ndarray, shape (ns, nt)
+ Current value of transport matrix T
+
+ Returns
+ -------
+
+ loss : float
+ Gromov Wasserstein loss
+
+ References
+ ----------
+ .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon,
+ "Gromov-Wasserstein averaging of kernel and distance matrices."
+ International Conference on Machine Learning (ICML). 2016.
+
+ """
tens=tensor_product(constC,hC1,hC2,T)
return np.sum(tens*T)
def gwggrad(constC,hC1,hC2,T):
-
+ """ Return the gradient for Gromov-Wasserstein
+
+ The gradient is computed as described in Proposition 2 in [12].
+
+ Parameters
+ ----------
+ constC : ndarray, shape (ns, nt)
+ Constant C matrix in Eq. (6)
+ hC1 : ndarray, shape (ns, ns)
+ h1(C1) matrix in Eq. (6)
+ hC2 : ndarray, shape (nt, nt)
+ h2(C) matrix in Eq. (6)
+ T : ndarray, shape (ns, nt)
+ Current value of transport matrix T
+
+ Returns
+ -------
+
+ grad : ndarray, shape (ns, nt)
+ Gromov Wasserstein gradient
+
+ References
+ ----------
+ .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon,
+ "Gromov-Wasserstein averaging of kernel and distance matrices."
+ International Conference on Machine Learning (ICML). 2016.
+
+ """
return 2*tensor_product(constC,hC1,hC2,T) # [12] Prop. 2 misses a 2 factor
-def gromov_wasserstein(C1,C2,p,q,loss_fun,alpha=1,log=False,**kwargs):
+def gromov_wasserstein(C1,C2,p,q,loss_fun,log=False,**kwargs):
+ """
+ Returns the gromov-wasserstein discrepancy between the two measured similarity matrices
+
+ (C1,p) and (C2,q)
+
+ The function solves the following optimization problem:
+
+ .. math::
+ \GW_Dist = \min_T \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l}
+
+ Where :
+ C1 : Metric cost matrix in the source space
+ C2 : Metric cost matrix in the target space
+ p : distribution in the source space
+ q : distribution in the target space
+ L : loss function to account for the misfit between the similarity matrices
+ H : entropy
+ Parameters
+ ----------
+ C1 : ndarray, shape (ns, ns)
+ Metric cost matrix in the source space
+ C2 : ndarray, shape (nt, nt)
+ Metric costfr matrix in the target space
+ p : ndarray, shape (ns,)
+ distribution in the source space
+ q : ndarray, shape (nt,)
+ distribution in the target space
+ loss_fun : string
+ loss function used for the solver either 'square_loss' or 'kl_loss'
+
+ max_iter : int, optional
+ Max number of iterations
+ tol : float, optional
+ Stop threshold on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+ Returns
+ -------
+ gw_dist : float
+ Gromov-Wasserstein distance
+
+ References
+ ----------
+ .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon,
+ "Gromov-Wasserstein averaging of kernel and distance matrices."
+ International Conference on Machine Learning (ICML). 2016.
+
+ """
T = np.eye(len(p), len(q))