diff options
author | Rémi Flamary <remi.flamary@gmail.com> | 2018-03-20 15:05:57 +0100 |
---|---|---|
committer | Rémi Flamary <remi.flamary@gmail.com> | 2018-03-20 15:05:57 +0100 |
commit | cb739f625921e7fc19113d6d758e27ac69eac24b (patch) | |
tree | 89fc75a25e6b309441523f66768c974c0abc4deb /ot/da.py | |
parent | f31d7259bd8d02774301d478d8e2027abd8b10cf (diff) |
add linear mapping function
Diffstat (limited to 'ot/da.py')
-rw-r--r-- | ot/da.py | 36 |
1 files changed, 36 insertions, 0 deletions
@@ -10,6 +10,7 @@ Domain adaptation with optimal transport # License: MIT License import numpy as np +import scipy.linalg as linalg from .bregman import sinkhorn from .lp import emd @@ -633,6 +634,41 @@ def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian', return G, L +def OT_mapping_linear(xs, xt, reg=1e-6,ws=None,wt=None,log=False): + """ return OT linear operator between samples""" + + d=xs.shape[1] + + mxs=xs.mean(0,keepdims=True) + mxt=xt.mean(0,keepdims=True) + + + if ws is None: + ws=np.ones((xs.shape[0],1))/xs.shape[0] + + if wt is None: + wt=np.ones((xt.shape[0],1))/xt.shape[0] + + Cs=(xs*ws).T.dot(xs)/ws.sum()+reg*np.eye(d) + Ct=(xt*wt).T.dot(xt)/wt.sum()+reg*np.eye(d) + + + Cs12=linalg.sqrtm(Cs) + Cs_12=linalg.inv(Cs12) + + M0=linalg.sqrtm(Cs12.dot(Ct.dot(Cs12))) + + A=Cs_12.dot(M0.dot(Cs_12)).T + + b=mxt-mxs.dot(A) + + if log: + pass + else: + return A,b + + + @deprecated("The class OTDA is deprecated in 0.3.1 and will be " "removed in 0.5" "\n\tfor standard transport use class EMDTransport instead.") |