From 05a97b44a7137ef6cb0397cca3bb2ea1f8736ac5 Mon Sep 17 00:00:00 2001 From: "Mokhtar Z. Alaya" Date: Tue, 7 Jan 2020 16:10:36 +0100 Subject: fix default values for the budget arguments --- ot/bregman.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) (limited to 'ot/bregman.py') diff --git a/ot/bregman.py b/ot/bregman.py index 12eaa65..8a20307 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -1789,7 +1789,7 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli sinkhorn_div = sinkhorn_loss_ab - 1 / 2 * (sinkhorn_loss_a + sinkhorn_loss_b) return max(0, sinkhorn_div) -def screenkhorn(a, b, M, reg, ns_budget, nt_budget, uniform=True, restricted=True, +def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=True, restricted=True, maxiter=10000, maxfun=10000, pgtol=1e-09, verbose=False, log=False): """" @@ -1823,11 +1823,13 @@ def screenkhorn(a, b, M, reg, ns_budget, nt_budget, uniform=True, restricted=Tru reg : `float` Level of the entropy regularisation - ns_budget: `int` + ns_budget: `int`, deafult=None Number budget of points to be keeped in the source domain + If it is None then 50% of the source sample points will be keeped - nt_budget: `int` + nt_budget: `int`, deafult=None Number budget of points to be keeped in the target domain + If it is None then 50% of the target sample points will be keeped uniform: `bool`, default=True If `True`, a_i = 1. / ns and b_j = 1. / nt @@ -1874,11 +1876,17 @@ def screenkhorn(a, b, M, reg, ns_budget, nt_budget, uniform=True, restricted=Tru import bottleneck except ImportError as e: print("Bottleneck module doesn't exist. Install it from https://pypi.org/project/Bottleneck/") - + a = np.asarray(a, dtype=np.float64) b = np.asarray(b, dtype=np.float64) M = np.asarray(M, dtype=np.float64) ns, nt = M.shape + + # by default, we keep only 50% of the sapmle data points + if ns_budget is None: + ns_budget = int(np.floor(0.5*ns)) + if nt_budget is None: + ns_budget = int(np.floor(0.5*ns)) # calculate the Gibbs kernel K = np.empty_like(M) -- cgit v1.2.3