summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ot/da.py174
-rw-r--r--ot/datasets.py62
-rw-r--r--ot/utils.py51
3 files changed, 234 insertions, 53 deletions
diff --git a/ot/da.py b/ot/da.py
index 7cfbca1..66680cd 100644
--- a/ot/da.py
+++ b/ot/da.py
@@ -6,13 +6,15 @@ Domain adaptation with optimal transport
import numpy as np
from .bregman import sinkhorn
from .lp import emd
-from .utils import unif,dist
+from .utils import unif,dist,kernel
from .optim import cg
def indices(a, func):
return [i for (i, val) in enumerate(a) if func(val)]
+
+
def sinkhorn_lpl1_mm(a,labels_a, b, M, reg, eta=0.1,numItermax = 10,numInnerItermax = 200,stopInnerThr=1e-9,verbose=False,log=False):
"""
Solve the entropic regularization optimal transport problem with nonconvex group lasso regularization
@@ -129,13 +131,15 @@ def joint_OT_mapping_linear(xs,xt,mu=1,eta=0.001,bias=False,verbose=False,verbos
if bias:
xs1=np.hstack((xs,np.ones((ns,1))))
- I=eta*np.eye(d+1)
+ xstxs=xs1.T.dot(xs1)
+ I=np.eye(d+1)
I[-1]=0
I0=I[:,:-1]
sel=lambda x : x[:-1,:]
else:
xs1=xs
- I=eta*np.eye(d)
+ xstxs=xs1.T.dot(xs1)
+ I=np.eye(d)
I0=I
sel=lambda x : x
@@ -143,20 +147,22 @@ def joint_OT_mapping_linear(xs,xt,mu=1,eta=0.001,bias=False,verbose=False,verbos
log={'err':[]}
a,b=unif(ns),unif(nt)
- M=dist(xs,xt)
+ M=dist(xs,xt)*ns
G=emd(a,b,M)
vloss=[]
def loss(L,G):
+ """Compute full loss"""
return np.sum((xs1.dot(L)-ns*G.dot(xt))**2)+mu*np.sum(G*M)+eta*np.sum(sel(L-I0)**2)
def solve_L(G):
- """ solve problem with fixed G"""
+ """ solve L problem with fixed G (least square)"""
xst=ns*G.dot(xt)
- return np.linalg.solve(xs1.T.dot(xs1)+I,xs1.T.dot(xst)+I0)
+ return np.linalg.solve(xstxs+eta*I,xs1.T.dot(xst)+eta*I0)
def solve_G(L,G0):
+ """Update G with CG algorithm"""
xsi=xs1.dot(L)
def f(G):
return np.sum((xsi-ns*G.dot(xt))**2)
@@ -175,8 +181,11 @@ def joint_OT_mapping_linear(xs,xt,mu=1,eta=0.001,bias=False,verbose=False,verbos
print('{:5d}|{:8e}|{:8e}'.format(0,vloss[-1],0))
- # regul matrix
- loop=1
+ # init loop
+ if numItermax>0:
+ loop=1
+ else:
+ loop=0
it=0
while loop:
@@ -191,6 +200,9 @@ def joint_OT_mapping_linear(xs,xt,mu=1,eta=0.001,bias=False,verbose=False,verbos
vloss.append(loss(L,G))
+ if it>=numItermax:
+ loop=0
+
if abs(vloss[-1]-vloss[-2])<stopThr:
loop=0
@@ -198,11 +210,106 @@ def joint_OT_mapping_linear(xs,xt,mu=1,eta=0.001,bias=False,verbose=False,verbos
if it%20==0:
print('{:5s}|{:12s}|{:8s}'.format('It.','Loss','Delta loss')+'\n'+'-'*32)
print('{:5d}|{:8e}|{:8e}'.format(it,vloss[-1],abs(vloss[-1]-vloss[-2])/abs(vloss[-2])))
+ if log:
+ log['loss']=vloss
+ return G,L,log
+ else:
+ return G,L
- return G,L
+def joint_OT_mapping_kernel(xs,xt,mu=1,eta=0.001,kernel='gaussian',sigma=1,bias=False,verbose=False,verbose2=False,numItermax = 100,numInnerItermax = 20,stopInnerThr=1e-9,stopThr=1e-6,log=False,**kwargs):
+ """Joint Ot and mapping estimation (uniform weights and )
+ """
+ ns,nt,d=xs.shape[0],xt.shape[0],xt.shape[1]
+ if bias:
+ K=
+ xs1=np.hstack((xs,np.ones((ns,1))))
+ xstxs=xs1.T.dot(xs1)
+ I=np.eye(d+1)
+ I[-1]=0
+ I0=I[:,:-1]
+ sel=lambda x : x[:-1,:]
+ else:
+ xs1=xs
+ xstxs=xs1.T.dot(xs1)
+ I=np.eye(d)
+ I0=I
+ sel=lambda x : x
+
+ if log:
+ log={'err':[]}
+
+ a,b=unif(ns),unif(nt)
+ M=dist(xs,xt)*ns
+ G=emd(a,b,M)
+
+ vloss=[]
+
+ def loss(L,G):
+ """Compute full loss"""
+ return np.sum((xs1.dot(L)-ns*G.dot(xt))**2)+mu*np.sum(G*M)+eta*np.sum(sel(L-I0)**2)
+
+ def solve_L(G):
+ """ solve L problem with fixed G (least square)"""
+ xst=ns*G.dot(xt)
+ return np.linalg.solve(xstxs+eta*I,xs1.T.dot(xst)+eta*I0)
+
+ def solve_G(L,G0):
+ """Update G with CG algorithm"""
+ xsi=xs1.dot(L)
+ def f(G):
+ return np.sum((xsi-ns*G.dot(xt))**2)
+ def df(G):
+ return -2*ns*(xsi-ns*G.dot(xt)).dot(xt.T)
+ G=cg(a,b,M,1.0/mu,f,df,G0=G0,numItermax=numInnerItermax,stopThr=stopInnerThr)
+ return G
+
+
+ L=solve_L(G)
+
+ vloss.append(loss(L,G))
+
+ if verbose:
+ print('{:5s}|{:12s}|{:8s}'.format('It.','Loss','Delta loss')+'\n'+'-'*32)
+ print('{:5d}|{:8e}|{:8e}'.format(0,vloss[-1],0))
+
+
+ # init loop
+ if numItermax>0:
+ loop=1
+ else:
+ loop=0
+ it=0
+
+ while loop:
+
+ it+=1
+
+ # update G
+ G=solve_G(L,G)
+
+ #update L
+ L=solve_L(G)
+
+ vloss.append(loss(L,G))
+
+ if it>=numItermax:
+ loop=0
+
+ if abs(vloss[-1]-vloss[-2])<stopThr:
+ loop=0
+
+ if verbose:
+ if it%20==0:
+ print('{:5s}|{:12s}|{:8s}'.format('It.','Loss','Delta loss')+'\n'+'-'*32)
+ print('{:5d}|{:8e}|{:8e}'.format(it,vloss[-1],abs(vloss[-1]-vloss[-2])/abs(vloss[-2])))
+ if log:
+ log['loss']=vloss
+ return G,L,log
+ else:
+ return G,L
class OTDA(object):
@@ -294,6 +401,7 @@ class OTDA(object):
class OTDA_sinkhorn(OTDA):
"""Class for domain adaptation with optimal transport with entropic regularization"""
+
def fit(self,xs,xt,reg=1,ws=None,wt=None,**kwargs):
""" Fit domain adaptation between samples is xs and xt (with optional
weights)"""
@@ -335,3 +443,51 @@ class OTDA_lpl1(OTDA):
self.G=sinkhorn_lpl1_mm(ws,ys,wt,self.M,reg,eta,**kwargs)
self.computed=True
+class OTDA_mapping(OTDA):
+ """Class for optimal transport with joint linear mapping estimation"""
+
+
+ def __init__(self,metric='sqeuclidean'):
+ """ Class initialization"""
+
+
+ self.xs=0
+ self.xt=0
+ self.G=0
+ self.L=0
+ self.bias=False
+ self.metric=metric
+ self.computed=False
+
+ def fit(self,xs,xt,mu=1,eta=1,bias=False,**kwargs):
+ """ Fit domain adaptation between samples is xs and xt (with optional
+ weights)"""
+ self.xs=xs
+ self.xt=xt
+ self.bias=bias
+
+ self.ws=unif(xs.shape[0])
+ self.wt=unif(xt.shape[0])
+
+ self.G,self.L=joint_OT_mapping_linear(xs,xt,mu=mu,eta=eta,bias=bias,**kwargs)
+ self.computed=True
+
+ def mapping(self):
+ return lambda x: self.predict(x)
+
+
+ def predict(self,x):
+ """ Out of sample mapping using the formulation from Ferradans
+
+ It basically find the source sample the nearset to the nex sample and
+ apply the difference to the displaced source sample.
+
+ """
+ if self.computed:
+ if self.bias:
+ x=np.hstack((x,np.ones((x.shape[0],1))))
+ return x.dot(self.L) # aply the delta to the interpolation
+ else:
+ print("Warning, model not fitted yet, returning None")
+ return None
+
diff --git a/ot/datasets.py b/ot/datasets.py
index 6388d94..588f501 100644
--- a/ot/datasets.py
+++ b/ot/datasets.py
@@ -8,8 +8,8 @@ import scipy as sp
def get_1D_gauss(n,m,s):
- """return a 1D histogram for a gaussian distribution (n bins, mean m and std s)
-
+ """return a 1D histogram for a gaussian distribution (n bins, mean m and std s)
+
Parameters
----------
@@ -20,21 +20,21 @@ def get_1D_gauss(n,m,s):
s : float
standard deviaton of the gaussian distribution
-
+
Returns
-------
h : np.array (n,)
- 1D histogram for a gaussian distribution
-
+ 1D histogram for a gaussian distribution
+
"""
x=np.arange(n,dtype=np.float64)
h=np.exp(-(x-m)**2/(2*s^2))
return h/h.sum()
-
-
+
+
def get_2D_samples_gauss(n,m,sigma):
- """return n samples drawn from 2D gaussian N(m,sigma)
-
+ """return n samples drawn from 2D gaussian N(m,sigma)
+
Parameters
----------
@@ -45,12 +45,12 @@ def get_2D_samples_gauss(n,m,sigma):
sigma : np.array (2,2)
covariance matrix of the gaussian distribution
-
+
Returns
-------
X : np.array (n,2)
- n samples drawn from N(m,sigma)
-
+ n samples drawn from N(m,sigma)
+
"""
if np.isscalar(sigma):
sigma=np.array([sigma,])
@@ -61,9 +61,10 @@ def get_2D_samples_gauss(n,m,sigma):
res= np.random.randn(n,2)*np.sqrt(sigma)+m
return res
-def get_data_classif(dataset,n,nz=.5,**kwargs):
+
+def get_data_classif(dataset,n,nz=.5,theta=0,**kwargs):
""" dataset generation for classification problems
-
+
Parameters
----------
@@ -74,13 +75,13 @@ def get_data_classif(dataset,n,nz=.5,**kwargs):
nz : float
noise level (>0)
-
+
Returns
-------
X : np.array (n,d)
- n observation of size d
+ n observation of size d
y : np.array (n,)
- labels of the samples
+ labels of the samples
"""
if dataset.lower()=='3gauss':
@@ -90,10 +91,10 @@ def get_data_classif(dataset,n,nz=.5,**kwargs):
x[y==1,0]=-1.; x[y==1,1]=-1.
x[y==2,0]=-1.; x[y==2,1]=1.
x[y==3,0]=1. ; x[y==3,1]=0
-
+
x[y!=3,:]+=1.5*nz*np.random.randn(sum(y!=3),2)
x[y==3,:]+=2*nz*np.random.randn(sum(y==3),2)
-
+
elif dataset.lower()=='3gauss2':
y=np.floor((np.arange(n)*1.0/n*3))+1
x=np.zeros((n,2))
@@ -102,12 +103,29 @@ def get_data_classif(dataset,n,nz=.5,**kwargs):
x[y==1,0]=-2.; x[y==1,1]=-2.
x[y==2,0]=-2.; x[y==2,1]=2.
x[y==3,0]=2. ; x[y==3,1]=0
-
+
x[y!=3,:]+=nz*np.random.randn(sum(y!=3),2)
- x[y==3,:]+=2*nz*np.random.randn(sum(y==3),2)
+ x[y==3,:]+=2*nz*np.random.randn(sum(y==3),2)
+
+ elif dataset.lower()=='gaussrot' :
+ rot=np.array([[np.cos(theta),-np.sin(theta)],[np.sin(theta),np.cos(theta)]])
+ m1=np.array([-1,-1])
+ m2=np.array([1,1])
+ y=np.floor((np.arange(n)*1.0/n*2))+1
+ n1=np.sum(y==1)
+ n2=np.sum(y==2)
+ x=np.zeros((n,2))
+
+ x[y==1,:]=get_2D_samples_gauss(n1,m1,nz)
+ x[y==2,:]=get_2D_samples_gauss(n2,m2,nz)
+
+ x=x.dot(rot)
+
+
+
else:
x=0
y=0
print("unknown dataset")
-
+
return x,y.astype(int) \ No newline at end of file
diff --git a/ot/utils.py b/ot/utils.py
index 24f65a8..47fe77f 100644
--- a/ot/utils.py
+++ b/ot/utils.py
@@ -6,28 +6,34 @@ import numpy as np
from scipy.spatial.distance import cdist
+def kernel(x1,x2,method='gaussian',sigma=1,**kwargs):
+ """Compute kernel matrix"""
+ if method.lower() in ['gaussian','gauss','rbf']:
+ K=np.exp(dist(x1,x2)/(2*sigma**2))
+ return K
+
def unif(n):
- """ return a uniform histogram of length n (simplex)
-
+ """ return a uniform histogram of length n (simplex)
+
Parameters
----------
n : int
number of bins in the histogram
-
+
Returns
-------
h : np.array (n,)
- histogram of length n such that h_i=1/n for all i
-
-
+ histogram of length n such that h_i=1/n for all i
+
+
"""
return np.ones((n,))/n
def dist(x1,x2=None,metric='sqeuclidean'):
"""Compute distance between samples in x1 and x2 using function scipy.spatial.distance.cdist
-
+
Parameters
----------
@@ -36,28 +42,29 @@ def dist(x1,x2=None,metric='sqeuclidean'):
x2 : np.array (n2,d), optional
matrix with n2 samples of size d (if None then x2=x1)
metric : str, fun, optional
- name of the metric to be computed (full list in the doc of scipy), If a string,
+ name of the metric to be computed (full list in the doc of scipy), If a string,
the distance function can be ‘braycurtis’, ‘canberra’, ‘chebyshev’, ‘cityblock’,
‘correlation’, ‘cosine’, ‘dice’, ‘euclidean’, ‘hamming’, ‘jaccard’, ‘kulsinski’,
‘mahalanobis’, ‘matching’, ‘minkowski’, ‘rogerstanimoto’, ‘russellrao’, ‘seuclidean’,
‘sokalmichener’, ‘sokalsneath’, ‘sqeuclidean’, ‘wminkowski’, ‘yule’.
-
+
Returns
-------
-
+
M : np.array (n1,n2)
distance matrix computed with given metric
-
+
"""
if x2 is None:
- return cdist(x1,x1,metric=metric)
- else:
- return cdist(x1,x2,metric=metric)
-
+ x2=x1
+
+ return cdist(x1,x2,metric=metric)
+
+
def dist0(n,method='lin_square'):
"""Compute standard cost matrices of size (n,n) for OT problems
-
+
Parameters
----------
@@ -68,21 +75,21 @@ def dist0(n,method='lin_square'):
* 'lin_square' : linear sampling between 0 and n-1, quadratic loss
-
+
Returns
-------
-
+
M : np.array (n1,n2)
- distance matrix computed with given metric
-
-
+ distance matrix computed with given metric
+
+
"""
res=0
if method=='lin_square':
x=np.arange(n,dtype=np.float64).reshape((n,1))
res=dist(x,x)
return res
-
+
def dots(*args):
""" dots function for multiple matrix multiply """