diff options
Diffstat (limited to 'ot/bregman.py')
-rw-r--r-- | ot/bregman.py | 72 |
1 files changed, 42 insertions, 30 deletions
diff --git a/ot/bregman.py b/ot/bregman.py index cce52e2..fc20175 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -830,9 +830,9 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, a, b, M = list_to_array(a, b, M) nx = get_backend(M, a, b) - if nx.__name__ == "jax": - raise TypeError("JAX arrays have been received. Greenkhorn is not " - "compatible with JAX") + if nx.__name__ in ("jax", "tf"): + raise TypeError("JAX or TF arrays have been received. Greenkhorn is not " + "compatible with neither JAX nor TF") if len(a) == 0: a = nx.ones((M.shape[0],), type_as=M) / M.shape[0] @@ -865,20 +865,20 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, if m_viol_1 > m_viol_2: old_u = u[i_1] - new_u = a[i_1] / (K[i_1, :].dot(v)) + new_u = a[i_1] / nx.dot(K[i_1, :], v) G[i_1, :] = new_u * K[i_1, :] * v - viol[i_1] = new_u * K[i_1, :].dot(v) - a[i_1] + viol[i_1] = nx.dot(new_u * K[i_1, :], v) - a[i_1] viol_2 += (K[i_1, :].T * (new_u - old_u) * v) u[i_1] = new_u else: old_v = v[i_2] - new_v = b[i_2] / (K[:, i_2].T.dot(u)) + new_v = b[i_2] / nx.dot(K[:, i_2].T, u) G[:, i_2] = u * K[:, i_2] * new_v # aviol = (G@one_m - a) # aviol_2 = (G.T@one_n - b) viol += (-old_v + new_v) * K[:, i_2] * u - viol_2[i_2] = new_v * K[:, i_2].dot(u) - b[i_2] + viol_2[i_2] = new_v * nx.dot(K[:, i_2], u) - b[i_2] v[i_2] = new_v if stopThr_val <= stopThr: @@ -1550,9 +1550,11 @@ def _barycenter_sinkhorn_log(A, M, reg, weights=None, numItermax=1000, nx = get_backend(A, M) - if nx.__name__ == "jax": - raise NotImplementedError("Log-domain functions are not yet implemented" - " for Jax. Use numpy or torch arrays instead.") + if nx.__name__ in ("jax", "tf"): + raise NotImplementedError( + "Log-domain functions are not yet implemented" + " for Jax and tf. Use numpy or torch arrays instead." + ) if weights is None: weights = nx.ones(n_hists, type_as=A) / n_hists @@ -1886,9 +1888,11 @@ def _barycenter_debiased_log(A, M, reg, weights=None, numItermax=1000, dim, n_hists = A.shape nx = get_backend(A, M) - if nx.__name__ == "jax": - raise NotImplementedError("Log-domain functions are not yet implemented" - " for Jax. Use numpy or torch arrays instead.") + if nx.__name__ in ("jax", "tf"): + raise NotImplementedError( + "Log-domain functions are not yet implemented" + " for Jax and TF. Use numpy or torch arrays instead." + ) if weights is None: weights = nx.ones(n_hists, type_as=A) / n_hists @@ -2043,7 +2047,7 @@ def _convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, log = {'err': []} bar = nx.ones(A.shape[1:], type_as=A) - bar /= bar.sum() + bar /= nx.sum(bar) U = nx.ones(A.shape, type_as=A) V = nx.ones(A.shape, type_as=A) err = 1 @@ -2069,9 +2073,11 @@ def _convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, KV = convol_imgs(V) U = A / KV KU = convol_imgs(U) - bar = nx.exp((weights[:, None, None] * nx.log(KU + stabThr)).sum(axis=0)) + bar = nx.exp( + nx.sum(weights[:, None, None] * nx.log(KU + stabThr), axis=0) + ) if ii % 10 == 9: - err = (V * KU).std(axis=0).sum() + err = nx.sum(nx.std(V * KU, axis=0)) # log and verbose print if log: log['err'].append(err) @@ -2106,9 +2112,11 @@ def _convolutional_barycenter2d_log(A, reg, weights=None, numItermax=10000, A = list_to_array(A) nx = get_backend(A) - if nx.__name__ == "jax": - raise NotImplementedError("Log-domain functions are not yet implemented" - " for Jax. Use numpy or torch arrays instead.") + if nx.__name__ in ("jax", "tf"): + raise NotImplementedError( + "Log-domain functions are not yet implemented" + " for Jax and TF. Use numpy or torch arrays instead." + ) n_hists, width, height = A.shape @@ -2298,13 +2306,15 @@ def _convolutional_barycenter2d_debiased(A, reg, weights=None, numItermax=10000, KV = convol_imgs(V) U = A / KV KU = convol_imgs(U) - bar = c * nx.exp((weights[:, None, None] * nx.log(KU + stabThr)).sum(axis=0)) + bar = c * nx.exp( + nx.sum(weights[:, None, None] * nx.log(KU + stabThr), axis=0) + ) for _ in range(10): - c = (c * bar / convol_imgs(c[None]).squeeze()) ** 0.5 + c = (c * bar / nx.squeeze(convol_imgs(c[None]))) ** 0.5 if ii % 10 == 9: - err = (V * KU).std(axis=0).sum() + err = nx.sum(nx.std(V * KU, axis=0)) # log and verbose print if log: log['err'].append(err) @@ -2340,9 +2350,11 @@ def _convolutional_barycenter2d_debiased_log(A, reg, weights=None, numItermax=10 A = list_to_array(A) n_hists, width, height = A.shape nx = get_backend(A) - if nx.__name__ == "jax": - raise NotImplementedError("Log-domain functions are not yet implemented" - " for Jax. Use numpy or torch arrays instead.") + if nx.__name__ in ("jax", "tf"): + raise NotImplementedError( + "Log-domain functions are not yet implemented" + " for Jax and TF. Use numpy or torch arrays instead." + ) if weights is None: weights = nx.ones((n_hists,), type_as=A) / n_hists else: @@ -2382,7 +2394,7 @@ def _convolutional_barycenter2d_debiased_log(A, reg, weights=None, numItermax=10 c = 0.5 * (c + log_bar - convol_img(c)) if ii % 10 == 9: - err = nx.exp(G + log_KU).std(axis=0).sum() + err = nx.sum(nx.std(nx.exp(G + log_KU), axis=0)) # log and verbose print if log: log['err'].append(err) @@ -3312,9 +3324,9 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, a, b, M = list_to_array(a, b, M) nx = get_backend(M, a, b) - if nx.__name__ == "jax": - raise TypeError("JAX arrays have been received but screenkhorn is not " - "compatible with JAX.") + if nx.__name__ in ("jax", "tf"): + raise TypeError("JAX or TF arrays have been received but screenkhorn is not " + "compatible with neither JAX nor TF.") ns, nt = M.shape @@ -3328,7 +3340,7 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, K = nx.exp(-M / reg) def projection(u, epsilon): - u[u <= epsilon] = epsilon + u = nx.maximum(u, epsilon) return u # ----------------------------------------------------------------------------------------------------------------# |