summaryrefslogtreecommitdiff
path: root/ot/da.py
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2018-03-20 16:27:49 +0100
committerRémi Flamary <remi.flamary@gmail.com>2018-03-20 16:27:49 +0100
commitc1046238d826fe9cf1294f8ea60b8d44743fac78 (patch)
tree505a707c8458c74901952fd84d9054f7922a86c6 /ot/da.py
parent8fc9fce6c920c646ea7324ac0af54ad53e9aa1bf (diff)
passing tests
Diffstat (limited to 'ot/da.py')
-rw-r--r--ot/da.py147
1 files changed, 72 insertions, 75 deletions
diff --git a/ot/da.py b/ot/da.py
index ab5f860..f789396 100644
--- a/ot/da.py
+++ b/ot/da.py
@@ -357,7 +357,8 @@ def joint_OT_mapping_linear(xs, xt, mu=1, eta=0.001, bias=False, verbose=False,
def loss(L, G):
"""Compute full loss"""
- return np.sum((xs1.dot(L) - ns * G.dot(xt))**2) + mu * np.sum(G * M) + eta * np.sum(sel(L - I0)**2)
+ return np.sum((xs1.dot(L) - ns * G.dot(xt))**2) + mu * \
+ np.sum(G * M) + eta * np.sum(sel(L - I0)**2)
def solve_L(G):
""" solve L problem with fixed G (least square)"""
@@ -557,7 +558,8 @@ def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian',
def loss(L, G):
"""Compute full loss"""
- return np.sum((K1.dot(L) - ns * G.dot(xt))**2) + mu * np.sum(G * M) + eta * np.trace(L.T.dot(Kreg).dot(L))
+ return np.sum((K1.dot(L) - ns * G.dot(xt))**2) + mu * \
+ np.sum(G * M) + eta * np.trace(L.T.dot(Kreg).dot(L))
def solve_L_nobias(G):
""" solve L problem with fixed G (least square)"""
@@ -634,25 +636,26 @@ def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian',
return G, L
-def OT_mapping_linear(xs, xt, reg=1e-6,ws=None,wt=None,bias=True,log=False):
+def OT_mapping_linear(xs, xt, reg=1e-6, ws=None,
+ wt=None, bias=True, log=False):
""" return OT linear operator between samples
The function estimate the optimal linear operator that align the two
- empirical distributions. This is equivalent to estimating the closed
- form mapping between two Gaussian distribution :math:`N(\mu_s,\Sigma_s)`
+ empirical distributions. This is equivalent to estimating the closed
+ form mapping between two Gaussian distribution :math:`N(\mu_s,\Sigma_s)`
and :math:`N(\mu_t,\Sigma_t)` as proposed in [14].
-
+
The linear operator from source to target :math:`M`
.. math::
M(x)=Ax+b
-
+
where :
-
+
.. math::
A=\Sigma_s^{-1/2}(\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2})^{1/2}
\Sigma_s^{-1/2}
- .. math::
+ .. math::
b=\mu_t-A\mu_s
Parameters
@@ -666,7 +669,7 @@ def OT_mapping_linear(xs, xt, reg=1e-6,ws=None,wt=None,bias=True,log=False):
ws : np.ndarray (ns,1), optional
weights for the source samples
wt : np.ndarray (ns,1), optional
- weights for the target samples
+ weights for the target samples
bias: boolean, optional
estimate bias b else b=0 (default:True)
log : bool, optional
@@ -686,55 +689,52 @@ def OT_mapping_linear(xs, xt, reg=1e-6,ws=None,wt=None,bias=True,log=False):
References
----------
- .. [14] Knott, M. and Smith, C. S. "On the optimal mapping of
+ .. [14] Knott, M. and Smith, C. S. "On the optimal mapping of
distributions", Journal of Optimization Theory and Applications
Vol 43, 1984
"""
- d=xs.shape[1]
-
+ d = xs.shape[1]
+
if bias:
- mxs=xs.mean(0,keepdims=True)
- mxt=xt.mean(0,keepdims=True)
-
- xs=xs-mxs
- xt=xt-mxt
+ mxs = xs.mean(0, keepdims=True)
+ mxt = xt.mean(0, keepdims=True)
+
+ xs = xs - mxs
+ xt = xt - mxt
else:
- mxs=np.zeros((1,d))
- mxt=np.zeros((1,d))
+ mxs = np.zeros((1, d))
+ mxt = np.zeros((1, d))
-
if ws is None:
- ws=np.ones((xs.shape[0],1))/xs.shape[0]
-
+ ws = np.ones((xs.shape[0], 1)) / xs.shape[0]
+
if wt is None:
- wt=np.ones((xt.shape[0],1))/xt.shape[0]
-
- Cs=(xs*ws).T.dot(xs)/ws.sum()+reg*np.eye(d)
- Ct=(xt*wt).T.dot(xt)/wt.sum()+reg*np.eye(d)
-
-
- Cs12=linalg.sqrtm(Cs)
- Cs_12=linalg.inv(Cs12)
-
- M0=linalg.sqrtm(Cs12.dot(Ct.dot(Cs12)))
-
- A=Cs_12.dot(M0.dot(Cs_12))
-
- b=mxt-mxs.dot(A)
-
+ wt = np.ones((xt.shape[0], 1)) / xt.shape[0]
+
+ Cs = (xs * ws).T.dot(xs) / ws.sum() + reg * np.eye(d)
+ Ct = (xt * wt).T.dot(xt) / wt.sum() + reg * np.eye(d)
+
+ Cs12 = linalg.sqrtm(Cs)
+ Cs_12 = linalg.inv(Cs12)
+
+ M0 = linalg.sqrtm(Cs12.dot(Ct.dot(Cs12)))
+
+ A = Cs_12.dot(M0.dot(Cs_12))
+
+ b = mxt - mxs.dot(A)
+
if log:
- log={}
- log['Cs']=Cs
- log['Ct']=Ct
- log['Cs12']=Cs12
- log['Cs_12']=Cs_12
- return A,b,log
+ log = {}
+ log['Cs'] = Cs
+ log['Ct'] = Ct
+ log['Cs12'] = Cs12
+ log['Cs_12'] = Cs_12
+ return A, b, log
else:
- return A,b
-
+ return A, b
@deprecated("The class OTDA is deprecated in 0.3.1 and will be "
@@ -1288,42 +1288,42 @@ class LinearTransport(BaseTransport):
""" OT linear operator between empirical distributions
The function estimate the optimal linear operator that align the two
- empirical distributions. This is equivalent to estimating the closed
- form mapping between two Gaussian distribution :math:`N(\mu_s,\Sigma_s)`
+ empirical distributions. This is equivalent to estimating the closed
+ form mapping between two Gaussian distribution :math:`N(\mu_s,\Sigma_s)`
and :math:`N(\mu_t,\Sigma_t)` as proposed in [14].
-
+
The linear operator from source to target :math:`M`
.. math::
M(x)=Ax+b
-
+
where :
-
+
.. math::
A=\Sigma_s^{-1/2}(\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2})^{1/2}
\Sigma_s^{-1/2}
- .. math::
+ .. math::
b=\mu_t-A\mu_s
Parameters
----------
reg : float,optional
- regularization added to the daigonals of convariances (>0)
+ regularization added to the daigonals of convariances (>0)
bias: boolean, optional
estimate bias b else b=0 (default:True)
log : bool, optional
record log if True
-
-
+
+
"""
-
- def __init__(self, reg=1e-8,bias=True,log=False,
+
+ def __init__(self, reg=1e-8, bias=True, log=False,
distribution_estimation=distribution_estimation_uniform):
-
- self.bias=bias
- self.log=log
- self.reg=reg
- self.distribution_estimation=distribution_estimation
+
+ self.bias = bias
+ self.log = log
+ self.reg = reg
+ self.distribution_estimation = distribution_estimation
def fit(self, Xs=None, ys=None, Xt=None, yt=None):
"""Build a coupling matrix from source and target sets of samples
@@ -1349,17 +1349,15 @@ class LinearTransport(BaseTransport):
self : object
Returns self.
"""
-
+
self.mu_s = self.distribution_estimation(Xs)
self.mu_t = self.distribution_estimation(Xt)
-
-
# coupling estimation
- returned_ = OT_mapping_linear(Xs,Xt,reg=self.reg,
- ws=self.mu_s.reshape((-1,1)),
- wt=self.mu_t.reshape((-1,1)),
- bias=self.bias,log=self.log)
+ returned_ = OT_mapping_linear(Xs, Xt, reg=self.reg,
+ ws=self.mu_s.reshape((-1, 1)),
+ wt=self.mu_t.reshape((-1, 1)),
+ bias=self.bias, log=self.log)
# deal with the value of log
if self.log:
@@ -1367,10 +1365,10 @@ class LinearTransport(BaseTransport):
else:
self.A_, self.B_, = returned_
self.log_ = dict()
-
+
# re compute inverse mapping
- self.A1_=linalg.inv(self.A_)
- self.B1_=-self.B_.dot(self.A1_)
+ self.A1_ = linalg.inv(self.A_)
+ self.B1_ = -self.B_.dot(self.A1_)
return self
@@ -1403,7 +1401,7 @@ class LinearTransport(BaseTransport):
# check the necessary inputs parameters are here
if check_params(Xs=Xs):
- transp_Xs= Xs.dot(self.A_)+self.B_
+ transp_Xs = Xs.dot(self.A_) + self.B_
return transp_Xs
@@ -1437,12 +1435,11 @@ class LinearTransport(BaseTransport):
# check the necessary inputs parameters are here
if check_params(Xt=Xt):
- transp_Xt= Xt.dot(self.A1_)+self.B1_
+ transp_Xt = Xt.dot(self.A1_) + self.B1_
return transp_Xt
-
class SinkhornTransport(BaseTransport):
"""Domain Adapatation OT method based on Sinkhorn Algorithm