summaryrefslogtreecommitdiff
path: root/ot/da.py
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2016-11-02 17:13:43 +0100
committerRémi Flamary <remi.flamary@gmail.com>2016-11-02 17:13:43 +0100
commit981351165dbab740145d109b00782f0c41f2244b (patch)
treed48f811d79b215165b72f1887fa1303a7563f571 /ot/da.py
parent3c4944cf705477c913e5feb8f2483d0fa40ed5e1 (diff)
add mapping estimation (still debugging)
Diffstat (limited to 'ot/da.py')
-rw-r--r--ot/da.py228
1 files changed, 155 insertions, 73 deletions
diff --git a/ot/da.py b/ot/da.py
index 3447437..7cfbca1 100644
--- a/ot/da.py
+++ b/ot/da.py
@@ -7,6 +7,7 @@ import numpy as np
from .bregman import sinkhorn
from .lp import emd
from .utils import unif,dist
+from .optim import cg
def indices(a, func):
@@ -15,81 +16,81 @@ def indices(a, func):
def sinkhorn_lpl1_mm(a,labels_a, b, M, reg, eta=0.1,numItermax = 10,numInnerItermax = 200,stopInnerThr=1e-9,verbose=False,log=False):
"""
Solve the entropic regularization optimal transport problem with nonconvex group lasso regularization
-
+
The function solves the following optimization problem:
-
+
.. math::
\gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega_e(\gamma)+ \eta \Omega_g(\gamma)
-
+
s.t. \gamma 1 = a
-
- \gamma^T 1= b
-
+
+ \gamma^T 1= b
+
\gamma\geq 0
where :
-
+
- M is the (ns,nt) metric cost matrix
- :math:`\Omega_e` is the entropic regularization term :math:`\Omega_e(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- :math:`\Omega_g` is the group lasso regulaization term :math:`\Omega_g(\gamma)=\sum_{i,c} \|\gamma_{i,\mathcal{I}_c}\|^{1/2}_1` where :math:`\mathcal{I}_c` are the index of samples from class c in the source domain.
- a and b are source and target weights (sum to 1)
-
+
The algorithm used for solving the problem is the generalised conditional gradient as proposed in [5]_ [7]_
-
-
+
+
Parameters
----------
a : np.ndarray (ns,)
samples weights in the source domain
labels_a : np.ndarray (ns,)
- labels of samples in the source domain
+ labels of samples in the source domain
b : np.ndarray (nt,)
samples in the target domain
M : np.ndarray (ns,nt)
- loss matrix
+ loss matrix
reg: float
Regularization term for entropic regularization >0
eta: float, optional
- Regularization term for group lasso regularization >0
+ Regularization term for group lasso regularization >0
numItermax: int, optional
Max number of iterations
numInnerItermax: int, optional
Max number of iterations (inner sinkhorn solver)
stopInnerThr: float, optional
- Stop threshold on error (inner sinkhorn solver) (>0)
+ Stop threshold on error (inner sinkhorn solver) (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
- record log if True
-
-
+ record log if True
+
+
Returns
-------
gamma: (ns x nt) ndarray
Optimal transportation matrix for the given parameters
log: dict
- log dictionary return only if log==True in parameters
-
-
+ log dictionary return only if log==True in parameters
+
+
References
----------
-
+
.. [5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, "Optimal Transport for Domain Adaptation," in IEEE Transactions on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1
.. [7] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015). Generalized conditional gradient: analysis of convergence and applications. arXiv preprint arXiv:1510.06567.
-
+
See Also
--------
ot.lp.emd : Unregularized OT
ot.bregman.sinkhorn : Entropic regularized OT
ot.optim.cg : General regularized OT
-
- """
+
+ """
p=0.5
epsilon = 1e-3
# init data
Nini = len(a)
Nfin = len(b)
-
+
indices_labels = []
idx_begin = np.min(labels_a)
for c in range(idx_begin,np.max(labels_a)+1):
@@ -117,14 +118,96 @@ def sinkhorn_lpl1_mm(a,labels_a, b, M, reg, eta=0.1,numItermax = 10,numInnerIter
# do it only for unlabbled data
if idx_begin==-1:
W[indices_labels[0],t]=np.min(all_maj)
-
+
return transp
+def joint_OT_mapping_linear(xs,xt,mu=1,eta=0.001,bias=False,verbose=False,verbose2=False,numItermax = 100,numInnerItermax = 20,stopInnerThr=1e-9,stopThr=1e-6,log=False,**kwargs):
+ """Joint Ot and mapping estimation (uniform weights and )
+ """
+
+ ns,nt,d=xs.shape[0],xt.shape[0],xt.shape[1]
+
+ if bias:
+ xs1=np.hstack((xs,np.ones((ns,1))))
+ I=eta*np.eye(d+1)
+ I[-1]=0
+ I0=I[:,:-1]
+ sel=lambda x : x[:-1,:]
+ else:
+ xs1=xs
+ I=eta*np.eye(d)
+ I0=I
+ sel=lambda x : x
+
+ if log:
+ log={'err':[]}
+
+ a,b=unif(ns),unif(nt)
+ M=dist(xs,xt)
+ G=emd(a,b,M)
+
+ vloss=[]
+
+ def loss(L,G):
+ return np.sum((xs1.dot(L)-ns*G.dot(xt))**2)+mu*np.sum(G*M)+eta*np.sum(sel(L-I0)**2)
+
+ def solve_L(G):
+ """ solve problem with fixed G"""
+ xst=ns*G.dot(xt)
+ return np.linalg.solve(xs1.T.dot(xs1)+I,xs1.T.dot(xst)+I0)
+
+ def solve_G(L,G0):
+ xsi=xs1.dot(L)
+ def f(G):
+ return np.sum((xsi-ns*G.dot(xt))**2)
+ def df(G):
+ return -2*ns*(xsi-ns*G.dot(xt)).dot(xt.T)
+ G=cg(a,b,M,1.0/mu,f,df,G0=G0,numItermax=numInnerItermax,stopThr=stopInnerThr)
+ return G
+
+
+ L=solve_L(G)
+
+ vloss.append(loss(L,G))
+
+ if verbose:
+ print('{:5s}|{:12s}|{:8s}'.format('It.','Loss','Delta loss')+'\n'+'-'*32)
+ print('{:5d}|{:8e}|{:8e}'.format(0,vloss[-1],0))
+
+
+ # regul matrix
+ loop=1
+ it=0
+
+ while loop:
+
+ it+=1
+
+ # update G
+ G=solve_G(L,G)
+
+ #update L
+ L=solve_L(G)
+
+ vloss.append(loss(L,G))
+
+ if abs(vloss[-1]-vloss[-2])<stopThr:
+ loop=0
+
+ if verbose:
+ if it%20==0:
+ print('{:5s}|{:12s}|{:8s}'.format('It.','Loss','Delta loss')+'\n'+'-'*32)
+ print('{:5d}|{:8e}|{:8e}'.format(it,vloss[-1],abs(vloss[-1]-vloss[-2])/abs(vloss[-2])))
+
+ return G,L
+
+
+
class OTDA(object):
"""Class for domain adaptation with optimal transport"""
-
+
def __init__(self,metric='sqeuclidean'):
""" Class initialization"""
self.xs=0
@@ -132,42 +215,42 @@ class OTDA(object):
self.G=0
self.metric=metric
self.computed=False
-
-
+
+
def fit(self,xs,xt,ws=None,wt=None):
- """ Fit domain adaptation between samples is xs and xt (with optional
+ """ Fit domain adaptation between samples is xs and xt (with optional
weights)"""
self.xs=xs
self.xt=xt
-
+
if wt is None:
wt=unif(xt.shape[0])
if ws is None:
ws=unif(xs.shape[0])
-
+
self.ws=ws
self.wt=wt
-
+
self.M=dist(xs,xt,metric=self.metric)
self.G=emd(ws,wt,self.M)
self.computed=True
-
+
def interp(self,direction=1):
"""Barycentric interpolation for the source (1) or target (-1)
-
- This Barycentric interpolation solves for each source (resp target)
+
+ This Barycentric interpolation solves for each source (resp target)
sample xs (resp xt) the following optimization problem:
-
+
.. math::
arg\min_x \sum_i \gamma_{k,i} c(x,x_i^t)
-
+
where k is the index of the sample in xs
-
- For the moment only squared euclidean distance is provided but more
- metric could be used in the future.
-
+
+ For the moment only squared euclidean distance is provided but more
+ metric could be used in the future.
+
"""
- if direction>0: # >0 then source to target
+ if direction>0: # >0 then source to target
G=self.G
w=self.ws.reshape((self.xs.shape[0],1))
x=self.xt
@@ -175,81 +258,80 @@ class OTDA(object):
G=self.G.T
w=self.wt.reshape((self.xt.shape[0],1))
x=self.xs
-
+
if self.computed:
if self.metric=='sqeuclidean':
return np.dot(G/w,x) # weighted mean
else:
print("Warning, metric not handled yet, using weighted average")
- return np.dot(G/w,x) # weighted mean
- return None
+ return np.dot(G/w,x) # weighted mean
+ return None
else:
print("Warning, model not fitted yet, returning None")
return None
-
-
+
+
def predict(self,x,direction=1):
- """ Out of sample mapping using the formulation from Ferradans
-
- It basically find the source sample the nearset to the nex sample and
+ """ Out of sample mapping using the formulation from Ferradans
+
+ It basically find the source sample the nearset to the nex sample and
apply the difference to the displaced source sample.
-
+
"""
- if direction>0: # >0 then source to target
+ if direction>0: # >0 then source to target
xf=self.xt
x0=self.xs
else:
- xf=self.xs
+ xf=self.xs
x0=self.xt
-
+
D0=dist(x,x0) # dist netween new samples an source
idx=np.argmin(D0,1) # closest one
xf=self.interp(direction)# interp the source samples
return xf[idx,:]+x-x0[idx,:] # aply the delta to the interpolation
-
-
+
+
class OTDA_sinkhorn(OTDA):
"""Class for domain adaptation with optimal transport with entropic regularization"""
def fit(self,xs,xt,reg=1,ws=None,wt=None,**kwargs):
- """ Fit domain adaptation between samples is xs and xt (with optional
+ """ Fit domain adaptation between samples is xs and xt (with optional
weights)"""
self.xs=xs
self.xt=xt
-
+
if wt is None:
wt=unif(xt.shape[0])
if ws is None:
ws=unif(xs.shape[0])
-
+
self.ws=ws
self.wt=wt
-
+
self.M=dist(xs,xt,metric=self.metric)
self.G=sinkhorn(ws,wt,self.M,reg,**kwargs)
- self.computed=True
-
-
+ self.computed=True
+
+
class OTDA_lpl1(OTDA):
"""Class for domain adaptation with optimal transport with entropic an group regularization"""
-
-
+
+
def fit(self,xs,ys,xt,reg=1,eta=1,ws=None,wt=None,**kwargs):
- """ Fit domain adaptation between samples is xs and xt (with optional
+ """ Fit domain adaptation between samples is xs and xt (with optional
weights)"""
self.xs=xs
self.xt=xt
-
+
if wt is None:
wt=unif(xt.shape[0])
if ws is None:
ws=unif(xs.shape[0])
-
+
self.ws=ws
self.wt=wt
-
+
self.M=dist(xs,xt,metric=self.metric)
self.G=sinkhorn_lpl1_mm(ws,ys,wt,self.M,reg,eta,**kwargs)
- self.computed=True
-
- \ No newline at end of file
+ self.computed=True
+