summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNathan Cassereau <84033440+ncassereau-idris@users.noreply.github.com>2022-11-21 17:27:50 +0100
committerGitHub <noreply@github.com>2022-11-21 17:27:50 +0100
commitfa0d4f2afff73284f4b79bfebb085eed332c112f (patch)
tree2b4f175da95f9a4097ca0ab1513b71862a9ee610
parente433775c2015eb85c2683b6955618c2836f001bc (diff)
[MRG] Replaces numpy compiler with setuptools (#409)
* Numpy ccompiler deprecation handled with setuptools ccompiler * Remove useless OMP Macro, already provides _OPENMP * RELEASES.md * Remove forgotten temporary bug added for logging purposes
-rw-r--r--RELEASES.md1
-rw-r--r--ot/helpers/pre_build_helpers.py24
-rw-r--r--ot/lp/network_simplex_simple_omp.h6
-rw-r--r--setup.py2
-rw-r--r--test/test_ot.py16
5 files changed, 23 insertions, 26 deletions
diff --git a/RELEASES.md b/RELEASES.md
index 1c7b7da..564fd4a 100644
--- a/RELEASES.md
+++ b/RELEASES.md
@@ -25,6 +25,7 @@ roughly 2^31) (PR #381)
- Fixed an issue where a pytorch example would throw an error if executed on a GPU (Issue #389, PR #391)
- Added a work-around for scipy's bug, where you cannot compute the Hamming distance with a "None" weight attribute. (Issue #400, PR #402)
- Fixed an issue where the doc could not be built due to some changes in matplotlib's API (Issue #403, PR #402)
+- Replaced Numpy C Compiler with Setuptools C Compiler due to deprecation issues (Issue #408, PR #409)
## 0.8.2
diff --git a/ot/helpers/pre_build_helpers.py b/ot/helpers/pre_build_helpers.py
index 93ecd6a..2930036 100644
--- a/ot/helpers/pre_build_helpers.py
+++ b/ot/helpers/pre_build_helpers.py
@@ -4,34 +4,14 @@ import os
import sys
import glob
import tempfile
-import setuptools # noqa
import subprocess
-from distutils.dist import Distribution
-from distutils.sysconfig import customize_compiler
-from numpy.distutils.ccompiler import new_compiler
-from numpy.distutils.command.config_compiler import config_cc
+from setuptools.command.build_ext import customize_compiler, new_compiler
def _get_compiler():
- """Get a compiler equivalent to the one that will be used to build POT
- Handles compiler specified as follows:
- - python setup.py build_ext --compiler=<compiler>
- - CC=<compiler> python setup.py build_ext
- """
- dist = Distribution({'script_name': os.path.basename(sys.argv[0]),
- 'script_args': sys.argv[1:],
- 'cmdclass': {'config_cc': config_cc}})
-
- cmd_opts = dist.command_options.get('build_ext')
- if cmd_opts is not None and 'compiler' in cmd_opts:
- compiler = cmd_opts['compiler'][1]
- else:
- compiler = None
-
- ccompiler = new_compiler(compiler=compiler)
+ ccompiler = new_compiler()
customize_compiler(ccompiler)
-
return ccompiler
diff --git a/ot/lp/network_simplex_simple_omp.h b/ot/lp/network_simplex_simple_omp.h
index c324d4c..890b7ab 100644
--- a/ot/lp/network_simplex_simple_omp.h
+++ b/ot/lp/network_simplex_simple_omp.h
@@ -67,7 +67,7 @@
//#include "core.h"
//#include "lmath.h"
-#ifdef OMP
+#ifdef _OPENMP
#include <omp.h>
#endif
#include <cmath>
@@ -254,7 +254,7 @@ namespace lemon_omp {
// Reset data structures
reset();
max_iter = maxiters;
-#ifdef OMP
+#ifdef _OPENMP
if (max_threads < 0) {
max_threads = omp_get_max_threads();
}
@@ -513,7 +513,7 @@ namespace lemon_omp {
int j;
#pragma omp parallel
{
-#ifdef OMP
+#ifdef _OPENMP
int t = omp_get_thread_num();
#else
int t = 0;
diff --git a/setup.py b/setup.py
index c03191a..dc9066d 100644
--- a/setup.py
+++ b/setup.py
@@ -37,7 +37,7 @@ compile_args = ["/O2" if sys.platform == "win32" else "-O3"]
link_args = []
if openmp_supported:
- compile_args += flags + ["/DOMP" if sys.platform == 'win32' else "-DOMP"]
+ compile_args += flags
link_args += flags
if sys.platform.startswith('darwin'):
diff --git a/test/test_ot.py b/test/test_ot.py
index 9a4e175..f2338ac 100644
--- a/test/test_ot.py
+++ b/test/test_ot.py
@@ -204,6 +204,22 @@ def test_emd_emd2():
np.testing.assert_allclose(w, 0)
+def test_omp_emd2():
+ # test emd2 and emd2 with openmp for simple identity
+ n = 100
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 2)
+ u = ot.utils.unif(n)
+
+ M = ot.dist(x, x)
+
+ w = ot.emd2(u, u, M)
+ w2 = ot.emd2(u, u, M, numThreads=2)
+
+ np.testing.assert_allclose(w, w2)
+
+
def test_emd_empty():
# test emd and emd2 for simple identity
n = 100