summaryrefslogtreecommitdiff
path: root/ot/da.py
diff options
context:
space:
mode:
Diffstat (limited to 'ot/da.py')
-rw-r--r--ot/da.py46
1 files changed, 15 insertions, 31 deletions
diff --git a/ot/da.py b/ot/da.py
index 44ce829..557e2aa 100644
--- a/ot/da.py
+++ b/ot/da.py
@@ -11,10 +11,6 @@ from .optim import cg
from .optim import gcg
-def indices(a, func):
- return [i for (i, val) in enumerate(a) if func(val)]
-
-
def sinkhorn_lpl1_mm(a,labels_a, b, M, reg, eta=0.1,numItermax = 10,numInnerItermax = 200,stopInnerThr=1e-9,verbose=False,log=False):
"""
Solve the entropic regularization optimal transport problem with nonconvex group lasso regularization
@@ -46,7 +42,7 @@ def sinkhorn_lpl1_mm(a,labels_a, b, M, reg, eta=0.1,numItermax = 10,numInnerIter
labels_a : np.ndarray (ns,)
labels of samples in the source domain
b : np.ndarray (nt,)
- samples in the target domain
+ samples weights in the target domain
M : np.ndarray (ns,nt)
loss matrix
reg : float
@@ -86,40 +82,28 @@ def sinkhorn_lpl1_mm(a,labels_a, b, M, reg, eta=0.1,numItermax = 10,numInnerIter
ot.optim.cg : General regularized OT
"""
- p=0.5
+ p = 0.5
epsilon = 1e-3
- # init data
- Nini = len(a)
- Nfin = len(b)
-
indices_labels = []
- idx_begin = np.min(labels_a)
- for c in range(idx_begin,np.max(labels_a)+1):
- idxc = indices(labels_a, lambda x: x==c)
+ classes = np.unique(labels_a)
+ for c in classes:
+ idxc, = np.where(labels_a == c)
indices_labels.append(idxc)
- W=np.zeros(M.shape)
+ W = np.zeros(M.shape)
for cpt in range(numItermax):
Mreg = M + eta*W
- transp=sinkhorn(a,b,Mreg,reg,numItermax=numInnerItermax, stopThr=stopInnerThr)
- # the transport has been computed. Check if classes are really separated
- W = np.ones((Nini,Nfin))
- for t in range(Nfin):
- column = transp[:,t]
- all_maj = []
- for c in range(idx_begin,np.max(labels_a)+1):
- col_c = column[indices_labels[c-idx_begin]]
- if c!=-1:
- maj = p*((sum(col_c)+epsilon)**(p-1))
- W[indices_labels[c-idx_begin],t]=maj
- all_maj.append(maj)
-
- # now we majorize the unlabelled by the min of the majorizations
- # do it only for unlabbled data
- if idx_begin==-1:
- W[indices_labels[0],t]=np.min(all_maj)
+ transp = sinkhorn(a, b, Mreg, reg, numItermax=numInnerItermax,
+ stopThr=stopInnerThr)
+ # the transport has been computed. Check if classes are really
+ # separated
+ W = np.ones(M.shape)
+ for (i, c) in enumerate(classes):
+ majs = np.sum(transp[indices_labels[i]], axis=0)
+ majs = p*((majs+epsilon)**(p-1))
+ W[indices_labels[i]] = majs
return transp