summaryrefslogtreecommitdiff
path: root/ot/optim.py
diff options
context:
space:
mode:
authortvayer <titouan.vayer@gmail.com>2019-05-29 16:10:26 +0200
committertvayer <titouan.vayer@gmail.com>2019-05-29 16:10:26 +0200
commitd4320382fa8873d15dcaec7adca3a4723c142515 (patch)
tree15769bbf4faa109949b40720ee9f751846fffdd8 /ot/optim.py
parent9421dddd8890d4c575b593d678eb7bdf5f933f83 (diff)
relative+absolute loss
Diffstat (limited to 'ot/optim.py')
-rw-r--r--ot/optim.py31
1 files changed, 19 insertions, 12 deletions
diff --git a/ot/optim.py b/ot/optim.py
index 82a91bf..7d103e2 100644
--- a/ot/optim.py
+++ b/ot/optim.py
@@ -135,7 +135,7 @@ def do_linesearch(cost, G, deltaG, Mi, f_val,
def cg(a, b, M, reg, f, df, G0=None, numItermax=200,
- stopThr=1e-9, verbose=False, log=False, **kwargs):
+ stopThr=1e-9, stopThr2=1e-9, verbose=False, log=False, **kwargs):
"""
Solve the general regularized OT problem with conditional gradient
@@ -173,7 +173,9 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200,
numItermax : int, optional
Max number of iterations
stopThr : float, optional
- Stop threshol on error (>0)
+ Stop threshol on the relative variation (>0)
+ stopThr2 : float, optional
+ Stop threshol on the absolute variation (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
@@ -249,8 +251,9 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200,
if it >= numItermax:
loop = 0
- delta_fval = (f_val - old_fval) / abs(f_val)
- if abs(delta_fval) < stopThr:
+ abs_delta_fval = abs(f_val - old_fval)
+ relative_delta_fval = abs_delta_fval / abs(f_val)
+ if relative_delta_fval < stopThr and abs_delta_fval < stopThr2:
loop = 0
if log:
@@ -259,8 +262,8 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200,
if verbose:
if it % 20 == 0:
print('{:5s}|{:12s}|{:8s}'.format(
- 'It.', 'Loss', 'Delta loss') + '\n' + '-' * 32)
- print('{:5d}|{:8e}|{:8e}'.format(it, f_val, delta_fval))
+ 'It.', 'Loss', 'Relative variation loss', 'Absolute variation loss') + '\n' + '-' * 32)
+ print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, f_val, relative_delta_fval, abs_delta_fval))
if log:
return G, log
@@ -269,7 +272,7 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200,
def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10,
- numInnerItermax=200, stopThr=1e-9, verbose=False, log=False):
+ numInnerItermax=200, stopThr=1e-9, stopThr2=1e-9, verbose=False, log=False):
"""
Solve the general regularized OT problem with the generalized conditional gradient
@@ -312,7 +315,9 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10,
numInnerItermax : int, optional
Max number of iterations of Sinkhorn
stopThr : float, optional
- Stop threshol on error (>0)
+ Stop threshol on the relative variation (>0)
+ stopThr2 : float, optional
+ Stop threshol on the absolute variation (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
@@ -386,8 +391,10 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10,
if it >= numItermax:
loop = 0
- delta_fval = (f_val - old_fval) / abs(f_val)
- if abs(delta_fval) < stopThr:
+ abs_delta_fval = abs(f_val - old_fval)
+ relative_delta_fval = abs_delta_fval / abs(f_val)
+
+ if relative_delta_fval < stopThr and abs_delta_fval < stopThr2:
loop = 0
if log:
@@ -396,8 +403,8 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10,
if verbose:
if it % 20 == 0:
print('{:5s}|{:12s}|{:8s}'.format(
- 'It.', 'Loss', 'Delta loss') + '\n' + '-' * 32)
- print('{:5d}|{:8e}|{:8e}'.format(it, f_val, delta_fval))
+ 'It.', 'Loss', 'Relative variation loss', 'Absolute variation loss') + '\n' + '-' * 32)
+ print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, f_val, relative_delta_fval, abs_delta_fval))
if log:
return G, log