summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.github/requirements_strict.txt9
-rw-r--r--.github/workflows/build_tests.yml37
-rw-r--r--.gitignore3
-rw-r--r--Makefile4
-rw-r--r--ot/lp/__init__.py2
-rw-r--r--test/test_da.py8
6 files changed, 42 insertions, 21 deletions
diff --git a/.github/requirements_strict.txt b/.github/requirements_strict.txt
index d7539c5..9a1ada4 100644
--- a/.github/requirements_strict.txt
+++ b/.github/requirements_strict.txt
@@ -1,7 +1,4 @@
-numpy==1.16.*
-scipy==1.0.*
-cython==0.23.*
-matplotlib
-cvxopt
-scikit-learn
+numpy
+scipy>=1.3
+cython
pytest
diff --git a/.github/workflows/build_tests.yml b/.github/workflows/build_tests.yml
index 41b08b3..fa814ba 100644
--- a/.github/workflows/build_tests.yml
+++ b/.github/workflows/build_tests.yml
@@ -30,14 +30,7 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install -r requirements.txt
- pip install flake8 pytest "pytest-cov<2.6" codecov
- pip install -U "sklearn"
- - name: Lint with flake8
- run: |
- # stop the build if there are Python syntax errors or undefined names
- flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
- # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
- flake8 examples/ ot/ test/ --count --max-line-length=127 --statistics
+ pip install pytest "pytest-cov<2.6" codecov
- name: Install POT
run: |
pip install -e .
@@ -48,6 +41,29 @@ jobs:
run: |
codecov
+ pep8:
+ runs-on: ubuntu-latest
+ strategy:
+ max-parallel: 4
+ matrix:
+ python-version: [3.8]
+
+ steps:
+ - uses: actions/checkout@v1
+ - name: Set up Python ${{ matrix.python-version }}
+ uses: actions/setup-python@v1
+ with:
+ python-version: ${{ matrix.python-version }}
+ - name: Install dependencies
+ run: |
+ python -m pip install --upgrade pip
+ pip install flake8
+ - name: Lint with flake8
+ run: |
+ # stop the build if there are Python syntax errors or undefined names
+ flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
+ # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
+ flake8 examples/ ot/ test/ --count --max-line-length=127 --statistics
linux-minimal-deps:
@@ -55,7 +71,7 @@ jobs:
strategy:
max-parallel: 4
matrix:
- python-version: [3.6]
+ python-version: [3.8]
steps:
- uses: actions/checkout@v1
@@ -68,7 +84,6 @@ jobs:
python -m pip install --upgrade pip
pip install -r .github/requirements_strict.txt
pip install pytest
- pip install -U "sklearn"
- name: Install POT
run: |
pip install -e .
@@ -95,7 +110,6 @@ jobs:
python -m pip install --upgrade pip
pip install -r requirements.txt
pip install pytest "pytest-cov<2.6"
- pip install -U "sklearn"
- name: Install POT
run: |
pip install -e .
@@ -122,7 +136,6 @@ jobs:
python -m pip install --upgrade pip
pip install -r requirements.txt
pip install pytest "pytest-cov<2.6"
- pip install -U "sklearn"
- name: Install POT
run: |
pip install -e .
diff --git a/.gitignore b/.gitignore
index a2ace7c..b44ea43 100644
--- a/.gitignore
+++ b/.gitignore
@@ -40,6 +40,9 @@ var/
*.manifest
*.spec
+# env
+pythonenv3.8/
+
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
diff --git a/Makefile b/Makefile
index 70cdbdd..32332b4 100644
--- a/Makefile
+++ b/Makefile
@@ -45,10 +45,10 @@ pep8 :
flake8 examples/ ot/ test/
test : FORCE pep8
- $(PYTHON) -m pytest -v test/ --doctest-modules --ignore ot/gpu/ --cov=ot --cov-report html:cov_html
+ $(PYTHON) -m pytest -v test/ --doctest-modules --ignore ot/gpu/
pytest : FORCE
- $(PYTHON) -m pytest -v test/ --doctest-modules --ignore ot/gpu/ --cov=ot
+ $(PYTHON) -m pytest -v test/ --doctest-modules --ignore ot/gpu/
release :
twine upload dist/*
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py
index 2a1b082..f08e020 100644
--- a/ot/lp/__init__.py
+++ b/ot/lp/__init__.py
@@ -426,7 +426,7 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
nb = b.shape[1]
if processes > 1:
- res = parmap(f, [b[:, i] for i in range(nb)], processes)
+ res = parmap(f, [b[:, i].copy() for i in range(nb)], processes)
else:
res = list(map(f, [b[:, i].copy() for i in range(nb)]))
diff --git a/test/test_da.py b/test/test_da.py
index 3b28119..52c6a48 100644
--- a/test/test_da.py
+++ b/test/test_da.py
@@ -6,11 +6,18 @@
import numpy as np
from numpy.testing import assert_allclose, assert_equal
+import pytest
import ot
from ot.datasets import make_data_classif
from ot.utils import unif
+try: # test if cudamat installed
+ import sklearn # noqa: F401
+ nosklearn = False
+except ImportError:
+ nosklearn = True
+
def test_sinkhorn_lpl1_transport_class():
"""test_sinkhorn_transport
@@ -691,6 +698,7 @@ def test_jcpot_barycenter():
np.testing.assert_allclose(prop, [1 - pt, pt], rtol=1e-3, atol=1e-3)
+@pytest.mark.skipif(nosklearn, reason="No sklearn available")
def test_emd_laplace_class():
"""test_emd_laplace_transport
"""