summaryrefslogtreecommitdiff
path: root/ot/da.py
diff options
context:
space:
mode:
Diffstat (limited to 'ot/da.py')
-rw-r--r--ot/da.py37
1 files changed, 11 insertions, 26 deletions
diff --git a/ot/da.py b/ot/da.py
index c944c0d..557e2aa 100644
--- a/ot/da.py
+++ b/ot/da.py
@@ -11,7 +11,7 @@ from .optim import cg
from .optim import gcg
-def sinkhorn_lpl1_mm(a,labels_a, b, M, reg, eta=0.1,numItermax = 10,numInnerItermax = 200,stopInnerThr=1e-9,unlabelledValue=-99,verbose=False,log=False):
+def sinkhorn_lpl1_mm(a,labels_a, b, M, reg, eta=0.1,numItermax = 10,numInnerItermax = 200,stopInnerThr=1e-9,verbose=False,log=False):
"""
Solve the entropic regularization optimal transport problem with nonconvex group lasso regularization
@@ -55,8 +55,6 @@ def sinkhorn_lpl1_mm(a,labels_a, b, M, reg, eta=0.1,numItermax = 10,numInnerIter
Max number of iterations (inner sinkhorn solver)
stopInnerThr : float, optional
Stop threshold on error (inner sinkhorn solver) (>0)
- unlabelledValue : int, optional
- this value in array labels_a means this is an unlabelled example
verbose : bool, optional
Print information along iterations
log : bool, optional
@@ -84,41 +82,28 @@ def sinkhorn_lpl1_mm(a,labels_a, b, M, reg, eta=0.1,numItermax = 10,numInnerIter
ot.optim.cg : General regularized OT
"""
- p=0.5
+ p = 0.5
epsilon = 1e-3
- # init data
- Nini = len(a)
- Nfin = len(b)
-
indices_labels = []
classes = np.unique(labels_a)
for c in classes:
idxc, = np.where(labels_a == c)
indices_labels.append(idxc)
- W=np.zeros(M.shape)
+ W = np.zeros(M.shape)
for cpt in range(numItermax):
Mreg = M + eta*W
- transp=sinkhorn(a,b,Mreg,reg,numItermax=numInnerItermax, stopThr=stopInnerThr)
- # the transport has been computed. Check if classes are really separated
- W = np.ones((Nini,Nfin))
- all_majs = []
- idx_unlabelled = -1
+ transp = sinkhorn(a, b, Mreg, reg, numItermax=numInnerItermax,
+ stopThr=stopInnerThr)
+ # the transport has been computed. Check if classes are really
+ # separated
+ W = np.ones(M.shape)
for (i, c) in enumerate(classes):
- if c != unlabelledValue:
- majs = np.sum(transp[indices_labels[i]], axis=0)
- majs = p*((majs+epsilon)**(p-1))
- W[indices_labels[i]] = majs
- all_majs.append(majs)
- else:
- idx_unlabelled = i
-
- # now we majorize the unlabelled (if there are any) by the min of
- # the majorizations. do it only for unlabbled data
- if idx_unlabelled != -1:
- W[indices_labels[idx_unlabelled]] = np.min(all_majs, axis=0)
+ majs = np.sum(transp[indices_labels[i]], axis=0)
+ majs = p*((majs+epsilon)**(p-1))
+ W[indices_labels[i]] = majs
return transp