summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ot/da.py45
1 files changed, 22 insertions, 23 deletions
diff --git a/ot/da.py b/ot/da.py
index 44ce829..c944c0d 100644
--- a/ot/da.py
+++ b/ot/da.py
@@ -11,11 +11,7 @@ from .optim import cg
from .optim import gcg
-def indices(a, func):
- return [i for (i, val) in enumerate(a) if func(val)]
-
-
-def sinkhorn_lpl1_mm(a,labels_a, b, M, reg, eta=0.1,numItermax = 10,numInnerItermax = 200,stopInnerThr=1e-9,verbose=False,log=False):
+def sinkhorn_lpl1_mm(a,labels_a, b, M, reg, eta=0.1,numItermax = 10,numInnerItermax = 200,stopInnerThr=1e-9,unlabelledValue=-99,verbose=False,log=False):
"""
Solve the entropic regularization optimal transport problem with nonconvex group lasso regularization
@@ -46,7 +42,7 @@ def sinkhorn_lpl1_mm(a,labels_a, b, M, reg, eta=0.1,numItermax = 10,numInnerIter
labels_a : np.ndarray (ns,)
labels of samples in the source domain
b : np.ndarray (nt,)
- samples in the target domain
+ samples weights in the target domain
M : np.ndarray (ns,nt)
loss matrix
reg : float
@@ -59,6 +55,8 @@ def sinkhorn_lpl1_mm(a,labels_a, b, M, reg, eta=0.1,numItermax = 10,numInnerIter
Max number of iterations (inner sinkhorn solver)
stopInnerThr : float, optional
Stop threshold on error (inner sinkhorn solver) (>0)
+ unlabelledValue : int, optional
+ this value in array labels_a means this is an unlabelled example
verbose : bool, optional
Print information along iterations
log : bool, optional
@@ -94,9 +92,9 @@ def sinkhorn_lpl1_mm(a,labels_a, b, M, reg, eta=0.1,numItermax = 10,numInnerIter
Nfin = len(b)
indices_labels = []
- idx_begin = np.min(labels_a)
- for c in range(idx_begin,np.max(labels_a)+1):
- idxc = indices(labels_a, lambda x: x==c)
+ classes = np.unique(labels_a)
+ for c in classes:
+ idxc, = np.where(labels_a == c)
indices_labels.append(idxc)
W=np.zeros(M.shape)
@@ -106,20 +104,21 @@ def sinkhorn_lpl1_mm(a,labels_a, b, M, reg, eta=0.1,numItermax = 10,numInnerIter
transp=sinkhorn(a,b,Mreg,reg,numItermax=numInnerItermax, stopThr=stopInnerThr)
# the transport has been computed. Check if classes are really separated
W = np.ones((Nini,Nfin))
- for t in range(Nfin):
- column = transp[:,t]
- all_maj = []
- for c in range(idx_begin,np.max(labels_a)+1):
- col_c = column[indices_labels[c-idx_begin]]
- if c!=-1:
- maj = p*((sum(col_c)+epsilon)**(p-1))
- W[indices_labels[c-idx_begin],t]=maj
- all_maj.append(maj)
-
- # now we majorize the unlabelled by the min of the majorizations
- # do it only for unlabbled data
- if idx_begin==-1:
- W[indices_labels[0],t]=np.min(all_maj)
+ all_majs = []
+ idx_unlabelled = -1
+ for (i, c) in enumerate(classes):
+ if c != unlabelledValue:
+ majs = np.sum(transp[indices_labels[i]], axis=0)
+ majs = p*((majs+epsilon)**(p-1))
+ W[indices_labels[i]] = majs
+ all_majs.append(majs)
+ else:
+ idx_unlabelled = i
+
+ # now we majorize the unlabelled (if there are any) by the min of
+ # the majorizations. do it only for unlabbled data
+ if idx_unlabelled != -1:
+ W[indices_labels[idx_unlabelled]] = np.min(all_majs, axis=0)
return transp