summaryrefslogtreecommitdiff
path: root/ot/bregman.py
diff options
context:
space:
mode:
authorMokhtar Z. Alaya <mokhtarzahdi.alaya@gmail.com>2020-01-07 16:10:36 +0100
committerGitHub <noreply@github.com>2020-01-07 16:10:36 +0100
commit05a97b44a7137ef6cb0397cca3bb2ea1f8736ac5 (patch)
tree22d37a8673af44b16c880d52194f54cba7fa27cb /ot/bregman.py
parent69c666fc82553bed0fbbc7fc17a906eb2487ddf7 (diff)
fix default values for the budget arguments
Diffstat (limited to 'ot/bregman.py')
-rw-r--r--ot/bregman.py16
1 files changed, 12 insertions, 4 deletions
diff --git a/ot/bregman.py b/ot/bregman.py
index 12eaa65..8a20307 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -1789,7 +1789,7 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
sinkhorn_div = sinkhorn_loss_ab - 1 / 2 * (sinkhorn_loss_a + sinkhorn_loss_b)
return max(0, sinkhorn_div)
-def screenkhorn(a, b, M, reg, ns_budget, nt_budget, uniform=True, restricted=True,
+def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=True, restricted=True,
maxiter=10000, maxfun=10000, pgtol=1e-09, verbose=False, log=False):
""""
@@ -1823,11 +1823,13 @@ def screenkhorn(a, b, M, reg, ns_budget, nt_budget, uniform=True, restricted=Tru
reg : `float`
Level of the entropy regularisation
- ns_budget: `int`
+ ns_budget: `int`, deafult=None
Number budget of points to be keeped in the source domain
+ If it is None then 50% of the source sample points will be keeped
- nt_budget: `int`
+ nt_budget: `int`, deafult=None
Number budget of points to be keeped in the target domain
+ If it is None then 50% of the target sample points will be keeped
uniform: `bool`, default=True
If `True`, a_i = 1. / ns and b_j = 1. / nt
@@ -1874,11 +1876,17 @@ def screenkhorn(a, b, M, reg, ns_budget, nt_budget, uniform=True, restricted=Tru
import bottleneck
except ImportError as e:
print("Bottleneck module doesn't exist. Install it from https://pypi.org/project/Bottleneck/")
-
+
a = np.asarray(a, dtype=np.float64)
b = np.asarray(b, dtype=np.float64)
M = np.asarray(M, dtype=np.float64)
ns, nt = M.shape
+
+ # by default, we keep only 50% of the sapmle data points
+ if ns_budget is None:
+ ns_budget = int(np.floor(0.5*ns))
+ if nt_budget is None:
+ ns_budget = int(np.floor(0.5*ns))
# calculate the Gibbs kernel
K = np.empty_like(M)