diff options
Diffstat (limited to 'ot/da.py')
-rw-r--r-- | ot/da.py | 46 |
1 files changed, 15 insertions, 31 deletions
@@ -11,10 +11,6 @@ from .optim import cg from .optim import gcg -def indices(a, func): - return [i for (i, val) in enumerate(a) if func(val)] - - def sinkhorn_lpl1_mm(a,labels_a, b, M, reg, eta=0.1,numItermax = 10,numInnerItermax = 200,stopInnerThr=1e-9,verbose=False,log=False): """ Solve the entropic regularization optimal transport problem with nonconvex group lasso regularization @@ -46,7 +42,7 @@ def sinkhorn_lpl1_mm(a,labels_a, b, M, reg, eta=0.1,numItermax = 10,numInnerIter labels_a : np.ndarray (ns,) labels of samples in the source domain b : np.ndarray (nt,) - samples in the target domain + samples weights in the target domain M : np.ndarray (ns,nt) loss matrix reg : float @@ -86,40 +82,28 @@ def sinkhorn_lpl1_mm(a,labels_a, b, M, reg, eta=0.1,numItermax = 10,numInnerIter ot.optim.cg : General regularized OT """ - p=0.5 + p = 0.5 epsilon = 1e-3 - # init data - Nini = len(a) - Nfin = len(b) - indices_labels = [] - idx_begin = np.min(labels_a) - for c in range(idx_begin,np.max(labels_a)+1): - idxc = indices(labels_a, lambda x: x==c) + classes = np.unique(labels_a) + for c in classes: + idxc, = np.where(labels_a == c) indices_labels.append(idxc) - W=np.zeros(M.shape) + W = np.zeros(M.shape) for cpt in range(numItermax): Mreg = M + eta*W - transp=sinkhorn(a,b,Mreg,reg,numItermax=numInnerItermax, stopThr=stopInnerThr) - # the transport has been computed. Check if classes are really separated - W = np.ones((Nini,Nfin)) - for t in range(Nfin): - column = transp[:,t] - all_maj = [] - for c in range(idx_begin,np.max(labels_a)+1): - col_c = column[indices_labels[c-idx_begin]] - if c!=-1: - maj = p*((sum(col_c)+epsilon)**(p-1)) - W[indices_labels[c-idx_begin],t]=maj - all_maj.append(maj) - - # now we majorize the unlabelled by the min of the majorizations - # do it only for unlabbled data - if idx_begin==-1: - W[indices_labels[0],t]=np.min(all_maj) + transp = sinkhorn(a, b, Mreg, reg, numItermax=numInnerItermax, + stopThr=stopInnerThr) + # the transport has been computed. Check if classes are really + # separated + W = np.ones(M.shape) + for (i, c) in enumerate(classes): + majs = np.sum(transp[indices_labels[i]], axis=0) + majs = p*((majs+epsilon)**(p-1)) + W[indices_labels[i]] = majs return transp |