summaryrefslogtreecommitdiff
path: root/ot/gromov.py
diff options
context:
space:
mode:
authortvayer <titouan.vayer@gmail.com>2019-05-29 15:51:57 +0200
committertvayer <titouan.vayer@gmail.com>2019-05-29 15:51:57 +0200
commit9421dddd8890d4c575b593d678eb7bdf5f933f83 (patch)
treeea589599791cf38f7f6c2420d919bc3a627f5ae0 /ot/gromov.py
parent94d2fe5fd0b07060426e9449de0331b88ab53df4 (diff)
Doc+armijo
Diffstat (limited to 'ot/gromov.py')
-rw-r--r--ot/gromov.py39
1 files changed, 20 insertions, 19 deletions
diff --git a/ot/gromov.py b/ot/gromov.py
index 44248d1..5a57dc8 100644
--- a/ot/gromov.py
+++ b/ot/gromov.py
@@ -33,12 +33,12 @@ def init_matrix(C1, C2, p, q, loss_fun='square_loss'):
* C2 : Metric cost matrix in the target space
* T : A coupling between those two spaces
- The square-loss function L(a,b)=(1/2)*|a-b|^2 is read as :
+ The square-loss function L(a,b)=|a-b|^2 is read as :
L(a,b) = f1(a)+f2(b)-h1(a)*h2(b) with :
- * f1(a)=(a^2)/2
- * f2(b)=(b^2)/2
+ * f1(a)=(a^2)
+ * f2(b)=(b^2)
* h1(a)=a
- * h2(b)=b
+ * h2(b)=2*b
The kl-loss function L(a,b)=a*log(a/b)-a+b is read as :
L(a,b) = f1(a)+f2(b)-h1(a)*h2(b) with :
@@ -269,7 +269,7 @@ def update_kl_loss(p, lambdas, T, Cs):
return np.exp(np.divide(tmpsum, ppt))
-def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, amijo=False, **kwargs):
+def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs):
"""
Returns the gromov-wasserstein transport between (C1,p) and (C2,q)
@@ -307,8 +307,8 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, amijo=False, **kwargs)
Print information along iterations
log : bool, optional
record log if True
- amijo : bool, optional
- If True the steps of the line-search is found via an amijo research. Else closed form is used.
+ armijo : bool, optional
+ If True the steps of the line-search is found via an armijo research. Else closed form is used.
If there is convergence issues use False.
**kwargs : dict
parameters can be directly pased to the ot.optim.cg solver
@@ -344,14 +344,14 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, amijo=False, **kwargs)
return gwggrad(constC, hC1, hC2, G)
if log:
- res, log = cg(p, q, 0, 1, f, df, G0, log=True, amijo=amijo, C1=C1, C2=C2, constC=constC, **kwargs)
+ res, log = cg(p, q, 0, 1, f, df, G0, log=True, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs)
log['gw_dist'] = gwloss(constC, hC1, hC2, res)
return res, log
else:
- return cg(p, q, 0, 1, f, df, G0, amijo=amijo, C1=C1, C2=C2, constC=constC, **kwargs)
+ return cg(p, q, 0, 1, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs)
-def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, amijo=False, **kwargs):
+def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, **kwargs):
"""
Computes the FGW distance between two graphs see [3]
.. math::
@@ -363,6 +363,7 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5,
- M is the (ns,nt) metric cost matrix
- :math:`f` is the regularization term ( and df is its gradient)
- a and b are source and target weights (sum to 1)
+ - L is a loss function to account for the misfit between the similarity matrices
The algorithm used for solving the problem is conditional gradient as discussed in [1]_
Parameters
----------
@@ -386,8 +387,8 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5,
Print information along iterations
log : bool, optional
record log if True
- amijo : bool, optional
- If True the steps of the line-search is found via an amijo research. Else closed form is used.
+ armijo : bool, optional
+ If True the steps of the line-search is found via an armijo research. Else closed form is used.
If there is convergence issues use False.
**kwargs : dict
parameters can be directly pased to the ot.optim.cg solver
@@ -415,10 +416,10 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5,
def df(G):
return gwggrad(constC, hC1, hC2, G)
- return cg(p, q, M, alpha, f, df, G0, amijo=amijo, C1=C1, C2=C2, constC=constC, **kwargs)
+ return cg(p, q, M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs)
-def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, amijo=False, **kwargs):
+def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs):
"""
Returns the gromov-wasserstein discrepancy between (C1,p) and (C2,q)
@@ -456,8 +457,8 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, amijo=False, **kwargs
Print information along iterations
log : bool, optional
record log if True
- amijo : bool, optional
- If True the steps of the line-search is found via an amijo research. Else closed form is used.
+ armijo : bool, optional
+ If True the steps of the line-search is found via an armijo research. Else closed form is used.
If there is convergence issues use False.
Returns
-------
@@ -487,7 +488,7 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, amijo=False, **kwargs
def df(G):
return gwggrad(constC, hC1, hC2, G)
- res, log = cg(p, q, 0, 1, f, df, G0, log=True, amijo=amijo, C1=C1, C2=C2, constC=constC, **kwargs)
+ res, log = cg(p, q, 0, 1, f, df, G0, log=True, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs)
log['gw_dist'] = gwloss(constC, hC1, hC2, res)
log['T'] = res
if log:
@@ -890,7 +891,7 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_
p=None, loss_fun='square_loss', max_iter=100, tol=1e-9,
verbose=False, log=True, init_C=None, init_X=None):
"""
- Compute the fgw barycenter as presented eq (5) in [3].
+ Compute the fgw barycenter as presented eq (5) in [24].
----------
N : integer
Desired number of samples of the target barycenter
@@ -1065,7 +1066,7 @@ def update_sructure_matrix(p, lambdas, T, Cs):
def update_feature_matrix(lambdas, Ys, Ts, p):
"""
- Updates the feature with respect to the S Ts couplings. See "Solving the barycenter problem with Block Coordinate Descent (BCD)" in [3]
+ Updates the feature with respect to the S Ts couplings. See "Solving the barycenter problem with Block Coordinate Descent (BCD)" in [24]
calculated at each iteration
Parameters
----------