summaryrefslogtreecommitdiff
path: root/benchmarks/emd.py
diff options
context:
space:
mode:
authorGard Spreemann <gspr@nonempty.org>2021-12-29 19:26:32 +0100
committerGard Spreemann <gspr@nonempty.org>2021-12-29 19:26:32 +0100
commit367366a649f57a147456f11f7e803de12ced3b8f (patch)
treea900af1302f4a6923323d203ae8cc22550b59e8f /benchmarks/emd.py
parent88d850422a838c29d70ef757d04ab57707d7cd26 (diff)
parentedab1c60630f95b38db430017585d06253c92817 (diff)
Merge branch 'dfsg/latest' into debian/sid
Diffstat (limited to 'benchmarks/emd.py')
-rw-r--r--benchmarks/emd.py40
1 files changed, 40 insertions, 0 deletions
diff --git a/benchmarks/emd.py b/benchmarks/emd.py
new file mode 100644
index 0000000..9f64863
--- /dev/null
+++ b/benchmarks/emd.py
@@ -0,0 +1,40 @@
+# /usr/bin/env python3
+# -*- coding: utf-8 -*-
+
+import numpy as np
+import ot
+from .benchmark import (
+ setup_backends,
+ exec_bench,
+ convert_to_html_table
+)
+
+
+def setup(n_samples):
+ rng = np.random.RandomState(789465132)
+ x = rng.randn(n_samples, 2)
+ y = rng.randn(n_samples, 2)
+
+ a = ot.utils.unif(n_samples)
+ M = ot.dist(x, y)
+ return a, M
+
+
+if __name__ == "__main__":
+ n_runs = 100
+ warmup_runs = 10
+ param_list = [50, 100, 500, 1000, 2000, 5000]
+
+ setup_backends()
+ results = exec_bench(
+ setup=setup,
+ tested_function=lambda a, M: ot.emd(a, a, M),
+ param_list=param_list,
+ n_runs=n_runs,
+ warmup_runs=warmup_runs
+ )
+ print(convert_to_html_table(
+ results,
+ param_name="Sample size",
+ main_title=f"EMD - Averaged on {n_runs} runs"
+ ))