summaryrefslogtreecommitdiff
path: root/ot/lp
diff options
context:
space:
mode:
Diffstat (limited to 'ot/lp')
-rw-r--r--ot/lp/__init__.py11
-rw-r--r--ot/lp/cvx.py2
-rw-r--r--ot/lp/solver_1d.py12
3 files changed, 14 insertions, 11 deletions
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py
index 2ff02ab..4952a21 100644
--- a/ot/lp/__init__.py
+++ b/ot/lp/__init__.py
@@ -253,7 +253,7 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1):
Otherwise returns only the optimal transportation matrix.
center_dual: boolean, optional (default=True)
If True, centers the dual potential using function
- :ref:`center_ot_dual`.
+ :py:func:`ot.lp.center_ot_dual`.
numThreads: int or "max", optional (default=1, i.e. OpenMP is not used)
If compiled with OpenMP, chooses the number of threads to parallelize.
"max" selects the highest number possible.
@@ -418,7 +418,7 @@ def emd2(a, b, M, processes=1,
If True, returns the optimal transportation matrix in the log.
center_dual: boolean, optional (default=True)
If True, centers the dual potential using function
- :ref:`center_ot_dual`.
+ :py:func:`ot.lp.center_ot_dual`.
numThreads: int or "max", optional (default=1, i.e. OpenMP is not used)
If compiled with OpenMP, chooses the number of threads to parallelize.
"max" selects the highest number possible.
@@ -631,6 +631,7 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None
.. _references-free-support-barycenter:
+
References
----------
.. [20] Cuturi, Marco, and Arnaud Doucet. "Fast computation of Wasserstein barycenters." International Conference on Machine Learning. 2014.
@@ -688,7 +689,7 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None
def generalized_free_support_barycenter(X_list, a_list, P_list, n_samples_bary, Y_init=None, b=None, weights=None,
numItermax=100, stopThr=1e-7, verbose=False, log=None, numThreads=1, eps=0):
r"""
- Solves the free support generalised Wasserstein barycenter problem: finding a barycenter (a discrete measure with
+ Solves the free support generalized Wasserstein barycenter problem: finding a barycenter (a discrete measure with
a fixed amount of points of uniform weights) whose respective projections fit the input measures.
More formally:
@@ -776,7 +777,7 @@ def generalized_free_support_barycenter(X_list, a_list, P_list, n_samples_bary,
Y_init = nx.randn(n_samples_bary, d, type_as=X_list[0])
if b is None:
- b = nx.ones(n_samples_bary, type_as=X_list[0]) / n_samples_bary # not optimised
+ b = nx.ones(n_samples_bary, type_as=X_list[0]) / n_samples_bary # not optimized
out = free_support_barycenter(Z_list, a_list, Y_init, b, numItermax=numItermax,
stopThr=stopThr, verbose=verbose, log=log, numThreads=numThreads)
@@ -786,7 +787,7 @@ def generalized_free_support_barycenter(X_list, a_list, P_list, n_samples_bary,
else:
Y = out
log_dict = None
- Y = Y @ B.T # return to the Generalised WB formulation
+ Y = Y @ B.T # return to the Generalized WB formulation
if log:
return Y, log_dict
diff --git a/ot/lp/cvx.py b/ot/lp/cvx.py
index 361ad0f..3f7eb36 100644
--- a/ot/lp/cvx.py
+++ b/ot/lp/cvx.py
@@ -52,7 +52,7 @@ def barycenter(A, M, weights=None, verbose=False, log=False, solver='interior-po
reg : float
Regularization term >0
weights : np.ndarray (n,)
- Weights of each histogram a_i on the simplex (barycentric coodinates)
+ Weights of each histogram a_i on the simplex (barycentric coordinates)
verbose : bool, optional
Print information along iterations
log : bool, optional
diff --git a/ot/lp/solver_1d.py b/ot/lp/solver_1d.py
index 840801a..8d841ec 100644
--- a/ot/lp/solver_1d.py
+++ b/ot/lp/solver_1d.py
@@ -37,7 +37,7 @@ def quantile_function(qs, cws, xs):
n = xs.shape[0]
if nx.__name__ == 'torch':
# this is to ensure the best performance for torch searchsorted
- # and avoid a warninng related to non-contiguous arrays
+ # and avoid a warning related to non-contiguous arrays
cws = cws.T.contiguous()
qs = qs.T.contiguous()
else:
@@ -145,6 +145,7 @@ def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True,
s.t. \gamma 1 = a,
\gamma^T 1= b,
\gamma\geq 0
+
where :
- d is the metric
@@ -283,6 +284,7 @@ def emd2_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True,
s.t. \gamma 1 = a,
\gamma^T 1= b,
\gamma\geq 0
+
where :
- d is the metric
@@ -464,7 +466,7 @@ def derivative_cost_on_circle(theta, u_values, v_values, u_cdf, v_cdf, p=2):
if nx.__name__ == 'torch':
# this is to ensure the best performance for torch searchsorted
- # and avoid a warninng related to non-contiguous arrays
+ # and avoid a warning related to non-contiguous arrays
u_cdf = u_cdf.contiguous()
v_cdf_theta = v_cdf_theta.contiguous()
@@ -478,7 +480,7 @@ def derivative_cost_on_circle(theta, u_values, v_values, u_cdf, v_cdf, p=2):
if nx.__name__ == 'torch':
# this is to ensure the best performance for torch searchsorted
- # and avoid a warninng related to non-contiguous arrays
+ # and avoid a warning related to non-contiguous arrays
u_cdfm = u_cdfm.contiguous()
v_cdf_theta = v_cdf_theta.contiguous()
@@ -665,8 +667,8 @@ def binary_search_circle(u_values, v_values, u_weights=None, v_weights=None, p=1
if u_values.shape[1] != v_values.shape[1]:
raise ValueError(
- "u and v must have the same number of batchs {} and {} respectively given".format(u_values.shape[1],
- v_values.shape[1]))
+ "u and v must have the same number of batches {} and {} respectively given".format(u_values.shape[1],
+ v_values.shape[1]))
u_values = u_values % 1
v_values = v_values % 1