summaryrefslogtreecommitdiff
path: root/ot/utils.py
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2016-10-21 11:19:46 +0200
committerRémi Flamary <remi.flamary@gmail.com>2016-10-21 11:19:46 +0200
commit872e6db7c0d110069b450cbe7efcc186c4871428 (patch)
treef499a0258cf69f47a211a54447af990ac0afb591 /ot/utils.py
parent581c6de782dca279edd97778cc474e7597788c0f (diff)
demo with sinkhorn
Diffstat (limited to 'ot/utils.py')
-rw-r--r--ot/utils.py13
1 files changed, 11 insertions, 2 deletions
diff --git a/ot/utils.py b/ot/utils.py
index 1a1c6b8..582c3ff 100644
--- a/ot/utils.py
+++ b/ot/utils.py
@@ -1,14 +1,23 @@
import numpy as np
-from scipy.spatial.distance import cdist, pdist
+from scipy.spatial.distance import cdist
def dist(x1,x2=None,metric='sqeuclidean'):
"""Compute distance between samples in x1 and x2"""
if x2 is None:
- return pdist(x1,metric=metric)
+ return cdist(x1,x1,metric=metric)
else:
return cdist(x1,x2,metric=metric)
+
+def dist0(n,method='linear'):
+ """Compute stardard cos matrices for OT problems"""
+ res=0
+ if method=='linear':
+ x=np.arange(n,dtype=np.float64).reshape((n,1))
+ res=dist(x,x)
+ return res
+
def dots(*args):
""" Stupid but nice dots function for multiple matrix multiply """