summaryrefslogtreecommitdiff
path: root/ot/bregman.py
diff options
context:
space:
mode:
Diffstat (limited to 'ot/bregman.py')
-rw-r--r--ot/bregman.py14
1 files changed, 7 insertions, 7 deletions
diff --git a/ot/bregman.py b/ot/bregman.py
index 89eb295..aa3cf1a 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -1281,7 +1281,7 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4,
regi = get_reg(ii)
G, logi = sinkhorn_stabilized(a, b, M, regi,
- numItermax=numInnerItermax, stopThr=1e-9,
+ numItermax=numInnerItermax, stopThr=stopThr,
warmstart=(alpha, beta), verbose=False,
print_period=20, tau=tau, log=True)
@@ -3306,17 +3306,17 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
if log:
sinkhorn_loss_ab, log_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric,
numIterMax=numIterMax,
- stopThr=1e-9, verbose=verbose,
+ stopThr=stopThr, verbose=verbose,
log=log, warn=warn, **kwargs)
sinkhorn_loss_a, log_a = empirical_sinkhorn2(X_s, X_s, reg, a, a, metric=metric,
numIterMax=numIterMax,
- stopThr=1e-9, verbose=verbose,
+ stopThr=stopThr, verbose=verbose,
log=log, warn=warn, **kwargs)
sinkhorn_loss_b, log_b = empirical_sinkhorn2(X_t, X_t, reg, b, b, metric=metric,
numIterMax=numIterMax,
- stopThr=1e-9, verbose=verbose,
+ stopThr=stopThr, verbose=verbose,
log=log, warn=warn, **kwargs)
sinkhorn_div = sinkhorn_loss_ab - 0.5 * (sinkhorn_loss_a + sinkhorn_loss_b)
@@ -3333,17 +3333,17 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
else:
sinkhorn_loss_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric,
- numIterMax=numIterMax, stopThr=1e-9,
+ numIterMax=numIterMax, stopThr=stopThr,
verbose=verbose, log=log,
warn=warn, **kwargs)
sinkhorn_loss_a = empirical_sinkhorn2(X_s, X_s, reg, a, a, metric=metric,
- numIterMax=numIterMax, stopThr=1e-9,
+ numIterMax=numIterMax, stopThr=stopThr,
verbose=verbose, log=log,
warn=warn, **kwargs)
sinkhorn_loss_b = empirical_sinkhorn2(X_t, X_t, reg, b, b, metric=metric,
- numIterMax=numIterMax, stopThr=1e-9,
+ numIterMax=numIterMax, stopThr=stopThr,
verbose=verbose, log=log,
warn=warn, **kwargs)