diff options
author | Rémi Flamary <remi.flamary@gmail.com> | 2016-10-24 17:12:01 +0200 |
---|---|---|
committer | Rémi Flamary <remi.flamary@gmail.com> | 2016-10-24 17:12:01 +0200 |
commit | 041f9bca2e89d1d111b553159aed862291630b00 (patch) | |
tree | 8ce24333e2f916592e7667fa8e0cbc8ebc6dda4a /ot/da.py | |
parent | 176ff069483b9ba630af8a00ce5edc104168c0a2 (diff) |
add domain adaptation
Diffstat (limited to 'ot/da.py')
-rw-r--r-- | ot/da.py | 50 |
1 files changed, 50 insertions, 0 deletions
diff --git a/ot/da.py b/ot/da.py new file mode 100644 index 0000000..8ecd952 --- /dev/null +++ b/ot/da.py @@ -0,0 +1,50 @@ +""" +domain adaptation with optimal transport +""" +import numpy as np +from scipy.spatial.distance import cdist + +from bregman import sinkhorn + + + +def indices(a, func): + return [i for (i, val) in enumerate(a) if func(val)] + +def sinkhorn_lpl1_mm(a,labels_a, b, M, reg, eta=0.1): + p=0.5 + epsilon = 1e-3 + + # init data + Nini = len(a) + Nfin = len(b) + + indices_labels = [] + idx_begin = np.min(labels_a) + for c in range(idx_begin,np.max(labels_a)+1): + idxc = indices(labels_a, lambda x: x==c) + indices_labels.append(idxc) + + W=np.zeros(M.shape) + + for cpt in range(10): + Mreg = M + eta*W + transp=sinkhorn(a,b,Mreg,reg,numItermax = 200) + # the transport has been computed. Check if classes are really separated + W = np.ones((Nini,Nfin)) + for t in range(Nfin): + column = transp[:,t] + all_maj = [] + for c in range(idx_begin,np.max(labels_a)+1): + col_c = column[indices_labels[c-idx_begin]] + if c!=-1: + maj = p*((sum(col_c)+epsilon)**(p-1)) + W[indices_labels[c-idx_begin],t]=maj + all_maj.append(maj) + + # now we majorize the unlabelled by the min of the majorizations + # do it only for unlabbled data + if idx_begin==-1: + W[indices_labels[0],t]=np.min(all_maj) + + return transp
\ No newline at end of file |