summaryrefslogtreecommitdiff
path: root/ot/da.py
diff options
context:
space:
mode:
authorLeo gautheron <gautheron@iv-cm-359.creatis.insa-lyon.fr>2017-04-20 12:12:15 +0200
committerLeo gautheron <gautheron@iv-cm-359.creatis.insa-lyon.fr>2017-04-20 12:12:15 +0200
commit16f51f971607efab2c73958d207c582b389406c8 (patch)
tree299a4f6f13faf8545d2144767e9a7791098aacf8 /ot/da.py
parent48ec27d8e1c2599bd6d9015d15f4204b8116af28 (diff)
sinkhorn GPU implementation
Diffstat (limited to 'ot/da.py')
-rw-r--r--ot/da.py23
1 files changed, 11 insertions, 12 deletions
diff --git a/ot/da.py b/ot/da.py
index 81b6a35..d7b8492 100644
--- a/ot/da.py
+++ b/ot/da.py
@@ -193,30 +193,30 @@ def sinkhorn_l1l2_gl(a,labels_a, b, M, reg, eta=0.1,numItermax = 10,numInnerIter
"""
lstlab=np.unique(labels_a)
-
+
def f(G):
res=0
for i in range(G.shape[1]):
for lab in lstlab:
temp=G[labels_a==lab,i]
- res+=np.linalg.norm(temp)
+ res+=np.linalg.norm(temp)
return res
-
+
def df(G):
- W=np.zeros(G.shape)
+ W=np.zeros(G.shape)
for i in range(G.shape[1]):
for lab in lstlab:
temp=G[labels_a==lab,i]
n=np.linalg.norm(temp)
if n:
- W[labels_a==lab,i]=temp/n
- return W
+ W[labels_a==lab,i]=temp/n
+ return W
+
-
return gcg(a,b,M,reg,eta,f,df,G0=None,numItermax = numItermax,numInnerItermax=numInnerItermax, stopThr=stopInnerThr,verbose=verbose,log=log)
-
-
-
+
+
+
def joint_OT_mapping_linear(xs,xt,mu=1,eta=0.001,bias=False,verbose=False,verbose2=False,numItermax = 100,numInnerItermax = 10,stopInnerThr=1e-6,stopThr=1e-5,log=False,**kwargs):
"""Joint OT and linear mapping estimation as proposed in [8]
@@ -685,7 +685,6 @@ class OTDA(object):
return xf[idx,:]+x-x0[idx,:] # aply the delta to the interpolation
-
class OTDA_sinkhorn(OTDA):
"""Class for domain adaptation with optimal transport with entropic regularization"""
@@ -727,7 +726,7 @@ class OTDA_lpl1(OTDA):
self.M=dist(xs,xt,metric=self.metric)
self.G=sinkhorn_lpl1_mm(ws,ys,wt,self.M,reg,eta,**kwargs)
self.computed=True
-
+
class OTDA_l1l2(OTDA):
"""Class for domain adaptation with optimal transport with entropic and group lasso regularization"""