summaryrefslogtreecommitdiff
path: root/ot/da.py
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2016-10-31 14:36:34 +0100
committerRémi Flamary <remi.flamary@gmail.com>2016-10-31 14:36:34 +0100
commitfdceacbac0d4b0380c90ca2c942e5abd0f69df64 (patch)
treeffe8fe373c7f7ae2f65172e820230c57762da84f /ot/da.py
parent104627b2f69eb22d3f9010955e6765ac2b179faa (diff)
add classes for entropic and group lasso regularization
Diffstat (limited to 'ot/da.py')
-rw-r--r--ot/da.py25
1 files changed, 23 insertions, 2 deletions
diff --git a/ot/da.py b/ot/da.py
index 87354b9..56963d8 100644
--- a/ot/da.py
+++ b/ot/da.py
@@ -210,8 +210,8 @@ class OTDA():
class OTDA_sinkhorn(OTDA):
-
- def fit(self,xs,xt,ws=None,wt=None,reg=1,**kwargs):
+ """Class for domain adaptation with optimal transport with entropic regularization"""
+ def fit(self,xs,xt,reg=1,ws=None,wt=None,**kwargs):
""" Fit domain adaptation between samples is xs and xt (with optional
weights)"""
self.xs=xs
@@ -230,5 +230,26 @@ class OTDA_sinkhorn(OTDA):
self.computed=True
+class OTDA_lpl1(OTDA):
+ """Class for domain adaptation with optimal transport with entropic an group regularization"""
+
+ def fit(self,xs,ys,xt,reg=1,eta=1,ws=None,wt=None,**kwargs):
+ """ Fit domain adaptation between samples is xs and xt (with optional
+ weights)"""
+ self.xs=xs
+ self.xt=xt
+
+ if wt is None:
+ wt=unif(xt.shape[0])
+ if ws is None:
+ ws=unif(xs.shape[0])
+
+ self.ws=ws
+ self.wt=wt
+
+ self.M=dist(xs,xt,metric=self.metric)
+ self.G=sinkhorn_lpl1_mm(ws,ys,wt,self.M,reg,eta,**kwargs)
+ self.computed=True
+
\ No newline at end of file