From 566645ad184e1205f7f666ea2f19021254c33d74 Mon Sep 17 00:00:00 2001 From: Rémi Flamary Date: Thu, 3 Nov 2016 14:53:52 +0100 Subject: add mapping estimation (still debugging) --- ot/da.py | 174 ++++++++++++++++++++++++++++++++++++++++++++++++++++++--- ot/datasets.py | 62 ++++++++++++-------- ot/utils.py | 51 +++++++++-------- 3 files changed, 234 insertions(+), 53 deletions(-) (limited to 'ot') diff --git a/ot/da.py b/ot/da.py index 7cfbca1..66680cd 100644 --- a/ot/da.py +++ b/ot/da.py @@ -6,13 +6,15 @@ Domain adaptation with optimal transport import numpy as np from .bregman import sinkhorn from .lp import emd -from .utils import unif,dist +from .utils import unif,dist,kernel from .optim import cg def indices(a, func): return [i for (i, val) in enumerate(a) if func(val)] + + def sinkhorn_lpl1_mm(a,labels_a, b, M, reg, eta=0.1,numItermax = 10,numInnerItermax = 200,stopInnerThr=1e-9,verbose=False,log=False): """ Solve the entropic regularization optimal transport problem with nonconvex group lasso regularization @@ -129,13 +131,15 @@ def joint_OT_mapping_linear(xs,xt,mu=1,eta=0.001,bias=False,verbose=False,verbos if bias: xs1=np.hstack((xs,np.ones((ns,1)))) - I=eta*np.eye(d+1) + xstxs=xs1.T.dot(xs1) + I=np.eye(d+1) I[-1]=0 I0=I[:,:-1] sel=lambda x : x[:-1,:] else: xs1=xs - I=eta*np.eye(d) + xstxs=xs1.T.dot(xs1) + I=np.eye(d) I0=I sel=lambda x : x @@ -143,20 +147,22 @@ def joint_OT_mapping_linear(xs,xt,mu=1,eta=0.001,bias=False,verbose=False,verbos log={'err':[]} a,b=unif(ns),unif(nt) - M=dist(xs,xt) + M=dist(xs,xt)*ns G=emd(a,b,M) vloss=[] def loss(L,G): + """Compute full loss""" return np.sum((xs1.dot(L)-ns*G.dot(xt))**2)+mu*np.sum(G*M)+eta*np.sum(sel(L-I0)**2) def solve_L(G): - """ solve problem with fixed G""" + """ solve L problem with fixed G (least square)""" xst=ns*G.dot(xt) - return np.linalg.solve(xs1.T.dot(xs1)+I,xs1.T.dot(xst)+I0) + return np.linalg.solve(xstxs+eta*I,xs1.T.dot(xst)+eta*I0) def solve_G(L,G0): + """Update G with CG algorithm""" xsi=xs1.dot(L) def f(G): return np.sum((xsi-ns*G.dot(xt))**2) @@ -175,8 +181,11 @@ def joint_OT_mapping_linear(xs,xt,mu=1,eta=0.001,bias=False,verbose=False,verbos print('{:5d}|{:8e}|{:8e}'.format(0,vloss[-1],0)) - # regul matrix - loop=1 + # init loop + if numItermax>0: + loop=1 + else: + loop=0 it=0 while loop: @@ -191,6 +200,9 @@ def joint_OT_mapping_linear(xs,xt,mu=1,eta=0.001,bias=False,verbose=False,verbos vloss.append(loss(L,G)) + if it>=numItermax: + loop=0 + if abs(vloss[-1]-vloss[-2])0: + loop=1 + else: + loop=0 + it=0 + + while loop: + + it+=1 + + # update G + G=solve_G(L,G) + + #update L + L=solve_L(G) + + vloss.append(loss(L,G)) + + if it>=numItermax: + loop=0 + + if abs(vloss[-1]-vloss[-2])0) - + Returns ------- X : np.array (n,d) - n observation of size d + n observation of size d y : np.array (n,) - labels of the samples + labels of the samples """ if dataset.lower()=='3gauss': @@ -90,10 +91,10 @@ def get_data_classif(dataset,n,nz=.5,**kwargs): x[y==1,0]=-1.; x[y==1,1]=-1. x[y==2,0]=-1.; x[y==2,1]=1. x[y==3,0]=1. ; x[y==3,1]=0 - + x[y!=3,:]+=1.5*nz*np.random.randn(sum(y!=3),2) x[y==3,:]+=2*nz*np.random.randn(sum(y==3),2) - + elif dataset.lower()=='3gauss2': y=np.floor((np.arange(n)*1.0/n*3))+1 x=np.zeros((n,2)) @@ -102,12 +103,29 @@ def get_data_classif(dataset,n,nz=.5,**kwargs): x[y==1,0]=-2.; x[y==1,1]=-2. x[y==2,0]=-2.; x[y==2,1]=2. x[y==3,0]=2. ; x[y==3,1]=0 - + x[y!=3,:]+=nz*np.random.randn(sum(y!=3),2) - x[y==3,:]+=2*nz*np.random.randn(sum(y==3),2) + x[y==3,:]+=2*nz*np.random.randn(sum(y==3),2) + + elif dataset.lower()=='gaussrot' : + rot=np.array([[np.cos(theta),-np.sin(theta)],[np.sin(theta),np.cos(theta)]]) + m1=np.array([-1,-1]) + m2=np.array([1,1]) + y=np.floor((np.arange(n)*1.0/n*2))+1 + n1=np.sum(y==1) + n2=np.sum(y==2) + x=np.zeros((n,2)) + + x[y==1,:]=get_2D_samples_gauss(n1,m1,nz) + x[y==2,:]=get_2D_samples_gauss(n2,m2,nz) + + x=x.dot(rot) + + + else: x=0 y=0 print("unknown dataset") - + return x,y.astype(int) \ No newline at end of file diff --git a/ot/utils.py b/ot/utils.py index 24f65a8..47fe77f 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -6,28 +6,34 @@ import numpy as np from scipy.spatial.distance import cdist +def kernel(x1,x2,method='gaussian',sigma=1,**kwargs): + """Compute kernel matrix""" + if method.lower() in ['gaussian','gauss','rbf']: + K=np.exp(dist(x1,x2)/(2*sigma**2)) + return K + def unif(n): - """ return a uniform histogram of length n (simplex) - + """ return a uniform histogram of length n (simplex) + Parameters ---------- n : int number of bins in the histogram - + Returns ------- h : np.array (n,) - histogram of length n such that h_i=1/n for all i - - + histogram of length n such that h_i=1/n for all i + + """ return np.ones((n,))/n def dist(x1,x2=None,metric='sqeuclidean'): """Compute distance between samples in x1 and x2 using function scipy.spatial.distance.cdist - + Parameters ---------- @@ -36,28 +42,29 @@ def dist(x1,x2=None,metric='sqeuclidean'): x2 : np.array (n2,d), optional matrix with n2 samples of size d (if None then x2=x1) metric : str, fun, optional - name of the metric to be computed (full list in the doc of scipy), If a string, + name of the metric to be computed (full list in the doc of scipy), If a string, the distance function can be ‘braycurtis’, ‘canberra’, ‘chebyshev’, ‘cityblock’, ‘correlation’, ‘cosine’, ‘dice’, ‘euclidean’, ‘hamming’, ‘jaccard’, ‘kulsinski’, ‘mahalanobis’, ‘matching’, ‘minkowski’, ‘rogerstanimoto’, ‘russellrao’, ‘seuclidean’, ‘sokalmichener’, ‘sokalsneath’, ‘sqeuclidean’, ‘wminkowski’, ‘yule’. - + Returns ------- - + M : np.array (n1,n2) distance matrix computed with given metric - + """ if x2 is None: - return cdist(x1,x1,metric=metric) - else: - return cdist(x1,x2,metric=metric) - + x2=x1 + + return cdist(x1,x2,metric=metric) + + def dist0(n,method='lin_square'): """Compute standard cost matrices of size (n,n) for OT problems - + Parameters ---------- @@ -68,21 +75,21 @@ def dist0(n,method='lin_square'): * 'lin_square' : linear sampling between 0 and n-1, quadratic loss - + Returns ------- - + M : np.array (n1,n2) - distance matrix computed with given metric - - + distance matrix computed with given metric + + """ res=0 if method=='lin_square': x=np.arange(n,dtype=np.float64).reshape((n,1)) res=dist(x,x) return res - + def dots(*args): """ dots function for multiple matrix multiply """ -- cgit v1.2.3