diff options
author | Rémi Flamary <remi.flamary@gmail.com> | 2018-03-20 15:05:57 +0100 |
---|---|---|
committer | Rémi Flamary <remi.flamary@gmail.com> | 2018-03-20 15:05:57 +0100 |
commit | cb739f625921e7fc19113d6d758e27ac69eac24b (patch) | |
tree | 89fc75a25e6b309441523f66768c974c0abc4deb | |
parent | f31d7259bd8d02774301d478d8e2027abd8b10cf (diff) |
add linear mapping function
-rw-r--r-- | examples/plot_otda_linear_mapping.py | 54 | ||||
-rw-r--r-- | ot/da.py | 36 |
2 files changed, 90 insertions, 0 deletions
diff --git a/examples/plot_otda_linear_mapping.py b/examples/plot_otda_linear_mapping.py new file mode 100644 index 0000000..eff2648 --- /dev/null +++ b/examples/plot_otda_linear_mapping.py @@ -0,0 +1,54 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Created on Tue Mar 20 14:31:15 2018 + +@author: rflamary +""" + +import numpy as np +import pylab as pl +import ot + + + +#%% + + +n=1000 +d=2 +sigma=.1 + +angles=np.random.rand(n,1)*2*np.pi +xs=np.concatenate((np.sin(angles),np.cos(angles)),axis=1)+sigma*np.random.randn(n,2) + +xs[:n//2,1]+=2 + +anglet=np.random.rand(n,1)*2*np.pi +xt=np.concatenate((np.sin(anglet),np.cos(anglet)),axis=1)+sigma*np.random.randn(n,2) +xt[:n//2,1]+=2 + + +A=np.array([[1.5,.7],[.7,1.5]]) +b=np.array([[4,2]]) +xt=xt.dot(A)+b + +#%% + +pl.figure(1,(5,5)) +pl.plot(xs[:,0],xs[:,1],'+') +pl.plot(xt[:,0],xt[:,1],'o') + +#%% + +Ae,be=ot.da.OT_mapping_linear(xs,xt) + +xst=xs.dot(Ae)+be + +##%% + +pl.figure(1,(5,5)) +pl.clf() +pl.plot(xs[:,0],xs[:,1],'+') +pl.plot(xt[:,0],xt[:,1],'o') +pl.plot(xst[:,0],xst[:,1],'+')
\ No newline at end of file @@ -10,6 +10,7 @@ Domain adaptation with optimal transport # License: MIT License import numpy as np +import scipy.linalg as linalg from .bregman import sinkhorn from .lp import emd @@ -633,6 +634,41 @@ def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian', return G, L +def OT_mapping_linear(xs, xt, reg=1e-6,ws=None,wt=None,log=False): + """ return OT linear operator between samples""" + + d=xs.shape[1] + + mxs=xs.mean(0,keepdims=True) + mxt=xt.mean(0,keepdims=True) + + + if ws is None: + ws=np.ones((xs.shape[0],1))/xs.shape[0] + + if wt is None: + wt=np.ones((xt.shape[0],1))/xt.shape[0] + + Cs=(xs*ws).T.dot(xs)/ws.sum()+reg*np.eye(d) + Ct=(xt*wt).T.dot(xt)/wt.sum()+reg*np.eye(d) + + + Cs12=linalg.sqrtm(Cs) + Cs_12=linalg.inv(Cs12) + + M0=linalg.sqrtm(Cs12.dot(Ct.dot(Cs12))) + + A=Cs_12.dot(M0.dot(Cs_12)).T + + b=mxt-mxs.dot(A) + + if log: + pass + else: + return A,b + + + @deprecated("The class OTDA is deprecated in 0.3.1 and will be " "removed in 0.5" "\n\tfor standard transport use class EMDTransport instead.") |