diff options
Diffstat (limited to 'ot/lp/solver_1d.py')
-rw-r--r-- | ot/lp/solver_1d.py | 10 |
1 files changed, 5 insertions, 5 deletions
diff --git a/ot/lp/solver_1d.py b/ot/lp/solver_1d.py index bcfc920..840801a 100644 --- a/ot/lp/solver_1d.py +++ b/ot/lp/solver_1d.py @@ -401,7 +401,7 @@ def roll_cols(M, shifts): n_rows, n_cols = M.shape - arange1 = nx.tile(nx.reshape(nx.arange(n_cols), (1, n_cols)), (n_rows, 1)) + arange1 = nx.tile(nx.reshape(nx.arange(n_cols, type_as=shifts), (1, n_cols)), (n_rows, 1)) arange2 = (arange1 - shifts) % n_cols return nx.take_along_axis(M, arange2, 1) @@ -600,7 +600,7 @@ def binary_search_circle(u_values, v_values, u_weights=None, v_weights=None, p=1 using e.g. ot.utils.get_coordinate_circle(x) - The function runs on backend but tensorflow is not supported. + The function runs on backend but tensorflow and jax are not supported. Parameters ---------- @@ -730,7 +730,7 @@ def binary_search_circle(u_values, v_values, u_weights=None, v_weights=None, p=1 tp[((1 - mask) * (dCp >= 0)) > 0] = tc[((1 - mask) * (dCp >= 0)) > 0] tc[((1 - mask) * (1 - done)) > 0] = (tm[((1 - mask) * (1 - done)) > 0] + tp[((1 - mask) * (1 - done)) > 0]) / 2 - w = ot_cost_on_circle(tc, u_values, v_values, u_cdf, v_cdf, p) + w = ot_cost_on_circle(nx.detach(tc), u_values, v_values, u_cdf, v_cdf, p) if log: return w, {"optimal_theta": tc[:, 0]} @@ -743,7 +743,7 @@ def wasserstein1_circle(u_values, v_values, u_weights=None, v_weights=None, requ takes the value modulo 1. If the values are on :math:`S^1\subset\mathbb{R}^2`, first find the coordinates using e.g. the atan2 function. - The function runs on backend but tensorflow is not supported. + The function runs on backend but tensorflow and jax are not supported. .. math:: W_1(u,v) = \int_0^1 |F_u(t)-F_v(t)-LevMed(F_u-F_v)|\ \mathrm{d}t @@ -864,7 +864,7 @@ def wasserstein_circle(u_values, v_values, u_weights=None, v_weights=None, p=1, using e.g. ot.utils.get_coordinate_circle(x) - The function runs on backend but tensorflow is not supported. + The function runs on backend but tensorflow and jax are not supported. Parameters ---------- |