summaryrefslogtreecommitdiff
path: root/ot/lp/solver_1d.py
diff options
context:
space:
mode:
Diffstat (limited to 'ot/lp/solver_1d.py')
-rw-r--r--ot/lp/solver_1d.py12
1 files changed, 7 insertions, 5 deletions
diff --git a/ot/lp/solver_1d.py b/ot/lp/solver_1d.py
index 840801a..8d841ec 100644
--- a/ot/lp/solver_1d.py
+++ b/ot/lp/solver_1d.py
@@ -37,7 +37,7 @@ def quantile_function(qs, cws, xs):
n = xs.shape[0]
if nx.__name__ == 'torch':
# this is to ensure the best performance for torch searchsorted
- # and avoid a warninng related to non-contiguous arrays
+ # and avoid a warning related to non-contiguous arrays
cws = cws.T.contiguous()
qs = qs.T.contiguous()
else:
@@ -145,6 +145,7 @@ def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True,
s.t. \gamma 1 = a,
\gamma^T 1= b,
\gamma\geq 0
+
where :
- d is the metric
@@ -283,6 +284,7 @@ def emd2_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True,
s.t. \gamma 1 = a,
\gamma^T 1= b,
\gamma\geq 0
+
where :
- d is the metric
@@ -464,7 +466,7 @@ def derivative_cost_on_circle(theta, u_values, v_values, u_cdf, v_cdf, p=2):
if nx.__name__ == 'torch':
# this is to ensure the best performance for torch searchsorted
- # and avoid a warninng related to non-contiguous arrays
+ # and avoid a warning related to non-contiguous arrays
u_cdf = u_cdf.contiguous()
v_cdf_theta = v_cdf_theta.contiguous()
@@ -478,7 +480,7 @@ def derivative_cost_on_circle(theta, u_values, v_values, u_cdf, v_cdf, p=2):
if nx.__name__ == 'torch':
# this is to ensure the best performance for torch searchsorted
- # and avoid a warninng related to non-contiguous arrays
+ # and avoid a warning related to non-contiguous arrays
u_cdfm = u_cdfm.contiguous()
v_cdf_theta = v_cdf_theta.contiguous()
@@ -665,8 +667,8 @@ def binary_search_circle(u_values, v_values, u_weights=None, v_weights=None, p=1
if u_values.shape[1] != v_values.shape[1]:
raise ValueError(
- "u and v must have the same number of batchs {} and {} respectively given".format(u_values.shape[1],
- v_values.shape[1]))
+ "u and v must have the same number of batches {} and {} respectively given".format(u_values.shape[1],
+ v_values.shape[1]))
u_values = u_values % 1
v_values = v_values % 1