diff options
Diffstat (limited to 'ot/da.py')
-rw-r--r-- | ot/da.py | 35 |
1 files changed, 21 insertions, 14 deletions
@@ -217,25 +217,23 @@ def joint_OT_mapping_linear(xs,xt,mu=1,eta=0.001,bias=False,verbose=False,verbos return G,L -def joint_OT_mapping_kernel(xs,xt,mu=1,eta=0.001,kernel='gaussian',sigma=1,bias=False,verbose=False,verbose2=False,numItermax = 100,numInnerItermax = 20,stopInnerThr=1e-9,stopThr=1e-6,log=False,**kwargs): +def joint_OT_mapping_kernel(xs,xt,mu=1,eta=0.001,kerneltype='gaussian',sigma=1,bias=False,verbose=False,verbose2=False,numItermax = 100,numInnerItermax = 20,stopInnerThr=1e-9,stopThr=1e-6,log=False,**kwargs): """Joint Ot and mapping estimation (uniform weights and ) """ ns,nt,d=xs.shape[0],xt.shape[0],xt.shape[1] + K=kernel(xs,xs,method=kerneltype,sigma=sigma) if bias: - K= - xs1=np.hstack((xs,np.ones((ns,1)))) - xstxs=xs1.T.dot(xs1) - I=np.eye(d+1) + K1=np.hstack((K,np.ones((ns,1)))) + I=np.eye(ns+1) I[-1]=0 - I0=I[:,:-1] + K0 = K1.T.dot(K1)+eta*I sel=lambda x : x[:-1,:] else: - xs1=xs - xstxs=xs1.T.dot(xs1) - I=np.eye(d) - I0=I + K1=K + I=np.eye(ns) + K0=K+eta*I sel=lambda x : x if log: @@ -249,16 +247,21 @@ def joint_OT_mapping_kernel(xs,xt,mu=1,eta=0.001,kernel='gaussian',sigma=1,bias= def loss(L,G): """Compute full loss""" - return np.sum((xs1.dot(L)-ns*G.dot(xt))**2)+mu*np.sum(G*M)+eta*np.sum(sel(L-I0)**2) + return np.sum((K1.dot(L)-ns*G.dot(xt))**2)+mu*np.sum(G*M)+eta*np.sum(sel(L)**2) - def solve_L(G): + def solve_L_nobias(G): """ solve L problem with fixed G (least square)""" xst=ns*G.dot(xt) - return np.linalg.solve(xstxs+eta*I,xs1.T.dot(xst)+eta*I0) + return np.linalg.solve(K0,xst) + + def solve_L_bias(G): + """ solve L problem with fixed G (least square)""" + xst=ns*G.dot(xt) + return np.linalg.solve(K0,K1.T.dot(xst)) def solve_G(L,G0): """Update G with CG algorithm""" - xsi=xs1.dot(L) + xsi=K1.dot(L) def f(G): return np.sum((xsi-ns*G.dot(xt))**2) def df(G): @@ -266,6 +269,10 @@ def joint_OT_mapping_kernel(xs,xt,mu=1,eta=0.001,kernel='gaussian',sigma=1,bias= G=cg(a,b,M,1.0/mu,f,df,G0=G0,numItermax=numInnerItermax,stopThr=stopInnerThr) return G + if bias: + solve_L=solve_L_bias + else: + solve_L=solve_L_nobias L=solve_L(G) |