summaryrefslogtreecommitdiff
path: root/ot/utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'ot/utils.py')
-rw-r--r--ot/utils.py47
1 files changed, 19 insertions, 28 deletions
diff --git a/ot/utils.py b/ot/utils.py
index e8249ef..8419c83 100644
--- a/ot/utils.py
+++ b/ot/utils.py
@@ -111,12 +111,12 @@ def dist(x1, x2=None, metric='sqeuclidean'):
Parameters
----------
- x1 : np.array (n1,d)
+ x1 : ndarray, shape (n1,d)
matrix with n1 samples of size d
- x2 : np.array (n2,d), optional
+ x2 : array, shape (n2,d), optional
matrix with n2 samples of size d (if None then x2=x1)
- metric : str, fun, optional
- name of the metric to be computed (full list in the doc of scipy), If a string,
+ metric : str | callable, optional
+ Name of the metric to be computed (full list in the doc of scipy), If a string,
the distance function can be 'braycurtis', 'canberra', 'chebyshev', 'cityblock',
'correlation', 'cosine', 'dice', 'euclidean', 'hamming', 'jaccard', 'kulsinski',
'mahalanobis', 'matching', 'minkowski', 'rogerstanimoto', 'russellrao', 'seuclidean',
@@ -138,26 +138,21 @@ def dist(x1, x2=None, metric='sqeuclidean'):
def dist0(n, method='lin_square'):
- """Compute standard cost matrices of size (n,n) for OT problems
+ """Compute standard cost matrices of size (n, n) for OT problems
Parameters
----------
-
n : int
- size of the cost matrix
+ Size of the cost matrix.
method : str, optional
Type of loss matrix chosen from:
* 'lin_square' : linear sampling between 0 and n-1, quadratic loss
-
Returns
-------
-
- M : np.array (n1,n2)
- distance matrix computed with given metric
-
-
+ M : ndarray, shape (n1,n2)
+ Distance matrix computed with given metric.
"""
res = 0
if method == 'lin_square':
@@ -169,22 +164,18 @@ def dist0(n, method='lin_square'):
def cost_normalization(C, norm=None):
""" Apply normalization to the loss matrix
-
Parameters
----------
- C : np.array (n1, n2)
+ C : ndarray, shape (n1, n2)
The cost matrix to normalize.
norm : str
- type of normalization from 'median','max','log','loglog'. Any other
- value do not normalize.
-
+ Type of normalization from 'median', 'max', 'log', 'loglog'. Any
+ other value do not normalize.
Returns
-------
-
- C : np.array (n1, n2)
+ C : ndarray, shape (n1, n2)
The input cost matrix normalized according to given norm.
-
"""
if norm == "median":
@@ -194,7 +185,7 @@ def cost_normalization(C, norm=None):
elif norm == "log":
C = np.log(1 + C)
elif norm == "loglog":
- C = np.log(1 + np.log(1 + C))
+ C = np.log1p(np.log1p(C))
return C
@@ -256,6 +247,7 @@ def check_params(**kwargs):
def check_random_state(seed):
"""Turn seed into a np.random.RandomState instance
+
Parameters
----------
seed : None | int | instance of RandomState
@@ -275,7 +267,6 @@ def check_random_state(seed):
class deprecated(object):
-
"""Decorator to mark a function or class as deprecated.
deprecated class from scikit-learn package
@@ -291,8 +282,8 @@ class deprecated(object):
Parameters
----------
- extra : string
- to be added to the deprecation messages
+ extra : str
+ To be added to the deprecation messages.
"""
# Adapted from http://wiki.python.org/moin/PythonDecoratorLibrary,
@@ -373,9 +364,9 @@ def _is_deprecated(func):
class BaseEstimator(object):
-
"""Base class for most objects in POT
- adapted from sklearn BaseEstimator class
+
+ Code adapted from sklearn BaseEstimator class
Notes
-----
@@ -417,7 +408,7 @@ class BaseEstimator(object):
Parameters
----------
- deep : boolean, optional
+ deep : bool, optional
If True, will return the parameters for this estimator and
contained subobjects that are estimators.