summaryrefslogtreecommitdiff
path: root/ot/utils.py
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2016-11-03 14:53:52 +0100
committerRémi Flamary <remi.flamary@gmail.com>2016-11-03 14:53:52 +0100
commit566645ad184e1205f7f666ea2f19021254c33d74 (patch)
tree1d740a6771ab515d0cbfe9f21fde801398eb19b6 /ot/utils.py
parent981351165dbab740145d109b00782f0c41f2244b (diff)
add mapping estimation (still debugging)
Diffstat (limited to 'ot/utils.py')
-rw-r--r--ot/utils.py51
1 files changed, 29 insertions, 22 deletions
diff --git a/ot/utils.py b/ot/utils.py
index 24f65a8..47fe77f 100644
--- a/ot/utils.py
+++ b/ot/utils.py
@@ -6,28 +6,34 @@ import numpy as np
from scipy.spatial.distance import cdist
+def kernel(x1,x2,method='gaussian',sigma=1,**kwargs):
+ """Compute kernel matrix"""
+ if method.lower() in ['gaussian','gauss','rbf']:
+ K=np.exp(dist(x1,x2)/(2*sigma**2))
+ return K
+
def unif(n):
- """ return a uniform histogram of length n (simplex)
-
+ """ return a uniform histogram of length n (simplex)
+
Parameters
----------
n : int
number of bins in the histogram
-
+
Returns
-------
h : np.array (n,)
- histogram of length n such that h_i=1/n for all i
-
-
+ histogram of length n such that h_i=1/n for all i
+
+
"""
return np.ones((n,))/n
def dist(x1,x2=None,metric='sqeuclidean'):
"""Compute distance between samples in x1 and x2 using function scipy.spatial.distance.cdist
-
+
Parameters
----------
@@ -36,28 +42,29 @@ def dist(x1,x2=None,metric='sqeuclidean'):
x2 : np.array (n2,d), optional
matrix with n2 samples of size d (if None then x2=x1)
metric : str, fun, optional
- name of the metric to be computed (full list in the doc of scipy), If a string,
+ name of the metric to be computed (full list in the doc of scipy), If a string,
the distance function can be ‘braycurtis’, ‘canberra’, ‘chebyshev’, ‘cityblock’,
‘correlation’, ‘cosine’, ‘dice’, ‘euclidean’, ‘hamming’, ‘jaccard’, ‘kulsinski’,
‘mahalanobis’, ‘matching’, ‘minkowski’, ‘rogerstanimoto’, ‘russellrao’, ‘seuclidean’,
‘sokalmichener’, ‘sokalsneath’, ‘sqeuclidean’, ‘wminkowski’, ‘yule’.
-
+
Returns
-------
-
+
M : np.array (n1,n2)
distance matrix computed with given metric
-
+
"""
if x2 is None:
- return cdist(x1,x1,metric=metric)
- else:
- return cdist(x1,x2,metric=metric)
-
+ x2=x1
+
+ return cdist(x1,x2,metric=metric)
+
+
def dist0(n,method='lin_square'):
"""Compute standard cost matrices of size (n,n) for OT problems
-
+
Parameters
----------
@@ -68,21 +75,21 @@ def dist0(n,method='lin_square'):
* 'lin_square' : linear sampling between 0 and n-1, quadratic loss
-
+
Returns
-------
-
+
M : np.array (n1,n2)
- distance matrix computed with given metric
-
-
+ distance matrix computed with given metric
+
+
"""
res=0
if method=='lin_square':
x=np.arange(n,dtype=np.float64).reshape((n,1))
res=dist(x,x)
return res
-
+
def dots(*args):
""" dots function for multiple matrix multiply """