summaryrefslogtreecommitdiff
path: root/ot/bregman.py
diff options
context:
space:
mode:
Diffstat (limited to 'ot/bregman.py')
-rw-r--r--ot/bregman.py72
1 files changed, 42 insertions, 30 deletions
diff --git a/ot/bregman.py b/ot/bregman.py
index cce52e2..fc20175 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -830,9 +830,9 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False,
a, b, M = list_to_array(a, b, M)
nx = get_backend(M, a, b)
- if nx.__name__ == "jax":
- raise TypeError("JAX arrays have been received. Greenkhorn is not "
- "compatible with JAX")
+ if nx.__name__ in ("jax", "tf"):
+ raise TypeError("JAX or TF arrays have been received. Greenkhorn is not "
+ "compatible with neither JAX nor TF")
if len(a) == 0:
a = nx.ones((M.shape[0],), type_as=M) / M.shape[0]
@@ -865,20 +865,20 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False,
if m_viol_1 > m_viol_2:
old_u = u[i_1]
- new_u = a[i_1] / (K[i_1, :].dot(v))
+ new_u = a[i_1] / nx.dot(K[i_1, :], v)
G[i_1, :] = new_u * K[i_1, :] * v
- viol[i_1] = new_u * K[i_1, :].dot(v) - a[i_1]
+ viol[i_1] = nx.dot(new_u * K[i_1, :], v) - a[i_1]
viol_2 += (K[i_1, :].T * (new_u - old_u) * v)
u[i_1] = new_u
else:
old_v = v[i_2]
- new_v = b[i_2] / (K[:, i_2].T.dot(u))
+ new_v = b[i_2] / nx.dot(K[:, i_2].T, u)
G[:, i_2] = u * K[:, i_2] * new_v
# aviol = (G@one_m - a)
# aviol_2 = (G.T@one_n - b)
viol += (-old_v + new_v) * K[:, i_2] * u
- viol_2[i_2] = new_v * K[:, i_2].dot(u) - b[i_2]
+ viol_2[i_2] = new_v * nx.dot(K[:, i_2], u) - b[i_2]
v[i_2] = new_v
if stopThr_val <= stopThr:
@@ -1550,9 +1550,11 @@ def _barycenter_sinkhorn_log(A, M, reg, weights=None, numItermax=1000,
nx = get_backend(A, M)
- if nx.__name__ == "jax":
- raise NotImplementedError("Log-domain functions are not yet implemented"
- " for Jax. Use numpy or torch arrays instead.")
+ if nx.__name__ in ("jax", "tf"):
+ raise NotImplementedError(
+ "Log-domain functions are not yet implemented"
+ " for Jax and tf. Use numpy or torch arrays instead."
+ )
if weights is None:
weights = nx.ones(n_hists, type_as=A) / n_hists
@@ -1886,9 +1888,11 @@ def _barycenter_debiased_log(A, M, reg, weights=None, numItermax=1000,
dim, n_hists = A.shape
nx = get_backend(A, M)
- if nx.__name__ == "jax":
- raise NotImplementedError("Log-domain functions are not yet implemented"
- " for Jax. Use numpy or torch arrays instead.")
+ if nx.__name__ in ("jax", "tf"):
+ raise NotImplementedError(
+ "Log-domain functions are not yet implemented"
+ " for Jax and TF. Use numpy or torch arrays instead."
+ )
if weights is None:
weights = nx.ones(n_hists, type_as=A) / n_hists
@@ -2043,7 +2047,7 @@ def _convolutional_barycenter2d(A, reg, weights=None, numItermax=10000,
log = {'err': []}
bar = nx.ones(A.shape[1:], type_as=A)
- bar /= bar.sum()
+ bar /= nx.sum(bar)
U = nx.ones(A.shape, type_as=A)
V = nx.ones(A.shape, type_as=A)
err = 1
@@ -2069,9 +2073,11 @@ def _convolutional_barycenter2d(A, reg, weights=None, numItermax=10000,
KV = convol_imgs(V)
U = A / KV
KU = convol_imgs(U)
- bar = nx.exp((weights[:, None, None] * nx.log(KU + stabThr)).sum(axis=0))
+ bar = nx.exp(
+ nx.sum(weights[:, None, None] * nx.log(KU + stabThr), axis=0)
+ )
if ii % 10 == 9:
- err = (V * KU).std(axis=0).sum()
+ err = nx.sum(nx.std(V * KU, axis=0))
# log and verbose print
if log:
log['err'].append(err)
@@ -2106,9 +2112,11 @@ def _convolutional_barycenter2d_log(A, reg, weights=None, numItermax=10000,
A = list_to_array(A)
nx = get_backend(A)
- if nx.__name__ == "jax":
- raise NotImplementedError("Log-domain functions are not yet implemented"
- " for Jax. Use numpy or torch arrays instead.")
+ if nx.__name__ in ("jax", "tf"):
+ raise NotImplementedError(
+ "Log-domain functions are not yet implemented"
+ " for Jax and TF. Use numpy or torch arrays instead."
+ )
n_hists, width, height = A.shape
@@ -2298,13 +2306,15 @@ def _convolutional_barycenter2d_debiased(A, reg, weights=None, numItermax=10000,
KV = convol_imgs(V)
U = A / KV
KU = convol_imgs(U)
- bar = c * nx.exp((weights[:, None, None] * nx.log(KU + stabThr)).sum(axis=0))
+ bar = c * nx.exp(
+ nx.sum(weights[:, None, None] * nx.log(KU + stabThr), axis=0)
+ )
for _ in range(10):
- c = (c * bar / convol_imgs(c[None]).squeeze()) ** 0.5
+ c = (c * bar / nx.squeeze(convol_imgs(c[None]))) ** 0.5
if ii % 10 == 9:
- err = (V * KU).std(axis=0).sum()
+ err = nx.sum(nx.std(V * KU, axis=0))
# log and verbose print
if log:
log['err'].append(err)
@@ -2340,9 +2350,11 @@ def _convolutional_barycenter2d_debiased_log(A, reg, weights=None, numItermax=10
A = list_to_array(A)
n_hists, width, height = A.shape
nx = get_backend(A)
- if nx.__name__ == "jax":
- raise NotImplementedError("Log-domain functions are not yet implemented"
- " for Jax. Use numpy or torch arrays instead.")
+ if nx.__name__ in ("jax", "tf"):
+ raise NotImplementedError(
+ "Log-domain functions are not yet implemented"
+ " for Jax and TF. Use numpy or torch arrays instead."
+ )
if weights is None:
weights = nx.ones((n_hists,), type_as=A) / n_hists
else:
@@ -2382,7 +2394,7 @@ def _convolutional_barycenter2d_debiased_log(A, reg, weights=None, numItermax=10
c = 0.5 * (c + log_bar - convol_img(c))
if ii % 10 == 9:
- err = nx.exp(G + log_KU).std(axis=0).sum()
+ err = nx.sum(nx.std(nx.exp(G + log_KU), axis=0))
# log and verbose print
if log:
log['err'].append(err)
@@ -3312,9 +3324,9 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False,
a, b, M = list_to_array(a, b, M)
nx = get_backend(M, a, b)
- if nx.__name__ == "jax":
- raise TypeError("JAX arrays have been received but screenkhorn is not "
- "compatible with JAX.")
+ if nx.__name__ in ("jax", "tf"):
+ raise TypeError("JAX or TF arrays have been received but screenkhorn is not "
+ "compatible with neither JAX nor TF.")
ns, nt = M.shape
@@ -3328,7 +3340,7 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False,
K = nx.exp(-M / reg)
def projection(u, epsilon):
- u[u <= epsilon] = epsilon
+ u = nx.maximum(u, epsilon)
return u
# ----------------------------------------------------------------------------------------------------------------#