summaryrefslogtreecommitdiff
path: root/ot/utils.py
blob: 2f68775a30887ca9febe96e8bc3e1ef558820f49 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
# -*- coding: utf-8 -*-
"""
Various function that can be usefull
"""
import numpy as np
from scipy.spatial.distance import cdist


import time
__time_tic_toc=time.time()

def tic():
    """ Python implementation of Matlab tic() function """
    global __time_tic_toc
    __time_tic_toc=time.time()

def toc(message='Elapsed time : {} s'):
    """ Python implementation of Matlab toc() function """
    t=time.time()
    print(message.format(t-__time_tic_toc))
    return t-__time_tic_toc

def toq():
    """ Python implementation of Julia toc() function """
    t=time.time()
    return t-__time_tic_toc


def kernel(x1,x2,method='gaussian',sigma=1,**kwargs):
    """Compute kernel matrix"""
    if method.lower() in ['gaussian','gauss','rbf']:
        K=np.exp(-dist(x1,x2)/(2*sigma**2))
    return K

def unif(n):
    """ return a uniform histogram of length n (simplex)

    Parameters
    ----------

    n : int
        number of bins in the histogram

    Returns
    -------
    h : np.array (n,)
        histogram of length n such that h_i=1/n for all i


    """
    return np.ones((n,))/n


def dist(x1,x2=None,metric='sqeuclidean'):
    """Compute distance between samples in x1 and x2 using function scipy.spatial.distance.cdist

    Parameters
    ----------

    x1 : np.array (n1,d)
        matrix with n1 samples of size d
    x2 : np.array (n2,d), optional
        matrix with n2 samples of size d (if None then x2=x1)
    metric : str, fun, optional
        name of the metric to be computed (full list in the doc of scipy),  If a string,
        the distance function can be ‘braycurtis’, ‘canberra’, ‘chebyshev’, ‘cityblock’,
        ‘correlation’, ‘cosine’, ‘dice’, ‘euclidean’, ‘hamming’, ‘jaccard’, ‘kulsinski’,
        ‘mahalanobis’, ‘matching’, ‘minkowski’, ‘rogerstanimoto’, ‘russellrao’, ‘seuclidean’,
        ‘sokalmichener’, ‘sokalsneath’, ‘sqeuclidean’, ‘wminkowski’, ‘yule’.


    Returns
    -------

    M : np.array (n1,n2)
        distance matrix computed with given metric

    """
    if x2 is None:
        x2=x1

    return cdist(x1,x2,metric=metric)


def dist0(n,method='lin_square'):
    """Compute standard cost matrices of size (n,n) for OT problems

    Parameters
    ----------

    n : int
        size of the cost matrix
    method : str, optional
        Type of loss matrix chosen from:

        * 'lin_square' : linear sampling between 0 and n-1, quadratic loss


    Returns
    -------

    M : np.array (n1,n2)
        distance matrix computed with given metric


    """
    res=0
    if method=='lin_square':
        x=np.arange(n,dtype=np.float64).reshape((n,1))
        res=dist(x,x)
    return res


def dots(*args):
    """ dots function for multiple matrix multiply """
    return reduce(np.dot,args)