summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorThéo Lacombe <lacombe1993@gmail.com>2022-12-21 09:00:21 +0100
committerGitHub <noreply@github.com>2022-12-21 09:00:21 +0100
commitc9578b4cc29b58d9cde9ff586870140021471fc1 (patch)
tree700bea5d262d8cc1a6a9ac8330a306bee95e59fe
parentf8277713d8a63293a4d04082b031d0e467a06d47 (diff)
[MRG] Fix#421 pass stopThr to the sinkhorn function in empirical_sinkhorn_divergence (#422)
* fix stopThr hardcoded in some places * added fix documentation in RELEASES.Md Co-authored-by: Rémi Flamary <remi.flamary@gmail.com>
-rw-r--r--RELEASES.md2
-rw-r--r--ot/bregman.py14
2 files changed, 9 insertions, 7 deletions
diff --git a/RELEASES.md b/RELEASES.md
index 49475f2..4e41af6 100644
--- a/RELEASES.md
+++ b/RELEASES.md
@@ -32,6 +32,8 @@ roughly 2^31) (PR #381)
- Fixed weak optimal transport docstring (Issue #404, PR #410)
- Fixed error whith parameter `log=True`for `SinkhornLpl1Transport` (Issue #412,
PR #413)
+- Fix an issue where the parameter `stopThr` in `empirical_sinkhorn_divergence` was rendered useless by subcalls
+ that explicitly specified `stopThr=1e-9` (Issue #421, PR #422).
- Fixed a bug breaking an example where we would try to make an array of arrays of different shapes (Issue #424, PR #425)
diff --git a/ot/bregman.py b/ot/bregman.py
index 89eb295..aa3cf1a 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -1281,7 +1281,7 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4,
regi = get_reg(ii)
G, logi = sinkhorn_stabilized(a, b, M, regi,
- numItermax=numInnerItermax, stopThr=1e-9,
+ numItermax=numInnerItermax, stopThr=stopThr,
warmstart=(alpha, beta), verbose=False,
print_period=20, tau=tau, log=True)
@@ -3306,17 +3306,17 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
if log:
sinkhorn_loss_ab, log_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric,
numIterMax=numIterMax,
- stopThr=1e-9, verbose=verbose,
+ stopThr=stopThr, verbose=verbose,
log=log, warn=warn, **kwargs)
sinkhorn_loss_a, log_a = empirical_sinkhorn2(X_s, X_s, reg, a, a, metric=metric,
numIterMax=numIterMax,
- stopThr=1e-9, verbose=verbose,
+ stopThr=stopThr, verbose=verbose,
log=log, warn=warn, **kwargs)
sinkhorn_loss_b, log_b = empirical_sinkhorn2(X_t, X_t, reg, b, b, metric=metric,
numIterMax=numIterMax,
- stopThr=1e-9, verbose=verbose,
+ stopThr=stopThr, verbose=verbose,
log=log, warn=warn, **kwargs)
sinkhorn_div = sinkhorn_loss_ab - 0.5 * (sinkhorn_loss_a + sinkhorn_loss_b)
@@ -3333,17 +3333,17 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
else:
sinkhorn_loss_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric,
- numIterMax=numIterMax, stopThr=1e-9,
+ numIterMax=numIterMax, stopThr=stopThr,
verbose=verbose, log=log,
warn=warn, **kwargs)
sinkhorn_loss_a = empirical_sinkhorn2(X_s, X_s, reg, a, a, metric=metric,
- numIterMax=numIterMax, stopThr=1e-9,
+ numIterMax=numIterMax, stopThr=stopThr,
verbose=verbose, log=log,
warn=warn, **kwargs)
sinkhorn_loss_b = empirical_sinkhorn2(X_t, X_t, reg, b, b, metric=metric,
- numIterMax=numIterMax, stopThr=1e-9,
+ numIterMax=numIterMax, stopThr=stopThr,
verbose=verbose, log=log,
warn=warn, **kwargs)