summaryrefslogtreecommitdiff
path: root/benchmarks/sinkhorn_knopp.py
diff options
context:
space:
mode:
Diffstat (limited to 'benchmarks/sinkhorn_knopp.py')
-rw-r--r--benchmarks/sinkhorn_knopp.py42
1 files changed, 42 insertions, 0 deletions
diff --git a/benchmarks/sinkhorn_knopp.py b/benchmarks/sinkhorn_knopp.py
new file mode 100644
index 0000000..3a1ef3f
--- /dev/null
+++ b/benchmarks/sinkhorn_knopp.py
@@ -0,0 +1,42 @@
+# /usr/bin/env python3
+# -*- coding: utf-8 -*-
+
+import numpy as np
+import ot
+from .benchmark import (
+ setup_backends,
+ exec_bench,
+ convert_to_html_table
+)
+
+
+def setup(n_samples):
+ rng = np.random.RandomState(123456789)
+ a = rng.rand(n_samples // 4, 100)
+ b = rng.rand(n_samples, 100)
+
+ wa = ot.unif(n_samples // 4)
+ wb = ot.unif(n_samples)
+
+ M = ot.dist(a.copy(), b.copy())
+ return wa, wb, M
+
+
+if __name__ == "__main__":
+ n_runs = 100
+ warmup_runs = 10
+ param_list = [50, 100, 500, 1000, 2000, 5000]
+
+ setup_backends()
+ results = exec_bench(
+ setup=setup,
+ tested_function=lambda *args: ot.bregman.sinkhorn(*args, reg=1, stopThr=1e-7),
+ param_list=param_list,
+ n_runs=n_runs,
+ warmup_runs=warmup_runs
+ )
+ print(convert_to_html_table(
+ results,
+ param_name="Sample size",
+ main_title=f"Sinkhorn Knopp - Averaged on {n_runs} runs"
+ ))