diff options
author | Théo Lacombe <lacombe1993@gmail.com> | 2022-12-21 09:00:21 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-12-21 09:00:21 +0100 |
commit | c9578b4cc29b58d9cde9ff586870140021471fc1 (patch) | |
tree | 700bea5d262d8cc1a6a9ac8330a306bee95e59fe | |
parent | f8277713d8a63293a4d04082b031d0e467a06d47 (diff) |
[MRG] Fix#421 pass stopThr to the sinkhorn function in empirical_sinkhorn_divergence (#422)
* fix stopThr hardcoded in some places
* added fix documentation in RELEASES.Md
Co-authored-by: Rémi Flamary <remi.flamary@gmail.com>
-rw-r--r-- | RELEASES.md | 2 | ||||
-rw-r--r-- | ot/bregman.py | 14 |
2 files changed, 9 insertions, 7 deletions
diff --git a/RELEASES.md b/RELEASES.md index 49475f2..4e41af6 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -32,6 +32,8 @@ roughly 2^31) (PR #381) - Fixed weak optimal transport docstring (Issue #404, PR #410) - Fixed error whith parameter `log=True`for `SinkhornLpl1Transport` (Issue #412, PR #413) +- Fix an issue where the parameter `stopThr` in `empirical_sinkhorn_divergence` was rendered useless by subcalls + that explicitly specified `stopThr=1e-9` (Issue #421, PR #422). - Fixed a bug breaking an example where we would try to make an array of arrays of different shapes (Issue #424, PR #425) diff --git a/ot/bregman.py b/ot/bregman.py index 89eb295..aa3cf1a 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -1281,7 +1281,7 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, regi = get_reg(ii) G, logi = sinkhorn_stabilized(a, b, M, regi, - numItermax=numInnerItermax, stopThr=1e-9, + numItermax=numInnerItermax, stopThr=stopThr, warmstart=(alpha, beta), verbose=False, print_period=20, tau=tau, log=True) @@ -3306,17 +3306,17 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli if log: sinkhorn_loss_ab, log_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, - stopThr=1e-9, verbose=verbose, + stopThr=stopThr, verbose=verbose, log=log, warn=warn, **kwargs) sinkhorn_loss_a, log_a = empirical_sinkhorn2(X_s, X_s, reg, a, a, metric=metric, numIterMax=numIterMax, - stopThr=1e-9, verbose=verbose, + stopThr=stopThr, verbose=verbose, log=log, warn=warn, **kwargs) sinkhorn_loss_b, log_b = empirical_sinkhorn2(X_t, X_t, reg, b, b, metric=metric, numIterMax=numIterMax, - stopThr=1e-9, verbose=verbose, + stopThr=stopThr, verbose=verbose, log=log, warn=warn, **kwargs) sinkhorn_div = sinkhorn_loss_ab - 0.5 * (sinkhorn_loss_a + sinkhorn_loss_b) @@ -3333,17 +3333,17 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli else: sinkhorn_loss_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, - numIterMax=numIterMax, stopThr=1e-9, + numIterMax=numIterMax, stopThr=stopThr, verbose=verbose, log=log, warn=warn, **kwargs) sinkhorn_loss_a = empirical_sinkhorn2(X_s, X_s, reg, a, a, metric=metric, - numIterMax=numIterMax, stopThr=1e-9, + numIterMax=numIterMax, stopThr=stopThr, verbose=verbose, log=log, warn=warn, **kwargs) sinkhorn_loss_b = empirical_sinkhorn2(X_t, X_t, reg, b, b, metric=metric, - numIterMax=numIterMax, stopThr=1e-9, + numIterMax=numIterMax, stopThr=stopThr, verbose=verbose, log=log, warn=warn, **kwargs) |