summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2018-03-20 15:05:57 +0100
committerRémi Flamary <remi.flamary@gmail.com>2018-03-20 15:05:57 +0100
commitcb739f625921e7fc19113d6d758e27ac69eac24b (patch)
tree89fc75a25e6b309441523f66768c974c0abc4deb
parentf31d7259bd8d02774301d478d8e2027abd8b10cf (diff)
add linear mapping function
-rw-r--r--examples/plot_otda_linear_mapping.py54
-rw-r--r--ot/da.py36
2 files changed, 90 insertions, 0 deletions
diff --git a/examples/plot_otda_linear_mapping.py b/examples/plot_otda_linear_mapping.py
new file mode 100644
index 0000000..eff2648
--- /dev/null
+++ b/examples/plot_otda_linear_mapping.py
@@ -0,0 +1,54 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+"""
+Created on Tue Mar 20 14:31:15 2018
+
+@author: rflamary
+"""
+
+import numpy as np
+import pylab as pl
+import ot
+
+
+
+#%%
+
+
+n=1000
+d=2
+sigma=.1
+
+angles=np.random.rand(n,1)*2*np.pi
+xs=np.concatenate((np.sin(angles),np.cos(angles)),axis=1)+sigma*np.random.randn(n,2)
+
+xs[:n//2,1]+=2
+
+anglet=np.random.rand(n,1)*2*np.pi
+xt=np.concatenate((np.sin(anglet),np.cos(anglet)),axis=1)+sigma*np.random.randn(n,2)
+xt[:n//2,1]+=2
+
+
+A=np.array([[1.5,.7],[.7,1.5]])
+b=np.array([[4,2]])
+xt=xt.dot(A)+b
+
+#%%
+
+pl.figure(1,(5,5))
+pl.plot(xs[:,0],xs[:,1],'+')
+pl.plot(xt[:,0],xt[:,1],'o')
+
+#%%
+
+Ae,be=ot.da.OT_mapping_linear(xs,xt)
+
+xst=xs.dot(Ae)+be
+
+##%%
+
+pl.figure(1,(5,5))
+pl.clf()
+pl.plot(xs[:,0],xs[:,1],'+')
+pl.plot(xt[:,0],xt[:,1],'o')
+pl.plot(xst[:,0],xst[:,1],'+') \ No newline at end of file
diff --git a/ot/da.py b/ot/da.py
index c688654..63bee5a 100644
--- a/ot/da.py
+++ b/ot/da.py
@@ -10,6 +10,7 @@ Domain adaptation with optimal transport
# License: MIT License
import numpy as np
+import scipy.linalg as linalg
from .bregman import sinkhorn
from .lp import emd
@@ -633,6 +634,41 @@ def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian',
return G, L
+def OT_mapping_linear(xs, xt, reg=1e-6,ws=None,wt=None,log=False):
+ """ return OT linear operator between samples"""
+
+ d=xs.shape[1]
+
+ mxs=xs.mean(0,keepdims=True)
+ mxt=xt.mean(0,keepdims=True)
+
+
+ if ws is None:
+ ws=np.ones((xs.shape[0],1))/xs.shape[0]
+
+ if wt is None:
+ wt=np.ones((xt.shape[0],1))/xt.shape[0]
+
+ Cs=(xs*ws).T.dot(xs)/ws.sum()+reg*np.eye(d)
+ Ct=(xt*wt).T.dot(xt)/wt.sum()+reg*np.eye(d)
+
+
+ Cs12=linalg.sqrtm(Cs)
+ Cs_12=linalg.inv(Cs12)
+
+ M0=linalg.sqrtm(Cs12.dot(Ct.dot(Cs12)))
+
+ A=Cs_12.dot(M0.dot(Cs_12)).T
+
+ b=mxt-mxs.dot(A)
+
+ if log:
+ pass
+ else:
+ return A,b
+
+
+
@deprecated("The class OTDA is deprecated in 0.3.1 and will be "
"removed in 0.5"
"\n\tfor standard transport use class EMDTransport instead.")