summaryrefslogtreecommitdiff
path: root/ot
diff options
context:
space:
mode:
authortvayer <titouan.vayer@gmail.com>2019-05-29 17:05:38 +0200
committertvayer <titouan.vayer@gmail.com>2019-05-29 17:05:48 +0200
commite1bd94bb7e85a0d2fd0fcd7642b06da12c1db6db (patch)
tree1e85920b878ab715d211db56f99e25bfa2482fd3 /ot
parentd4320382fa8873d15dcaec7adca3a4723c142515 (diff)
code review1
Diffstat (limited to 'ot')
-rw-r--r--ot/gromov.py108
-rw-r--r--ot/optim.py31
2 files changed, 112 insertions, 27 deletions
diff --git a/ot/gromov.py b/ot/gromov.py
index 5a57dc8..53349b7 100644
--- a/ot/gromov.py
+++ b/ot/gromov.py
@@ -10,6 +10,7 @@ Gromov-Wasserstein transport method
# Nicolas Courty <ncourty@irisa.fr>
# RĂ©mi Flamary <remi.flamary@unice.fr>
# Titouan Vayer <titouan.vayer@irisa.fr>
+#
# License: MIT License
import numpy as np
@@ -351,9 +352,9 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs
return cg(p, q, 0, 1, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs)
-def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, **kwargs):
+def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, **kwargs):
"""
- Computes the FGW distance between two graphs see [3]
+ Computes the FGW transport between two graphs see [24]
.. math::
\gamma = arg\min_\gamma (1-\alpha)*<\gamma,M>_F + alpha* \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l}
s.t. \gamma 1 = p
@@ -377,7 +378,7 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5,
distribution in the source space
q : ndarray, shape (nt,)
distribution in the target space
- loss_fun : string,optionnal
+ loss_fun : string,optional
loss function used for the solver
max_iter : int, optional
Max number of iterations
@@ -416,7 +417,86 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5,
def df(G):
return gwggrad(constC, hC1, hC2, G)
- return cg(p, q, M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs)
+ if log:
+ res, log = cg(p, q, M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs)
+ log['fgw_dist'] = log['loss'][::-1][0]
+ return res, log
+ else:
+ return cg(p, q, M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs)
+
+
+def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, **kwargs):
+ """
+ Computes the FGW distance between two graphs see [24]
+ .. math::
+ \gamma = arg\min_\gamma (1-\alpha)*<\gamma,M>_F + alpha* \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l}
+ s.t. \gamma 1 = p
+ \gamma^T 1= q
+ \gamma\geq 0
+ where :
+ - M is the (ns,nt) metric cost matrix
+ - :math:`f` is the regularization term ( and df is its gradient)
+ - a and b are source and target weights (sum to 1)
+ - L is a loss function to account for the misfit between the similarity matrices
+ The algorithm used for solving the problem is conditional gradient as discussed in [1]_
+ Parameters
+ ----------
+ M : ndarray, shape (ns, nt)
+ Metric cost matrix between features across domains
+ C1 : ndarray, shape (ns, ns)
+ Metric cost matrix respresentative of the structure in the source space
+ C2 : ndarray, shape (nt, nt)
+ Metric cost matrix espresentative of the structure in the target space
+ p : ndarray, shape (ns,)
+ distribution in the source space
+ q : ndarray, shape (nt,)
+ distribution in the target space
+ loss_fun : string,optional
+ loss function used for the solver
+ max_iter : int, optional
+ Max number of iterations
+ tol : float, optional
+ Stop threshold on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+ armijo : bool, optional
+ If True the steps of the line-search is found via an armijo research. Else closed form is used.
+ If there is convergence issues use False.
+ **kwargs : dict
+ parameters can be directly pased to the ot.optim.cg solver
+ Returns
+ -------
+ gamma : (ns x nt) ndarray
+ Optimal transportation matrix for the given parameters
+ log : dict
+ log dictionary return only if log==True in parameters
+ References
+ ----------
+ .. [24] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain
+ and Courty Nicolas
+ "Optimal Transport for structured data with application on graphs"
+ International Conference on Machine Learning (ICML). 2019.
+ """
+
+ constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun)
+
+ G0 = p[:, None] * q[None, :]
+
+ def f(G):
+ return gwloss(constC, hC1, hC2, G)
+
+ def df(G):
+ return gwggrad(constC, hC1, hC2, G)
+
+ res, log = cg(p, q, M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs)
+ if log:
+ log['fgw_dist'] = log['loss'][::-1][0]
+ log['T'] = res
+ return log['fgw_dist'], log
+ else:
+ return log['fgw_dist']
def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs):
@@ -889,7 +969,7 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun,
def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_features=False,
p=None, loss_fun='square_loss', max_iter=100, tol=1e-9,
- verbose=False, log=True, init_C=None, init_X=None):
+ verbose=False, log=False, init_C=None, init_X=None):
"""
Compute the fgw barycenter as presented eq (5) in [24].
----------
@@ -919,7 +999,8 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_
Barycenters' features
C : ndarray, shape (N,N)
Barycenters' structure matrix
- log_:
+ log_: dictionary
+ Only returned when log=True
T : list of (N,ns) transport matrices
Ms : all distance matrices between the feature of the barycenter and the other features dist(X,Ys) shape (N,ns)
References
@@ -1015,14 +1096,13 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_
T = [fused_gromov_wasserstein((1 - alpha) * Ms[s], C, Cs[s], p, ps[s], loss_fun, alpha, numItermax=max_iter, stopThr=1e-5, verbose=verbose) for s in range(S)]
# T is N,ns
-
- log_['Ts_iter'].append(T)
err_feature = np.linalg.norm(X - Xprev.reshape(N, d))
err_structure = np.linalg.norm(C - Cprev)
if log:
log_['err_feature'].append(err_feature)
log_['err_structure'].append(err_structure)
+ log_['Ts_iter'].append(T)
if verbose:
if cpt % 200 == 0:
@@ -1032,11 +1112,15 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_
print('{:5d}|{:8e}|'.format(cpt, err_feature))
cpt += 1
- log_['T'] = T # from target to Ys
- log_['p'] = p
- log_['Ms'] = Ms # Ms are N,ns
+ if log:
+ log_['T'] = T # from target to Ys
+ log_['p'] = p
+ log_['Ms'] = Ms # Ms are N,ns
- return X, C, log_
+ if log:
+ return X, C, log_
+ else:
+ return X, C
def update_sructure_matrix(p, lambdas, T, Cs):
diff --git a/ot/optim.py b/ot/optim.py
index 7d103e2..4d428d9 100644
--- a/ot/optim.py
+++ b/ot/optim.py
@@ -5,6 +5,7 @@ Optimization algorithms for OT
# Author: Remi Flamary <remi.flamary@unice.fr>
# Titouan Vayer <titouan.vayer@irisa.fr>
+#
# License: MIT License
import numpy as np
@@ -88,20 +89,20 @@ def do_linesearch(cost, G, deltaG, Mi, f_val,
Cost matrix of the linearized transport problem. Corresponds to the gradient of the cost
f_val : float
Value of the cost at G
- armijo : bool, optionnal
+ armijo : bool, optional
If True the steps of the line-search is found via an armijo research. Else closed form is used.
If there is convergence issues use False.
- C1 : ndarray (ns,ns), optionnal
+ C1 : ndarray (ns,ns), optional
Structure matrix in the source domain. Only used when armijo=False
- C2 : ndarray (nt,nt), optionnal
+ C2 : ndarray (nt,nt), optional
Structure matrix in the target domain. Only used when armijo=False
- reg : float, optionnal
+ reg : float, optional
Regularization parameter. Only used when armijo=False
Gc : ndarray (ns,nt)
Optimal map found by linearization in the FW algorithm. Only used when armijo=False
constC : ndarray (ns,nt)
Constant for the gromov cost. See [24]. Only used when armijo=False
- M : ndarray (ns,nt), optionnal
+ M : ndarray (ns,nt), optional
Cost matrix between the features. Only used when armijo=False
Returns
-------
@@ -223,9 +224,9 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200,
it = 0
if verbose:
- print('{:5s}|{:12s}|{:8s}'.format(
- 'It.', 'Loss', 'Delta loss') + '\n' + '-' * 32)
- print('{:5d}|{:8e}|{:8e}'.format(it, f_val, 0))
+ print('{:5s}|{:12s}|{:8s}|{:8s}'.format(
+ 'It.', 'Loss', 'Relative loss', 'Absolute loss') + '\n' + '-' * 48)
+ print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, f_val, 0, 0))
while loop:
@@ -261,8 +262,8 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200,
if verbose:
if it % 20 == 0:
- print('{:5s}|{:12s}|{:8s}'.format(
- 'It.', 'Loss', 'Relative variation loss', 'Absolute variation loss') + '\n' + '-' * 32)
+ print('{:5s}|{:12s}|{:8s}|{:8s}'.format(
+ 'It.', 'Loss', 'Relative loss', 'Absolute loss') + '\n' + '-' * 48)
print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, f_val, relative_delta_fval, abs_delta_fval))
if log:
@@ -363,9 +364,9 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10,
it = 0
if verbose:
- print('{:5s}|{:12s}|{:8s}'.format(
- 'It.', 'Loss', 'Delta loss') + '\n' + '-' * 32)
- print('{:5d}|{:8e}|{:8e}'.format(it, f_val, 0))
+ print('{:5s}|{:12s}|{:8s}|{:8s}'.format(
+ 'It.', 'Loss', 'Relative loss', 'Absolute loss') + '\n' + '-' * 48)
+ print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, f_val, 0, 0))
while loop:
@@ -402,8 +403,8 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10,
if verbose:
if it % 20 == 0:
- print('{:5s}|{:12s}|{:8s}'.format(
- 'It.', 'Loss', 'Relative variation loss', 'Absolute variation loss') + '\n' + '-' * 32)
+ print('{:5s}|{:12s}|{:8s}|{:8s}'.format(
+ 'It.', 'Loss', 'Relative loss', 'Absolute loss') + '\n' + '-' * 48)
print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, f_val, relative_delta_fval, abs_delta_fval))
if log: