diff options
author | Rémi Flamary <remi.flamary@gmail.com> | 2016-11-03 14:53:52 +0100 |
---|---|---|
committer | Rémi Flamary <remi.flamary@gmail.com> | 2016-11-03 14:53:52 +0100 |
commit | 566645ad184e1205f7f666ea2f19021254c33d74 (patch) | |
tree | 1d740a6771ab515d0cbfe9f21fde801398eb19b6 /ot/utils.py | |
parent | 981351165dbab740145d109b00782f0c41f2244b (diff) |
add mapping estimation (still debugging)
Diffstat (limited to 'ot/utils.py')
-rw-r--r-- | ot/utils.py | 51 |
1 files changed, 29 insertions, 22 deletions
diff --git a/ot/utils.py b/ot/utils.py index 24f65a8..47fe77f 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -6,28 +6,34 @@ import numpy as np from scipy.spatial.distance import cdist +def kernel(x1,x2,method='gaussian',sigma=1,**kwargs): + """Compute kernel matrix""" + if method.lower() in ['gaussian','gauss','rbf']: + K=np.exp(dist(x1,x2)/(2*sigma**2)) + return K + def unif(n): - """ return a uniform histogram of length n (simplex) - + """ return a uniform histogram of length n (simplex) + Parameters ---------- n : int number of bins in the histogram - + Returns ------- h : np.array (n,) - histogram of length n such that h_i=1/n for all i - - + histogram of length n such that h_i=1/n for all i + + """ return np.ones((n,))/n def dist(x1,x2=None,metric='sqeuclidean'): """Compute distance between samples in x1 and x2 using function scipy.spatial.distance.cdist - + Parameters ---------- @@ -36,28 +42,29 @@ def dist(x1,x2=None,metric='sqeuclidean'): x2 : np.array (n2,d), optional matrix with n2 samples of size d (if None then x2=x1) metric : str, fun, optional - name of the metric to be computed (full list in the doc of scipy), If a string, + name of the metric to be computed (full list in the doc of scipy), If a string, the distance function can be ‘braycurtis’, ‘canberra’, ‘chebyshev’, ‘cityblock’, ‘correlation’, ‘cosine’, ‘dice’, ‘euclidean’, ‘hamming’, ‘jaccard’, ‘kulsinski’, ‘mahalanobis’, ‘matching’, ‘minkowski’, ‘rogerstanimoto’, ‘russellrao’, ‘seuclidean’, ‘sokalmichener’, ‘sokalsneath’, ‘sqeuclidean’, ‘wminkowski’, ‘yule’. - + Returns ------- - + M : np.array (n1,n2) distance matrix computed with given metric - + """ if x2 is None: - return cdist(x1,x1,metric=metric) - else: - return cdist(x1,x2,metric=metric) - + x2=x1 + + return cdist(x1,x2,metric=metric) + + def dist0(n,method='lin_square'): """Compute standard cost matrices of size (n,n) for OT problems - + Parameters ---------- @@ -68,21 +75,21 @@ def dist0(n,method='lin_square'): * 'lin_square' : linear sampling between 0 and n-1, quadratic loss - + Returns ------- - + M : np.array (n1,n2) - distance matrix computed with given metric - - + distance matrix computed with given metric + + """ res=0 if method=='lin_square': x=np.arange(n,dtype=np.float64).reshape((n,1)) res=dist(x,x) return res - + def dots(*args): """ dots function for multiple matrix multiply """ |