diff options
Diffstat (limited to 'ot')
-rw-r--r-- | ot/gromov.py | 26 | ||||
-rw-r--r-- | ot/optim.py | 32 | ||||
-rw-r--r-- | ot/utils.py | 8 |
3 files changed, 27 insertions, 39 deletions
diff --git a/ot/gromov.py b/ot/gromov.py index 53349b7..ca96b31 100644 --- a/ot/gromov.py +++ b/ot/gromov.py @@ -17,7 +17,7 @@ import numpy as np from .bregman import sinkhorn
-from .utils import dist
+from .utils import dist, UndefinedParameter
from .optim import cg
@@ -1011,9 +1011,6 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_ International Conference on Machine Learning (ICML). 2019.
"""
- class UndefinedParameter(Exception):
- pass
-
S = len(Cs)
d = Ys[0].shape[1] # dimension on the node features
if p is None:
@@ -1049,10 +1046,7 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_ T = [np.outer(p, q) for q in ps]
- # X is N,d
- # Ys is ns,d
- Ms = [np.asarray(dist(X, Ys[s]), dtype=np.float64) for s in range(len(Ys))]
- # Ms is N,ns
+ Ms = [np.asarray(dist(X, Ys[s]), dtype=np.float64) for s in range(len(Ys))] # Ms is N,ns
cpt = 0
err_feature = 1
@@ -1072,27 +1066,13 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_ Ys_temp = [y.T for y in Ys]
X = update_feature_matrix(lambdas, Ys_temp, T, p).T
- # X must be N,d
- # Ys must be ns,d
Ms = [np.asarray(dist(X, Ys[s]), dtype=np.float64) for s in range(len(Ys))]
if not fixed_structure:
if loss_fun == 'square_loss':
- # T must be ns,N
- # Cs must be ns,ns
- # p must be N,1
T_temp = [t.T for t in T]
C = update_sructure_matrix(p, lambdas, T_temp, Cs)
- # Ys must be d,ns
- # Ts must be N,ns
- # p must be N,1
- # Ms is N,ns
- # C is N,N
- # Cs is ns,ns
- # p is N,1
- # ps is ns,1
-
T = [fused_gromov_wasserstein((1 - alpha) * Ms[s], C, Cs[s], p, ps[s], loss_fun, alpha, numItermax=max_iter, stopThr=1e-5, verbose=verbose) for s in range(S)]
# T is N,ns
@@ -1115,7 +1095,7 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_ if log:
log_['T'] = T # from target to Ys
log_['p'] = p
- log_['Ms'] = Ms # Ms are N,ns
+ log_['Ms'] = Ms
if log:
return X, C, log_
diff --git a/ot/optim.py b/ot/optim.py index 4d428d9..f94aceb 100644 --- a/ot/optim.py +++ b/ot/optim.py @@ -73,8 +73,8 @@ def line_search_armijo(f, xk, pk, gfk, old_fval, return alpha, fc[0], phi1 -def do_linesearch(cost, G, deltaG, Mi, f_val, - armijo=True, C1=None, C2=None, reg=None, Gc=None, constC=None, M=None): +def solve_linesearch(cost, G, deltaG, Mi, f_val, + armijo=True, C1=None, C2=None, reg=None, Gc=None, constC=None, M=None): """ Solve the linesearch in the FW iterations Parameters @@ -93,17 +93,17 @@ def do_linesearch(cost, G, deltaG, Mi, f_val, If True the steps of the line-search is found via an armijo research. Else closed form is used. If there is convergence issues use False. C1 : ndarray (ns,ns), optional - Structure matrix in the source domain. Only used when armijo=False + Structure matrix in the source domain. Only used and necessary when armijo=False C2 : ndarray (nt,nt), optional - Structure matrix in the target domain. Only used when armijo=False + Structure matrix in the target domain. Only used and necessary when armijo=False reg : float, optional - Regularization parameter. Only used when armijo=False + Regularization parameter. Only used and necessary when armijo=False Gc : ndarray (ns,nt) - Optimal map found by linearization in the FW algorithm. Only used when armijo=False + Optimal map found by linearization in the FW algorithm. Only used and necessary when armijo=False constC : ndarray (ns,nt) - Constant for the gromov cost. See [24]. Only used when armijo=False + Constant for the gromov cost. See [24]. Only used and necessary when armijo=False M : ndarray (ns,nt), optional - Cost matrix between the features. Only used when armijo=False + Cost matrix between the features. Only used and necessary when armijo=False Returns ------- alpha : float @@ -128,7 +128,7 @@ def do_linesearch(cost, G, deltaG, Mi, f_val, b = np.sum((M + reg * constC) * deltaG) - 2 * reg * (np.sum(dot12 * G) + np.sum(np.dot(C1, G).dot(C2) * deltaG)) c = cost(G) - alpha = solve_1d_linesearch_quad_funct(a, b, c) + alpha = solve_1d_linesearch_quad(a, b, c) fc = None f_val = cost(G + alpha * deltaG) @@ -181,7 +181,7 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, Print information along iterations log : bool, optional record log if True - kwargs : dict + **kwargs : dict Parameters for linesearch Returns @@ -244,7 +244,7 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, deltaG = Gc - G # line search - alpha, fc, f_val = do_linesearch(cost, G, deltaG, Mi, f_val, reg=reg, M=M, Gc=Gc, **kwargs) + alpha, fc, f_val = solve_linesearch(cost, G, deltaG, Mi, f_val, reg=reg, M=M, Gc=Gc, **kwargs) G = G + alpha * deltaG @@ -254,7 +254,7 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, abs_delta_fval = abs(f_val - old_fval) relative_delta_fval = abs_delta_fval / abs(f_val) - if relative_delta_fval < stopThr and abs_delta_fval < stopThr2: + if relative_delta_fval < stopThr or abs_delta_fval < stopThr2: loop = 0 if log: @@ -395,7 +395,7 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10, abs_delta_fval = abs(f_val - old_fval) relative_delta_fval = abs_delta_fval / abs(f_val) - if relative_delta_fval < stopThr and abs_delta_fval < stopThr2: + if relative_delta_fval < stopThr or abs_delta_fval < stopThr2: loop = 0 if log: @@ -413,11 +413,11 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10, return G -def solve_1d_linesearch_quad_funct(a, b, c): +def solve_1d_linesearch_quad(a, b, c): """ - Solve on 0,1 the following problem: + For any convex or non-convex 1d quadratic function f, solve on [0,1] the following problem: .. math:: - \min f(x)=a*x^{2}+b*x+c + \argmin f(x)=a*x^{2}+b*x+c Parameters ---------- diff --git a/ot/utils.py b/ot/utils.py index bb21b38..efd1288 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -487,3 +487,11 @@ class BaseEstimator(object): (key, self.__class__.__name__)) setattr(self, key, value) return self + + +class UndefinedParameter(Exception): + """ + Aim at raising an Exception when a undefined parameter is called + + """ + pass |