summaryrefslogtreecommitdiff
path: root/test/test_gpu.py
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2017-07-24 11:15:33 +0200
committerRémi Flamary <remi.flamary@gmail.com>2017-07-24 11:15:33 +0200
commit5a6b5de9b2f28c93bef1a9db2e3b94693c05ff4f (patch)
tree1f7457a028ef71253be36c44fb87c2e4131e909a /test/test_gpu.py
parent82da63f1020835a412f6174500099694a78ab6be (diff)
add proper testing
Diffstat (limited to 'test/test_gpu.py')
-rw-r--r--test/test_gpu.py59
1 files changed, 59 insertions, 0 deletions
diff --git a/test/test_gpu.py b/test/test_gpu.py
new file mode 100644
index 0000000..312a2d4
--- /dev/null
+++ b/test/test_gpu.py
@@ -0,0 +1,59 @@
+import ot
+import numpy as np
+import time
+import pytest
+
+
+@pytest.mark.skip(reason="No way to test GPU on travis yet")
+def test_gpu_sinkhorn():
+ import ot.gpu
+
+ def describeRes(r):
+ print("min:{:.3E}, max::{:.3E}, mean::{:.3E}, std::{:.3E}".format(
+ np.min(r), np.max(r), np.mean(r), np.std(r)))
+
+ for n in [5000]:
+ print(n)
+ a = np.random.rand(n // 4, 100)
+ b = np.random.rand(n, 100)
+ time1 = time.time()
+ transport = ot.da.OTDA_sinkhorn()
+ transport.fit(a, b)
+ G1 = transport.G
+ time2 = time.time()
+ transport = ot.gpu.da.OTDA_sinkhorn()
+ transport.fit(a, b)
+ G2 = transport.G
+ time3 = time.time()
+ print("Normal sinkhorn, time: {:6.2f} sec ".format(time2 - time1))
+ describeRes(G1)
+ print(" GPU sinkhorn, time: {:6.2f} sec ".format(time3 - time2))
+ describeRes(G2)
+
+
+@pytest.mark.skip(reason="No way to test GPU on travis yet")
+def test_gpu_sinkhorn_lpl1():
+ def describeRes(r):
+ print("min:{:.3E}, max:{:.3E}, mean:{:.3E}, std:{:.3E}"
+ .format(np.min(r), np.max(r), np.mean(r), np.std(r)))
+
+ for n in [5000]:
+ print(n)
+ a = np.random.rand(n // 4, 100)
+ labels_a = np.random.randint(10, size=(n // 4))
+ b = np.random.rand(n, 100)
+ time1 = time.time()
+ transport = ot.da.OTDA_lpl1()
+ transport.fit(a, labels_a, b)
+ G1 = transport.G
+ time2 = time.time()
+ transport = ot.gpu.da.OTDA_lpl1()
+ transport.fit(a, labels_a, b)
+ G2 = transport.G
+ time3 = time.time()
+ print("Normal sinkhorn lpl1, time: {:6.2f} sec ".format(
+ time2 - time1))
+ describeRes(G1)
+ print(" GPU sinkhorn lpl1, time: {:6.2f} sec ".format(
+ time3 - time2))
+ describeRes(G2)