From 69c666fc82553bed0fbbc7fc17a906eb2487ddf7 Mon Sep 17 00:00:00 2001 From: "Mokhtar Z. Alaya" Date: Tue, 7 Jan 2020 16:03:18 +0100 Subject: set default param. for LBFGS in the function's prototype --- ot/bregman.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) (limited to 'ot') diff --git a/ot/bregman.py b/ot/bregman.py index 4a899a6..12eaa65 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -1789,7 +1789,8 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli sinkhorn_div = sinkhorn_loss_ab - 1 / 2 * (sinkhorn_loss_a + sinkhorn_loss_b) return max(0, sinkhorn_div) -def screenkhorn(a, b, M, reg, ns_budget, nt_budget, uniform=True, restricted=True, verbose=False, log=False): +def screenkhorn(a, b, M, reg, ns_budget, nt_budget, uniform=True, restricted=True, + maxiter=10000, maxfun=10000, pgtol=1e-09, verbose=False, log=False): """" Screening Sinkhorn Algorithm for Regularized Optimal Transport @@ -1834,6 +1835,15 @@ def screenkhorn(a, b, M, reg, ns_budget, nt_budget, uniform=True, restricted=Tru restricted: `bool`, default=True If `True`, a warm-start initialization for the L-BFGS-B solver using a restricted Sinkhorn algorithm with at most 5 iterations + + maxiter : `int`, default=10000 + Maximum number of iterations in LBFGS solver + + maxfun : `int`, default=10000 + Maximum number of function evaluations in LBFGS solver + + pgtol : `float`, default=1e-09 + Final objective function accuracy in LBFGS solver verbose: `bool`, default=False If `True`, dispaly informations along iterations @@ -2056,7 +2066,6 @@ def screenkhorn(a, b, M, reg, ns_budget, nt_budget, uniform=True, restricted=Tru return grad_u, grad_v def bfgspost(theta): - u = theta[:ns_budget] v = theta[ns_budget:] # objective @@ -2072,10 +2081,7 @@ def screenkhorn(a, b, M, reg, ns_budget, nt_budget, uniform=True, restricted=Tru u0, v0 = restricted_sinkhorn(u0, v0) theta0 = np.hstack([u0, v0]) - maxiter = 10000 # max number of iterations - maxfun = 10000 # max number of function evaluations - pgtol = 1e-09 # final objective function accuracy - + bounds = bounds_u + bounds_v # constraint bounds obj = lambda theta: bfgspost(theta) -- cgit v1.2.3