diff options
author | Rémi Flamary <remi.flamary@gmail.com> | 2016-10-28 12:31:07 +0200 |
---|---|---|
committer | Rémi Flamary <remi.flamary@gmail.com> | 2016-10-28 12:31:07 +0200 |
commit | f33087d2b1790dac773782bb0d91bcfe7ce6a079 (patch) | |
tree | 428b0e271d7c8267dc8aee25c0068e087d1f35f5 | |
parent | c418ef460514a991d95bb2e2a1937b1aedd3e0c9 (diff) |
doc utils.py
-rw-r--r-- | ot/utils.py | 51 |
1 files changed, 46 insertions, 5 deletions
diff --git a/ot/utils.py b/ot/utils.py index 46c3775..e5ec864 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -6,21 +6,62 @@ from scipy.spatial.distance import cdist def unif(n): - """ return a uniform histogram (simplex) """ + """ return a uniform histogram of length n (simplex) """ return np.ones((n,))/n def dist(x1,x2=None,metric='sqeuclidean'): - """Compute distance between samples in x1 and x2""" + """Compute distance between samples in x1 and x2 using function scipy.spatial.distance.cdist + + Parameters + ---------- + + x1 : np.array (n1,d) + matrix with n1 samples of size d + x2 : np.array (n2,d), optional + matrix with n2 samples of size d (if None then x2=x1) + metric : str, fun, optional + name of the metric to be computed (full list in the doc of scipy), If a string, + the distance function can be ‘braycurtis’, ‘canberra’, ‘chebyshev’, ‘cityblock’, + ‘correlation’, ‘cosine’, ‘dice’, ‘euclidean’, ‘hamming’, ‘jaccard’, ‘kulsinski’, + ‘mahalanobis’, ‘matching’, ‘minkowski’, ‘rogerstanimoto’, ‘russellrao’, ‘seuclidean’, + ‘sokalmichener’, ‘sokalsneath’, ‘sqeuclidean’, ‘wminkowski’, ‘yule’. + + + Returns + ------- + M : np.array (n1,n2) + distance matrix computed with given metric + + """ if x2 is None: return cdist(x1,x1,metric=metric) else: return cdist(x1,x2,metric=metric) -def dist0(n,method='linear'): - """Compute stardard cos matrices for OT problems""" +def dist0(n,method='lin_square'): + """Compute standard cost matrices of size (n,n) for OT problems + + Parameters + ---------- + + n : int + size of the cost matrix + method : str, optional + Type of loss matrix chosen from: + + * 'lin_square' : linear sampling between 0 and n-1, quadratic loss + + + Returns + ------- + M : np.array (n1,n2) + distance matrix computed with given metric + + + """ res=0 - if method=='linear': + if method=='lin_square': x=np.arange(n,dtype=np.float64).reshape((n,1)) res=dist(x,x) return res |