diff options
author | Rémi Flamary <remi.flamary@gmail.com> | 2016-10-21 11:19:46 +0200 |
---|---|---|
committer | Rémi Flamary <remi.flamary@gmail.com> | 2016-10-21 11:19:46 +0200 |
commit | 872e6db7c0d110069b450cbe7efcc186c4871428 (patch) | |
tree | f499a0258cf69f47a211a54447af990ac0afb591 /ot/utils.py | |
parent | 581c6de782dca279edd97778cc474e7597788c0f (diff) |
demo with sinkhorn
Diffstat (limited to 'ot/utils.py')
-rw-r--r-- | ot/utils.py | 13 |
1 files changed, 11 insertions, 2 deletions
diff --git a/ot/utils.py b/ot/utils.py index 1a1c6b8..582c3ff 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -1,14 +1,23 @@ import numpy as np -from scipy.spatial.distance import cdist, pdist +from scipy.spatial.distance import cdist def dist(x1,x2=None,metric='sqeuclidean'): """Compute distance between samples in x1 and x2""" if x2 is None: - return pdist(x1,metric=metric) + return cdist(x1,x1,metric=metric) else: return cdist(x1,x2,metric=metric) + +def dist0(n,method='linear'): + """Compute stardard cos matrices for OT problems""" + res=0 + if method=='linear': + x=np.arange(n,dtype=np.float64).reshape((n,1)) + res=dist(x,x) + return res + def dots(*args): """ Stupid but nice dots function for multiple matrix multiply """ |