summaryrefslogtreecommitdiff
path: root/ot/bregman.py
diff options
context:
space:
mode:
authorMokhtar Z. Alaya <mokhtarzahdi.alaya@gmail.com>2020-01-07 16:03:18 +0100
committerGitHub <noreply@github.com>2020-01-07 16:03:18 +0100
commit69c666fc82553bed0fbbc7fc17a906eb2487ddf7 (patch)
tree8ae2bbcd7b9f42071d5156173b5e9849ac772fe4 /ot/bregman.py
parentbe33a36eb1916968d9281c5a76e12e04b7ddb686 (diff)
set default param. for LBFGS in the function's prototype
Diffstat (limited to 'ot/bregman.py')
-rw-r--r--ot/bregman.py18
1 files changed, 12 insertions, 6 deletions
diff --git a/ot/bregman.py b/ot/bregman.py
index 4a899a6..12eaa65 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -1789,7 +1789,8 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
sinkhorn_div = sinkhorn_loss_ab - 1 / 2 * (sinkhorn_loss_a + sinkhorn_loss_b)
return max(0, sinkhorn_div)
-def screenkhorn(a, b, M, reg, ns_budget, nt_budget, uniform=True, restricted=True, verbose=False, log=False):
+def screenkhorn(a, b, M, reg, ns_budget, nt_budget, uniform=True, restricted=True,
+ maxiter=10000, maxfun=10000, pgtol=1e-09, verbose=False, log=False):
""""
Screening Sinkhorn Algorithm for Regularized Optimal Transport
@@ -1834,6 +1835,15 @@ def screenkhorn(a, b, M, reg, ns_budget, nt_budget, uniform=True, restricted=Tru
restricted: `bool`, default=True
If `True`, a warm-start initialization for the L-BFGS-B solver
using a restricted Sinkhorn algorithm with at most 5 iterations
+
+ maxiter : `int`, default=10000
+ Maximum number of iterations in LBFGS solver
+
+ maxfun : `int`, default=10000
+ Maximum number of function evaluations in LBFGS solver
+
+ pgtol : `float`, default=1e-09
+ Final objective function accuracy in LBFGS solver
verbose: `bool`, default=False
If `True`, dispaly informations along iterations
@@ -2056,7 +2066,6 @@ def screenkhorn(a, b, M, reg, ns_budget, nt_budget, uniform=True, restricted=Tru
return grad_u, grad_v
def bfgspost(theta):
-
u = theta[:ns_budget]
v = theta[ns_budget:]
# objective
@@ -2072,10 +2081,7 @@ def screenkhorn(a, b, M, reg, ns_budget, nt_budget, uniform=True, restricted=Tru
u0, v0 = restricted_sinkhorn(u0, v0)
theta0 = np.hstack([u0, v0])
- maxiter = 10000 # max number of iterations
- maxfun = 10000 # max number of function evaluations
- pgtol = 1e-09 # final objective function accuracy
-
+
bounds = bounds_u + bounds_v # constraint bounds
obj = lambda theta: bfgspost(theta)