summaryrefslogtreecommitdiff
path: root/ot/utils.py
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2016-10-28 12:31:07 +0200
committerRémi Flamary <remi.flamary@gmail.com>2016-10-28 12:31:07 +0200
commitf33087d2b1790dac773782bb0d91bcfe7ce6a079 (patch)
tree428b0e271d7c8267dc8aee25c0068e087d1f35f5 /ot/utils.py
parentc418ef460514a991d95bb2e2a1937b1aedd3e0c9 (diff)
doc utils.py
Diffstat (limited to 'ot/utils.py')
-rw-r--r--ot/utils.py51
1 files changed, 46 insertions, 5 deletions
diff --git a/ot/utils.py b/ot/utils.py
index 46c3775..e5ec864 100644
--- a/ot/utils.py
+++ b/ot/utils.py
@@ -6,21 +6,62 @@ from scipy.spatial.distance import cdist
def unif(n):
- """ return a uniform histogram (simplex) """
+ """ return a uniform histogram of length n (simplex) """
return np.ones((n,))/n
def dist(x1,x2=None,metric='sqeuclidean'):
- """Compute distance between samples in x1 and x2"""
+ """Compute distance between samples in x1 and x2 using function scipy.spatial.distance.cdist
+
+ Parameters
+ ----------
+
+ x1 : np.array (n1,d)
+ matrix with n1 samples of size d
+ x2 : np.array (n2,d), optional
+ matrix with n2 samples of size d (if None then x2=x1)
+ metric : str, fun, optional
+ name of the metric to be computed (full list in the doc of scipy), If a string,
+ the distance function can be ‘braycurtis’, ‘canberra’, ‘chebyshev’, ‘cityblock’,
+ ‘correlation’, ‘cosine’, ‘dice’, ‘euclidean’, ‘hamming’, ‘jaccard’, ‘kulsinski’,
+ ‘mahalanobis’, ‘matching’, ‘minkowski’, ‘rogerstanimoto’, ‘russellrao’, ‘seuclidean’,
+ ‘sokalmichener’, ‘sokalsneath’, ‘sqeuclidean’, ‘wminkowski’, ‘yule’.
+
+
+ Returns
+ -------
+ M : np.array (n1,n2)
+ distance matrix computed with given metric
+
+ """
if x2 is None:
return cdist(x1,x1,metric=metric)
else:
return cdist(x1,x2,metric=metric)
-def dist0(n,method='linear'):
- """Compute stardard cos matrices for OT problems"""
+def dist0(n,method='lin_square'):
+ """Compute standard cost matrices of size (n,n) for OT problems
+
+ Parameters
+ ----------
+
+ n : int
+ size of the cost matrix
+ method : str, optional
+ Type of loss matrix chosen from:
+
+ * 'lin_square' : linear sampling between 0 and n-1, quadratic loss
+
+
+ Returns
+ -------
+ M : np.array (n1,n2)
+ distance matrix computed with given metric
+
+
+ """
res=0
- if method=='linear':
+ if method=='lin_square':
x=np.arange(n,dtype=np.float64).reshape((n,1))
res=dist(x,x)
return res