diff options
Diffstat (limited to 'ot/lp/solver_1d.py')
-rw-r--r-- | ot/lp/solver_1d.py | 12 |
1 files changed, 7 insertions, 5 deletions
diff --git a/ot/lp/solver_1d.py b/ot/lp/solver_1d.py index 840801a..8d841ec 100644 --- a/ot/lp/solver_1d.py +++ b/ot/lp/solver_1d.py @@ -37,7 +37,7 @@ def quantile_function(qs, cws, xs): n = xs.shape[0] if nx.__name__ == 'torch': # this is to ensure the best performance for torch searchsorted - # and avoid a warninng related to non-contiguous arrays + # and avoid a warning related to non-contiguous arrays cws = cws.T.contiguous() qs = qs.T.contiguous() else: @@ -145,6 +145,7 @@ def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True, s.t. \gamma 1 = a, \gamma^T 1= b, \gamma\geq 0 + where : - d is the metric @@ -283,6 +284,7 @@ def emd2_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True, s.t. \gamma 1 = a, \gamma^T 1= b, \gamma\geq 0 + where : - d is the metric @@ -464,7 +466,7 @@ def derivative_cost_on_circle(theta, u_values, v_values, u_cdf, v_cdf, p=2): if nx.__name__ == 'torch': # this is to ensure the best performance for torch searchsorted - # and avoid a warninng related to non-contiguous arrays + # and avoid a warning related to non-contiguous arrays u_cdf = u_cdf.contiguous() v_cdf_theta = v_cdf_theta.contiguous() @@ -478,7 +480,7 @@ def derivative_cost_on_circle(theta, u_values, v_values, u_cdf, v_cdf, p=2): if nx.__name__ == 'torch': # this is to ensure the best performance for torch searchsorted - # and avoid a warninng related to non-contiguous arrays + # and avoid a warning related to non-contiguous arrays u_cdfm = u_cdfm.contiguous() v_cdf_theta = v_cdf_theta.contiguous() @@ -665,8 +667,8 @@ def binary_search_circle(u_values, v_values, u_weights=None, v_weights=None, p=1 if u_values.shape[1] != v_values.shape[1]: raise ValueError( - "u and v must have the same number of batchs {} and {} respectively given".format(u_values.shape[1], - v_values.shape[1])) + "u and v must have the same number of batches {} and {} respectively given".format(u_values.shape[1], + v_values.shape[1])) u_values = u_values % 1 v_values = v_values % 1 |