summaryrefslogtreecommitdiff
path: root/ot/da.py
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2016-11-03 14:53:52 +0100
committerRémi Flamary <remi.flamary@gmail.com>2016-11-03 14:53:52 +0100
commit566645ad184e1205f7f666ea2f19021254c33d74 (patch)
tree1d740a6771ab515d0cbfe9f21fde801398eb19b6 /ot/da.py
parent981351165dbab740145d109b00782f0c41f2244b (diff)
add mapping estimation (still debugging)
Diffstat (limited to 'ot/da.py')
-rw-r--r--ot/da.py174
1 files changed, 165 insertions, 9 deletions
diff --git a/ot/da.py b/ot/da.py
index 7cfbca1..66680cd 100644
--- a/ot/da.py
+++ b/ot/da.py
@@ -6,13 +6,15 @@ Domain adaptation with optimal transport
import numpy as np
from .bregman import sinkhorn
from .lp import emd
-from .utils import unif,dist
+from .utils import unif,dist,kernel
from .optim import cg
def indices(a, func):
return [i for (i, val) in enumerate(a) if func(val)]
+
+
def sinkhorn_lpl1_mm(a,labels_a, b, M, reg, eta=0.1,numItermax = 10,numInnerItermax = 200,stopInnerThr=1e-9,verbose=False,log=False):
"""
Solve the entropic regularization optimal transport problem with nonconvex group lasso regularization
@@ -129,13 +131,15 @@ def joint_OT_mapping_linear(xs,xt,mu=1,eta=0.001,bias=False,verbose=False,verbos
if bias:
xs1=np.hstack((xs,np.ones((ns,1))))
- I=eta*np.eye(d+1)
+ xstxs=xs1.T.dot(xs1)
+ I=np.eye(d+1)
I[-1]=0
I0=I[:,:-1]
sel=lambda x : x[:-1,:]
else:
xs1=xs
- I=eta*np.eye(d)
+ xstxs=xs1.T.dot(xs1)
+ I=np.eye(d)
I0=I
sel=lambda x : x
@@ -143,20 +147,22 @@ def joint_OT_mapping_linear(xs,xt,mu=1,eta=0.001,bias=False,verbose=False,verbos
log={'err':[]}
a,b=unif(ns),unif(nt)
- M=dist(xs,xt)
+ M=dist(xs,xt)*ns
G=emd(a,b,M)
vloss=[]
def loss(L,G):
+ """Compute full loss"""
return np.sum((xs1.dot(L)-ns*G.dot(xt))**2)+mu*np.sum(G*M)+eta*np.sum(sel(L-I0)**2)
def solve_L(G):
- """ solve problem with fixed G"""
+ """ solve L problem with fixed G (least square)"""
xst=ns*G.dot(xt)
- return np.linalg.solve(xs1.T.dot(xs1)+I,xs1.T.dot(xst)+I0)
+ return np.linalg.solve(xstxs+eta*I,xs1.T.dot(xst)+eta*I0)
def solve_G(L,G0):
+ """Update G with CG algorithm"""
xsi=xs1.dot(L)
def f(G):
return np.sum((xsi-ns*G.dot(xt))**2)
@@ -175,8 +181,11 @@ def joint_OT_mapping_linear(xs,xt,mu=1,eta=0.001,bias=False,verbose=False,verbos
print('{:5d}|{:8e}|{:8e}'.format(0,vloss[-1],0))
- # regul matrix
- loop=1
+ # init loop
+ if numItermax>0:
+ loop=1
+ else:
+ loop=0
it=0
while loop:
@@ -191,6 +200,9 @@ def joint_OT_mapping_linear(xs,xt,mu=1,eta=0.001,bias=False,verbose=False,verbos
vloss.append(loss(L,G))
+ if it>=numItermax:
+ loop=0
+
if abs(vloss[-1]-vloss[-2])<stopThr:
loop=0
@@ -198,11 +210,106 @@ def joint_OT_mapping_linear(xs,xt,mu=1,eta=0.001,bias=False,verbose=False,verbos
if it%20==0:
print('{:5s}|{:12s}|{:8s}'.format('It.','Loss','Delta loss')+'\n'+'-'*32)
print('{:5d}|{:8e}|{:8e}'.format(it,vloss[-1],abs(vloss[-1]-vloss[-2])/abs(vloss[-2])))
+ if log:
+ log['loss']=vloss
+ return G,L,log
+ else:
+ return G,L
- return G,L
+def joint_OT_mapping_kernel(xs,xt,mu=1,eta=0.001,kernel='gaussian',sigma=1,bias=False,verbose=False,verbose2=False,numItermax = 100,numInnerItermax = 20,stopInnerThr=1e-9,stopThr=1e-6,log=False,**kwargs):
+ """Joint Ot and mapping estimation (uniform weights and )
+ """
+ ns,nt,d=xs.shape[0],xt.shape[0],xt.shape[1]
+ if bias:
+ K=
+ xs1=np.hstack((xs,np.ones((ns,1))))
+ xstxs=xs1.T.dot(xs1)
+ I=np.eye(d+1)
+ I[-1]=0
+ I0=I[:,:-1]
+ sel=lambda x : x[:-1,:]
+ else:
+ xs1=xs
+ xstxs=xs1.T.dot(xs1)
+ I=np.eye(d)
+ I0=I
+ sel=lambda x : x
+
+ if log:
+ log={'err':[]}
+
+ a,b=unif(ns),unif(nt)
+ M=dist(xs,xt)*ns
+ G=emd(a,b,M)
+
+ vloss=[]
+
+ def loss(L,G):
+ """Compute full loss"""
+ return np.sum((xs1.dot(L)-ns*G.dot(xt))**2)+mu*np.sum(G*M)+eta*np.sum(sel(L-I0)**2)
+
+ def solve_L(G):
+ """ solve L problem with fixed G (least square)"""
+ xst=ns*G.dot(xt)
+ return np.linalg.solve(xstxs+eta*I,xs1.T.dot(xst)+eta*I0)
+
+ def solve_G(L,G0):
+ """Update G with CG algorithm"""
+ xsi=xs1.dot(L)
+ def f(G):
+ return np.sum((xsi-ns*G.dot(xt))**2)
+ def df(G):
+ return -2*ns*(xsi-ns*G.dot(xt)).dot(xt.T)
+ G=cg(a,b,M,1.0/mu,f,df,G0=G0,numItermax=numInnerItermax,stopThr=stopInnerThr)
+ return G
+
+
+ L=solve_L(G)
+
+ vloss.append(loss(L,G))
+
+ if verbose:
+ print('{:5s}|{:12s}|{:8s}'.format('It.','Loss','Delta loss')+'\n'+'-'*32)
+ print('{:5d}|{:8e}|{:8e}'.format(0,vloss[-1],0))
+
+
+ # init loop
+ if numItermax>0:
+ loop=1
+ else:
+ loop=0
+ it=0
+
+ while loop:
+
+ it+=1
+
+ # update G
+ G=solve_G(L,G)
+
+ #update L
+ L=solve_L(G)
+
+ vloss.append(loss(L,G))
+
+ if it>=numItermax:
+ loop=0
+
+ if abs(vloss[-1]-vloss[-2])<stopThr:
+ loop=0
+
+ if verbose:
+ if it%20==0:
+ print('{:5s}|{:12s}|{:8s}'.format('It.','Loss','Delta loss')+'\n'+'-'*32)
+ print('{:5d}|{:8e}|{:8e}'.format(it,vloss[-1],abs(vloss[-1]-vloss[-2])/abs(vloss[-2])))
+ if log:
+ log['loss']=vloss
+ return G,L,log
+ else:
+ return G,L
class OTDA(object):
@@ -294,6 +401,7 @@ class OTDA(object):
class OTDA_sinkhorn(OTDA):
"""Class for domain adaptation with optimal transport with entropic regularization"""
+
def fit(self,xs,xt,reg=1,ws=None,wt=None,**kwargs):
""" Fit domain adaptation between samples is xs and xt (with optional
weights)"""
@@ -335,3 +443,51 @@ class OTDA_lpl1(OTDA):
self.G=sinkhorn_lpl1_mm(ws,ys,wt,self.M,reg,eta,**kwargs)
self.computed=True
+class OTDA_mapping(OTDA):
+ """Class for optimal transport with joint linear mapping estimation"""
+
+
+ def __init__(self,metric='sqeuclidean'):
+ """ Class initialization"""
+
+
+ self.xs=0
+ self.xt=0
+ self.G=0
+ self.L=0
+ self.bias=False
+ self.metric=metric
+ self.computed=False
+
+ def fit(self,xs,xt,mu=1,eta=1,bias=False,**kwargs):
+ """ Fit domain adaptation between samples is xs and xt (with optional
+ weights)"""
+ self.xs=xs
+ self.xt=xt
+ self.bias=bias
+
+ self.ws=unif(xs.shape[0])
+ self.wt=unif(xt.shape[0])
+
+ self.G,self.L=joint_OT_mapping_linear(xs,xt,mu=mu,eta=eta,bias=bias,**kwargs)
+ self.computed=True
+
+ def mapping(self):
+ return lambda x: self.predict(x)
+
+
+ def predict(self,x):
+ """ Out of sample mapping using the formulation from Ferradans
+
+ It basically find the source sample the nearset to the nex sample and
+ apply the difference to the displaced source sample.
+
+ """
+ if self.computed:
+ if self.bias:
+ x=np.hstack((x,np.ones((x.shape[0],1))))
+ return x.dot(self.L) # aply the delta to the interpolation
+ else:
+ print("Warning, model not fitted yet, returning None")
+ return None
+