summaryrefslogtreecommitdiff
path: root/ot/utils.py
blob: 1a1c6b869b24ada73bd7043ad858891c5e68f9c9 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import numpy as np
from scipy.spatial.distance import cdist, pdist


def dist(x1,x2=None,metric='sqeuclidean'):
    """Compute distance between samples in x1 and x2"""
    if x2 is None:
        return pdist(x1,metric=metric)
    else:
        return cdist(x1,x2,metric=metric)  

def dots(*args):
    """ Stupid but nice dots function for multiple matrix multiply """
    return reduce(np.dot,args)