summaryrefslogtreecommitdiff
path: root/ot/utils.py
blob: 46c3775a7cfb976e210d04203694ff2baaa6870a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
"""
Various function that can be usefull
"""
import numpy as np
from scipy.spatial.distance import cdist


def unif(n):
    """ return a uniform histogram (simplex) """
    return np.ones((n,))/n


def dist(x1,x2=None,metric='sqeuclidean'):
    """Compute distance between samples in x1 and x2"""
    if x2 is None:
        return cdist(x1,x1,metric=metric)
    else:
        return cdist(x1,x2,metric=metric)  
        
def dist0(n,method='linear'):
    """Compute stardard cos matrices for OT problems"""
    res=0
    if method=='linear':
        x=np.arange(n,dtype=np.float64).reshape((n,1))
        res=dist(x,x)
    return res
    

def dots(*args):
    """ Stupid but nice dots function for multiple matrix multiply """
    return reduce(np.dot,args)