diff options
Diffstat (limited to 'benchmarks/emd.py')
-rw-r--r-- | benchmarks/emd.py | 40 |
1 files changed, 40 insertions, 0 deletions
diff --git a/benchmarks/emd.py b/benchmarks/emd.py new file mode 100644 index 0000000..9f64863 --- /dev/null +++ b/benchmarks/emd.py @@ -0,0 +1,40 @@ +# /usr/bin/env python3 +# -*- coding: utf-8 -*- + +import numpy as np +import ot +from .benchmark import ( + setup_backends, + exec_bench, + convert_to_html_table +) + + +def setup(n_samples): + rng = np.random.RandomState(789465132) + x = rng.randn(n_samples, 2) + y = rng.randn(n_samples, 2) + + a = ot.utils.unif(n_samples) + M = ot.dist(x, y) + return a, M + + +if __name__ == "__main__": + n_runs = 100 + warmup_runs = 10 + param_list = [50, 100, 500, 1000, 2000, 5000] + + setup_backends() + results = exec_bench( + setup=setup, + tested_function=lambda a, M: ot.emd(a, a, M), + param_list=param_list, + n_runs=n_runs, + warmup_runs=warmup_runs + ) + print(convert_to_html_table( + results, + param_name="Sample size", + main_title=f"EMD - Averaged on {n_runs} runs" + )) |