diff options
Diffstat (limited to 'ot/lp')
-rw-r--r-- | ot/lp/__init__.py | 11 | ||||
-rw-r--r-- | ot/lp/cvx.py | 2 | ||||
-rw-r--r-- | ot/lp/solver_1d.py | 12 |
3 files changed, 14 insertions, 11 deletions
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index 2ff02ab..4952a21 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -253,7 +253,7 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1): Otherwise returns only the optimal transportation matrix. center_dual: boolean, optional (default=True) If True, centers the dual potential using function - :ref:`center_ot_dual`. + :py:func:`ot.lp.center_ot_dual`. numThreads: int or "max", optional (default=1, i.e. OpenMP is not used) If compiled with OpenMP, chooses the number of threads to parallelize. "max" selects the highest number possible. @@ -418,7 +418,7 @@ def emd2(a, b, M, processes=1, If True, returns the optimal transportation matrix in the log. center_dual: boolean, optional (default=True) If True, centers the dual potential using function - :ref:`center_ot_dual`. + :py:func:`ot.lp.center_ot_dual`. numThreads: int or "max", optional (default=1, i.e. OpenMP is not used) If compiled with OpenMP, chooses the number of threads to parallelize. "max" selects the highest number possible. @@ -631,6 +631,7 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None .. _references-free-support-barycenter: + References ---------- .. [20] Cuturi, Marco, and Arnaud Doucet. "Fast computation of Wasserstein barycenters." International Conference on Machine Learning. 2014. @@ -688,7 +689,7 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None def generalized_free_support_barycenter(X_list, a_list, P_list, n_samples_bary, Y_init=None, b=None, weights=None, numItermax=100, stopThr=1e-7, verbose=False, log=None, numThreads=1, eps=0): r""" - Solves the free support generalised Wasserstein barycenter problem: finding a barycenter (a discrete measure with + Solves the free support generalized Wasserstein barycenter problem: finding a barycenter (a discrete measure with a fixed amount of points of uniform weights) whose respective projections fit the input measures. More formally: @@ -776,7 +777,7 @@ def generalized_free_support_barycenter(X_list, a_list, P_list, n_samples_bary, Y_init = nx.randn(n_samples_bary, d, type_as=X_list[0]) if b is None: - b = nx.ones(n_samples_bary, type_as=X_list[0]) / n_samples_bary # not optimised + b = nx.ones(n_samples_bary, type_as=X_list[0]) / n_samples_bary # not optimized out = free_support_barycenter(Z_list, a_list, Y_init, b, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, numThreads=numThreads) @@ -786,7 +787,7 @@ def generalized_free_support_barycenter(X_list, a_list, P_list, n_samples_bary, else: Y = out log_dict = None - Y = Y @ B.T # return to the Generalised WB formulation + Y = Y @ B.T # return to the Generalized WB formulation if log: return Y, log_dict diff --git a/ot/lp/cvx.py b/ot/lp/cvx.py index 361ad0f..3f7eb36 100644 --- a/ot/lp/cvx.py +++ b/ot/lp/cvx.py @@ -52,7 +52,7 @@ def barycenter(A, M, weights=None, verbose=False, log=False, solver='interior-po reg : float Regularization term >0 weights : np.ndarray (n,) - Weights of each histogram a_i on the simplex (barycentric coodinates) + Weights of each histogram a_i on the simplex (barycentric coordinates) verbose : bool, optional Print information along iterations log : bool, optional diff --git a/ot/lp/solver_1d.py b/ot/lp/solver_1d.py index 840801a..8d841ec 100644 --- a/ot/lp/solver_1d.py +++ b/ot/lp/solver_1d.py @@ -37,7 +37,7 @@ def quantile_function(qs, cws, xs): n = xs.shape[0] if nx.__name__ == 'torch': # this is to ensure the best performance for torch searchsorted - # and avoid a warninng related to non-contiguous arrays + # and avoid a warning related to non-contiguous arrays cws = cws.T.contiguous() qs = qs.T.contiguous() else: @@ -145,6 +145,7 @@ def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True, s.t. \gamma 1 = a, \gamma^T 1= b, \gamma\geq 0 + where : - d is the metric @@ -283,6 +284,7 @@ def emd2_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True, s.t. \gamma 1 = a, \gamma^T 1= b, \gamma\geq 0 + where : - d is the metric @@ -464,7 +466,7 @@ def derivative_cost_on_circle(theta, u_values, v_values, u_cdf, v_cdf, p=2): if nx.__name__ == 'torch': # this is to ensure the best performance for torch searchsorted - # and avoid a warninng related to non-contiguous arrays + # and avoid a warning related to non-contiguous arrays u_cdf = u_cdf.contiguous() v_cdf_theta = v_cdf_theta.contiguous() @@ -478,7 +480,7 @@ def derivative_cost_on_circle(theta, u_values, v_values, u_cdf, v_cdf, p=2): if nx.__name__ == 'torch': # this is to ensure the best performance for torch searchsorted - # and avoid a warninng related to non-contiguous arrays + # and avoid a warning related to non-contiguous arrays u_cdfm = u_cdfm.contiguous() v_cdf_theta = v_cdf_theta.contiguous() @@ -665,8 +667,8 @@ def binary_search_circle(u_values, v_values, u_weights=None, v_weights=None, p=1 if u_values.shape[1] != v_values.shape[1]: raise ValueError( - "u and v must have the same number of batchs {} and {} respectively given".format(u_values.shape[1], - v_values.shape[1])) + "u and v must have the same number of batches {} and {} respectively given".format(u_values.shape[1], + v_values.shape[1])) u_values = u_values % 1 v_values = v_values % 1 |