summaryrefslogtreecommitdiff
path: root/ot
diff options
context:
space:
mode:
authortvayer <titouan.vayer@gmail.com>2019-06-04 10:32:30 +0200
committertvayer <titouan.vayer@gmail.com>2019-06-04 10:32:35 +0200
commitad450b0a5bb63ee9731e88d4a8e7423b16f1abd8 (patch)
treecab0421292074e59cb4eeb2846e8cca5aa159d3a /ot
parent89a2e0aee4353a051d924de0457f8976c26fa5d7 (diff)
changes forgotten coments
Diffstat (limited to 'ot')
-rw-r--r--ot/gromov.py26
-rw-r--r--ot/optim.py32
-rw-r--r--ot/utils.py8
3 files changed, 27 insertions, 39 deletions
diff --git a/ot/gromov.py b/ot/gromov.py
index 53349b7..ca96b31 100644
--- a/ot/gromov.py
+++ b/ot/gromov.py
@@ -17,7 +17,7 @@ import numpy as np
from .bregman import sinkhorn
-from .utils import dist
+from .utils import dist, UndefinedParameter
from .optim import cg
@@ -1011,9 +1011,6 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_
International Conference on Machine Learning (ICML). 2019.
"""
- class UndefinedParameter(Exception):
- pass
-
S = len(Cs)
d = Ys[0].shape[1] # dimension on the node features
if p is None:
@@ -1049,10 +1046,7 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_
T = [np.outer(p, q) for q in ps]
- # X is N,d
- # Ys is ns,d
- Ms = [np.asarray(dist(X, Ys[s]), dtype=np.float64) for s in range(len(Ys))]
- # Ms is N,ns
+ Ms = [np.asarray(dist(X, Ys[s]), dtype=np.float64) for s in range(len(Ys))] # Ms is N,ns
cpt = 0
err_feature = 1
@@ -1072,27 +1066,13 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_
Ys_temp = [y.T for y in Ys]
X = update_feature_matrix(lambdas, Ys_temp, T, p).T
- # X must be N,d
- # Ys must be ns,d
Ms = [np.asarray(dist(X, Ys[s]), dtype=np.float64) for s in range(len(Ys))]
if not fixed_structure:
if loss_fun == 'square_loss':
- # T must be ns,N
- # Cs must be ns,ns
- # p must be N,1
T_temp = [t.T for t in T]
C = update_sructure_matrix(p, lambdas, T_temp, Cs)
- # Ys must be d,ns
- # Ts must be N,ns
- # p must be N,1
- # Ms is N,ns
- # C is N,N
- # Cs is ns,ns
- # p is N,1
- # ps is ns,1
-
T = [fused_gromov_wasserstein((1 - alpha) * Ms[s], C, Cs[s], p, ps[s], loss_fun, alpha, numItermax=max_iter, stopThr=1e-5, verbose=verbose) for s in range(S)]
# T is N,ns
@@ -1115,7 +1095,7 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_
if log:
log_['T'] = T # from target to Ys
log_['p'] = p
- log_['Ms'] = Ms # Ms are N,ns
+ log_['Ms'] = Ms
if log:
return X, C, log_
diff --git a/ot/optim.py b/ot/optim.py
index 4d428d9..f94aceb 100644
--- a/ot/optim.py
+++ b/ot/optim.py
@@ -73,8 +73,8 @@ def line_search_armijo(f, xk, pk, gfk, old_fval,
return alpha, fc[0], phi1
-def do_linesearch(cost, G, deltaG, Mi, f_val,
- armijo=True, C1=None, C2=None, reg=None, Gc=None, constC=None, M=None):
+def solve_linesearch(cost, G, deltaG, Mi, f_val,
+ armijo=True, C1=None, C2=None, reg=None, Gc=None, constC=None, M=None):
"""
Solve the linesearch in the FW iterations
Parameters
@@ -93,17 +93,17 @@ def do_linesearch(cost, G, deltaG, Mi, f_val,
If True the steps of the line-search is found via an armijo research. Else closed form is used.
If there is convergence issues use False.
C1 : ndarray (ns,ns), optional
- Structure matrix in the source domain. Only used when armijo=False
+ Structure matrix in the source domain. Only used and necessary when armijo=False
C2 : ndarray (nt,nt), optional
- Structure matrix in the target domain. Only used when armijo=False
+ Structure matrix in the target domain. Only used and necessary when armijo=False
reg : float, optional
- Regularization parameter. Only used when armijo=False
+ Regularization parameter. Only used and necessary when armijo=False
Gc : ndarray (ns,nt)
- Optimal map found by linearization in the FW algorithm. Only used when armijo=False
+ Optimal map found by linearization in the FW algorithm. Only used and necessary when armijo=False
constC : ndarray (ns,nt)
- Constant for the gromov cost. See [24]. Only used when armijo=False
+ Constant for the gromov cost. See [24]. Only used and necessary when armijo=False
M : ndarray (ns,nt), optional
- Cost matrix between the features. Only used when armijo=False
+ Cost matrix between the features. Only used and necessary when armijo=False
Returns
-------
alpha : float
@@ -128,7 +128,7 @@ def do_linesearch(cost, G, deltaG, Mi, f_val,
b = np.sum((M + reg * constC) * deltaG) - 2 * reg * (np.sum(dot12 * G) + np.sum(np.dot(C1, G).dot(C2) * deltaG))
c = cost(G)
- alpha = solve_1d_linesearch_quad_funct(a, b, c)
+ alpha = solve_1d_linesearch_quad(a, b, c)
fc = None
f_val = cost(G + alpha * deltaG)
@@ -181,7 +181,7 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200,
Print information along iterations
log : bool, optional
record log if True
- kwargs : dict
+ **kwargs : dict
Parameters for linesearch
Returns
@@ -244,7 +244,7 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200,
deltaG = Gc - G
# line search
- alpha, fc, f_val = do_linesearch(cost, G, deltaG, Mi, f_val, reg=reg, M=M, Gc=Gc, **kwargs)
+ alpha, fc, f_val = solve_linesearch(cost, G, deltaG, Mi, f_val, reg=reg, M=M, Gc=Gc, **kwargs)
G = G + alpha * deltaG
@@ -254,7 +254,7 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200,
abs_delta_fval = abs(f_val - old_fval)
relative_delta_fval = abs_delta_fval / abs(f_val)
- if relative_delta_fval < stopThr and abs_delta_fval < stopThr2:
+ if relative_delta_fval < stopThr or abs_delta_fval < stopThr2:
loop = 0
if log:
@@ -395,7 +395,7 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10,
abs_delta_fval = abs(f_val - old_fval)
relative_delta_fval = abs_delta_fval / abs(f_val)
- if relative_delta_fval < stopThr and abs_delta_fval < stopThr2:
+ if relative_delta_fval < stopThr or abs_delta_fval < stopThr2:
loop = 0
if log:
@@ -413,11 +413,11 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10,
return G
-def solve_1d_linesearch_quad_funct(a, b, c):
+def solve_1d_linesearch_quad(a, b, c):
"""
- Solve on 0,1 the following problem:
+ For any convex or non-convex 1d quadratic function f, solve on [0,1] the following problem:
.. math::
- \min f(x)=a*x^{2}+b*x+c
+ \argmin f(x)=a*x^{2}+b*x+c
Parameters
----------
diff --git a/ot/utils.py b/ot/utils.py
index bb21b38..efd1288 100644
--- a/ot/utils.py
+++ b/ot/utils.py
@@ -487,3 +487,11 @@ class BaseEstimator(object):
(key, self.__class__.__name__))
setattr(self, key, value)
return self
+
+
+class UndefinedParameter(Exception):
+ """
+ Aim at raising an Exception when a undefined parameter is called
+
+ """
+ pass