From d4320382fa8873d15dcaec7adca3a4723c142515 Mon Sep 17 00:00:00 2001 From: tvayer Date: Wed, 29 May 2019 16:10:26 +0200 Subject: relative+absolute loss --- ot/optim.py | 31 +++++++++++++++++++------------ 1 file changed, 19 insertions(+), 12 deletions(-) (limited to 'ot') diff --git a/ot/optim.py b/ot/optim.py index 82a91bf..7d103e2 100644 --- a/ot/optim.py +++ b/ot/optim.py @@ -135,7 +135,7 @@ def do_linesearch(cost, G, deltaG, Mi, f_val, def cg(a, b, M, reg, f, df, G0=None, numItermax=200, - stopThr=1e-9, verbose=False, log=False, **kwargs): + stopThr=1e-9, stopThr2=1e-9, verbose=False, log=False, **kwargs): """ Solve the general regularized OT problem with conditional gradient @@ -173,7 +173,9 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermax : int, optional Max number of iterations stopThr : float, optional - Stop threshol on error (>0) + Stop threshol on the relative variation (>0) + stopThr2 : float, optional + Stop threshol on the absolute variation (>0) verbose : bool, optional Print information along iterations log : bool, optional @@ -249,8 +251,9 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, if it >= numItermax: loop = 0 - delta_fval = (f_val - old_fval) / abs(f_val) - if abs(delta_fval) < stopThr: + abs_delta_fval = abs(f_val - old_fval) + relative_delta_fval = abs_delta_fval / abs(f_val) + if relative_delta_fval < stopThr and abs_delta_fval < stopThr2: loop = 0 if log: @@ -259,8 +262,8 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, if verbose: if it % 20 == 0: print('{:5s}|{:12s}|{:8s}'.format( - 'It.', 'Loss', 'Delta loss') + '\n' + '-' * 32) - print('{:5d}|{:8e}|{:8e}'.format(it, f_val, delta_fval)) + 'It.', 'Loss', 'Relative variation loss', 'Absolute variation loss') + '\n' + '-' * 32) + print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, f_val, relative_delta_fval, abs_delta_fval)) if log: return G, log @@ -269,7 +272,7 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10, - numInnerItermax=200, stopThr=1e-9, verbose=False, log=False): + numInnerItermax=200, stopThr=1e-9, stopThr2=1e-9, verbose=False, log=False): """ Solve the general regularized OT problem with the generalized conditional gradient @@ -312,7 +315,9 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10, numInnerItermax : int, optional Max number of iterations of Sinkhorn stopThr : float, optional - Stop threshol on error (>0) + Stop threshol on the relative variation (>0) + stopThr2 : float, optional + Stop threshol on the absolute variation (>0) verbose : bool, optional Print information along iterations log : bool, optional @@ -386,8 +391,10 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10, if it >= numItermax: loop = 0 - delta_fval = (f_val - old_fval) / abs(f_val) - if abs(delta_fval) < stopThr: + abs_delta_fval = abs(f_val - old_fval) + relative_delta_fval = abs_delta_fval / abs(f_val) + + if relative_delta_fval < stopThr and abs_delta_fval < stopThr2: loop = 0 if log: @@ -396,8 +403,8 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10, if verbose: if it % 20 == 0: print('{:5s}|{:12s}|{:8s}'.format( - 'It.', 'Loss', 'Delta loss') + '\n' + '-' * 32) - print('{:5d}|{:8e}|{:8e}'.format(it, f_val, delta_fval)) + 'It.', 'Loss', 'Relative variation loss', 'Absolute variation loss') + '\n' + '-' * 32) + print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, f_val, relative_delta_fval, abs_delta_fval)) if log: return G, log -- cgit v1.2.3