summaryrefslogtreecommitdiff
path: root/ot
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2016-10-24 17:12:01 +0200
committerRémi Flamary <remi.flamary@gmail.com>2016-10-24 17:12:01 +0200
commit041f9bca2e89d1d111b553159aed862291630b00 (patch)
tree8ce24333e2f916592e7667fa8e0cbc8ebc6dda4a /ot
parent176ff069483b9ba630af8a00ce5edc104168c0a2 (diff)
add domain adaptation
Diffstat (limited to 'ot')
-rw-r--r--ot/__init__.py4
-rw-r--r--ot/da.py50
-rw-r--r--ot/datasets.py39
3 files changed, 91 insertions, 2 deletions
diff --git a/ot/__init__.py b/ot/__init__.py
index dd74590..f63d1c2 100644
--- a/ot/__init__.py
+++ b/ot/__init__.py
@@ -5,12 +5,14 @@ import utils
import datasets
import plot
import bregman
+import da
# OT functions
from emd import emd
from bregman import sinkhorn,barycenter
+from da import sinkhorn_lpl1_mm
# utils functions
from utils import dist,unif
-__all__ = ["emd","sinkhorn","utils",'datasets','bregman','plot','dist','unif','barycenter']
+__all__ = ["emd","sinkhorn","utils",'datasets','bregman','plot','dist','unif','barycenter','sinkhorn_lpl1_mm']
diff --git a/ot/da.py b/ot/da.py
new file mode 100644
index 0000000..8ecd952
--- /dev/null
+++ b/ot/da.py
@@ -0,0 +1,50 @@
+"""
+domain adaptation with optimal transport
+"""
+import numpy as np
+from scipy.spatial.distance import cdist
+
+from bregman import sinkhorn
+
+
+
+def indices(a, func):
+ return [i for (i, val) in enumerate(a) if func(val)]
+
+def sinkhorn_lpl1_mm(a,labels_a, b, M, reg, eta=0.1):
+ p=0.5
+ epsilon = 1e-3
+
+ # init data
+ Nini = len(a)
+ Nfin = len(b)
+
+ indices_labels = []
+ idx_begin = np.min(labels_a)
+ for c in range(idx_begin,np.max(labels_a)+1):
+ idxc = indices(labels_a, lambda x: x==c)
+ indices_labels.append(idxc)
+
+ W=np.zeros(M.shape)
+
+ for cpt in range(10):
+ Mreg = M + eta*W
+ transp=sinkhorn(a,b,Mreg,reg,numItermax = 200)
+ # the transport has been computed. Check if classes are really separated
+ W = np.ones((Nini,Nfin))
+ for t in range(Nfin):
+ column = transp[:,t]
+ all_maj = []
+ for c in range(idx_begin,np.max(labels_a)+1):
+ col_c = column[indices_labels[c-idx_begin]]
+ if c!=-1:
+ maj = p*((sum(col_c)+epsilon)**(p-1))
+ W[indices_labels[c-idx_begin],t]=maj
+ all_maj.append(maj)
+
+ # now we majorize the unlabelled by the min of the majorizations
+ # do it only for unlabbled data
+ if idx_begin==-1:
+ W[indices_labels[0],t]=np.min(all_maj)
+
+ return transp \ No newline at end of file
diff --git a/ot/datasets.py b/ot/datasets.py
index cebfdac..edc29a9 100644
--- a/ot/datasets.py
+++ b/ot/datasets.py
@@ -24,4 +24,41 @@ def get_2D_samples_gauss(n,m,sigma):
else:
res= np.random.randn(n,2)*np.sqrt(sigma)+m
return res
- \ No newline at end of file
+
+def get_data_classif(dataset,n,nz=.5,**kwargs):
+ """
+ dataset generation
+ """
+ if dataset.lower()=='3gauss':
+ y=np.floor((np.arange(n)*1.0/n*3))+1
+ x=np.zeros((n,2))
+ # class 1
+ x[y==1,0]=-1.; x[y==1,1]=-1.
+ x[y==2,0]=-1.; x[y==2,1]=1.
+ x[y==3,0]=1. ; x[y==3,1]=0
+
+ x[y!=3,:]+=nz*np.random.randn(sum(y!=3),2)
+ x[y==3,:]+=2*nz*np.random.randn(sum(y==3),2)
+
+ elif dataset.lower()=='3gauss2':
+ y=np.floor((np.arange(n)*1.0/n*4))+1
+ x=np.zeros((n,2))
+ y[y==4]=3
+ # class 1
+ x[y==1,0]=-1.; x[y==1,1]=-1.
+ x[y==2,0]=-1.; x[y==2,1]=1.
+ x[y==3,0]=1. ; x[y==3,1]=0
+
+ x[y!=3,:]+=nz*np.random.randn(sum(y!=3),2)
+ x[y==3,:]+=2*nz*np.random.randn(sum(y==3),2)
+# elif dataset.lower()=='sinreg':
+#
+# x=np.random.rand(n,1)
+# y=4*x+np.sin(2*np.pi*x)+nz*np.random.randn(n,1)
+
+ else:
+ x=0
+ y=0
+ print("unknown dataset")
+
+ return x,y \ No newline at end of file