summaryrefslogtreecommitdiff
path: root/ot/da.py
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2016-10-24 17:12:01 +0200
committerRémi Flamary <remi.flamary@gmail.com>2016-10-24 17:12:01 +0200
commit041f9bca2e89d1d111b553159aed862291630b00 (patch)
tree8ce24333e2f916592e7667fa8e0cbc8ebc6dda4a /ot/da.py
parent176ff069483b9ba630af8a00ce5edc104168c0a2 (diff)
add domain adaptation
Diffstat (limited to 'ot/da.py')
-rw-r--r--ot/da.py50
1 files changed, 50 insertions, 0 deletions
diff --git a/ot/da.py b/ot/da.py
new file mode 100644
index 0000000..8ecd952
--- /dev/null
+++ b/ot/da.py
@@ -0,0 +1,50 @@
+"""
+domain adaptation with optimal transport
+"""
+import numpy as np
+from scipy.spatial.distance import cdist
+
+from bregman import sinkhorn
+
+
+
+def indices(a, func):
+ return [i for (i, val) in enumerate(a) if func(val)]
+
+def sinkhorn_lpl1_mm(a,labels_a, b, M, reg, eta=0.1):
+ p=0.5
+ epsilon = 1e-3
+
+ # init data
+ Nini = len(a)
+ Nfin = len(b)
+
+ indices_labels = []
+ idx_begin = np.min(labels_a)
+ for c in range(idx_begin,np.max(labels_a)+1):
+ idxc = indices(labels_a, lambda x: x==c)
+ indices_labels.append(idxc)
+
+ W=np.zeros(M.shape)
+
+ for cpt in range(10):
+ Mreg = M + eta*W
+ transp=sinkhorn(a,b,Mreg,reg,numItermax = 200)
+ # the transport has been computed. Check if classes are really separated
+ W = np.ones((Nini,Nfin))
+ for t in range(Nfin):
+ column = transp[:,t]
+ all_maj = []
+ for c in range(idx_begin,np.max(labels_a)+1):
+ col_c = column[indices_labels[c-idx_begin]]
+ if c!=-1:
+ maj = p*((sum(col_c)+epsilon)**(p-1))
+ W[indices_labels[c-idx_begin],t]=maj
+ all_maj.append(maj)
+
+ # now we majorize the unlabelled by the min of the majorizations
+ # do it only for unlabbled data
+ if idx_begin==-1:
+ W[indices_labels[0],t]=np.min(all_maj)
+
+ return transp \ No newline at end of file