diff options
Diffstat (limited to 'ot/lp/solver_1d.py')
-rw-r--r-- | ot/lp/solver_1d.py | 4 |
1 files changed, 2 insertions, 2 deletions
diff --git a/ot/lp/solver_1d.py b/ot/lp/solver_1d.py index 8b4d0c3..43763a9 100644 --- a/ot/lp/solver_1d.py +++ b/ot/lp/solver_1d.py @@ -100,11 +100,11 @@ def wasserstein_1d(u_values, v_values, u_weights=None, v_weights=None, p=1, requ m = v_values.shape[0] if u_weights is None: - u_weights = nx.full(u_values.shape, 1. / n) + u_weights = nx.full(u_values.shape, 1. / n, type_as=u_values) elif u_weights.ndim != u_values.ndim: u_weights = nx.repeat(u_weights[..., None], u_values.shape[-1], -1) if v_weights is None: - v_weights = nx.full(v_values.shape, 1. / m) + v_weights = nx.full(v_values.shape, 1. / m, type_as=v_values) elif v_weights.ndim != v_values.ndim: v_weights = nx.repeat(v_weights[..., None], v_values.shape[-1], -1) |