summaryrefslogtreecommitdiff
path: root/ot/utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'ot/utils.py')
-rw-r--r--ot/utils.py56
1 files changed, 26 insertions, 30 deletions
diff --git a/ot/utils.py b/ot/utils.py
index 5707d9b..b71458b 100644
--- a/ot/utils.py
+++ b/ot/utils.py
@@ -111,12 +111,12 @@ def dist(x1, x2=None, metric='sqeuclidean'):
Parameters
----------
- x1 : np.array (n1,d)
+ x1 : ndarray, shape (n1,d)
matrix with n1 samples of size d
- x2 : np.array (n2,d), optional
+ x2 : array, shape (n2,d), optional
matrix with n2 samples of size d (if None then x2=x1)
- metric : str, fun, optional
- name of the metric to be computed (full list in the doc of scipy), If a string,
+ metric : str | callable, optional
+ Name of the metric to be computed (full list in the doc of scipy), If a string,
the distance function can be 'braycurtis', 'canberra', 'chebyshev', 'cityblock',
'correlation', 'cosine', 'dice', 'euclidean', 'hamming', 'jaccard', 'kulsinski',
'mahalanobis', 'matching', 'minkowski', 'rogerstanimoto', 'russellrao', 'seuclidean',
@@ -138,26 +138,21 @@ def dist(x1, x2=None, metric='sqeuclidean'):
def dist0(n, method='lin_square'):
- """Compute standard cost matrices of size (n,n) for OT problems
+ """Compute standard cost matrices of size (n, n) for OT problems
Parameters
----------
-
n : int
- size of the cost matrix
+ Size of the cost matrix.
method : str, optional
Type of loss matrix chosen from:
* 'lin_square' : linear sampling between 0 and n-1, quadratic loss
-
Returns
-------
-
- M : np.array (n1,n2)
- distance matrix computed with given metric
-
-
+ M : ndarray, shape (n1,n2)
+ Distance matrix computed with given metric.
"""
res = 0
if method == 'lin_square':
@@ -169,33 +164,34 @@ def dist0(n, method='lin_square'):
def cost_normalization(C, norm=None):
""" Apply normalization to the loss matrix
-
Parameters
----------
- C : np.array (n1, n2)
+ C : ndarray, shape (n1, n2)
The cost matrix to normalize.
norm : str
- type of normalization from 'median','max','log','loglog'. Any other
- value do not normalize.
-
+ Type of normalization from 'median', 'max', 'log', 'loglog'. Any
+ other value do not normalize.
Returns
-------
-
- C : np.array (n1, n2)
+ C : ndarray, shape (n1, n2)
The input cost matrix normalized according to given norm.
-
"""
- if norm == "median":
+ if norm is None:
+ pass
+ elif norm == "median":
C /= float(np.median(C))
elif norm == "max":
C /= float(np.max(C))
elif norm == "log":
C = np.log(1 + C)
elif norm == "loglog":
- C = np.log(1 + np.log(1 + C))
-
+ C = np.log1p(np.log1p(C))
+ else:
+ raise ValueError('Norm %s is not a valid option.\n'
+ 'Valid options are:\n'
+ 'median, max, log, loglog' % norm)
return C
@@ -261,6 +257,7 @@ def check_params(**kwargs):
def check_random_state(seed):
"""Turn seed into a np.random.RandomState instance
+
Parameters
----------
seed : None | int | instance of RandomState
@@ -280,7 +277,6 @@ def check_random_state(seed):
class deprecated(object):
-
"""Decorator to mark a function or class as deprecated.
deprecated class from scikit-learn package
@@ -296,8 +292,8 @@ class deprecated(object):
Parameters
----------
- extra : string
- to be added to the deprecation messages
+ extra : str
+ To be added to the deprecation messages.
"""
# Adapted from http://wiki.python.org/moin/PythonDecoratorLibrary,
@@ -378,9 +374,9 @@ def _is_deprecated(func):
class BaseEstimator(object):
-
"""Base class for most objects in POT
- adapted from sklearn BaseEstimator class
+
+ Code adapted from sklearn BaseEstimator class
Notes
-----
@@ -422,7 +418,7 @@ class BaseEstimator(object):
Parameters
----------
- deep : boolean, optional
+ deep : bool, optional
If True, will return the parameters for this estimator and
contained subobjects that are estimators.