diff options
Diffstat (limited to 'ot/da.py')
-rw-r--r-- | ot/da.py | 23 |
1 files changed, 11 insertions, 12 deletions
@@ -193,30 +193,30 @@ def sinkhorn_l1l2_gl(a,labels_a, b, M, reg, eta=0.1,numItermax = 10,numInnerIter """ lstlab=np.unique(labels_a) - + def f(G): res=0 for i in range(G.shape[1]): for lab in lstlab: temp=G[labels_a==lab,i] - res+=np.linalg.norm(temp) + res+=np.linalg.norm(temp) return res - + def df(G): - W=np.zeros(G.shape) + W=np.zeros(G.shape) for i in range(G.shape[1]): for lab in lstlab: temp=G[labels_a==lab,i] n=np.linalg.norm(temp) if n: - W[labels_a==lab,i]=temp/n - return W + W[labels_a==lab,i]=temp/n + return W + - return gcg(a,b,M,reg,eta,f,df,G0=None,numItermax = numItermax,numInnerItermax=numInnerItermax, stopThr=stopInnerThr,verbose=verbose,log=log) - - - + + + def joint_OT_mapping_linear(xs,xt,mu=1,eta=0.001,bias=False,verbose=False,verbose2=False,numItermax = 100,numInnerItermax = 10,stopInnerThr=1e-6,stopThr=1e-5,log=False,**kwargs): """Joint OT and linear mapping estimation as proposed in [8] @@ -685,7 +685,6 @@ class OTDA(object): return xf[idx,:]+x-x0[idx,:] # aply the delta to the interpolation - class OTDA_sinkhorn(OTDA): """Class for domain adaptation with optimal transport with entropic regularization""" @@ -727,7 +726,7 @@ class OTDA_lpl1(OTDA): self.M=dist(xs,xt,metric=self.metric) self.G=sinkhorn_lpl1_mm(ws,ys,wt,self.M,reg,eta,**kwargs) self.computed=True - + class OTDA_l1l2(OTDA): """Class for domain adaptation with optimal transport with entropic and group lasso regularization""" |