summaryrefslogtreecommitdiff
path: root/ot/lp/solver_1d.py
diff options
context:
space:
mode:
Diffstat (limited to 'ot/lp/solver_1d.py')
-rw-r--r--ot/lp/solver_1d.py4
1 files changed, 2 insertions, 2 deletions
diff --git a/ot/lp/solver_1d.py b/ot/lp/solver_1d.py
index 8b4d0c3..43763a9 100644
--- a/ot/lp/solver_1d.py
+++ b/ot/lp/solver_1d.py
@@ -100,11 +100,11 @@ def wasserstein_1d(u_values, v_values, u_weights=None, v_weights=None, p=1, requ
m = v_values.shape[0]
if u_weights is None:
- u_weights = nx.full(u_values.shape, 1. / n)
+ u_weights = nx.full(u_values.shape, 1. / n, type_as=u_values)
elif u_weights.ndim != u_values.ndim:
u_weights = nx.repeat(u_weights[..., None], u_values.shape[-1], -1)
if v_weights is None:
- v_weights = nx.full(v_values.shape, 1. / m)
+ v_weights = nx.full(v_values.shape, 1. / m, type_as=v_values)
elif v_weights.ndim != v_values.ndim:
v_weights = nx.repeat(v_weights[..., None], v_values.shape[-1], -1)