From c9578b4cc29b58d9cde9ff586870140021471fc1 Mon Sep 17 00:00:00 2001 From: Théo Lacombe Date: Wed, 21 Dec 2022 09:00:21 +0100 Subject: [MRG] Fix#421 pass stopThr to the sinkhorn function in empirical_sinkhorn_divergence (#422) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix stopThr hardcoded in some places * added fix documentation in RELEASES.Md Co-authored-by: Rémi Flamary --- ot/bregman.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) (limited to 'ot') diff --git a/ot/bregman.py b/ot/bregman.py index 89eb295..aa3cf1a 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -1281,7 +1281,7 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, regi = get_reg(ii) G, logi = sinkhorn_stabilized(a, b, M, regi, - numItermax=numInnerItermax, stopThr=1e-9, + numItermax=numInnerItermax, stopThr=stopThr, warmstart=(alpha, beta), verbose=False, print_period=20, tau=tau, log=True) @@ -3306,17 +3306,17 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli if log: sinkhorn_loss_ab, log_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, - stopThr=1e-9, verbose=verbose, + stopThr=stopThr, verbose=verbose, log=log, warn=warn, **kwargs) sinkhorn_loss_a, log_a = empirical_sinkhorn2(X_s, X_s, reg, a, a, metric=metric, numIterMax=numIterMax, - stopThr=1e-9, verbose=verbose, + stopThr=stopThr, verbose=verbose, log=log, warn=warn, **kwargs) sinkhorn_loss_b, log_b = empirical_sinkhorn2(X_t, X_t, reg, b, b, metric=metric, numIterMax=numIterMax, - stopThr=1e-9, verbose=verbose, + stopThr=stopThr, verbose=verbose, log=log, warn=warn, **kwargs) sinkhorn_div = sinkhorn_loss_ab - 0.5 * (sinkhorn_loss_a + sinkhorn_loss_b) @@ -3333,17 +3333,17 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli else: sinkhorn_loss_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, - numIterMax=numIterMax, stopThr=1e-9, + numIterMax=numIterMax, stopThr=stopThr, verbose=verbose, log=log, warn=warn, **kwargs) sinkhorn_loss_a = empirical_sinkhorn2(X_s, X_s, reg, a, a, metric=metric, - numIterMax=numIterMax, stopThr=1e-9, + numIterMax=numIterMax, stopThr=stopThr, verbose=verbose, log=log, warn=warn, **kwargs) sinkhorn_loss_b = empirical_sinkhorn2(X_t, X_t, reg, b, b, metric=metric, - numIterMax=numIterMax, stopThr=1e-9, + numIterMax=numIterMax, stopThr=stopThr, verbose=verbose, log=log, warn=warn, **kwargs) -- cgit v1.2.3