summaryrefslogtreecommitdiff
path: root/ot/gromov.py
diff options
context:
space:
mode:
authorNicolas Courty <Nico@MacBook-Pro-de-Nicolas.local>2017-09-01 15:37:09 +0200
committerNicolas Courty <Nico@MacBook-Pro-de-Nicolas.local>2017-09-01 15:37:09 +0200
commit53e1115349ddbdff83b74c5dd15fc4b258c46cd4 (patch)
treeac4619872b341ef0471cc70d671d566526113c0b /ot/gromov.py
parentf12322c1a288baedffd5b6aedcff15747aadac8e (diff)
docstrings + naming
Diffstat (limited to 'ot/gromov.py')
-rw-r--r--ot/gromov.py92
1 files changed, 46 insertions, 46 deletions
diff --git a/ot/gromov.py b/ot/gromov.py
index 9dbf463..cf9c4da 100644
--- a/ot/gromov.py
+++ b/ot/gromov.py
@@ -58,13 +58,13 @@ def tensor_square_loss(C1, C2, T):
Metric cost matrix in the source space
C2 : ndarray, shape (nt, nt)
Metric costfr matrix in the target space
- T : np.ndarray(ns,nt)
+ T : ndarray, shape (ns, nt)
Coupling between source and target spaces
Returns
-------
- tens : (ns*nt) ndarray
+ tens : ndarray, shape (ns, nt)
\mathcal{L}(C1,C2) \otimes T tensor-matrix multiplication result
@@ -89,7 +89,7 @@ def tensor_square_loss(C1, C2, T):
tens = -np.dot(h1(C1), T).dot(h2(C2).T)
tens -= tens.min()
- return np.array(tens)
+ return tens
def tensor_kl_loss(C1, C2, T):
@@ -116,13 +116,13 @@ def tensor_kl_loss(C1, C2, T):
Metric cost matrix in the source space
C2 : ndarray, shape (nt, nt)
Metric costfr matrix in the target space
- T : np.ndarray(ns,nt)
+ T : ndarray, shape (ns, nt)
Coupling between source and target spaces
Returns
-------
- tens : (ns*nt) ndarray
+ tens : ndarray, shape (ns, nt)
\mathcal{L}(C1,C2) \otimes T tensor-matrix multiplication result
References
@@ -151,34 +151,36 @@ def tensor_kl_loss(C1, C2, T):
tens = -np.dot(h1(C1), T).dot(h2(C2).T)
tens -= tens.min()
- return np.array(tens)
+ return tens
def update_square_loss(p, lambdas, T, Cs):
"""
- Updates C according to the L2 Loss kernel with the S Ts couplings calculated at each iteration
+ Updates C according to the L2 Loss kernel with the S Ts couplings
+ calculated at each iteration
Parameters
----------
- p : np.ndarray(N,)
+ p : ndarray, shape (N,)
weights in the targeted barycenter
lambdas : list of the S spaces' weights
T : list of S np.ndarray(ns,N)
the S Ts couplings calculated at each iteration
- Cs : Cs : list of S np.ndarray(ns,ns)
+ Cs : list of S ndarray, shape(ns,ns)
Metric cost matrices
Returns
----------
- C updated
+ C : ndarray, shape (nt,nt)
+ updated C matrix
"""
tmpsum = sum([lambdas[s] * np.dot(T[s].T, Cs[s]).dot(T[s]) for s in range(len(T))])
ppt = np.outer(p, p)
- return(np.divide(tmpsum, ppt))
+ return np.divide(tmpsum, ppt)
def update_kl_loss(p, lambdas, T, Cs):
@@ -188,27 +190,28 @@ def update_kl_loss(p, lambdas, T, Cs):
Parameters
----------
- p : np.ndarray(N,)
+ p : ndarray, shape (N,)
weights in the targeted barycenter
lambdas : list of the S spaces' weights
T : list of S np.ndarray(ns,N)
the S Ts couplings calculated at each iteration
- Cs : Cs : list of S np.ndarray(ns,ns)
+ Cs : list of S ndarray, shape(ns,ns)
Metric cost matrices
Returns
----------
- C updated
+ C : ndarray, shape (ns,ns)
+ updated C matrix
"""
tmpsum = sum([lambdas[s] * np.dot(T[s].T, Cs[s]).dot(T[s]) for s in range(len(T))])
ppt = np.outer(p, p)
- return(np.exp(np.divide(tmpsum, ppt)))
+ return np.exp(np.divide(tmpsum, ppt))
-def gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, stopThr=1e-9, verbose=False, log=False):
+def gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, tol=1e-9, verbose=False, log=False):
"""
Returns the gromov-wasserstein coupling between the two measured similarity matrices
@@ -241,31 +244,28 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, stopThr=1
Metric cost matrix in the source space
C2 : ndarray, shape (nt, nt)
Metric costfr matrix in the target space
- p : np.ndarray(ns,)
+ p : ndarray, shape (ns,)
distribution in the source space
- q : np.ndarray(nt)
+ q : ndarray, shape (nt,)
distribution in the target space
- loss_fun : loss function used for the solver either 'square_loss' or 'kl_loss'
+ loss_fun : string
+ loss function used for the solver either 'square_loss' or 'kl_loss'
epsilon : float
Regularization term >0
-<<<<<<< HEAD
max_iter : int, optional
-=======
- numItermax : int, optional
->>>>>>> 986f46ddde3ce2f550cb56f66620df377326423d
- Max number of iterations
- stopThr : float, optional
+ Max number of iterations
+ tol : float, optional
Stop threshold on error (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
record log if True
- forcing : np.ndarray(N,2)
- list of forced couplings (where N is the number of forcing)
+
Returns
-------
- T : coupling between the two spaces that minimizes :
+ T : ndarray, shape (ns, nt)
+ coupling between the two spaces that minimizes :
\sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l}-\epsilon(H(T))
"""
@@ -278,7 +278,7 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, stopThr=1
cpt = 0
err = 1
- while (err > stopThr and cpt < max_iter):
+ while (err > tol and cpt < max_iter):
Tprev = T
@@ -303,7 +303,7 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, stopThr=1
'It.', 'Err') + '\n' + '-' * 19)
print('{:5d}|{:8e}|'.format(cpt, err))
- cpt = cpt + 1
+ cpt += 1
if log:
return T, log
@@ -311,7 +311,7 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, stopThr=1
return T
-def gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, stopThr=1e-9, verbose=False, log=False):
+def gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, tol=1e-9, verbose=False, log=False):
"""
Returns the gromov-wasserstein discrepancy between the two measured similarity matrices
@@ -339,37 +339,36 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, stopThr=
Metric cost matrix in the source space
C2 : ndarray, shape (nt, nt)
Metric costfr matrix in the target space
- p : np.ndarray(ns,)
+ p : ndarray, shape (ns,)
distribution in the source space
- q : np.ndarray(nt)
+ q : ndarray, shape (nt,)
distribution in the target space
- loss_fun : loss function used for the solver either 'square_loss' or 'kl_loss'
+ loss_fun : string
+ loss function used for the solver either 'square_loss' or 'kl_loss'
epsilon : float
Regularization term >0
max_iter : int, optional
Max number of iterations
- stopThr : float, optional
+ tol : float, optional
Stop threshold on error (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
record log if True
- forcing : np.ndarray(N,2)
- list of forced couplings (where N is the number of forcing)
Returns
-------
- T : coupling between the two spaces that minimizes :
- \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l}-\epsilon(H(T))
+ gw_dist : float
+ Gromov-Wasserstein distance
"""
if log:
gw, logv = gromov_wasserstein(
- C1, C2, p, q, loss_fun, epsilon, max_iter, stopThr, verbose, log)
+ C1, C2, p, q, loss_fun, epsilon, max_iter, tol, verbose, log)
else:
gw = gromov_wasserstein(C1, C2, p, q, loss_fun,
- epsilon, max_iter, stopThr, verbose, log)
+ epsilon, max_iter, tol, verbose, log)
if loss_fun == 'square_loss':
gw_dist = np.sum(gw * tensor_square_loss(C1, C2, gw))
@@ -383,7 +382,7 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, stopThr=
return gw_dist
-def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, max_iter=1000, stopThr=1e-9, verbose=False, log=False):
+def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, max_iter=1000, tol=1e-9, verbose=False, log=False):
"""
Returns the gromov-wasserstein barycenters of S measured similarity matrices
@@ -408,7 +407,7 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, max_iter=1000,
Metric cost matrices
ps : list of S np.ndarray(ns,)
sample weights in the S spaces
- p : np.ndarray(N,)
+ p : ndarray, shape(N,)
weights in the targeted barycenter
lambdas : list of the S spaces' weights
L : tensor-matrix multiplication function based on specific loss function
@@ -418,7 +417,7 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, max_iter=1000,
Regularization term >0
max_iter : int, optional
Max number of iterations
- stopThr : float, optional
+ tol : float, optional
Stop threshol on error (>0)
verbose : bool, optional
Print information along iterations
@@ -427,7 +426,8 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, max_iter=1000,
Returns
-------
- C : Similarity matrix in the barycenter space (permutated arbitrarily)
+ C : ndarray, shape (N, N)
+ Similarity matrix in the barycenter space (permutated arbitrarily)
"""
@@ -446,7 +446,7 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, max_iter=1000,
error = []
- while(err > stopThr and cpt < max_iter):
+ while(err > tol and cpt < max_iter):
Cprev = C
T = [gromov_wasserstein(Cs[s], C, ps[s], p, loss_fun, epsilon,