diff options
author | Mokhtar Z. Alaya <mokhtarzahdi.alaya@gmail.com> | 2020-01-07 16:03:18 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-01-07 16:03:18 +0100 |
commit | 69c666fc82553bed0fbbc7fc17a906eb2487ddf7 (patch) | |
tree | 8ae2bbcd7b9f42071d5156173b5e9849ac772fe4 /ot | |
parent | be33a36eb1916968d9281c5a76e12e04b7ddb686 (diff) |
set default param. for LBFGS in the function's prototype
Diffstat (limited to 'ot')
-rw-r--r-- | ot/bregman.py | 18 |
1 files changed, 12 insertions, 6 deletions
diff --git a/ot/bregman.py b/ot/bregman.py index 4a899a6..12eaa65 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -1789,7 +1789,8 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli sinkhorn_div = sinkhorn_loss_ab - 1 / 2 * (sinkhorn_loss_a + sinkhorn_loss_b) return max(0, sinkhorn_div) -def screenkhorn(a, b, M, reg, ns_budget, nt_budget, uniform=True, restricted=True, verbose=False, log=False): +def screenkhorn(a, b, M, reg, ns_budget, nt_budget, uniform=True, restricted=True, + maxiter=10000, maxfun=10000, pgtol=1e-09, verbose=False, log=False): """" Screening Sinkhorn Algorithm for Regularized Optimal Transport @@ -1834,6 +1835,15 @@ def screenkhorn(a, b, M, reg, ns_budget, nt_budget, uniform=True, restricted=Tru restricted: `bool`, default=True If `True`, a warm-start initialization for the L-BFGS-B solver using a restricted Sinkhorn algorithm with at most 5 iterations + + maxiter : `int`, default=10000 + Maximum number of iterations in LBFGS solver + + maxfun : `int`, default=10000 + Maximum number of function evaluations in LBFGS solver + + pgtol : `float`, default=1e-09 + Final objective function accuracy in LBFGS solver verbose: `bool`, default=False If `True`, dispaly informations along iterations @@ -2056,7 +2066,6 @@ def screenkhorn(a, b, M, reg, ns_budget, nt_budget, uniform=True, restricted=Tru return grad_u, grad_v def bfgspost(theta): - u = theta[:ns_budget] v = theta[ns_budget:] # objective @@ -2072,10 +2081,7 @@ def screenkhorn(a, b, M, reg, ns_budget, nt_budget, uniform=True, restricted=Tru u0, v0 = restricted_sinkhorn(u0, v0) theta0 = np.hstack([u0, v0]) - maxiter = 10000 # max number of iterations - maxfun = 10000 # max number of function evaluations - pgtol = 1e-09 # final objective function accuracy - + bounds = bounds_u + bounds_v # constraint bounds obj = lambda theta: bfgspost(theta) |