diff options
author | Mokhtar Z. Alaya <mokhtarzahdi.alaya@gmail.com> | 2020-01-07 16:10:36 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-01-07 16:10:36 +0100 |
commit | 05a97b44a7137ef6cb0397cca3bb2ea1f8736ac5 (patch) | |
tree | 22d37a8673af44b16c880d52194f54cba7fa27cb /ot | |
parent | 69c666fc82553bed0fbbc7fc17a906eb2487ddf7 (diff) |
fix default values for the budget arguments
Diffstat (limited to 'ot')
-rw-r--r-- | ot/bregman.py | 16 |
1 files changed, 12 insertions, 4 deletions
diff --git a/ot/bregman.py b/ot/bregman.py index 12eaa65..8a20307 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -1789,7 +1789,7 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli sinkhorn_div = sinkhorn_loss_ab - 1 / 2 * (sinkhorn_loss_a + sinkhorn_loss_b) return max(0, sinkhorn_div) -def screenkhorn(a, b, M, reg, ns_budget, nt_budget, uniform=True, restricted=True, +def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=True, restricted=True, maxiter=10000, maxfun=10000, pgtol=1e-09, verbose=False, log=False): """" @@ -1823,11 +1823,13 @@ def screenkhorn(a, b, M, reg, ns_budget, nt_budget, uniform=True, restricted=Tru reg : `float` Level of the entropy regularisation - ns_budget: `int` + ns_budget: `int`, deafult=None Number budget of points to be keeped in the source domain + If it is None then 50% of the source sample points will be keeped - nt_budget: `int` + nt_budget: `int`, deafult=None Number budget of points to be keeped in the target domain + If it is None then 50% of the target sample points will be keeped uniform: `bool`, default=True If `True`, a_i = 1. / ns and b_j = 1. / nt @@ -1874,11 +1876,17 @@ def screenkhorn(a, b, M, reg, ns_budget, nt_budget, uniform=True, restricted=Tru import bottleneck except ImportError as e: print("Bottleneck module doesn't exist. Install it from https://pypi.org/project/Bottleneck/") - + a = np.asarray(a, dtype=np.float64) b = np.asarray(b, dtype=np.float64) M = np.asarray(M, dtype=np.float64) ns, nt = M.shape + + # by default, we keep only 50% of the sapmle data points + if ns_budget is None: + ns_budget = int(np.floor(0.5*ns)) + if nt_budget is None: + ns_budget = int(np.floor(0.5*ns)) # calculate the Gibbs kernel K = np.empty_like(M) |