From e2cf8538b05f026d73c6033777699af77e7508b5 Mon Sep 17 00:00:00 2001 From: Leo gautheron Date: Mon, 24 Apr 2017 10:43:44 +0200 Subject: add GPU implementation sinkhorn lpl1 --- test/test_gpu_sinkhorn_lpl1.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) create mode 100644 test/test_gpu_sinkhorn_lpl1.py (limited to 'test') diff --git a/test/test_gpu_sinkhorn_lpl1.py b/test/test_gpu_sinkhorn_lpl1.py new file mode 100644 index 0000000..e6cdd31 --- /dev/null +++ b/test/test_gpu_sinkhorn_lpl1.py @@ -0,0 +1,28 @@ +import ot +import numpy as np +import time +import ot.gpu + + +def describeRes(r): + print("min:{:.3E}, max:{:.3E}, mean:{:.3E}, std:{:.3E}" + .format(np.min(r), np.max(r), np.mean(r), np.std(r))) + +for n in [5000, 10000, 15000, 20000]: + print(n) + a = np.random.rand(n // 4, 100) + labels_a = np.random.randint(10, size=(n // 4)) + b = np.random.rand(n, 100) + time1 = time.time() + transport = ot.da.OTDA_lpl1() + transport.fit(a, labels_a, b) + G1 = transport.G + time2 = time.time() + transport = ot.gpu.da.OTDA_lpl1() + transport.fit(a, labels_a, b) + G2 = transport.G + time3 = time.time() + print("Normal sinkhorn lpl1, time: {:6.2f} sec ".format(time2 - time1)) + describeRes(G1) + print(" GPU sinkhorn lpl1, time: {:6.2f} sec ".format(time3 - time2)) + describeRes(G2) -- cgit v1.2.3