summaryrefslogtreecommitdiff
path: root/ot/bregman.py
diff options
context:
space:
mode:
authorMokhtar Z. Alaya <mokhtarzahdi.alaya@gmail.com>2020-01-08 11:03:59 +0100
committerGitHub <noreply@github.com>2020-01-08 11:03:59 +0100
commit45119609cbc317f59beb92382c28de6c51290c53 (patch)
tree43210060d212d8d5425526311874be18fc2c85e0 /ot/bregman.py
parent88fb534d83f42e45a42c0a9773ccfe338cd3a811 (diff)
using binary indexing for definition the active sets
Diffstat (limited to 'ot/bregman.py')
-rw-r--r--ot/bregman.py44
1 files changed, 21 insertions, 23 deletions
diff --git a/ot/bregman.py b/ot/bregman.py
index ceb7754..b664ac1 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -1884,13 +1884,13 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=True, rest
# by default, we keep only 50% of the sapmle data points
if ns_budget is None:
- ns_budget = int(np.floor(0.5*ns))
+ ns_budget = int(np.floor(0.5*ns))
if nt_budget is None:
- ns_budget = int(np.floor(0.5*ns))
+ nt_budget = int(np.floor(0.5*ns))
# calculate the Gibbs kernel
K = np.empty_like(M)
- np.divide(M, - reg, out=K)
+ np.divide(M, -reg, out=K)
np.exp(K, out=K)
def projection(u, epsilon):
@@ -1898,13 +1898,13 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=True, rest
return u
# ----------------------------------------------------------------------------------------------------------------#
- # Step 1: Screening Pre-processing #
+ # Step 1: Screening pre-processing #
# ----------------------------------------------------------------------------------------------------------------#
if ns_budget == ns and nt_budget == nt:
# full number of budget points (ns, nt) = (ns_budget, nt_budget)
- I = list(np.arange(ns))
- J = list(np.arange(nt))
+ I = np.arange(ns)
+ J = np.arange(nt)
epsilon = 0.0
kappa = 1.0
@@ -1953,37 +1953,34 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=True, rest
bK_sort = np.sort(bK)[::-1]
epsilon_v_square = bK_sort[ns_budget - 1]
- # active sets I and J (see Proposition .. in [1])
- I = np.where(a >= epsilon_u_square * K_sum_cols)[0].tolist()
- J = np.where(b >= epsilon_v_square * K_sum_rows)[0].tolist()
+ # active sets I and J (see Lemma 1 in [26])
+ I = np.where(a >= epsilon_u_square * K_sum_cols)[0]
+ J = np.where(b >= epsilon_v_square * K_sum_rows)[0]
if len(I) != ns_budget:
if uniform:
aK = a / K_sum_cols
aK_sort = np.sort(aK)[::-1]
epsilon_u_square = aK_sort[ns_budget - 1:ns_budget + 1].mean()
- I = np.where(a >= epsilon_u_square * K_sum_cols)[0].tolist()
+ I = np.where(a >= epsilon_u_square * K_sum_cols)[0]
if len(J) != nt_budget:
if uniform:
bK = b / K_sum_rows
bK_sort = np.sort(bK)[::-1]
- epsilon_v_square = bK_sort[ns_budget - 1:ns_budget + 1].mean()
- J = np.where(b >= epsilon_v_square * K_sum_rows)[0].tolist()
+ epsilon_v_square = bK_sort[nt_budget - 1:nt_budget + 1].mean()
+ J = np.where(b >= epsilon_v_square * K_sum_rows)[0]
epsilon = (epsilon_u_square * epsilon_v_square) ** (1 / 4)
kappa = (epsilon_v_square / epsilon_u_square) ** (1 / 2)
-
+
if verbose:
print("Epsilon = %s\n" %epsilon)
- print("Scaling factor = %s\n" %kappa)
-
- if verbose:
print('|I_active| = %s \t |J_active| = %s ' %(len(I), len(J)))
# Ic, Jc: complementary of the active sets I and J
- Ic = list(set(np.arange(ns)) - set(I))
- Jc = list(set(np.arange(nt)) - set(J))
+ Ic = np.arange(ns)[~np.isin(np.arange(ns), I)]
+ Jc = np.arange(nt)[~np.isin(np.arange(nt), J)]
K_IJ = K[np.ix_(I, J)]
K_IcJ = K[np.ix_(Ic, J)]
@@ -2109,12 +2106,13 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=True, rest
vsc_full[J] = vsc
if log:
- log['u'] = usc_full
- log['v'] = vsc_full
+ log['u'] = usc_full
+ log['v'] = vsc_full
gamma = usc_full[:, None] * K * vsc_full[None, :]
-
+ gamma = gamma / gamma.sum()
+
if log:
- return gamma, log
+ return gamma, log
else:
- return gamma
+ return gamma