From 041f9bca2e89d1d111b553159aed862291630b00 Mon Sep 17 00:00:00 2001 From: RĂ©mi Flamary Date: Mon, 24 Oct 2016 17:12:01 +0200 Subject: add domain adaptation --- ot/__init__.py | 4 +++- ot/da.py | 50 ++++++++++++++++++++++++++++++++++++++++++++++++++ ot/datasets.py | 39 ++++++++++++++++++++++++++++++++++++++- 3 files changed, 91 insertions(+), 2 deletions(-) create mode 100644 ot/da.py diff --git a/ot/__init__.py b/ot/__init__.py index dd74590..f63d1c2 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -5,12 +5,14 @@ import utils import datasets import plot import bregman +import da # OT functions from emd import emd from bregman import sinkhorn,barycenter +from da import sinkhorn_lpl1_mm # utils functions from utils import dist,unif -__all__ = ["emd","sinkhorn","utils",'datasets','bregman','plot','dist','unif','barycenter'] +__all__ = ["emd","sinkhorn","utils",'datasets','bregman','plot','dist','unif','barycenter','sinkhorn_lpl1_mm'] diff --git a/ot/da.py b/ot/da.py new file mode 100644 index 0000000..8ecd952 --- /dev/null +++ b/ot/da.py @@ -0,0 +1,50 @@ +""" +domain adaptation with optimal transport +""" +import numpy as np +from scipy.spatial.distance import cdist + +from bregman import sinkhorn + + + +def indices(a, func): + return [i for (i, val) in enumerate(a) if func(val)] + +def sinkhorn_lpl1_mm(a,labels_a, b, M, reg, eta=0.1): + p=0.5 + epsilon = 1e-3 + + # init data + Nini = len(a) + Nfin = len(b) + + indices_labels = [] + idx_begin = np.min(labels_a) + for c in range(idx_begin,np.max(labels_a)+1): + idxc = indices(labels_a, lambda x: x==c) + indices_labels.append(idxc) + + W=np.zeros(M.shape) + + for cpt in range(10): + Mreg = M + eta*W + transp=sinkhorn(a,b,Mreg,reg,numItermax = 200) + # the transport has been computed. Check if classes are really separated + W = np.ones((Nini,Nfin)) + for t in range(Nfin): + column = transp[:,t] + all_maj = [] + for c in range(idx_begin,np.max(labels_a)+1): + col_c = column[indices_labels[c-idx_begin]] + if c!=-1: + maj = p*((sum(col_c)+epsilon)**(p-1)) + W[indices_labels[c-idx_begin],t]=maj + all_maj.append(maj) + + # now we majorize the unlabelled by the min of the majorizations + # do it only for unlabbled data + if idx_begin==-1: + W[indices_labels[0],t]=np.min(all_maj) + + return transp \ No newline at end of file diff --git a/ot/datasets.py b/ot/datasets.py index cebfdac..edc29a9 100644 --- a/ot/datasets.py +++ b/ot/datasets.py @@ -24,4 +24,41 @@ def get_2D_samples_gauss(n,m,sigma): else: res= np.random.randn(n,2)*np.sqrt(sigma)+m return res - \ No newline at end of file + +def get_data_classif(dataset,n,nz=.5,**kwargs): + """ + dataset generation + """ + if dataset.lower()=='3gauss': + y=np.floor((np.arange(n)*1.0/n*3))+1 + x=np.zeros((n,2)) + # class 1 + x[y==1,0]=-1.; x[y==1,1]=-1. + x[y==2,0]=-1.; x[y==2,1]=1. + x[y==3,0]=1. ; x[y==3,1]=0 + + x[y!=3,:]+=nz*np.random.randn(sum(y!=3),2) + x[y==3,:]+=2*nz*np.random.randn(sum(y==3),2) + + elif dataset.lower()=='3gauss2': + y=np.floor((np.arange(n)*1.0/n*4))+1 + x=np.zeros((n,2)) + y[y==4]=3 + # class 1 + x[y==1,0]=-1.; x[y==1,1]=-1. + x[y==2,0]=-1.; x[y==2,1]=1. + x[y==3,0]=1. ; x[y==3,1]=0 + + x[y!=3,:]+=nz*np.random.randn(sum(y!=3),2) + x[y==3,:]+=2*nz*np.random.randn(sum(y==3),2) +# elif dataset.lower()=='sinreg': +# +# x=np.random.rand(n,1) +# y=4*x+np.sin(2*np.pi*x)+nz*np.random.randn(n,1) + + else: + x=0 + y=0 + print("unknown dataset") + + return x,y \ No newline at end of file -- cgit v1.2.3