summaryrefslogtreecommitdiff
path: root/ot/dr.py
blob: 3732d81ff17010c1e8a70a4940972d463d49f802 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
# -*- coding: utf-8 -*-
"""
Domain adaptation with optimal transport
"""


import autograd.numpy as np
from pymanopt.manifolds import Stiefel
from pymanopt import Problem
from pymanopt.solvers import SteepestDescent, TrustRegions

def dist(x1,x2):
    """ Compute squared euclidena distance between samples
    """
    x1p2=np.sum(np.square(x1),1)    
    x2p2=np.sum(np.square(x2),1)
    return x1p2.reshape((-1,1))+x2p2.reshape((1,-1))-2*np.dot(x1,x2.T)    

def sinkhorn(w1,w2,M,reg,k):
    """
    Simple solver for Sinkhorn algorithm with fixed number of iteration
    """
    K=np.exp(-M/reg)
    ui=np.ones((M.shape[0],))
    vi=np.ones((M.shape[1],))
    for i in range(k):
        vi=w2/(np.dot(K.T,ui))
        ui=w1/(np.dot(K,vi))
    G=ui.reshape((M.shape[0],1))*K*vi.reshape((1,M.shape[1]))   
    return G

def split_classes(X,y):
    """
    split samples in X by classes in y
    """
    lstsclass=np.unique(y)
    return [X[y==i,:].astype(np.float32) for i in lstsclass]
    


def wda(X,y,p=2,reg=1,k=10,solver = None,maxiter=100,verbose=0):
    """ 
    Wasserstein Discriminant Analysis    
    
    The function solves the following optimization problem:

    .. math::
        P = arg\min_P \frac{\sum_i W(PX^i,PX^i)}{\sum_{i,j\neq i} W(PX^i,PX^j)}

    where :

    - :math:`W` is entropic regularized Wasserstein distances
    - :math:`X^i` are samples in the dataset corresponding to class i    
    
    """
    
    mx=np.mean(X)
    X-=mx.reshape((1,-1))
    
    # data split between classes
    d=X.shape[1]
    xc=split_classes(X,y)
    # compute uniform weighs
    wc=[np.ones((x.shape[0]),dtype=np.float32)/x.shape[0] for x in xc]
        
    def cost(P):
        # wda loss
        loss_b=0
        loss_w=0
    
        for i,xi in enumerate(xc):
            xi=np.dot(xi,P)
            for j,xj in  enumerate(xc[i:]):
                xj=np.dot(xj,P)
                M=dist(xi,xj)
                G=sinkhorn(wc[i],wc[j+i],M,reg,k)
                if j==0:
                    loss_w+=np.sum(G*M)
                else:
                    loss_b+=np.sum(G*M)
                    
        # loss inversed because minimization            
        return loss_w/loss_b
    
    
    # declare manifold and problem
    manifold = Stiefel(d, p)    
    problem = Problem(manifold=manifold, cost=cost)
    
    # declare solver and solve
    if solver is None:
        solver= SteepestDescent(maxiter=maxiter,logverbosity=verbose)   
    elif solver in ['tr','TrustRegions']:
        solver= TrustRegions(maxiter=maxiter,logverbosity=verbose)
        
    Popt = solver.solve(problem)
    
    def proj(X):
        return (X-mx.reshape((1,-1))).dot(Popt)
    
    return Popt, proj