summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2017-04-20 15:47:51 +0200
committerRémi Flamary <remi.flamary@gmail.com>2017-04-20 15:47:51 +0200
commit26c128963fefaf7610fa6a1d48a676efc71dd137 (patch)
tree27810e0c7d4ea3d4b7ced7ce150db3efc844cb7f /test
parent2cc2f069b9241efdf2d0e0be1f7fbb6e6ab9dc45 (diff)
add test gpu
Diffstat (limited to 'test')
-rw-r--r--test/test_gpu_sinkhorn.py26
1 files changed, 26 insertions, 0 deletions
diff --git a/test/test_gpu_sinkhorn.py b/test/test_gpu_sinkhorn.py
new file mode 100644
index 0000000..bfa2cd2
--- /dev/null
+++ b/test/test_gpu_sinkhorn.py
@@ -0,0 +1,26 @@
+import ot
+import numpy as np
+import time
+import ot.gpu
+
+def describeRes(r):
+ print("min:{:.3E}, max::{:.3E}, mean::{:.3E}, std::{:.3E}".format(np.min(r),np.max(r),np.mean(r),np.std(r)))
+
+
+for n in [5000, 10000, 15000, 20000]:
+ print(n)
+ a = np.random.rand(n // 4, 100)
+ b = np.random.rand(n, 100)
+ time1 = time.time()
+ transport = ot.da.OTDA_sinkhorn()
+ transport.fit(a, b)
+ G1 = transport.G
+ time2 = time.time()
+ transport = ot.gpu.da.OTDA_sinkhorn()
+ transport.fit(a, b)
+ G2 = transport.G
+ time3 = time.time()
+ print("Normal sinkhorn, time: {:6.2f} sec ".format(time2 - time1))
+ describeRes(G1)
+ print(" GPU sinkhorn, time: {:6.2f} sec ".format(time3 - time2))
+ describeRes(G2) \ No newline at end of file