diff options
author | Rémi Flamary <remi.flamary@gmail.com> | 2016-11-03 14:53:52 +0100 |
---|---|---|
committer | Rémi Flamary <remi.flamary@gmail.com> | 2016-11-03 14:53:52 +0100 |
commit | 566645ad184e1205f7f666ea2f19021254c33d74 (patch) | |
tree | 1d740a6771ab515d0cbfe9f21fde801398eb19b6 | |
parent | 981351165dbab740145d109b00782f0c41f2244b (diff) |
add mapping estimation (still debugging)
-rw-r--r-- | ot/da.py | 174 | ||||
-rw-r--r-- | ot/datasets.py | 62 | ||||
-rw-r--r-- | ot/utils.py | 51 |
3 files changed, 234 insertions, 53 deletions
@@ -6,13 +6,15 @@ Domain adaptation with optimal transport import numpy as np from .bregman import sinkhorn from .lp import emd -from .utils import unif,dist +from .utils import unif,dist,kernel from .optim import cg def indices(a, func): return [i for (i, val) in enumerate(a) if func(val)] + + def sinkhorn_lpl1_mm(a,labels_a, b, M, reg, eta=0.1,numItermax = 10,numInnerItermax = 200,stopInnerThr=1e-9,verbose=False,log=False): """ Solve the entropic regularization optimal transport problem with nonconvex group lasso regularization @@ -129,13 +131,15 @@ def joint_OT_mapping_linear(xs,xt,mu=1,eta=0.001,bias=False,verbose=False,verbos if bias: xs1=np.hstack((xs,np.ones((ns,1)))) - I=eta*np.eye(d+1) + xstxs=xs1.T.dot(xs1) + I=np.eye(d+1) I[-1]=0 I0=I[:,:-1] sel=lambda x : x[:-1,:] else: xs1=xs - I=eta*np.eye(d) + xstxs=xs1.T.dot(xs1) + I=np.eye(d) I0=I sel=lambda x : x @@ -143,20 +147,22 @@ def joint_OT_mapping_linear(xs,xt,mu=1,eta=0.001,bias=False,verbose=False,verbos log={'err':[]} a,b=unif(ns),unif(nt) - M=dist(xs,xt) + M=dist(xs,xt)*ns G=emd(a,b,M) vloss=[] def loss(L,G): + """Compute full loss""" return np.sum((xs1.dot(L)-ns*G.dot(xt))**2)+mu*np.sum(G*M)+eta*np.sum(sel(L-I0)**2) def solve_L(G): - """ solve problem with fixed G""" + """ solve L problem with fixed G (least square)""" xst=ns*G.dot(xt) - return np.linalg.solve(xs1.T.dot(xs1)+I,xs1.T.dot(xst)+I0) + return np.linalg.solve(xstxs+eta*I,xs1.T.dot(xst)+eta*I0) def solve_G(L,G0): + """Update G with CG algorithm""" xsi=xs1.dot(L) def f(G): return np.sum((xsi-ns*G.dot(xt))**2) @@ -175,8 +181,11 @@ def joint_OT_mapping_linear(xs,xt,mu=1,eta=0.001,bias=False,verbose=False,verbos print('{:5d}|{:8e}|{:8e}'.format(0,vloss[-1],0)) - # regul matrix - loop=1 + # init loop + if numItermax>0: + loop=1 + else: + loop=0 it=0 while loop: @@ -191,6 +200,9 @@ def joint_OT_mapping_linear(xs,xt,mu=1,eta=0.001,bias=False,verbose=False,verbos vloss.append(loss(L,G)) + if it>=numItermax: + loop=0 + if abs(vloss[-1]-vloss[-2])<stopThr: loop=0 @@ -198,11 +210,106 @@ def joint_OT_mapping_linear(xs,xt,mu=1,eta=0.001,bias=False,verbose=False,verbos if it%20==0: print('{:5s}|{:12s}|{:8s}'.format('It.','Loss','Delta loss')+'\n'+'-'*32) print('{:5d}|{:8e}|{:8e}'.format(it,vloss[-1],abs(vloss[-1]-vloss[-2])/abs(vloss[-2]))) + if log: + log['loss']=vloss + return G,L,log + else: + return G,L - return G,L +def joint_OT_mapping_kernel(xs,xt,mu=1,eta=0.001,kernel='gaussian',sigma=1,bias=False,verbose=False,verbose2=False,numItermax = 100,numInnerItermax = 20,stopInnerThr=1e-9,stopThr=1e-6,log=False,**kwargs): + """Joint Ot and mapping estimation (uniform weights and ) + """ + ns,nt,d=xs.shape[0],xt.shape[0],xt.shape[1] + if bias: + K= + xs1=np.hstack((xs,np.ones((ns,1)))) + xstxs=xs1.T.dot(xs1) + I=np.eye(d+1) + I[-1]=0 + I0=I[:,:-1] + sel=lambda x : x[:-1,:] + else: + xs1=xs + xstxs=xs1.T.dot(xs1) + I=np.eye(d) + I0=I + sel=lambda x : x + + if log: + log={'err':[]} + + a,b=unif(ns),unif(nt) + M=dist(xs,xt)*ns + G=emd(a,b,M) + + vloss=[] + + def loss(L,G): + """Compute full loss""" + return np.sum((xs1.dot(L)-ns*G.dot(xt))**2)+mu*np.sum(G*M)+eta*np.sum(sel(L-I0)**2) + + def solve_L(G): + """ solve L problem with fixed G (least square)""" + xst=ns*G.dot(xt) + return np.linalg.solve(xstxs+eta*I,xs1.T.dot(xst)+eta*I0) + + def solve_G(L,G0): + """Update G with CG algorithm""" + xsi=xs1.dot(L) + def f(G): + return np.sum((xsi-ns*G.dot(xt))**2) + def df(G): + return -2*ns*(xsi-ns*G.dot(xt)).dot(xt.T) + G=cg(a,b,M,1.0/mu,f,df,G0=G0,numItermax=numInnerItermax,stopThr=stopInnerThr) + return G + + + L=solve_L(G) + + vloss.append(loss(L,G)) + + if verbose: + print('{:5s}|{:12s}|{:8s}'.format('It.','Loss','Delta loss')+'\n'+'-'*32) + print('{:5d}|{:8e}|{:8e}'.format(0,vloss[-1],0)) + + + # init loop + if numItermax>0: + loop=1 + else: + loop=0 + it=0 + + while loop: + + it+=1 + + # update G + G=solve_G(L,G) + + #update L + L=solve_L(G) + + vloss.append(loss(L,G)) + + if it>=numItermax: + loop=0 + + if abs(vloss[-1]-vloss[-2])<stopThr: + loop=0 + + if verbose: + if it%20==0: + print('{:5s}|{:12s}|{:8s}'.format('It.','Loss','Delta loss')+'\n'+'-'*32) + print('{:5d}|{:8e}|{:8e}'.format(it,vloss[-1],abs(vloss[-1]-vloss[-2])/abs(vloss[-2]))) + if log: + log['loss']=vloss + return G,L,log + else: + return G,L class OTDA(object): @@ -294,6 +401,7 @@ class OTDA(object): class OTDA_sinkhorn(OTDA): """Class for domain adaptation with optimal transport with entropic regularization""" + def fit(self,xs,xt,reg=1,ws=None,wt=None,**kwargs): """ Fit domain adaptation between samples is xs and xt (with optional weights)""" @@ -335,3 +443,51 @@ class OTDA_lpl1(OTDA): self.G=sinkhorn_lpl1_mm(ws,ys,wt,self.M,reg,eta,**kwargs) self.computed=True +class OTDA_mapping(OTDA): + """Class for optimal transport with joint linear mapping estimation""" + + + def __init__(self,metric='sqeuclidean'): + """ Class initialization""" + + + self.xs=0 + self.xt=0 + self.G=0 + self.L=0 + self.bias=False + self.metric=metric + self.computed=False + + def fit(self,xs,xt,mu=1,eta=1,bias=False,**kwargs): + """ Fit domain adaptation between samples is xs and xt (with optional + weights)""" + self.xs=xs + self.xt=xt + self.bias=bias + + self.ws=unif(xs.shape[0]) + self.wt=unif(xt.shape[0]) + + self.G,self.L=joint_OT_mapping_linear(xs,xt,mu=mu,eta=eta,bias=bias,**kwargs) + self.computed=True + + def mapping(self): + return lambda x: self.predict(x) + + + def predict(self,x): + """ Out of sample mapping using the formulation from Ferradans + + It basically find the source sample the nearset to the nex sample and + apply the difference to the displaced source sample. + + """ + if self.computed: + if self.bias: + x=np.hstack((x,np.ones((x.shape[0],1)))) + return x.dot(self.L) # aply the delta to the interpolation + else: + print("Warning, model not fitted yet, returning None") + return None + diff --git a/ot/datasets.py b/ot/datasets.py index 6388d94..588f501 100644 --- a/ot/datasets.py +++ b/ot/datasets.py @@ -8,8 +8,8 @@ import scipy as sp def get_1D_gauss(n,m,s): - """return a 1D histogram for a gaussian distribution (n bins, mean m and std s) - + """return a 1D histogram for a gaussian distribution (n bins, mean m and std s) + Parameters ---------- @@ -20,21 +20,21 @@ def get_1D_gauss(n,m,s): s : float standard deviaton of the gaussian distribution - + Returns ------- h : np.array (n,) - 1D histogram for a gaussian distribution - + 1D histogram for a gaussian distribution + """ x=np.arange(n,dtype=np.float64) h=np.exp(-(x-m)**2/(2*s^2)) return h/h.sum() - - + + def get_2D_samples_gauss(n,m,sigma): - """return n samples drawn from 2D gaussian N(m,sigma) - + """return n samples drawn from 2D gaussian N(m,sigma) + Parameters ---------- @@ -45,12 +45,12 @@ def get_2D_samples_gauss(n,m,sigma): sigma : np.array (2,2) covariance matrix of the gaussian distribution - + Returns ------- X : np.array (n,2) - n samples drawn from N(m,sigma) - + n samples drawn from N(m,sigma) + """ if np.isscalar(sigma): sigma=np.array([sigma,]) @@ -61,9 +61,10 @@ def get_2D_samples_gauss(n,m,sigma): res= np.random.randn(n,2)*np.sqrt(sigma)+m return res -def get_data_classif(dataset,n,nz=.5,**kwargs): + +def get_data_classif(dataset,n,nz=.5,theta=0,**kwargs): """ dataset generation for classification problems - + Parameters ---------- @@ -74,13 +75,13 @@ def get_data_classif(dataset,n,nz=.5,**kwargs): nz : float noise level (>0) - + Returns ------- X : np.array (n,d) - n observation of size d + n observation of size d y : np.array (n,) - labels of the samples + labels of the samples """ if dataset.lower()=='3gauss': @@ -90,10 +91,10 @@ def get_data_classif(dataset,n,nz=.5,**kwargs): x[y==1,0]=-1.; x[y==1,1]=-1. x[y==2,0]=-1.; x[y==2,1]=1. x[y==3,0]=1. ; x[y==3,1]=0 - + x[y!=3,:]+=1.5*nz*np.random.randn(sum(y!=3),2) x[y==3,:]+=2*nz*np.random.randn(sum(y==3),2) - + elif dataset.lower()=='3gauss2': y=np.floor((np.arange(n)*1.0/n*3))+1 x=np.zeros((n,2)) @@ -102,12 +103,29 @@ def get_data_classif(dataset,n,nz=.5,**kwargs): x[y==1,0]=-2.; x[y==1,1]=-2. x[y==2,0]=-2.; x[y==2,1]=2. x[y==3,0]=2. ; x[y==3,1]=0 - + x[y!=3,:]+=nz*np.random.randn(sum(y!=3),2) - x[y==3,:]+=2*nz*np.random.randn(sum(y==3),2) + x[y==3,:]+=2*nz*np.random.randn(sum(y==3),2) + + elif dataset.lower()=='gaussrot' : + rot=np.array([[np.cos(theta),-np.sin(theta)],[np.sin(theta),np.cos(theta)]]) + m1=np.array([-1,-1]) + m2=np.array([1,1]) + y=np.floor((np.arange(n)*1.0/n*2))+1 + n1=np.sum(y==1) + n2=np.sum(y==2) + x=np.zeros((n,2)) + + x[y==1,:]=get_2D_samples_gauss(n1,m1,nz) + x[y==2,:]=get_2D_samples_gauss(n2,m2,nz) + + x=x.dot(rot) + + + else: x=0 y=0 print("unknown dataset") - + return x,y.astype(int)
\ No newline at end of file diff --git a/ot/utils.py b/ot/utils.py index 24f65a8..47fe77f 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -6,28 +6,34 @@ import numpy as np from scipy.spatial.distance import cdist +def kernel(x1,x2,method='gaussian',sigma=1,**kwargs): + """Compute kernel matrix""" + if method.lower() in ['gaussian','gauss','rbf']: + K=np.exp(dist(x1,x2)/(2*sigma**2)) + return K + def unif(n): - """ return a uniform histogram of length n (simplex) - + """ return a uniform histogram of length n (simplex) + Parameters ---------- n : int number of bins in the histogram - + Returns ------- h : np.array (n,) - histogram of length n such that h_i=1/n for all i - - + histogram of length n such that h_i=1/n for all i + + """ return np.ones((n,))/n def dist(x1,x2=None,metric='sqeuclidean'): """Compute distance between samples in x1 and x2 using function scipy.spatial.distance.cdist - + Parameters ---------- @@ -36,28 +42,29 @@ def dist(x1,x2=None,metric='sqeuclidean'): x2 : np.array (n2,d), optional matrix with n2 samples of size d (if None then x2=x1) metric : str, fun, optional - name of the metric to be computed (full list in the doc of scipy), If a string, + name of the metric to be computed (full list in the doc of scipy), If a string, the distance function can be ‘braycurtis’, ‘canberra’, ‘chebyshev’, ‘cityblock’, ‘correlation’, ‘cosine’, ‘dice’, ‘euclidean’, ‘hamming’, ‘jaccard’, ‘kulsinski’, ‘mahalanobis’, ‘matching’, ‘minkowski’, ‘rogerstanimoto’, ‘russellrao’, ‘seuclidean’, ‘sokalmichener’, ‘sokalsneath’, ‘sqeuclidean’, ‘wminkowski’, ‘yule’. - + Returns ------- - + M : np.array (n1,n2) distance matrix computed with given metric - + """ if x2 is None: - return cdist(x1,x1,metric=metric) - else: - return cdist(x1,x2,metric=metric) - + x2=x1 + + return cdist(x1,x2,metric=metric) + + def dist0(n,method='lin_square'): """Compute standard cost matrices of size (n,n) for OT problems - + Parameters ---------- @@ -68,21 +75,21 @@ def dist0(n,method='lin_square'): * 'lin_square' : linear sampling between 0 and n-1, quadratic loss - + Returns ------- - + M : np.array (n1,n2) - distance matrix computed with given metric - - + distance matrix computed with given metric + + """ res=0 if method=='lin_square': x=np.arange(n,dtype=np.float64).reshape((n,1)) res=dist(x,x) return res - + def dots(*args): """ dots function for multiple matrix multiply """ |