diff options
Diffstat (limited to 'ot/helpers')
-rw-r--r-- | ot/helpers/__init__.py | 3 | ||||
-rw-r--r-- | ot/helpers/openmp_helpers.py | 85 | ||||
-rw-r--r-- | ot/helpers/pre_build_helpers.py | 87 |
3 files changed, 175 insertions, 0 deletions
diff --git a/ot/helpers/__init__.py b/ot/helpers/__init__.py new file mode 100644 index 0000000..b948671 --- /dev/null +++ b/ot/helpers/__init__.py @@ -0,0 +1,3 @@ +# Author: Remi Flamary <remi.flamary@unice.fr> +# +# License: MIT License diff --git a/ot/helpers/openmp_helpers.py b/ot/helpers/openmp_helpers.py new file mode 100644 index 0000000..a6ad38b --- /dev/null +++ b/ot/helpers/openmp_helpers.py @@ -0,0 +1,85 @@ +"""Helpers for OpenMP support during the build.""" + +# This code is adapted for a large part from the astropy openmp helpers, which +# can be found at: https://github.com/astropy/extension-helpers/blob/master/extension_helpers/_openmp_helpers.py # noqa + + +import os +import sys +import textwrap +import subprocess + +from distutils.errors import CompileError, LinkError + +from pre_build_helpers import compile_test_program + + +def get_openmp_flag(compiler): + """Get openmp flags for a given compiler""" + + if hasattr(compiler, 'compiler'): + compiler = compiler.compiler[0] + else: + compiler = compiler.__class__.__name__ + + if sys.platform == "win32" and ('icc' in compiler or 'icl' in compiler): + omp_flag = ['/Qopenmp'] + elif sys.platform == "win32": + omp_flag = ['/openmp'] + elif sys.platform in ("darwin", "linux") and "icc" in compiler: + omp_flag = ['-qopenmp'] + elif sys.platform == "darwin" and 'openmp' in os.getenv('CPPFLAGS', ''): + omp_flag = [] + else: + # Default flag for GCC and clang: + omp_flag = ['-fopenmp'] + if sys.platform.startswith("darwin"): + omp_flag += ["-Xpreprocessor", "-lomp"] + return omp_flag + + +def check_openmp_support(): + """Check whether OpenMP test code can be compiled and run""" + + code = textwrap.dedent( + """\ + #include <omp.h> + #include <stdio.h> + int main(void) { + #pragma omp parallel + printf("nthreads=%d\\n", omp_get_num_threads()); + return 0; + } + """) + + extra_preargs = os.getenv('LDFLAGS', None) + if extra_preargs is not None: + extra_preargs = extra_preargs.strip().split(" ") + extra_preargs = [ + flag for flag in extra_preargs + if flag.startswith(('-L', '-Wl,-rpath', '-l'))] + + extra_postargs = get_openmp_flag + + try: + output, compile_flags = compile_test_program( + code, + extra_preargs=extra_preargs, + extra_postargs=extra_postargs + ) + + if output and 'nthreads=' in output[0]: + nthreads = int(output[0].strip().split('=')[1]) + openmp_supported = len(output) == nthreads + elif "PYTHON_CROSSENV" in os.environ: + # Since we can't run the test program when cross-compiling + # assume that openmp is supported if the program can be + # compiled. + openmp_supported = True + else: + openmp_supported = False + + except (CompileError, LinkError, subprocess.CalledProcessError): + openmp_supported = False + compile_flags = [] + return openmp_supported, compile_flags diff --git a/ot/helpers/pre_build_helpers.py b/ot/helpers/pre_build_helpers.py new file mode 100644 index 0000000..93ecd6a --- /dev/null +++ b/ot/helpers/pre_build_helpers.py @@ -0,0 +1,87 @@ +"""Helpers to check build environment before actual build of POT""" + +import os +import sys +import glob +import tempfile +import setuptools # noqa +import subprocess + +from distutils.dist import Distribution +from distutils.sysconfig import customize_compiler +from numpy.distutils.ccompiler import new_compiler +from numpy.distutils.command.config_compiler import config_cc + + +def _get_compiler(): + """Get a compiler equivalent to the one that will be used to build POT + Handles compiler specified as follows: + - python setup.py build_ext --compiler=<compiler> + - CC=<compiler> python setup.py build_ext + """ + dist = Distribution({'script_name': os.path.basename(sys.argv[0]), + 'script_args': sys.argv[1:], + 'cmdclass': {'config_cc': config_cc}}) + + cmd_opts = dist.command_options.get('build_ext') + if cmd_opts is not None and 'compiler' in cmd_opts: + compiler = cmd_opts['compiler'][1] + else: + compiler = None + + ccompiler = new_compiler(compiler=compiler) + customize_compiler(ccompiler) + + return ccompiler + + +def compile_test_program(code, extra_preargs=[], extra_postargs=[]): + """Check that some C code can be compiled and run""" + ccompiler = _get_compiler() + + # extra_(pre/post)args can be a callable to make it possible to get its + # value from the compiler + if callable(extra_preargs): + extra_preargs = extra_preargs(ccompiler) + if callable(extra_postargs): + extra_postargs = extra_postargs(ccompiler) + + start_dir = os.path.abspath('.') + + with tempfile.TemporaryDirectory() as tmp_dir: + try: + os.chdir(tmp_dir) + + # Write test program + with open('test_program.c', 'w') as f: + f.write(code) + + os.mkdir('objects') + + # Compile, test program + ccompiler.compile(['test_program.c'], output_dir='objects', + extra_postargs=extra_postargs) + + # Link test program + objects = glob.glob( + os.path.join('objects', '*' + ccompiler.obj_extension)) + ccompiler.link_executable(objects, 'test_program', + extra_preargs=extra_preargs, + extra_postargs=extra_postargs) + + if "PYTHON_CROSSENV" not in os.environ: + # Run test program if not cross compiling + # will raise a CalledProcessError if return code was non-zero + output = subprocess.check_output('./test_program') + output = output.decode( + sys.stdout.encoding or 'utf-8').splitlines() + else: + # Return an empty output if we are cross compiling + # as we cannot run the test_program + output = [] + except Exception: + raise + finally: + os.chdir(start_dir) + + return output, extra_postargs |