diff options
-rw-r--r-- | .github/requirements_strict.txt | 9 | ||||
-rw-r--r-- | .github/workflows/build_tests.yml | 37 | ||||
-rw-r--r-- | .gitignore | 3 | ||||
-rw-r--r-- | Makefile | 4 | ||||
-rw-r--r-- | ot/lp/__init__.py | 2 | ||||
-rw-r--r-- | test/test_da.py | 8 |
6 files changed, 42 insertions, 21 deletions
diff --git a/.github/requirements_strict.txt b/.github/requirements_strict.txt index d7539c5..9a1ada4 100644 --- a/.github/requirements_strict.txt +++ b/.github/requirements_strict.txt @@ -1,7 +1,4 @@ -numpy==1.16.* -scipy==1.0.* -cython==0.23.* -matplotlib -cvxopt -scikit-learn +numpy +scipy>=1.3 +cython pytest diff --git a/.github/workflows/build_tests.yml b/.github/workflows/build_tests.yml index 41b08b3..fa814ba 100644 --- a/.github/workflows/build_tests.yml +++ b/.github/workflows/build_tests.yml @@ -30,14 +30,7 @@ jobs: run: | python -m pip install --upgrade pip pip install -r requirements.txt - pip install flake8 pytest "pytest-cov<2.6" codecov - pip install -U "sklearn" - - name: Lint with flake8 - run: | - # stop the build if there are Python syntax errors or undefined names - flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics - # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide - flake8 examples/ ot/ test/ --count --max-line-length=127 --statistics + pip install pytest "pytest-cov<2.6" codecov - name: Install POT run: | pip install -e . @@ -48,6 +41,29 @@ jobs: run: | codecov + pep8: + runs-on: ubuntu-latest + strategy: + max-parallel: 4 + matrix: + python-version: [3.8] + + steps: + - uses: actions/checkout@v1 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v1 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install flake8 + - name: Lint with flake8 + run: | + # stop the build if there are Python syntax errors or undefined names + flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics + # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide + flake8 examples/ ot/ test/ --count --max-line-length=127 --statistics linux-minimal-deps: @@ -55,7 +71,7 @@ jobs: strategy: max-parallel: 4 matrix: - python-version: [3.6] + python-version: [3.8] steps: - uses: actions/checkout@v1 @@ -68,7 +84,6 @@ jobs: python -m pip install --upgrade pip pip install -r .github/requirements_strict.txt pip install pytest - pip install -U "sklearn" - name: Install POT run: | pip install -e . @@ -95,7 +110,6 @@ jobs: python -m pip install --upgrade pip pip install -r requirements.txt pip install pytest "pytest-cov<2.6" - pip install -U "sklearn" - name: Install POT run: | pip install -e . @@ -122,7 +136,6 @@ jobs: python -m pip install --upgrade pip pip install -r requirements.txt pip install pytest "pytest-cov<2.6" - pip install -U "sklearn" - name: Install POT run: | pip install -e . @@ -40,6 +40,9 @@ var/ *.manifest *.spec +# env +pythonenv3.8/ + # Installer logs pip-log.txt pip-delete-this-directory.txt @@ -45,10 +45,10 @@ pep8 : flake8 examples/ ot/ test/ test : FORCE pep8 - $(PYTHON) -m pytest -v test/ --doctest-modules --ignore ot/gpu/ --cov=ot --cov-report html:cov_html + $(PYTHON) -m pytest -v test/ --doctest-modules --ignore ot/gpu/ pytest : FORCE - $(PYTHON) -m pytest -v test/ --doctest-modules --ignore ot/gpu/ --cov=ot + $(PYTHON) -m pytest -v test/ --doctest-modules --ignore ot/gpu/ release : twine upload dist/* diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index 2a1b082..f08e020 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -426,7 +426,7 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), nb = b.shape[1] if processes > 1: - res = parmap(f, [b[:, i] for i in range(nb)], processes) + res = parmap(f, [b[:, i].copy() for i in range(nb)], processes) else: res = list(map(f, [b[:, i].copy() for i in range(nb)])) diff --git a/test/test_da.py b/test/test_da.py index 3b28119..52c6a48 100644 --- a/test/test_da.py +++ b/test/test_da.py @@ -6,11 +6,18 @@ import numpy as np from numpy.testing import assert_allclose, assert_equal +import pytest import ot from ot.datasets import make_data_classif from ot.utils import unif +try: # test if cudamat installed + import sklearn # noqa: F401 + nosklearn = False +except ImportError: + nosklearn = True + def test_sinkhorn_lpl1_transport_class(): """test_sinkhorn_transport @@ -691,6 +698,7 @@ def test_jcpot_barycenter(): np.testing.assert_allclose(prop, [1 - pt, pt], rtol=1e-3, atol=1e-3) +@pytest.mark.skipif(nosklearn, reason="No sklearn available") def test_emd_laplace_class(): """test_emd_laplace_transport """ |