summaryrefslogtreecommitdiff
path: root/ot/gromov.py
diff options
context:
space:
mode:
Diffstat (limited to 'ot/gromov.py')
-rw-r--r--ot/gromov.py59
1 files changed, 36 insertions, 23 deletions
diff --git a/ot/gromov.py b/ot/gromov.py
index ea667e4..6544260 100644
--- a/ot/gromov.py
+++ b/ot/gromov.py
@@ -822,8 +822,12 @@ def GW_distance_estimation(C1, C2, p, q, loss_fun, T,
index_k = np.zeros((nb_samples_p, nb_samples_q), dtype=int)
index_l = np.zeros((nb_samples_p, nb_samples_q), dtype=int)
- index_i = generator.choice(len_p, size=nb_samples_p, p=p, replace=False)
- index_j = generator.choice(len_p, size=nb_samples_p, p=p, replace=False)
+ index_i = generator.choice(
+ len_p, size=nb_samples_p, p=nx.to_numpy(p), replace=False
+ )
+ index_j = generator.choice(
+ len_p, size=nb_samples_p, p=nx.to_numpy(p), replace=False
+ )
for i in range(nb_samples_p):
if nx.issparse(T):
@@ -836,13 +840,13 @@ def GW_distance_estimation(C1, C2, p, q, loss_fun, T,
index_k[i] = generator.choice(
len_q,
size=nb_samples_q,
- p=T_indexi / nx.sum(T_indexi),
+ p=nx.to_numpy(T_indexi / nx.sum(T_indexi)),
replace=True
)
index_l[i] = generator.choice(
len_q,
size=nb_samples_q,
- p=T_indexj / nx.sum(T_indexj),
+ p=nx.to_numpy(T_indexj / nx.sum(T_indexj)),
replace=True
)
@@ -934,15 +938,17 @@ def pointwise_gromov_wasserstein(C1, C2, p, q, loss_fun,
index = np.zeros(2, dtype=int)
# Initialize with default marginal
- index[0] = generator.choice(len_p, size=1, p=p)
- index[1] = generator.choice(len_q, size=1, p=q)
+ index[0] = generator.choice(len_p, size=1, p=nx.to_numpy(p))
+ index[1] = generator.choice(len_q, size=1, p=nx.to_numpy(q))
T = nx.tocsr(emd_1d(C1[index[0]], C2[index[1]], a=p, b=q, dense=False))
best_gw_dist_estimated = np.inf
for cpt in range(max_iter):
- index[0] = generator.choice(len_p, size=1, p=p)
+ index[0] = generator.choice(len_p, size=1, p=nx.to_numpy(p))
T_index0 = nx.reshape(nx.todense(T[index[0], :]), (-1,))
- index[1] = generator.choice(len_q, size=1, p=T_index0 / T_index0.sum())
+ index[1] = generator.choice(
+ len_q, size=1, p=nx.to_numpy(T_index0 / nx.sum(T_index0))
+ )
if alpha == 1:
T = nx.tocsr(
@@ -1071,13 +1077,16 @@ def sampled_gromov_wasserstein(C1, C2, p, q, loss_fun,
C_are_symmetric = nx.allclose(C1, C1.T, rtol=1e-10, atol=1e-10) and nx.allclose(C2, C2.T, rtol=1e-10, atol=1e-10)
for cpt in range(max_iter):
- index0 = generator.choice(len_p, size=nb_samples_grad_p, p=p, replace=False)
+ index0 = generator.choice(
+ len_p, size=nb_samples_grad_p, p=nx.to_numpy(p), replace=False
+ )
Lik = 0
for i, index0_i in enumerate(index0):
- index1 = generator.choice(len_q,
- size=nb_samples_grad_q,
- p=T[index0_i, :] / nx.sum(T[index0_i, :]),
- replace=False)
+ index1 = generator.choice(
+ len_q, size=nb_samples_grad_q,
+ p=nx.to_numpy(T[index0_i, :] / nx.sum(T[index0_i, :])),
+ replace=False
+ )
# If the matrices C are not symmetric, the gradient has 2 terms, thus the term is chosen randomly.
if (not C_are_symmetric) and generator.rand(1) > 0.5:
Lik += nx.mean(loss_fun(
@@ -1359,6 +1368,8 @@ def entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon,
-------
C : array-like, shape (`N`, `N`)
Similarity matrix in the barycenter space (permutated arbitrarily)
+ log : dict
+ Log dictionary of error during iterations. Return only if `log=True` in parameters.
References
----------
@@ -1392,7 +1403,7 @@ def entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon,
Cprev = C
T = [entropic_gromov_wasserstein(Cs[s], C, ps[s], p, loss_fun, epsilon,
- max_iter, 1e-4, verbose, log) for s in range(S)]
+ max_iter, 1e-4, verbose, log=False) for s in range(S)]
if loss_fun == 'square_loss':
C = update_square_loss(p, lambdas, T, Cs)
@@ -1405,9 +1416,6 @@ def entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon,
err = nx.norm(C - Cprev)
error.append(err)
- if log:
- log['err'].append(err)
-
if verbose:
if cpt % 200 == 0:
print('{:5s}|{:12s}'.format(
@@ -1416,7 +1424,10 @@ def entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon,
cpt += 1
- return C
+ if log:
+ return C, {"err": error}
+ else:
+ return C
def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun,
@@ -1470,6 +1481,8 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun,
-------
C : array-like, shape (`N`, `N`)
Similarity matrix in the barycenter space (permutated arbitrarily)
+ log : dict
+ Log dictionary of error during iterations. Return only if `log=True` in parameters.
References
----------
@@ -1504,7 +1517,7 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun,
Cprev = C
T = [gromov_wasserstein(Cs[s], C, ps[s], p, loss_fun,
- numItermax=max_iter, stopThr=1e-5, verbose=verbose, log=log) for s in range(S)]
+ numItermax=max_iter, stopThr=1e-5, verbose=verbose, log=False) for s in range(S)]
if loss_fun == 'square_loss':
C = update_square_loss(p, lambdas, T, Cs)
@@ -1517,9 +1530,6 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun,
err = nx.norm(C - Cprev)
error.append(err)
- if log:
- log['err'].append(err)
-
if verbose:
if cpt % 200 == 0:
print('{:5s}|{:12s}'.format(
@@ -1528,7 +1538,10 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun,
cpt += 1
- return C
+ if log:
+ return C, {"err": error}
+ else:
+ return C
def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_features=False,