summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2018-03-20 16:27:49 +0100
committerRémi Flamary <remi.flamary@gmail.com>2018-03-20 16:27:49 +0100
commitc1046238d826fe9cf1294f8ea60b8d44743fac78 (patch)
tree505a707c8458c74901952fd84d9054f7922a86c6
parent8fc9fce6c920c646ea7324ac0af54ad53e9aa1bf (diff)
passing tests
-rw-r--r--examples/plot_otda_linear_mapping.py74
-rw-r--r--ot/da.py147
2 files changed, 110 insertions, 111 deletions
diff --git a/examples/plot_otda_linear_mapping.py b/examples/plot_otda_linear_mapping.py
index 44aa9c5..143f129 100644
--- a/examples/plot_otda_linear_mapping.py
+++ b/examples/plot_otda_linear_mapping.py
@@ -15,69 +15,71 @@ import scipy.linalg as linalg
#%%
-n=1000
-d=2
-sigma=.1
+n = 1000
+d = 2
+sigma = .1
# source samples
-angles=np.random.rand(n,1)*2*np.pi
-xs=np.concatenate((np.sin(angles),np.cos(angles)),axis=1)+sigma*np.random.randn(n,2)
-xs[:n//2,1]+=2
+angles = np.random.rand(n, 1) * 2 * np.pi
+xs = np.concatenate((np.sin(angles), np.cos(angles)),
+ axis=1) + sigma * np.random.randn(n, 2)
+xs[:n // 2, 1] += 2
# target samples
-anglet=np.random.rand(n,1)*2*np.pi
-xt=np.concatenate((np.sin(anglet),np.cos(anglet)),axis=1)+sigma*np.random.randn(n,2)
-xt[:n//2,1]+=2
+anglet = np.random.rand(n, 1) * 2 * np.pi
+xt = np.concatenate((np.sin(anglet), np.cos(anglet)),
+ axis=1) + sigma * np.random.randn(n, 2)
+xt[:n // 2, 1] += 2
-A=np.array([[1.5,.7],[.7,1.5]])
-b=np.array([[4,2]])
-xt=xt.dot(A)+b
+A = np.array([[1.5, .7], [.7, 1.5]])
+b = np.array([[4, 2]])
+xt = xt.dot(A) + b
#%%
-pl.figure(1,(5,5))
-pl.plot(xs[:,0],xs[:,1],'+')
-pl.plot(xt[:,0],xt[:,1],'o')
+pl.figure(1, (5, 5))
+pl.plot(xs[:, 0], xs[:, 1], '+')
+pl.plot(xt[:, 0], xt[:, 1], 'o')
#%%
-Ae,be=ot.da.OT_mapping_linear(xs,xt)
+Ae, be = ot.da.OT_mapping_linear(xs, xt)
-Ae1=linalg.inv(Ae)
-be1=-be.dot(Ae1)
+Ae1 = linalg.inv(Ae)
+be1 = -be.dot(Ae1)
-xst=xs.dot(Ae)+be
-xts=xt.dot(Ae1)+be1
+xst = xs.dot(Ae) + be
+xts = xt.dot(Ae1) + be1
-##%%
+# %%
-pl.figure(1,(5,5))
+pl.figure(1, (5, 5))
pl.clf()
-pl.plot(xs[:,0],xs[:,1],'+')
-pl.plot(xt[:,0],xt[:,1],'o')
-pl.plot(xst[:,0],xst[:,1],'+')
-pl.plot(xts[:,0],xts[:,1],'o')
+pl.plot(xs[:, 0], xs[:, 1], '+')
+pl.plot(xt[:, 0], xt[:, 1], 'o')
+pl.plot(xst[:, 0], xst[:, 1], '+')
+pl.plot(xts[:, 0], xts[:, 1], 'o')
pl.show()
#%% Example class with on images
-mapping=ot.da.LinearTransport()
+mapping = ot.da.LinearTransport()
-mapping.fit(Xs=xs,Xt=xt)
+mapping.fit(Xs=xs, Xt=xt)
-xst=mapping.transform(Xs=xs)
-xts=mapping.inverse_transform(Xt=xt)
+xst = mapping.transform(Xs=xs)
+xts = mapping.inverse_transform(Xt=xt)
-##%%
+# %%
-pl.figure(1,(5,5))
+pl.figure(1, (5, 5))
pl.clf()
-pl.plot(xs[:,0],xs[:,1],'+')
-pl.plot(xt[:,0],xt[:,1],'o')
-pl.plot(xst[:,0],xst[:,1],'+')
-pl.plot(xts[:,0],xts[:,1],'o')
+pl.plot(xs[:, 0], xs[:, 1], '+')
+pl.plot(xt[:, 0], xt[:, 1], 'o')
+pl.plot(xst[:, 0], xst[:, 1], '+')
+pl.plot(xts[:, 0], xts[:, 1], 'o')
diff --git a/ot/da.py b/ot/da.py
index ab5f860..f789396 100644
--- a/ot/da.py
+++ b/ot/da.py
@@ -357,7 +357,8 @@ def joint_OT_mapping_linear(xs, xt, mu=1, eta=0.001, bias=False, verbose=False,
def loss(L, G):
"""Compute full loss"""
- return np.sum((xs1.dot(L) - ns * G.dot(xt))**2) + mu * np.sum(G * M) + eta * np.sum(sel(L - I0)**2)
+ return np.sum((xs1.dot(L) - ns * G.dot(xt))**2) + mu * \
+ np.sum(G * M) + eta * np.sum(sel(L - I0)**2)
def solve_L(G):
""" solve L problem with fixed G (least square)"""
@@ -557,7 +558,8 @@ def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian',
def loss(L, G):
"""Compute full loss"""
- return np.sum((K1.dot(L) - ns * G.dot(xt))**2) + mu * np.sum(G * M) + eta * np.trace(L.T.dot(Kreg).dot(L))
+ return np.sum((K1.dot(L) - ns * G.dot(xt))**2) + mu * \
+ np.sum(G * M) + eta * np.trace(L.T.dot(Kreg).dot(L))
def solve_L_nobias(G):
""" solve L problem with fixed G (least square)"""
@@ -634,25 +636,26 @@ def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian',
return G, L
-def OT_mapping_linear(xs, xt, reg=1e-6,ws=None,wt=None,bias=True,log=False):
+def OT_mapping_linear(xs, xt, reg=1e-6, ws=None,
+ wt=None, bias=True, log=False):
""" return OT linear operator between samples
The function estimate the optimal linear operator that align the two
- empirical distributions. This is equivalent to estimating the closed
- form mapping between two Gaussian distribution :math:`N(\mu_s,\Sigma_s)`
+ empirical distributions. This is equivalent to estimating the closed
+ form mapping between two Gaussian distribution :math:`N(\mu_s,\Sigma_s)`
and :math:`N(\mu_t,\Sigma_t)` as proposed in [14].
-
+
The linear operator from source to target :math:`M`
.. math::
M(x)=Ax+b
-
+
where :
-
+
.. math::
A=\Sigma_s^{-1/2}(\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2})^{1/2}
\Sigma_s^{-1/2}
- .. math::
+ .. math::
b=\mu_t-A\mu_s
Parameters
@@ -666,7 +669,7 @@ def OT_mapping_linear(xs, xt, reg=1e-6,ws=None,wt=None,bias=True,log=False):
ws : np.ndarray (ns,1), optional
weights for the source samples
wt : np.ndarray (ns,1), optional
- weights for the target samples
+ weights for the target samples
bias: boolean, optional
estimate bias b else b=0 (default:True)
log : bool, optional
@@ -686,55 +689,52 @@ def OT_mapping_linear(xs, xt, reg=1e-6,ws=None,wt=None,bias=True,log=False):
References
----------
- .. [14] Knott, M. and Smith, C. S. "On the optimal mapping of
+ .. [14] Knott, M. and Smith, C. S. "On the optimal mapping of
distributions", Journal of Optimization Theory and Applications
Vol 43, 1984
"""
- d=xs.shape[1]
-
+ d = xs.shape[1]
+
if bias:
- mxs=xs.mean(0,keepdims=True)
- mxt=xt.mean(0,keepdims=True)
-
- xs=xs-mxs
- xt=xt-mxt
+ mxs = xs.mean(0, keepdims=True)
+ mxt = xt.mean(0, keepdims=True)
+
+ xs = xs - mxs
+ xt = xt - mxt
else:
- mxs=np.zeros((1,d))
- mxt=np.zeros((1,d))
+ mxs = np.zeros((1, d))
+ mxt = np.zeros((1, d))
-
if ws is None:
- ws=np.ones((xs.shape[0],1))/xs.shape[0]
-
+ ws = np.ones((xs.shape[0], 1)) / xs.shape[0]
+
if wt is None:
- wt=np.ones((xt.shape[0],1))/xt.shape[0]
-
- Cs=(xs*ws).T.dot(xs)/ws.sum()+reg*np.eye(d)
- Ct=(xt*wt).T.dot(xt)/wt.sum()+reg*np.eye(d)
-
-
- Cs12=linalg.sqrtm(Cs)
- Cs_12=linalg.inv(Cs12)
-
- M0=linalg.sqrtm(Cs12.dot(Ct.dot(Cs12)))
-
- A=Cs_12.dot(M0.dot(Cs_12))
-
- b=mxt-mxs.dot(A)
-
+ wt = np.ones((xt.shape[0], 1)) / xt.shape[0]
+
+ Cs = (xs * ws).T.dot(xs) / ws.sum() + reg * np.eye(d)
+ Ct = (xt * wt).T.dot(xt) / wt.sum() + reg * np.eye(d)
+
+ Cs12 = linalg.sqrtm(Cs)
+ Cs_12 = linalg.inv(Cs12)
+
+ M0 = linalg.sqrtm(Cs12.dot(Ct.dot(Cs12)))
+
+ A = Cs_12.dot(M0.dot(Cs_12))
+
+ b = mxt - mxs.dot(A)
+
if log:
- log={}
- log['Cs']=Cs
- log['Ct']=Ct
- log['Cs12']=Cs12
- log['Cs_12']=Cs_12
- return A,b,log
+ log = {}
+ log['Cs'] = Cs
+ log['Ct'] = Ct
+ log['Cs12'] = Cs12
+ log['Cs_12'] = Cs_12
+ return A, b, log
else:
- return A,b
-
+ return A, b
@deprecated("The class OTDA is deprecated in 0.3.1 and will be "
@@ -1288,42 +1288,42 @@ class LinearTransport(BaseTransport):
""" OT linear operator between empirical distributions
The function estimate the optimal linear operator that align the two
- empirical distributions. This is equivalent to estimating the closed
- form mapping between two Gaussian distribution :math:`N(\mu_s,\Sigma_s)`
+ empirical distributions. This is equivalent to estimating the closed
+ form mapping between two Gaussian distribution :math:`N(\mu_s,\Sigma_s)`
and :math:`N(\mu_t,\Sigma_t)` as proposed in [14].
-
+
The linear operator from source to target :math:`M`
.. math::
M(x)=Ax+b
-
+
where :
-
+
.. math::
A=\Sigma_s^{-1/2}(\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2})^{1/2}
\Sigma_s^{-1/2}
- .. math::
+ .. math::
b=\mu_t-A\mu_s
Parameters
----------
reg : float,optional
- regularization added to the daigonals of convariances (>0)
+ regularization added to the daigonals of convariances (>0)
bias: boolean, optional
estimate bias b else b=0 (default:True)
log : bool, optional
record log if True
-
-
+
+
"""
-
- def __init__(self, reg=1e-8,bias=True,log=False,
+
+ def __init__(self, reg=1e-8, bias=True, log=False,
distribution_estimation=distribution_estimation_uniform):
-
- self.bias=bias
- self.log=log
- self.reg=reg
- self.distribution_estimation=distribution_estimation
+
+ self.bias = bias
+ self.log = log
+ self.reg = reg
+ self.distribution_estimation = distribution_estimation
def fit(self, Xs=None, ys=None, Xt=None, yt=None):
"""Build a coupling matrix from source and target sets of samples
@@ -1349,17 +1349,15 @@ class LinearTransport(BaseTransport):
self : object
Returns self.
"""
-
+
self.mu_s = self.distribution_estimation(Xs)
self.mu_t = self.distribution_estimation(Xt)
-
-
# coupling estimation
- returned_ = OT_mapping_linear(Xs,Xt,reg=self.reg,
- ws=self.mu_s.reshape((-1,1)),
- wt=self.mu_t.reshape((-1,1)),
- bias=self.bias,log=self.log)
+ returned_ = OT_mapping_linear(Xs, Xt, reg=self.reg,
+ ws=self.mu_s.reshape((-1, 1)),
+ wt=self.mu_t.reshape((-1, 1)),
+ bias=self.bias, log=self.log)
# deal with the value of log
if self.log:
@@ -1367,10 +1365,10 @@ class LinearTransport(BaseTransport):
else:
self.A_, self.B_, = returned_
self.log_ = dict()
-
+
# re compute inverse mapping
- self.A1_=linalg.inv(self.A_)
- self.B1_=-self.B_.dot(self.A1_)
+ self.A1_ = linalg.inv(self.A_)
+ self.B1_ = -self.B_.dot(self.A1_)
return self
@@ -1403,7 +1401,7 @@ class LinearTransport(BaseTransport):
# check the necessary inputs parameters are here
if check_params(Xs=Xs):
- transp_Xs= Xs.dot(self.A_)+self.B_
+ transp_Xs = Xs.dot(self.A_) + self.B_
return transp_Xs
@@ -1437,12 +1435,11 @@ class LinearTransport(BaseTransport):
# check the necessary inputs parameters are here
if check_params(Xt=Xt):
- transp_Xt= Xt.dot(self.A1_)+self.B1_
+ transp_Xt = Xt.dot(self.A1_) + self.B1_
return transp_Xt
-
class SinkhornTransport(BaseTransport):
"""Domain Adapatation OT method based on Sinkhorn Algorithm