summaryrefslogtreecommitdiff
path: root/examples/plot_otda_linear_mapping.py
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2018-03-20 16:27:49 +0100
committerRémi Flamary <remi.flamary@gmail.com>2018-03-20 16:27:49 +0100
commitc1046238d826fe9cf1294f8ea60b8d44743fac78 (patch)
tree505a707c8458c74901952fd84d9054f7922a86c6 /examples/plot_otda_linear_mapping.py
parent8fc9fce6c920c646ea7324ac0af54ad53e9aa1bf (diff)
passing tests
Diffstat (limited to 'examples/plot_otda_linear_mapping.py')
-rw-r--r--examples/plot_otda_linear_mapping.py74
1 files changed, 38 insertions, 36 deletions
diff --git a/examples/plot_otda_linear_mapping.py b/examples/plot_otda_linear_mapping.py
index 44aa9c5..143f129 100644
--- a/examples/plot_otda_linear_mapping.py
+++ b/examples/plot_otda_linear_mapping.py
@@ -15,69 +15,71 @@ import scipy.linalg as linalg
#%%
-n=1000
-d=2
-sigma=.1
+n = 1000
+d = 2
+sigma = .1
# source samples
-angles=np.random.rand(n,1)*2*np.pi
-xs=np.concatenate((np.sin(angles),np.cos(angles)),axis=1)+sigma*np.random.randn(n,2)
-xs[:n//2,1]+=2
+angles = np.random.rand(n, 1) * 2 * np.pi
+xs = np.concatenate((np.sin(angles), np.cos(angles)),
+ axis=1) + sigma * np.random.randn(n, 2)
+xs[:n // 2, 1] += 2
# target samples
-anglet=np.random.rand(n,1)*2*np.pi
-xt=np.concatenate((np.sin(anglet),np.cos(anglet)),axis=1)+sigma*np.random.randn(n,2)
-xt[:n//2,1]+=2
+anglet = np.random.rand(n, 1) * 2 * np.pi
+xt = np.concatenate((np.sin(anglet), np.cos(anglet)),
+ axis=1) + sigma * np.random.randn(n, 2)
+xt[:n // 2, 1] += 2
-A=np.array([[1.5,.7],[.7,1.5]])
-b=np.array([[4,2]])
-xt=xt.dot(A)+b
+A = np.array([[1.5, .7], [.7, 1.5]])
+b = np.array([[4, 2]])
+xt = xt.dot(A) + b
#%%
-pl.figure(1,(5,5))
-pl.plot(xs[:,0],xs[:,1],'+')
-pl.plot(xt[:,0],xt[:,1],'o')
+pl.figure(1, (5, 5))
+pl.plot(xs[:, 0], xs[:, 1], '+')
+pl.plot(xt[:, 0], xt[:, 1], 'o')
#%%
-Ae,be=ot.da.OT_mapping_linear(xs,xt)
+Ae, be = ot.da.OT_mapping_linear(xs, xt)
-Ae1=linalg.inv(Ae)
-be1=-be.dot(Ae1)
+Ae1 = linalg.inv(Ae)
+be1 = -be.dot(Ae1)
-xst=xs.dot(Ae)+be
-xts=xt.dot(Ae1)+be1
+xst = xs.dot(Ae) + be
+xts = xt.dot(Ae1) + be1
-##%%
+# %%
-pl.figure(1,(5,5))
+pl.figure(1, (5, 5))
pl.clf()
-pl.plot(xs[:,0],xs[:,1],'+')
-pl.plot(xt[:,0],xt[:,1],'o')
-pl.plot(xst[:,0],xst[:,1],'+')
-pl.plot(xts[:,0],xts[:,1],'o')
+pl.plot(xs[:, 0], xs[:, 1], '+')
+pl.plot(xt[:, 0], xt[:, 1], 'o')
+pl.plot(xst[:, 0], xst[:, 1], '+')
+pl.plot(xts[:, 0], xts[:, 1], 'o')
pl.show()
#%% Example class with on images
-mapping=ot.da.LinearTransport()
+mapping = ot.da.LinearTransport()
-mapping.fit(Xs=xs,Xt=xt)
+mapping.fit(Xs=xs, Xt=xt)
-xst=mapping.transform(Xs=xs)
-xts=mapping.inverse_transform(Xt=xt)
+xst = mapping.transform(Xs=xs)
+xts = mapping.inverse_transform(Xt=xt)
-##%%
+# %%
-pl.figure(1,(5,5))
+pl.figure(1, (5, 5))
pl.clf()
-pl.plot(xs[:,0],xs[:,1],'+')
-pl.plot(xt[:,0],xt[:,1],'o')
-pl.plot(xst[:,0],xst[:,1],'+')
-pl.plot(xts[:,0],xts[:,1],'o')
+pl.plot(xs[:, 0], xs[:, 1], '+')
+pl.plot(xt[:, 0], xt[:, 1], 'o')
+pl.plot(xst[:, 0], xst[:, 1], '+')
+pl.plot(xts[:, 0], xts[:, 1], 'o')