summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGard Spreemann <gspr@nonempty.org>2021-11-09 17:05:13 +0100
committerGard Spreemann <gspr@nonempty.org>2021-11-09 17:05:13 +0100
commita9fdc844907decddf54bed3ebeea8d8b2cf0fc5c (patch)
tree449a03fce8fafb78b6badd12b6e633f1e5d73a64
parenta16b9471d7114ec08977479b7249efe747702b97 (diff)
parentf1628794d521a8dfa00af383b5e06cd6d34af619 (diff)
Merge tag '0.8.0' into dfsg/latest
-rw-r--r--.circleci/config.yml75
-rw-r--r--.coveragerc6
-rw-r--r--.github/CODE_OF_CONDUCT.md (renamed from CODE_OF_CONDUCT.md)0
-rw-r--r--.github/CONTRIBUTING.md (renamed from CONTRIBUTING.md)0
-rw-r--r--.github/ISSUE_TEMPLATE/bug_report.md43
-rw-r--r--.github/ISSUE_TEMPLATE/feature_request.md23
-rw-r--r--.github/PULL_REQUEST_TEMPLATE.md28
-rw-r--r--.github/requirements_strict.txt9
-rw-r--r--.github/requirements_test_windows.txt10
-rw-r--r--.github/workflows/build_tests.yml97
-rw-r--r--.github/workflows/build_wheels.yml62
-rw-r--r--.github/workflows/build_wheels_weekly.yml54
-rw-r--r--.gitignore3
-rw-r--r--MANIFEST.in2
-rw-r--r--Makefile4
-rw-r--r--README.md127
-rw-r--r--RELEASES.md192
-rw-r--r--_config.yml1
-rw-r--r--codecov.yml43
-rw-r--r--data/manhattan.npzbin0 -> 2320938 bytes
-rw-r--r--docs/Makefile5
-rw-r--r--docs/source/_templates/versions.html43
-rw-r--r--docs/source/all.rst2
-rw-r--r--docs/source/auto_examples/images/bak.pngbin0 -> 304669 bytes
-rw-r--r--docs/source/auto_examples/images/sinkhorn.pngbin0 -> 37204 bytes
-rw-r--r--docs/source/conf.py8
-rw-r--r--docs/source/quickstart.rst522
-rw-r--r--docs/source/readme.rst189
-rw-r--r--docs/source/releases.rst134
-rw-r--r--examples/README.txt2
-rw-r--r--examples/backends/README.txt4
-rw-r--r--examples/backends/plot_optim_gromov_pytorch.py260
-rw-r--r--examples/backends/plot_sliced_wass_grad_flow_pytorch.py185
-rw-r--r--examples/backends/plot_unmix_optim_torch.py161
-rw-r--r--examples/backends/plot_wass1d_torch.py152
-rw-r--r--examples/backends/plot_wass2_gan_torch.py227
-rw-r--r--examples/barycenters/plot_barycenter_1D.py63
-rw-r--r--examples/barycenters/plot_barycenter_lp_vs_entropic.py2
-rw-r--r--examples/barycenters/plot_convolutional_barycenter.py53
-rw-r--r--examples/barycenters/plot_debiased_barycenter.py131
-rw-r--r--examples/barycenters/plot_free_support_barycenter.py4
-rw-r--r--examples/domain-adaptation/plot_otda_color_images.py128
-rw-r--r--examples/domain-adaptation/plot_otda_jcpot.py4
-rw-r--r--examples/domain-adaptation/plot_otda_linear_mapping.py81
-rw-r--r--examples/domain-adaptation/plot_otda_mapping_colors_images.py128
-rw-r--r--examples/gromov/plot_barycenter_fgw.py2
-rw-r--r--examples/gromov/plot_fgw.py10
-rw-r--r--examples/gromov/plot_gromov.py34
-rwxr-xr-xexamples/gromov/plot_gromov_barycenter.py94
-rw-r--r--examples/plot_Intro_OT.py373
-rw-r--r--examples/plot_OT_1D_smooth.py2
-rw-r--r--examples/plot_OT_2D_samples.py2
-rw-r--r--examples/sliced-wasserstein/README.txt4
-rw-r--r--examples/sliced-wasserstein/plot_variance.py86
-rw-r--r--examples/unbalanced-partial/plot_UOT_1D.py3
-rwxr-xr-xexamples/unbalanced-partial/plot_partial_wass_and_gromov.py23
-rw-r--r--examples/unbalanced-partial/plot_regpath.py135
-rw-r--r--ot/__init__.py24
-rw-r--r--ot/backend.py1502
-rw-r--r--ot/bregman.py2811
-rw-r--r--ot/da.py507
-rw-r--r--ot/datasets.py12
-rw-r--r--ot/dr.py156
-rw-r--r--ot/gpu/__init__.py12
-rw-r--r--ot/gpu/bregman.py12
-rw-r--r--ot/gpu/da.py2
-rw-r--r--ot/gromov.py1312
-rw-r--r--ot/helpers/__init__.py3
-rw-r--r--ot/helpers/openmp_helpers.py85
-rw-r--r--ot/helpers/pre_build_helpers.py87
-rw-r--r--ot/lp/EMD.h5
-rw-r--r--ot/lp/EMD_wrapper.cpp124
-rw-r--r--ot/lp/__init__.py597
-rw-r--r--ot/lp/cvx.py3
-rw-r--r--ot/lp/emd_wrap.pyx32
-rw-r--r--ot/lp/full_bipartitegraph.h27
-rw-r--r--ot/lp/full_bipartitegraph_omp.h234
-rw-r--r--ot/lp/network_simplex_simple.h212
-rw-r--r--ot/lp/network_simplex_simple_omp.h1699
-rw-r--r--ot/lp/solver_1d.py367
-rw-r--r--ot/optim.py189
-rwxr-xr-xot/partial.py352
-rw-r--r--ot/plot.py10
-rw-r--r--ot/regpath.py827
-rw-r--r--ot/sliced.py258
-rw-r--r--ot/smooth.py183
-rw-r--r--ot/stochastic.py192
-rw-r--r--ot/unbalanced.py220
-rw-r--r--ot/utils.py269
-rw-r--r--pyproject.toml3
-rw-r--r--requirements.txt3
-rw-r--r--setup.cfg2
-rw-r--r--[-rwxr-xr-x]setup.py138
-rw-r--r--test/conftest.py62
-rw-r--r--test/test_1d_solver.py172
-rw-r--r--test/test_backend.py577
-rw-r--r--test/test_bregman.py718
-rw-r--r--test/test_da.py24
-rw-r--r--test/test_dr.py62
-rw-r--r--test/test_gromov.py523
-rw-r--r--test/test_helpers.py26
-rw-r--r--test/test_optim.py103
-rw-r--r--test/test_ot.py183
-rwxr-xr-xtest/test_partial.py16
-rw-r--r--test/test_regpath.py64
-rw-r--r--test/test_sliced.py213
-rw-r--r--test/test_smooth.py12
-rw-r--r--test/test_stochastic.py52
-rw-r--r--test/test_unbalanced.py33
-rw-r--r--test/test_utils.py84
110 files changed, 15817 insertions, 3613 deletions
diff --git a/.circleci/config.yml b/.circleci/config.yml
index 9701ad1..85f8073 100644
--- a/.circleci/config.yml
+++ b/.circleci/config.yml
@@ -4,7 +4,8 @@ version: 2
jobs:
build_docs:
docker:
- - image: circleci/python:3.7-stretch
+ - image: cimg/python:3.9
+ resource_class: medium
steps:
- checkout
- run:
@@ -23,7 +24,7 @@ jobs:
echo ${CI_PULL_REQUEST//*pull\//} | tee merge.txt
if [[ $(cat merge.txt) != "" ]]; then
echo "Merging $(cat merge.txt)";
- git remote add upstream git://github.com/PythonOT/POT.git;
+ git remote add upstream https://github.com/PythonOT/POT.git;
git pull --ff-only upstream "refs/pull/$(cat merge.txt)/merge";
git fetch upstream master;
fi
@@ -35,25 +36,14 @@ jobs:
- pip-cache
- run:
- name: Spin up Xvfb
- command: |
- /sbin/start-stop-daemon --start --quiet --pidfile /tmp/custom_xvfb_99.pid --make-pidfile --background --exec /usr/bin/Xvfb -- :99 -screen 0 1400x900x24 -ac +extension GLX +render -noreset;
-
- # https://github.com/ContinuumIO/anaconda-issues/issues/9190#issuecomment-386508136
- # https://github.com/golemfactory/golem/issues/1019
- - run:
- name: Fix libgcc_s.so.1 pthread_cancel bug
- command: |
- sudo apt-get install qt5-default
-
- - run:
name: Get Python running
command: |
python -m pip install --user --upgrade --progress-bar off pip
- python -m pip install --user --upgrade --progress-bar off -r requirements.txt
+ python -m pip install --user -e .
+ python -m pip install --user --upgrade --no-cache-dir --progress-bar off -r requirements.txt
python -m pip install --user --upgrade --progress-bar off -r docs/requirements.txt
python -m pip install --user --upgrade --progress-bar off ipython "https://api.github.com/repos/sphinx-gallery/sphinx-gallery/zipball/master" memory_profiler
- python -m pip install --user -e .
+
- save_cache:
key: pip-cache
@@ -73,6 +63,7 @@ jobs:
command: |
cd docs;
make html;
+ no_output_timeout: 30m
# Save the outputs
- store_artifacts:
@@ -83,7 +74,47 @@ jobs:
paths:
- html
- deploy:
+ deploy_master:
+ docker:
+ - image: circleci/python:3.6-jessie
+ steps:
+ - attach_workspace:
+ at: /tmp/build
+ - run:
+ name: Fetch docs
+ command: |
+ set -e
+ mkdir -p ~/.ssh
+ echo -e "Host *\nStrictHostKeyChecking no" > ~/.ssh/config
+ chmod og= ~/.ssh/config
+ if [ ! -d ~/PythonOT.github.io ]; then
+ git clone git@github.com:/PythonOT/PythonOT.github.io.git ~/PythonOT.github.io --depth=1
+ fi
+ - run:
+ name: Deploy docs
+ command: |
+ set -e;
+ if [ "${CIRCLE_BRANCH}" == "master" ]; then
+ git config --global user.email "circle@PythonOT.com";
+ git config --global user.name "Circle CI";
+ cd ~/PythonOT.github.io;
+ git checkout master
+ git remote -v
+ git fetch origin
+ git reset --hard origin/master
+ git clean -xdf
+ echo "Deploying dev docs for ${CIRCLE_BRANCH}.";
+ cd master
+ cp -a /tmp/build/html/* .;
+ touch .nojekyll;
+ git add -A;
+ git commit -m "CircleCI update of dev docs (${CIRCLE_BUILD_NUM}).";
+ git push origin master;
+ else
+ echo "No deployment (build: ${CIRCLE_BRANCH}).";
+ fi
+
+ deploy_tag:
docker:
- image: circleci/python:3.6-jessie
steps:
@@ -122,16 +153,24 @@ jobs:
echo "No deployment (build: ${CIRCLE_BRANCH}).";
fi
+
workflows:
version: 2
default:
+
jobs:
- build_docs
- - deploy:
+ - deploy_master:
requires:
- build_docs
filters:
branches:
only:
- master
+ - deploy_tag:
+ filters:
+ branches:
+ ignore: /.*/
+ tags:
+ only: /[0-9]+(\.[0-9]+)*$/ \ No newline at end of file
diff --git a/.coveragerc b/.coveragerc
deleted file mode 100644
index 2114fb4..0000000
--- a/.coveragerc
+++ /dev/null
@@ -1,6 +0,0 @@
-[run]
-
-omit=
- ot/externals/*
- ot/externals/funcsigs.py
- ot/gpu/*
diff --git a/CODE_OF_CONDUCT.md b/.github/CODE_OF_CONDUCT.md
index 9c1c621..9c1c621 100644
--- a/CODE_OF_CONDUCT.md
+++ b/.github/CODE_OF_CONDUCT.md
diff --git a/CONTRIBUTING.md b/.github/CONTRIBUTING.md
index 54e7e42..54e7e42 100644
--- a/CONTRIBUTING.md
+++ b/.github/CONTRIBUTING.md
diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md
index 7f8acda..f24d993 100644
--- a/.github/ISSUE_TEMPLATE/bug_report.md
+++ b/.github/ISSUE_TEMPLATE/bug_report.md
@@ -1,27 +1,42 @@
---
name: Bug report
about: Create a report to help us improve POT
+title: ''
+labels: bug, help wanted
+assignees: ''
---
-**Describe the bug**
-A clear and concise description of what the bug is.
+## Describe the bug
+<!-- A clear and concise description of what the bug is. -->
-**To Reproduce**
+### To Reproduce
Steps to reproduce the behavior:
-1 ...
+1. ...
2.
-**Expected behavior**
-A clear and concise description of what you expected to happen.
+<!-- If you have error messages or stack traces, please provide it here as well -->
-**Screenshots**
-If applicable, add screenshots to help explain your problem.
+#### Screenshots
+<!-- If applicable, add screenshots to help explain your problem. -->
-**Desktop (please complete the following information):**
- - OS: [e.g. MacOSX, Windows, Ubuntu]
- - Python version [2.7,3.6]
-- How was POT installed [source, pip, conda]
+#### Code sample
+<!-- Ideally attach a minimal code sample to reproduce the decried issue.
+Minimal means having the shortest code but still preserving the bug. -->
+
+### Expected behavior
+<!-- A clear and concise description of what you expected to happen. -->
+
+
+### Environment (please complete the following information):
+- OS (e.g. MacOS, Windows, Linux):
+- Python version:
+- How was POT installed (source, `pip`, `conda`):
+- Build command you used (if compiling from source):
+- Only for GPU related bugs:
+ - CUDA version:
+ - GPU models and configuration:
+ - Any other relevant information:
Output of the following code snippet:
```python
@@ -32,5 +47,5 @@ import scipy; print("SciPy", scipy.__version__)
import ot; print("POT", ot.__version__)
```
-**Additional context**
-Add any other context about the problem here.
+### Additional context
+<!-- Add any other context about the problem here. -->
diff --git a/.github/ISSUE_TEMPLATE/feature_request.md b/.github/ISSUE_TEMPLATE/feature_request.md
new file mode 100644
index 0000000..2ee07e0
--- /dev/null
+++ b/.github/ISSUE_TEMPLATE/feature_request.md
@@ -0,0 +1,23 @@
+---
+name: Feature request
+about: Suggest an idea for this project
+title: ''
+labels: enhancement, feature request
+assignees: ''
+
+---
+
+## 🚀 Feature
+<!-- A clear and concise description of the feature proposal -->
+
+### Motivation
+<!-- Please outline the motivation for the proposal. Is your feature request related to a problem? e.g., I'm always frustrated when [...]. If this is related to another GitHub issue, please link here too -->
+
+### Pitch
+<!-- A clear and concise description of what you want to happen. -->
+
+### Alternatives
+<!-- A clear and concise description of any alternative solutions or features you've considered, if any. -->
+
+### Additional context
+<!-- Add any other context or screenshots about the feature request here. -->
diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md
new file mode 100644
index 0000000..7cfe4e6
--- /dev/null
+++ b/.github/PULL_REQUEST_TEMPLATE.md
@@ -0,0 +1,28 @@
+## Types of changes
+<!--- What types of changes does your code introduce? Put an `x` in all the boxes that apply: -->
+
+- [ ] Docs change / refactoring / dependency upgrade
+- [ ] Bug fix (non-breaking change which fixes an issue)
+- [ ] New feature (non-breaking change which adds functionality)
+- [ ] Breaking change (fix or feature that would cause existing functionality to change)
+
+
+## Motivation and context / Related issue
+<!--- Why is this change required? What problem does it solve? -->
+<!--- Please link to an existing issue here if one exists. -->
+<!--- (we recommend to have an existing issue for each pull request) -->
+
+
+## How has this been tested (if it applies)
+<!--- Please describe here how your modifications have been tested. -->
+
+
+## Checklist
+<!-- - Go over all the following points, and put an `x` in all the boxes that apply. -->
+<!--- If you're unsure about any of these, don't hesitate to ask. We're here to help! -->
+
+- [ ] The documentation is up-to-date with the changes I made.
+- [ ] I have read the [**CONTRIBUTING**](CONTRIBUTING.md) document.
+- [ ] All tests passed, and additional code has been covered with new tests.
+
+<!--- In any case, don't hesitate to join and ask questions if you need on slack (https://pot-toolbox.slack.com/), gitter (https://gitter.im/PythonOT/community), or the mailing list (https://mail.python.org/mm3/mailman3/lists/pot.python.org/). -->
diff --git a/.github/requirements_strict.txt b/.github/requirements_strict.txt
index d7539c5..9a1ada4 100644
--- a/.github/requirements_strict.txt
+++ b/.github/requirements_strict.txt
@@ -1,7 +1,4 @@
-numpy==1.16.*
-scipy==1.0.*
-cython==0.23.*
-matplotlib
-cvxopt
-scikit-learn
+numpy
+scipy>=1.3
+cython
pytest
diff --git a/.github/requirements_test_windows.txt b/.github/requirements_test_windows.txt
new file mode 100644
index 0000000..331dd57
--- /dev/null
+++ b/.github/requirements_test_windows.txt
@@ -0,0 +1,10 @@
+numpy
+scipy>=1.3
+cython
+matplotlib
+autograd
+pymanopt==0.2.4; python_version <'3'
+pymanopt; python_version >= '3'
+cvxopt
+scikit-learn
+pytest \ No newline at end of file
diff --git a/.github/workflows/build_tests.yml b/.github/workflows/build_tests.yml
index 41b08b3..ee5a435 100644
--- a/.github/workflows/build_tests.yml
+++ b/.github/workflows/build_tests.yml
@@ -1,10 +1,13 @@
-name: build
+name: Tests
on:
- push:
-
+ workflow_dispatch:
pull_request:
-
+ branches:
+ - 'master'
+ push:
+ branches:
+ - 'master'
create:
branches:
- 'master'
@@ -15,10 +18,11 @@ jobs:
linux:
runs-on: ubuntu-latest
+ if: "!contains(github.event.head_commit.message, 'no ci')"
strategy:
max-parallel: 4
matrix:
- python-version: [3.5, 3.6, 3.7, 3.8]
+ python-version: [ "3.6", "3.7", "3.8", "3.9"]
steps:
- uses: actions/checkout@v1
@@ -26,63 +30,70 @@ jobs:
uses: actions/setup-python@v1
with:
python-version: ${{ matrix.python-version }}
+ - name: Install POT
+ run: |
+ pip install -e .
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r requirements.txt
- pip install flake8 pytest "pytest-cov<2.6" codecov
- pip install -U "sklearn"
- - name: Lint with flake8
- run: |
- # stop the build if there are Python syntax errors or undefined names
- flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
- # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
- flake8 examples/ ot/ test/ --count --max-line-length=127 --statistics
- - name: Install POT
- run: |
- pip install -e .
+ pip install pytest "pytest-cov<2.6" codecov
- name: Run tests
run: |
- python -m pytest -v test/ ot/ --doctest-modules --ignore ot/gpu/ --cov=ot
+ python -m pytest --durations=20 -v test/ ot/ --doctest-modules --ignore ot/gpu/ --cov=ot --color=yes
- name: Upload codecov
run: |
codecov
+ pep8:
+ runs-on: ubuntu-latest
+ if: "!contains(github.event.head_commit.message, 'no pep8')"
+ steps:
+ - uses: actions/checkout@v1
+ - name: Set up Python
+ uses: actions/setup-python@v1
+ with:
+ python-version: 3.9
+ - name: Install dependencies
+ run: |
+ python -m pip install --upgrade pip
+ pip install flake8
+ - name: Lint with flake8
+ run: |
+ # stop the build if there are Python syntax errors or undefined names
+ flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
+ # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
+ flake8 examples/ ot/ test/ --count --max-line-length=127 --statistics
linux-minimal-deps:
runs-on: ubuntu-latest
- strategy:
- max-parallel: 4
- matrix:
- python-version: [3.6]
-
+ if: "!contains(github.event.head_commit.message, 'no ci')"
steps:
- uses: actions/checkout@v1
- - name: Set up Python ${{ matrix.python-version }}
+ - name: Set up Python
uses: actions/setup-python@v1
with:
- python-version: ${{ matrix.python-version }}
+ python-version: 3.9
- name: Install dependencies
run: |
python -m pip install --upgrade pip
- pip install -r .github/requirements_strict.txt
pip install pytest
- pip install -U "sklearn"
- name: Install POT
run: |
pip install -e .
- name: Run tests
run: |
- python -m pytest -v test/ ot/ --ignore ot/gpu/
+ python -m pytest --durations=20 -v test/ ot/ --ignore ot/gpu/ --color=yes
macos:
- runs-on: macOS-latest
+ runs-on: macos-latest
+ if: "!contains(github.event.head_commit.message, 'no ci')"
strategy:
max-parallel: 4
matrix:
- python-version: [3.7]
+ python-version: ["3.7", "3.8", "3.9"]
steps:
- uses: actions/checkout@v1
@@ -90,26 +101,26 @@ jobs:
uses: actions/setup-python@v1
with:
python-version: ${{ matrix.python-version }}
+ - name: Install POT
+ run: |
+ pip install -e .
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r requirements.txt
pip install pytest "pytest-cov<2.6"
- pip install -U "sklearn"
- - name: Install POT
- run: |
- pip install -e .
- name: Run tests
run: |
- python -m pytest -v test/ ot/ --doctest-modules --ignore ot/gpu/ --cov=ot
+ python -m pytest --durations=20 -v test/ ot/ --doctest-modules --ignore ot/gpu/ --cov=ot --color=yes
windows:
- runs-on: windows-2019
+ runs-on: windows-latest
+ if: "!contains(github.event.head_commit.message, 'no ci')"
strategy:
max-parallel: 4
matrix:
- python-version: [3.7]
+ python-version: ["3.7", "3.8", "3.9"]
steps:
- uses: actions/checkout@v1
@@ -117,15 +128,15 @@ jobs:
uses: actions/setup-python@v1
with:
python-version: ${{ matrix.python-version }}
+ - name: Install POT
+ run: |
+ python -m pip install -e .
- name: Install dependencies
run: |
python -m pip install --upgrade pip
- pip install -r requirements.txt
- pip install pytest "pytest-cov<2.6"
- pip install -U "sklearn"
- - name: Install POT
- run: |
- pip install -e .
+ python -m pip install -r .github/requirements_test_windows.txt
+ python -m pip install torch==1.8.1+cpu -f https://download.pytorch.org/whl/torch_stable.html
+ python -m pip install pytest "pytest-cov<2.6"
- name: Run tests
run: |
- python -m pytest -v test/ ot/ --doctest-modules --ignore ot/gpu/ --cov=ot
+ python -m pytest --durations=20 -v test/ ot/ --doctest-modules --ignore ot/gpu/ --cov=ot --color=yes
diff --git a/.github/workflows/build_wheels.yml b/.github/workflows/build_wheels.yml
index 662a604..a935a5e 100644
--- a/.github/workflows/build_wheels.yml
+++ b/.github/workflows/build_wheels.yml
@@ -1,19 +1,21 @@
-name: Build dist and wheels
+name: Build wheels
on:
+ workflow_dispatch:
release:
+ pull_request:
push:
branches:
- - "master"
+ - "*"
jobs:
build_wheels:
name: ${{ matrix.os }}
runs-on: ${{ matrix.os }}
+ if: "contains(github.event.head_commit.message, 'build wheels')"
strategy:
matrix:
os: [ubuntu-latest, macos-latest, windows-latest]
- # macos-latest, windows-latest
steps:
- uses: actions/checkout@v1
@@ -30,17 +32,58 @@ jobs:
- name: Install cibuildwheel
run: |
- python -m pip install cibuildwheel==1.3.0
+ python -m pip install cibuildwheel==2.2.2
- - name: Install Visual C++ for Python 2.7
- if: startsWith(matrix.os, 'windows')
+ - name: Build wheels
+ env:
+ CIBW_SKIP: "pp*-win* pp*-macosx* cp2* pp*" # remove pypy on mac and win (wrong version)
+ CIBW_BEFORE_BUILD: "pip install numpy cython"
+ run: |
+ python -m cibuildwheel --output-dir wheelhouse
+
+ - uses: actions/upload-artifact@v1
+ with:
+ name: wheels
+ path: ./wheelhouse
+
+
+ build_all_wheels:
+ name: ${{ matrix.os }}
+ runs-on: ${{ matrix.os }}
+ if: "contains(github.event.head_commit.message, 'build all wheels')"
+ strategy:
+ matrix:
+ os: [ubuntu-latest, macos-latest, windows-latest]
+
+ steps:
+ - uses: actions/checkout@v1
+ - name: Set up Python 3.8
+ uses: actions/setup-python@v1
+ with:
+ python-version: 3.8
+
+ - name: Install dependencies
+ run: |
+ python -m pip install --upgrade pip
+ pip install -r requirements.txt
+ pip install -U "cython"
+
+ - name: Install cibuildwheel
run: |
- choco install vcpython27 -f -y
+ python -m pip install cibuildwheel==2.2.2
- - name: Build wheel
+ - name: Set up QEMU
+ if: runner.os == 'Linux'
+ uses: docker/setup-qemu-action@v1
+ with:
+ platforms: all
+
+ - name: Build wheels
env:
- CIBW_SKIP: "pp*-win* pp*-macosx* cp2* pp*" # remove pypy on mac and win (wrong version)
+ CIBW_SKIP: "pp*-win* pp*-macosx* cp2* pp* cp*musl*" # remove pypy on mac and win (wrong version)
CIBW_BEFORE_BUILD: "pip install numpy cython"
+ CIBW_ARCHS_LINUX: auto aarch64 # force aarch64 with QEMU
+ CIBW_ARCHS_MACOS: x86_64 universal2 arm64
run: |
python -m cibuildwheel --output-dir wheelhouse
@@ -48,3 +91,4 @@ jobs:
with:
name: wheels
path: ./wheelhouse
+
diff --git a/.github/workflows/build_wheels_weekly.yml b/.github/workflows/build_wheels_weekly.yml
new file mode 100644
index 0000000..2964844
--- /dev/null
+++ b/.github/workflows/build_wheels_weekly.yml
@@ -0,0 +1,54 @@
+name: Build all wheels
+
+on:
+ workflow_dispatch:
+ schedule:
+ - cron: '30 0 * * 1'
+ push:
+ branches:
+ - "master"
+
+jobs:
+ build_wheels:
+ name: ${{ matrix.os }}
+ runs-on: ${{ matrix.os }}
+ strategy:
+ matrix:
+ os: [ubuntu-latest, macos-latest, windows-latest]
+
+ steps:
+ - uses: actions/checkout@v1
+ - name: Set up Python 3.8
+ uses: actions/setup-python@v1
+ with:
+ python-version: 3.8
+
+ - name: Install dependencies
+ run: |
+ python -m pip install --upgrade pip
+ pip install -r requirements.txt
+ pip install -U "cython"
+
+ - name: Install cibuildwheel
+ run: |
+ python -m pip install cibuildwheel==2.2.2
+
+ - name: Set up QEMU
+ if: runner.os == 'Linux'
+ uses: docker/setup-qemu-action@v1
+ with:
+ platforms: all
+
+ - name: Build wheels
+ env:
+ CIBW_SKIP: "pp*-win* pp*-macosx* cp2* pp* cp*musl*" # remove pypy on mac and win (wrong version)
+ CIBW_BEFORE_BUILD: "pip install numpy cython"
+ CIBW_ARCHS_LINUX: auto aarch64 # force aarch64 with QEMU
+ CIBW_ARCHS_MACOS: x86_64 universal2 arm64
+ run: |
+ python -m cibuildwheel --output-dir wheelhouse
+
+ - uses: actions/upload-artifact@v1
+ with:
+ name: wheels
+ path: ./wheelhouse
diff --git a/.gitignore b/.gitignore
index a2ace7c..b44ea43 100644
--- a/.gitignore
+++ b/.gitignore
@@ -40,6 +40,9 @@ var/
*.manifest
*.spec
+# env
+pythonenv3.8/
+
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
diff --git a/MANIFEST.in b/MANIFEST.in
index df4e139..da67c77 100644
--- a/MANIFEST.in
+++ b/MANIFEST.in
@@ -6,4 +6,6 @@ include ot/lp/EMD.h
include ot/lp/EMD_wrapper.cpp
include ot/lp/emd_wrap.pyx
include ot/lp/full_bipartitegraph.h
+include ot/lp/full_bipartitegraph_omp.h
include ot/lp/network_simplex_simple.h
+include ot/lp/network_simplex_simple_omp.h
diff --git a/Makefile b/Makefile
index 70cdbdd..315218d 100644
--- a/Makefile
+++ b/Makefile
@@ -45,10 +45,10 @@ pep8 :
flake8 examples/ ot/ test/
test : FORCE pep8
- $(PYTHON) -m pytest -v test/ --doctest-modules --ignore ot/gpu/ --cov=ot --cov-report html:cov_html
+ $(PYTHON) -m pytest --durations=20 -v test/ --doctest-modules --ignore ot/gpu/
pytest : FORCE
- $(PYTHON) -m pytest -v test/ --doctest-modules --ignore ot/gpu/ --cov=ot
+ $(PYTHON) -m pytest --durations=20 -v test/ --doctest-modules --ignore ot/gpu/
release :
twine upload dist/*
diff --git a/README.md b/README.md
index 5626742..08db003 100644
--- a/README.md
+++ b/README.md
@@ -2,13 +2,12 @@
[![PyPI version](https://badge.fury.io/py/POT.svg)](https://badge.fury.io/py/POT)
[![Anaconda Cloud](https://anaconda.org/conda-forge/pot/badges/version.svg)](https://anaconda.org/conda-forge/pot)
-[![Build Status](https://github.com/PythonOT/POT/workflows/build/badge.svg)](https://github.com/PythonOT/POT/actions)
+[![Build Status](https://github.com/PythonOT/POT/workflows/build/badge.svg?branch=master&event=push)](https://github.com/PythonOT/POT/actions)
[![Codecov Status](https://codecov.io/gh/PythonOT/POT/branch/master/graph/badge.svg)](https://codecov.io/gh/PythonOT/POT)
[![Downloads](https://pepy.tech/badge/pot)](https://pepy.tech/project/pot)
[![Anaconda downloads](https://anaconda.org/conda-forge/pot/badges/downloads.svg)](https://anaconda.org/conda-forge/pot)
[![License](https://anaconda.org/conda-forge/pot/badges/license.svg)](https://github.com/PythonOT/POT/blob/master/LICENSE)
-
This open source Python library provide several solvers for optimization
problems related to Optimal Transport for signal, image processing and machine
learning.
@@ -21,18 +20,22 @@ POT provides the following generic OT solvers (links to examples):
* [OT Network Simplex solver](https://pythonot.github.io/auto_examples/plot_OT_1D.html) for the linear program/ Earth Movers Distance [1] .
* [Conditional gradient](https://pythonot.github.io/auto_examples/plot_optim_OTreg.html) [6] and [Generalized conditional gradient](https://pythonot.github.io/auto_examples/plot_optim_OTreg.html) for regularized OT [7].
-* Entropic regularization OT solver with [Sinkhorn Knopp Algorithm](https://pythonot.github.io/auto_examples/plot_OT_1D.html) [2] , stabilized version [9] [10], greedy Sinkhorn [22] and [Screening Sinkhorn [26] ](https://pythonot.github.io/auto_examples/plot_screenkhorn_1D.html) with optional GPU implementation (requires cupy).
+* Entropic regularization OT solver with [Sinkhorn Knopp Algorithm](https://pythonot.github.io/auto_examples/plot_OT_1D.html) [2] , stabilized version [9] [10] [34], greedy Sinkhorn [22] and [Screening Sinkhorn [26] ](https://pythonot.github.io/auto_examples/plot_screenkhorn_1D.html).
* Bregman projections for [Wasserstein barycenter](https://pythonot.github.io/auto_examples/barycenters/plot_barycenter_lp_vs_entropic.html) [3], [convolutional barycenter](https://pythonot.github.io/auto_examples/barycenters/plot_convolutional_barycenter.html) [21] and unmixing [4].
-* Sinkhorn divergence [23] and entropic regularization OT from empirical data.
+* Sinkhorn divergence [23] and entropic regularization OT from empirical data.
+* Debiased Sinkhorn barycenters [Sinkhorn divergence barycenter](https://pythonot.github.io/auto_examples/barycenters/plot_debiased_barycenter.html) [37]
* [Smooth optimal transport solvers](https://pythonot.github.io/auto_examples/plot_OT_1D_smooth.html) (dual and semi-dual) for KL and squared L2 regularizations [17].
* Non regularized [Wasserstein barycenters [16] ](https://pythonot.github.io/auto_examples/barycenters/plot_barycenter_lp_vs_entropic.html)) with LP solver (only small scale).
-* [Gromov-Wasserstein distances](https://pythonot.github.io/auto_examples/gromov/plot_gromov.html) and [GW barycenters](https://pythonot.github.io/auto_examples/gromov/plot_gromov_barycenter.html) (exact [13] and regularized [12])
+* [Gromov-Wasserstein distances](https://pythonot.github.io/auto_examples/gromov/plot_gromov.html) and [GW barycenters](https://pythonot.github.io/auto_examples/gromov/plot_gromov_barycenter.html) (exact [13] and regularized [12]), differentiable using gradients from
* [Fused-Gromov-Wasserstein distances solver](https://pythonot.github.io/auto_examples/gromov/plot_fgw.html#sphx-glr-auto-examples-plot-fgw-py) and [FGW barycenters](https://pythonot.github.io/auto_examples/gromov/plot_barycenter_fgw.html) [24]
* [Stochastic solver](https://pythonot.github.io/auto_examples/plot_stochastic.html) for Large-scale Optimal Transport (semi-dual problem [18] and dual problem [19])
+* [Stochastic solver of Gromov Wasserstein](https://pythonot.github.io/auto_examples/gromov/plot_gromov.html) for large-scale problem with any loss functions [33]
* Non regularized [free support Wasserstein barycenters](https://pythonot.github.io/auto_examples/barycenters/plot_free_support_barycenter.html) [20].
* [Unbalanced OT](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT_1D.html) with KL relaxation and [barycenter](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT_barycenter_1D.html) [10, 25].
* [Partial Wasserstein and Gromov-Wasserstein](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_partial_wass_and_gromov.html) (exact [29] and entropic [3]
formulations).
+* [Sliced Wasserstein](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_variance.html) [31, 32] and Max-sliced Wasserstein [35] that can be used for gradient flows [36].
+* [Several backends](https://pythonot.github.io/quickstart.html#solving-ot-with-multiple-backends) for easy use of POT with [Pytorch](https://pytorch.org/)/[jax](https://github.com/google/jax)/[Numpy](https://numpy.org/) arrays.
POT provides the following Machine Learning related solvers:
@@ -49,19 +52,25 @@ Some other examples are available in the [documentation](https://pythonot.githu
#### Using and citing the toolbox
If you use this toolbox in your research and find it useful, please cite POT
-using the following reference:
-```
-Rémi Flamary and Nicolas Courty, POT Python Optimal Transport library,
-Website: https://pythonot.github.io/, 2017
-```
+using the following reference from our [JMLR paper](https://jmlr.org/papers/v22/20-451.html):
+
+ Rémi Flamary, Nicolas Courty, Alexandre Gramfort, Mokhtar Z. Alaya, Aurélie Boisbunon, Stanislas Chambon, Laetitia Chapel, Adrien Corenflos, Kilian Fatras, Nemo Fournier, Léo Gautheron, Nathalie T.H. Gayraud, Hicham Janati, Alain Rakotomamonjy, Ievgen Redko, Antoine Rolet, Antony Schutz, Vivien Seguy, Danica J. Sutherland, Romain Tavenard, Alexander Tong, Titouan Vayer,
+ POT Python Optimal Transport library,
+ Journal of Machine Learning Research, 22(78):1−8, 2021.
+ Website: https://pythonot.github.io/
In Bibtex format:
-```
-@misc{flamary2017pot,
-title={POT Python Optimal Transport library},
-author={Flamary, R{'e}mi and Courty, Nicolas},
-url={https://pythonot.github.io/},
-year={2017}
+
+```bibtex
+@article{flamary2021pot,
+ author = {R{\'e}mi Flamary and Nicolas Courty and Alexandre Gramfort and Mokhtar Z. Alaya and Aur{\'e}lie Boisbunon and Stanislas Chambon and Laetitia Chapel and Adrien Corenflos and Kilian Fatras and Nemo Fournier and L{\'e}o Gautheron and Nathalie T.H. Gayraud and Hicham Janati and Alain Rakotomamonjy and Ievgen Redko and Antoine Rolet and Antony Schutz and Vivien Seguy and Danica J. Sutherland and Romain Tavenard and Alexander Tong and Titouan Vayer},
+ title = {POT: Python Optimal Transport},
+ journal = {Journal of Machine Learning Research},
+ year = {2021},
+ volume = {22},
+ number = {78},
+ pages = {1-8},
+ url = {http://jmlr.org/papers/v22/20-451.html}
}
```
@@ -71,41 +80,39 @@ The library has been tested on Linux, MacOSX and Windows. It requires a C++ comp
- Numpy (>=1.16)
- Scipy (>=1.0)
-- Cython (>=0.23)
-- Matplotlib (>=1.5)
+- Cython (>=0.23) (build only, not necessary when installing from pip or conda)
#### Pip installation
-Note that due to a limitation of pip, `cython` and `numpy` need to be installed
-prior to installing POT. This can be done easily with
-```
-pip install numpy cython
-```
You can install the toolbox through PyPI with:
-```
+
+```console
pip install POT
```
+
or get the very latest version by running:
-```
+
+```console
pip install -U https://github.com/PythonOT/POT/archive/master.zip # with --user for user install (no root)
```
-
-
#### Anaconda installation with conda-forge
If you use the Anaconda python distribution, POT is available in [conda-forge](https://conda-forge.org). To install it and the required dependencies:
-```
+
+```console
conda install -c conda-forge pot
```
#### Post installation check
After a correct installation, you should be able to import the module without errors:
+
```python
import ot
```
-Note that for easier access the module is name ot instead of pot.
+
+Note that for easier access the module is named `ot` instead of `pot`.
### Dependencies
@@ -113,42 +120,48 @@ Note that for easier access the module is name ot instead of pot.
Some sub-modules require additional dependences which are discussed below
* **ot.dr** (Wasserstein dimensionality reduction) depends on autograd and pymanopt that can be installed with:
-```
+
+```shell
pip install pymanopt autograd
```
-* **ot.gpu** (GPU accelerated OT) depends on cupy that have to be installed following instructions on [this page](https://docs-cupy.chainer.org/en/stable/install.html).
-
-obviously you need CUDA installed and a compatible GPU.
+* **ot.gpu** (GPU accelerated OT) depends on cupy that have to be installed following instructions on [this page](https://docs-cupy.chainer.org/en/stable/install.html). Obviously you will need CUDA installed and a compatible GPU. Note that this module is deprecated since version 0.8 and will be deleted in the future. GPU is now handled automatically through the backends and several solver already can run on GPU using the Pytorch backend.
## Examples
### Short examples
* Import the toolbox
+
```python
import ot
```
+
* Compute Wasserstein distances
+
```python
# a,b are 1D histograms (sum to 1 and positive)
# M is the ground cost matrix
-Wd=ot.emd2(a,b,M) # exact linear program
-Wd_reg=ot.sinkhorn2(a,b,M,reg) # entropic regularized OT
+Wd = ot.emd2(a, b, M) # exact linear program
+Wd_reg = ot.sinkhorn2(a, b, M, reg) # entropic regularized OT
# if b is a matrix compute all distances to a and return a vector
```
+
* Compute OT matrix
+
```python
# a,b are 1D histograms (sum to 1 and positive)
# M is the ground cost matrix
-T=ot.emd(a,b,M) # exact linear program
-T_reg=ot.sinkhorn(a,b,M,reg) # entropic regularized OT
+T = ot.emd(a, b, M) # exact linear program
+T_reg = ot.sinkhorn(a, b, M, reg) # entropic regularized OT
```
+
* Compute Wasserstein barycenter
+
```python
# A is a n*d matrix containing d 1D histograms
# M is the ground cost matrix
-ba=ot.barycenter(A,M,reg) # reg is regularization parameter
+ba = ot.barycenter(A, M, reg) # reg is regularization parameter
```
### Examples and Notebooks
@@ -176,31 +189,34 @@ The contributors to this library are
* [Kilian Fatras](https://kilianfatras.github.io/) (Stochastic solvers)
* [Alain Rakotomamonjy](https://sites.google.com/site/alainrakotomamonjy/home)
* [Vayer Titouan](https://tvayer.github.io/) (Gromov-Wasserstein -, Fused-Gromov-Wasserstein)
-* [Hicham Janati](https://hichamjanati.github.io/) (Unbalanced OT)
+* [Hicham Janati](https://hichamjanati.github.io/) (Unbalanced OT, Debiased barycenters)
* [Romain Tavenard](https://rtavenar.github.io/) (1d Wasserstein)
* [Mokhtar Z. Alaya](http://mzalaya.github.io/) (Screenkhorn)
* [Ievgen Redko](https://ievred.github.io/) (Laplacian DA, JCPOT)
+* [Adrien Corenflos](https://adriencorenflos.github.io/) (Sliced Wasserstein Distance)
+* [Tanguy Kerdoncuff](https://hv0nnus.github.io/) (Sampled Gromov Wasserstein)
+* [Minhui Huang](https://mhhuang95.github.io) (Projection Robust Wasserstein Distance)
This toolbox benefit a lot from open source research and we would like to thank the following persons for providing some code (in various languages):
* [Gabriel Peyré](http://gpeyre.github.io/) (Wasserstein Barycenters in Matlab)
+* [Mathieu Blondel](https://mblondel.org/) (original implementation smooth OT)
* [Nicolas Bonneel](http://liris.cnrs.fr/~nbonneel/) ( C++ code for EMD)
* [Marco Cuturi](http://marcocuturi.net/) (Sinkhorn Knopp in Matlab/Cuda)
-
## Contributions and code of conduct
-Every contribution is welcome and should respect the [contribution guidelines](CONTRIBUTING.md). Each member of the project is expected to follow the [code of conduct](CODE_OF_CONDUCT.md).
+Every contribution is welcome and should respect the [contribution guidelines](.github/CONTRIBUTING.md). Each member of the project is expected to follow the [code of conduct](.github/CODE_OF_CONDUCT.md).
## Support
You can ask questions and join the development discussion:
-* On the [POT Slack channel](https://pot-toolbox.slack.com)
+* On the POT [slack channel](https://pot-toolbox.slack.com)
+* On the POT [gitter channel](https://gitter.im/PythonOT/community)
* On the POT [mailing list](https://mail.python.org/mm3/mailman3/lists/pot.python.org/)
-
-You can also post bug reports and feature requests in Github issues. Make sure to read our [guidelines](CONTRIBUTING.md) first.
+You can also post bug reports and feature requests in Github issues. Make sure to read our [guidelines](.github/CONTRIBUTING.md) first.
## References
@@ -260,6 +276,27 @@ You can also post bug reports and feature requests in Github issues. Make sure t
[28] Caffarelli, L. A., McCann, R. J. (2010). [Free boundaries in optimal transport and Monge-Ampere obstacle problems](http://www.math.toronto.edu/~mccann/papers/annals2010.pdf), Annals of mathematics, 673-730.
-[29] Chapel, L., Alaya, M., Gasso, G. (2019). [Partial Gromov-Wasserstein with Applications on Positive-Unlabeled Learning](https://arxiv.org/abs/2002.08276), arXiv preprint arXiv:2002.08276.
+[29] Chapel, L., Alaya, M., Gasso, G. (2020). [Partial Optimal Transport with Applications on Positive-Unlabeled Learning](https://arxiv.org/abs/2002.08276), Advances in Neural Information Processing Systems (NeurIPS), 2020.
[30] Flamary R., Courty N., Tuia D., Rakotomamonjy A. (2014). [Optimal transport with Laplacian regularization: Applications to domain adaptation and shape matching](https://remi.flamary.com/biblio/flamary2014optlaplace.pdf), NIPS Workshop on Optimal Transport and Machine Learning OTML, 2014.
+
+[31] Bonneel, Nicolas, et al. [Sliced and radon wasserstein barycenters of measures](https://perso.liris.cnrs.fr/nicolas.bonneel/WassersteinSliced-JMIV.pdf), Journal of Mathematical Imaging and Vision 51.1 (2015): 22-45
+
+[32] Huang, M., Ma S., Lai, L. (2021). [A Riemannian Block Coordinate Descent Method for Computing the Projection Robust Wasserstein Distance](http://proceedings.mlr.press/v139/huang21e.html), Proceedings of the 38th International Conference on Machine Learning (ICML).
+
+[33] Kerdoncuff T., Emonet R., Marc S. [Sampled Gromov Wasserstein](https://hal.archives-ouvertes.fr/hal-03232509/document), Machine Learning Journal (MJL), 2021
+
+[34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., Trouvé, A., & Peyré, G. (2019, April). [Interpolating between optimal transport and MMD using Sinkhorn divergences](http://proceedings.mlr.press/v89/feydy19a/feydy19a.pdf). In The 22nd International Conference on Artificial Intelligence and Statistics (pp. 2681-2690). PMLR.
+
+[35] Deshpande, I., Hu, Y. T., Sun, R., Pyrros, A., Siddiqui, N., Koyejo, S., ... & Schwing, A. G. (2019). [Max-sliced wasserstein distance and its use for gans](https://openaccess.thecvf.com/content_CVPR_2019/papers/Deshpande_Max-Sliced_Wasserstein_Distance_and_Its_Use_for_GANs_CVPR_2019_paper.pdf). In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (pp. 10648-10656).
+
+[36] Liutkus, A., Simsekli, U., Majewski, S., Durmus, A., & Stöter, F. R.
+(2019, May). [Sliced-Wasserstein flows: Nonparametric generative modeling
+via optimal transport and diffusions](http://proceedings.mlr.press/v97/liutkus19a/liutkus19a.pdf). In International Conference on
+Machine Learning (pp. 4104-4113). PMLR.
+
+[37] Janati, H., Cuturi, M., Gramfort, A. [Debiased sinkhorn barycenters](http://proceedings.mlr.press/v119/janati20a/janati20a.pdf) Proceedings of the 37th International
+Conference on Machine Learning, PMLR 119:4692-4701, 2020
+
+[38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, [Online Graph
+Dictionary Learning](https://arxiv.org/pdf/2102.06555.pdf), International Conference on Machine Learning (ICML), 2021. \ No newline at end of file
diff --git a/RELEASES.md b/RELEASES.md
index adb7fc1..6eb1502 100644
--- a/RELEASES.md
+++ b/RELEASES.md
@@ -1,19 +1,167 @@
# Releases
-## 0.7.0
-*May 2020*
-This is the new stable release for POT. We made a lot of changes in the documentation and added several new features such as Partial OT, Unbalanced and Multi Sources OT Domain Adaptation and several bug fixes. One important change is that we have created the GitHub organization [PythonOT](https://github.com/PythonOT) that now owns the main POT repository [https://github.com/PythonOT/POT](https://github.com/PythonOT/POT) and the repository for the new documentation is now hosted at [https://PythonOT.github.io/](https://PythonOT.github.io/).
+## 0.8.0
+*November 2021*
+
+This new stable release introduces several important features.
+
+First we now have
+an OpenMP compatible exact ot solver in `ot.emd`. The OpenMP version is used
+when the parameter `numThreads` is greater than one and can lead to nice
+speedups on multi-core machines.
+
+Second we have introduced a backend mechanism that allows to use standard POT
+function seamlessly on Numpy, Pytorch and Jax arrays. Other backends are coming
+but right now POT can be used seamlessly for training neural networks in
+Pytorch. Notably we propose the first differentiable computation of the exact OT
+loss with `ot.emd2` (can be differentiated w.r.t. both cost matrix and sample
+weights), but also for the classical Sinkhorn loss with `ot.sinkhorn2`, the
+Wasserstein distance in 1D with `ot.wasserstein_1d`, sliced Wasserstein with
+`ot.sliced_wasserstein_distance` and Gromov-Wasserstein with `ot.gromov_wasserstein2`. Examples of how
+this new feature can be used are now available in the documentation where the
+Pytorch backend is used to estimate a [minimal Wasserstein
+estimator](https://PythonOT.github.io/auto_examples/backends/plot_unmix_optim_torch.html),
+a [Generative Network
+(GAN)](https://PythonOT.github.io/auto_examples/backends/plot_wass2_gan_torch.html),
+for a [sliced Wasserstein gradient
+flow](https://PythonOT.github.io/auto_examples/backends/plot_sliced_wass_grad_flow_pytorch.html)
+and [optimizing the Gromov-Wassersein distance](https://PythonOT.github.io/auto_examples/backends/plot_optim_gromov_pytorch.html). Note that the Jax backend is still in early development and quite
+slow at the moment, we strongly recommend for Jax users to use the [OTT
+toolbox](https://github.com/google-research/ott) when possible.
+ As a result of this new feature,
+ the old `ot.gpu` submodule is now deprecated since GPU
+implementations can be done using GPU arrays on the torch backends.
+
+Other novel features include implementation for [Sampled Gromov Wasserstein and
+Pointwise Gromov
+Wasserstein](https://PythonOT.github.io/auto_examples/gromov/plot_gromov.html#compute-gw-with-a-scalable-stochastic-method-with-any-loss-function),
+Sinkhorn in log space with `method='sinkhorn_log'`, [Projection Robust
+Wasserstein](https://PythonOT.github.io/gen_modules/ot.dr.html?highlight=robust#ot.dr.projection_robust_wasserstein),
+ans [deviased Sinkorn barycenters](https://PythonOT.github.ioauto_examples/barycenters/plot_debiased_barycenter.html).
+
+This release will also simplify the installation process. We have now a
+`pyproject.toml` that defines the build dependency and POT should now build even
+when cython is not installed yet. Also we now provide pe-compiled wheels for
+linux `aarch64` that is used on Raspberry PI and android phones and for MacOS on
+ARM processors.
+
+
+Finally POT was accepted for publication in the Journal of Machine Learning
+Research (JMLR) open source software track and we ask the POT users to cite [this
+paper](https://www.jmlr.org/papers/v22/20-451.html) from now on. The documentation has been improved in particular by adding a
+"Why OT?" section to the quick start guide and several new examples illustrating
+the new features. The documentation now has two version : the stable version
+[https://pythonot.github.io/](https://pythonot.github.io/)
+corresponding to the last release and the master version [https://pythonot.github.io/master](https://pythonot.github.io/master) that corresponds to the
+current master branch on GitHub.
+
+
+As usual, we want to thank all the POT contributors (now 37 people have
+contributed to the toolbox). But for this release we thank in particular Nathan
+Cassereau and Kamel Guerda from the AI support team at
+[IDRIS](http://www.idris.fr/) for their support to the development of the
+backend and OpenMP implementations.
+
+
+#### New features
+
+- OpenMP support for exact OT solvers (PR #260)
+- Backend for running POT in numpy/torch + exact solver (PR #249)
+- Backend implementation of most functions in `ot.bregman` (PR #280)
+- Backend implementation of most functions in `ot.optim` (PR #282)
+- Backend implementation of most functions in `ot.gromov` (PR #294, PR #302)
+- Test for arrays of different type and device (CPU/GPU) (PR #304, #303)
+- Implementation of Sinkhorn in log space with `method='sinkhorn_log'` (PR #290)
+- Implementation of regularization path for L2 Unbalanced OT (PR #274)
+- Implementation of Projection Robust Wasserstein (PR #267)
+- Implementation of Debiased Sinkhorn Barycenters (PR #291)
+- Implementation of Sampled Gromov Wasserstein and Pointwise Gromov Wasserstein
+ (PR #275)
+- Add `pyproject.toml` and build POT without installing cython first (PR #293)
+- Lazy implementation in log space for sinkhorn on samples (PR #259)
+- Documentation cleanup (PR #298)
+- Two up-to-date documentations [for stable
+ release](https://PythonOT.github.io/) and for [master branch](https://pythonot.github.io/master/).
+- Building wheels on ARM for Raspberry PI and smartphones (PR #238)
+- Update build wheels to new version and new pythons (PR #236, #253)
+- Implementation of sliced Wasserstein distance (Issue #202, PR #203)
+- Add minimal build to CI and perform pep8 test separately (PR #210)
+- Speedup of tests and return run time (PR #262)
+- Add "Why OT" discussion to the documentation (PR #220)
+- New introductory example to discrete OT in the documentation (PR #191)
+- Add templates for Issues/PR on Github (PR#181)
-This is the first release where the Python 2.7 tests have been removed. Most of the toolbox should still work but we do not offer support for Python 2.7 and will close related Issues.
+#### Closed issues
-A lot of changes have been done to the documentation that is now hosted on [https://PythonOT.github.io/](https://PythonOT.github.io/) instead of readthedocs. It was a hard choice but readthedocs did not allow us to run sphinx-gallery to update our beautiful examples and it was a huge amount of work to maintain. The documentation is now automatically compiled and updated on merge. We also removed the notebooks from the repository for space reason and also because they are all available in the [example gallery](https://pythonot.github.io/auto_examples/index.html). Note that now the output of the documentation build for each commit in the PR is available to check that the doc builds correctly before merging which was not possible with readthedocs.
+- Debug Memory leak in GAN example (#254)
+- DEbug GPU bug (Issue #284, #287, PR #288)
+- set_gradients method for JAX backend (PR #278)
+- Quicker GAN example for CircleCI build (PR #258)
+- Better formatting in Readme (PR #234)
+- Debug CI tests (PR #240, #241, #242)
+- Bug in Partial OT solver dummy points (PR #215)
+- Bug when Armijo linesearch (Issue #184, #198, #281, PR #189, #199, #286)
+- Bug Barycenter Sinkhorn (Issue 134, PR #195)
+- Infeasible solution in exact OT (Issues #126,#93, PR #217)
+- Doc for SUpport Barycenters (Issue #200, PR #201)
+- Fix labels transport in BaseTransport (Issue #207, PR #208)
+- Bug in `emd_1d`, non respected bounds (Issue #169, PR #170)
+- Removed Python 2.7 support and update codecov file (PR #178)
+- Add normalization for WDA and test it (PR #172, #296)
+- Cleanup code for new version of `flake8` (PR #176)
+- Fixed requirements in `setup.py` (PR #174)
+- Removed specific MacOS flags (PR #175)
-The CI framework has also been changed with a move from Travis to Github Action which allows to get faster tests on Windows, MacOS and Linux. We also now report our coverage on [Codecov.io](https://codecov.io/gh/PythonOT/POT) and we have a reasonable 92% coverage. We also now generate wheels for a number of OS and Python versions at each merge in the master branch. They are available as outputs of this [action](https://github.com/PythonOT/POT/actions?query=workflow%3A%22Build+dist+and+wheels%22). This will allow simpler multi-platform releases from now on.
-In terms of new features we now have [OTDA Classes for unbalanced OT](https://pythonot.github.io/gen_modules/ot.da.html#ot.da.UnbalancedSinkhornTransport), a new Domain adaptation class form [multi domain problems (JCPOT)](https://pythonot.github.io/auto_examples/domain-adaptation/plot_otda_jcpot.html#sphx-glr-auto-examples-domain-adaptation-plot-otda-jcpot-py), and several solvers to solve the [Partial Optimal Transport](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_partial_wass_and_gromov.html#sphx-glr-auto-examples-unbalanced-partial-plot-partial-wass-and-gromov-py) problems.
+## 0.7.0
+*May 2020*
-This release is also the moment to thank all the POT contributors (old and new) for helping making POT such a nice toolbox. A lot of changes (also in the API) are comming for the next versions.
+This is the new stable release for POT. We made a lot of changes in the
+documentation and added several new features such as Partial OT, Unbalanced and
+Multi Sources OT Domain Adaptation and several bug fixes. One important change
+is that we have created the GitHub organization
+[PythonOT](https://github.com/PythonOT) that now owns the main POT repository
+[https://github.com/PythonOT/POT](https://github.com/PythonOT/POT) and the
+repository for the new documentation is now hosted at
+[https://PythonOT.github.io/](https://PythonOT.github.io/).
+
+This is the first release where the Python 2.7 tests have been removed. Most of
+the toolbox should still work but we do not offer support for Python 2.7 and
+will close related Issues.
+
+A lot of changes have been done to the documentation that is now hosted on
+[https://PythonOT.github.io/](https://PythonOT.github.io/) instead of
+readthedocs. It was a hard choice but readthedocs did not allow us to run
+sphinx-gallery to update our beautiful examples and it was a huge amount of work
+to maintain. The documentation is now automatically compiled and updated on
+merge. We also removed the notebooks from the repository for space reason and
+also because they are all available in the [example
+gallery](https://pythonot.github.io/auto_examples/index.html). Note that now the
+output of the documentation build for each commit in the PR is available to
+check that the doc builds correctly before merging which was not possible with
+readthedocs.
+
+The CI framework has also been changed with a move from Travis to Github Action
+which allows to get faster tests on Windows, MacOS and Linux. We also now report
+our coverage on [Codecov.io](https://codecov.io/gh/PythonOT/POT) and we have a
+reasonable 92% coverage. We also now generate wheels for a number of OS and
+Python versions at each merge in the master branch. They are available as
+outputs of this
+[action](https://github.com/PythonOT/POT/actions?query=workflow%3A%22Build+dist+and+wheels%22).
+This will allow simpler multi-platform releases from now on.
+
+In terms of new features we now have [OTDA Classes for unbalanced
+OT](https://pythonot.github.io/gen_modules/ot.da.html#ot.da.UnbalancedSinkhornTransport),
+a new Domain adaptation class form [multi domain problems
+(JCPOT)](https://pythonot.github.io/auto_examples/domain-adaptation/plot_otda_jcpot.html#sphx-glr-auto-examples-domain-adaptation-plot-otda-jcpot-py),
+and several solvers to solve the [Partial Optimal
+Transport](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_partial_wass_and_gromov.html#sphx-glr-auto-examples-unbalanced-partial-plot-partial-wass-and-gromov-py)
+problems.
+
+This release is also the moment to thank all the POT contributors (old and new)
+for helping making POT such a nice toolbox. A lot of changes (also in the API)
+are coming for the next versions.
#### Features
@@ -31,6 +179,8 @@ This release is also the moment to thank all the POT contributors (old and new)
#### Closed issues
+- Add JMLR paper to teh readme ad Mathieu Blondel to the Acknoledgments (PR
+ #231, #232)
- Bug in Unbalanced OT example (Issue #127)
- Clean Cython output when calling setup.py clean (Issue #122)
- Various Macosx compilation problems (Issue #113, Issue #118, PR#130)
@@ -54,18 +204,24 @@ https://python3statement.org/ for more reasons). For next release we will keep
the travis tests for Python 2 but will make them non necessary for merge in 2020.
The features are never complete in a toolbox designed for solving mathematical
-problems and research but with the new contributions we now implement algorithms and solvers
-from 24 scientific papers (listed in the README.md file). New features include a
-direct implementation of the [empirical Sinkhorn divergence](https://pot.readthedocs.io/en/latest/all.html#ot.bregman.empirical_sinkhorn_divergence)
-, a new efficient (Cython implementation) solver for [EMD in 1D](https://pot.readthedocs.io/en/latest/all.html#ot.lp.emd_1d)
-and corresponding [Wasserstein
-1D](https://pot.readthedocs.io/en/latest/all.html#ot.lp.wasserstein_1d). We now also
-have implementations for [Unbalanced OT](https://github.com/rflamary/POT/blob/master/notebooks/plot_UOT_1D.ipynb)
-and a solver for [Unbalanced OT barycenters](https://github.com/rflamary/POT/blob/master/notebooks/plot_UOT_barycenter_1D.ipynb).
+problems and research but with the new contributions we now implement algorithms
+and solvers from 24 scientific papers (listed in the README.md file). New
+features include a direct implementation of the [empirical Sinkhorn
+divergence](https://pot.readthedocs.io/en/latest/all.html#ot.bregman.empirical_sinkhorn_divergence),
+a new efficient (Cython implementation) solver for [EMD in
+1D](https://pot.readthedocs.io/en/latest/all.html#ot.lp.emd_1d) and
+corresponding [Wasserstein
+1D](https://pot.readthedocs.io/en/latest/all.html#ot.lp.wasserstein_1d). We now
+also have implementations for [Unbalanced
+OT](https://github.com/rflamary/POT/blob/master/notebooks/plot_UOT_1D.ipynb) and
+a solver for [Unbalanced OT
+barycenters](https://github.com/rflamary/POT/blob/master/notebooks/plot_UOT_barycenter_1D.ipynb).
A new variant of Gromov-Wasserstein divergence called [Fused
Gromov-Wasserstein](https://pot.readthedocs.io/en/latest/all.html?highlight=fused_#ot.gromov.fused_gromov_wasserstein)
- has been also contributed with exemples of use on [structured data](https://github.com/rflamary/POT/blob/master/notebooks/plot_fgw.ipynb)
-and computing [barycenters of labeld graphs](https://github.com/rflamary/POT/blob/master/notebooks/plot_barycenter_fgw.ipynb).
+has been also contributed with exemples of use on [structured
+data](https://github.com/rflamary/POT/blob/master/notebooks/plot_fgw.ipynb) and
+computing [barycenters of labeld
+graphs](https://github.com/rflamary/POT/blob/master/notebooks/plot_barycenter_fgw.ipynb).
A lot of work has been done on the documentation with several new
diff --git a/_config.yml b/_config.yml
deleted file mode 100644
index c741881..0000000
--- a/_config.yml
+++ /dev/null
@@ -1 +0,0 @@
-theme: jekyll-theme-slate \ No newline at end of file
diff --git a/codecov.yml b/codecov.yml
index fbd1b07..1447ced 100644
--- a/codecov.yml
+++ b/codecov.yml
@@ -1,3 +1,15 @@
+# Docs ref: https://docs.codecov.io/docs/codecovyml-reference
+# Validation check: $ curl --data-binary @codecov.yml https://codecov.io/validate
+
+
+codecov:
+ token: 057953e4-d263-41c0-913c-5d45c0371df9
+ bot: "codecov-io"
+ strict_yaml_branch: "yaml-config"
+ require_ci_to_pass: yes
+ notify:
+ wait_for_ci: yes
+
coverage:
precision: 2
round: down
@@ -5,12 +17,31 @@ coverage:
status:
project:
default:
- target: auto
- threshold: 0.01
- patch: false
+ base: auto # target to compare against
+ target: auto # target "X%" coverage to hit on project
+ threshold: 1% # allow this much decrease from base
+ if_ci_failed: error
+ patch:
+ default:
+ base: auto # target to compare against
+ target: 50% # target "X%" coverage to hit on patch
+ # threshold: 50% # allow this much decrease on patch
changes: false
+
+parsers:
+ gcov:
+ branch_detection:
+ conditional: yes
+ loop: yes
+ method: no
+ macro: no
+
+# https://docs.codecov.io/docs/ignoring-paths
+ignore:
+ - "ot/gpu/*"
+
+# https://docs.codecov.io/docs/pull-request-comments
comment:
- layout: "header, diff, sunburst, uncovered"
+ layout: header, diff, sunburst, uncovered
behavior: default
-codecov:
- token: 057953e4-d263-41c0-913c-5d45c0371df9 \ No newline at end of file
+ require_changes: true # if true: only post the comment if coverage changes
diff --git a/data/manhattan.npz b/data/manhattan.npz
new file mode 100644
index 0000000..37808fb
--- /dev/null
+++ b/data/manhattan.npz
Binary files differ
diff --git a/docs/Makefile b/docs/Makefile
index 3511a59..9892785 100644
--- a/docs/Makefile
+++ b/docs/Makefile
@@ -57,6 +57,11 @@ html:
@echo
@echo "Build finished. The HTML pages are in $(BUILDDIR)/html."
+html-noplot:
+ $(SPHINXBUILD) -D plot_gallery=0 -b html $(ALLSPHINXOPTS) $(BUILDDIR)/html
+ @echo
+ @echo "Build finished. The HTML pages are in $(BUILDDIR)/html."
+
.PHONY: dirhtml
dirhtml:
$(SPHINXBUILD) -b dirhtml $(ALLSPHINXOPTS) $(BUILDDIR)/dirhtml
diff --git a/docs/source/_templates/versions.html b/docs/source/_templates/versions.html
new file mode 100644
index 0000000..10d60d7
--- /dev/null
+++ b/docs/source/_templates/versions.html
@@ -0,0 +1,43 @@
+<div class="rst-versions shift-up" data-toggle="rst-versions" role="note" aria-label="versions">
+ <span class="rst-current-version" data-toggle="rst-current-version">
+ <span class="fa fa-book"> Python Optimal Transport</span>
+ versions
+ <span class="fa fa-caret-down"></span>
+ </span>
+ <div class="rst-other-versions"><!-- Inserted RTD Footer -->
+
+<div class="injected">
+
+
+
+ <dl>
+ <dt>Versions</dt>
+
+ <dd><a href="https://pythonot.github.io/master">latest</a></dd>
+
+ <dd><a href="https://pythonot.github.io/">stable</a></dd>
+
+ </dl>
+
+
+
+
+ <dl>
+ <dt>On GitHub</dt>
+ <dd>
+ <a href="https://github.com/PythonOT/POT">Code on Github</a>
+ </dd>
+
+ </dl>
+
+
+
+
+
+ <hr>
+
+
+
+</div>
+</div>
+ </div> \ No newline at end of file
diff --git a/docs/source/all.rst b/docs/source/all.rst
index d7b878f..6a07599 100644
--- a/docs/source/all.rst
+++ b/docs/source/all.rst
@@ -14,6 +14,7 @@ API and modules
:template: module.rst
lp
+ backend
bregman
smooth
gromov
@@ -27,6 +28,7 @@ API and modules
stochastic
unbalanced
partial
+ sliced
.. autosummary::
:toctree: ../modules/generated/
diff --git a/docs/source/auto_examples/images/bak.png b/docs/source/auto_examples/images/bak.png
new file mode 100644
index 0000000..25e7e8e
--- /dev/null
+++ b/docs/source/auto_examples/images/bak.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sinkhorn.png b/docs/source/auto_examples/images/sinkhorn.png
new file mode 100644
index 0000000..e003e13
--- /dev/null
+++ b/docs/source/auto_examples/images/sinkhorn.png
Binary files differ
diff --git a/docs/source/conf.py b/docs/source/conf.py
index 384bf40..9b5a719 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -92,7 +92,7 @@ master_doc = 'index'
# General information about the project.
project = u'POT Python Optimal Transport'
-copyright = u'2016-2020, Rémi Flamary, Nicolas Courty'
+copyright = u'2016-2021, Rémi Flamary, Nicolas Courty'
author = u'Rémi Flamary, Nicolas Courty'
# The version info for the project you're documenting, acts as replacement for
@@ -162,7 +162,7 @@ html_theme = 'sphinx_rtd_theme'
# Theme options are theme-specific and customize the look and feel of a theme
# further. For a list of options available for each theme, see the
# documentation.
-#html_theme_options = {}
+html_theme_options = {}
# Add any paths that contain custom themes here, relative to this directory.
#html_theme_path = []
@@ -337,7 +337,8 @@ texinfo_documents = [
intersphinx_mapping = {'python': ('https://docs.python.org/3', None),
'numpy': ('http://docs.scipy.org/doc/numpy/', None),
'scipy': ('http://docs.scipy.org/doc/scipy/reference/', None),
- 'matplotlib': ('http://matplotlib.org/', None)}
+ 'matplotlib': ('http://matplotlib.org/', None),
+ 'torch': ('https://pytorch.org/docs/stable/', None)}
sphinx_gallery_conf = {
'examples_dirs': ['../../examples', '../../examples/da'],
@@ -345,6 +346,7 @@ sphinx_gallery_conf = {
'backreferences_dir': 'gen_modules/backreferences',
'inspect_global_variables' : True,
'doc_module' : ('ot','numpy','scipy','pylab'),
+ 'matplotlib_animations': True,
'reference_url': {
'ot': None}
}
diff --git a/docs/source/quickstart.rst b/docs/source/quickstart.rst
index d56f812..232df7b 100644
--- a/docs/source/quickstart.rst
+++ b/docs/source/quickstart.rst
@@ -7,19 +7,175 @@ to use for different problems related to optimal transport (OT) and machine
learning. We refer when we can to concrete examples in the documentation that
are also available as notebooks on the POT Github.
-This document is not a tutorial on numerical optimal transport. For this we strongly
-recommend to read the very nice book [15]_ .
+.. note::
+
+ For a good introduction to numerical optimal transport we refer the reader
+ to `the book <https://arxiv.org/pdf/1803.00567.pdf>`_ by Peyré and Cuturi
+ [15]_. For more detailed introduction to OT and how it can be used
+ in ML applications we refer the reader to the following `OTML tutorial
+ <https://remi.flamary.com/cours/tuto_otml.html>`_.
+
+.. note::
+
+ Since version 0.8, POT provides a backend to automatically solve some OT
+ problems independently from the toolbox used by the user (numpy/torch/jax).
+ We provide a discussion about which functions are compatible in section
+ `Backend section <#solving-ot-with-multiple-backends>`_ .
+
+
+Why Optimal Transport ?
+-----------------------
+
+
+When to use OT
+^^^^^^^^^^^^^^
+
+Optimal Transport (OT) is a mathematical problem introduced by Gaspard Monge in
+1781 that aim at finding the most efficient way to move mass between
+distributions. The cost of moving a unit of mass between two positions is called
+the ground cost and the objective is to minimize the overall cost of moving one
+mass distribution onto another one. The optimization problem can be expressed
+for two distributions :math:`\mu_s` and :math:`\mu_t` as
+
+.. math::
+ \min_{m, m \# \mu_s = \mu_t} \int c(x,m(x))d\mu_s(x) ,
+
+where :math:`c(\cdot,\cdot)` is the ground cost and the constraint
+:math:`m \# \mu_s = \mu_t` ensures that :math:`\mu_s` is completely transported to :math:`\mu_t`.
+This problem is particularly difficult to solve because of this constraint and
+has been replaced in practice (on discrete distributions) by a
+linear program easier to solve. It corresponds to the Kantorovitch formulation
+where the Monge mapping :math:`m` is replaced by a joint distribution
+(OT matrix expressed in the next section) (see :ref:`kantorovitch_solve`).
+
+From the optimization problem above we can see that there are two main aspects
+to the OT solution that can be used in practical applications:
+
+- The optimal value (Wasserstein distance): Measures similarity between distributions.
+- The optimal mapping (Monge mapping, OT matrix): Finds correspondences between distributions.
+
+
+In the first case, OT can be used to measure similarity between distributions
+(or datasets), in this case the Wasserstein distance (the optimal value of the
+problem) is used. In the second case one can be interested in the way the mass
+is moved between the distributions (the mapping). This mapping can then be used
+to transfer knowledge between distributions.
+
+
+Wasserstein distance between distributions
+""""""""""""""""""""""""""""""""""""""""""
+
+OT is often used to measure similarity between distributions, especially
+when they do not share the same support. When the support between the
+distributions is disjoint OT-based Wasserstein distances compare favorably to
+popular f-divergences including the popular Kullback-Leibler, Jensen-Shannon
+divergences, and the Total Variation distance. What is particularly interesting
+for data science applications is that one can compute meaningful sub-gradients
+of the Wasserstein distance. For these reasons it became a very efficient tool
+for machine learning applications that need to measure and optimize similarity
+between empirical distributions.
+
+
+Numerous contributions make use of this an approach is the machine learning (ML)
+literature. For example OT was used for training `Generative
+Adversarial Networks (GANs) <https://arxiv.org/pdf/1701.07875.pdf>`_
+in order to overcome the vanishing gradient problem. It has also
+been used to find `discriminant <https://arxiv.org/pdf/1608.08063.pdf>`_ or
+`robust <https://arxiv.org/pdf/1901.08949.pdf>`_ subspaces for a dataset. The
+Wasserstein distance has also been used to measure `similarity between word
+embeddings of documents <http://proceedings.mlr.press/v37/kusnerb15.pdf>`_ or
+between `signals
+<https://www.math.ucdavis.edu/~saito/data/acha.read.s19/kolouri-etal_optimal-mass-transport.pdf>`_
+or `spectra <https://arxiv.org/pdf/1609.09799.pdf>`_.
+
+
+
+OT for mapping estimation
+"""""""""""""""""""""""""
+
+A very interesting aspect of OT problem is the OT mapping in itself. When
+computing optimal transport between discrete distributions one output is the OT
+matrix that will provide you with correspondences between the samples in each
+distributions.
+
+
+This correspondence is estimated with respect to the OT criterion and is found
+in a non-supervised way, which makes it very interesting on problems of transfer
+between datasets. It has been used to perform
+`color transfer between images <https://arxiv.org/pdf/1307.5551.pdf>`_ or in
+the context of `domain adaptation <https://arxiv.org/pdf/1507.00504.pdf>`_.
+More recent applications include the use of extension of OT (Gromov-Wasserstein)
+to find correspondences between languages in `word embeddings
+<https://arxiv.org/pdf/1809.00013.pdf>`_.
+
+
+When to use POT
+^^^^^^^^^^^^^^^
+
+
+The main objective of POT is to provide OT solvers for the rapidly growing area
+of OT in the context of machine learning. To this end we implement a number of
+solvers that have been proposed in research papers. Doing so we aim to promote
+reproducible research and foster novel developments.
+
+
+One very important aspect of POT is its ability to be easily extended. For
+instance we provide a very generic OT solver :any:`ot.optim.cg` that can solve
+OT problems with any smooth/continuous regularization term making it
+particularly practical for research purpose. Note that this generic solver has
+been used to solve both graph Laplacian regularization OT and Gromov
+Wasserstein [30]_.
+
+
+.. note::
+
+ POT is originally designed to solve OT problems with Numpy interface and
+ is not yet compatible with Pytorch API. We are currently working on a torch
+ submodule that will provide OT solvers and losses for the most common deep
+ learning configurations.
+
+
+When not to use POT
+"""""""""""""""""""
+
+While POT has to the best of our knowledge one of the most efficient exact OT
+solvers, it has not been designed to handle large scale OT problems. For
+instance the memory cost for an OT problem is always :math:`\mathcal{O}(n^2)` in
+memory because the cost matrix has to be computed. The exact solver in of time
+complexity :math:`\mathcal{O}(n^3\log(n))` and the Sinkhorn solver has been
+proven to be nearly :math:`\mathcal{O}(n^2)` which is still too complex for very
+large scale solvers.
+
+
+If you need to solve OT with large number of samples, we recommend to use
+entropic regularization and memory efficient implementation of Sinkhorn as
+proposed in `GeomLoss <https://www.kernel-operations.io/geomloss/>`_. This
+implementation is compatible with Pytorch and can handle large number of
+samples. Another approach to estimate the Wasserstein distance for very large
+number of sample is to use the trick from `Wasserstein GAN
+<https://arxiv.org/pdf/1701.07875.pdf>`_ that solves the problem
+in the dual with a neural network estimating the dual variable. Note that in this
+case you are only solving an approximation of the Wasserstein distance because
+the 1-Lipschitz constraint on the dual cannot be enforced exactly (approximated
+through filter thresholding or regularization). Finally note that in order to
+avoid solving large scale OT problems, a number of recent approached minimized
+the expected Wasserstein distance on minibtaches that is different from the
+Wasserstein but has better computational and
+`statistical properties <https://arxiv.org/pdf/1910.04091.pdf>`_.
Optimal transport and Wasserstein distance
------------------------------------------
.. note::
+
In POT, most functions that solve OT or regularized OT problems have two
versions that return the OT matrix or the value of the optimal solution. For
- instance :any:`ot.emd` return the OT matrix and :any:`ot.emd2` return the
+ instance :any:`ot.emd` returns the OT matrix and :any:`ot.emd2` returns the
Wassertsein distance. This approach has been implemented in practice for all
- solvers that return an OT matrix (even Gromov-Wasserstsein)
+ solvers that return an OT matrix (even Gromov-Wasserstsein).
+
+.. _kantorovitch_solve:
Solving optimal transport
^^^^^^^^^^^^^^^^^^^^^^^^^
@@ -28,30 +184,31 @@ The optimal transport problem between discrete distributions is often expressed
as
.. math::
- \gamma^* = arg\min_\gamma \quad \sum_{i,j}\gamma_{i,j}M_{i,j}
+ \gamma^* = arg\min_{\gamma \in \mathbb{R}_+^{m\times n}} \quad \sum_{i,j}\gamma_{i,j}M_{i,j}
s.t. \gamma 1 = a; \gamma^T 1= b; \gamma\geq 0
-where :
+where:
-- :math:`M\in\mathbb{R}_+^{m\times n}` is the metric cost matrix defining the cost to move mass from bin :math:`a_i` to bin :math:`b_j`.
-- :math:`a` and :math:`b` are histograms on the simplex (positive, sum to 1) that represent the
-weights of each samples in the source an target distributions.
+ - :math:`M\in\mathbb{R}_+^{m\times n}` is the metric cost matrix defining the cost to move mass from bin :math:`a_i` to bin :math:`b_j`.
+
+ - :math:`a` and :math:`b` are histograms on the simplex (positive, sum to 1) that represent the weights of each samples in the source an target distributions.
Solving the linear program above can be done using the function :any:`ot.emd`
that will return the optimal transport matrix :math:`\gamma^*`:
.. code:: python
- # a,b are 1D histograms (sum to 1 and positive)
+ # a and b are 1D histograms (sum to 1 and positive)
# M is the ground cost matrix
- T=ot.emd(a,b,M) # exact linear program
+ T = ot.emd(a, b, M) # exact linear program
-The method implemented for solving the OT problem is the network simplex, it is
-implemented in C from [1]_. It has a complexity of :math:`O(n^3)` but the
+The method implemented for solving the OT problem is the network simplex. It is
+implemented in C from [1]_. It has a complexity of :math:`O(n^3)` but the
solver is quite efficient and uses sparsity of the solution.
.. hint::
+
Examples of use for :any:`ot.emd` are available in :
- :any:`auto_examples/plot_OT_2D_samples`
@@ -62,10 +219,11 @@ solver is quite efficient and uses sparsity of the solution.
Computing Wasserstein distance
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
-The value of the OT solution is often more of interest than the OT matrix :
+The value of the OT solution is often more interesting than the OT matrix:
.. math::
- OT(a,b)=\min_\gamma \quad \sum_{i,j}\gamma_{i,j}M_{i,j}
+
+ OT(a,b) = \min_{\gamma \in \mathbb{R}_+^{m\times n}} \quad \sum_{i,j}\gamma_{i,j}M_{i,j}
s.t. \gamma 1 = a; \gamma^T 1= b; \gamma\geq 0
@@ -75,9 +233,9 @@ It can computed from an already estimated OT matrix with
.. code:: python
- # a,b are 1D histograms (sum to 1 and positive)
+ # a and b are 1D histograms (sum to 1 and positive)
# M is the ground cost matrix
- W=ot.emd2(a,b,M) # Wasserstein distance / EMD value
+ W = ot.emd2(a, b, M) # Wasserstein distance / EMD value
Note that the well known `Wasserstein distance
<https://en.wikipedia.org/wiki/Wasserstein_metric>`_ between distributions a and
@@ -86,19 +244,19 @@ b is defined as
.. math::
- W_p(a,b)=(\min_\gamma \sum_{i,j}\gamma_{i,j}\|x_i-y_j\|_p)^\frac{1}{p}
+ W_p(a,b)=(\min_{\gamma \in \mathbb{R}_+^{m\times n}} \sum_{i,j}\gamma_{i,j}\|x_i-y_j\|_p)^\frac{1}{p}
s.t. \gamma 1 = a; \gamma^T 1= b; \gamma\geq 0
This means that if you want to compute the :math:`W_2` you need to compute the
square root of :any:`ot.emd2` when providing
-:code:`M=ot.dist(xs,xt)` that use the squared euclidean distance by default. Computing
-the :math:`W_1` wasserstein distance can be done directly with :any:`ot.emd2`
-when providing :code:`M=ot.dist(xs,xt, metric='euclidean')` to use the euclidean
+:code:`M = ot.dist(xs, xt)`, that uses the squared euclidean distance by default. Computing
+the :math:`W_1` Wasserstein distance can be done directly with :any:`ot.emd2`
+when providing :code:`M = ot.dist(xs, xt, metric='euclidean')` to use the Euclidean
distance.
-
.. hint::
+
An example of use for :any:`ot.emd2` is available in :
- :any:`auto_examples/plot_compute_emd`
@@ -123,9 +281,9 @@ Another special case for estimating OT and Monge mapping is between Gaussian
distributions. In this case there exists a close form solution given in Remark
2.29 in [15]_ and the Monge mapping is an affine function and can be
also computed from the covariances and means of the source and target
-distributions. In the case when the finite sample dataset is supposed gaussian, we provide
-:any:`ot.da.OT_mapping_linear` that returns the parameters for the Monge
-mapping.
+distributions. In the case when the finite sample dataset is supposed Gaussian,
+we provide :any:`ot.da.OT_mapping_linear` that returns the parameters for the
+Monge mapping.
Regularized Optimal Transport
@@ -136,7 +294,7 @@ computational and statistical properties.
We address in this section the regularized OT problems that can be expressed as
.. math::
- \gamma^* = arg\min_\gamma \quad \sum_{i,j}\gamma_{i,j}M_{i,j} + \lambda\Omega(\gamma)
+ \gamma^* = arg\min_{\gamma \in \mathbb{R}_+^{m\times n}} \quad \sum_{i,j}\gamma_{i,j}M_{i,j} + \lambda\Omega(\gamma)
s.t. \gamma 1 = a; \gamma^T 1= b; \gamma\geq 0
@@ -175,8 +333,8 @@ solution of the resulting optimization problem can be expressed as:
where :math:`u` and :math:`v` are vectors and :math:`K=\exp(-M/\lambda)` where
the :math:`\exp` is taken component-wise. In order to solve the optimization
-problem, on can use an alternative projection algorithm called Sinkhorn-Knopp that can be very
-efficient for large values if regularization.
+problem, one can use an alternative projection algorithm called Sinkhorn-Knopp
+that can be very efficient for large values of regularization.
The Sinkhorn-Knopp algorithm is implemented in :any:`ot.sinkhorn` and
:any:`ot.sinkhorn2` that return respectively the OT matrix and the value of the
@@ -184,10 +342,10 @@ linear term. Note that the regularization parameter :math:`\lambda` in the
equation above is given to those functions with the parameter :code:`reg`.
>>> import ot
- >>> a=[.5,.5]
- >>> b=[.5,.5]
- >>> M=[[0.,1.],[1.,0.]]
- >>> ot.sinkhorn(a,b,M,1)
+ >>> a = [.5, .5]
+ >>> b = [.5, .5]
+ >>> M = [[0., 1.], [1., 0.]]
+ >>> ot.sinkhorn(a, b, M, 1)
array([[ 0.36552929, 0.13447071],
[ 0.13447071, 0.36552929]])
@@ -195,20 +353,27 @@ More details about the algorithms used are given in the following note.
.. note::
The main function to solve entropic regularized OT is :any:`ot.sinkhorn`.
- This function is a wrapper and the parameter :code:`method` help you select
+ This function is a wrapper and the parameter :code:`method` allows you to select
the actual algorithm used to solve the problem:
+ :code:`method='sinkhorn'` calls :any:`ot.bregman.sinkhorn_knopp` the
classic algorithm [2]_.
+ + :code:`method='sinkhorn_log'` calls :any:`ot.bregman.sinkhorn_log` the
+ sinkhorn algorithm in log space [2]_ that is more stable but can be
+ slower in numpy since `logsumexp` is not implmemented in parallel.
+ It is the recommended solver for applications that requires
+ differentiability with a small number of iterations.
+ :code:`method='sinkhorn_stabilized'` calls :any:`ot.bregman.sinkhorn_stabilized` the
log stabilized version of the algorithm [9]_.
+ :code:`method='sinkhorn_epsilon_scaling'` calls
:any:`ot.bregman.sinkhorn_epsilon_scaling` the epsilon scaling version
of the algorithm [9]_.
+ :code:`method='greenkhorn'` calls :any:`ot.bregman.greenkhorn` the
- greedy sinkhorn verison of the algorithm [22]_.
+ greedy Sinkhorn version of the algorithm [22]_.
+ + :code:`method='screenkhorn'` calls :any:`ot.bregman.screenkhorn` the
+ screening sinkhorn version of the algorithm [26]_.
- In addition to all those variants of sinkhorn, we have another
+ In addition to all those variants of Sinkhorn, we have another
implementation solving the problem in the smooth dual or semi-dual in
:any:`ot.smooth`. This solver uses the :any:`scipy.optimize.minimize`
function to solve the smooth problem with :code:`L-BFGS-B` algorithm. Tu use
@@ -216,12 +381,31 @@ More details about the algorithms used are given in the following note.
:any:`ot.smooth.smooth_ot_semi_dual` with parameter :code:`reg_type='kl'` to
choose entropic/Kullbach Leibler regularization.
+ **Choosing a Sinkhorn solver**
+
+ By default and when using a regularization parameter that is not too small
+ the default Sinkhorn solver should be enough. If you need to use a small
+ regularization to get sharper OT matrices, you should use the
+ :any:`ot.bregman.sinkhorn_stabilized` solver that will avoid numerical
+ errors. This last solver can be very slow in practice and might not even
+ converge to a reasonable OT matrix in a finite time. This is why
+ :any:`ot.bregman.sinkhorn_epsilon_scaling` that relie on iterating the value
+ of the regularization (and using warm start) sometimes leads to better
+ solutions. Note that the greedy version of the Sinkhorn
+ :any:`ot.bregman.greenkhorn` can also lead to a speedup and the screening
+ version of the Sinkhorn :any:`ot.bregman.screenkhorn` aim a providing a
+ fast approximation of the Sinkhorn problem. For use of GPU and gradient
+ computation with small number of iterations we strongly recommend the
+ :any:`ot.bregman.sinkhorn_log` solver that will no need to check for
+ numerical problems.
-Recently [23]_ introduced the sinkhorn divergence that build from entropic
+
+
+Recently Genevay et al. [23]_ introduced the Sinkhorn divergence that build from entropic
regularization to compute fast and differentiable geometric divergence between
-empirical distributions. Note that we provide a function that compute directly
-(with no need to pre compute the :code:`M` matrix)
-the sinkhorn divergence for empirical distributions in
+empirical distributions. Note that we provide a function that computes directly
+(with no need to precompute the :code:`M` matrix)
+the Sinkhorn divergence for empirical distributions in
:any:`ot.bregman.empirical_sinkhorn_divergence`. Similarly one can compute the
OT matrix and loss for empirical distributions with respectively
:any:`ot.bregman.empirical_sinkhorn` and :any:`ot.bregman.empirical_sinkhorn2`.
@@ -229,7 +413,7 @@ OT matrix and loss for empirical distributions with respectively
Finally note that we also provide in :any:`ot.stochastic` several implementation
of stochastic solvers for entropic regularized OT [18]_ [19]_. Those pure Python
-implementations are not optimized for speed but provide a roust implementation
+implementations are not optimized for speed but provide a robust implementation
of algorithms in [18]_ [19]_.
.. hint::
@@ -244,11 +428,11 @@ of algorithms in [18]_ [19]_.
Other regularization
^^^^^^^^^^^^^^^^^^^^
-While entropic OT is the most common and favored in practice, there exist other
-kind of regularization. We provide in POT two specific solvers for other
-regularization terms, namely quadratic regularization and group lasso
-regularization. But we also provide in :any:`ot.optim` two generic solvers that allows solving any
-smooth regularization in practice.
+While entropic OT is the most common and favored in practice, there exists other
+kinds of regularizations. We provide in POT two specific solvers for other
+regularization terms, namely quadratic regularization and group Lasso
+regularization. But we also provide in :any:`ot.optim` two generic solvers
+that allows solving any smooth regularization in practice.
Quadratic regularization
""""""""""""""""""""""""
@@ -259,8 +443,8 @@ regularization of the form
.. math::
\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}^2
-this regularization term has a similar effect to entropic regularization in
-densifying the OT matrix but it keeps some sort of sparsity that is lost with
+This regularization term has an effect similar to entropic regularization by
+densifying the OT matrix, yet it keeps some sort of sparsity that is lost with
entropic regularization as soon as :math:`\lambda>0` [17]_. This problem can be
solved with POT using solvers from :any:`ot.smooth`, more specifically
functions :any:`ot.smooth.smooth_ot_dual` or
@@ -278,30 +462,29 @@ choose the quadratic regularization.
Group Lasso regularization
""""""""""""""""""""""""""
-Another regularization that has been used in recent years [5]_ is the group lasso
+Another regularization that has been used in recent years [5]_ is the group Lasso
regularization
.. math::
\Omega(\gamma)=\sum_{j,G\in\mathcal{G}} \|\gamma_{G,j}\|_q^p
-where :math:`\mathcal{G}` contains non overlapping groups of lines in the OT
-matrix. This regularization proposed in [5]_ will promote sparsity at the group level and for
+where :math:`\mathcal{G}` contains non-overlapping groups of lines in the OT
+matrix. This regularization proposed in [5]_ promotes sparsity at the group level and for
instance will force target samples to get mass from a small number of groups.
Note that the exact OT solution is already sparse so this regularization does
-not make sens if it is not combined with entropic regularization. Depending on
+not make sense if it is not combined with entropic regularization. Depending on
the choice of :code:`p` and :code:`q`, the problem can be solved with different
-approaches. When :code:`q=1` and :code:`p<1` the problem is non convex but can
+approaches. When :code:`q=1` and :code:`p<1` the problem is non-convex but can
be solved using an efficient majoration minimization approach with
:any:`ot.sinkhorn_lpl1_mm`. When :code:`q=2` and :code:`p=1` we recover the
convex group lasso and we provide a solver using generalized conditional
-gradient algorithm [7]_ in function
-:any:`ot.da.sinkhorn_l1l2_gl`.
+gradient algorithm [7]_ in function :any:`ot.da.sinkhorn_l1l2_gl`.
.. hint::
- Examples of group Lasso regularization are available in :
+ Examples of group Lasso regularization are available in:
- - :any:`auto_examples/plot_otda_classes`
- - :any:`auto_examples/plot_otda_d2`
+ - :any:`auto_examples/domain-adaptation/plot_otda_classes`
+ - :any:`auto_examples/domain-adaptation/plot_otda_d2`
Generic solvers
@@ -322,11 +505,10 @@ you can use function :any:`ot.optim.cg` that will use a conditional gradient as
proposed in [6]_ . You need to provide the regularization function as parameter
``f`` and its gradient as parameter ``df``. Note that the conditional gradient relies on
iterative solving of a linearization of the problem using the exact
-:any:`ot.emd` so it can be slow in practice. But, being an interior point
-algorithm, it always returns a
-transport matrix that does not violates the marginals.
+:any:`ot.emd` so it can be quite slow in practice. However, being an interior point
+algorithm, it always returns a transport matrix that does not violates the marginals.
-Another generic solver is proposed to solve the problem
+Another generic solver is proposed to solve the problem:
.. math::
\gamma^* = arg\min_\gamma \quad \sum_{i,j}\gamma_{i,j}M_{i,j}+ \lambda_e\Omega_e(\gamma) + \lambda\Omega(\gamma)
@@ -347,7 +529,7 @@ relies on :any:`ot.sinkhorn` for its iterations.
Wasserstein Barycenters
-----------------------
-A Wasserstein barycenter is a distribution that minimize its Wasserstein
+A Wasserstein barycenter is a distribution that minimizes its Wasserstein
distance with respect to other distributions [16]_. It corresponds to minimizing the
following problem by searching a distribution :math:`\mu` such that
@@ -378,18 +560,18 @@ be expressed as
where :math:`b_k` are also weights in the simplex. In the non-regularized case,
the problem above is a classical linear program. In this case we propose a
-solver :any:`ot.lp.barycenter` that rely on generic LP solvers. By default the
+solver :meth:`ot.lp.barycenter` that relies on generic LP solvers. By default the
function uses :any:`scipy.optimize.linprog`, but more efficient LP solvers from
cvxopt can be also used by changing parameter :code:`solver`. Note that this problem
requires to solve a very large linear program and can be very slow in
practice.
Similarly to the OT problem, OT barycenters can be computed in the regularized
-case. When using entropic regularization is used, the problem can be solved with a
-generalization of the sinkhorn algorithm based on bregman projections [3]_. This
+case. When entropic regularization is used, the problem can be solved with a
+generalization of the Sinkhorn algorithm based on Bregman projections [3]_. This
algorithm is provided in function :any:`ot.bregman.barycenter` also available as
:any:`ot.barycenter`. In this case, the algorithm scales better to large
-distributions and rely only on matrix multiplications that can be performed in
+distributions and relies only on matrix multiplications that can be performed in
parallel.
In addition to the speedup brought by regularization, one can also greatly
@@ -400,18 +582,18 @@ operators. We provide an implementation of this algorithm in function
:any:`ot.bregman.convolutional_barycenter2d`.
.. hint::
- Examples of Wasserstein (:any:`ot.lp.barycenter`) and regularized Wasserstein
+ Examples of Wasserstein (:meth:`ot.lp.barycenter`) and regularized Wasserstein
barycenter (:any:`ot.bregman.barycenter`) computation are available in :
- - :any:`auto_examples/plot_barycenter_1D`
- - :any:`auto_examples/plot_barycenter_lp_vs_entropic`
+ - :any:`auto_examples/barycenters/plot_barycenter_1D`
+ - :any:`auto_examples/barycenters/plot_barycenter_lp_vs_entropic`
An example of convolutional barycenter
(:any:`ot.bregman.convolutional_barycenter2d`) computation
for 2D images is available
in :
- - :any:`auto_examples/plot_convolutional_barycenter`
+ - :any:`auto_examples/barycenters/plot_convolutional_barycenter`
@@ -419,7 +601,7 @@ Barycenters with free support
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Estimating the Wasserstein barycenter with free support but fixed weights
-corresponds to solving the following optimization problem:
+corresponds to solving the following optimization problem:
.. math::
\min_{\{x_i\}} \quad \sum_{k} w_kW(\mu,\mu_k)
@@ -436,7 +618,7 @@ return a locally optimal support :math:`\{x_i\}` for uniform or given weights
An example of the free support barycenter estimation is available
in :
- - :any:`auto_examples/plot_free_support_barycenter`
+ - :any:`auto_examples/barycenters/plot_free_support_barycenter`
@@ -444,7 +626,7 @@ return a locally optimal support :math:`\{x_i\}` for uniform or given weights
Monge mapping and Domain adaptation
-----------------------------------
-The original transport problem investigated by Gaspard Monge was seeking for a
+The original transport problem investigated by Gaspard Monge was seeking for a
mapping function that maps (or transports) between a source and target
distribution but that minimizes the transport loss. The existence and uniqueness of this
optimal mapping is still an open problem in the general case but has been proven
@@ -462,24 +644,24 @@ approximate a Monge mapping from finite distributions.
First note that when the source and target distributions are supposed to be Gaussian
distributions, there exists a close form solution for the mapping and its an
affine function [14]_ of the form :math:`T(x)=Ax+b` . In this case we provide the function
-:any:`ot.da.OT_mapping_linear` that return the operator :math:`A` and vector
+:any:`ot.da.OT_mapping_linear` that returns the operator :math:`A` and vector
:math:`b`. Note that if the number of samples is too small there is a parameter
-:code:`reg` that provide a regularization for the covariance matrix estimation.
+:code:`reg` that provides a regularization for the covariance matrix estimation.
For a more general mapping estimation we also provide the barycentric mapping
-proposed in [6]_ . It is implemented in the class :any:`ot.da.EMDTransport` and
-other transport based classes in :any:`ot.da` . Those classes are discussed more
-in the following but follow an interface similar to sklearn classes. Finally a
+proposed in [6]_. It is implemented in the class :any:`ot.da.EMDTransport` and
+other transport-based classes in :any:`ot.da` . Those classes are discussed more
+in the following but follow an interface similar to scikit-learn classes. Finally a
method proposed in [8]_ that estimates a continuous mapping approximating the
barycentric mapping is provided in :any:`ot.da.joint_OT_mapping_linear` for
-linear mapping and :any:`ot.da.joint_OT_mapping_kernel` for non linear mapping.
+linear mapping and :any:`ot.da.joint_OT_mapping_kernel` for non-linear mapping.
.. hint::
An example of the linear Monge mapping estimation is available
in :
- - :any:`auto_examples/plot_otda_linear_mapping`
+ - :any:`auto_examples/domain-adaptation/plot_otda_linear_mapping`
Domain adaptation classes
^^^^^^^^^^^^^^^^^^^^^^^^^
@@ -491,21 +673,19 @@ transport labeled source samples onto the target distribution with no labels.
We provide several classes based on :any:`ot.da.BaseTransport` that provide
several OT and mapping estimations. The interface of those classes is similar to
-classifiers in sklearn toolbox. At initialization, several parameters such as
- regularization parameter value can be set. Then one needs to estimate the
+classifiers in scikit-learn. At initialization, several parameters such as
+regularization parameter value can be set. Then one needs to estimate the
mapping with function :any:`ot.da.BaseTransport.fit`. Finally one can map the
samples from source to target with :any:`ot.da.BaseTransport.transform` and
from target to source with :any:`ot.da.BaseTransport.inverse_transform`.
-Here is
-an example for class :any:`ot.da.EMDTransport` :
+Here is an example for class :any:`ot.da.EMDTransport`:
.. code::
ot_emd = ot.da.EMDTransport()
ot_emd.fit(Xs=Xs, Xt=Xt)
-
- Mapped_Xs= ot_emd.transform(Xs=Xs)
+ Xs_mapped = ot_emd.transform(Xs=Xs)
A list of the provided implementation is given in the following note.
@@ -514,24 +694,24 @@ A list of the provided implementation is given in the following note.
Here is a list of the OT mapping classes inheriting from
:any:`ot.da.BaseTransport`
- * :any:`ot.da.EMDTransport` : Barycentric mapping with EMD transport
- * :any:`ot.da.SinkhornTransport` : Barycentric mapping with Sinkhorn transport
- * :any:`ot.da.SinkhornL1l2Transport` : Barycentric mapping with Sinkhorn +
+ * :any:`ot.da.EMDTransport`: Barycentric mapping with EMD transport
+ * :any:`ot.da.SinkhornTransport`: Barycentric mapping with Sinkhorn transport
+ * :any:`ot.da.SinkhornL1l2Transport`: Barycentric mapping with Sinkhorn +
group Lasso regularization [5]_
- * :any:`ot.da.SinkhornLpl1Transport` : Barycentric mapping with Sinkhorn +
+ * :any:`ot.da.SinkhornLpl1Transport`: Barycentric mapping with Sinkhorn +
non convex group Lasso regularization [5]_
- * :any:`ot.da.LinearTransport` : Linear mapping estimation between Gaussians
+ * :any:`ot.da.LinearTransport`: Linear mapping estimation between Gaussians
[14]_
- * :any:`ot.da.MappingTransport` : Nonlinear mapping estimation [8]_
+ * :any:`ot.da.MappingTransport`: Nonlinear mapping estimation [8]_
.. hint::
- Example of the use of OTDA classes are available in :
+ Examples of the use of OTDA classes are available in:
- - :any:`auto_examples/plot_otda_color_images`
- - :any:`auto_examples/plot_otda_mapping`
- - :any:`auto_examples/plot_otda_mapping_colors_images`
- - :any:`auto_examples/plot_otda_semi_supervised`
+ - :any:`auto_examples/domain-adaptation/plot_otda_color_images`
+ - :any:`auto_examples/domain-adaptation/plot_otda_mapping`
+ - :any:`auto_examples/domain-adaptation/plot_otda_mapping_colors_images`
+ - :any:`auto_examples/domain-adaptation/plot_otda_semi_supervised`
Other applications
------------------
@@ -545,14 +725,14 @@ Wasserstein Discriminant Analysis
Wasserstein Discriminant Analysis [11]_ is a generalization of `Fisher Linear Discriminant
Analysis <https://en.wikipedia.org/wiki/Linear_discriminant_analysis>`__ that
allows discrimination between classes that are not linearly separable. It
-consist in finding a linear projector optimizing the following criterion
+consists in finding a linear projector optimizing the following criterion
.. math::
P = \text{arg}\min_P \frac{\sum_i OT_e(\mu_i\#P,\mu_i\#P)}{\sum_{i,j\neq i}
OT_e(\mu_i\#P,\mu_j\#P)}
where :math:`\#` is the push-forward operator, :math:`OT_e` is the entropic OT
-loss and :math:`\mu_i` is the
+loss and :math:`\mu_i` is the
distribution of samples from class :math:`i`. :math:`P` is also constrained to
be in the Stiefel manifold. WDA can be solved in POT using function
:any:`ot.dr.wda`. It requires to have installed :code:`pymanopt` and
@@ -561,6 +741,7 @@ respectively. Note that we also provide the Fisher discriminant estimator in
:any:`ot.dr.fda` for easy comparison.
.. warning::
+
Note that due to the hard dependency on :code:`pymanopt` and
:code:`autograd`, :any:`ot.dr` is not imported by default. If you want to
use it you have to specifically import it with :code:`import ot.dr` .
@@ -569,7 +750,7 @@ respectively. Note that we also provide the Fisher discriminant estimator in
An example of the use of WDA is available in :
- - :any:`auto_examples/plot_WDA`
+ - :any:`auto_examples/others/plot_WDA`
Unbalanced optimal transport
@@ -610,7 +791,7 @@ linear term.
Examples of the use of :any:`ot.sinkhorn_unbalanced` are available in :
- - :any:`auto_examples/plot_UOT_1D`
+ - :any:`auto_examples/unbalanced-partial/plot_UOT_1D`
Unbalanced Barycenters
@@ -622,17 +803,17 @@ histograms with different masses as a Fréchet Mean:
.. math::
\min_{\mu} \quad \sum_{k} w_kW_u(\mu,\mu_k)
-Where :math:`W_u` is the unbalanced Wasserstein metric defined above. This problem
+where :math:`W_u` is the unbalanced Wasserstein metric defined above. This problem
can also be solved using generalized version of Sinkhorn's algorithm and it is
implemented the main function :any:`ot.barycenter_unbalanced`.
.. note::
The main function to compute UOT barycenters is :any:`ot.barycenter_unbalanced`.
- This function is a wrapper and the parameter :code:`method` help you select
+ This function is a wrapper and the parameter :code:`method` helps you select
the actual algorithm used to solve the problem:
- + :code:`method='sinkhorn'` calls :any:`ot.unbalanced.barycenter_unbalanced_sinkhorn_unbalanced`
+ + :code:`method='sinkhorn'` calls :meth:`ot.unbalanced.barycenter_unbalanced_sinkhorn_unbalanced`
the generalized Sinkhorn algorithm [10]_.
+ :code:`method='sinkhorn_stabilized'` calls :any:`ot.unbalanced.barycenter_unbalanced_stabilized`
the log stabilized version of the algorithm [10]_.
@@ -642,7 +823,7 @@ implemented the main function :any:`ot.barycenter_unbalanced`.
Examples of the use of :any:`ot.barycenter_unbalanced` are available in :
- - :any:`auto_examples/plot_UOT_barycenter_1D`
+ - :any:`auto_examples/unbalanced-partial/plot_UOT_barycenter_1D`
Partial optimal transport
@@ -686,9 +867,9 @@ regularization of the problem.
.. hint::
- Examples of the use of :any:`ot.partial` are available in :
+ Examples of the use of :any:`ot.partial` are available in:
- - :any:`auto_examples/plot_partial`
+ - :any:`auto_examples/unbalanced-partial/plot_partial_wass_and_gromov`
@@ -699,7 +880,7 @@ Gromov Wasserstein (GW) is a generalization of OT to distributions that do not l
the same space [13]_. In this case one cannot compute distance between samples
from the two distributions. [13]_ proposed instead to realign the metric spaces
by computing a transport between distance matrices. The Gromow Wasserstein
-alignement between two distributions can be expressed as the one minimizing:
+alignment between two distributions can be expressed as the one minimizing:
.. math::
GW = \min_\gamma \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*\gamma_{i,j}*\gamma_{k,l}
@@ -731,8 +912,8 @@ positive matrix. We provide a block coordinate optimization procedure in
barycenters respectively.
Finally note that recently a fusion between Wasserstein and GW, coined Fused
-Gromov-Wasserstein (FGW) has been proposed
-in [24]_. It allows to compute a similarity between objects that are only partly in
+Gromov-Wasserstein (FGW) has been proposed [24]_.
+It allows to compute a similarity between objects that are only partly in
the same space. As such it can be used to measure similarity between labeled
graphs for instance and also provide computable barycenters.
The implementations of FGW and FGW barycenter is provided in functions
@@ -740,20 +921,27 @@ The implementations of FGW and FGW barycenter is provided in functions
.. hint::
- Examples of computation of GW, regularized G and FGW are available in :
+ Examples of computation of GW, regularized G and FGW are available in:
- - :any:`auto_examples/plot_gromov`
- - :any:`auto_examples/plot_fgw`
+ - :any:`auto_examples/gromov/plot_gromov`
+ - :any:`auto_examples/gromov/plot_fgw`
- Examples of GW, regularized GW and FGW barycenters are available in :
+ Examples of GW, regularized GW and FGW barycenters are available in:
- - :any:`auto_examples/plot_gromov_barycenter`
- - :any:`auto_examples/plot_barycenter_fgw`
+ - :any:`auto_examples/gromov/plot_gromov_barycenter`
+ - :any:`auto_examples/gromov/plot_barycenter_fgw`
GPU acceleration
^^^^^^^^^^^^^^^^
+.. warning::
+
+ The :any:`ot.gpu` has been deprecated since the release 0.8 of POT and
+ should not be used. The GPU implementation (in Pytorch for instance) can be
+ used with the novel backends using the compatible functions from POT.
+
+
We provide several implementation of our OT solvers in :any:`ot.gpu`. Those
implementations use the :code:`cupy` toolbox that obviously need to be installed.
@@ -764,28 +952,80 @@ implementations use the :code:`cupy` toolbox that obviously need to be installed
algebra) have been implemented in :any:`ot.gpu`. Here is a short list on the
main entries:
- - :any:`ot.gpu.dist` : computation of distance matrix
- - :any:`ot.gpu.sinkhorn` : computation of sinkhorn
- - :any:`ot.gpu.sinkhorn_lpl1_mm` : computation of sinkhorn + group lasso
+ - :meth:`ot.gpu.dist`: computation of distance matrix
+ - :meth:`ot.gpu.sinkhorn`: computation of sinkhorn
+ - :meth:`ot.gpu.sinkhorn_lpl1_mm`: computation of sinkhorn + group lasso
Note that while the :any:`ot.gpu` module has been designed to be compatible with
-POT, calling its function with :any:`numpy` arrays will incur a large overhead due to
+POT, calling its function with :any:`numpy` arrays will incur a large overhead due to
the memory copy of the array on GPU prior to computation and conversion of the
array after computation. To avoid this overhead, we provide functions
-:any:`ot.gpu.to_gpu` and :any:`ot.gpu.to_np` that perform the conversion
+:meth:`ot.gpu.to_gpu` and :meth:`ot.gpu.to_np` that perform the conversion
explicitly.
-
.. warning::
- Note that due to the hard dependency on :code:`cupy`, :any:`ot.gpu` is not
+
+ Note that due to the hard dependency on :code:`cupy`, :any:`ot.gpu` is not
imported by default. If you want to
use it you have to specifically import it with :code:`import ot.gpu` .
-FAQ
----
+Solving OT with Multiple backends
+---------------------------------
+
+.. _backends_section:
+
+Since version 0.8, POT provides a backend that allows to code solvers
+independently from the type of the input arrays. The idea is to provide the user
+with a package that works seamlessly and returns a solution for instance as a
+Pytorch tensors when the function has Pytorch tensors as input.
+How it works
+^^^^^^^^^^^^
+
+The aim of the backend is to use the same function independently of the type of
+the input arrays.
+
+For instance when executing the following code
+
+.. code:: python
+
+ # a and b are 1D histograms (sum to 1 and positive)
+ # M is the ground cost matrix
+ T = ot.emd(a, b, M) # exact linear program
+ w = ot.emd2(a, b, M) # Wasserstein computation
+
+the functions :any:`ot.emd` and :any:`ot.emd2` can take inputs of the type
+:any:`numpy.array`, :any:`torch.tensor` or :any:`jax.numpy.array`. The output of
+the function will be the same type as the inputs and on the same device. When
+possible all computations are done on the same device and also when possible the
+output will be differentiable with respect to the input of the function.
+
+
+
+List of compatible Backends
+^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+- `Numpy <https://numpy.org/>`_ (all functions and solvers)
+- `Pytorch <https://pytorch.org/>`_ (all outputs differentiable w.r.t. inputs)
+- `Jax <https://github.com/google/jax>`_ (Some functions are differentiable some require a wrapper)
+
+List of compatible functions
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+This list will get longer for new releases and will hopefully disappear when POT
+become fully implemented with the backend.
+
+- :any:`ot.emd`
+- :any:`ot.emd2`
+- :any:`ot.sinkhorn`
+- :any:`ot.sinkhorn2`
+- :any:`ot.dist`
+
+
+FAQ
+---
1. **How to solve a discrete optimal transport problem ?**
@@ -798,10 +1038,10 @@ FAQ
.. code:: python
- # a,b are 1D histograms (sum to 1 and positive)
+ # a and b are 1D histograms (sum to 1 and positive)
# M is the ground cost matrix
- T=ot.emd(a,b,M) # exact linear program
- T_reg=ot.sinkhorn(a,b,M,reg) # entropic regularized OT
+ T = ot.emd(a, b, M) # exact linear program
+ T_reg = ot.sinkhorn(a, b, M, reg) # entropic regularized OT
More detailed examples can be seen on this example:
:doc:`auto_examples/plot_OT_2D_samples`
@@ -823,15 +1063,15 @@ FAQ
3. **Why is Sinkhorn slower than EMD ?**
This might come from the choice of the regularization term. The speed of
- convergence of sinkhorn depends directly on this term [22]_ and when the
- regularization gets very small the problem try and approximate the exact OT
+ convergence of Sinkhorn depends directly on this term [22]_. When the
+ regularization gets very small the problem tries to approximate the exact OT
which leads to slow convergence in addition to numerical problems. In other
- words, for large regularization sinkhorn will be very fast to converge, for
+ words, for large regularization Sinkhorn will be very fast to converge, for
small regularization (when you need an OT matrix close to the true OT), it
might be quicker to use the EMD solver.
- Also note that the numpy implementation of the sinkhorn can use parallel
- computation depending on the configuration of your system but very important
+ Also note that the numpy implementation of Sinkhorn can use parallel
+ computation depending on the configuration of your system, yet very important
speedup can be obtained by using a GPU implementation since all operations
are matrix/vector products.
@@ -863,11 +1103,6 @@ References
problems <https://arxiv.org/pdf/1412.5154.pdf>`__. SIAM Journal on
Scientific Computing, 37(2), A1111-A1138.
-.. [4] S. Nakhostin, N. Courty, R. Flamary, D. Tuia, T. Corpetti,
- `Supervised planetary unmixing with optimal
- transport <https://hal.archives-ouvertes.fr/hal-01377236/document>`__,
- Whorkshop on Hyperspectral Image and Signal Processing : Evolution in
- Remote Sensing (WHISPERS), 2016.
.. [5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, `Optimal Transport
for Domain Adaptation <https://arxiv.org/pdf/1507.00504.pdf>`__, in IEEE
@@ -955,7 +1190,7 @@ References
iteration <https://papers.nips.cc/paper/6792-near-linear-time-approximation-algorithms-for-optimal-transport-via-sinkhorn-iteration.pdf>`__,
Advances in Neural Information Processing Systems (NIPS) 31
-.. [23] Aude, G., Peyré, G., Cuturi, M., `Learning Generative Models with
+.. [23] Genevay, A., Peyré, G., Cuturi, M., `Learning Generative Models with
Sinkhorn Divergences <https://arxiv.org/abs/1706.00292>`__, Proceedings
of the Twenty-First International Conference on Artficial Intelligence
and Statistics, (AISTATS) 21, 2018
@@ -972,11 +1207,6 @@ References
.. [26] Alaya M. Z., Bérar M., Gasso G., Rakotomamonjy A. (2019). Screening Sinkhorn
Algorithm for Regularized Optimal Transport <https://papers.nips.cc/paper/9386-screening-sinkhorn-algorithm-for-regularized-optimal-transport>,
Advances in Neural Information Processing Systems 33 (NeurIPS).
-
-.. [27] Redko I., Courty N., Flamary R., Tuia D. (2019). Optimal Transport for Multi-source
- Domain Adaptation under Target Shift <http://proceedings.mlr.press/v89/redko19a.html>,
- Proceedings of the Twenty-Second International Conference on Artificial Intelligence
- and Statistics (AISTATS) 22, 2019.
.. [28] Caffarelli, L. A., McCann, R. J. (2020). Free boundaries in optimal transport and
Monge-Ampere obstacle problems <http://www.math.toronto.edu/~mccann/papers/annals2010.pdf>,
@@ -985,3 +1215,7 @@ References
.. [29] Chapel, L., Alaya, M., Gasso, G. (2019). Partial Gromov-Wasserstein with
Applications on Positive-Unlabeled Learning <https://arxiv.org/abs/2002.08276>,
arXiv preprint arXiv:2002.08276.
+
+.. [30] Flamary, Rémi, et al. "Optimal transport with Laplacian regularization:
+ Applications to domain adaptation and shape matching." NIPS Workshop on Optimal
+ Transport and Machine Learning OTML. 2014.
diff --git a/docs/source/readme.rst b/docs/source/readme.rst
index b8cb48c..a8f1bc0 100644
--- a/docs/source/readme.rst
+++ b/docs/source/readme.rst
@@ -24,10 +24,9 @@ POT provides the following generic OT solvers (links to examples):
for regularized OT [7].
- Entropic regularization OT solver with `Sinkhorn Knopp
Algorithm <auto_examples/plot_OT_1D.html>`__
- [2] , stabilized version [9] [10], greedy Sinkhorn [22] and
+ [2] , stabilized version [9] [10] [34], greedy Sinkhorn [22] and
`Screening Sinkhorn
- [26] <auto_examples/plot_screenkhorn_1D.html>`__
- with optional GPU implementation (requires cupy).
+ [26] <auto_examples/plot_screenkhorn_1D.html>`__.
- Bregman projections for `Wasserstein
barycenter <auto_examples/barycenters/plot_barycenter_lp_vs_entropic.html>`__
[3], `convolutional
@@ -35,6 +34,9 @@ POT provides the following generic OT solvers (links to examples):
[21] and unmixing [4].
- Sinkhorn divergence [23] and entropic regularization OT from
empirical data.
+- Debiased Sinkhorn barycenters `Sinkhorn divergence
+ barycenter <auto_examples/barycenters/plot_debiased_barycenter.html>`__
+ [37]
- `Smooth optimal transport
solvers <auto_examples/plot_OT_1D_smooth.html>`__
(dual and semi-dual) for KL and squared L2 regularizations [17].
@@ -45,7 +47,8 @@ POT provides the following generic OT solvers (links to examples):
distances <auto_examples/gromov/plot_gromov.html>`__
and `GW
barycenters <auto_examples/gromov/plot_gromov_barycenter.html>`__
- (exact [13] and regularized [12])
+ (exact [13] and regularized [12]), differentiable using gradients
+ from
- `Fused-Gromov-Wasserstein distances
solver <auto_examples/gromov/plot_fgw.html#sphx-glr-auto-examples-plot-fgw-py>`__
and `FGW
@@ -55,6 +58,9 @@ POT provides the following generic OT solvers (links to examples):
solver <auto_examples/plot_stochastic.html>`__
for Large-scale Optimal Transport (semi-dual problem [18] and dual
problem [19])
+- `Stochastic solver of Gromov
+ Wasserstein <auto_examples/gromov/plot_gromov.html>`__
+ for large-scale problem with any loss functions [33]
- Non regularized `free support Wasserstein
barycenters <auto_examples/barycenters/plot_free_support_barycenter.html>`__
[20].
@@ -66,6 +72,15 @@ POT provides the following generic OT solvers (links to examples):
- `Partial Wasserstein and
Gromov-Wasserstein <auto_examples/unbalanced-partial/plot_partial_wass_and_gromov.html>`__
(exact [29] and entropic [3] formulations).
+- `Sliced
+ Wasserstein <auto_examples/sliced-wasserstein/plot_variance.html>`__
+ [31, 32] and Max-sliced Wasserstein [35] that can be used for
+ gradient flows [36].
+- `Several
+ backends <https://pythonot.github.io/quickstart.html#solving-ot-with-multiple-backends>`__
+ for easy use of POT with
+ `Pytorch <https://pytorch.org/>`__/`jax <https://github.com/google/jax>`__/`Numpy <https://numpy.org/>`__
+ arrays.
POT provides the following Machine Learning related solvers:
@@ -96,22 +111,29 @@ Using and citing the toolbox
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
If you use this toolbox in your research and find it useful, please cite
-POT using the following reference:
+POT using the following reference from our `JMLR
+paper <https://jmlr.org/papers/v22/20-451.html>`__:
::
- Rémi Flamary and Nicolas Courty, POT Python Optimal Transport library,
- Website: https://pythonot.github.io/, 2017
+ Rémi Flamary, Nicolas Courty, Alexandre Gramfort, Mokhtar Z. Alaya, Aurélie Boisbunon, Stanislas Chambon, Laetitia Chapel, Adrien Corenflos, Kilian Fatras, Nemo Fournier, Léo Gautheron, Nathalie T.H. Gayraud, Hicham Janati, Alain Rakotomamonjy, Ievgen Redko, Antoine Rolet, Antony Schutz, Vivien Seguy, Danica J. Sutherland, Romain Tavenard, Alexander Tong, Titouan Vayer,
+ POT Python Optimal Transport library,
+ Journal of Machine Learning Research, 22(78):1−8, 2021.
+ Website: https://pythonot.github.io/
In Bibtex format:
-::
-
- @misc{flamary2017pot,
- title={POT Python Optimal Transport library},
- author={Flamary, R{'e}mi and Courty, Nicolas},
- url={https://pythonot.github.io/},
- year={2017}
+.. code:: bibtex
+
+ @article{flamary2021pot,
+ author = {R{\'e}mi Flamary and Nicolas Courty and Alexandre Gramfort and Mokhtar Z. Alaya and Aur{\'e}lie Boisbunon and Stanislas Chambon and Laetitia Chapel and Adrien Corenflos and Kilian Fatras and Nemo Fournier and L{\'e}o Gautheron and Nathalie T.H. Gayraud and Hicham Janati and Alain Rakotomamonjy and Ievgen Redko and Antoine Rolet and Antony Schutz and Vivien Seguy and Danica J. Sutherland and Romain Tavenard and Alexander Tong and Titouan Vayer},
+ title = {POT: Python Optimal Transport},
+ journal = {Journal of Machine Learning Research},
+ year = {2021},
+ volume = {22},
+ number = {78},
+ pages = {1-8},
+ url = {http://jmlr.org/papers/v22/20-451.html}
}
Installation
@@ -123,28 +145,21 @@ following Python modules:
- Numpy (>=1.16)
- Scipy (>=1.0)
-- Cython (>=0.23)
-- Matplotlib (>=1.5)
+- Cython (>=0.23) (build only, not necessary when installing from pip
+ or conda)
Pip installation
^^^^^^^^^^^^^^^^
-Note that due to a limitation of pip, ``cython`` and ``numpy`` need to
-be installed prior to installing POT. This can be done easily with
-
-::
-
- pip install numpy cython
-
You can install the toolbox through PyPI with:
-::
+.. code:: console
pip install POT
or get the very latest version by running:
-::
+.. code:: console
pip install -U https://github.com/PythonOT/POT/archive/master.zip # with --user for user install (no root)
@@ -155,7 +170,7 @@ If you use the Anaconda python distribution, POT is available in
`conda-forge <https://conda-forge.org>`__. To install it and the
required dependencies:
-::
+.. code:: console
conda install -c conda-forge pot
@@ -169,7 +184,8 @@ without errors:
import ot
-Note that for easier access the module is name ot instead of pot.
+Note that for easier access the module is named ``ot`` instead of
+``pot``.
Dependencies
~~~~~~~~~~~~
@@ -180,15 +196,17 @@ below
- **ot.dr** (Wasserstein dimensionality reduction) depends on autograd
and pymanopt that can be installed with:
- ::
+.. code:: shell
- pip install pymanopt autograd
+ pip install pymanopt autograd
- **ot.gpu** (GPU accelerated OT) depends on cupy that have to be
installed following instructions on `this
page <https://docs-cupy.chainer.org/en/stable/install.html>`__.
-
-obviously you need CUDA installed and a compatible GPU.
+ Obviously you will need CUDA installed and a compatible GPU. Note
+ that this module is deprecated since version 0.8 and will be deleted
+ in the future. GPU is now handled automatically through the backends
+ and several solver already can run on GPU using the Pytorch backend.
Examples
--------
@@ -198,36 +216,36 @@ Short examples
- Import the toolbox
- .. code:: python
+.. code:: python
- import ot
+ import ot
- Compute Wasserstein distances
- .. code:: python
+.. code:: python
- # a,b are 1D histograms (sum to 1 and positive)
- # M is the ground cost matrix
- Wd=ot.emd2(a,b,M) # exact linear program
- Wd_reg=ot.sinkhorn2(a,b,M,reg) # entropic regularized OT
- # if b is a matrix compute all distances to a and return a vector
+ # a,b are 1D histograms (sum to 1 and positive)
+ # M is the ground cost matrix
+ Wd = ot.emd2(a, b, M) # exact linear program
+ Wd_reg = ot.sinkhorn2(a, b, M, reg) # entropic regularized OT
+ # if b is a matrix compute all distances to a and return a vector
- Compute OT matrix
- .. code:: python
+.. code:: python
- # a,b are 1D histograms (sum to 1 and positive)
- # M is the ground cost matrix
- T=ot.emd(a,b,M) # exact linear program
- T_reg=ot.sinkhorn(a,b,M,reg) # entropic regularized OT
+ # a,b are 1D histograms (sum to 1 and positive)
+ # M is the ground cost matrix
+ T = ot.emd(a, b, M) # exact linear program
+ T_reg = ot.sinkhorn(a, b, M, reg) # entropic regularized OT
- Compute Wasserstein barycenter
- .. code:: python
+.. code:: python
- # A is a n*d matrix containing d 1D histograms
- # M is the ground cost matrix
- ba=ot.barycenter(A,M,reg) # reg is regularization parameter
+ # A is a n*d matrix containing d 1D histograms
+ # M is the ground cost matrix
+ ba = ot.barycenter(A, M, reg) # reg is regularization parameter
Examples and Notebooks
~~~~~~~~~~~~~~~~~~~~~~
@@ -265,10 +283,17 @@ The contributors to this library are
Rakotomamonjy <https://sites.google.com/site/alainrakotomamonjy/home>`__
- `Vayer Titouan <https://tvayer.github.io/>`__ (Gromov-Wasserstein -,
Fused-Gromov-Wasserstein)
-- `Hicham Janati <https://hichamjanati.github.io/>`__ (Unbalanced OT)
+- `Hicham Janati <https://hichamjanati.github.io/>`__ (Unbalanced OT,
+ Debiased barycenters)
- `Romain Tavenard <https://rtavenar.github.io/>`__ (1d Wasserstein)
- `Mokhtar Z. Alaya <http://mzalaya.github.io/>`__ (Screenkhorn)
- `Ievgen Redko <https://ievred.github.io/>`__ (Laplacian DA, JCPOT)
+- `Adrien Corenflos <https://adriencorenflos.github.io/>`__ (Sliced
+ Wasserstein Distance)
+- `Tanguy Kerdoncuff <https://hv0nnus.github.io/>`__ (Sampled Gromov
+ Wasserstein)
+- `Minhui Huang <https://mhhuang95.github.io>`__ (Projection Robust
+ Wasserstein Distance)
This toolbox benefit a lot from open source research and we would like
to thank the following persons for providing some code (in various
@@ -276,6 +301,8 @@ languages):
- `Gabriel Peyré <http://gpeyre.github.io/>`__ (Wasserstein Barycenters
in Matlab)
+- `Mathieu Blondel <https://mblondel.org/>`__ (original implementation
+ smooth OT)
- `Nicolas Bonneel <http://liris.cnrs.fr/~nbonneel/>`__ ( C++ code for
EMD)
- `Marco Cuturi <http://marcocuturi.net/>`__ (Sinkhorn Knopp in
@@ -285,20 +312,21 @@ Contributions and code of conduct
---------------------------------
Every contribution is welcome and should respect the `contribution
-guidelines <CONTRIBUTING.md>`__. Each member of the project is expected
-to follow the `code of conduct <CODE_OF_CONDUCT.md>`__.
+guidelines <.github/CONTRIBUTING.md>`__. Each member of the project is
+expected to follow the `code of conduct <.github/CODE_OF_CONDUCT.md>`__.
Support
-------
You can ask questions and join the development discussion:
-- On the `POT Slack channel <https://pot-toolbox.slack.com>`__
+- On the POT `slack channel <https://pot-toolbox.slack.com>`__
+- On the POT `gitter channel <https://gitter.im/PythonOT/community>`__
- On the POT `mailing
list <https://mail.python.org/mm3/mailman3/lists/pot.python.org/>`__
You can also post bug reports and feature requests in Github issues.
-Make sure to read our `guidelines <CONTRIBUTING.md>`__ first.
+Make sure to read our `guidelines <.github/CONTRIBUTING.md>`__ first.
References
----------
@@ -439,10 +467,10 @@ optimal transport and Monge-Ampere obstacle
problems <http://www.math.toronto.edu/~mccann/papers/annals2010.pdf>`__,
Annals of mathematics, 673-730.
-[29] Chapel, L., Alaya, M., Gasso, G. (2019). `Partial
-Gromov-Wasserstein with Applications on Positive-Unlabeled
-Learning <https://arxiv.org/abs/2002.08276>`__, arXiv preprint
-arXiv:2002.08276.
+[29] Chapel, L., Alaya, M., Gasso, G. (2020). `Partial Optimal Transport
+with Applications on Positive-Unlabeled
+Learning <https://arxiv.org/abs/2002.08276>`__, Advances in Neural
+Information Processing Systems (NeurIPS), 2020.
[30] Flamary R., Courty N., Tuia D., Rakotomamonjy A. (2014). `Optimal
transport with Laplacian regularization: Applications to domain
@@ -450,11 +478,56 @@ adaptation and shape
matching <https://remi.flamary.com/biblio/flamary2014optlaplace.pdf>`__,
NIPS Workshop on Optimal Transport and Machine Learning OTML, 2014.
+[31] Bonneel, Nicolas, et al. `Sliced and radon wasserstein barycenters
+of
+measures <https://perso.liris.cnrs.fr/nicolas.bonneel/WassersteinSliced-JMIV.pdf>`__,
+Journal of Mathematical Imaging and Vision 51.1 (2015): 22-45
+
+[32] Huang, M., Ma S., Lai, L. (2021). `A Riemannian Block Coordinate
+Descent Method for Computing the Projection Robust Wasserstein
+Distance <http://proceedings.mlr.press/v139/huang21e.html>`__,
+Proceedings of the 38th International Conference on Machine Learning
+(ICML).
+
+[33] Kerdoncuff T., Emonet R., Marc S. `Sampled Gromov
+Wasserstein <https://hal.archives-ouvertes.fr/hal-03232509/document>`__,
+Machine Learning Journal (MJL), 2021
+
+[34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., Trouvé, A.,
+& Peyré, G. (2019, April). `Interpolating between optimal transport and
+MMD using Sinkhorn
+divergences <http://proceedings.mlr.press/v89/feydy19a/feydy19a.pdf>`__.
+In The 22nd International Conference on Artificial Intelligence and
+Statistics (pp. 2681-2690). PMLR.
+
+[35] Deshpande, I., Hu, Y. T., Sun, R., Pyrros, A., Siddiqui, N.,
+Koyejo, S., ... & Schwing, A. G. (2019). `Max-sliced wasserstein
+distance and its use for
+gans <https://openaccess.thecvf.com/content_CVPR_2019/papers/Deshpande_Max-Sliced_Wasserstein_Distance_and_Its_Use_for_GANs_CVPR_2019_paper.pdf>`__.
+In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern
+Recognition (pp. 10648-10656).
+
+[36] Liutkus, A., Simsekli, U., Majewski, S., Durmus, A., & Stöter, F.
+R. (2019, May). `Sliced-Wasserstein flows: Nonparametric generative
+modeling via optimal transport and
+diffusions <http://proceedings.mlr.press/v97/liutkus19a/liutkus19a.pdf>`__.
+In International Conference on Machine Learning (pp. 4104-4113). PMLR.
+
+[37] Janati, H., Cuturi, M., Gramfort, A. `Debiased sinkhorn
+barycenters <http://proceedings.mlr.press/v119/janati20a/janati20a.pdf>`__
+Proceedings of the 37th International Conference on Machine Learning,
+PMLR 119:4692-4701, 2020
+
+[38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty,
+`Online Graph Dictionary
+Learning <https://arxiv.org/pdf/2102.06555.pdf>`__, International
+Conference on Machine Learning (ICML), 2021.
+
.. |PyPI version| image:: https://badge.fury.io/py/POT.svg
:target: https://badge.fury.io/py/POT
.. |Anaconda Cloud| image:: https://anaconda.org/conda-forge/pot/badges/version.svg
:target: https://anaconda.org/conda-forge/pot
-.. |Build Status| image:: https://github.com/PythonOT/POT/workflows/build/badge.svg
+.. |Build Status| image:: https://github.com/PythonOT/POT/workflows/build/badge.svg?branch=master&event=push
:target: https://github.com/PythonOT/POT/actions
.. |Codecov Status| image:: https://codecov.io/gh/PythonOT/POT/branch/master/graph/badge.svg
:target: https://codecov.io/gh/PythonOT/POT
diff --git a/docs/source/releases.rst b/docs/source/releases.rst
index 5a357f3..aa06105 100644
--- a/docs/source/releases.rst
+++ b/docs/source/releases.rst
@@ -1,6 +1,132 @@
Releases
========
+0.8.0
+-----
+
+*November 2021*
+
+This new stable release introduces several important features.
+
+First we now have an OpenMP compatible exact ot solver in ``ot.emd``.
+The OpenMP version is used when the parameter ``numThreads`` is greater
+than one and can lead to nice speedups on multi-core machines.
+
+| Second we have introduced a backend mechanism that allows to use
+ standard POT function seamlessly on Numpy, Pytorch and Jax arrays.
+ Other backends are coming but right now POT can be used seamlessly for
+ training neural networks in Pytorch. Notably we propose the first
+ differentiable computation of the exact OT loss with ``ot.emd2`` (can
+ be differentiated w.r.t. both cost matrix and sample weights), but
+ also for the classical Sinkhorn loss with ``ot.sinkhorn2``, the
+ Wasserstein distance in 1D with ``ot.wasserstein_1d``, sliced
+ Wasserstein with ``ot.sliced_wasserstein_distance`` and
+ Gromov-Wasserstein with ``ot.gromov_wasserstein2``. Examples of how
+ this new feature can be used are now available in the documentation
+ where the Pytorch backend is used to estimate a `minimal Wasserstein
+ estimator <https://PythonOT.github.io/auto_examples/backends/plot_unmix_optim_torch.html>`__,
+ a `Generative Network
+ (GAN) <https://PythonOT.github.io/auto_examples/backends/plot_wass2_gan_torch.html>`__,
+ for a `sliced Wasserstein gradient
+ flow <https://PythonOT.github.io/auto_examples/backends/plot_sliced_wass_grad_flow_pytorch.html>`__
+ and `optimizing the Gromov-Wassersein
+ distance <https://PythonOT.github.io/auto_examples/backends/plot_optim_gromov_pytorch.html>`__.
+ Note that the Jax backend is still in early development and quite slow
+ at the moment, we strongly recommend for Jax users to use the `OTT
+ toolbox <https://github.com/google-research/ott>`__ when possible.
+| As a result of this new feature, the old ``ot.gpu`` submodule is now
+ deprecated since GPU implementations can be done using GPU arrays on
+ the torch backends.
+
+Other novel features include implementation for `Sampled Gromov
+Wasserstein and Pointwise Gromov
+Wasserstein <https://PythonOT.github.io/auto_examples/gromov/plot_gromov.html#compute-gw-with-a-scalable-stochastic-method-with-any-loss-function>`__,
+Sinkhorn in log space with ``method='sinkhorn_log'``, `Projection Robust
+Wasserstein <https://PythonOT.github.io/gen_modules/ot.dr.html?highlight=robust#ot.dr.projection_robust_wasserstein>`__,
+ans `deviased Sinkorn
+barycenters <https://PythonOT.github.ioauto_examples/barycenters/plot_debiased_barycenter.html>`__.
+
+This release will also simplify the installation process. We have now a
+``pyproject.toml`` that defines the build dependency and POT should now
+build even when cython is not installed yet. Also we now provide
+pe-compiled wheels for linux ``aarch64`` that is used on Raspberry PI
+and android phones and for MacOS on ARM processors.
+
+Finally POT was accepted for publication in the Journal of Machine
+Learning Research (JMLR) open source software track and we ask the POT
+users to cite `this
+paper <https://www.jmlr.org/papers/v22/20-451.html>`__ from now on. The
+documentation has been improved in particular by adding a "Why OT?"
+section to the quick start guide and several new examples illustrating
+the new features. The documentation now has two version : the stable
+version https://pythonot.github.io/ corresponding to the last release
+and the master version https://pythonot.github.io/master that
+corresponds to the current master branch on GitHub.
+
+As usual, we want to thank all the POT contributors (now 37 people have
+contributed to the toolbox). But for this release we thank in particular
+Nathan Cassereau and Kamel Guerda from the AI support team at
+`IDRIS <http://www.idris.fr/>`__ for their support to the development of
+the backend and OpenMP implementations.
+
+New features
+^^^^^^^^^^^^
+
+- OpenMP support for exact OT solvers (PR #260)
+- Backend for running POT in numpy/torch + exact solver (PR #249)
+- Backend implementation of most functions in ``ot.bregman`` (PR #280)
+- Backend implementation of most functions in ``ot.optim`` (PR #282)
+- Backend implementation of most functions in ``ot.gromov`` (PR #294,
+ PR #302)
+- Test for arrays of different type and device (CPU/GPU) (PR #304,
+ #303)
+- Implementation of Sinkhorn in log space with
+ ``method='sinkhorn_log'`` (PR #290)
+- Implementation of regularization path for L2 Unbalanced OT (PR #274)
+- Implementation of Projection Robust Wasserstein (PR #267)
+- Implementation of Debiased Sinkhorn Barycenters (PR #291)
+- Implementation of Sampled Gromov Wasserstein and Pointwise Gromov
+ Wasserstein (PR #275)
+- Add ``pyproject.toml`` and build POT without installing cython first
+ (PR #293)
+- Lazy implementation in log space for sinkhorn on samples (PR #259)
+- Documentation cleanup (PR #298)
+- Two up-to-date documentations `for stable
+ release <https://PythonOT.github.io/>`__ and for `master
+ branch <https://pythonot.github.io/master/>`__.
+- Building wheels on ARM for Raspberry PI and smartphones (PR #238)
+- Update build wheels to new version and new pythons (PR #236, #253)
+- Implementation of sliced Wasserstein distance (Issue #202, PR #203)
+- Add minimal build to CI and perform pep8 test separately (PR #210)
+- Speedup of tests and return run time (PR #262)
+- Add "Why OT" discussion to the documentation (PR #220)
+- New introductory example to discrete OT in the documentation (PR
+ #191)
+- Add templates for Issues/PR on Github (PR#181)
+
+Closed issues
+^^^^^^^^^^^^^
+
+- Debug Memory leak in GAN example (#254)
+- DEbug GPU bug (Issue #284, #287, PR #288)
+- set\_gradients method for JAX backend (PR #278)
+- Quicker GAN example for CircleCI build (PR #258)
+- Better formatting in Readme (PR #234)
+- Debug CI tests (PR #240, #241, #242)
+- Bug in Partial OT solver dummy points (PR #215)
+- Bug when Armijo linesearch (Issue #184, #198, #281, PR #189, #199,
+ #286)
+- Bug Barycenter Sinkhorn (Issue 134, PR #195)
+- Infeasible solution in exact OT (Issues #126,#93, PR #217)
+- Doc for SUpport Barycenters (Issue #200, PR #201)
+- Fix labels transport in BaseTransport (Issue #207, PR #208)
+- Bug in ``emd_1d``, non respected bounds (Issue #169, PR #170)
+- Removed Python 2.7 support and update codecov file (PR #178)
+- Add normalization for WDA and test it (PR #172, #296)
+- Cleanup code for new version of ``flake8`` (PR #176)
+- Fixed requirements in ``setup.py`` (PR #174)
+- Removed specific MacOS flags (PR #175)
+
0.7.0
-----
@@ -50,7 +176,7 @@ problems.
This release is also the moment to thank all the POT contributors (old
and new) for helping making POT such a nice toolbox. A lot of changes
-(also in the API) are comming for the next versions.
+(also in the API) are coming for the next versions.
Features
^^^^^^^^
@@ -72,6 +198,8 @@ Features
Closed issues
^^^^^^^^^^^^^
+- Add JMLR paper to teh readme ad Mathieu Blondel to the Acknoledgments
+ (PR #231, #232)
- Bug in Unbalanced OT example (Issue #127)
- Clean Cython output when calling setup.py clean (Issue #122)
- Various Macosx compilation problems (Issue #113, Issue #118, PR#130)
@@ -103,8 +231,8 @@ mathematical problems and research but with the new contributions we now
implement algorithms and solvers from 24 scientific papers (listed in
the README.md file). New features include a direct implementation of the
`empirical Sinkhorn
-divergence <all.html#ot.bregman.empirical_sinkhorn_divergence>`__
-, a new efficient (Cython implementation) solver for `EMD in
+divergence <all.html#ot.bregman.empirical_sinkhorn_divergence>`__,
+a new efficient (Cython implementation) solver for `EMD in
1D <all.html#ot.lp.emd_1d>`__ and
corresponding `Wasserstein
1D <all.html#ot.lp.wasserstein_1d>`__.
diff --git a/examples/README.txt b/examples/README.txt
index 69a9f84..b48487f 100644
--- a/examples/README.txt
+++ b/examples/README.txt
@@ -1,7 +1,7 @@
Examples gallery
================
-This is a gallery of all the POT example files.
+This is a gallery of all the POT example files.
OT and regularized OT
diff --git a/examples/backends/README.txt b/examples/backends/README.txt
new file mode 100644
index 0000000..3ee0e27
--- /dev/null
+++ b/examples/backends/README.txt
@@ -0,0 +1,4 @@
+
+
+POT backend examples
+-------------------- \ No newline at end of file
diff --git a/examples/backends/plot_optim_gromov_pytorch.py b/examples/backends/plot_optim_gromov_pytorch.py
new file mode 100644
index 0000000..969707f
--- /dev/null
+++ b/examples/backends/plot_optim_gromov_pytorch.py
@@ -0,0 +1,260 @@
+r"""
+=================================
+Optimizing the Gromov-Wasserstein distance with PyTorch
+=================================
+
+In this exemple we use the pytorch backend to optimize the Gromov-Wasserstein
+(GW) loss between two graphs expressed as empirical distribution.
+
+In the first example we optimize the weights on the node of a simple template
+graph so that it minimizes the GW with a given Stochastic Block Model graph.
+We can see that this actually recovers the proportion of classes in the SBM
+and allows for an accurate clustering of the nodes using the GW optimal plan.
+
+In a second example we optimize simultaneously the weights and the sructure of
+the template graph which allows us to perform graph compression and to recover
+other properties of the SBM.
+
+The backend actually uses the gradients expressed in [38] to optimize the
+weights.
+
+[38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, Online Graph
+Dictionary Learning, International Conference on Machine Learning (ICML), 2021.
+
+"""
+# Author: Rémi Flamary <remi.flamary@polytechnique.edu>
+#
+# License: MIT License
+
+# sphinx_gallery_thumbnail_number = 3
+
+from sklearn.manifold import MDS
+import numpy as np
+import matplotlib.pylab as pl
+import torch
+
+import ot
+from ot.gromov import gromov_wasserstein2
+
+# %%
+# Graph generation
+# ---------------
+
+rng = np.random.RandomState(42)
+
+
+def get_sbm(n, nc, ratio, P):
+ nbpc = np.round(n * ratio).astype(int)
+ n = np.sum(nbpc)
+ C = np.zeros((n, n))
+ for c1 in range(nc):
+ for c2 in range(c1 + 1):
+ if c1 == c2:
+ for i in range(np.sum(nbpc[:c1]), np.sum(nbpc[:c1 + 1])):
+ for j in range(np.sum(nbpc[:c2]), i):
+ if rng.rand() <= P[c1, c2]:
+ C[i, j] = 1
+ else:
+ for i in range(np.sum(nbpc[:c1]), np.sum(nbpc[:c1 + 1])):
+ for j in range(np.sum(nbpc[:c2]), np.sum(nbpc[:c2 + 1])):
+ if rng.rand() <= P[c1, c2]:
+ C[i, j] = 1
+
+ return C + C.T
+
+
+n = 100
+nc = 3
+ratio = np.array([.5, .3, .2])
+P = np.array(0.6 * np.eye(3) + 0.05 * np.ones((3, 3)))
+C1 = get_sbm(n, nc, ratio, P)
+
+# get 2d position for nodes
+x1 = MDS(dissimilarity='precomputed', random_state=0).fit_transform(1 - C1)
+
+
+def plot_graph(x, C, color='C0', s=None):
+ for j in range(C.shape[0]):
+ for i in range(j):
+ if C[i, j] > 0:
+ pl.plot([x[i, 0], x[j, 0]], [x[i, 1], x[j, 1]], alpha=0.2, color='k')
+ pl.scatter(x[:, 0], x[:, 1], c=color, s=s, zorder=10, edgecolors='k', cmap='tab10', vmax=9)
+
+
+pl.figure(1, (10, 5))
+pl.clf()
+pl.subplot(1, 2, 1)
+plot_graph(x1, C1, color='C0')
+pl.title("SBM Graph")
+pl.axis("off")
+pl.subplot(1, 2, 2)
+pl.imshow(C1, interpolation='nearest')
+pl.title("Adjacency matrix")
+pl.axis("off")
+
+
+# %%
+# Optimizing GW w.r.t. the weights on a template structure
+# ------------------------------------------------
+# The adajacency matrix C1 is block diagonal with 3 blocks. We want to
+# optimize the weights of a simple template C0=eye(3) and see if we can
+# recover the proportion of classes from the SBM (up to a permutation).
+
+C0 = np.eye(3)
+
+
+def min_weight_gw(C1, C2, a2, nb_iter_max=100, lr=1e-2):
+ """ solve min_a GW(C1,C2,a, a2) by gradient descent"""
+
+ # use pyTorch for our data
+ C1_torch = torch.tensor(C1)
+ C2_torch = torch.tensor(C2)
+
+ a0 = rng.rand(C1.shape[0]) # random_init
+ a0 /= a0.sum() # on simplex
+ a1_torch = torch.tensor(a0).requires_grad_(True)
+ a2_torch = torch.tensor(a2)
+
+ loss_iter = []
+
+ for i in range(nb_iter_max):
+
+ loss = gromov_wasserstein2(C1_torch, C2_torch, a1_torch, a2_torch)
+
+ loss_iter.append(loss.clone().detach().cpu().numpy())
+ loss.backward()
+
+ #print("{:03d} | {}".format(i, loss_iter[-1]))
+
+ # performs a step of projected gradient descent
+ with torch.no_grad():
+ grad = a1_torch.grad
+ a1_torch -= grad * lr # step
+ a1_torch.grad.zero_()
+ a1_torch.data = ot.utils.proj_simplex(a1_torch)
+
+ a1 = a1_torch.clone().detach().cpu().numpy()
+
+ return a1, loss_iter
+
+
+a0_est, loss_iter0 = min_weight_gw(C0, C1, ot.unif(n), nb_iter_max=100, lr=1e-2)
+
+pl.figure(2)
+pl.plot(loss_iter0)
+pl.title("Loss along iterations")
+
+print("Estimated weights : ", a0_est)
+print("True proportions : ", ratio)
+
+
+# %%
+# It is clear that the optimization has converged and that we recover the
+# ratio of the different classes in the SBM graph up to a permutation.
+
+
+# %%
+# Community clustering with uniform and estimated weights
+# --------------------------------------------
+# The GW OT plan can be used to perform a clustering of the nodes of a graph
+# when computing the GW with a simple template like C0 by labeling nodes in
+# the original graph using by the index of the noe in the template receiving
+# the most mass.
+#
+# We show here the result of such a clustering when using uniform weights on
+# the template C0 and when using the optimal weights previously estimated.
+
+
+T_unif = ot.gromov_wasserstein(C1, C0, ot.unif(n), ot.unif(3))
+label_unif = T_unif.argmax(1)
+
+T_est = ot.gromov_wasserstein(C1, C0, ot.unif(n), a0_est)
+label_est = T_est.argmax(1)
+
+pl.figure(3, (10, 5))
+pl.clf()
+pl.subplot(1, 2, 1)
+plot_graph(x1, C1, color=label_unif)
+pl.title("Graph clustering unif. weights")
+pl.axis("off")
+pl.subplot(1, 2, 2)
+plot_graph(x1, C1, color=label_est)
+pl.title("Graph clustering est. weights")
+pl.axis("off")
+
+
+# %%
+# Graph compression with GW
+# -------------------------
+
+# Now we optimize both the weights and structure of a small graph that
+# minimize the GW distance wrt our data graph. This can be seen as graph
+# compression but can also recover important properties of an SBM such
+# as its class proportion but also its matrix of probability of links between
+# classes
+
+
+def graph_compession_gw(nb_nodes, C2, a2, nb_iter_max=100, lr=1e-2):
+ """ solve min_a GW(C1,C2,a, a2) by gradient descent"""
+
+ # use pyTorch for our data
+
+ C2_torch = torch.tensor(C2)
+ a2_torch = torch.tensor(a2)
+
+ a0 = rng.rand(nb_nodes) # random_init
+ a0 /= a0.sum() # on simplex
+ a1_torch = torch.tensor(a0).requires_grad_(True)
+ C0 = np.eye(nb_nodes)
+ C1_torch = torch.tensor(C0).requires_grad_(True)
+
+ loss_iter = []
+
+ for i in range(nb_iter_max):
+
+ loss = gromov_wasserstein2(C1_torch, C2_torch, a1_torch, a2_torch)
+
+ loss_iter.append(loss.clone().detach().cpu().numpy())
+ loss.backward()
+
+ #print("{:03d} | {}".format(i, loss_iter[-1]))
+
+ # performs a step of projected gradient descent
+ with torch.no_grad():
+ grad = a1_torch.grad
+ a1_torch -= grad * lr # step
+ a1_torch.grad.zero_()
+ a1_torch.data = ot.utils.proj_simplex(a1_torch)
+
+ grad = C1_torch.grad
+ C1_torch -= grad * lr # step
+ C1_torch.grad.zero_()
+ C1_torch.data = torch.clamp(C1_torch, 0, 1)
+
+ a1 = a1_torch.clone().detach().cpu().numpy()
+ C1 = C1_torch.clone().detach().cpu().numpy()
+
+ return a1, C1, loss_iter
+
+
+nb_nodes = 3
+a0_est2, C0_est2, loss_iter2 = graph_compession_gw(nb_nodes, C1, ot.unif(n),
+ nb_iter_max=100, lr=5e-2)
+
+pl.figure(4)
+pl.plot(loss_iter2)
+pl.title("Loss along iterations")
+
+
+print("Estimated weights : ", a0_est2)
+print("True proportions : ", ratio)
+
+pl.figure(6, (10, 3.5))
+pl.clf()
+pl.subplot(1, 2, 1)
+pl.imshow(P, vmin=0, vmax=1)
+pl.title('True SBM P matrix')
+pl.subplot(1, 2, 2)
+pl.imshow(C0_est2, vmin=0, vmax=1)
+pl.title('Estimated C0 matrix')
+pl.colorbar()
diff --git a/examples/backends/plot_sliced_wass_grad_flow_pytorch.py b/examples/backends/plot_sliced_wass_grad_flow_pytorch.py
new file mode 100644
index 0000000..05b9952
--- /dev/null
+++ b/examples/backends/plot_sliced_wass_grad_flow_pytorch.py
@@ -0,0 +1,185 @@
+r"""
+=================================
+Sliced Wasserstein barycenter and gradient flow with PyTorch
+=================================
+
+In this exemple we use the pytorch backend to optimize the sliced Wasserstein
+loss between two empirical distributions [31].
+
+In the first example one we perform a
+gradient flow on the support of a distribution that minimize the sliced
+Wassersein distance as poposed in [36].
+
+In the second exemple we optimize with a gradient descent the sliced
+Wasserstein barycenter between two distributions as in [31].
+
+[31] Bonneel, Nicolas, et al. "Sliced and radon wasserstein barycenters of
+measures." Journal of Mathematical Imaging and Vision 51.1 (2015): 22-45
+
+[36] Liutkus, A., Simsekli, U., Majewski, S., Durmus, A., & Stöter, F. R.
+(2019, May). Sliced-Wasserstein flows: Nonparametric generative modeling
+via optimal transport and diffusions. In International Conference on
+Machine Learning (pp. 4104-4113). PMLR.
+
+
+"""
+# Author: Rémi Flamary <remi.flamary@polytechnique.edu>
+#
+# License: MIT License
+
+
+# %%
+# Loading the data
+
+
+import numpy as np
+import matplotlib.pylab as pl
+import torch
+import ot
+import matplotlib.animation as animation
+
+I1 = pl.imread('../../data/redcross.png').astype(np.float64)[::4, ::4, 2]
+I2 = pl.imread('../../data/tooth.png').astype(np.float64)[::4, ::4, 2]
+
+sz = I2.shape[0]
+XX, YY = np.meshgrid(np.arange(sz), np.arange(sz))
+
+x1 = np.stack((XX[I1 == 0], YY[I1 == 0]), 1) * 1.0
+x2 = np.stack((XX[I2 == 0] + 60, -YY[I2 == 0] + 32), 1) * 1.0
+x3 = np.stack((XX[I2 == 0], -YY[I2 == 0] + 32), 1) * 1.0
+
+pl.figure(1, (8, 4))
+pl.scatter(x1[:, 0], x1[:, 1], alpha=0.5)
+pl.scatter(x2[:, 0], x2[:, 1], alpha=0.5)
+
+# %%
+# Sliced Wasserstein gradient flow with Pytorch
+# ---------------------------------------------
+
+
+device = "cuda" if torch.cuda.is_available() else "cpu"
+
+# use pyTorch for our data
+x1_torch = torch.tensor(x1).to(device=device).requires_grad_(True)
+x2_torch = torch.tensor(x2).to(device=device)
+
+
+lr = 1e3
+nb_iter_max = 100
+
+x_all = np.zeros((nb_iter_max, x1.shape[0], 2))
+
+loss_iter = []
+
+# generator for random permutations
+gen = torch.Generator()
+gen.manual_seed(42)
+
+for i in range(nb_iter_max):
+
+ loss = ot.sliced_wasserstein_distance(x1_torch, x2_torch, n_projections=20, seed=gen)
+
+ loss_iter.append(loss.clone().detach().cpu().numpy())
+ loss.backward()
+
+ # performs a step of projected gradient descent
+ with torch.no_grad():
+ grad = x1_torch.grad
+ x1_torch -= grad * lr / (1 + i / 5e1) # step
+ x1_torch.grad.zero_()
+ x_all[i, :, :] = x1_torch.clone().detach().cpu().numpy()
+
+xb = x1_torch.clone().detach().cpu().numpy()
+
+pl.figure(2, (8, 4))
+pl.scatter(x1[:, 0], x1[:, 1], alpha=0.5, label='$\mu^{(0)}$')
+pl.scatter(x2[:, 0], x2[:, 1], alpha=0.5, label=r'$\nu$')
+pl.scatter(xb[:, 0], xb[:, 1], alpha=0.5, label='$\mu^{(100)}$')
+pl.title('Sliced Wasserstein gradient flow')
+pl.legend()
+ax = pl.axis()
+
+# %%
+# Animate trajectories of the gradient flow along iteration
+# -------------------------------------------------------
+
+pl.figure(3, (8, 4))
+
+
+def _update_plot(i):
+ pl.clf()
+ pl.scatter(x1[:, 0], x1[:, 1], alpha=0.5, label='$\mu^{(0)}$')
+ pl.scatter(x2[:, 0], x2[:, 1], alpha=0.5, label=r'$\nu$')
+ pl.scatter(x_all[i, :, 0], x_all[i, :, 1], alpha=0.5, label='$\mu^{(100)}$')
+ pl.title('Sliced Wasserstein gradient flow Iter. {}'.format(i))
+ pl.axis(ax)
+ return 1
+
+
+ani = animation.FuncAnimation(pl.gcf(), _update_plot, nb_iter_max, interval=100, repeat_delay=2000)
+
+# %%
+# Compute the Sliced Wasserstein Barycenter
+#
+x1_torch = torch.tensor(x1).to(device=device)
+x3_torch = torch.tensor(x3).to(device=device)
+xbinit = np.random.randn(500, 2) * 10 + 16
+xbary_torch = torch.tensor(xbinit).to(device=device).requires_grad_(True)
+
+lr = 1e3
+nb_iter_max = 100
+
+x_all = np.zeros((nb_iter_max, xbary_torch.shape[0], 2))
+
+loss_iter = []
+
+# generator for random permutations
+gen = torch.Generator()
+gen.manual_seed(42)
+
+alpha = 0.5
+
+for i in range(nb_iter_max):
+
+ loss = alpha * ot.sliced_wasserstein_distance(xbary_torch, x3_torch, n_projections=50, seed=gen) \
+ + (1 - alpha) * ot.sliced_wasserstein_distance(xbary_torch, x1_torch, n_projections=50, seed=gen)
+
+ loss_iter.append(loss.clone().detach().cpu().numpy())
+ loss.backward()
+
+ # performs a step of projected gradient descent
+ with torch.no_grad():
+ grad = xbary_torch.grad
+ xbary_torch -= grad * lr # / (1 + i / 5e1) # step
+ xbary_torch.grad.zero_()
+ x_all[i, :, :] = xbary_torch.clone().detach().cpu().numpy()
+
+xb = xbary_torch.clone().detach().cpu().numpy()
+
+pl.figure(4, (8, 4))
+pl.scatter(x1[:, 0], x1[:, 1], alpha=0.5, label='$\mu$')
+pl.scatter(x2[:, 0], x2[:, 1], alpha=0.5, label=r'$\nu$')
+pl.scatter(xb[:, 0] + 30, xb[:, 1], alpha=0.5, label='Barycenter')
+pl.title('Sliced Wasserstein barycenter')
+pl.legend()
+ax = pl.axis()
+
+
+# %%
+# Animate trajectories of the barycenter along gradient descent
+# -------------------------------------------------------
+
+pl.figure(5, (8, 4))
+
+
+def _update_plot(i):
+ pl.clf()
+ pl.scatter(x1[:, 0], x1[:, 1], alpha=0.5, label='$\mu^{(0)}$')
+ pl.scatter(x2[:, 0], x2[:, 1], alpha=0.5, label=r'$\nu$')
+ pl.scatter(x_all[i, :, 0] + 30, x_all[i, :, 1], alpha=0.5, label='$\mu^{(100)}$')
+ pl.title('Sliced Wasserstein barycenter Iter. {}'.format(i))
+ pl.axis(ax)
+ return 1
+
+
+ani = animation.FuncAnimation(pl.gcf(), _update_plot, nb_iter_max, interval=100, repeat_delay=2000)
diff --git a/examples/backends/plot_unmix_optim_torch.py b/examples/backends/plot_unmix_optim_torch.py
new file mode 100644
index 0000000..9ae66e9
--- /dev/null
+++ b/examples/backends/plot_unmix_optim_torch.py
@@ -0,0 +1,161 @@
+# -*- coding: utf-8 -*-
+r"""
+=================================
+Wasserstein unmixing with PyTorch
+=================================
+
+In this example we estimate mixing parameters from distributions that minimize
+the Wasserstein distance. In other words we suppose that a target
+distribution :math:`\mu^t` can be expressed as a weighted sum of source
+distributions :math:`\mu^s_k` with the following model:
+
+.. math::
+ \mu^t = \sum_{k=1}^K w_k\mu^s_k
+
+where :math:`\mathbf{w}` is a vector of size :math:`K` and belongs in the
+distribution simplex :math:`\Delta_K`.
+
+In order to estimate this weight vector we propose to optimize the Wasserstein
+distance between the model and the observed :math:`\mu^t` with respect to
+the vector. This leads to the following optimization problem:
+
+.. math::
+ \min_{\mathbf{w}\in\Delta_K} \quad W \left(\mu^t,\sum_{k=1}^K w_k\mu^s_k\right)
+
+This minimization is done in this example with a simple projected gradient
+descent in PyTorch. We use the automatic backend of POT that allows us to
+compute the Wasserstein distance with :any:`ot.emd2` with
+differentiable losses.
+
+"""
+
+# Author: Remi Flamary <remi.flamary@polytechnique.edu>
+#
+# License: MIT License
+
+# sphinx_gallery_thumbnail_number = 2
+
+import numpy as np
+import matplotlib.pylab as pl
+import ot
+import torch
+
+
+##############################################################################
+# Generate data
+# -------------
+
+#%% Data
+
+nt = 100
+nt1 = 10 #
+
+ns1 = 50
+ns = 2 * ns1
+
+rng = np.random.RandomState(2)
+
+xt = rng.randn(nt, 2) * 0.2
+xt[:nt1, 0] += 1
+xt[nt1:, 1] += 1
+
+
+xs1 = rng.randn(ns1, 2) * 0.2
+xs1[:, 0] += 1
+xs2 = rng.randn(ns1, 2) * 0.2
+xs2[:, 1] += 1
+
+xs = np.concatenate((xs1, xs2))
+
+# Sample reweighting matrix H
+H = np.zeros((ns, 2))
+H[:ns1, 0] = 1 / ns1
+H[ns1:, 1] = 1 / ns1
+# each columns sums to 1 and has weights only for samples form the
+# corresponding source distribution
+
+M = ot.dist(xs, xt)
+
+##############################################################################
+# Plot data
+# ---------
+
+#%% plot the distributions
+
+pl.figure(1)
+pl.scatter(xt[:, 0], xt[:, 1], label='Target $\mu^t$', alpha=0.5)
+pl.scatter(xs1[:, 0], xs1[:, 1], label='Source $\mu^s_1$', alpha=0.5)
+pl.scatter(xs2[:, 0], xs2[:, 1], label='Source $\mu^s_2$', alpha=0.5)
+pl.title('Sources and Target distributions')
+pl.legend()
+
+
+##############################################################################
+# Optimization of the model wrt the Wasserstein distance
+# ------------------------------------------------------
+
+
+#%% Weights optimization with gradient descent
+
+# convert numpy arrays to torch tensors
+H2 = torch.tensor(H)
+M2 = torch.tensor(M)
+
+# weights for the source distributions
+w = torch.tensor(ot.unif(2), requires_grad=True)
+
+# uniform weights for target
+b = torch.tensor(ot.unif(nt))
+
+lr = 2e-3 # learning rate
+niter = 500 # number of iterations
+losses = [] # loss along the iterations
+
+# loss for the minimal Wasserstein estimator
+
+
+def get_loss(w):
+ a = torch.mv(H2, w) # distribution reweighting
+ return ot.emd2(a, b, M2) # squared Wasserstein 2
+
+
+for i in range(niter):
+
+ loss = get_loss(w)
+ losses.append(float(loss))
+
+ loss.backward()
+
+ with torch.no_grad():
+ w -= lr * w.grad # gradient step
+ w[:] = ot.utils.proj_simplex(w) # projection on the simplex
+
+ w.grad.zero_()
+
+
+##############################################################################
+# Estimated weights and convergence of the objective
+# ---------------------------------------------------
+
+we = w.detach().numpy()
+print('Estimated mixture:', we)
+
+pl.figure(2)
+pl.semilogy(losses)
+pl.grid()
+pl.title('Wasserstein distance')
+pl.xlabel("Iterations")
+
+##############################################################################
+# Ploting the reweighted source distribution
+# ------------------------------------------
+
+pl.figure(3)
+
+# compute source weights
+ws = H.dot(we)
+
+pl.scatter(xt[:, 0], xt[:, 1], label='Target $\mu^t$', alpha=0.5)
+pl.scatter(xs[:, 0], xs[:, 1], color='C3', s=ws * 20 * ns, label='Weighted sources $\sum_{k} w_k\mu^s_k$', alpha=0.5)
+pl.title('Target and reweighted source distributions')
+pl.legend()
diff --git a/examples/backends/plot_wass1d_torch.py b/examples/backends/plot_wass1d_torch.py
new file mode 100644
index 0000000..0abdd6d
--- /dev/null
+++ b/examples/backends/plot_wass1d_torch.py
@@ -0,0 +1,152 @@
+r"""
+=================================
+Wasserstein 1D with PyTorch
+=================================
+
+In this small example, we consider the following minization problem:
+
+.. math::
+ \mu^* = \min_\mu W(\mu,\nu)
+
+where :math:`\nu` is a reference 1D measure. The problem is handled
+by a projected gradient descent method, where the gradient is computed
+by pyTorch automatic differentiation. The projection on the simplex
+ensures that the iterate will remain on the probability simplex.
+
+This example illustrates both `wasserstein_1d` function and backend use within
+the POT framework.
+"""
+# Author: Nicolas Courty <ncourty@irisa.fr>
+# Rémi Flamary <remi.flamary@polytechnique.edu>
+#
+# License: MIT License
+
+import numpy as np
+import matplotlib.pylab as pl
+import matplotlib as mpl
+import torch
+
+from ot.lp import wasserstein_1d
+from ot.datasets import make_1D_gauss as gauss
+from ot.utils import proj_simplex
+
+red = np.array(mpl.colors.to_rgb('red'))
+blue = np.array(mpl.colors.to_rgb('blue'))
+
+
+n = 100 # nb bins
+
+# bin positions
+x = np.arange(n, dtype=np.float64)
+
+# Gaussian distributions
+a = gauss(n, m=20, s=5) # m= mean, s= std
+b = gauss(n, m=60, s=10)
+
+# enforce sum to one on the support
+a = a / a.sum()
+b = b / b.sum()
+
+
+device = "cuda" if torch.cuda.is_available() else "cpu"
+
+# use pyTorch for our data
+x_torch = torch.tensor(x).to(device=device)
+a_torch = torch.tensor(a).to(device=device).requires_grad_(True)
+b_torch = torch.tensor(b).to(device=device)
+
+lr = 1e-6
+nb_iter_max = 800
+
+loss_iter = []
+
+pl.figure(1, figsize=(8, 4))
+pl.plot(x, a, 'b', label='Source distribution')
+pl.plot(x, b, 'r', label='Target distribution')
+
+for i in range(nb_iter_max):
+ # Compute the Wasserstein 1D with torch backend
+ loss = wasserstein_1d(x_torch, x_torch, a_torch, b_torch, p=2)
+ # record the corresponding loss value
+ loss_iter.append(loss.clone().detach().cpu().numpy())
+ loss.backward()
+
+ # performs a step of projected gradient descent
+ with torch.no_grad():
+ grad = a_torch.grad
+ a_torch -= a_torch.grad * lr # step
+ a_torch.grad.zero_()
+ a_torch.data = proj_simplex(a_torch) # projection onto the simplex
+
+ # plot one curve every 10 iterations
+ if i % 10 == 0:
+ mix = float(i) / nb_iter_max
+ pl.plot(x, a_torch.clone().detach().cpu().numpy(), c=(1 - mix) * blue + mix * red)
+
+pl.legend()
+pl.title('Distribution along the iterations of the projected gradient descent')
+pl.show()
+
+pl.figure(2)
+pl.plot(range(nb_iter_max), loss_iter, lw=3)
+pl.title('Evolution of the loss along iterations', fontsize=16)
+pl.show()
+
+# %%
+# Wasserstein barycenter
+# ---------
+# In this example, we consider the following Wasserstein barycenter problem
+# $$ \\eta^* = \\min_\\eta\;\;\; (1-t)W(\\mu,\\eta) + tW(\\eta,\\nu)$$
+# where :math:`\\mu` and :math:`\\nu` are reference 1D measures, and :math:`t`
+# is a parameter :math:`\in [0,1]`. The problem is handled by a project gradient
+# descent method, where the gradient is computed by pyTorch automatic differentiation.
+# The projection on the simplex ensures that the iterate will remain on the
+# probability simplex.
+#
+# This example illustrates both `wasserstein_1d` function and backend use within the
+# POT framework.
+
+
+device = "cuda" if torch.cuda.is_available() else "cpu"
+
+# use pyTorch for our data
+x_torch = torch.tensor(x).to(device=device)
+a_torch = torch.tensor(a).to(device=device)
+b_torch = torch.tensor(b).to(device=device)
+bary_torch = torch.tensor((a + b).copy() / 2).to(device=device).requires_grad_(True)
+
+
+lr = 1e-6
+nb_iter_max = 2000
+
+loss_iter = []
+
+# instant of the interpolation
+t = 0.5
+
+for i in range(nb_iter_max):
+ # Compute the Wasserstein 1D with torch backend
+ loss = (1 - t) * wasserstein_1d(x_torch, x_torch, a_torch.detach(), bary_torch, p=2) + t * wasserstein_1d(x_torch, x_torch, b_torch, bary_torch, p=2)
+ # record the corresponding loss value
+ loss_iter.append(loss.clone().detach().cpu().numpy())
+ loss.backward()
+
+ # performs a step of projected gradient descent
+ with torch.no_grad():
+ grad = bary_torch.grad
+ bary_torch -= bary_torch.grad * lr # step
+ bary_torch.grad.zero_()
+ bary_torch.data = proj_simplex(bary_torch) # projection onto the simplex
+
+pl.figure(3, figsize=(8, 4))
+pl.plot(x, a, 'b', label='Source distribution')
+pl.plot(x, b, 'r', label='Target distribution')
+pl.plot(x, bary_torch.clone().detach().cpu().numpy(), c='green', label='W barycenter')
+pl.legend()
+pl.title('Wasserstein barycenter computed by gradient descent')
+pl.show()
+
+pl.figure(4)
+pl.plot(range(nb_iter_max), loss_iter, lw=3)
+pl.title('Evolution of the loss along iterations', fontsize=16)
+pl.show()
diff --git a/examples/backends/plot_wass2_gan_torch.py b/examples/backends/plot_wass2_gan_torch.py
new file mode 100644
index 0000000..ca5b3c9
--- /dev/null
+++ b/examples/backends/plot_wass2_gan_torch.py
@@ -0,0 +1,227 @@
+# -*- coding: utf-8 -*-
+r"""
+========================================
+Wasserstein 2 Minibatch GAN with PyTorch
+========================================
+
+In this example we train a Wasserstein GAN using Wasserstein 2 on minibatches
+as a distribution fitting term.
+
+We want to train a generator :math:`G_\theta` that generates realistic
+data from random noise drawn form a Gaussian :math:`\mu_n` distribution so
+that the data is indistinguishable from true data in the data distribution
+:math:`\mu_d`. To this end Wasserstein GAN [Arjovsky2017] aim at optimizing
+the parameters :math:`\theta` of the generator with the following
+optimization problem:
+
+.. math::
+ \min_{\theta} W(\mu_d,G_\theta\#\mu_n)
+
+
+In practice we do not have access to the full distribution :math:`\mu_d` but
+samples and we cannot compute the Wasserstein distance for lare dataset.
+[Arjovsky2017] proposed to approximate the dual potential of Wasserstein 1
+with a neural network recovering an optimization problem similar to GAN.
+In this example
+we will optimize the expectation of the Wasserstein distance over minibatches
+at each iterations as proposed in [Genevay2018]. Optimizing the Minibatches
+of the Wasserstein distance has been studied in[Fatras2019].
+
+[Arjovsky2017] Arjovsky, M., Chintala, S., & Bottou, L. (2017, July).
+Wasserstein generative adversarial networks. In International conference
+on machine learning (pp. 214-223). PMLR.
+
+[Genevay2018] Genevay, Aude, Gabriel Peyré, and Marco Cuturi. "Learning generative models
+with sinkhorn divergences." International Conference on Artificial Intelligence
+and Statistics. PMLR, 2018.
+
+[Fatras2019] Fatras, K., Zine, Y., Flamary, R., Gribonval, R., & Courty, N.
+(2020, June). Learning with minibatch Wasserstein: asymptotic and gradient
+properties. In the 23nd International Conference on Artificial Intelligence
+and Statistics (Vol. 108).
+
+"""
+
+# Author: Remi Flamary <remi.flamary@polytechnique.edu>
+#
+# License: MIT License
+
+# sphinx_gallery_thumbnail_number = 3
+
+import numpy as np
+import matplotlib.pyplot as pl
+import matplotlib.animation as animation
+import torch
+from torch import nn
+import ot
+
+
+# %%
+# Data generation
+# ---------------
+
+torch.manual_seed(1)
+sigma = 0.1
+n_dims = 2
+n_features = 2
+
+
+def get_data(n_samples):
+ c = torch.rand(size=(n_samples, 1))
+ angle = c * 2 * np.pi
+ x = torch.cat((torch.cos(angle), torch.sin(angle)), 1)
+ x += torch.randn(n_samples, 2) * sigma
+ return x
+
+
+# %%
+# Plot data
+# ---------
+
+# plot the distributions
+x = get_data(500)
+pl.figure(1)
+pl.scatter(x[:, 0], x[:, 1], label='Data samples from $\mu_d$', alpha=0.5)
+pl.title('Data distribution')
+pl.legend()
+
+
+# %%
+# Generator Model
+# ---------------
+
+# define the MLP model
+class Generator(torch.nn.Module):
+ def __init__(self):
+ super(Generator, self).__init__()
+ self.fc1 = nn.Linear(n_features, 200)
+ self.fc2 = nn.Linear(200, 500)
+ self.fc3 = nn.Linear(500, n_dims)
+ self.relu = torch.nn.ReLU() # instead of Heaviside step fn
+
+ def forward(self, x):
+ output = self.fc1(x)
+ output = self.relu(output) # instead of Heaviside step fn
+ output = self.fc2(output)
+ output = self.relu(output)
+ output = self.fc3(output)
+ return output
+
+# %%
+# Training the model
+# ------------------
+
+
+G = Generator()
+optimizer = torch.optim.RMSprop(G.parameters(), lr=0.00019, eps=1e-5)
+
+# number of iteration and size of the batches
+n_iter = 200 # set to 200 for doc build but 1000 is better ;)
+size_batch = 500
+
+# generate statis samples to see their trajectory along training
+n_visu = 100
+xnvisu = torch.randn(n_visu, n_features)
+xvisu = torch.zeros(n_iter, n_visu, n_dims)
+
+ab = torch.ones(size_batch) / size_batch
+losses = []
+
+
+for i in range(n_iter):
+
+ # generate noise samples
+ xn = torch.randn(size_batch, n_features)
+
+ # generate data samples
+ xd = get_data(size_batch)
+
+ # generate sample along iterations
+ xvisu[i, :, :] = G(xnvisu).detach()
+
+ # generate smaples and compte distance matrix
+ xg = G(xn)
+ M = ot.dist(xg, xd)
+
+ loss = ot.emd2(ab, ab, M)
+ losses.append(float(loss.detach()))
+
+ if i % 10 == 0:
+ print("Iter: {:3d}, loss={}".format(i, losses[-1]))
+
+ loss.backward()
+ optimizer.step()
+
+ del M
+
+pl.figure(2)
+pl.semilogy(losses)
+pl.grid()
+pl.title('Wasserstein distance')
+pl.xlabel("Iterations")
+
+
+# %%
+# Plot trajectories of generated samples along iterations
+# -------------------------------------------------------
+
+
+pl.figure(3, (10, 10))
+
+ivisu = [0, 10, 25, 50, 75, 125, 15, 175, 199]
+
+for i in range(9):
+ pl.subplot(3, 3, i + 1)
+ pl.scatter(xd[:, 0], xd[:, 1], label='Data samples from $\mu_d$', alpha=0.1)
+ pl.scatter(xvisu[ivisu[i], :, 0], xvisu[ivisu[i], :, 1], label='Data samples from $G\#\mu_n$', alpha=0.5)
+ pl.xticks(())
+ pl.yticks(())
+ pl.title('Iter. {}'.format(ivisu[i]))
+ if i == 0:
+ pl.legend()
+
+# %%
+# Animate trajectories of generated samples along iteration
+# -------------------------------------------------------
+
+pl.figure(4, (8, 8))
+
+
+def _update_plot(i):
+ pl.clf()
+ pl.scatter(xd[:, 0], xd[:, 1], label='Data samples from $\mu_d$', alpha=0.1)
+ pl.scatter(xvisu[i, :, 0], xvisu[i, :, 1], label='Data samples from $G\#\mu_n$', alpha=0.5)
+ pl.xticks(())
+ pl.yticks(())
+ pl.xlim((-1.5, 1.5))
+ pl.ylim((-1.5, 1.5))
+ pl.title('Iter. {}'.format(i))
+ return 1
+
+
+i = 0
+pl.scatter(xd[:, 0], xd[:, 1], label='Data samples from $\mu_d$', alpha=0.1)
+pl.scatter(xvisu[i, :, 0], xvisu[i, :, 1], label='Data samples from $G\#\mu_n$', alpha=0.5)
+pl.xticks(())
+pl.yticks(())
+pl.xlim((-1.5, 1.5))
+pl.ylim((-1.5, 1.5))
+pl.title('Iter. {}'.format(ivisu[i]))
+
+
+ani = animation.FuncAnimation(pl.gcf(), _update_plot, n_iter, interval=100, repeat_delay=2000)
+
+# %%
+# Generate and visualize data
+# ---------------------------
+
+size_batch = 500
+xd = get_data(size_batch)
+xn = torch.randn(size_batch, 2)
+x = G(xn).detach().numpy()
+
+pl.figure(5)
+pl.scatter(xd[:, 0], xd[:, 1], label='Data samples from $\mu_d$', alpha=0.5)
+pl.scatter(x[:, 0], x[:, 1], label='Data samples from $G\#\mu_n$', alpha=0.5)
+pl.title('Sources and Target distributions')
+pl.legend()
diff --git a/examples/barycenters/plot_barycenter_1D.py b/examples/barycenters/plot_barycenter_1D.py
index 63dc460..2373e99 100644
--- a/examples/barycenters/plot_barycenter_1D.py
+++ b/examples/barycenters/plot_barycenter_1D.py
@@ -18,10 +18,10 @@ SIAM Journal on Scientific Computing, 37(2), A1111-A1138.
#
# License: MIT License
-# sphinx_gallery_thumbnail_number = 4
+# sphinx_gallery_thumbnail_number = 1
import numpy as np
-import matplotlib.pylab as pl
+import matplotlib.pyplot as plt
import ot
# necessary for 3d plot even if not used
from mpl_toolkits.mplot3d import Axes3D # noqa
@@ -51,18 +51,6 @@ M = ot.utils.dist0(n)
M /= M.max()
##############################################################################
-# Plot data
-# ---------
-
-#%% plot the distributions
-
-pl.figure(1, figsize=(6.4, 3))
-for i in range(n_distributions):
- pl.plot(x, A[:, i])
-pl.title('Distributions')
-pl.tight_layout()
-
-##############################################################################
# Barycenter computation
# ----------------------
@@ -78,24 +66,20 @@ bary_l2 = A.dot(weights)
reg = 1e-3
bary_wass = ot.bregman.barycenter(A, M, reg, weights)
-pl.figure(2)
-pl.clf()
-pl.subplot(2, 1, 1)
-for i in range(n_distributions):
- pl.plot(x, A[:, i])
-pl.title('Distributions')
+f, (ax1, ax2) = plt.subplots(2, 1, tight_layout=True, num=1)
+ax1.plot(x, A, color="black")
+ax1.set_title('Distributions')
-pl.subplot(2, 1, 2)
-pl.plot(x, bary_l2, 'r', label='l2')
-pl.plot(x, bary_wass, 'g', label='Wasserstein')
-pl.legend()
-pl.title('Barycenters')
-pl.tight_layout()
+ax2.plot(x, bary_l2, 'r', label='l2')
+ax2.plot(x, bary_wass, 'g', label='Wasserstein')
+ax2.set_title('Barycenters')
+
+plt.legend()
+plt.show()
##############################################################################
# Barycentric interpolation
# -------------------------
-
#%% barycenter interpolation
n_alpha = 11
@@ -106,24 +90,23 @@ B_l2 = np.zeros((n, n_alpha))
B_wass = np.copy(B_l2)
-for i in range(0, n_alpha):
+for i in range(n_alpha):
alpha = alpha_list[i]
weights = np.array([1 - alpha, alpha])
B_l2[:, i] = A.dot(weights)
B_wass[:, i] = ot.bregman.barycenter(A, M, reg, weights)
#%% plot interpolation
+plt.figure(2)
-pl.figure(3)
-
-cmap = pl.cm.get_cmap('viridis')
+cmap = plt.cm.get_cmap('viridis')
verts = []
zs = alpha_list
for i, z in enumerate(zs):
ys = B_l2[:, i]
verts.append(list(zip(x, ys)))
-ax = pl.gcf().gca(projection='3d')
+ax = plt.gcf().gca(projection='3d')
poly = PolyCollection(verts, facecolors=[cmap(a) for a in alpha_list])
poly.set_alpha(0.7)
@@ -134,18 +117,18 @@ ax.set_ylabel('$\\alpha$')
ax.set_ylim3d(0, 1)
ax.set_zlabel('')
ax.set_zlim3d(0, B_l2.max() * 1.01)
-pl.title('Barycenter interpolation with l2')
-pl.tight_layout()
+plt.title('Barycenter interpolation with l2')
+plt.tight_layout()
-pl.figure(4)
-cmap = pl.cm.get_cmap('viridis')
+plt.figure(3)
+cmap = plt.cm.get_cmap('viridis')
verts = []
zs = alpha_list
for i, z in enumerate(zs):
ys = B_wass[:, i]
verts.append(list(zip(x, ys)))
-ax = pl.gcf().gca(projection='3d')
+ax = plt.gcf().gca(projection='3d')
poly = PolyCollection(verts, facecolors=[cmap(a) for a in alpha_list])
poly.set_alpha(0.7)
@@ -156,7 +139,7 @@ ax.set_ylabel('$\\alpha$')
ax.set_ylim3d(0, 1)
ax.set_zlabel('')
ax.set_zlim3d(0, B_l2.max() * 1.01)
-pl.title('Barycenter interpolation with Wasserstein')
-pl.tight_layout()
+plt.title('Barycenter interpolation with Wasserstein')
+plt.tight_layout()
-pl.show()
+plt.show()
diff --git a/examples/barycenters/plot_barycenter_lp_vs_entropic.py b/examples/barycenters/plot_barycenter_lp_vs_entropic.py
index 57a6bac..6502f16 100644
--- a/examples/barycenters/plot_barycenter_lp_vs_entropic.py
+++ b/examples/barycenters/plot_barycenter_lp_vs_entropic.py
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
"""
=================================================================================
-1D Wasserstein barycenter comparison between exact LP and entropic regularization
+1D Wasserstein barycenter: exact LP vs entropic regularization
=================================================================================
This example illustrates the computation of regularized Wasserstein Barycenter
diff --git a/examples/barycenters/plot_convolutional_barycenter.py b/examples/barycenters/plot_convolutional_barycenter.py
index cbcd4a1..3721f31 100644
--- a/examples/barycenters/plot_convolutional_barycenter.py
+++ b/examples/barycenters/plot_convolutional_barycenter.py
@@ -6,17 +6,18 @@
Convolutional Wasserstein Barycenter example
============================================
-This example is designed to illustrate how the Convolutional Wasserstein Barycenter
-function of POT works.
+This example is designed to illustrate how the Convolutional Wasserstein
+Barycenter function of POT works.
"""
# Author: Nicolas Courty <ncourty@irisa.fr>
#
# License: MIT License
-
+import os
+from pathlib import Path
import numpy as np
-import pylab as pl
+import matplotlib.pyplot as plt
import ot
##############################################################################
@@ -25,22 +26,19 @@ import ot
#
# The four distributions are constructed from 4 simple images
+this_file = os.path.realpath('__file__')
+data_path = os.path.join(Path(this_file).parent.parent.parent, 'data')
-f1 = 1 - pl.imread('../../data/redcross.png')[:, :, 2]
-f2 = 1 - pl.imread('../../data/duck.png')[:, :, 2]
-f3 = 1 - pl.imread('../../data/heart.png')[:, :, 2]
-f4 = 1 - pl.imread('../../data/tooth.png')[:, :, 2]
+f1 = 1 - plt.imread(os.path.join(data_path, 'redcross.png'))[:, :, 2]
+f2 = 1 - plt.imread(os.path.join(data_path, 'tooth.png'))[:, :, 2]
+f3 = 1 - plt.imread(os.path.join(data_path, 'heart.png'))[:, :, 2]
+f4 = 1 - plt.imread(os.path.join(data_path, 'duck.png'))[:, :, 2]
-A = []
f1 = f1 / np.sum(f1)
f2 = f2 / np.sum(f2)
f3 = f3 / np.sum(f3)
f4 = f4 / np.sum(f4)
-A.append(f1)
-A.append(f2)
-A.append(f3)
-A.append(f4)
-A = np.array(A)
+A = np.array([f1, f2, f3, f4])
nb_images = 5
@@ -57,14 +55,13 @@ v4 = np.array((0, 0, 0, 1))
# ----------------------------------------
#
-pl.figure(figsize=(10, 10))
-pl.title('Convolutional Wasserstein Barycenters in POT')
+fig, axes = plt.subplots(nb_images, nb_images, figsize=(7, 7))
+plt.suptitle('Convolutional Wasserstein Barycenters in POT')
cm = 'Blues'
# regularization parameter
reg = 0.004
for i in range(nb_images):
for j in range(nb_images):
- pl.subplot(nb_images, nb_images, i * nb_images + j + 1)
tx = float(i) / (nb_images - 1)
ty = float(j) / (nb_images - 1)
@@ -74,19 +71,19 @@ for i in range(nb_images):
weights = (1 - ty) * tmp1 + ty * tmp2
if i == 0 and j == 0:
- pl.imshow(f1, cmap=cm)
- pl.axis('off')
+ axes[i, j].imshow(f1, cmap=cm)
elif i == 0 and j == (nb_images - 1):
- pl.imshow(f3, cmap=cm)
- pl.axis('off')
+ axes[i, j].imshow(f3, cmap=cm)
elif i == (nb_images - 1) and j == 0:
- pl.imshow(f2, cmap=cm)
- pl.axis('off')
+ axes[i, j].imshow(f2, cmap=cm)
elif i == (nb_images - 1) and j == (nb_images - 1):
- pl.imshow(f4, cmap=cm)
- pl.axis('off')
+ axes[i, j].imshow(f4, cmap=cm)
else:
# call to barycenter computation
- pl.imshow(ot.bregman.convolutional_barycenter2d(A, reg, weights), cmap=cm)
- pl.axis('off')
-pl.show()
+ axes[i, j].imshow(
+ ot.bregman.convolutional_barycenter2d(A, reg, weights),
+ cmap=cm
+ )
+ axes[i, j].axis('off')
+plt.tight_layout()
+plt.show()
diff --git a/examples/barycenters/plot_debiased_barycenter.py b/examples/barycenters/plot_debiased_barycenter.py
new file mode 100644
index 0000000..2a603dd
--- /dev/null
+++ b/examples/barycenters/plot_debiased_barycenter.py
@@ -0,0 +1,131 @@
+# -*- coding: utf-8 -*-
+"""
+=================================
+Debiased Sinkhorn barycenter demo
+=================================
+
+This example illustrates the computation of the debiased Sinkhorn barycenter
+as proposed in [37]_.
+
+
+.. [37] Janati, H., Cuturi, M., Gramfort, A. Proceedings of the 37th
+ International Conference on Machine Learning, PMLR 119:4692-4701, 2020
+"""
+
+# Author: Hicham Janati <hicham.janati100@gmail.com>
+#
+# License: MIT License
+# sphinx_gallery_thumbnail_number = 3
+
+import os
+from pathlib import Path
+
+import numpy as np
+import matplotlib.pyplot as plt
+
+import ot
+from ot.bregman import (barycenter, barycenter_debiased,
+ convolutional_barycenter2d,
+ convolutional_barycenter2d_debiased)
+
+##############################################################################
+# Debiased barycenter of 1D Gaussians
+# ------------------------------------
+
+#%% parameters
+
+n = 100 # nb bins
+
+# bin positions
+x = np.arange(n, dtype=np.float64)
+
+# Gaussian distributions
+a1 = ot.datasets.make_1D_gauss(n, m=20, s=5) # m= mean, s= std
+a2 = ot.datasets.make_1D_gauss(n, m=60, s=8)
+
+# creating matrix A containing all distributions
+A = np.vstack((a1, a2)).T
+n_distributions = A.shape[1]
+
+# loss matrix + normalization
+M = ot.utils.dist0(n)
+M /= M.max()
+
+#%% barycenter computation
+
+alpha = 0.2 # 0<=alpha<=1
+weights = np.array([1 - alpha, alpha])
+
+epsilons = [5e-3, 1e-2, 5e-2]
+
+
+bars = [barycenter(A, M, reg, weights) for reg in epsilons]
+bars_debiased = [barycenter_debiased(A, M, reg, weights) for reg in epsilons]
+labels = ["Sinkhorn barycenter", "Debiased barycenter"]
+colors = ["indianred", "gold"]
+
+f, axes = plt.subplots(1, len(epsilons), tight_layout=True, sharey=True,
+ figsize=(12, 4), num=1)
+for ax, eps, bar, bar_debiased in zip(axes, epsilons, bars, bars_debiased):
+ ax.plot(A[:, 0], color="k", ls="--", label="Input data", alpha=0.3)
+ ax.plot(A[:, 1], color="k", ls="--", alpha=0.3)
+ for data, label, color in zip([bar, bar_debiased], labels, colors):
+ ax.plot(data, color=color, label=label, lw=2)
+ ax.set_title(r"$\varepsilon = %.3f$" % eps)
+plt.legend()
+plt.show()
+
+
+##############################################################################
+# Debiased barycenter of 2D images
+# ---------------------------------
+this_file = os.path.realpath('__file__')
+data_path = os.path.join(Path(this_file).parent.parent.parent, 'data')
+f1 = 1 - plt.imread(os.path.join(data_path, 'heart.png'))[:, :, 2]
+f2 = 1 - plt.imread(os.path.join(data_path, 'duck.png'))[:, :, 2]
+
+A = np.asarray([f1, f2]) + 1e-2
+A /= A.sum(axis=(1, 2))[:, None, None]
+
+##############################################################################
+# Display the input images
+
+fig, axes = plt.subplots(1, 2, figsize=(7, 4), num=2)
+for ax, img in zip(axes, A):
+ ax.imshow(img, cmap="Greys")
+ ax.axis("off")
+fig.tight_layout()
+plt.show()
+
+
+##############################################################################
+# Barycenter computation and visualization
+# ----------------------------------------
+#
+
+bars_sinkhorn, bars_debiased = [], []
+epsilons = [5e-3, 7e-3, 1e-2]
+for eps in epsilons:
+ bar = convolutional_barycenter2d(A, eps)
+ bar_debiased, log = convolutional_barycenter2d_debiased(A, eps, log=True)
+ bars_sinkhorn.append(bar)
+ bars_debiased.append(bar_debiased)
+
+titles = ["Sinkhorn", "Debiased"]
+all_bars = [bars_sinkhorn, bars_debiased]
+fig, axes = plt.subplots(2, 3, figsize=(8, 6), num=3)
+for jj, (method, ax_row, bars) in enumerate(zip(titles, axes, all_bars)):
+ for ii, (ax, img, eps) in enumerate(zip(ax_row, bars, epsilons)):
+ ax.imshow(img, cmap="Greys")
+ if jj == 0:
+ ax.set_title(r"$\varepsilon = %.3f$" % eps, fontsize=13)
+ ax.set_xticks([])
+ ax.set_yticks([])
+ ax.spines['top'].set_visible(False)
+ ax.spines['right'].set_visible(False)
+ ax.spines['bottom'].set_visible(False)
+ ax.spines['left'].set_visible(False)
+ if ii == 0:
+ ax.set_ylabel(method, fontsize=15)
+fig.tight_layout()
+plt.show()
diff --git a/examples/barycenters/plot_free_support_barycenter.py b/examples/barycenters/plot_free_support_barycenter.py
index 27ddc8e..2d68a39 100644
--- a/examples/barycenters/plot_free_support_barycenter.py
+++ b/examples/barycenters/plot_free_support_barycenter.py
@@ -1,8 +1,8 @@
# -*- coding: utf-8 -*-
"""
-====================================================
+========================================================
2D free support Wasserstein barycenters of distributions
-====================================================
+========================================================
Illustration of 2D Wasserstein barycenters if distributions are weighted
sum of diracs.
diff --git a/examples/domain-adaptation/plot_otda_color_images.py b/examples/domain-adaptation/plot_otda_color_images.py
index 929365e..06dc8ab 100644
--- a/examples/domain-adaptation/plot_otda_color_images.py
+++ b/examples/domain-adaptation/plot_otda_color_images.py
@@ -19,17 +19,20 @@ SIAM Journal on Imaging Sciences, 7(3), 1853-1882.
# sphinx_gallery_thumbnail_number = 2
+import os
+from pathlib import Path
+
import numpy as np
-import matplotlib.pylab as pl
+from matplotlib import pyplot as plt
import ot
-r = np.random.RandomState(42)
+rng = np.random.RandomState(42)
-def im2mat(I):
+def im2mat(img):
"""Converts an image to matrix (one pixel per line)"""
- return I.reshape((I.shape[0] * I.shape[1], I.shape[2]))
+ return img.reshape((img.shape[0] * img.shape[1], img.shape[2]))
def mat2im(X, shape):
@@ -37,8 +40,8 @@ def mat2im(X, shape):
return X.reshape(shape)
-def minmax(I):
- return np.clip(I, 0, 1)
+def minmax(img):
+ return np.clip(img, 0, 1)
##############################################################################
@@ -46,16 +49,19 @@ def minmax(I):
# -------------
# Loading images
-I1 = pl.imread('../../data/ocean_day.jpg').astype(np.float64) / 256
-I2 = pl.imread('../../data/ocean_sunset.jpg').astype(np.float64) / 256
+this_file = os.path.realpath('__file__')
+data_path = os.path.join(Path(this_file).parent.parent.parent, 'data')
+
+I1 = plt.imread(os.path.join(data_path, 'ocean_day.jpg')).astype(np.float64) / 256
+I2 = plt.imread(os.path.join(data_path, 'ocean_sunset.jpg')).astype(np.float64) / 256
X1 = im2mat(I1)
X2 = im2mat(I2)
# training samples
-nb = 1000
-idx1 = r.randint(X1.shape[0], size=(nb,))
-idx2 = r.randint(X2.shape[0], size=(nb,))
+nb = 500
+idx1 = rng.randint(X1.shape[0], size=(nb,))
+idx2 = rng.randint(X2.shape[0], size=(nb,))
Xs = X1[idx1, :]
Xt = X2[idx2, :]
@@ -65,39 +71,39 @@ Xt = X2[idx2, :]
# Plot original image
# -------------------
-pl.figure(1, figsize=(6.4, 3))
+plt.figure(1, figsize=(6.4, 3))
-pl.subplot(1, 2, 1)
-pl.imshow(I1)
-pl.axis('off')
-pl.title('Image 1')
+plt.subplot(1, 2, 1)
+plt.imshow(I1)
+plt.axis('off')
+plt.title('Image 1')
-pl.subplot(1, 2, 2)
-pl.imshow(I2)
-pl.axis('off')
-pl.title('Image 2')
+plt.subplot(1, 2, 2)
+plt.imshow(I2)
+plt.axis('off')
+plt.title('Image 2')
##############################################################################
# Scatter plot of colors
# ----------------------
-pl.figure(2, figsize=(6.4, 3))
+plt.figure(2, figsize=(6.4, 3))
-pl.subplot(1, 2, 1)
-pl.scatter(Xs[:, 0], Xs[:, 2], c=Xs)
-pl.axis([0, 1, 0, 1])
-pl.xlabel('Red')
-pl.ylabel('Blue')
-pl.title('Image 1')
+plt.subplot(1, 2, 1)
+plt.scatter(Xs[:, 0], Xs[:, 2], c=Xs)
+plt.axis([0, 1, 0, 1])
+plt.xlabel('Red')
+plt.ylabel('Blue')
+plt.title('Image 1')
-pl.subplot(1, 2, 2)
-pl.scatter(Xt[:, 0], Xt[:, 2], c=Xt)
-pl.axis([0, 1, 0, 1])
-pl.xlabel('Red')
-pl.ylabel('Blue')
-pl.title('Image 2')
-pl.tight_layout()
+plt.subplot(1, 2, 2)
+plt.scatter(Xt[:, 0], Xt[:, 2], c=Xt)
+plt.axis([0, 1, 0, 1])
+plt.xlabel('Red')
+plt.ylabel('Blue')
+plt.title('Image 2')
+plt.tight_layout()
##############################################################################
@@ -130,37 +136,37 @@ I2te = minmax(mat2im(transp_Xt_sinkhorn, I2.shape))
# Plot new images
# ---------------
-pl.figure(3, figsize=(8, 4))
+plt.figure(3, figsize=(8, 4))
-pl.subplot(2, 3, 1)
-pl.imshow(I1)
-pl.axis('off')
-pl.title('Image 1')
+plt.subplot(2, 3, 1)
+plt.imshow(I1)
+plt.axis('off')
+plt.title('Image 1')
-pl.subplot(2, 3, 2)
-pl.imshow(I1t)
-pl.axis('off')
-pl.title('Image 1 Adapt')
+plt.subplot(2, 3, 2)
+plt.imshow(I1t)
+plt.axis('off')
+plt.title('Image 1 Adapt')
-pl.subplot(2, 3, 3)
-pl.imshow(I1te)
-pl.axis('off')
-pl.title('Image 1 Adapt (reg)')
+plt.subplot(2, 3, 3)
+plt.imshow(I1te)
+plt.axis('off')
+plt.title('Image 1 Adapt (reg)')
-pl.subplot(2, 3, 4)
-pl.imshow(I2)
-pl.axis('off')
-pl.title('Image 2')
+plt.subplot(2, 3, 4)
+plt.imshow(I2)
+plt.axis('off')
+plt.title('Image 2')
-pl.subplot(2, 3, 5)
-pl.imshow(I2t)
-pl.axis('off')
-pl.title('Image 2 Adapt')
+plt.subplot(2, 3, 5)
+plt.imshow(I2t)
+plt.axis('off')
+plt.title('Image 2 Adapt')
-pl.subplot(2, 3, 6)
-pl.imshow(I2te)
-pl.axis('off')
-pl.title('Image 2 Adapt (reg)')
-pl.tight_layout()
+plt.subplot(2, 3, 6)
+plt.imshow(I2te)
+plt.axis('off')
+plt.title('Image 2 Adapt (reg)')
+plt.tight_layout()
-pl.show()
+plt.show()
diff --git a/examples/domain-adaptation/plot_otda_jcpot.py b/examples/domain-adaptation/plot_otda_jcpot.py
index c495690..0d974f4 100644
--- a/examples/domain-adaptation/plot_otda_jcpot.py
+++ b/examples/domain-adaptation/plot_otda_jcpot.py
@@ -1,8 +1,8 @@
# -*- coding: utf-8 -*-
"""
-========================
+================================
OT for multi-source target shift
-========================
+================================
This example introduces a target shift problem with two 2D source and 1 target domain.
diff --git a/examples/domain-adaptation/plot_otda_linear_mapping.py b/examples/domain-adaptation/plot_otda_linear_mapping.py
index dbf16b8..a44096a 100644
--- a/examples/domain-adaptation/plot_otda_linear_mapping.py
+++ b/examples/domain-adaptation/plot_otda_linear_mapping.py
@@ -13,9 +13,11 @@ Linear OT mapping estimation
# License: MIT License
# sphinx_gallery_thumbnail_number = 2
+import os
+from pathlib import Path
import numpy as np
-import pylab as pl
+from matplotlib import pyplot as plt
import ot
##############################################################################
@@ -26,17 +28,19 @@ n = 1000
d = 2
sigma = .1
+rng = np.random.RandomState(42)
+
# source samples
-angles = np.random.rand(n, 1) * 2 * np.pi
+angles = rng.rand(n, 1) * 2 * np.pi
xs = np.concatenate((np.sin(angles), np.cos(angles)),
- axis=1) + sigma * np.random.randn(n, 2)
+ axis=1) + sigma * rng.randn(n, 2)
xs[:n // 2, 1] += 2
# target samples
-anglet = np.random.rand(n, 1) * 2 * np.pi
+anglet = rng.rand(n, 1) * 2 * np.pi
xt = np.concatenate((np.sin(anglet), np.cos(anglet)),
- axis=1) + sigma * np.random.randn(n, 2)
+ axis=1) + sigma * rng.randn(n, 2)
xt[:n // 2, 1] += 2
@@ -48,9 +52,9 @@ xt = xt.dot(A) + b
# Plot data
# ---------
-pl.figure(1, (5, 5))
-pl.plot(xs[:, 0], xs[:, 1], '+')
-pl.plot(xt[:, 0], xt[:, 1], 'o')
+plt.figure(1, (5, 5))
+plt.plot(xs[:, 0], xs[:, 1], '+')
+plt.plot(xt[:, 0], xt[:, 1], 'o')
##############################################################################
@@ -66,22 +70,22 @@ xst = xs.dot(Ae) + be
# Plot transported samples
# ------------------------
-pl.figure(1, (5, 5))
-pl.clf()
-pl.plot(xs[:, 0], xs[:, 1], '+')
-pl.plot(xt[:, 0], xt[:, 1], 'o')
-pl.plot(xst[:, 0], xst[:, 1], '+')
+plt.figure(1, (5, 5))
+plt.clf()
+plt.plot(xs[:, 0], xs[:, 1], '+')
+plt.plot(xt[:, 0], xt[:, 1], 'o')
+plt.plot(xst[:, 0], xst[:, 1], '+')
-pl.show()
+plt.show()
##############################################################################
# Load image data
# ---------------
-def im2mat(I):
+def im2mat(img):
"""Converts and image to matrix (one pixel per line)"""
- return I.reshape((I.shape[0] * I.shape[1], I.shape[2]))
+ return img.reshape((img.shape[0] * img.shape[1], img.shape[2]))
def mat2im(X, shape):
@@ -89,13 +93,16 @@ def mat2im(X, shape):
return X.reshape(shape)
-def minmax(I):
- return np.clip(I, 0, 1)
+def minmax(img):
+ return np.clip(img, 0, 1)
# Loading images
-I1 = pl.imread('../../data/ocean_day.jpg').astype(np.float64) / 256
-I2 = pl.imread('../../data/ocean_sunset.jpg').astype(np.float64) / 256
+this_file = os.path.realpath('__file__')
+data_path = os.path.join(Path(this_file).parent.parent.parent, 'data')
+
+I1 = plt.imread(os.path.join(data_path, 'ocean_day.jpg')).astype(np.float64) / 256
+I2 = plt.imread(os.path.join(data_path, 'ocean_sunset.jpg')).astype(np.float64) / 256
X1 = im2mat(I1)
@@ -123,24 +130,24 @@ I2t = minmax(mat2im(xts, I2.shape))
# Plot transformed images
# -----------------------
-pl.figure(2, figsize=(10, 7))
+plt.figure(2, figsize=(10, 7))
-pl.subplot(2, 2, 1)
-pl.imshow(I1)
-pl.axis('off')
-pl.title('Im. 1')
+plt.subplot(2, 2, 1)
+plt.imshow(I1)
+plt.axis('off')
+plt.title('Im. 1')
-pl.subplot(2, 2, 2)
-pl.imshow(I2)
-pl.axis('off')
-pl.title('Im. 2')
+plt.subplot(2, 2, 2)
+plt.imshow(I2)
+plt.axis('off')
+plt.title('Im. 2')
-pl.subplot(2, 2, 3)
-pl.imshow(I1t)
-pl.axis('off')
-pl.title('Mapping Im. 1')
+plt.subplot(2, 2, 3)
+plt.imshow(I1t)
+plt.axis('off')
+plt.title('Mapping Im. 1')
-pl.subplot(2, 2, 4)
-pl.imshow(I2t)
-pl.axis('off')
-pl.title('Inverse mapping Im. 2')
+plt.subplot(2, 2, 4)
+plt.imshow(I2t)
+plt.axis('off')
+plt.title('Inverse mapping Im. 2')
diff --git a/examples/domain-adaptation/plot_otda_mapping_colors_images.py b/examples/domain-adaptation/plot_otda_mapping_colors_images.py
index ee5c8b0..dbece70 100644
--- a/examples/domain-adaptation/plot_otda_mapping_colors_images.py
+++ b/examples/domain-adaptation/plot_otda_mapping_colors_images.py
@@ -21,17 +21,19 @@ discrete optimal transport", Neural Information Processing Systems (NIPS), 2016.
# License: MIT License
# sphinx_gallery_thumbnail_number = 3
+import os
+from pathlib import Path
import numpy as np
-import matplotlib.pylab as pl
+from matplotlib import pyplot as plt
import ot
-r = np.random.RandomState(42)
+rng = np.random.RandomState(42)
-def im2mat(I):
+def im2mat(img):
"""Converts and image to matrix (one pixel per line)"""
- return I.reshape((I.shape[0] * I.shape[1], I.shape[2]))
+ return img.reshape((img.shape[0] * img.shape[1], img.shape[2]))
def mat2im(X, shape):
@@ -39,8 +41,8 @@ def mat2im(X, shape):
return X.reshape(shape)
-def minmax(I):
- return np.clip(I, 0, 1)
+def minmax(img):
+ return np.clip(img, 0, 1)
##############################################################################
@@ -48,17 +50,19 @@ def minmax(I):
# -------------
# Loading images
-I1 = pl.imread('../../data/ocean_day.jpg').astype(np.float64) / 256
-I2 = pl.imread('../../data/ocean_sunset.jpg').astype(np.float64) / 256
+this_file = os.path.realpath('__file__')
+data_path = os.path.join(Path(this_file).parent.parent.parent, 'data')
+I1 = plt.imread(os.path.join(data_path, 'ocean_day.jpg')).astype(np.float64) / 256
+I2 = plt.imread(os.path.join(data_path, 'ocean_sunset.jpg')).astype(np.float64) / 256
X1 = im2mat(I1)
X2 = im2mat(I2)
# training samples
-nb = 1000
-idx1 = r.randint(X1.shape[0], size=(nb,))
-idx2 = r.randint(X2.shape[0], size=(nb,))
+nb = 500
+idx1 = rng.randint(X1.shape[0], size=(nb,))
+idx2 = rng.randint(X2.shape[0], size=(nb,))
Xs = X1[idx1, :]
Xt = X2[idx2, :]
@@ -99,76 +103,76 @@ Image_mapping_gaussian = minmax(mat2im(X1tn, I1.shape))
# Plot original images
# --------------------
-pl.figure(1, figsize=(6.4, 3))
-pl.subplot(1, 2, 1)
-pl.imshow(I1)
-pl.axis('off')
-pl.title('Image 1')
+plt.figure(1, figsize=(6.4, 3))
+plt.subplot(1, 2, 1)
+plt.imshow(I1)
+plt.axis('off')
+plt.title('Image 1')
-pl.subplot(1, 2, 2)
-pl.imshow(I2)
-pl.axis('off')
-pl.title('Image 2')
-pl.tight_layout()
+plt.subplot(1, 2, 2)
+plt.imshow(I2)
+plt.axis('off')
+plt.title('Image 2')
+plt.tight_layout()
##############################################################################
# Plot pixel values distribution
# ------------------------------
-pl.figure(2, figsize=(6.4, 5))
+plt.figure(2, figsize=(6.4, 5))
-pl.subplot(1, 2, 1)
-pl.scatter(Xs[:, 0], Xs[:, 2], c=Xs)
-pl.axis([0, 1, 0, 1])
-pl.xlabel('Red')
-pl.ylabel('Blue')
-pl.title('Image 1')
+plt.subplot(1, 2, 1)
+plt.scatter(Xs[:, 0], Xs[:, 2], c=Xs)
+plt.axis([0, 1, 0, 1])
+plt.xlabel('Red')
+plt.ylabel('Blue')
+plt.title('Image 1')
-pl.subplot(1, 2, 2)
-pl.scatter(Xt[:, 0], Xt[:, 2], c=Xt)
-pl.axis([0, 1, 0, 1])
-pl.xlabel('Red')
-pl.ylabel('Blue')
-pl.title('Image 2')
-pl.tight_layout()
+plt.subplot(1, 2, 2)
+plt.scatter(Xt[:, 0], Xt[:, 2], c=Xt)
+plt.axis([0, 1, 0, 1])
+plt.xlabel('Red')
+plt.ylabel('Blue')
+plt.title('Image 2')
+plt.tight_layout()
##############################################################################
# Plot transformed images
# -----------------------
-pl.figure(2, figsize=(10, 5))
+plt.figure(2, figsize=(10, 5))
-pl.subplot(2, 3, 1)
-pl.imshow(I1)
-pl.axis('off')
-pl.title('Im. 1')
+plt.subplot(2, 3, 1)
+plt.imshow(I1)
+plt.axis('off')
+plt.title('Im. 1')
-pl.subplot(2, 3, 4)
-pl.imshow(I2)
-pl.axis('off')
-pl.title('Im. 2')
+plt.subplot(2, 3, 4)
+plt.imshow(I2)
+plt.axis('off')
+plt.title('Im. 2')
-pl.subplot(2, 3, 2)
-pl.imshow(Image_emd)
-pl.axis('off')
-pl.title('EmdTransport')
+plt.subplot(2, 3, 2)
+plt.imshow(Image_emd)
+plt.axis('off')
+plt.title('EmdTransport')
-pl.subplot(2, 3, 5)
-pl.imshow(Image_sinkhorn)
-pl.axis('off')
-pl.title('SinkhornTransport')
+plt.subplot(2, 3, 5)
+plt.imshow(Image_sinkhorn)
+plt.axis('off')
+plt.title('SinkhornTransport')
-pl.subplot(2, 3, 3)
-pl.imshow(Image_mapping_linear)
-pl.axis('off')
-pl.title('MappingTransport (linear)')
+plt.subplot(2, 3, 3)
+plt.imshow(Image_mapping_linear)
+plt.axis('off')
+plt.title('MappingTransport (linear)')
-pl.subplot(2, 3, 6)
-pl.imshow(Image_mapping_gaussian)
-pl.axis('off')
-pl.title('MappingTransport (gaussian)')
-pl.tight_layout()
+plt.subplot(2, 3, 6)
+plt.imshow(Image_mapping_gaussian)
+plt.axis('off')
+plt.title('MappingTransport (gaussian)')
+plt.tight_layout()
-pl.show()
+plt.show()
diff --git a/examples/gromov/plot_barycenter_fgw.py b/examples/gromov/plot_barycenter_fgw.py
index 3f81765..556e08f 100644
--- a/examples/gromov/plot_barycenter_fgw.py
+++ b/examples/gromov/plot_barycenter_fgw.py
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
"""
=================================
-Plot graphs' barycenter using FGW
+Plot graphs barycenter using FGW
=================================
This example illustrates the computation barycenter of labeled graphs using
diff --git a/examples/gromov/plot_fgw.py b/examples/gromov/plot_fgw.py
index 97fe619..5475fb3 100644
--- a/examples/gromov/plot_fgw.py
+++ b/examples/gromov/plot_fgw.py
@@ -26,7 +26,7 @@ from ot.gromov import gromov_wasserstein, fused_gromov_wasserstein
##############################################################################
# Generate data
-# ---------
+# -------------
#%% parameters
# We create two 1D random measures
@@ -76,7 +76,7 @@ pl.show()
##############################################################################
# Create structure matrices and across-feature distance matrix
-# ---------
+# ------------------------------------------------------------
#%% Structure matrices and across-features distance matrix
C1 = ot.dist(xs)
@@ -88,7 +88,7 @@ Got = ot.emd([], [], M)
##############################################################################
# Plot matrices
-# ---------
+# -------------
#%%
cmap = 'Reds'
@@ -131,7 +131,7 @@ pl.show()
##############################################################################
# Compute FGW/GW
-# ---------
+# --------------
#%% Computing FGW and GW
alpha = 1e-3
@@ -145,7 +145,7 @@ Gg, log = gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', verbose=True,
##############################################################################
# Visualize transport matrices
-# ---------
+# ----------------------------
#%% visu OT matrix
cmap = 'Blues'
diff --git a/examples/gromov/plot_gromov.py b/examples/gromov/plot_gromov.py
index deb2f86..5a362cf 100644
--- a/examples/gromov/plot_gromov.py
+++ b/examples/gromov/plot_gromov.py
@@ -104,3 +104,37 @@ pl.imshow(gw, cmap='jet')
pl.title('Entropic Gromov Wasserstein')
pl.show()
+
+#############################################################################
+#
+# Compute GW with a scalable stochastic method with any loss function
+# ----------------------------------------------------------------------
+
+
+def loss(x, y):
+ return np.abs(x - y)
+
+
+pgw, plog = ot.gromov.pointwise_gromov_wasserstein(C1, C2, p, q, loss, max_iter=100,
+ log=True)
+
+sgw, slog = ot.gromov.sampled_gromov_wasserstein(C1, C2, p, q, loss, epsilon=0.1, max_iter=100,
+ log=True)
+
+print('Pointwise Gromov-Wasserstein distance estimated: ' + str(plog['gw_dist_estimated']))
+print('Variance estimated: ' + str(plog['gw_dist_std']))
+print('Sampled Gromov-Wasserstein distance: ' + str(slog['gw_dist_estimated']))
+print('Variance estimated: ' + str(slog['gw_dist_std']))
+
+
+pl.figure(1, (10, 5))
+
+pl.subplot(1, 2, 1)
+pl.imshow(pgw.toarray(), cmap='jet')
+pl.title('Pointwise Gromov Wasserstein')
+
+pl.subplot(1, 2, 2)
+pl.imshow(sgw, cmap='jet')
+pl.title('Sampled Gromov Wasserstein')
+
+pl.show()
diff --git a/examples/gromov/plot_gromov_barycenter.py b/examples/gromov/plot_gromov_barycenter.py
index f6f031a..7fe081f 100755
--- a/examples/gromov/plot_gromov_barycenter.py
+++ b/examples/gromov/plot_gromov_barycenter.py
@@ -13,11 +13,13 @@ computation in POT.
#
# License: MIT License
+import os
+from pathlib import Path
import numpy as np
import scipy as sp
-import matplotlib.pylab as pl
+from matplotlib import pyplot as plt
from sklearn import manifold
from sklearn.decomposition import PCA
@@ -84,22 +86,24 @@ def smacof_mds(C, dim, max_iter=3000, eps=1e-9):
# The four distributions are constructed from 4 simple images
-def im2mat(I):
+def im2mat(img):
"""Converts and image to matrix (one pixel per line)"""
- return I.reshape((I.shape[0] * I.shape[1], I.shape[2]))
+ return img.reshape((img.shape[0] * img.shape[1], img.shape[2]))
-square = pl.imread('../../data/square.png').astype(np.float64)[:, :, 2]
-cross = pl.imread('../../data/cross.png').astype(np.float64)[:, :, 2]
-triangle = pl.imread('../../data/triangle.png').astype(np.float64)[:, :, 2]
-star = pl.imread('../../data/star.png').astype(np.float64)[:, :, 2]
+this_file = os.path.realpath('__file__')
+data_path = os.path.join(Path(this_file).parent.parent.parent, 'data')
+
+square = plt.imread(os.path.join(data_path, 'square.png')).astype(np.float64)[:, :, 2]
+cross = plt.imread(os.path.join(data_path, 'cross.png')).astype(np.float64)[:, :, 2]
+triangle = plt.imread(os.path.join(data_path, 'triangle.png')).astype(np.float64)[:, :, 2]
+star = plt.imread(os.path.join(data_path, 'star.png')).astype(np.float64)[:, :, 2]
shapes = [square, cross, triangle, star]
S = 4
xs = [[] for i in range(S)]
-
for nb in range(4):
for i in range(8):
for j in range(8):
@@ -184,64 +188,64 @@ npost23 = [smacof_mds(Ct23[s], 2) for s in range(2)]
npost23 = [clf.fit_transform(npost23[s]) for s in range(2)]
-fig = pl.figure(figsize=(10, 10))
+fig = plt.figure(figsize=(10, 10))
-ax1 = pl.subplot2grid((4, 4), (0, 0))
-pl.xlim((-1, 1))
-pl.ylim((-1, 1))
+ax1 = plt.subplot2grid((4, 4), (0, 0))
+plt.xlim((-1, 1))
+plt.ylim((-1, 1))
ax1.scatter(npos[0][:, 0], npos[0][:, 1], color='r')
-ax2 = pl.subplot2grid((4, 4), (0, 1))
-pl.xlim((-1, 1))
-pl.ylim((-1, 1))
+ax2 = plt.subplot2grid((4, 4), (0, 1))
+plt.xlim((-1, 1))
+plt.ylim((-1, 1))
ax2.scatter(npost01[1][:, 0], npost01[1][:, 1], color='b')
-ax3 = pl.subplot2grid((4, 4), (0, 2))
-pl.xlim((-1, 1))
-pl.ylim((-1, 1))
+ax3 = plt.subplot2grid((4, 4), (0, 2))
+plt.xlim((-1, 1))
+plt.ylim((-1, 1))
ax3.scatter(npost01[0][:, 0], npost01[0][:, 1], color='b')
-ax4 = pl.subplot2grid((4, 4), (0, 3))
-pl.xlim((-1, 1))
-pl.ylim((-1, 1))
+ax4 = plt.subplot2grid((4, 4), (0, 3))
+plt.xlim((-1, 1))
+plt.ylim((-1, 1))
ax4.scatter(npos[1][:, 0], npos[1][:, 1], color='r')
-ax5 = pl.subplot2grid((4, 4), (1, 0))
-pl.xlim((-1, 1))
-pl.ylim((-1, 1))
+ax5 = plt.subplot2grid((4, 4), (1, 0))
+plt.xlim((-1, 1))
+plt.ylim((-1, 1))
ax5.scatter(npost02[1][:, 0], npost02[1][:, 1], color='b')
-ax6 = pl.subplot2grid((4, 4), (1, 3))
-pl.xlim((-1, 1))
-pl.ylim((-1, 1))
+ax6 = plt.subplot2grid((4, 4), (1, 3))
+plt.xlim((-1, 1))
+plt.ylim((-1, 1))
ax6.scatter(npost13[1][:, 0], npost13[1][:, 1], color='b')
-ax7 = pl.subplot2grid((4, 4), (2, 0))
-pl.xlim((-1, 1))
-pl.ylim((-1, 1))
+ax7 = plt.subplot2grid((4, 4), (2, 0))
+plt.xlim((-1, 1))
+plt.ylim((-1, 1))
ax7.scatter(npost02[0][:, 0], npost02[0][:, 1], color='b')
-ax8 = pl.subplot2grid((4, 4), (2, 3))
-pl.xlim((-1, 1))
-pl.ylim((-1, 1))
+ax8 = plt.subplot2grid((4, 4), (2, 3))
+plt.xlim((-1, 1))
+plt.ylim((-1, 1))
ax8.scatter(npost13[0][:, 0], npost13[0][:, 1], color='b')
-ax9 = pl.subplot2grid((4, 4), (3, 0))
-pl.xlim((-1, 1))
-pl.ylim((-1, 1))
+ax9 = plt.subplot2grid((4, 4), (3, 0))
+plt.xlim((-1, 1))
+plt.ylim((-1, 1))
ax9.scatter(npos[2][:, 0], npos[2][:, 1], color='r')
-ax10 = pl.subplot2grid((4, 4), (3, 1))
-pl.xlim((-1, 1))
-pl.ylim((-1, 1))
+ax10 = plt.subplot2grid((4, 4), (3, 1))
+plt.xlim((-1, 1))
+plt.ylim((-1, 1))
ax10.scatter(npost23[1][:, 0], npost23[1][:, 1], color='b')
-ax11 = pl.subplot2grid((4, 4), (3, 2))
-pl.xlim((-1, 1))
-pl.ylim((-1, 1))
+ax11 = plt.subplot2grid((4, 4), (3, 2))
+plt.xlim((-1, 1))
+plt.ylim((-1, 1))
ax11.scatter(npost23[0][:, 0], npost23[0][:, 1], color='b')
-ax12 = pl.subplot2grid((4, 4), (3, 3))
-pl.xlim((-1, 1))
-pl.ylim((-1, 1))
+ax12 = plt.subplot2grid((4, 4), (3, 3))
+plt.xlim((-1, 1))
+plt.ylim((-1, 1))
ax12.scatter(npos[3][:, 0], npos[3][:, 1], color='r')
diff --git a/examples/plot_Intro_OT.py b/examples/plot_Intro_OT.py
new file mode 100644
index 0000000..2e2c6fd
--- /dev/null
+++ b/examples/plot_Intro_OT.py
@@ -0,0 +1,373 @@
+# coding: utf-8
+"""
+=============================================
+Introduction to Optimal Transport with Python
+=============================================
+
+This example gives an introduction on how to use Optimal Transport in Python.
+
+"""
+
+# Author: Remi Flamary, Nicolas Courty, Aurelie Boisbunon
+#
+# License: MIT License
+# sphinx_gallery_thumbnail_number = 1
+
+##############################################################################
+# POT Python Optimal Transport Toolbox
+# ------------------------------------
+#
+# POT installation
+# ```````````````````
+#
+# * Install with pip::
+#
+# pip install pot
+# * Install with conda::
+#
+# conda install -c conda-forge pot
+#
+# Import the toolbox
+# ```````````````````
+#
+
+import numpy as np # always need it
+import pylab as pl # do the plots
+
+import ot # ot
+
+import time
+
+##############################################################################
+# Getting help
+# `````````````
+#
+# Online documentation : `<https://pythonot.github.io/all.html>`_
+#
+# Or inline help:
+#
+
+help(ot.dist)
+
+
+##############################################################################
+# First OT Problem
+# ----------------
+#
+# We will solve the Bakery/Cafés problem of transporting croissants from a
+# number of Bakeries to Cafés in a City (in this case Manhattan). We did a
+# quick google map search in Manhattan for bakeries and Cafés:
+#
+# .. image:: images/bak.png
+# :align: center
+# :alt: bakery-cafe-manhattan
+# :width: 600px
+# :height: 280px
+#
+# We extracted from this search their positions and generated fictional
+# production and sale number (that both sum to the same value).
+#
+# We have acess to the position of Bakeries ``bakery_pos`` and their
+# respective production ``bakery_prod`` which describe the source
+# distribution. The Cafés where the croissants are sold are defined also by
+# their position ``cafe_pos`` and ``cafe_prod``, and describe the target
+# distribution. For fun we also provide a
+# map ``Imap`` that will illustrate the position of these shops in the city.
+#
+#
+# Now we load the data
+#
+#
+
+data = np.load('../data/manhattan.npz')
+
+bakery_pos = data['bakery_pos']
+bakery_prod = data['bakery_prod']
+cafe_pos = data['cafe_pos']
+cafe_prod = data['cafe_prod']
+Imap = data['Imap']
+
+print('Bakery production: {}'.format(bakery_prod))
+print('Cafe sale: {}'.format(cafe_prod))
+print('Total croissants : {}'.format(cafe_prod.sum()))
+
+
+##############################################################################
+# Plotting bakeries in the city
+# -----------------------------
+#
+# Next we plot the position of the bakeries and cafés on the map. The size of
+# the circle is proportional to their production.
+#
+
+pl.figure(1, (7, 6))
+pl.clf()
+pl.imshow(Imap, interpolation='bilinear') # plot the map
+pl.scatter(bakery_pos[:, 0], bakery_pos[:, 1], s=bakery_prod, c='r', ec='k', label='Bakeries')
+pl.scatter(cafe_pos[:, 0], cafe_pos[:, 1], s=cafe_prod, c='b', ec='k', label='Cafés')
+pl.legend()
+pl.title('Manhattan Bakeries and Cafés')
+
+
+##############################################################################
+# Cost matrix
+# -----------
+#
+#
+# We can now compute the cost matrix between the bakeries and the cafés, which
+# will be the transport cost matrix. This can be done using the
+# `ot.dist <https://pythonot.github.io/all.html#ot.dist>`_ function that
+# defaults to squared Euclidean distance but can return other things such as
+# cityblock (or Manhattan distance).
+#
+
+C = ot.dist(bakery_pos, cafe_pos)
+
+labels = [str(i) for i in range(len(bakery_prod))]
+f = pl.figure(2, (14, 7))
+pl.clf()
+pl.subplot(121)
+pl.imshow(Imap, interpolation='bilinear') # plot the map
+for i in range(len(cafe_pos)):
+ pl.text(cafe_pos[i, 0], cafe_pos[i, 1], labels[i], color='b',
+ fontsize=14, fontweight='bold', ha='center', va='center')
+for i in range(len(bakery_pos)):
+ pl.text(bakery_pos[i, 0], bakery_pos[i, 1], labels[i], color='r',
+ fontsize=14, fontweight='bold', ha='center', va='center')
+pl.title('Manhattan Bakeries and Cafés')
+
+ax = pl.subplot(122)
+im = pl.imshow(C, cmap="coolwarm")
+pl.title('Cost matrix')
+cbar = pl.colorbar(im, ax=ax, shrink=0.5, use_gridspec=True)
+cbar.ax.set_ylabel("cost", rotation=-90, va="bottom")
+
+pl.xlabel('Cafés')
+pl.ylabel('Bakeries')
+pl.tight_layout()
+
+
+##############################################################################
+# The red cells in the matrix image show the bakeries and cafés that are
+# further away, and thus more costly to transport from one to the other, while
+# the blue ones show those that are very close to each other, with respect to
+# the squared Euclidean distance.
+
+
+##############################################################################
+# Solving the OT problem with `ot.emd <https://pythonot.github.io/all.html#ot.emd>`_
+# -----------------------------------------------------------------------------------
+
+start = time.time()
+ot_emd = ot.emd(bakery_prod, cafe_prod, C)
+time_emd = time.time() - start
+
+##############################################################################
+# The function returns the transport matrix, which we can then visualize (next section).
+
+##############################################################################
+# Transportation plan vizualization
+# `````````````````````````````````
+#
+# A good vizualization of the OT matrix in the 2D plane is to denote the
+# transportation of mass between a Bakery and a Café by a line. This can easily
+# be done with a double ``for`` loop.
+#
+# In order to make it more interpretable one can also use the ``alpha``
+# parameter of plot and set it to ``alpha=G[i,j]/G.max()``.
+
+# Plot the matrix and the map
+f = pl.figure(3, (14, 7))
+pl.clf()
+pl.subplot(121)
+pl.imshow(Imap, interpolation='bilinear') # plot the map
+for i in range(len(bakery_pos)):
+ for j in range(len(cafe_pos)):
+ pl.plot([bakery_pos[i, 0], cafe_pos[j, 0]], [bakery_pos[i, 1], cafe_pos[j, 1]],
+ '-k', lw=3. * ot_emd[i, j] / ot_emd.max())
+for i in range(len(cafe_pos)):
+ pl.text(cafe_pos[i, 0], cafe_pos[i, 1], labels[i], color='b', fontsize=14,
+ fontweight='bold', ha='center', va='center')
+for i in range(len(bakery_pos)):
+ pl.text(bakery_pos[i, 0], bakery_pos[i, 1], labels[i], color='r', fontsize=14,
+ fontweight='bold', ha='center', va='center')
+pl.title('Manhattan Bakeries and Cafés')
+
+ax = pl.subplot(122)
+im = pl.imshow(ot_emd)
+for i in range(len(bakery_prod)):
+ for j in range(len(cafe_prod)):
+ text = ax.text(j, i, '{0:g}'.format(ot_emd[i, j]),
+ ha="center", va="center", color="w")
+pl.title('Transport matrix')
+
+pl.xlabel('Cafés')
+pl.ylabel('Bakeries')
+pl.tight_layout()
+
+##############################################################################
+# The transport matrix gives the number of croissants that can be transported
+# from each bakery to each café. We can see that the bakeries only need to
+# transport croissants to one or two cafés, the transport matrix being very
+# sparse.
+
+##############################################################################
+# OT loss and dual variables
+# --------------------------
+#
+# The resulting wasserstein loss loss is of the form:
+#
+# .. math::
+# W=\sum_{i,j}\gamma_{i,j}C_{i,j}
+#
+# where :math:`\gamma` is the optimal transport matrix.
+#
+
+W = np.sum(ot_emd * C)
+print('Wasserstein loss (EMD) = {0:.2f}'.format(W))
+
+##############################################################################
+# Regularized OT with Sinkhorn
+# ----------------------------
+#
+# The Sinkhorn algorithm is very simple to code. You can implement it directly
+# using the following pseudo-code
+#
+# .. image:: images/sinkhorn.png
+# :align: center
+# :alt: Sinkhorn algorithm
+# :width: 440px
+# :height: 240px
+#
+# In this algorithm, :math:`\oslash` corresponds to the element-wise division.
+#
+# An alternative is to use the POT toolbox with
+# `ot.sinkhorn <https://pythonot.github.io/all.html#ot.sinkhorn>`_
+#
+# Be careful of numerical problems. A good pre-processing for Sinkhorn is to
+# divide the cost matrix ``C`` by its maximum value.
+
+##############################################################################
+# Algorithm
+# `````````
+
+# Compute Sinkhorn transport matrix from algorithm
+reg = 0.1
+K = np.exp(-C / C.max() / reg)
+nit = 100
+u = np.ones((len(bakery_prod), ))
+for i in range(1, nit):
+ v = cafe_prod / np.dot(K.T, u)
+ u = bakery_prod / (np.dot(K, v))
+ot_sink_algo = np.atleast_2d(u).T * (K * v.T) # Equivalent to np.dot(np.diag(u), np.dot(K, np.diag(v)))
+
+# Compute Sinkhorn transport matrix with POT
+ot_sinkhorn = ot.sinkhorn(bakery_prod, cafe_prod, reg=reg, M=C / C.max())
+
+# Difference between the 2
+print('Difference between algo and ot.sinkhorn = {0:.2g}'.format(np.sum(np.power(ot_sink_algo - ot_sinkhorn, 2))))
+
+##############################################################################
+# Plot the matrix and the map
+# ```````````````````````````
+
+print('Min. of Sinkhorn\'s transport matrix = {0:.2g}'.format(np.min(ot_sinkhorn)))
+
+f = pl.figure(4, (13, 6))
+pl.clf()
+pl.subplot(121)
+pl.imshow(Imap, interpolation='bilinear') # plot the map
+for i in range(len(bakery_pos)):
+ for j in range(len(cafe_pos)):
+ pl.plot([bakery_pos[i, 0], cafe_pos[j, 0]],
+ [bakery_pos[i, 1], cafe_pos[j, 1]],
+ '-k', lw=3. * ot_sinkhorn[i, j] / ot_sinkhorn.max())
+for i in range(len(cafe_pos)):
+ pl.text(cafe_pos[i, 0], cafe_pos[i, 1], labels[i], color='b',
+ fontsize=14, fontweight='bold', ha='center', va='center')
+for i in range(len(bakery_pos)):
+ pl.text(bakery_pos[i, 0], bakery_pos[i, 1], labels[i], color='r',
+ fontsize=14, fontweight='bold', ha='center', va='center')
+pl.title('Manhattan Bakeries and Cafés')
+
+ax = pl.subplot(122)
+im = pl.imshow(ot_sinkhorn)
+for i in range(len(bakery_prod)):
+ for j in range(len(cafe_prod)):
+ text = ax.text(j, i, np.round(ot_sinkhorn[i, j], 1),
+ ha="center", va="center", color="w")
+pl.title('Transport matrix')
+
+pl.xlabel('Cafés')
+pl.ylabel('Bakeries')
+pl.tight_layout()
+
+
+##############################################################################
+# We notice right away that the matrix is not sparse at all with Sinkhorn,
+# each bakery delivering croissants to all 5 cafés with that solution. Also,
+# this solution gives a transport with fractions, which does not make sense
+# in the case of croissants. This was not the case with EMD.
+
+##############################################################################
+# Varying the regularization parameter in Sinkhorn
+# ````````````````````````````````````````````````
+#
+
+reg_parameter = np.logspace(-3, 0, 20)
+W_sinkhorn_reg = np.zeros((len(reg_parameter), ))
+time_sinkhorn_reg = np.zeros((len(reg_parameter), ))
+
+f = pl.figure(5, (14, 5))
+pl.clf()
+max_ot = 100 # plot matrices with the same colorbar
+for k in range(len(reg_parameter)):
+ start = time.time()
+ ot_sinkhorn = ot.sinkhorn(bakery_prod, cafe_prod, reg=reg_parameter[k], M=C / C.max())
+ time_sinkhorn_reg[k] = time.time() - start
+
+ if k % 4 == 0 and k > 0: # we only plot a few
+ ax = pl.subplot(1, 5, k / 4)
+ im = pl.imshow(ot_sinkhorn, vmin=0, vmax=max_ot)
+ pl.title('reg={0:.2g}'.format(reg_parameter[k]))
+ pl.xlabel('Cafés')
+ pl.ylabel('Bakeries')
+
+ # Compute the Wasserstein loss for Sinkhorn, and compare with EMD
+ W_sinkhorn_reg[k] = np.sum(ot_sinkhorn * C)
+pl.tight_layout()
+
+
+##############################################################################
+# This series of graph shows that the solution of Sinkhorn starts with something
+# very similar to EMD (although not sparse) for very small values of the
+# regularization parameter, and tends to a more uniform solution as the
+# regularization parameter increases.
+#
+
+##############################################################################
+# Wasserstein loss and computational time
+# ```````````````````````````````````````
+#
+
+# Plot the matrix and the map
+f = pl.figure(6, (4, 4))
+pl.clf()
+pl.title("Comparison between Sinkhorn and EMD")
+
+pl.plot(reg_parameter, W_sinkhorn_reg, 'o', label="Sinkhorn")
+XLim = pl.xlim()
+pl.plot(XLim, [W, W], '--k', label="EMD")
+pl.legend()
+pl.xlabel("reg")
+pl.ylabel("Wasserstein loss")
+
+##############################################################################
+# In this last graph, we show the impact of the regularization parameter on
+# the Wasserstein loss. We can see that higher
+# values of ``reg`` leads to a much higher Wasserstein loss.
+#
+# The Wasserstein loss of EMD is displayed for
+# comparison. The Wasserstein loss of Sinkhorn can be a little lower than that
+# of EMD for low values of ``reg``, but it quickly gets much higher.
+#
diff --git a/examples/plot_OT_1D_smooth.py b/examples/plot_OT_1D_smooth.py
index 75cd295..b07f99f 100644
--- a/examples/plot_OT_1D_smooth.py
+++ b/examples/plot_OT_1D_smooth.py
@@ -87,7 +87,7 @@ pl.show()
##############################################################################
# Solve Smooth OT
-# --------------
+# ---------------
#%% Smooth OT with KL regularization
diff --git a/examples/plot_OT_2D_samples.py b/examples/plot_OT_2D_samples.py
index 1544e82..af1bc12 100644
--- a/examples/plot_OT_2D_samples.py
+++ b/examples/plot_OT_2D_samples.py
@@ -107,7 +107,7 @@ pl.show()
##############################################################################
# Emprirical Sinkhorn
-# ----------------
+# -------------------
#%% sinkhorn
diff --git a/examples/sliced-wasserstein/README.txt b/examples/sliced-wasserstein/README.txt
new file mode 100644
index 0000000..a575345
--- /dev/null
+++ b/examples/sliced-wasserstein/README.txt
@@ -0,0 +1,4 @@
+
+
+Sliced Wasserstein Distance
+--------------------------- \ No newline at end of file
diff --git a/examples/sliced-wasserstein/plot_variance.py b/examples/sliced-wasserstein/plot_variance.py
new file mode 100644
index 0000000..7d73907
--- /dev/null
+++ b/examples/sliced-wasserstein/plot_variance.py
@@ -0,0 +1,86 @@
+# -*- coding: utf-8 -*-
+"""
+==============================
+2D Sliced Wasserstein Distance
+==============================
+
+This example illustrates the computation of the sliced Wasserstein Distance as
+proposed in [31].
+
+[31] Bonneel, Nicolas, et al. "Sliced and radon wasserstein barycenters of
+measures." Journal of Mathematical Imaging and Vision 51.1 (2015): 22-45
+
+"""
+
+# Author: Adrien Corenflos <adrien.corenflos@aalto.fi>
+#
+# License: MIT License
+
+import matplotlib.pylab as pl
+import numpy as np
+
+import ot
+
+##############################################################################
+# Generate data
+# -------------
+
+# %% parameters and data generation
+
+n = 500 # nb samples
+
+mu_s = np.array([0, 0])
+cov_s = np.array([[1, 0], [0, 1]])
+
+mu_t = np.array([4, 4])
+cov_t = np.array([[1, -.8], [-.8, 1]])
+
+xs = ot.datasets.make_2D_samples_gauss(n, mu_s, cov_s)
+xt = ot.datasets.make_2D_samples_gauss(n, mu_t, cov_t)
+
+a, b = np.ones((n,)) / n, np.ones((n,)) / n # uniform distribution on samples
+
+##############################################################################
+# Plot data
+# ---------
+
+# %% plot samples
+
+pl.figure(1)
+pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
+pl.legend(loc=0)
+pl.title('Source and target distributions')
+
+###############################################################################
+# Sliced Wasserstein distance for different seeds and number of projections
+# -------------------------------------------------------------------------
+
+n_seed = 50
+n_projections_arr = np.logspace(0, 3, 25, dtype=int)
+res = np.empty((n_seed, 25))
+
+# %% Compute statistics
+for seed in range(n_seed):
+ for i, n_projections in enumerate(n_projections_arr):
+ res[seed, i] = ot.sliced_wasserstein_distance(xs, xt, a, b, n_projections, seed=seed)
+
+res_mean = np.mean(res, axis=0)
+res_std = np.std(res, axis=0)
+
+###############################################################################
+# Plot Sliced Wasserstein Distance
+# --------------------------------
+
+pl.figure(2)
+pl.plot(n_projections_arr, res_mean, label="SWD")
+pl.fill_between(n_projections_arr, res_mean - 2 * res_std, res_mean + 2 * res_std, alpha=0.5)
+
+pl.legend()
+pl.xscale('log')
+
+pl.xlabel("Number of projections")
+pl.ylabel("Distance")
+pl.title('Sliced Wasserstein Distance with 95% confidence inverval')
+
+pl.show()
diff --git a/examples/unbalanced-partial/plot_UOT_1D.py b/examples/unbalanced-partial/plot_UOT_1D.py
index 2ea8b05..183849c 100644
--- a/examples/unbalanced-partial/plot_UOT_1D.py
+++ b/examples/unbalanced-partial/plot_UOT_1D.py
@@ -61,8 +61,7 @@ ot.plot.plot1D_mat(a, b, M, 'Cost matrix M')
##############################################################################
# Solve Unbalanced Sinkhorn
-# --------------
-
+# -------------------------
# Sinkhorn
diff --git a/examples/unbalanced-partial/plot_partial_wass_and_gromov.py b/examples/unbalanced-partial/plot_partial_wass_and_gromov.py
index 0c5cbf9..ac4194c 100755
--- a/examples/unbalanced-partial/plot_partial_wass_and_gromov.py
+++ b/examples/unbalanced-partial/plot_partial_wass_and_gromov.py
@@ -4,7 +4,7 @@
Partial Wasserstein and Gromov-Wasserstein example
==================================================
-This example is designed to show how to use the Partial (Gromov-)Wassertsein
+This example is designed to show how to use the Partial (Gromov-)Wasserstein
distance computation in POT.
"""
@@ -123,11 +123,12 @@ C1 = sp.spatial.distance.cdist(xs, xs)
C2 = sp.spatial.distance.cdist(xt, xt)
# transport 100% of the mass
-print('-----m = 1')
+print('------m = 1')
m = 1
res0, log0 = ot.partial.partial_gromov_wasserstein(C1, C2, p, q, m=m, log=True)
res, log = ot.partial.entropic_partial_gromov_wasserstein(C1, C2, p, q, 10,
- m=m, log=True)
+ m=m, log=True,
+ verbose=True)
print('Wasserstein distance (m = 1): ' + str(log0['partial_gw_dist']))
print('Entropic Wasserstein distance (m = 1): ' + str(log['partial_gw_dist']))
@@ -136,18 +137,20 @@ pl.figure(1, (10, 5))
pl.title("mass to be transported m = 1")
pl.subplot(1, 2, 1)
pl.imshow(res0, cmap='jet')
-pl.title('Wasserstein')
+pl.title('Gromov-Wasserstein')
pl.subplot(1, 2, 2)
pl.imshow(res, cmap='jet')
-pl.title('Entropic Wasserstein')
+pl.title('Entropic Gromov-Wasserstein')
pl.show()
# transport 2/3 of the mass
-print('-----m = 2/3')
+print('------m = 2/3')
m = 2 / 3
-res0, log0 = ot.partial.partial_gromov_wasserstein(C1, C2, p, q, m=m, log=True)
+res0, log0 = ot.partial.partial_gromov_wasserstein(C1, C2, p, q, m=m, log=True,
+ verbose=True)
res, log = ot.partial.entropic_partial_gromov_wasserstein(C1, C2, p, q, 10,
- m=m, log=True)
+ m=m, log=True,
+ verbose=True)
print('Partial Wasserstein distance (m = 2/3): ' +
str(log0['partial_gw_dist']))
@@ -158,8 +161,8 @@ pl.figure(1, (10, 5))
pl.title("mass to be transported m = 2/3")
pl.subplot(1, 2, 1)
pl.imshow(res0, cmap='jet')
-pl.title('Partial Wasserstein')
+pl.title('Partial Gromov-Wasserstein')
pl.subplot(1, 2, 2)
pl.imshow(res, cmap='jet')
-pl.title('Entropic partial Wasserstein')
+pl.title('Entropic partial Gromov-Wasserstein')
pl.show()
diff --git a/examples/unbalanced-partial/plot_regpath.py b/examples/unbalanced-partial/plot_regpath.py
new file mode 100644
index 0000000..4a51c2d
--- /dev/null
+++ b/examples/unbalanced-partial/plot_regpath.py
@@ -0,0 +1,135 @@
+# -*- coding: utf-8 -*-
+"""
+================================================================
+Regularization path of l2-penalized unbalanced optimal transport
+================================================================
+This example illustrate the regularization path for 2D unbalanced
+optimal transport. We present here both the fully relaxed case
+and the semi-relaxed case.
+
+[Chapel et al., 2021] Chapel, L., Flamary, R., Wu, H., Févotte, C.,
+and Gasso, G. (2021). Unbalanced optimal transport through non-negative
+penalized linear regression.
+"""
+
+# Author: Haoran Wu <haoran.wu@univ-ubs.fr>
+# License: MIT License
+
+
+import numpy as np
+import matplotlib.pylab as pl
+import ot
+
+##############################################################################
+# Generate data
+# -------------
+
+#%% parameters and data generation
+
+n = 50 # nb samples
+
+mu_s = np.array([-1, -1])
+cov_s = np.array([[1, 0], [0, 1]])
+
+mu_t = np.array([4, 4])
+cov_t = np.array([[1, -.8], [-.8, 1]])
+
+np.random.seed(0)
+xs = ot.datasets.make_2D_samples_gauss(n, mu_s, cov_s)
+xt = ot.datasets.make_2D_samples_gauss(n, mu_t, cov_t)
+
+a, b = np.ones((n,)) / n, np.ones((n,)) / n # uniform distribution on samples
+
+# loss matrix
+M = ot.dist(xs, xt)
+M /= M.max()
+
+##############################################################################
+# Plot data
+# ---------
+
+#%% plot 2 distribution samples
+
+pl.figure(1)
+pl.scatter(xs[:, 0], xs[:, 1], c='C0', label='Source')
+pl.scatter(xt[:, 0], xt[:, 1], c='C1', label='Target')
+pl.legend(loc=2)
+pl.title('Source and target distributions')
+pl.show()
+
+##############################################################################
+# Compute semi-relaxed and fully relaxed regularization paths
+# -----------
+
+#%%
+final_gamma = 1e-8
+t, t_list, g_list = ot.regpath.regularization_path(a, b, M, reg=final_gamma,
+ semi_relaxed=False)
+t2, t_list2, g_list2 = ot.regpath.regularization_path(a, b, M, reg=final_gamma,
+ semi_relaxed=True)
+
+
+##############################################################################
+# Plot the regularization path
+# ----------------
+
+#%% fully relaxed l2-penalized UOT
+
+pl.figure(2)
+selected_gamma = [2e-1, 1e-1, 5e-2, 1e-3]
+for p in range(4):
+ tp = ot.regpath.compute_transport_plan(selected_gamma[p], g_list,
+ t_list)
+ P = tp.reshape((n, n))
+ pl.subplot(2, 2, p + 1)
+ if P.sum() > 0:
+ P = P / P.max()
+ for i in range(n):
+ for j in range(n):
+ if P[i, j] > 0:
+ pl.plot([xs[i, 0], xt[j, 0]], [xs[i, 1], xt[j, 1]], color='C2',
+ alpha=P[i, j] * 0.3)
+ pl.scatter(xs[:, 0], xs[:, 1], c='C0', alpha=0.2)
+ pl.scatter(xt[:, 0], xt[:, 1], c='C1', alpha=0.2)
+ pl.scatter(xs[:, 0], xs[:, 1], c='C0', s=P.sum(1).ravel() * (1 + p) * 2,
+ label='Re-weighted source', alpha=1)
+ pl.scatter(xt[:, 0], xt[:, 1], c='C1', s=P.sum(0).ravel() * (1 + p) * 2,
+ label='Re-weighted target', alpha=1)
+ pl.plot([], [], color='C2', alpha=0.8, label='OT plan')
+ pl.title(r'$\ell_2$ UOT $\gamma$={}'.format(selected_gamma[p]),
+ fontsize=11)
+ if p < 2:
+ pl.xticks(())
+pl.show()
+
+
+##############################################################################
+# Plot the semi-relaxed regularization path
+# -------------------
+
+#%% semi-relaxed l2-penalized UOT
+
+pl.figure(3)
+selected_gamma = [10, 1, 1e-1, 1e-2]
+for p in range(4):
+ tp = ot.regpath.compute_transport_plan(selected_gamma[p], g_list2,
+ t_list2)
+ P = tp.reshape((n, n))
+ pl.subplot(2, 2, p + 1)
+ if P.sum() > 0:
+ P = P / P.max()
+ for i in range(n):
+ for j in range(n):
+ if P[i, j] > 0:
+ pl.plot([xs[i, 0], xt[j, 0]], [xs[i, 1], xt[j, 1]], color='C2',
+ alpha=P[i, j] * 0.3)
+ pl.scatter(xs[:, 0], xs[:, 1], c='C0', alpha=0.2)
+ pl.scatter(xt[:, 0], xt[:, 1], c='C1', alpha=1, label='Target marginal')
+ pl.scatter(xs[:, 0], xs[:, 1], c='C0', s=P.sum(1).ravel() * 2 * (1 + p),
+ label='Source marginal', alpha=1)
+ pl.plot([], [], color='C2', alpha=0.8, label='OT plan')
+ pl.title(r'Semi-relaxed $l_2$ UOT $\gamma$={}'.format(selected_gamma[p]),
+ fontsize=11)
+ if p < 2:
+ pl.xticks(())
+pl.show()
diff --git a/ot/__init__.py b/ot/__init__.py
index 0e6e2e2..b6dc2b4 100644
--- a/ot/__init__.py
+++ b/ot/__init__.py
@@ -5,7 +5,8 @@
:py:mod:`ot.lp`, :py:mod:`ot.bregman`, :py:mod:`ot.optim`
:py:mod:`ot.utils`, :py:mod:`ot.datasets`,
:py:mod:`ot.gromov`, :py:mod:`ot.smooth`
- :py:mod:`ot.stochastic`
+ :py:mod:`ot.stochastic`, :py:mod:`ot.partial`, :py:mod:`ot.regpath`
+ , :py:mod:`ot.unbalanced`.
The following sub-modules are not imported due to additional dependencies:
@@ -33,21 +34,30 @@ from . import smooth
from . import stochastic
from . import unbalanced
from . import partial
+from . import backend
+from . import regpath
# OT functions
from .lp import emd, emd2, emd_1d, emd2_1d, wasserstein_1d
from .bregman import sinkhorn, sinkhorn2, barycenter
-from .unbalanced import sinkhorn_unbalanced, barycenter_unbalanced, sinkhorn_unbalanced2
+from .unbalanced import (sinkhorn_unbalanced, barycenter_unbalanced,
+ sinkhorn_unbalanced2)
from .da import sinkhorn_lpl1_mm
+from .sliced import sliced_wasserstein_distance, max_sliced_wasserstein_distance
+from .gromov import (gromov_wasserstein, gromov_wasserstein2,
+ gromov_barycenters, fused_gromov_wasserstein, fused_gromov_wasserstein2)
# utils functions
from .utils import dist, unif, tic, toc, toq
-__version__ = "0.7.0"
+__version__ = "0.8.0"
-__all__ = ['emd', 'emd2', 'emd_1d', 'sinkhorn', 'sinkhorn2', 'utils', 'datasets',
- 'bregman', 'lp', 'tic', 'toc', 'toq', 'gromov',
- 'emd_1d', 'emd2_1d', 'wasserstein_1d',
+__all__ = ['emd', 'emd2', 'emd_1d', 'sinkhorn', 'sinkhorn2', 'utils',
+ 'datasets', 'bregman', 'lp', 'tic', 'toc', 'toq', 'gromov',
+ 'emd2_1d', 'wasserstein_1d', 'backend',
'dist', 'unif', 'barycenter', 'sinkhorn_lpl1_mm', 'da', 'optim',
'sinkhorn_unbalanced', 'barycenter_unbalanced',
- 'sinkhorn_unbalanced2']
+ 'sinkhorn_unbalanced2', 'sliced_wasserstein_distance',
+ 'gromov_wasserstein', 'gromov_wasserstein2', 'gromov_barycenters', 'fused_gromov_wasserstein', 'fused_gromov_wasserstein2',
+ 'max_sliced_wasserstein_distance',
+ 'smooth', 'stochastic', 'unbalanced', 'partial', 'regpath']
diff --git a/ot/backend.py b/ot/backend.py
new file mode 100644
index 0000000..a044f84
--- /dev/null
+++ b/ot/backend.py
@@ -0,0 +1,1502 @@
+# -*- coding: utf-8 -*-
+"""
+Multi-lib backend for POT
+
+The goal is to write backend-agnostic code. Whether you're using Numpy, PyTorch,
+or Jax, POT code should work nonetheless.
+To achieve that, POT provides backend classes which implements functions in their respective backend
+imitating Numpy API. As a convention, we use nx instead of np to refer to the backend.
+
+Examples
+--------
+
+>>> from ot.utils import list_to_array
+>>> from ot.backend import get_backend
+>>> def f(a, b): # the function does not know which backend to use
+... a, b = list_to_array(a, b) # if a list in given, make it an array
+... nx = get_backend(a, b) # infer the backend from the arguments
+... c = nx.dot(a, b) # now use the backend to do any calculation
+... return c
+"""
+
+# Author: Remi Flamary <remi.flamary@polytechnique.edu>
+# Nicolas Courty <ncourty@irisa.fr>
+#
+# License: MIT License
+
+import numpy as np
+import scipy.special as scipy
+from scipy.sparse import issparse, coo_matrix, csr_matrix
+
+try:
+ import torch
+ torch_type = torch.Tensor
+except ImportError:
+ torch = False
+ torch_type = float
+
+try:
+ import jax
+ import jax.numpy as jnp
+ import jax.scipy.special as jscipy
+ jax_type = jax.numpy.ndarray
+except ImportError:
+ jax = False
+ jax_type = float
+
+str_type_error = "All array should be from the same type/backend. Current types are : {}"
+
+
+def get_backend_list():
+ """Returns the list of available backends"""
+ lst = [NumpyBackend(), ]
+
+ if torch:
+ lst.append(TorchBackend())
+
+ if jax:
+ lst.append(JaxBackend())
+
+ return lst
+
+
+def get_backend(*args):
+ """Returns the proper backend for a list of input arrays
+
+ Also raises TypeError if all arrays are not from the same backend
+ """
+ # check that some arrays given
+ if not len(args) > 0:
+ raise ValueError(" The function takes at least one parameter")
+ # check all same type
+ if not len(set(type(a) for a in args)) == 1:
+ raise ValueError(str_type_error.format([type(a) for a in args]))
+
+ if isinstance(args[0], np.ndarray):
+ return NumpyBackend()
+ elif isinstance(args[0], torch_type):
+ return TorchBackend()
+ elif isinstance(args[0], jax_type):
+ return JaxBackend()
+ else:
+ raise ValueError("Unknown type of non implemented backend.")
+
+
+def to_numpy(*args):
+ """Returns numpy arrays from any compatible backend"""
+
+ if len(args) == 1:
+ return get_backend(args[0]).to_numpy(args[0])
+ else:
+ return [get_backend(a).to_numpy(a) for a in args]
+
+
+class Backend():
+ """
+ Backend abstract class.
+ Implementations: :py:class:`JaxBackend`, :py:class:`NumpyBackend`, :py:class:`TorchBackend`
+
+ - The `__name__` class attribute refers to the name of the backend.
+ - The `__type__` class attribute refers to the data structure used by the backend.
+ """
+
+ __name__ = None
+ __type__ = None
+ __type_list__ = None
+
+ rng_ = None
+
+ def __str__(self):
+ return self.__name__
+
+ # convert to numpy
+ def to_numpy(self, a):
+ """Returns the numpy version of a tensor"""
+ raise NotImplementedError()
+
+ # convert from numpy
+ def from_numpy(self, a, type_as=None):
+ """Creates a tensor cloning a numpy array, with the given precision (defaulting to input's precision) and the given device (in case of GPUs)"""
+ raise NotImplementedError()
+
+ def set_gradients(self, val, inputs, grads):
+ """Define the gradients for the value val wrt the inputs """
+ raise NotImplementedError()
+
+ def zeros(self, shape, type_as=None):
+ r"""
+ Creates a tensor full of zeros.
+
+ This function follows the api from :any:`numpy.zeros`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.zeros.html
+ """
+ raise NotImplementedError()
+
+ def ones(self, shape, type_as=None):
+ r"""
+ Creates a tensor full of ones.
+
+ This function follows the api from :any:`numpy.ones`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.ones.html
+ """
+ raise NotImplementedError()
+
+ def arange(self, stop, start=0, step=1, type_as=None):
+ r"""
+ Returns evenly spaced values within a given interval.
+
+ This function follows the api from :any:`numpy.arange`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.arange.html
+ """
+ raise NotImplementedError()
+
+ def full(self, shape, fill_value, type_as=None):
+ r"""
+ Creates a tensor with given shape, filled with given value.
+
+ This function follows the api from :any:`numpy.full`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.full.html
+ """
+ raise NotImplementedError()
+
+ def eye(self, N, M=None, type_as=None):
+ r"""
+ Creates the identity matrix of given size.
+
+ This function follows the api from :any:`numpy.eye`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.eye.html
+ """
+ raise NotImplementedError()
+
+ def sum(self, a, axis=None, keepdims=False):
+ r"""
+ Sums tensor elements over given dimensions.
+
+ This function follows the api from :any:`numpy.sum`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.sum.html
+ """
+ raise NotImplementedError()
+
+ def cumsum(self, a, axis=None):
+ r"""
+ Returns the cumulative sum of tensor elements over given dimensions.
+
+ This function follows the api from :any:`numpy.cumsum`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.cumsum.html
+ """
+ raise NotImplementedError()
+
+ def max(self, a, axis=None, keepdims=False):
+ r"""
+ Returns the maximum of an array or maximum along given dimensions.
+
+ This function follows the api from :any:`numpy.amax`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.amax.html
+ """
+ raise NotImplementedError()
+
+ def min(self, a, axis=None, keepdims=False):
+ r"""
+ Returns the maximum of an array or maximum along given dimensions.
+
+ This function follows the api from :any:`numpy.amin`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.amin.html
+ """
+ raise NotImplementedError()
+
+ def maximum(self, a, b):
+ r"""
+ Returns element-wise maximum of array elements.
+
+ This function follows the api from :any:`numpy.maximum`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.maximum.html
+ """
+ raise NotImplementedError()
+
+ def minimum(self, a, b):
+ r"""
+ Returns element-wise minimum of array elements.
+
+ This function follows the api from :any:`numpy.minimum`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.minimum.html
+ """
+ raise NotImplementedError()
+
+ def dot(self, a, b):
+ r"""
+ Returns the dot product of two tensors.
+
+ This function follows the api from :any:`numpy.dot`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.dot.html
+ """
+ raise NotImplementedError()
+
+ def abs(self, a):
+ r"""
+ Computes the absolute value element-wise.
+
+ This function follows the api from :any:`numpy.absolute`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.absolute.html
+ """
+ raise NotImplementedError()
+
+ def exp(self, a):
+ r"""
+ Computes the exponential value element-wise.
+
+ This function follows the api from :any:`numpy.exp`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.exp.html
+ """
+ raise NotImplementedError()
+
+ def log(self, a):
+ r"""
+ Computes the natural logarithm, element-wise.
+
+ This function follows the api from :any:`numpy.log`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.log.html
+ """
+ raise NotImplementedError()
+
+ def sqrt(self, a):
+ r"""
+ Returns the non-ngeative square root of a tensor, element-wise.
+
+ This function follows the api from :any:`numpy.sqrt`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.sqrt.html
+ """
+ raise NotImplementedError()
+
+ def power(self, a, exponents):
+ r"""
+ First tensor elements raised to powers from second tensor, element-wise.
+
+ This function follows the api from :any:`numpy.power`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.power.html
+ """
+ raise NotImplementedError()
+
+ def norm(self, a):
+ r"""
+ Computes the matrix frobenius norm.
+
+ This function follows the api from :any:`numpy.linalg.norm`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.linalg.norm.html
+ """
+ raise NotImplementedError()
+
+ def any(self, a):
+ r"""
+ Tests whether any tensor element along given dimensions evaluates to True.
+
+ This function follows the api from :any:`numpy.any`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.any.html
+ """
+ raise NotImplementedError()
+
+ def isnan(self, a):
+ r"""
+ Tests element-wise for NaN and returns result as a boolean tensor.
+
+ This function follows the api from :any:`numpy.isnan`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.isnan.html
+ """
+ raise NotImplementedError()
+
+ def isinf(self, a):
+ r"""
+ Tests element-wise for positive or negative infinity and returns result as a boolean tensor.
+
+ This function follows the api from :any:`numpy.isinf`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.isinf.html
+ """
+ raise NotImplementedError()
+
+ def einsum(self, subscripts, *operands):
+ r"""
+ Evaluates the Einstein summation convention on the operands.
+
+ This function follows the api from :any:`numpy.einsum`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.einsum.html
+ """
+ raise NotImplementedError()
+
+ def sort(self, a, axis=-1):
+ r"""
+ Returns a sorted copy of a tensor.
+
+ This function follows the api from :any:`numpy.sort`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.sort.html
+ """
+ raise NotImplementedError()
+
+ def argsort(self, a, axis=None):
+ r"""
+ Returns the indices that would sort a tensor.
+
+ This function follows the api from :any:`numpy.argsort`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.argsort.html
+ """
+ raise NotImplementedError()
+
+ def searchsorted(self, a, v, side='left'):
+ r"""
+ Finds indices where elements should be inserted to maintain order in given tensor.
+
+ This function follows the api from :any:`numpy.searchsorted`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.searchsorted.html
+ """
+ raise NotImplementedError()
+
+ def flip(self, a, axis=None):
+ r"""
+ Reverses the order of elements in a tensor along given dimensions.
+
+ This function follows the api from :any:`numpy.flip`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.flip.html
+ """
+ raise NotImplementedError()
+
+ def clip(self, a, a_min, a_max):
+ """
+ Limits the values in a tensor.
+
+ This function follows the api from :any:`numpy.clip`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.clip.html
+ """
+ raise NotImplementedError()
+
+ def repeat(self, a, repeats, axis=None):
+ r"""
+ Repeats elements of a tensor.
+
+ This function follows the api from :any:`numpy.repeat`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.repeat.html
+ """
+ raise NotImplementedError()
+
+ def take_along_axis(self, arr, indices, axis):
+ r"""
+ Gathers elements of a tensor along given dimensions.
+
+ This function follows the api from :any:`numpy.take_along_axis`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.take_along_axis.html
+ """
+ raise NotImplementedError()
+
+ def concatenate(self, arrays, axis=0):
+ r"""
+ Joins a sequence of tensors along an existing dimension.
+
+ This function follows the api from :any:`numpy.concatenate`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.concatenate.html
+ """
+ raise NotImplementedError()
+
+ def zero_pad(self, a, pad_width):
+ r"""
+ Pads a tensor.
+
+ This function follows the api from :any:`numpy.pad`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.pad.html
+ """
+ raise NotImplementedError()
+
+ def argmax(self, a, axis=None):
+ r"""
+ Returns the indices of the maximum values of a tensor along given dimensions.
+
+ This function follows the api from :any:`numpy.argmax`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.argmax.html
+ """
+ raise NotImplementedError()
+
+ def mean(self, a, axis=None):
+ r"""
+ Computes the arithmetic mean of a tensor along given dimensions.
+
+ This function follows the api from :any:`numpy.mean`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.mean.html
+ """
+ raise NotImplementedError()
+
+ def std(self, a, axis=None):
+ r"""
+ Computes the standard deviation of a tensor along given dimensions.
+
+ This function follows the api from :any:`numpy.std`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.std.html
+ """
+ raise NotImplementedError()
+
+ def linspace(self, start, stop, num):
+ r"""
+ Returns a specified number of evenly spaced values over a given interval.
+
+ This function follows the api from :any:`numpy.linspace`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.linspace.html
+ """
+ raise NotImplementedError()
+
+ def meshgrid(self, a, b):
+ r"""
+ Returns coordinate matrices from coordinate vectors (Numpy convention).
+
+ This function follows the api from :any:`numpy.meshgrid`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.meshgrid.html
+ """
+ raise NotImplementedError()
+
+ def diag(self, a, k=0):
+ r"""
+ Extracts or constructs a diagonal tensor.
+
+ This function follows the api from :any:`numpy.diag`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.diag.html
+ """
+ raise NotImplementedError()
+
+ def unique(self, a):
+ r"""
+ Finds unique elements of given tensor.
+
+ This function follows the api from :any:`numpy.unique`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.unique.html
+ """
+ raise NotImplementedError()
+
+ def logsumexp(self, a, axis=None):
+ r"""
+ Computes the log of the sum of exponentials of input elements.
+
+ This function follows the api from :any:`scipy.special.logsumexp`
+
+ See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.special.logsumexp.html
+ """
+ raise NotImplementedError()
+
+ def stack(self, arrays, axis=0):
+ r"""
+ Joins a sequence of tensors along a new dimension.
+
+ This function follows the api from :any:`numpy.stack`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.stack.html
+ """
+ raise NotImplementedError()
+
+ def outer(self, a, b):
+ r"""
+ Computes the outer product between two vectors.
+
+ This function follows the api from :any:`numpy.outer`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.outer.html
+ """
+ raise NotImplementedError()
+
+ def reshape(self, a, shape):
+ r"""
+ Gives a new shape to a tensor without changing its data.
+
+ This function follows the api from :any:`numpy.reshape`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.reshape.html
+ """
+ raise NotImplementedError()
+
+ def seed(self, seed=None):
+ r"""
+ Sets the seed for the random generator.
+
+ This function follows the api from :any:`numpy.random.seed`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.random.seed.html
+ """
+ raise NotImplementedError()
+
+ def rand(self, *size, type_as=None):
+ r"""
+ Generate uniform random numbers.
+
+ This function follows the api from :any:`numpy.random.rand`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.random.rand.html
+ """
+ raise NotImplementedError()
+
+ def randn(self, *size, type_as=None):
+ r"""
+ Generate normal Gaussian random numbers.
+
+ This function follows the api from :any:`numpy.random.rand`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.random.rand.html
+ """
+ raise NotImplementedError()
+
+ def coo_matrix(self, data, rows, cols, shape=None, type_as=None):
+ r"""
+ Creates a sparse tensor in COOrdinate format.
+
+ This function follows the api from :any:`scipy.sparse.coo_matrix`
+
+ See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.coo_matrix.html
+ """
+ raise NotImplementedError()
+
+ def issparse(self, a):
+ r"""
+ Checks whether or not the input tensor is a sparse tensor.
+
+ This function follows the api from :any:`scipy.sparse.issparse`
+
+ See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.issparse.html
+ """
+ raise NotImplementedError()
+
+ def tocsr(self, a):
+ r"""
+ Converts this matrix to Compressed Sparse Row format.
+
+ This function follows the api from :any:`scipy.sparse.coo_matrix.tocsr`
+
+ See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.coo_matrix.tocsr.html
+ """
+ raise NotImplementedError()
+
+ def eliminate_zeros(self, a, threshold=0.):
+ r"""
+ Removes entries smaller than the given threshold from the sparse tensor.
+
+ This function follows the api from :any:`scipy.sparse.csr_matrix.eliminate_zeros`
+
+ See: https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.sparse.csr_matrix.eliminate_zeros.html
+ """
+ raise NotImplementedError()
+
+ def todense(self, a):
+ r"""
+ Converts a sparse tensor to a dense tensor.
+
+ This function follows the api from :any:`scipy.sparse.csr_matrix.toarray`
+
+ See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.csr_matrix.toarray.html
+ """
+ raise NotImplementedError()
+
+ def where(self, condition, x, y):
+ r"""
+ Returns elements chosen from x or y depending on condition.
+
+ This function follows the api from :any:`numpy.where`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.where.html
+ """
+ raise NotImplementedError()
+
+ def copy(self, a):
+ r"""
+ Returns a copy of the given tensor.
+
+ This function follows the api from :any:`numpy.copy`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.copy.html
+ """
+ raise NotImplementedError()
+
+ def allclose(self, a, b, rtol=1e-05, atol=1e-08, equal_nan=False):
+ r"""
+ Returns True if two arrays are element-wise equal within a tolerance.
+
+ This function follows the api from :any:`numpy.allclose`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.allclose.html
+ """
+ raise NotImplementedError()
+
+ def dtype_device(self, a):
+ r"""
+ Returns the dtype and the device of the given tensor.
+ """
+ raise NotImplementedError()
+
+ def assert_same_dtype_device(self, a, b):
+ r"""
+ Checks whether or not the two given inputs have the same dtype as well as the same device
+ """
+ raise NotImplementedError()
+
+
+class NumpyBackend(Backend):
+ """
+ NumPy implementation of the backend
+
+ - `__name__` is "numpy"
+ - `__type__` is np.ndarray
+ """
+
+ __name__ = 'numpy'
+ __type__ = np.ndarray
+ __type_list__ = [np.array(1, dtype=np.float32),
+ np.array(1, dtype=np.float64)]
+
+ rng_ = np.random.RandomState()
+
+ def to_numpy(self, a):
+ return a
+
+ def from_numpy(self, a, type_as=None):
+ if type_as is None:
+ return a
+ elif isinstance(a, float):
+ return a
+ else:
+ return a.astype(type_as.dtype)
+
+ def set_gradients(self, val, inputs, grads):
+ # No gradients for numpy
+ return val
+
+ def zeros(self, shape, type_as=None):
+ if type_as is None:
+ return np.zeros(shape)
+ else:
+ return np.zeros(shape, dtype=type_as.dtype)
+
+ def ones(self, shape, type_as=None):
+ if type_as is None:
+ return np.ones(shape)
+ else:
+ return np.ones(shape, dtype=type_as.dtype)
+
+ def arange(self, stop, start=0, step=1, type_as=None):
+ return np.arange(start, stop, step)
+
+ def full(self, shape, fill_value, type_as=None):
+ if type_as is None:
+ return np.full(shape, fill_value)
+ else:
+ return np.full(shape, fill_value, dtype=type_as.dtype)
+
+ def eye(self, N, M=None, type_as=None):
+ if type_as is None:
+ return np.eye(N, M)
+ else:
+ return np.eye(N, M, dtype=type_as.dtype)
+
+ def sum(self, a, axis=None, keepdims=False):
+ return np.sum(a, axis, keepdims=keepdims)
+
+ def cumsum(self, a, axis=None):
+ return np.cumsum(a, axis)
+
+ def max(self, a, axis=None, keepdims=False):
+ return np.max(a, axis, keepdims=keepdims)
+
+ def min(self, a, axis=None, keepdims=False):
+ return np.min(a, axis, keepdims=keepdims)
+
+ def maximum(self, a, b):
+ return np.maximum(a, b)
+
+ def minimum(self, a, b):
+ return np.minimum(a, b)
+
+ def dot(self, a, b):
+ return np.dot(a, b)
+
+ def abs(self, a):
+ return np.abs(a)
+
+ def exp(self, a):
+ return np.exp(a)
+
+ def log(self, a):
+ return np.log(a)
+
+ def sqrt(self, a):
+ return np.sqrt(a)
+
+ def power(self, a, exponents):
+ return np.power(a, exponents)
+
+ def norm(self, a):
+ return np.sqrt(np.sum(np.square(a)))
+
+ def any(self, a):
+ return np.any(a)
+
+ def isnan(self, a):
+ return np.isnan(a)
+
+ def isinf(self, a):
+ return np.isinf(a)
+
+ def einsum(self, subscripts, *operands):
+ return np.einsum(subscripts, *operands)
+
+ def sort(self, a, axis=-1):
+ return np.sort(a, axis)
+
+ def argsort(self, a, axis=-1):
+ return np.argsort(a, axis)
+
+ def searchsorted(self, a, v, side='left'):
+ if a.ndim == 1:
+ return np.searchsorted(a, v, side)
+ else:
+ # this is a not very efficient way to make numpy
+ # searchsorted work on 2d arrays
+ ret = np.empty(v.shape, dtype=int)
+ for i in range(a.shape[0]):
+ ret[i, :] = np.searchsorted(a[i, :], v[i, :], side)
+ return ret
+
+ def flip(self, a, axis=None):
+ return np.flip(a, axis)
+
+ def outer(self, a, b):
+ return np.outer(a, b)
+
+ def clip(self, a, a_min, a_max):
+ return np.clip(a, a_min, a_max)
+
+ def repeat(self, a, repeats, axis=None):
+ return np.repeat(a, repeats, axis)
+
+ def take_along_axis(self, arr, indices, axis):
+ return np.take_along_axis(arr, indices, axis)
+
+ def concatenate(self, arrays, axis=0):
+ return np.concatenate(arrays, axis)
+
+ def zero_pad(self, a, pad_width):
+ return np.pad(a, pad_width)
+
+ def argmax(self, a, axis=None):
+ return np.argmax(a, axis=axis)
+
+ def mean(self, a, axis=None):
+ return np.mean(a, axis=axis)
+
+ def std(self, a, axis=None):
+ return np.std(a, axis=axis)
+
+ def linspace(self, start, stop, num):
+ return np.linspace(start, stop, num)
+
+ def meshgrid(self, a, b):
+ return np.meshgrid(a, b)
+
+ def diag(self, a, k=0):
+ return np.diag(a, k)
+
+ def unique(self, a):
+ return np.unique(a)
+
+ def logsumexp(self, a, axis=None):
+ return scipy.logsumexp(a, axis=axis)
+
+ def stack(self, arrays, axis=0):
+ return np.stack(arrays, axis)
+
+ def reshape(self, a, shape):
+ return np.reshape(a, shape)
+
+ def seed(self, seed=None):
+ if seed is not None:
+ self.rng_.seed(seed)
+
+ def rand(self, *size, type_as=None):
+ return self.rng_.rand(*size)
+
+ def randn(self, *size, type_as=None):
+ return self.rng_.randn(*size)
+
+ def coo_matrix(self, data, rows, cols, shape=None, type_as=None):
+ if type_as is None:
+ return coo_matrix((data, (rows, cols)), shape=shape)
+ else:
+ return coo_matrix((data, (rows, cols)), shape=shape, dtype=type_as.dtype)
+
+ def issparse(self, a):
+ return issparse(a)
+
+ def tocsr(self, a):
+ if self.issparse(a):
+ return a.tocsr()
+ else:
+ return csr_matrix(a)
+
+ def eliminate_zeros(self, a, threshold=0.):
+ if threshold > 0:
+ if self.issparse(a):
+ a.data[self.abs(a.data) <= threshold] = 0
+ else:
+ a[self.abs(a) <= threshold] = 0
+ if self.issparse(a):
+ a.eliminate_zeros()
+ return a
+
+ def todense(self, a):
+ if self.issparse(a):
+ return a.toarray()
+ else:
+ return a
+
+ def where(self, condition, x, y):
+ return np.where(condition, x, y)
+
+ def copy(self, a):
+ return a.copy()
+
+ def allclose(self, a, b, rtol=1e-05, atol=1e-08, equal_nan=False):
+ return np.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)
+
+ def dtype_device(self, a):
+ if hasattr(a, "dtype"):
+ return a.dtype, "cpu"
+ else:
+ return type(a), "cpu"
+
+ def assert_same_dtype_device(self, a, b):
+ # numpy has implicit type conversion so we automatically validate the test
+ pass
+
+
+class JaxBackend(Backend):
+ """
+ JAX implementation of the backend
+
+ - `__name__` is "jax"
+ - `__type__` is jax.numpy.ndarray
+ """
+
+ __name__ = 'jax'
+ __type__ = jax_type
+ __type_list__ = None
+
+ rng_ = None
+
+ def __init__(self):
+ self.rng_ = jax.random.PRNGKey(42)
+
+ for d in jax.devices():
+ self.__type_list__ = [jax.device_put(jnp.array(1, dtype=jnp.float32), d),
+ jax.device_put(jnp.array(1, dtype=jnp.float64), d)]
+
+ def to_numpy(self, a):
+ return np.array(a)
+
+ def _change_device(self, a, type_as):
+ return jax.device_put(a, type_as.device_buffer.device())
+
+ def from_numpy(self, a, type_as=None):
+ if type_as is None:
+ return jnp.array(a)
+ else:
+ return self._change_device(jnp.array(a).astype(type_as.dtype), type_as)
+
+ def set_gradients(self, val, inputs, grads):
+ from jax.flatten_util import ravel_pytree
+ val, = jax.lax.stop_gradient((val,))
+
+ ravelled_inputs, _ = ravel_pytree(inputs)
+ ravelled_grads, _ = ravel_pytree(grads)
+
+ aux = jnp.sum(ravelled_inputs * ravelled_grads) / 2
+ aux = aux - jax.lax.stop_gradient(aux)
+
+ val, = jax.tree_map(lambda z: z + aux, (val,))
+ return val
+
+ def zeros(self, shape, type_as=None):
+ if type_as is None:
+ return jnp.zeros(shape)
+ else:
+ return self._change_device(jnp.zeros(shape, dtype=type_as.dtype), type_as)
+
+ def ones(self, shape, type_as=None):
+ if type_as is None:
+ return jnp.ones(shape)
+ else:
+ return self._change_device(jnp.ones(shape, dtype=type_as.dtype), type_as)
+
+ def arange(self, stop, start=0, step=1, type_as=None):
+ return jnp.arange(start, stop, step)
+
+ def full(self, shape, fill_value, type_as=None):
+ if type_as is None:
+ return jnp.full(shape, fill_value)
+ else:
+ return self._change_device(jnp.full(shape, fill_value, dtype=type_as.dtype), type_as)
+
+ def eye(self, N, M=None, type_as=None):
+ if type_as is None:
+ return jnp.eye(N, M)
+ else:
+ return self._change_device(jnp.eye(N, M, dtype=type_as.dtype), type_as)
+
+ def sum(self, a, axis=None, keepdims=False):
+ return jnp.sum(a, axis, keepdims=keepdims)
+
+ def cumsum(self, a, axis=None):
+ return jnp.cumsum(a, axis)
+
+ def max(self, a, axis=None, keepdims=False):
+ return jnp.max(a, axis, keepdims=keepdims)
+
+ def min(self, a, axis=None, keepdims=False):
+ return jnp.min(a, axis, keepdims=keepdims)
+
+ def maximum(self, a, b):
+ return jnp.maximum(a, b)
+
+ def minimum(self, a, b):
+ return jnp.minimum(a, b)
+
+ def dot(self, a, b):
+ return jnp.dot(a, b)
+
+ def abs(self, a):
+ return jnp.abs(a)
+
+ def exp(self, a):
+ return jnp.exp(a)
+
+ def log(self, a):
+ return jnp.log(a)
+
+ def sqrt(self, a):
+ return jnp.sqrt(a)
+
+ def power(self, a, exponents):
+ return jnp.power(a, exponents)
+
+ def norm(self, a):
+ return jnp.sqrt(jnp.sum(jnp.square(a)))
+
+ def any(self, a):
+ return jnp.any(a)
+
+ def isnan(self, a):
+ return jnp.isnan(a)
+
+ def isinf(self, a):
+ return jnp.isinf(a)
+
+ def einsum(self, subscripts, *operands):
+ return jnp.einsum(subscripts, *operands)
+
+ def sort(self, a, axis=-1):
+ return jnp.sort(a, axis)
+
+ def argsort(self, a, axis=-1):
+ return jnp.argsort(a, axis)
+
+ def searchsorted(self, a, v, side='left'):
+ if a.ndim == 1:
+ return jnp.searchsorted(a, v, side)
+ else:
+ # this is a not very efficient way to make jax numpy
+ # searchsorted work on 2d arrays
+ return jnp.array([jnp.searchsorted(a[i, :], v[i, :], side) for i in range(a.shape[0])])
+
+ def flip(self, a, axis=None):
+ return jnp.flip(a, axis)
+
+ def outer(self, a, b):
+ return jnp.outer(a, b)
+
+ def clip(self, a, a_min, a_max):
+ return jnp.clip(a, a_min, a_max)
+
+ def repeat(self, a, repeats, axis=None):
+ return jnp.repeat(a, repeats, axis)
+
+ def take_along_axis(self, arr, indices, axis):
+ return jnp.take_along_axis(arr, indices, axis)
+
+ def concatenate(self, arrays, axis=0):
+ return jnp.concatenate(arrays, axis)
+
+ def zero_pad(self, a, pad_width):
+ return jnp.pad(a, pad_width)
+
+ def argmax(self, a, axis=None):
+ return jnp.argmax(a, axis=axis)
+
+ def mean(self, a, axis=None):
+ return jnp.mean(a, axis=axis)
+
+ def std(self, a, axis=None):
+ return jnp.std(a, axis=axis)
+
+ def linspace(self, start, stop, num):
+ return jnp.linspace(start, stop, num)
+
+ def meshgrid(self, a, b):
+ return jnp.meshgrid(a, b)
+
+ def diag(self, a, k=0):
+ return jnp.diag(a, k)
+
+ def unique(self, a):
+ return jnp.unique(a)
+
+ def logsumexp(self, a, axis=None):
+ return jscipy.logsumexp(a, axis=axis)
+
+ def stack(self, arrays, axis=0):
+ return jnp.stack(arrays, axis)
+
+ def reshape(self, a, shape):
+ return jnp.reshape(a, shape)
+
+ def seed(self, seed=None):
+ if seed is not None:
+ self.rng_ = jax.random.PRNGKey(seed)
+
+ def rand(self, *size, type_as=None):
+ self.rng_, subkey = jax.random.split(self.rng_)
+ if type_as is not None:
+ return jax.random.uniform(subkey, shape=size, dtype=type_as.dtype)
+ else:
+ return jax.random.uniform(subkey, shape=size)
+
+ def randn(self, *size, type_as=None):
+ self.rng_, subkey = jax.random.split(self.rng_)
+ if type_as is not None:
+ return jax.random.normal(subkey, shape=size, dtype=type_as.dtype)
+ else:
+ return jax.random.normal(subkey, shape=size)
+
+ def coo_matrix(self, data, rows, cols, shape=None, type_as=None):
+ # Currently, JAX does not support sparse matrices
+ data = self.to_numpy(data)
+ rows = self.to_numpy(rows)
+ cols = self.to_numpy(cols)
+ nx = NumpyBackend()
+ coo_matrix = nx.coo_matrix(data, rows, cols, shape=shape, type_as=type_as)
+ matrix = nx.todense(coo_matrix)
+ return self.from_numpy(matrix)
+
+ def issparse(self, a):
+ # Currently, JAX does not support sparse matrices
+ return False
+
+ def tocsr(self, a):
+ # Currently, JAX does not support sparse matrices
+ return a
+
+ def eliminate_zeros(self, a, threshold=0.):
+ # Currently, JAX does not support sparse matrices
+ if threshold > 0:
+ return self.where(
+ self.abs(a) <= threshold,
+ self.zeros((1,), type_as=a),
+ a
+ )
+ return a
+
+ def todense(self, a):
+ # Currently, JAX does not support sparse matrices
+ return a
+
+ def where(self, condition, x, y):
+ return jnp.where(condition, x, y)
+
+ def copy(self, a):
+ # No need to copy, JAX arrays are immutable
+ return a
+
+ def allclose(self, a, b, rtol=1e-05, atol=1e-08, equal_nan=False):
+ return jnp.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)
+
+ def dtype_device(self, a):
+ return a.dtype, a.device_buffer.device()
+
+ def assert_same_dtype_device(self, a, b):
+ a_dtype, a_device = self.dtype_device(a)
+ b_dtype, b_device = self.dtype_device(b)
+
+ assert a_dtype == b_dtype, "Dtype discrepancy"
+ assert a_device == b_device, f"Device discrepancy. First input is on {str(a_device)}, whereas second input is on {str(b_device)}"
+
+
+class TorchBackend(Backend):
+ """
+ PyTorch implementation of the backend
+
+ - `__name__` is "torch"
+ - `__type__` is torch.Tensor
+ """
+
+ __name__ = 'torch'
+ __type__ = torch_type
+ __type_list__ = None
+
+ rng_ = None
+
+ def __init__(self):
+
+ self.rng_ = torch.Generator()
+ self.rng_.seed()
+
+ self.__type_list__ = [torch.tensor(1, dtype=torch.float32),
+ torch.tensor(1, dtype=torch.float64)]
+
+ if torch.cuda.is_available():
+ self.__type_list__.append(torch.tensor(1, dtype=torch.float32, device='cuda'))
+ self.__type_list__.append(torch.tensor(1, dtype=torch.float64, device='cuda'))
+
+ from torch.autograd import Function
+
+ # define a function that takes inputs val and grads
+ # ad returns a val tensor with proper gradients
+ class ValFunction(Function):
+
+ @staticmethod
+ def forward(ctx, val, grads, *inputs):
+ ctx.grads = grads
+ return val
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ # the gradients are grad
+ return (None, None) + ctx.grads
+
+ self.ValFunction = ValFunction
+
+ def to_numpy(self, a):
+ return a.cpu().detach().numpy()
+
+ def from_numpy(self, a, type_as=None):
+ if isinstance(a, float):
+ a = np.array(a)
+ if type_as is None:
+ return torch.from_numpy(a)
+ else:
+ return torch.as_tensor(a, dtype=type_as.dtype, device=type_as.device)
+
+ def set_gradients(self, val, inputs, grads):
+
+ Func = self.ValFunction()
+
+ res = Func.apply(val, grads, *inputs)
+
+ return res
+
+ def zeros(self, shape, type_as=None):
+ if isinstance(shape, int):
+ shape = (shape,)
+ if type_as is None:
+ return torch.zeros(shape)
+ else:
+ return torch.zeros(shape, dtype=type_as.dtype, device=type_as.device)
+
+ def ones(self, shape, type_as=None):
+ if isinstance(shape, int):
+ shape = (shape,)
+ if type_as is None:
+ return torch.ones(shape)
+ else:
+ return torch.ones(shape, dtype=type_as.dtype, device=type_as.device)
+
+ def arange(self, stop, start=0, step=1, type_as=None):
+ if type_as is None:
+ return torch.arange(start, stop, step)
+ else:
+ return torch.arange(start, stop, step, device=type_as.device)
+
+ def full(self, shape, fill_value, type_as=None):
+ if isinstance(shape, int):
+ shape = (shape,)
+ if type_as is None:
+ return torch.full(shape, fill_value)
+ else:
+ return torch.full(shape, fill_value, dtype=type_as.dtype, device=type_as.device)
+
+ def eye(self, N, M=None, type_as=None):
+ if M is None:
+ M = N
+ if type_as is None:
+ return torch.eye(N, m=M)
+ else:
+ return torch.eye(N, m=M, dtype=type_as.dtype, device=type_as.device)
+
+ def sum(self, a, axis=None, keepdims=False):
+ if axis is None:
+ return torch.sum(a)
+ else:
+ return torch.sum(a, axis, keepdim=keepdims)
+
+ def cumsum(self, a, axis=None):
+ if axis is None:
+ return torch.cumsum(a.flatten(), 0)
+ else:
+ return torch.cumsum(a, axis)
+
+ def max(self, a, axis=None, keepdims=False):
+ if axis is None:
+ return torch.max(a)
+ else:
+ return torch.max(a, axis, keepdim=keepdims)[0]
+
+ def min(self, a, axis=None, keepdims=False):
+ if axis is None:
+ return torch.min(a)
+ else:
+ return torch.min(a, axis, keepdim=keepdims)[0]
+
+ def maximum(self, a, b):
+ if isinstance(a, int) or isinstance(a, float):
+ a = torch.tensor([float(a)], dtype=b.dtype, device=b.device)
+ if isinstance(b, int) or isinstance(b, float):
+ b = torch.tensor([float(b)], dtype=a.dtype, device=a.device)
+ if hasattr(torch, "maximum"):
+ return torch.maximum(a, b)
+ else:
+ return torch.max(torch.stack(torch.broadcast_tensors(a, b)), axis=0)[0]
+
+ def minimum(self, a, b):
+ if isinstance(a, int) or isinstance(a, float):
+ a = torch.tensor([float(a)], dtype=b.dtype, device=b.device)
+ if isinstance(b, int) or isinstance(b, float):
+ b = torch.tensor([float(b)], dtype=a.dtype, device=a.device)
+ if hasattr(torch, "minimum"):
+ return torch.minimum(a, b)
+ else:
+ return torch.min(torch.stack(torch.broadcast_tensors(a, b)), axis=0)[0]
+
+ def dot(self, a, b):
+ return torch.matmul(a, b)
+
+ def abs(self, a):
+ return torch.abs(a)
+
+ def exp(self, a):
+ return torch.exp(a)
+
+ def log(self, a):
+ return torch.log(a)
+
+ def sqrt(self, a):
+ return torch.sqrt(a)
+
+ def power(self, a, exponents):
+ return torch.pow(a, exponents)
+
+ def norm(self, a):
+ return torch.sqrt(torch.sum(torch.square(a)))
+
+ def any(self, a):
+ return torch.any(a)
+
+ def isnan(self, a):
+ return torch.isnan(a)
+
+ def isinf(self, a):
+ return torch.isinf(a)
+
+ def einsum(self, subscripts, *operands):
+ return torch.einsum(subscripts, *operands)
+
+ def sort(self, a, axis=-1):
+ sorted0, indices = torch.sort(a, dim=axis)
+ return sorted0
+
+ def argsort(self, a, axis=-1):
+ sorted, indices = torch.sort(a, dim=axis)
+ return indices
+
+ def searchsorted(self, a, v, side='left'):
+ right = (side != 'left')
+ return torch.searchsorted(a, v, right=right)
+
+ def flip(self, a, axis=None):
+ if axis is None:
+ return torch.flip(a, tuple(i for i in range(len(a.shape))))
+ if isinstance(axis, int):
+ return torch.flip(a, (axis,))
+ else:
+ return torch.flip(a, dims=axis)
+
+ def outer(self, a, b):
+ return torch.outer(a, b)
+
+ def clip(self, a, a_min, a_max):
+ return torch.clamp(a, a_min, a_max)
+
+ def repeat(self, a, repeats, axis=None):
+ return torch.repeat_interleave(a, repeats, dim=axis)
+
+ def take_along_axis(self, arr, indices, axis):
+ return torch.gather(arr, axis, indices)
+
+ def concatenate(self, arrays, axis=0):
+ return torch.cat(arrays, dim=axis)
+
+ def zero_pad(self, a, pad_width):
+ from torch.nn.functional import pad
+ # pad_width is an array of ndim tuples indicating how many 0 before and after
+ # we need to add. We first need to make it compliant with torch syntax, that
+ # starts with the last dim, then second last, etc.
+ how_pad = tuple(element for tupl in pad_width[::-1] for element in tupl)
+ return pad(a, how_pad)
+
+ def argmax(self, a, axis=None):
+ return torch.argmax(a, dim=axis)
+
+ def mean(self, a, axis=None):
+ if axis is not None:
+ return torch.mean(a, dim=axis)
+ else:
+ return torch.mean(a)
+
+ def std(self, a, axis=None):
+ if axis is not None:
+ return torch.std(a, dim=axis, unbiased=False)
+ else:
+ return torch.std(a, unbiased=False)
+
+ def linspace(self, start, stop, num):
+ return torch.linspace(start, stop, num, dtype=torch.float64)
+
+ def meshgrid(self, a, b):
+ X, Y = torch.meshgrid(a, b)
+ return X.T, Y.T
+
+ def diag(self, a, k=0):
+ return torch.diag(a, diagonal=k)
+
+ def unique(self, a):
+ return torch.unique(a)
+
+ def logsumexp(self, a, axis=None):
+ if axis is not None:
+ return torch.logsumexp(a, dim=axis)
+ else:
+ return torch.logsumexp(a, dim=tuple(range(len(a.shape))))
+
+ def stack(self, arrays, axis=0):
+ return torch.stack(arrays, dim=axis)
+
+ def reshape(self, a, shape):
+ return torch.reshape(a, shape)
+
+ def seed(self, seed=None):
+ if isinstance(seed, int):
+ self.rng_.manual_seed(seed)
+ elif isinstance(seed, torch.Generator):
+ self.rng_ = seed
+ else:
+ raise ValueError("Non compatible seed : {}".format(seed))
+
+ def rand(self, *size, type_as=None):
+ if type_as is not None:
+ return torch.rand(size=size, generator=self.rng_, dtype=type_as.dtype, device=type_as.device)
+ else:
+ return torch.rand(size=size, generator=self.rng_)
+
+ def randn(self, *size, type_as=None):
+ if type_as is not None:
+ return torch.randn(size=size, dtype=type_as.dtype, generator=self.rng_, device=type_as.device)
+ else:
+ return torch.randn(size=size, generator=self.rng_)
+
+ def coo_matrix(self, data, rows, cols, shape=None, type_as=None):
+ if type_as is None:
+ return torch.sparse_coo_tensor(torch.stack([rows, cols]), data, size=shape)
+ else:
+ return torch.sparse_coo_tensor(
+ torch.stack([rows, cols]), data, size=shape,
+ dtype=type_as.dtype, device=type_as.device
+ )
+
+ def issparse(self, a):
+ return getattr(a, "is_sparse", False) or getattr(a, "is_sparse_csr", False)
+
+ def tocsr(self, a):
+ # Versions older than 1.9 do not support CSR tensors. PyTorch 1.9 and 1.10 offer a very limited support
+ return self.todense(a)
+
+ def eliminate_zeros(self, a, threshold=0.):
+ if self.issparse(a):
+ if threshold > 0:
+ mask = self.abs(a) <= threshold
+ mask = ~mask
+ mask = mask.nonzero()
+ else:
+ mask = a._values().nonzero()
+ nv = a._values().index_select(0, mask.view(-1))
+ ni = a._indices().index_select(1, mask.view(-1))
+ return self.coo_matrix(nv, ni[0], ni[1], shape=a.shape, type_as=a)
+ else:
+ if threshold > 0:
+ a[self.abs(a) <= threshold] = 0
+ return a
+
+ def todense(self, a):
+ if self.issparse(a):
+ return a.to_dense()
+ else:
+ return a
+
+ def where(self, condition, x, y):
+ return torch.where(condition, x, y)
+
+ def copy(self, a):
+ return torch.clone(a)
+
+ def allclose(self, a, b, rtol=1e-05, atol=1e-08, equal_nan=False):
+ return torch.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)
+
+ def dtype_device(self, a):
+ return a.dtype, a.device
+
+ def assert_same_dtype_device(self, a, b):
+ a_dtype, a_device = self.dtype_device(a)
+ b_dtype, b_device = self.dtype_device(b)
+
+ assert a_dtype == b_dtype, "Dtype discrepancy"
+ assert a_device == b_device, f"Device discrepancy. First input is on {str(a_device)}, whereas second input is on {str(b_device)}"
diff --git a/ot/bregman.py b/ot/bregman.py
index f1f8437..cce52e2 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -7,70 +7,104 @@ Bregman projections solvers for entropic regularized OT
# Nicolas Courty <ncourty@irisa.fr>
# Kilian Fatras <kilian.fatras@irisa.fr>
# Titouan Vayer <titouan.vayer@irisa.fr>
-# Hicham Janati <hicham.janati@inria.fr>
+# Hicham Janati <hicham.janati100@gmail.com>
# Mokhtar Z. Alaya <mokhtarzahdi.alaya@gmail.com>
# Alexander Tong <alexander.tong@yale.edu>
# Ievgen Redko <ievgen.redko@univ-st-etienne.fr>
+# Quang Huy Tran <quang-huy.tran@univ-ubs.fr>
#
# License: MIT License
-import numpy as np
import warnings
-from .utils import unif, dist
+
+import numpy as np
from scipy.optimize import fmin_l_bfgs_b
+from ot.utils import unif, dist, list_to_array
+from .backend import get_backend
+
def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
- stopThr=1e-9, verbose=False, log=False, **kwargs):
+ stopThr=1e-9, verbose=False, log=False, warn=True,
+ **kwargs):
r"""
Solve the entropic regularization optimal transport problem and return the OT matrix
The function solves the following optimization problem:
.. math::
- \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)
+ \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
+ \mathrm{reg}\cdot\Omega(\gamma)
+
+ s.t. \ \gamma \mathbf{1} &= \mathbf{a}
- s.t. \gamma 1 = a
+ \gamma^T \mathbf{1} &= \mathbf{b}
- \gamma^T 1= b
+ \gamma &\geq 0
- \gamma\geq 0
where :
- - M is the (dim_a, dim_b) metric cost matrix
- - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - a and b are source and target weights (histograms, both sum to 1)
+ - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix
+ - :math:`\Omega` is the entropic regularization term
+ :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target
+ weights (histograms, both sum to 1)
- The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [2]_
+ .. note:: This function is backend-compatible and will work on arrays
+ from all compatible backends.
+
+ The algorithm used for solving the problem is the Sinkhorn-Knopp matrix
+ scaling algorithm as proposed in :ref:`[2] <references-sinkhorn>`
+
+ **Choosing a Sinkhorn solver**
+
+ By default and when using a regularization parameter that is not too small
+ the default sinkhorn solver should be enough. If you need to use a small
+ regularization to get sharper OT matrices, you should use the
+ :py:func:`ot.bregman.sinkhorn_stabilized` solver that will avoid numerical
+ errors. This last solver can be very slow in practice and might not even
+ converge to a reasonable OT matrix in a finite time. This is why
+ :py:func:`ot.bregman.sinkhorn_epsilon_scaling` that relies on iterating the value
+ of the regularization (and using warm start) sometimes leads to better
+ solutions. Note that the greedy version of the sinkhorn
+ :py:func:`ot.bregman.greenkhorn` can also lead to a speedup and the screening
+ version of the sinkhorn :py:func:`ot.bregman.screenkhorn` aim at providing a
+ fast approximation of the Sinkhorn problem. For use of GPU and gradient
+ computation with small number of iterations we strongly recommend the
+ :py:func:`ot.bregman.sinkhorn_log` solver that will no need to check for
+ numerical problems.
Parameters
----------
- a : ndarray, shape (dim_a,)
+ a : array-like, shape (dim_a,)
samples weights in the source domain
- b : ndarray, shape (dim_b,) or ndarray, shape (dim_b, n_hists)
+ b : array-like, shape (dim_b,) or ndarray, shape (dim_b, n_hists)
samples in the target domain, compute sinkhorn with multiple targets
- and fixed M if b is a matrix (return OT loss + dual variables in log)
- M : ndarray, shape (dim_a, dim_b)
+ and fixed :math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix
+ (return OT loss + dual variables in log)
+ M : array-like, shape (dim_a, dim_b)
loss matrix
reg : float
Regularization term >0
method : str
- method used for the solver either 'sinkhorn', 'greenkhorn', 'sinkhorn_stabilized' or
- 'sinkhorn_epsilon_scaling', see those function for specific parameters
+ method used for the solver either 'sinkhorn','sinkhorn_log',
+ 'greenkhorn', 'sinkhorn_stabilized' or 'sinkhorn_epsilon_scaling', see
+ those function for specific parameters
numItermax : int, optional
Max number of iterations
stopThr : float, optional
- Stop threshol on error (>0)
+ Stop threshold on error (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
record log if True
-
+ warn : bool, optional
+ if True, raises a warning if the algorithm doesn't convergence.
Returns
-------
- gamma : ndarray, shape (dim_a, dim_b)
+ gamma : array-like, shape (dim_a, dim_b)
Optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
@@ -86,102 +120,152 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
array([[0.36552929, 0.13447071],
[0.13447071, 0.36552929]])
-
+ .. _references-sinkhorn:
References
----------
- .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
+ .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation
+ of Optimal Transport, Advances in Neural Information Processing
+ Systems (NIPS) 26, 2013
- .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519.
+ .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms
+ for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519.
- .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
+ .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
+ Scaling algorithms for unbalanced transport problems.
+ arXiv preprint arXiv:1607.05816.
+ .. [34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., Trouvé,
+ A., & Peyré, G. (2019, April). Interpolating between optimal transport
+ and MMD using Sinkhorn divergences. In The 22nd International Conference
+ on Artificial Intelligence and Statistics (pp. 2681-2690). PMLR.
See Also
--------
ot.lp.emd : Unregularized OT
ot.optim.cg : General regularized OT
- ot.bregman.sinkhorn_knopp : Classic Sinkhorn [2]
- ot.bregman.sinkhorn_stabilized: Stabilized sinkhorn [9][10]
- ot.bregman.sinkhorn_epsilon_scaling: Sinkhorn with epslilon scaling [9][10]
+ ot.bregman.sinkhorn_knopp : Classic Sinkhorn :ref:`[2] <references-sinkhorn>`
+ ot.bregman.sinkhorn_stabilized: Stabilized sinkhorn
+ :ref:`[9] <references-sinkhorn>` :ref:`[10] <references-sinkhorn>`
+ ot.bregman.sinkhorn_epsilon_scaling: Sinkhorn with epslilon scaling
+ :ref:`[9] <references-sinkhorn>` :ref:`[10] <references-sinkhorn>`
"""
if method.lower() == 'sinkhorn':
return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax,
stopThr=stopThr, verbose=verbose, log=log,
+ warn=warn,
**kwargs)
+ elif method.lower() == 'sinkhorn_log':
+ return sinkhorn_log(a, b, M, reg, numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose, log=log,
+ warn=warn,
+ **kwargs)
elif method.lower() == 'greenkhorn':
return greenkhorn(a, b, M, reg, numItermax=numItermax,
- stopThr=stopThr, verbose=verbose, log=log)
+ stopThr=stopThr, verbose=verbose, log=log,
+ warn=warn)
elif method.lower() == 'sinkhorn_stabilized':
return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax,
stopThr=stopThr, verbose=verbose,
- log=log, **kwargs)
+ log=log, warn=warn,
+ **kwargs)
elif method.lower() == 'sinkhorn_epsilon_scaling':
return sinkhorn_epsilon_scaling(a, b, M, reg,
numItermax=numItermax,
stopThr=stopThr, verbose=verbose,
- log=log, **kwargs)
+ log=log, warn=warn,
+ **kwargs)
else:
raise ValueError("Unknown method '%s'." % method)
def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
- stopThr=1e-9, verbose=False, log=False, **kwargs):
+ stopThr=1e-9, verbose=False, log=False, warn=False, **kwargs):
r"""
Solve the entropic regularization optimal transport problem and return the loss
The function solves the following optimization problem:
.. math::
- W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)
+ W = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
+ \mathrm{reg}\cdot\Omega(\gamma)
+
+ s.t. \ \gamma \mathbf{1} &= \mathbf{a}
- s.t. \gamma 1 = a
+ \gamma^T \mathbf{1} &= \mathbf{b}
- \gamma^T 1= b
+ \gamma &\geq 0
- \gamma\geq 0
where :
- - M is the (dim_a, dim_b) metric cost matrix
- - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - a and b are source and target weights (histograms, both sum to 1)
+ - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix
+ - :math:`\Omega` is the entropic regularization term
+ :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target
+ weights (histograms, both sum to 1)
- The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [2]_
+ .. note:: This function is backend-compatible and will work on arrays
+ from all compatible backends.
+ The algorithm used for solving the problem is the Sinkhorn-Knopp matrix
+ scaling algorithm as proposed in :ref:`[2] <references-sinkhorn2>`
+
+
+ **Choosing a Sinkhorn solver**
+
+ By default and when using a regularization parameter that is not too small
+ the default sinkhorn solver should be enough. If you need to use a small
+ regularization to get sharper OT matrices, you should use the
+ :py:func:`ot.bregman.sinkhorn_log` solver that will avoid numerical
+ errors. This last solver can be very slow in practice and might not even
+ converge to a reasonable OT matrix in a finite time. This is why
+ :py:func:`ot.bregman.sinkhorn_epsilon_scaling` that relies on iterating the value
+ of the regularization (and using warm start) sometimes leads to better
+ solutions. Note that the greedy version of the sinkhorn
+ :py:func:`ot.bregman.greenkhorn` can also lead to a speedup and the screening
+ version of the sinkhorn :py:func:`ot.bregman.screenkhorn` aim a providing a
+ fast approximation of the Sinkhorn problem. For use of GPU and gradient
+ computation with small number of iterations we strongly recommend the
+ :py:func:`ot.bregman.sinkhorn_log` solver that will no need to check for
+ numerical problems.
Parameters
----------
- a : ndarray, shape (dim_a,)
+ a : array-like, shape (dim_a,)
samples weights in the source domain
- b : ndarray, shape (dim_b,) or ndarray, shape (dim_b, n_hists)
+ b : array-like, shape (dim_b,) or ndarray, shape (dim_b, n_hists)
samples in the target domain, compute sinkhorn with multiple targets
- and fixed M if b is a matrix (return OT loss + dual variables in log)
- M : ndarray, shape (dim_a, dim_b)
+ and fixed :math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix
+ (return OT loss + dual variables in log)
+ M : array-like, shape (dim_a, dim_b)
loss matrix
reg : float
Regularization term >0
method : str
- method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or
- 'sinkhorn_epsilon_scaling', see those function for specific parameters
+ method used for the solver either 'sinkhorn','sinkhorn_log',
+ 'sinkhorn_stabilized', see those function for specific parameters
numItermax : int, optional
Max number of iterations
stopThr : float, optional
- Stop threshol on error (>0)
+ Stop threshold on error (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
record log if True
+ warn : bool, optional
+ if True, raises a warning if the algorithm doesn't convergence.
Returns
-------
- W : (n_hists) ndarray or float
+ W : (n_hists) float/array-like
Optimal transportation loss for the given parameters
log : dict
log dictionary return only if log==True in parameters
+
Examples
--------
@@ -190,99 +274,142 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
>>> b=[.5, .5]
>>> M=[[0., 1.], [1., 0.]]
>>> ot.sinkhorn2(a, b, M, 1)
- array([0.26894142])
-
+ 0.26894142136999516
+ .. _references-sinkhorn2:
References
----------
- .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
+ .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of
+ Optimal Transport, Advances in Neural Information
+ Processing Systems (NIPS) 26, 2013
- .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519.
+ .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms
+ for Entropy Regularized Transport Problems.
+ arXiv preprint arXiv:1610.06519.
- .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
+ .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
+ Scaling algorithms for unbalanced transport problems.
+ arXiv preprint arXiv:1607.05816.
- [21] Altschuler J., Weed J., Rigollet P. : Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration, Advances in Neural Information Processing Systems (NIPS) 31, 2017
+ .. [21] Altschuler J., Weed J., Rigollet P. : Near-linear time approximation
+ algorithms for optimal transport via Sinkhorn iteration,
+ Advances in Neural Information Processing Systems (NIPS) 31, 2017
+ .. [34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I.,
+ Trouvé, A., & Peyré, G. (2019, April).
+ Interpolating between optimal transport and MMD using Sinkhorn
+ divergences. In The 22nd International Conference on Artificial
+ Intelligence and Statistics (pp. 2681-2690). PMLR.
See Also
--------
ot.lp.emd : Unregularized OT
ot.optim.cg : General regularized OT
- ot.bregman.sinkhorn_knopp : Classic Sinkhorn [2]
- ot.bregman.greenkhorn : Greenkhorn [21]
- ot.bregman.sinkhorn_stabilized: Stabilized sinkhorn [9][10]
- ot.bregman.sinkhorn_epsilon_scaling: Sinkhorn with epslilon scaling [9][10]
-
+ ot.bregman.sinkhorn_knopp : Classic Sinkhorn :ref:`[2] <references-sinkhorn2>`
+ ot.bregman.greenkhorn : Greenkhorn :ref:`[21] <references-sinkhorn2>`
+ ot.bregman.sinkhorn_stabilized: Stabilized sinkhorn
+ :ref:`[9] <references-sinkhorn2>` :ref:`[10] <references-sinkhorn2>`
"""
- b = np.asarray(b, dtype=np.float64)
+
+ M, a, b = list_to_array(M, a, b)
+ nx = get_backend(M, a, b)
+
if len(b.shape) < 2:
- b = b[:, None]
- if method.lower() == 'sinkhorn':
- return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax,
- stopThr=stopThr, verbose=verbose, log=log,
- **kwargs)
- elif method.lower() == 'sinkhorn_stabilized':
- return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax,
- stopThr=stopThr, verbose=verbose, log=log,
- **kwargs)
- elif method.lower() == 'sinkhorn_epsilon_scaling':
- return sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=numItermax,
- stopThr=stopThr, verbose=verbose,
- log=log, **kwargs)
+ if method.lower() == 'sinkhorn':
+ res = sinkhorn_knopp(a, b, M, reg, numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose, log=log,
+ **kwargs)
+ elif method.lower() == 'sinkhorn_log':
+ res = sinkhorn_log(a, b, M, reg, numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose, log=log,
+ **kwargs)
+ elif method.lower() == 'sinkhorn_stabilized':
+ res = sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose, log=log,
+ **kwargs)
+ else:
+ raise ValueError("Unknown method '%s'." % method)
+ if log:
+ return nx.sum(M * res[0]), res[1]
+ else:
+ return nx.sum(M * res)
+
else:
- raise ValueError("Unknown method '%s'." % method)
+
+ if method.lower() == 'sinkhorn':
+ return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose, log=log,
+ **kwargs)
+ elif method.lower() == 'sinkhorn_log':
+ return sinkhorn_log(a, b, M, reg, numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose, log=log,
+ **kwargs)
+ elif method.lower() == 'sinkhorn_stabilized':
+ return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose, log=log,
+ **kwargs)
+ else:
+ raise ValueError("Unknown method '%s'." % method)
-def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
- stopThr=1e-9, verbose=False, log=False, **kwargs):
+def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9,
+ verbose=False, log=False, warn=True,
+ **kwargs):
r"""
Solve the entropic regularization optimal transport problem and return the OT matrix
The function solves the following optimization problem:
.. math::
- \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)
+ \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
+ \mathrm{reg}\cdot\Omega(\gamma)
- s.t. \gamma 1 = a
+ s.t. \ \gamma \mathbf{1} &= \mathbf{a}
- \gamma^T 1= b
+ \gamma^T \mathbf{1} &= \mathbf{b}
- \gamma\geq 0
+ \gamma &\geq 0
where :
- - M is the (dim_a, dim_b) metric cost matrix
- - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - a and b are source and target weights (histograms, both sum to 1)
+ - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix
+ - :math:`\Omega` is the entropic regularization term
+ :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target
+ weights (histograms, both sum to 1)
- The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [2]_
+ The algorithm used for solving the problem is the Sinkhorn-Knopp
+ matrix scaling algorithm as proposed in :ref:`[2] <references-sinkhorn-knopp>`
Parameters
----------
- a : ndarray, shape (dim_a,)
+ a : array-like, shape (dim_a,)
samples weights in the source domain
- b : ndarray, shape (dim_b,) or ndarray, shape (dim_b, n_hists)
+ b : array-like, shape (dim_b,) or array-like, shape (dim_b, n_hists)
samples in the target domain, compute sinkhorn with multiple targets
- and fixed M if b is a matrix (return OT loss + dual variables in log)
- M : ndarray, shape (dim_a, dim_b)
+ and fixed :math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix
+ (return OT loss + dual variables in log)
+ M : array-like, shape (dim_a, dim_b)
loss matrix
reg : float
Regularization term >0
numItermax : int, optional
Max number of iterations
stopThr : float, optional
- Stop threshol on error (>0)
+ Stop threshold on error (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
record log if True
+ warn : bool, optional
+ if True, raises a warning if the algorithm doesn't convergence.
Returns
-------
- gamma : ndarray, shape (dim_a, dim_b)
+ gamma : array-like, shape (dim_a, dim_b)
Optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
@@ -299,10 +426,13 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
[0.13447071, 0.36552929]])
+ .. _references-sinkhorn-knopp:
References
----------
- .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
+ .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation
+ of Optimal Transport, Advances in Neural Information
+ Processing Systems (NIPS) 26, 2013
See Also
@@ -312,18 +442,18 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
"""
- a = np.asarray(a, dtype=np.float64)
- b = np.asarray(b, dtype=np.float64)
- M = np.asarray(M, dtype=np.float64)
+ a, b, M = list_to_array(a, b, M)
+
+ nx = get_backend(M, a, b)
if len(a) == 0:
- a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0]
+ a = nx.full((M.shape[0],), 1.0 / M.shape[0], type_as=M)
if len(b) == 0:
- b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1]
+ b = nx.full((M.shape[1],), 1.0 / M.shape[1], type_as=M)
# init data
dim_a = len(a)
- dim_b = len(b)
+ dim_b = b.shape[0]
if len(b.shape) > 1:
n_hists = b.shape[1]
@@ -336,66 +466,64 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
# we assume that no distances are null except those of the diagonal of
# distances
if n_hists:
- u = np.ones((dim_a, n_hists)) / dim_a
- v = np.ones((dim_b, n_hists)) / dim_b
+ u = nx.ones((dim_a, n_hists), type_as=M) / dim_a
+ v = nx.ones((dim_b, n_hists), type_as=M) / dim_b
else:
- u = np.ones(dim_a) / dim_a
- v = np.ones(dim_b) / dim_b
+ u = nx.ones(dim_a, type_as=M) / dim_a
+ v = nx.ones(dim_b, type_as=M) / dim_b
- # print(reg)
-
- # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute
- K = np.empty(M.shape, dtype=M.dtype)
- np.divide(M, -reg, out=K)
- np.exp(K, out=K)
-
- # print(np.min(K))
- tmp2 = np.empty(b.shape, dtype=M.dtype)
+ K = nx.exp(M / (-reg))
Kp = (1 / a).reshape(-1, 1) * K
- cpt = 0
+
err = 1
- while (err > stopThr and cpt < numItermax):
+ for ii in range(numItermax):
uprev = u
vprev = v
+ KtransposeU = nx.dot(K.T, u)
+ v = b / KtransposeU
+ u = 1. / nx.dot(Kp, v)
- KtransposeU = np.dot(K.T, u)
- v = np.divide(b, KtransposeU)
- u = 1. / np.dot(Kp, v)
-
- if (np.any(KtransposeU == 0)
- or np.any(np.isnan(u)) or np.any(np.isnan(v))
- or np.any(np.isinf(u)) or np.any(np.isinf(v))):
+ if (nx.any(KtransposeU == 0)
+ or nx.any(nx.isnan(u)) or nx.any(nx.isnan(v))
+ or nx.any(nx.isinf(u)) or nx.any(nx.isinf(v))):
# we have reached the machine precision
# come back to previous solution and quit loop
- print('Warning: numerical errors at iteration', cpt)
+ warnings.warn('Warning: numerical errors at iteration %d' % ii)
u = uprev
v = vprev
break
- if cpt % 10 == 0:
+ if ii % 10 == 0:
# we can speed up the process by checking for the error only all
# the 10th iterations
if n_hists:
- np.einsum('ik,ij,jk->jk', u, K, v, out=tmp2)
+ tmp2 = nx.einsum('ik,ij,jk->jk', u, K, v)
else:
# compute right marginal tmp2= (diag(u)Kdiag(v))^T1
- np.einsum('i,ij,j->j', u, K, v, out=tmp2)
- err = np.linalg.norm(tmp2 - b) # violation of marginal
+ tmp2 = nx.einsum('i,ij,j->j', u, K, v)
+ err = nx.norm(tmp2 - b) # violation of marginal
if log:
log['err'].append(err)
+ if err < stopThr:
+ break
if verbose:
- if cpt % 200 == 0:
+ if ii % 200 == 0:
print(
'{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
- print('{:5d}|{:8e}|'.format(cpt, err))
- cpt = cpt + 1
+ print('{:5d}|{:8e}|'.format(ii, err))
+ else:
+ if warn:
+ warnings.warn("Sinkhorn did not converge. You might want to "
+ "increase the number of iterations `numItermax` "
+ "or the regularization parameter `reg`.")
if log:
+ log['niter'] = ii
log['u'] = u
log['v'] = v
if n_hists: # return only loss
- res = np.einsum('ik,ij,jk,ij->k', u, K, v, M)
+ res = nx.einsum('ik,ij,jk,ij->k', u, K, v, M)
if log:
return res, log
else:
@@ -409,58 +537,259 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
return u.reshape((-1, 1)) * K * v.reshape((1, -1))
-def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False,
- log=False):
+def sinkhorn_log(a, b, M, reg, numItermax=1000, stopThr=1e-9, verbose=False,
+ log=False, warn=True, **kwargs):
r"""
- Solve the entropic regularization optimal transport problem and return the OT matrix
+ Solve the entropic regularization optimal transport problem in log space
+ and return the OT matrix
+
+ The function solves the following optimization problem:
+
+ .. math::
+ \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
+ \mathrm{reg}\cdot\Omega(\gamma)
+
+ s.t. \ \gamma \mathbf{1} &= \mathbf{a}
+
+ \gamma^T \mathbf{1} &= \mathbf{b}
+
+ \gamma &\geq 0
+ where :
+
+ - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix
+ - :math:`\Omega` is the entropic regularization term
+ :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (histograms, both sum to 1)
+
+ The algorithm used for solving the problem is the Sinkhorn-Knopp matrix
+ scaling algorithm :ref:`[2] <references-sinkhorn-log>` with the
+ implementation from :ref:`[34] <references-sinkhorn-log>`
+
+
+ Parameters
+ ----------
+ a : array-like, shape (dim_a,)
+ samples weights in the source domain
+ b : array-like, shape (dim_b,) or array-like, shape (dim_b, n_hists)
+ samples in the target domain, compute sinkhorn with multiple targets
+ and fixed :math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix (return OT loss + dual variables in log)
+ M : array-like, shape (dim_a, dim_b)
+ loss matrix
+ reg : float
+ Regularization term >0
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshold on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+ warn : bool, optional
+ if True, raises a warning if the algorithm doesn't convergence.
+
+ Returns
+ -------
+ gamma : array-like, shape (dim_a, dim_b)
+ Optimal transportation matrix for the given parameters
+ log : dict
+ log dictionary return only if log==True in parameters
+
+ Examples
+ --------
+
+ >>> import ot
+ >>> a=[.5, .5]
+ >>> b=[.5, .5]
+ >>> M=[[0., 1.], [1., 0.]]
+ >>> ot.sinkhorn(a, b, M, 1)
+ array([[0.36552929, 0.13447071],
+ [0.13447071, 0.36552929]])
+
+
+ .. _references-sinkhorn-log:
+ References
+ ----------
+
+ .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of
+ Optimal Transport, Advances in Neural Information Processing
+ Systems (NIPS) 26, 2013
+
+ .. [34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I.,
+ Trouvé, A., & Peyré, G. (2019, April). Interpolating between
+ optimal transport and MMD using Sinkhorn divergences. In The
+ 22nd International Conference on Artificial Intelligence and
+ Statistics (pp. 2681-2690). PMLR.
+
+
+ See Also
+ --------
+ ot.lp.emd : Unregularized OT
+ ot.optim.cg : General regularized OT
+
+ """
+
+ a, b, M = list_to_array(a, b, M)
+
+ nx = get_backend(M, a, b)
+
+ if len(a) == 0:
+ a = nx.full((M.shape[0],), 1.0 / M.shape[0], type_as=M)
+ if len(b) == 0:
+ b = nx.full((M.shape[1],), 1.0 / M.shape[1], type_as=M)
+
+ # init data
+ dim_a = len(a)
+ dim_b = b.shape[0]
+
+ if len(b.shape) > 1:
+ n_hists = b.shape[1]
+ else:
+ n_hists = 0
- The algorithm used is based on the paper
+ if n_hists: # we do not want to use tensors sor we do a loop
- Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration
- by Jason Altschuler, Jonathan Weed, Philippe Rigollet
- appeared at NIPS 2017
+ lst_loss = []
+ lst_u = []
+ lst_v = []
- which is a stochastic version of the Sinkhorn-Knopp algorithm [2].
+ for k in range(n_hists):
+ res = sinkhorn_log(a, b[:, k], M, reg, numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose, log=log, **kwargs)
+
+ if log:
+ lst_loss.append(nx.sum(M * res[0]))
+ lst_u.append(res[1]['log_u'])
+ lst_v.append(res[1]['log_v'])
+ else:
+ lst_loss.append(nx.sum(M * res))
+ res = nx.stack(lst_loss)
+ if log:
+ log = {'log_u': nx.stack(lst_u, 1),
+ 'log_v': nx.stack(lst_v, 1), }
+ log['u'] = nx.exp(log['log_u'])
+ log['v'] = nx.exp(log['log_v'])
+ return res, log
+ else:
+ return res
+
+ else:
+
+ if log:
+ log = {'err': []}
+
+ Mr = - M / reg
+
+ # we assume that no distances are null except those of the diagonal of
+ # distances
+
+ u = nx.zeros(dim_a, type_as=M)
+ v = nx.zeros(dim_b, type_as=M)
+
+ def get_logT(u, v):
+ if n_hists:
+ return Mr[:, :, None] + u + v
+ else:
+ return Mr + u[:, None] + v[None, :]
+
+ loga = nx.log(a)
+ logb = nx.log(b)
+
+ err = 1
+ for ii in range(numItermax):
+
+ v = logb - nx.logsumexp(Mr + u[:, None], 0)
+ u = loga - nx.logsumexp(Mr + v[None, :], 1)
+
+ if ii % 10 == 0:
+ # we can speed up the process by checking for the error only all
+ # the 10th iterations
+
+ # compute right marginal tmp2= (diag(u)Kdiag(v))^T1
+ tmp2 = nx.sum(nx.exp(get_logT(u, v)), 0)
+ err = nx.norm(tmp2 - b) # violation of marginal
+ if log:
+ log['err'].append(err)
+
+ if verbose:
+ if ii % 200 == 0:
+ print(
+ '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(ii, err))
+ if err < stopThr:
+ break
+ else:
+ if warn:
+ warnings.warn("Sinkhorn did not converge. You might want to "
+ "increase the number of iterations `numItermax` "
+ "or the regularization parameter `reg`.")
+
+ if log:
+ log['niter'] = ii
+ log['log_u'] = u
+ log['log_v'] = v
+ log['u'] = nx.exp(u)
+ log['v'] = nx.exp(v)
+
+ return nx.exp(get_logT(u, v)), log
+
+ else:
+ return nx.exp(get_logT(u, v))
+
+
+def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False,
+ log=False, warn=True):
+ r"""
+ Solve the entropic regularization optimal transport problem and return the OT matrix
+
+ The algorithm used is based on the paper :ref:`[22] <references-greenkhorn>`
+ which is a stochastic version of the Sinkhorn-Knopp
+ algorithm :ref:`[2] <references-greenkhorn>`
The function solves the following optimization problem:
.. math::
- \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)
+ \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
+ \mathrm{reg}\cdot\Omega(\gamma)
- s.t. \gamma 1 = a
+ s.t. \ \gamma \mathbf{1} &= \mathbf{a}
- \gamma^T 1= b
+ \gamma^T \mathbf{1} &= \mathbf{b}
- \gamma\geq 0
+ \gamma &\geq 0
where :
- - M is the (dim_a, dim_b) metric cost matrix
- - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - a and b are source and target weights (histograms, both sum to 1)
-
+ - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix
+ - :math:`\Omega` is the entropic regularization term
+ :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target
+ weights (histograms, both sum to 1)
Parameters
----------
- a : ndarray, shape (dim_a,)
+ a : array-like, shape (dim_a,)
samples weights in the source domain
- b : ndarray, shape (dim_b,) or ndarray, shape (dim_b, n_hists)
+ b : array-like, shape (dim_b,) or array-like, shape (dim_b, n_hists)
samples in the target domain, compute sinkhorn with multiple targets
- and fixed M if b is a matrix (return OT loss + dual variables in log)
- M : ndarray, shape (dim_a, dim_b)
+ and fixed :math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix
+ (return OT loss + dual variables in log)
+ M : array-like, shape (dim_a, dim_b)
loss matrix
reg : float
Regularization term >0
numItermax : int, optional
Max number of iterations
stopThr : float, optional
- Stop threshol on error (>0)
+ Stop threshold on error (>0)
log : bool, optional
record log if True
+ warn : bool, optional
+ if True, raises a warning if the algorithm doesn't convergence.
Returns
-------
- gamma : ndarray, shape (dim_a, dim_b)
+ gamma : array-like, shape (dim_a, dim_b)
Optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
@@ -477,11 +806,18 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False,
[0.13447071, 0.36552929]])
+ .. _references-greenkhorn:
References
----------
- .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
- [22] J. Altschuler, J.Weed, P. Rigollet : Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration, Advances in Neural Information Processing Systems (NIPS) 31, 2017
+ .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation
+ of Optimal Transport, Advances in Neural Information
+ Processing Systems (NIPS) 26, 2013
+
+ .. [22] J. Altschuler, J.Weed, P. Rigollet : Near-linear time
+ approximation algorithms for optimal transport via Sinkhorn
+ iteration, Advances in Neural Information Processing
+ Systems (NIPS) 31, 2017
See Also
@@ -491,68 +827,70 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False,
"""
- a = np.asarray(a, dtype=np.float64)
- b = np.asarray(b, dtype=np.float64)
- M = np.asarray(M, dtype=np.float64)
+ a, b, M = list_to_array(a, b, M)
+
+ nx = get_backend(M, a, b)
+ if nx.__name__ == "jax":
+ raise TypeError("JAX arrays have been received. Greenkhorn is not "
+ "compatible with JAX")
if len(a) == 0:
- a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0]
+ a = nx.ones((M.shape[0],), type_as=M) / M.shape[0]
if len(b) == 0:
- b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1]
+ b = nx.ones((M.shape[1],), type_as=M) / M.shape[1]
dim_a = a.shape[0]
dim_b = b.shape[0]
- # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute
- K = np.empty_like(M)
- np.divide(M, -reg, out=K)
- np.exp(K, out=K)
+ K = nx.exp(-M / reg)
- u = np.full(dim_a, 1. / dim_a)
- v = np.full(dim_b, 1. / dim_b)
- G = u[:, np.newaxis] * K * v[np.newaxis, :]
+ u = nx.full((dim_a,), 1. / dim_a, type_as=K)
+ v = nx.full((dim_b,), 1. / dim_b, type_as=K)
+ G = u[:, None] * K * v[None, :]
- viol = G.sum(1) - a
- viol_2 = G.sum(0) - b
+ viol = nx.sum(G, axis=1) - a
+ viol_2 = nx.sum(G, axis=0) - b
stopThr_val = 1
-
if log:
log = dict()
log['u'] = u
log['v'] = v
- for i in range(numItermax):
- i_1 = np.argmax(np.abs(viol))
- i_2 = np.argmax(np.abs(viol_2))
- m_viol_1 = np.abs(viol[i_1])
- m_viol_2 = np.abs(viol_2[i_2])
- stopThr_val = np.maximum(m_viol_1, m_viol_2)
+ for ii in range(numItermax):
+ i_1 = nx.argmax(nx.abs(viol))
+ i_2 = nx.argmax(nx.abs(viol_2))
+ m_viol_1 = nx.abs(viol[i_1])
+ m_viol_2 = nx.abs(viol_2[i_2])
+ stopThr_val = nx.maximum(m_viol_1, m_viol_2)
if m_viol_1 > m_viol_2:
old_u = u[i_1]
- u[i_1] = a[i_1] / (K[i_1, :].dot(v))
- G[i_1, :] = u[i_1] * K[i_1, :] * v
-
- viol[i_1] = u[i_1] * K[i_1, :].dot(v) - a[i_1]
- viol_2 += (K[i_1, :].T * (u[i_1] - old_u) * v)
+ new_u = a[i_1] / (K[i_1, :].dot(v))
+ G[i_1, :] = new_u * K[i_1, :] * v
+ viol[i_1] = new_u * K[i_1, :].dot(v) - a[i_1]
+ viol_2 += (K[i_1, :].T * (new_u - old_u) * v)
+ u[i_1] = new_u
else:
old_v = v[i_2]
- v[i_2] = b[i_2] / (K[:, i_2].T.dot(u))
- G[:, i_2] = u * K[:, i_2] * v[i_2]
+ new_v = b[i_2] / (K[:, i_2].T.dot(u))
+ G[:, i_2] = u * K[:, i_2] * new_v
# aviol = (G@one_m - a)
# aviol_2 = (G.T@one_n - b)
- viol += (-old_v + v[i_2]) * K[:, i_2] * u
- viol_2[i_2] = v[i_2] * K[:, i_2].dot(u) - b[i_2]
-
- # print('b',np.max(abs(aviol -viol)),np.max(abs(aviol_2 - viol_2)))
+ viol += (-old_v + new_v) * K[:, i_2] * u
+ viol_2[i_2] = new_v * K[:, i_2].dot(u) - b[i_2]
+ v[i_2] = new_v
if stopThr_val <= stopThr:
break
else:
- print('Warning: Algorithm did not converge')
+ if warn:
+ warnings.warn("Sinkhorn did not converge. You might want to "
+ "increase the number of iterations `numItermax` "
+ "or the regularization parameter `reg`.")
if log:
+ log["n_iter"] = ii
log['u'] = u
log['v'] = v
@@ -564,58 +902,66 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False,
def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
warmstart=None, verbose=False, print_period=20,
- log=False, **kwargs):
+ log=False, warn=True, **kwargs):
r"""
Solve the entropic regularization OT problem with log stabilization
The function solves the following optimization problem:
.. math::
- \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)
+ \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
+ \mathrm{reg}\cdot\Omega(\gamma)
- s.t. \gamma 1 = a
+ s.t. \ \gamma \mathbf{1} &= \mathbf{a}
- \gamma^T 1= b
+ \gamma^T \mathbf{1} &= \mathbf{b}
- \gamma\geq 0
+ \gamma &\geq 0
where :
- - M is the (dim_a, dim_b) metric cost matrix
- - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - a and b are source and target weights (histograms, both sum to 1)
+ - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix
+ - :math:`\Omega` is the entropic regularization term
+ :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target
+ weights (histograms, both sum to 1)
The algorithm used for solving the problem is the Sinkhorn-Knopp matrix
- scaling algorithm as proposed in [2]_ but with the log stabilization
- proposed in [10]_ an defined in [9]_ (Algo 3.1) .
+ scaling algorithm as proposed in :ref:`[2] <references-sinkhorn-stabilized>`
+ but with the log stabilization
+ proposed in :ref:`[10] <references-sinkhorn-stabilized>` an defined in
+ :ref:`[9] <references-sinkhorn-stabilized>` (Algo 3.1) .
Parameters
----------
- a : ndarray, shape (dim_a,)
+ a : array-like, shape (dim_a,)
samples weights in the source domain
- b : ndarray, shape (dim_b,)
+ b : array-like, shape (dim_b,)
samples in the target domain
- M : ndarray, shape (dim_a, dim_b)
+ M : array-like, shape (dim_a, dim_b)
loss matrix
reg : float
Regularization term >0
tau : float
- thershold for max value in u or v for log scaling
- warmstart : tible of vectors
- if given then sarting values for alpha an beta log scalings
+ threshold for max value in :math:`\mathbf{u}` or :math:`\mathbf{v}`
+ for log scaling
+ warmstart : table of vectors
+ if given then starting values for alpha and beta log scalings
numItermax : int, optional
Max number of iterations
stopThr : float, optional
- Stop threshol on error (>0)
+ Stop threshold on error (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
record log if True
+ warn : bool, optional
+ if True, raises a warning if the algorithm doesn't convergence.
Returns
-------
- gamma : ndarray, shape (dim_a, dim_b)
+ gamma : array-like, shape (dim_a, dim_b)
Optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
@@ -632,14 +978,21 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
[0.13447071, 0.36552929]])
+ .. _references-sinkhorn-stabilized:
References
----------
- .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
+ .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of
+ Optimal Transport, Advances in Neural Information Processing
+ Systems (NIPS) 26, 2013
- .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519.
+ .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms
+ for Entropy Regularized Transport Problems.
+ arXiv preprint arXiv:1610.06519.
- .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
+ .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
+ Scaling algorithms for unbalanced transport problems.
+ arXiv preprint arXiv:1607.05816.
See Also
@@ -649,19 +1002,19 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
"""
- a = np.asarray(a, dtype=np.float64)
- b = np.asarray(b, dtype=np.float64)
- M = np.asarray(M, dtype=np.float64)
+ a, b, M = list_to_array(a, b, M)
+
+ nx = get_backend(M, a, b)
if len(a) == 0:
- a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0]
+ a = nx.ones((M.shape[0],), type_as=M) / M.shape[0]
if len(b) == 0:
- b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1]
+ b = nx.ones((M.shape[1],), type_as=M) / M.shape[1]
# test if multiple target
if len(b.shape) > 1:
n_hists = b.shape[1]
- a = a[:, np.newaxis]
+ a = a[:, None]
else:
n_hists = 0
@@ -669,123 +1022,123 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
dim_a = len(a)
dim_b = len(b)
- cpt = 0
if log:
log = {'err': []}
# we assume that no distances are null except those of the diagonal of
# distances
if warmstart is None:
- alpha, beta = np.zeros(dim_a), np.zeros(dim_b)
+ alpha, beta = nx.zeros(dim_a, type_as=M), nx.zeros(dim_b, type_as=M)
else:
alpha, beta = warmstart
if n_hists:
- u = np.ones((dim_a, n_hists)) / dim_a
- v = np.ones((dim_b, n_hists)) / dim_b
+ u = nx.ones((dim_a, n_hists), type_as=M) / dim_a
+ v = nx.ones((dim_b, n_hists), type_as=M) / dim_b
else:
- u, v = np.ones(dim_a) / dim_a, np.ones(dim_b) / dim_b
+ u, v = nx.ones(dim_a, type_as=M), nx.ones(dim_b, type_as=M)
+ u /= dim_a
+ v /= dim_b
def get_K(alpha, beta):
"""log space computation"""
- return np.exp(-(M - alpha.reshape((dim_a, 1))
+ return nx.exp(-(M - alpha.reshape((dim_a, 1))
- beta.reshape((1, dim_b))) / reg)
def get_Gamma(alpha, beta, u, v):
"""log space gamma computation"""
- return np.exp(-(M - alpha.reshape((dim_a, 1)) - beta.reshape((1, dim_b)))
- / reg + np.log(u.reshape((dim_a, 1))) + np.log(v.reshape((1, dim_b))))
-
- # print(np.min(K))
+ return nx.exp(-(M - alpha.reshape((dim_a, 1)) - beta.reshape((1, dim_b)))
+ / reg + nx.log(u.reshape((dim_a, 1))) + nx.log(v.reshape((1, dim_b))))
K = get_K(alpha, beta)
transp = K
- loop = 1
- cpt = 0
err = 1
- while loop:
+ for ii in range(numItermax):
uprev = u
vprev = v
# sinkhorn update
- v = b / (np.dot(K.T, u) + 1e-16)
- u = a / (np.dot(K, v) + 1e-16)
+ v = b / (nx.dot(K.T, u))
+ u = a / (nx.dot(K, v))
# remove numerical problems and store them in K
- if np.abs(u).max() > tau or np.abs(v).max() > tau:
+ if nx.max(nx.abs(u)) > tau or nx.max(nx.abs(v)) > tau:
if n_hists:
- alpha, beta = alpha + reg * \
- np.max(np.log(u), 1), beta + reg * np.max(np.log(v))
+ alpha, beta = alpha + reg * nx.max(nx.log(u), 1), beta + reg * nx.max(nx.log(v))
else:
- alpha, beta = alpha + reg * np.log(u), beta + reg * np.log(v)
+ alpha, beta = alpha + reg * nx.log(u), beta + reg * nx.log(v)
if n_hists:
- u, v = np.ones((dim_a, n_hists)) / dim_a, np.ones((dim_b, n_hists)) / dim_b
+ u = nx.ones((dim_a, n_hists), type_as=M) / dim_a
+ v = nx.ones((dim_b, n_hists), type_as=M) / dim_b
else:
- u, v = np.ones(dim_a) / dim_a, np.ones(dim_b) / dim_b
+ u = nx.ones(dim_a, type_as=M) / dim_a
+ v = nx.ones(dim_b, type_as=M) / dim_b
K = get_K(alpha, beta)
- if cpt % print_period == 0:
+ if ii % print_period == 0:
# we can speed up the process by checking for the error only all
# the 10th iterations
if n_hists:
- err_u = abs(u - uprev).max()
- err_u /= max(abs(u).max(), abs(uprev).max(), 1.)
- err_v = abs(v - vprev).max()
- err_v /= max(abs(v).max(), abs(vprev).max(), 1.)
+ err_u = nx.max(nx.abs(u - uprev))
+ err_u /= max(nx.max(nx.abs(u)), nx.max(nx.abs(uprev)), 1.0)
+ err_v = nx.max(nx.abs(v - vprev))
+ err_v /= max(nx.max(nx.abs(v)), nx.max(nx.abs(vprev)), 1.0)
err = 0.5 * (err_u + err_v)
else:
transp = get_Gamma(alpha, beta, u, v)
- err = np.linalg.norm((np.sum(transp, axis=0) - b))
+ err = nx.norm(nx.sum(transp, axis=0) - b)
if log:
log['err'].append(err)
if verbose:
- if cpt % (print_period * 20) == 0:
+ if ii % (print_period * 20) == 0:
print(
'{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
- print('{:5d}|{:8e}|'.format(cpt, err))
+ print('{:5d}|{:8e}|'.format(ii, err))
if err <= stopThr:
- loop = False
-
- if cpt >= numItermax:
- loop = False
+ break
- if np.any(np.isnan(u)) or np.any(np.isnan(v)):
+ if nx.any(nx.isnan(u)) or nx.any(nx.isnan(v)):
# we have reached the machine precision
# come back to previous solution and quit loop
- print('Warning: numerical errors at iteration', cpt)
+ warnings.warn('Numerical errors at iteration %d' % ii)
u = uprev
v = vprev
break
-
- cpt = cpt + 1
-
+ else:
+ if warn:
+ warnings.warn("Sinkhorn did not converge. You might want to "
+ "increase the number of iterations `numItermax` "
+ "or the regularization parameter `reg`.")
if log:
if n_hists:
alpha = alpha[:, None]
beta = beta[:, None]
- logu = alpha / reg + np.log(u)
- logv = beta / reg + np.log(v)
+ logu = alpha / reg + nx.log(u)
+ logv = beta / reg + nx.log(v)
+ log["n_iter"] = ii
log['logu'] = logu
log['logv'] = logv
- log['alpha'] = alpha + reg * np.log(u)
- log['beta'] = beta + reg * np.log(v)
+ log['alpha'] = alpha + reg * nx.log(u)
+ log['beta'] = beta + reg * nx.log(v)
log['warmstart'] = (log['alpha'], log['beta'])
if n_hists:
- res = np.zeros((n_hists))
- for i in range(n_hists):
- res[i] = np.sum(get_Gamma(alpha, beta, u[:, i], v[:, i]) * M)
+ res = nx.stack([
+ nx.sum(get_Gamma(alpha, beta, u[:, i], v[:, i]) * M)
+ for i in range(n_hists)
+ ])
return res, log
else:
return get_Gamma(alpha, beta, u, v), log
else:
if n_hists:
- res = np.zeros((n_hists))
- for i in range(n_hists):
- res[i] = np.sum(get_Gamma(alpha, beta, u[:, i], v[:, i]) * M)
+ res = nx.stack([
+ nx.sum(get_Gamma(alpha, beta, u[:, i], v[:, i]) * M)
+ for i in range(n_hists)
+ ])
return res
else:
return get_Gamma(alpha, beta, u, v)
@@ -794,70 +1147,73 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4,
numInnerItermax=100, tau=1e3, stopThr=1e-9,
warmstart=None, verbose=False, print_period=10,
- log=False, **kwargs):
+ log=False, warn=True, **kwargs):
r"""
Solve the entropic regularization optimal transport problem with log
stabilization and epsilon scaling.
-
The function solves the following optimization problem:
.. math::
- \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)
+ \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
+ \mathrm{reg}\cdot\Omega(\gamma)
- s.t. \gamma 1 = a
+ s.t. \ \gamma \mathbf{1} &= \mathbf{a}
- \gamma^T 1= b
+ \gamma^T \mathbf{1} &= \mathbf{b}
- \gamma\geq 0
- where :
+ \gamma &\geq 0
- - M is the (dim_a, dim_b) metric cost matrix
- - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - a and b are source and target weights (histograms, both sum to 1)
+ where :
+ - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix
+ - :math:`\Omega` is the entropic regularization term
+ :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (histograms, both sum to 1)
The algorithm used for solving the problem is the Sinkhorn-Knopp matrix
- scaling algorithm as proposed in [2]_ but with the log stabilization
- proposed in [10]_ and the log scaling proposed in [9]_ algorithm 3.2
-
+ scaling algorithm as proposed in :ref:`[2] <references-sinkhorn-epsilon-scaling>`
+ but with the log stabilization
+ proposed in :ref:`[10] <references-sinkhorn-epsilon-scaling>` and the log scaling
+ proposed in :ref:`[9] <references-sinkhorn-epsilon-scaling>` algorithm 3.2
Parameters
----------
- a : ndarray, shape (dim_a,)
+ a : array-like, shape (dim_a,)
samples weights in the source domain
- b : ndarray, shape (dim_b,)
+ b : array-like, shape (dim_b,)
samples in the target domain
- M : ndarray, shape (dim_a, dim_b)
+ M : array-like, shape (dim_a, dim_b)
loss matrix
reg : float
Regularization term >0
tau : float
- thershold for max value in u or v for log scaling
+ threshold for max value in :math:`\mathbf{u}` or :math:`\mathbf{b}`
+ for log scaling
warmstart : tuple of vectors
- if given then sarting values for alpha an beta log scalings
+ if given then starting values for alpha and beta log scalings
numItermax : int, optional
Max number of iterations
numInnerItermax : int, optional
- Max number of iterationsin the inner slog stabilized sinkhorn
+ Max number of iterations in the inner slog stabilized sinkhorn
epsilon0 : int, optional
first epsilon regularization value (then exponential decrease to reg)
stopThr : float, optional
- Stop threshol on error (>0)
+ Stop threshold on error (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
record log if True
+ warn : bool, optional
+ if True, raises a warning if the algorithm doesn't convergence.
Returns
-------
- gamma : ndarray, shape (dim_a, dim_b)
+ gamma : array-like, shape (dim_a, dim_b)
Optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
-
Examples
--------
-
>>> import ot
>>> a=[.5, .5]
>>> b=[.5, .5]
@@ -866,29 +1222,32 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4,
array([[0.36552929, 0.13447071],
[0.13447071, 0.36552929]])
-
+ .. _references-sinkhorn-epsilon-scaling:
References
----------
+ .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal
+ Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
- .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
+ .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for
+ Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519.
- .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519.
+ .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
+ Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
See Also
--------
ot.lp.emd : Unregularized OT
ot.optim.cg : General regularized OT
-
"""
- a = np.asarray(a, dtype=np.float64)
- b = np.asarray(b, dtype=np.float64)
- M = np.asarray(M, dtype=np.float64)
+ a, b, M = list_to_array(a, b, M)
+
+ nx = get_backend(M, a, b)
if len(a) == 0:
- a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0]
+ a = nx.ones((M.shape[0],), type_as=M) / M.shape[0]
if len(b) == 0:
- b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1]
+ b = nx.ones((M.shape[1],), type_as=M) / M.shape[1]
# init data
dim_a = len(a)
@@ -898,14 +1257,14 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4,
numItermin = 35
numItermax = max(numItermin, numItermax) # ensure that last velue is exact
- cpt = 0
+ ii = 0
if log:
log = {'err': []}
# we assume that no distances are null except those of the diagonal of
# distances
if warmstart is None:
- alpha, beta = np.zeros(dim_a), np.zeros(dim_b)
+ alpha, beta = nx.zeros(dim_a, type_as=M), nx.zeros(dim_b, type_as=M)
else:
alpha, beta = warmstart
@@ -913,12 +1272,10 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4,
def get_reg(n): # exponential decreasing
return (epsilon0 - reg) * np.exp(-n) + reg
- loop = 1
- cpt = 0
err = 1
- while loop:
+ for ii in range(numItermax):
- regi = get_reg(cpt)
+ regi = get_reg(ii)
G, logi = sinkhorn_stabilized(a, b, M, regi,
numItermax=numInnerItermax, stopThr=1e-9,
@@ -928,33 +1285,31 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4,
alpha = logi['alpha']
beta = logi['beta']
- if cpt >= numItermax:
- loop = False
-
- if cpt % (print_period) == 0: # spsion nearly converged
+ if ii % (print_period) == 0: # spsion nearly converged
# we can speed up the process by checking for the error only all
# the 10th iterations
transp = G
- err = np.linalg.norm(
- (np.sum(transp, axis=0) - b)) ** 2 + np.linalg.norm((np.sum(transp, axis=1) - a)) ** 2
+ err = nx.norm(nx.sum(transp, axis=0) - b) ** 2 + nx.norm(nx.sum(transp, axis=1) - a) ** 2
if log:
log['err'].append(err)
if verbose:
- if cpt % (print_period * 10) == 0:
- print(
- '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
- print('{:5d}|{:8e}|'.format(cpt, err))
-
- if err <= stopThr and cpt > numItermin:
- loop = False
+ if ii % (print_period * 10) == 0:
+ print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(ii, err))
- cpt = cpt + 1
- # print('err=',err,' cpt=',cpt)
+ if err <= stopThr and ii > numItermin:
+ break
+ else:
+ if warn:
+ warnings.warn("Sinkhorn did not converge. You might want to "
+ "increase the number of iterations `numItermax` "
+ "or the regularization parameter `reg`.")
if log:
log['alpha'] = alpha
log['beta'] = beta
log['warmstart'] = (log['alpha'], log['beta'])
+ log['niter'] = ii
return G, log
else:
return G
@@ -962,76 +1317,94 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4,
def geometricBar(weights, alldistribT):
"""return the weighted geometric mean of distributions"""
+ weights, alldistribT = list_to_array(weights, alldistribT)
+ nx = get_backend(weights, alldistribT)
assert (len(weights) == alldistribT.shape[1])
- return np.exp(np.dot(np.log(alldistribT), weights.T))
+ return nx.exp(nx.dot(nx.log(alldistribT), weights.T))
def geometricMean(alldistribT):
"""return the geometric mean of distributions"""
- return np.exp(np.mean(np.log(alldistribT), axis=1))
+ alldistribT = list_to_array(alldistribT)
+ nx = get_backend(alldistribT)
+ return nx.exp(nx.mean(nx.log(alldistribT), axis=1))
def projR(gamma, p):
"""return the KL projection on the row constrints """
- return np.multiply(gamma.T, p / np.maximum(np.sum(gamma, axis=1), 1e-10)).T
+ gamma, p = list_to_array(gamma, p)
+ nx = get_backend(gamma, p)
+ return (gamma.T * p / nx.maximum(nx.sum(gamma, axis=1), 1e-10)).T
def projC(gamma, q):
"""return the KL projection on the column constrints """
- return np.multiply(gamma, q / np.maximum(np.sum(gamma, axis=0), 1e-10))
+ gamma, q = list_to_array(gamma, q)
+ nx = get_backend(gamma, q)
+ return gamma * q / nx.maximum(nx.sum(gamma, axis=0), 1e-10)
def barycenter(A, M, reg, weights=None, method="sinkhorn", numItermax=10000,
- stopThr=1e-4, verbose=False, log=False, **kwargs):
- r"""Compute the entropic regularized wasserstein barycenter of distributions A
+ stopThr=1e-4, verbose=False, log=False, warn=True, **kwargs):
+ r"""Compute the entropic regularized wasserstein barycenter of distributions :math:`\mathbf{A}`
The function solves the following optimization problem:
.. math::
- \mathbf{a} = arg\min_\mathbf{a} \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i)
+ \mathbf{a} = \mathop{\arg \min}_\mathbf{a} \quad \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i)
where :
- - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see ot.bregman.sinkhorn)
- - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}`
- - reg and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT
+ - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein
+ distance (see :py:func:`ot.bregman.sinkhorn`)
+ if `method` is `sinkhorn` or `sinkhorn_stabilized` or `sinkhorn_log`.
+ - :math:`\mathbf{a}_i` are training distributions in the columns of matrix
+ :math:`\mathbf{A}`
+ - `reg` and :math:`\mathbf{M}` are respectively the regularization term and
+ the cost matrix for OT
- The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [3]_
+ The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling
+ algorithm as proposed in :ref:`[3] <references-barycenter>`
Parameters
----------
- A : ndarray, shape (dim, n_hists)
- n_hists training distributions a_i of size dim
- M : ndarray, shape (dim, dim)
+ A : array-like, shape (dim, n_hists)
+ `n_hists` training distributions :math:`\mathbf{a}_i` of size `dim`
+ M : array-like, shape (dim, dim)
loss matrix for OT
reg : float
Regularization term > 0
method : str (optional)
- method used for the solver either 'sinkhorn' or 'sinkhorn_stabilized'
- weights : ndarray, shape (n_hists,)
- Weights of each histogram a_i on the simplex (barycentric coodinates)
+ method used for the solver either 'sinkhorn' or 'sinkhorn_stabilized' or 'sinkhorn_log'
+ weights : array-like, shape (n_hists,)
+ Weights of each histogram :math:`\mathbf{a}_i` on the simplex (barycentric coodinates)
numItermax : int, optional
Max number of iterations
stopThr : float, optional
- Stop threshol on error (>0)
+ Stop threshold on error (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
record log if True
+ warn : bool, optional
+ if True, raises a warning if the algorithm doesn't convergence.
Returns
-------
- a : (dim,) ndarray
+ a : (dim,) array-like
Wasserstein barycenter
log : dict
log dictionary return only if log==True in parameters
+ .. _references-barycenter:
References
----------
- .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). Iterative Bregman projections for regularized transportation problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138.
+ .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015).
+ Iterative Bregman projections for regularized transportation problems.
+ SIAM Journal on Scientific Computing, 37(2), A1111-A1138.
"""
@@ -1039,232 +1412,327 @@ def barycenter(A, M, reg, weights=None, method="sinkhorn", numItermax=10000,
return barycenter_sinkhorn(A, M, reg, weights=weights,
numItermax=numItermax,
stopThr=stopThr, verbose=verbose, log=log,
+ warn=warn,
**kwargs)
elif method.lower() == 'sinkhorn_stabilized':
return barycenter_stabilized(A, M, reg, weights=weights,
numItermax=numItermax,
stopThr=stopThr, verbose=verbose,
- log=log, **kwargs)
+ log=log, warn=warn, **kwargs)
+ elif method.lower() == 'sinkhorn_log':
+ return _barycenter_sinkhorn_log(A, M, reg, weights=weights,
+ numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose,
+ log=log, warn=warn, **kwargs)
else:
raise ValueError("Unknown method '%s'." % method)
def barycenter_sinkhorn(A, M, reg, weights=None, numItermax=1000,
- stopThr=1e-4, verbose=False, log=False):
- r"""Compute the entropic regularized wasserstein barycenter of distributions A
+ stopThr=1e-4, verbose=False, log=False, warn=True):
+ r"""Compute the entropic regularized wasserstein barycenter of distributions :math:`\mathbf{A}`
The function solves the following optimization problem:
.. math::
- \mathbf{a} = arg\min_\mathbf{a} \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i)
+ \mathbf{a} = \mathop{\arg \min}_\mathbf{a} \quad \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i)
where :
- - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see ot.bregman.sinkhorn)
- - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}`
- - reg and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT
+ - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance
+ (see :py:func:`ot.bregman.sinkhorn`)
+ - :math:`\mathbf{a}_i` are training distributions in the columns of matrix
+ :math:`\mathbf{A}`
+ - `reg` and :math:`\mathbf{M}` are respectively the regularization term and
+ the cost matrix for OT
- The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [3]_
+ The algorithm used for solving the problem is the Sinkhorn-Knopp matrix
+ scaling algorithm as proposed in :ref:`[3]<references-barycenter-sinkhorn>`.
Parameters
----------
- A : ndarray, shape (dim, n_hists)
- n_hists training distributions a_i of size dim
- M : ndarray, shape (dim, dim)
+ A : array-like, shape (dim, n_hists)
+ `n_hists` training distributions :math:`\mathbf{a}_i` of size `dim`
+ M : array-like, shape (dim, dim)
loss matrix for OT
reg : float
Regularization term > 0
- weights : ndarray, shape (n_hists,)
- Weights of each histogram a_i on the simplex (barycentric coodinates)
+ weights : array-like, shape (n_hists,)
+ Weights of each histogram :math:`\mathbf{a}_i` on the simplex (barycentric coodinates)
numItermax : int, optional
Max number of iterations
stopThr : float, optional
- Stop threshol on error (>0)
+ Stop threshold on error (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
record log if True
+ warn : bool, optional
+ if True, raises a warning if the algorithm doesn't convergence.
Returns
-------
- a : (dim,) ndarray
+ a : (dim,) array-like
Wasserstein barycenter
log : dict
log dictionary return only if log==True in parameters
+ .. _references-barycenter-sinkhorn:
References
----------
- .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). Iterative Bregman projections for regularized transportation problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138.
+ .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015).
+ Iterative Bregman projections for regularized transportation problems.
+ SIAM Journal on Scientific Computing, 37(2), A1111-A1138.
"""
+ A, M = list_to_array(A, M)
+
+ nx = get_backend(A, M)
+
if weights is None:
- weights = np.ones(A.shape[1]) / A.shape[1]
+ weights = nx.ones((A.shape[1],), type_as=A) / A.shape[1]
else:
assert (len(weights) == A.shape[1])
if log:
log = {'err': []}
- # M = M/np.median(M) # suggested by G. Peyre
- K = np.exp(-M / reg)
+ K = nx.exp(-M / reg)
- cpt = 0
err = 1
- UKv = np.dot(K, np.divide(A.T, np.sum(K, axis=0)).T)
+ UKv = nx.dot(K, (A.T / nx.sum(K, axis=0)).T)
+
u = (geometricMean(UKv) / UKv.T).T
- while (err > stopThr and cpt < numItermax):
- cpt = cpt + 1
- UKv = u * np.dot(K, np.divide(A, np.dot(K, u)))
+ for ii in range(numItermax):
+
+ UKv = u * nx.dot(K, A / nx.dot(K, u))
u = (u.T * geometricBar(weights, UKv)).T / UKv
- if cpt % 10 == 1:
- err = np.sum(np.std(UKv, axis=1))
+ if ii % 10 == 1:
+ err = nx.sum(nx.std(UKv, axis=1))
# log and verbose print
if log:
log['err'].append(err)
+ if err < stopThr:
+ break
if verbose:
- if cpt % 200 == 0:
+ if ii % 200 == 0:
print(
'{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
- print('{:5d}|{:8e}|'.format(cpt, err))
-
+ print('{:5d}|{:8e}|'.format(ii, err))
+ else:
+ if warn:
+ warnings.warn("Sinkhorn did not converge. You might want to "
+ "increase the number of iterations `numItermax` "
+ "or the regularization parameter `reg`.")
if log:
- log['niter'] = cpt
+ log['niter'] = ii
return geometricBar(weights, UKv), log
else:
return geometricBar(weights, UKv)
+def _barycenter_sinkhorn_log(A, M, reg, weights=None, numItermax=1000,
+ stopThr=1e-4, verbose=False, log=False, warn=True):
+ r"""Compute the entropic wasserstein barycenter in log-domain
+ """
+
+ A, M = list_to_array(A, M)
+ dim, n_hists = A.shape
+
+ nx = get_backend(A, M)
+
+ if nx.__name__ == "jax":
+ raise NotImplementedError("Log-domain functions are not yet implemented"
+ " for Jax. Use numpy or torch arrays instead.")
+
+ if weights is None:
+ weights = nx.ones(n_hists, type_as=A) / n_hists
+ else:
+ assert (len(weights) == A.shape[1])
+
+ if log:
+ log = {'err': []}
+
+ M = - M / reg
+ logA = nx.log(A + 1e-15)
+ log_KU, G = nx.zeros((2, *logA.shape), type_as=A)
+ err = 1
+ for ii in range(numItermax):
+ log_bar = nx.zeros(dim, type_as=A)
+ for k in range(n_hists):
+ f = logA[:, k] - nx.logsumexp(M + G[None, :, k], axis=1)
+ log_KU[:, k] = nx.logsumexp(M + f[:, None], axis=0)
+ log_bar = log_bar + weights[k] * log_KU[:, k]
+
+ if ii % 10 == 1:
+ err = nx.exp(G + log_KU).std(axis=1).sum()
+
+ # log and verbose print
+ if log:
+ log['err'].append(err)
+
+ if err < stopThr:
+ break
+ if verbose:
+ if ii % 200 == 0:
+ print(
+ '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(ii, err))
+
+ G = log_bar[:, None] - log_KU
+
+ else:
+ if warn:
+ warnings.warn("Sinkhorn did not converge. You might want to "
+ "increase the number of iterations `numItermax` "
+ "or the regularization parameter `reg`.")
+ if log:
+ log['niter'] = ii
+ return nx.exp(log_bar), log
+ else:
+ return nx.exp(log_bar)
+
+
def barycenter_stabilized(A, M, reg, tau=1e10, weights=None, numItermax=1000,
- stopThr=1e-4, verbose=False, log=False):
- r"""Compute the entropic regularized wasserstein barycenter of distributions A
- with stabilization.
+ stopThr=1e-4, verbose=False, log=False, warn=True):
+ r"""Compute the entropic regularized wasserstein barycenter of distributions :math:`\mathbf{A}` with stabilization.
The function solves the following optimization problem:
.. math::
- \mathbf{a} = arg\min_\mathbf{a} \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i)
+ \mathbf{a} = \mathop{\arg \min}_\mathbf{a} \quad \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i)
where :
- - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see ot.bregman.sinkhorn)
- - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}`
- - reg and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT
+ - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein
+ distance (see :py:func:`ot.bregman.sinkhorn`)
+ - :math:`\mathbf{a}_i` are training distributions in the columns of matrix
+ :math:`\mathbf{A}`
+ - `reg` and :math:`\mathbf{M}` are respectively the regularization term and
+ the cost matrix for OT
- The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [3]_
+ The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling
+ algorithm as proposed in :ref:`[3] <references-barycenter-stabilized>`
Parameters
----------
- A : ndarray, shape (dim, n_hists)
- n_hists training distributions a_i of size dim
- M : ndarray, shape (dim, dim)
+ A : array-like, shape (dim, n_hists)
+ `n_hists` training distributions :math:`\mathbf{a}_i` of size `dim`
+ M : array-like, shape (dim, dim)
loss matrix for OT
reg : float
Regularization term > 0
tau : float
- thershold for max value in u or v for log scaling
- weights : ndarray, shape (n_hists,)
- Weights of each histogram a_i on the simplex (barycentric coodinates)
+ threshold for max value in :math:`\mathbf{u}` or :math:`\mathbf{v}`
+ for log scaling
+ weights : array-like, shape (n_hists,)
+ Weights of each histogram :math:`\mathbf{a}_i` on the simplex (barycentric coodinates)
numItermax : int, optional
Max number of iterations
stopThr : float, optional
- Stop threshol on error (>0)
+ Stop threshold on error (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
record log if True
+ warn : bool, optional
+ if True, raises a warning if the algorithm doesn't convergence.
Returns
-------
- a : (dim,) ndarray
+ a : (dim,) array-like
Wasserstein barycenter
log : dict
log dictionary return only if log==True in parameters
+ .. _references-barycenter-stabilized:
References
----------
- .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). Iterative Bregman projections for regularized transportation problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138.
+ .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015).
+ Iterative Bregman projections for regularized transportation problems.
+ SIAM Journal on Scientific Computing, 37(2), A1111-A1138.
"""
+ A, M = list_to_array(A, M)
+
+ nx = get_backend(A, M)
+
dim, n_hists = A.shape
if weights is None:
- weights = np.ones(n_hists) / n_hists
+ weights = nx.ones((n_hists,), type_as=M) / n_hists
else:
assert (len(weights) == A.shape[1])
if log:
log = {'err': []}
- u = np.ones((dim, n_hists)) / dim
- v = np.ones((dim, n_hists)) / dim
+ u = nx.ones((dim, n_hists), type_as=M) / dim
+ v = nx.ones((dim, n_hists), type_as=M) / dim
- # print(reg)
- # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute
- K = np.empty(M.shape, dtype=M.dtype)
- np.divide(M, -reg, out=K)
- np.exp(K, out=K)
+ K = nx.exp(-M / reg)
- cpt = 0
err = 1.
- alpha = np.zeros(dim)
- beta = np.zeros(dim)
- q = np.ones(dim) / dim
- while (err > stopThr and cpt < numItermax):
+ alpha = nx.zeros((dim,), type_as=M)
+ beta = nx.zeros((dim,), type_as=M)
+ q = nx.ones((dim,), type_as=M) / dim
+ for ii in range(numItermax):
qprev = q
- Kv = K.dot(v)
- u = A / (Kv + 1e-16)
- Ktu = K.T.dot(u)
+ Kv = nx.dot(K, v)
+ u = A / Kv
+ Ktu = nx.dot(K.T, u)
q = geometricBar(weights, Ktu)
Q = q[:, None]
- v = Q / (Ktu + 1e-16)
+ v = Q / Ktu
absorbing = False
- if (u > tau).any() or (v > tau).any():
+ if nx.any(u > tau) or nx.any(v > tau):
absorbing = True
- alpha = alpha + reg * np.log(np.max(u, 1))
- beta = beta + reg * np.log(np.max(v, 1))
- K = np.exp((alpha[:, None] + beta[None, :] -
- M) / reg)
- v = np.ones_like(v)
- Kv = K.dot(v)
- if (np.any(Ktu == 0.)
- or np.any(np.isnan(u)) or np.any(np.isnan(v))
- or np.any(np.isinf(u)) or np.any(np.isinf(v))):
+ alpha += reg * nx.log(nx.max(u, 1))
+ beta += reg * nx.log(nx.max(v, 1))
+ K = nx.exp((alpha[:, None] + beta[None, :] - M) / reg)
+ v = nx.ones(tuple(v.shape), type_as=v)
+ Kv = nx.dot(K, v)
+ if (nx.any(Ktu == 0.)
+ or nx.any(nx.isnan(u)) or nx.any(nx.isnan(v))
+ or nx.any(nx.isinf(u)) or nx.any(nx.isinf(v))):
# we have reached the machine precision
# come back to previous solution and quit loop
- warnings.warn('Numerical errors at iteration %s' % cpt)
+ warnings.warn('Numerical errors at iteration %s' % ii)
q = qprev
break
- if (cpt % 10 == 0 and not absorbing) or cpt == 0:
+ if (ii % 10 == 0 and not absorbing) or ii == 0:
# we can speed up the process by checking for the error only all
# the 10th iterations
- err = abs(u * Kv - A).max()
+ err = nx.max(nx.abs(u * Kv - A))
if log:
log['err'].append(err)
+ if err < stopThr:
+ break
if verbose:
- if cpt % 50 == 0:
+ if ii % 50 == 0:
print(
'{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
- print('{:5d}|{:8e}|'.format(cpt, err))
+ print('{:5d}|{:8e}|'.format(ii, err))
- cpt += 1
- if err > stopThr:
- warnings.warn("Stabilized Unbalanced Sinkhorn did not converge." +
- "Try a larger entropy `reg`" +
- "Or a larger absorption threshold `tau`.")
+ else:
+ if warn:
+ warnings.warn("Stabilized Sinkhorn did not converge." +
+ "Try a larger entropy `reg`" +
+ "Or a larger absorption threshold `tau`.")
if log:
- log['niter'] = cpt
+ log['niter'] = ii
log['logu'] = np.log(u + 1e-16)
log['logv'] = np.log(v + 1e-16)
return q, log
@@ -1272,157 +1740,717 @@ def barycenter_stabilized(A, M, reg, tau=1e10, weights=None, numItermax=1000,
return q
-def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000,
- stopThr=1e-9, stabThr=1e-30, verbose=False,
- log=False):
- r"""Compute the entropic regularized wasserstein barycenter of distributions A
- where A is a collection of 2D images.
+def barycenter_debiased(A, M, reg, weights=None, method="sinkhorn", numItermax=10000,
+ stopThr=1e-4, verbose=False, log=False, warn=True, **kwargs):
+ r"""Compute the debiased Sinkhorn barycenter of distributions A
The function solves the following optimization problem:
.. math::
- \mathbf{a} = arg\min_\mathbf{a} \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i)
+ \mathbf{a} = \mathop{\arg \min}_\mathbf{a} \quad \sum_i S_{reg}(\mathbf{a},\mathbf{a}_i)
where :
- - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see ot.bregman.sinkhorn)
- - :math:`\mathbf{a}_i` are training distributions (2D images) in the mast two dimensions of matrix :math:`\mathbf{A}`
- - reg is the regularization strength scalar value
+ - :math:`S_{reg}(\cdot,\cdot)` is the debiased Sinkhorn divergence
+ (see :py:func:`ot.bregman.empirical_sinkhorn_divergence`)
+ - :math:`\mathbf{a}_i` are training distributions in the columns of matrix
+ :math:`\mathbf{A}`
+ - `reg` and :math:`\mathbf{M}` are respectively the regularization term and
+ the cost matrix for OT
- The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [21]_
+ The algorithm used for solving the problem is the debiased Sinkhorn
+ algorithm as proposed in :ref:`[37] <references-barycenter-debiased>`
Parameters
----------
- A : ndarray, shape (n_hists, width, height)
- n distributions (2D images) of size width x height
+ A : array-like, shape (dim, n_hists)
+ `n_hists` training distributions :math:`\mathbf{a}_i` of size `dim`
+ M : array-like, shape (dim, dim)
+ loss matrix for OT
+ reg : float
+ Regularization term > 0
+ method : str (optional)
+ method used for the solver either 'sinkhorn' or 'sinkhorn_log'
+ weights : array-like, shape (n_hists,)
+ Weights of each histogram :math:`\mathbf{a}_i` on the simplex (barycentric coodinates)
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshold on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+ warn : bool, optional
+ if True, raises a warning if the algorithm doesn't convergence.
+
+
+ Returns
+ -------
+ a : (dim,) array-like
+ Wasserstein barycenter
+ log : dict
+ log dictionary return only if log==True in parameters
+
+
+ .. _references-barycenter-debiased:
+ References
+ ----------
+ .. [37] Janati, H., Cuturi, M., Gramfort, A. Proceedings of the 37th International
+ Conference on Machine Learning, PMLR 119:4692-4701, 2020
+ """
+
+ if method.lower() == 'sinkhorn':
+ return _barycenter_debiased(A, M, reg, weights=weights,
+ numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose, log=log,
+ warn=warn, **kwargs)
+ elif method.lower() == 'sinkhorn_log':
+ return _barycenter_debiased_log(A, M, reg, weights=weights,
+ numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose,
+ log=log, warn=warn, **kwargs)
+ else:
+ raise ValueError("Unknown method '%s'." % method)
+
+
+def _barycenter_debiased(A, M, reg, weights=None, numItermax=1000,
+ stopThr=1e-4, verbose=False, log=False, warn=True):
+ r"""Compute the debiased sinkhorn barycenter of distributions A.
+ """
+
+ A, M = list_to_array(A, M)
+
+ nx = get_backend(A, M)
+
+ if weights is None:
+ weights = nx.ones((A.shape[1],), type_as=A) / A.shape[1]
+ else:
+ assert (len(weights) == A.shape[1])
+
+ if log:
+ log = {'err': []}
+
+ K = nx.exp(-M / reg)
+
+ err = 1
+
+ UKv = nx.dot(K, (A.T / nx.sum(K, axis=0)).T)
+
+ u = (geometricMean(UKv) / UKv.T).T
+ c = nx.ones(A.shape[0], type_as=A)
+ bar = nx.ones(A.shape[0], type_as=A)
+
+ for ii in range(numItermax):
+ bold = bar
+ UKv = nx.dot(K, A / nx.dot(K, u))
+ bar = c * geometricBar(weights, UKv)
+ u = bar[:, None] / UKv
+ c = (c * bar / nx.dot(K, c)) ** 0.5
+
+ if ii % 10 == 9:
+ err = abs(bar - bold).max() / max(bar.max(), 1.)
+
+ # log and verbose print
+ if log:
+ log['err'].append(err)
+
+ # debiased Sinkhorn does not converge monotonically
+ # guarantee a few iterations are done before stopping
+ if err < stopThr and ii > 20:
+ break
+ if verbose:
+ if ii % 200 == 0:
+ print(
+ '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(ii, err))
+ else:
+ if warn:
+ warnings.warn("Sinkhorn did not converge. You might want to "
+ "increase the number of iterations `numItermax` "
+ "or the regularization parameter `reg`.")
+ if log:
+ log['niter'] = ii
+ return bar, log
+ else:
+ return bar
+
+
+def _barycenter_debiased_log(A, M, reg, weights=None, numItermax=1000,
+ stopThr=1e-4, verbose=False, log=False,
+ warn=True):
+ r"""Compute the debiased sinkhorn barycenter in log domain.
+ """
+
+ A, M = list_to_array(A, M)
+ dim, n_hists = A.shape
+
+ nx = get_backend(A, M)
+ if nx.__name__ == "jax":
+ raise NotImplementedError("Log-domain functions are not yet implemented"
+ " for Jax. Use numpy or torch arrays instead.")
+
+ if weights is None:
+ weights = nx.ones(n_hists, type_as=A) / n_hists
+ else:
+ assert (len(weights) == A.shape[1])
+
+ if log:
+ log = {'err': []}
+
+ M = - M / reg
+ logA = nx.log(A + 1e-15)
+ log_KU, G = nx.zeros((2, *logA.shape), type_as=A)
+ c = nx.zeros(dim, type_as=A)
+ err = 1
+ for ii in range(numItermax):
+ log_bar = nx.zeros(dim, type_as=A)
+ for k in range(n_hists):
+ f = logA[:, k] - nx.logsumexp(M + G[None, :, k], axis=1)
+ log_KU[:, k] = nx.logsumexp(M + f[:, None], axis=0)
+ log_bar += weights[k] * log_KU[:, k]
+ log_bar += c
+ if ii % 10 == 1:
+ err = nx.exp(G + log_KU).std(axis=1).sum()
+
+ # log and verbose print
+ if log:
+ log['err'].append(err)
+
+ if err < stopThr and ii > 20:
+ break
+ if verbose:
+ if ii % 200 == 0:
+ print(
+ '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(ii, err))
+
+ G = log_bar[:, None] - log_KU
+ for _ in range(10):
+ c = 0.5 * (c + log_bar - nx.logsumexp(M + c[:, None], axis=0))
+
+ else:
+ if warn:
+ warnings.warn("Sinkhorn did not converge. You might want to "
+ "increase the number of iterations `numItermax` "
+ "or the regularization parameter `reg`.")
+ if log:
+ log['niter'] = ii
+ return nx.exp(log_bar), log
+ else:
+ return nx.exp(log_bar)
+
+
+def convolutional_barycenter2d(A, reg, weights=None, method="sinkhorn", numItermax=10000,
+ stopThr=1e-4, verbose=False, log=False,
+ warn=True, **kwargs):
+ r"""Compute the entropic regularized wasserstein barycenter of distributions :math:`\mathbf{A}`
+ where :math:`\mathbf{A}` is a collection of 2D images.
+
+ The function solves the following optimization problem:
+
+ .. math::
+ \mathbf{a} = \mathop{\arg \min}_\mathbf{a} \quad \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i)
+
+ where :
+
+ - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein
+ distance (see :py:func:`ot.bregman.sinkhorn`)
+ - :math:`\mathbf{a}_i` are training distributions (2D images) in the mast two dimensions
+ of matrix :math:`\mathbf{A}`
+ - `reg` is the regularization strength scalar value
+
+ The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm
+ as proposed in :ref:`[21] <references-convolutional-barycenter-2d>`
+
+ Parameters
+ ----------
+ A : array-like, shape (n_hists, width, height)
+ `n` distributions (2D images) of size `width` x `height`
reg : float
Regularization term >0
- weights : ndarray, shape (n_hists,)
+ weights : array-like, shape (n_hists,)
Weights of each image on the simplex (barycentric coodinates)
+ method : string, optional
+ method used for the solver either 'sinkhorn' or 'sinkhorn_log'
numItermax : int, optional
Max number of iterations
stopThr : float, optional
- Stop threshol on error (> 0)
+ Stop threshold on error (> 0)
stabThr : float, optional
Stabilization threshold to avoid numerical precision issue
verbose : bool, optional
Print information along iterations
log : bool, optional
record log if True
+ warn : bool, optional
+ if True, raises a warning if the algorithm doesn't convergence.
Returns
-------
- a : ndarray, shape (width, height)
+ a : array-like, shape (width, height)
2D Wasserstein barycenter
log : dict
log dictionary return only if log==True in parameters
+
+ .. _references-convolutional-barycenter-2d:
References
----------
- .. [21] Solomon, J., De Goes, F., Peyré, G., Cuturi, M., Butscher, A., Nguyen, A. & Guibas, L. (2015).
- Convolutional wasserstein distances: Efficient optimal transportation on geometric domains
- ACM Transactions on Graphics (TOG), 34(4), 66
+ .. [21] Solomon, J., De Goes, F., Peyré, G., Cuturi, M., Butscher,
+ A., Nguyen, A. & Guibas, L. (2015). Convolutional wasserstein distances:
+ Efficient optimal transportation on geometric domains. ACM Transactions
+ on Graphics (TOG), 34(4), 66
+
+ .. [37] Janati, H., Cuturi, M., Gramfort, A. Proceedings of the 37th
+ International Conference on Machine Learning, PMLR 119:4692-4701, 2020
+ """
+
+ if method.lower() == 'sinkhorn':
+ return _convolutional_barycenter2d(A, reg, weights=weights,
+ numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose,
+ log=log, warn=warn,
+ **kwargs)
+ elif method.lower() == 'sinkhorn_log':
+ return _convolutional_barycenter2d_log(A, reg, weights=weights,
+ numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose,
+ log=log, warn=warn,
+ **kwargs)
+ else:
+ raise ValueError("Unknown method '%s'." % method)
+def _convolutional_barycenter2d(A, reg, weights=None, numItermax=10000,
+ stopThr=1e-9, stabThr=1e-30, verbose=False,
+ log=False, warn=True):
+ r"""Compute the entropic regularized wasserstein barycenter of distributions A
+ where A is a collection of 2D images.
"""
+ A = list_to_array(A)
+
+ nx = get_backend(A)
+
if weights is None:
- weights = np.ones(A.shape[0]) / A.shape[0]
+ weights = nx.ones((A.shape[0],), type_as=A) / A.shape[0]
else:
assert (len(weights) == A.shape[0])
if log:
log = {'err': []}
- b = np.zeros_like(A[0, :, :])
- U = np.ones_like(A)
- KV = np.ones_like(A)
-
- cpt = 0
+ bar = nx.ones(A.shape[1:], type_as=A)
+ bar /= bar.sum()
+ U = nx.ones(A.shape, type_as=A)
+ V = nx.ones(A.shape, type_as=A)
err = 1
# build the convolution operator
# this is equivalent to blurring on horizontal then vertical directions
- t = np.linspace(0, 1, A.shape[1])
- [Y, X] = np.meshgrid(t, t)
- xi1 = np.exp(-(X - Y) ** 2 / reg)
+ t = nx.linspace(0, 1, A.shape[1])
+ [Y, X] = nx.meshgrid(t, t)
+ K1 = nx.exp(-(X - Y) ** 2 / reg)
+
+ t = nx.linspace(0, 1, A.shape[2])
+ [Y, X] = nx.meshgrid(t, t)
+ K2 = nx.exp(-(X - Y) ** 2 / reg)
+
+ def convol_imgs(imgs):
+ kx = nx.einsum("...ij,kjl->kil", K1, imgs)
+ kxy = nx.einsum("...ij,klj->kli", K2, kx)
+ return kxy
+
+ KU = convol_imgs(U)
+ for ii in range(numItermax):
+ V = bar[None] / KU
+ KV = convol_imgs(V)
+ U = A / KV
+ KU = convol_imgs(U)
+ bar = nx.exp((weights[:, None, None] * nx.log(KU + stabThr)).sum(axis=0))
+ if ii % 10 == 9:
+ err = (V * KU).std(axis=0).sum()
+ # log and verbose print
+ if log:
+ log['err'].append(err)
+
+ if verbose:
+ if ii % 200 == 0:
+ print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(ii, err))
+ if err < stopThr:
+ break
+
+ else:
+ if warn:
+ warnings.warn("Convolutional Sinkhorn did not converge. "
+ "Try a larger number of iterations `numItermax` "
+ "or a larger entropy `reg`.")
+ if log:
+ log['niter'] = ii
+ log['U'] = U
+ return bar, log
+ else:
+ return bar
+
+
+def _convolutional_barycenter2d_log(A, reg, weights=None, numItermax=10000,
+ stopThr=1e-4, stabThr=1e-30, verbose=False,
+ log=False, warn=True):
+ r"""Compute the entropic regularized wasserstein barycenter of distributions A
+ where A is a collection of 2D images in log-domain.
+ """
+
+ A = list_to_array(A)
+
+ nx = get_backend(A)
+ if nx.__name__ == "jax":
+ raise NotImplementedError("Log-domain functions are not yet implemented"
+ " for Jax. Use numpy or torch arrays instead.")
+
+ n_hists, width, height = A.shape
+
+ if weights is None:
+ weights = nx.ones((n_hists,), type_as=A) / n_hists
+ else:
+ assert (len(weights) == n_hists)
+
+ if log:
+ log = {'err': []}
+
+ err = 1
+ # build the convolution operator
+ # this is equivalent to blurring on horizontal then vertical directions
+ t = nx.linspace(0, 1, width)
+ [Y, X] = nx.meshgrid(t, t)
+ M1 = - (X - Y) ** 2 / reg
+
+ t = nx.linspace(0, 1, height)
+ [Y, X] = nx.meshgrid(t, t)
+ M2 = - (X - Y) ** 2 / reg
+
+ def convol_img(log_img):
+ log_img = nx.logsumexp(M1[:, :, None] + log_img[None], axis=1)
+ log_img = nx.logsumexp(M2[:, :, None] + log_img.T[None], axis=1).T
+ return log_img
+
+ logA = nx.log(A + stabThr)
+ log_KU, G, F = nx.zeros((3, *logA.shape), type_as=A)
+ err = 1
+ for ii in range(numItermax):
+ log_bar = nx.zeros((width, height), type_as=A)
+ for k in range(n_hists):
+ f = logA[k] - convol_img(G[k])
+ log_KU[k] = convol_img(f)
+ log_bar = log_bar + weights[k] * log_KU[k]
+
+ if ii % 10 == 9:
+ err = nx.exp(G + log_KU).std(axis=0).sum()
+ # log and verbose print
+ if log:
+ log['err'].append(err)
+
+ if verbose:
+ if ii % 200 == 0:
+ print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(ii, err))
+ if err < stopThr:
+ break
+ G = log_bar[None, :, :] - log_KU
+
+ else:
+ if warn:
+ warnings.warn("Convolutional Sinkhorn did not converge. "
+ "Try a larger number of iterations `numItermax` "
+ "or a larger entropy `reg`.")
+ if log:
+ log['niter'] = ii
+ return nx.exp(log_bar), log
+ else:
+ return nx.exp(log_bar)
+
+
+def convolutional_barycenter2d_debiased(A, reg, weights=None, method="sinkhorn",
+ numItermax=10000, stopThr=1e-3,
+ verbose=False, log=False, warn=True,
+ **kwargs):
+ r"""Compute the debiased sinkhorn barycenter of distributions :math:`\mathbf{A}`
+ where :math:`\mathbf{A}` is a collection of 2D images.
+
+ The function solves the following optimization problem:
+
+ .. math::
+ \mathbf{a} = \mathop{\arg \min}_\mathbf{a} \quad \sum_i S_{reg}(\mathbf{a},\mathbf{a}_i)
+
+ where :
+
+ - :math:`S_{reg}(\cdot,\cdot)` is the debiased entropic regularized Wasserstein
+ distance (see :py:func:`ot.bregman.barycenter_debiased`)
+ - :math:`\mathbf{a}_i` are training distributions (2D images) in the mast two
+ dimensions of matrix :math:`\mathbf{A}`
+ - `reg` is the regularization strength scalar value
+
+ The algorithm used for solving the problem is the debiased Sinkhorn scaling
+ algorithm as proposed in :ref:`[37] <references-convolutional-barycenter2d-debiased>`
+
+ Parameters
+ ----------
+ A : array-like, shape (n_hists, width, height)
+ `n` distributions (2D images) of size `width` x `height`
+ reg : float
+ Regularization term >0
+ weights : array-like, shape (n_hists,)
+ Weights of each image on the simplex (barycentric coodinates)
+ method : string, optional
+ method used for the solver either 'sinkhorn' or 'sinkhorn_log'
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshold on error (> 0)
+ stabThr : float, optional
+ Stabilization threshold to avoid numerical precision issue
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+ warn : bool, optional
+ if True, raises a warning if the algorithm doesn't convergence.
+
+
+ Returns
+ -------
+ a : array-like, shape (width, height)
+ 2D Wasserstein barycenter
+ log : dict
+ log dictionary return only if log==True in parameters
+
+
+ .. _references-convolutional-barycenter2d-debiased:
+ References
+ ----------
+
+ .. [37] Janati, H., Cuturi, M., Gramfort, A. Proceedings of the 37th International
+ Conference on Machine Learning, PMLR 119:4692-4701, 2020
+ """
- t = np.linspace(0, 1, A.shape[2])
- [Y, X] = np.meshgrid(t, t)
- xi2 = np.exp(-(X - Y) ** 2 / reg)
+ if method.lower() == 'sinkhorn':
+ return _convolutional_barycenter2d_debiased(A, reg, weights=weights,
+ numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose,
+ log=log, warn=warn,
+ **kwargs)
+ elif method.lower() == 'sinkhorn_log':
+ return _convolutional_barycenter2d_debiased_log(A, reg, weights=weights,
+ numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose,
+ log=log, warn=warn,
+ **kwargs)
+ else:
+ raise ValueError("Unknown method '%s'." % method)
- def K(x):
- return np.dot(np.dot(xi1, x), xi2)
- while (err > stopThr and cpt < numItermax):
+def _convolutional_barycenter2d_debiased(A, reg, weights=None, numItermax=10000,
+ stopThr=1e-3, stabThr=1e-15, verbose=False,
+ log=False, warn=True):
+ r"""Compute the debiased barycenter of 2D images via sinkhorn convolutions.
+ """
- bold = b
- cpt = cpt + 1
+ A = list_to_array(A)
+ n_hists, width, height = A.shape
- b = np.zeros_like(A[0, :, :])
- for r in range(A.shape[0]):
- KV[r, :, :] = K(A[r, :, :] / np.maximum(stabThr, K(U[r, :, :])))
- b += weights[r] * np.log(np.maximum(stabThr, U[r, :, :] * KV[r, :, :]))
- b = np.exp(b)
- for r in range(A.shape[0]):
- U[r, :, :] = b / np.maximum(stabThr, KV[r, :, :])
+ nx = get_backend(A)
- if cpt % 10 == 1:
- err = np.sum(np.abs(bold - b))
+ if weights is None:
+ weights = nx.ones((n_hists,), type_as=A) / n_hists
+ else:
+ assert (len(weights) == n_hists)
+
+ if log:
+ log = {'err': []}
+
+ bar = nx.ones((width, height), type_as=A)
+ bar /= width * height
+ U = nx.ones(A.shape, type_as=A)
+ V = nx.ones(A.shape, type_as=A)
+ c = nx.ones(A.shape[1:], type_as=A)
+ err = 1
+
+ # build the convolution operator
+ # this is equivalent to blurring on horizontal then vertical directions
+ t = nx.linspace(0, 1, width)
+ [Y, X] = nx.meshgrid(t, t)
+ K1 = nx.exp(-(X - Y) ** 2 / reg)
+
+ t = nx.linspace(0, 1, height)
+ [Y, X] = nx.meshgrid(t, t)
+ K2 = nx.exp(-(X - Y) ** 2 / reg)
+
+ def convol_imgs(imgs):
+ kx = nx.einsum("...ij,kjl->kil", K1, imgs)
+ kxy = nx.einsum("...ij,klj->kli", K2, kx)
+ return kxy
+
+ KU = convol_imgs(U)
+ for ii in range(numItermax):
+ V = bar[None] / KU
+ KV = convol_imgs(V)
+ U = A / KV
+ KU = convol_imgs(U)
+ bar = c * nx.exp((weights[:, None, None] * nx.log(KU + stabThr)).sum(axis=0))
+
+ for _ in range(10):
+ c = (c * bar / convol_imgs(c[None]).squeeze()) ** 0.5
+
+ if ii % 10 == 9:
+ err = (V * KU).std(axis=0).sum()
# log and verbose print
if log:
log['err'].append(err)
if verbose:
- if cpt % 200 == 0:
+ if ii % 200 == 0:
print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
- print('{:5d}|{:8e}|'.format(cpt, err))
+ print('{:5d}|{:8e}|'.format(ii, err))
+ # debiased Sinkhorn does not converge monotonically
+ # guarantee a few iterations are done before stopping
+ if err < stopThr and ii > 20:
+ break
+ else:
+ if warn:
+ warnings.warn("Sinkhorn did not converge. You might want to "
+ "increase the number of iterations `numItermax` "
+ "or the regularization parameter `reg`.")
if log:
- log['niter'] = cpt
+ log['niter'] = ii
log['U'] = U
- return b, log
+ return bar, log
+ else:
+ return bar
+
+
+def _convolutional_barycenter2d_debiased_log(A, reg, weights=None, numItermax=10000,
+ stopThr=1e-3, stabThr=1e-30, verbose=False,
+ log=False, warn=True):
+ r"""Compute the debiased barycenter of 2D images in log-domain.
+ """
+
+ A = list_to_array(A)
+ n_hists, width, height = A.shape
+ nx = get_backend(A)
+ if nx.__name__ == "jax":
+ raise NotImplementedError("Log-domain functions are not yet implemented"
+ " for Jax. Use numpy or torch arrays instead.")
+ if weights is None:
+ weights = nx.ones((n_hists,), type_as=A) / n_hists
+ else:
+ assert (len(weights) == A.shape[0])
+
+ if log:
+ log = {'err': []}
+
+ err = 1
+ # build the convolution operator
+ # this is equivalent to blurring on horizontal then vertical directions
+ t = nx.linspace(0, 1, width)
+ [Y, X] = nx.meshgrid(t, t)
+ M1 = - (X - Y) ** 2 / reg
+
+ t = nx.linspace(0, 1, height)
+ [Y, X] = nx.meshgrid(t, t)
+ M2 = - (X - Y) ** 2 / reg
+
+ def convol_img(log_img):
+ log_img = nx.logsumexp(M1[:, :, None] + log_img[None], axis=1)
+ log_img = nx.logsumexp(M2[:, :, None] + log_img.T[None], axis=1).T
+ return log_img
+
+ logA = nx.log(A + stabThr)
+ log_bar, c = nx.zeros((2, width, height), type_as=A)
+ log_KU, G, F = nx.zeros((3, *logA.shape), type_as=A)
+ err = 1
+ for ii in range(numItermax):
+ log_bar = nx.zeros((width, height), type_as=A)
+ for k in range(n_hists):
+ f = logA[k] - convol_img(G[k])
+ log_KU[k] = convol_img(f)
+ log_bar = log_bar + weights[k] * log_KU[k]
+ log_bar += c
+ for _ in range(10):
+ c = 0.5 * (c + log_bar - convol_img(c))
+
+ if ii % 10 == 9:
+ err = nx.exp(G + log_KU).std(axis=0).sum()
+ # log and verbose print
+ if log:
+ log['err'].append(err)
+
+ if verbose:
+ if ii % 200 == 0:
+ print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(ii, err))
+ if err < stopThr and ii > 20:
+ break
+ G = log_bar[None, :, :] - log_KU
+
+ else:
+ if warn:
+ warnings.warn("Convolutional Sinkhorn did not converge. "
+ "Try a larger number of iterations `numItermax` "
+ "or a larger entropy `reg`.")
+ if log:
+ log['niter'] = ii
+ return nx.exp(log_bar), log
else:
- return b
+ return nx.exp(log_bar)
def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000,
- stopThr=1e-3, verbose=False, log=False):
+ stopThr=1e-3, verbose=False, log=False, warn=True):
r"""
Compute the unmixing of an observation with a given dictionary using Wasserstein distance
The function solve the following optimization problem:
.. math::
- \mathbf{h} = arg\min_\mathbf{h} (1- \\alpha) W_{M,reg}(\mathbf{a},\mathbf{Dh})+\\alpha W_{M0,reg0}(\mathbf{h}_0,\mathbf{h})
+
+ \mathbf{h} = \mathop{\arg \min}_\mathbf{h} \quad
+ (1 - \alpha) W_{\mathbf{M}, \mathrm{reg}}(\mathbf{a}, \mathbf{Dh}) +
+ \alpha W_{\mathbf{M_0}, \mathrm{reg}_0}(\mathbf{h}_0, \mathbf{h})
where :
- - :math:`W_{M,reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance with M loss matrix (see ot.bregman.sinkhorn)
- - :math: `\mathbf{D}` is a dictionary of `n_atoms` atoms of dimension `dim_a`, its expected shape is `(dim_a, n_atoms)`
+ - :math:`W_{M,reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance
+ with :math:`\mathbf{M}` loss matrix (see :py:func:`ot.bregman.sinkhorn`)
+ - :math:`\mathbf{D}` is a dictionary of `n_atoms` atoms of dimension `dim_a`,
+ its expected shape is `(dim_a, n_atoms)`
- :math:`\mathbf{h}` is the estimated unmixing of dimension `n_atoms`
- :math:`\mathbf{a}` is an observed distribution of dimension `dim_a`
- - :math:`\mathbf{h}_0` is a prior on `h` of dimension `dim_prior`
- - reg and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix (dim_a, dim_a) for OT data fitting
- - reg0 and :math:`\mathbf{M0}` are respectively the regularization term and the cost matrix (dim_prior, n_atoms) regularization
- - :math:`\\alpha`weight data fitting and regularization
+ - :math:`\mathbf{h}_0` is a prior on :math:`\mathbf{h}` of dimension `dim_prior`
+ - `reg` and :math:`\mathbf{M}` are respectively the regularization term and the
+ cost matrix (`dim_a`, `dim_a`) for OT data fitting
+ - `reg`:math:`_0` and :math:`\mathbf{M_0}` are respectively the regularization
+ term and the cost matrix (`dim_prior`, `n_atoms`) regularization
+ - :math:`\alpha` weight data fitting and regularization
- The optimization problem is solved suing the algorithm described in [4]
+ The optimization problem is solved following the algorithm described
+ in :ref:`[4] <references-unmix>`
Parameters
----------
- a : ndarray, shape (dim_a)
+ a : array-like, shape (dim_a)
observed distribution (histogram, sums to 1)
- D : ndarray, shape (dim_a, n_atoms)
+ D : array-like, shape (dim_a, n_atoms)
dictionary matrix
- M : ndarray, shape (dim_a, dim_a)
+ M : array-like, shape (dim_a, dim_a)
loss matrix
- M0 : ndarray, shape (n_atoms, dim_prior)
+ M0 : array-like, shape (n_atoms, dim_prior)
loss matrix
- h0 : ndarray, shape (n_atoms,)
+ h0 : array-like, shape (n_atoms,)
prior on the estimated unmixing h
reg : float
Regularization term >0 (Wasserstein data fitting)
@@ -1433,105 +2461,125 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000,
numItermax : int, optional
Max number of iterations
stopThr : float, optional
- Stop threshol on error (>0)
+ Stop threshold on error (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
record log if True
-
+ warn : bool, optional
+ if True, raises a warning if the algorithm doesn't convergence.
Returns
-------
- h : ndarray, shape (n_atoms,)
+ h : array-like, shape (n_atoms,)
Wasserstein barycenter
log : dict
log dictionary return only if log==True in parameters
+
+ .. _references-unmix:
References
----------
- .. [4] S. Nakhostin, N. Courty, R. Flamary, D. Tuia, T. Corpetti, Supervised planetary unmixing with optimal transport, Whorkshop on Hyperspectral Image and Signal Processing : Evolution in Remote Sensing (WHISPERS), 2016.
-
+ .. [4] S. Nakhostin, N. Courty, R. Flamary, D. Tuia, T. Corpetti,
+ Supervised planetary unmixing with optimal transport, Whorkshop
+ on Hyperspectral Image and Signal Processing :
+ Evolution in Remote Sensing (WHISPERS), 2016.
"""
+ a, D, M, M0, h0 = list_to_array(a, D, M, M0, h0)
+
+ nx = get_backend(a, D, M, M0, h0)
+
# M = M/np.median(M)
- K = np.exp(-M / reg)
+ K = nx.exp(-M / reg)
# M0 = M0/np.median(M0)
- K0 = np.exp(-M0 / reg0)
+ K0 = nx.exp(-M0 / reg0)
old = h0
err = 1
- cpt = 0
# log = {'niter':0, 'all_err':[]}
if log:
log = {'err': []}
- while (err > stopThr and cpt < numItermax):
+ for ii in range(numItermax):
K = projC(K, a)
K0 = projC(K0, h0)
- new = np.sum(K0, axis=1)
+ new = nx.sum(K0, axis=1)
# we recombine the current selection from dictionnary
- inv_new = np.dot(D, new)
- other = np.sum(K, axis=1)
+ inv_new = nx.dot(D, new)
+ other = nx.sum(K, axis=1)
# geometric interpolation
- delta = np.exp(alpha * np.log(other) + (1 - alpha) * np.log(inv_new))
+ delta = nx.exp(alpha * nx.log(other) + (1 - alpha) * nx.log(inv_new))
K = projR(K, delta)
- K0 = np.dot(np.diag(np.dot(D.T, delta / inv_new)), K0)
+ K0 = nx.dot(nx.diag(nx.dot(D.T, delta / inv_new)), K0)
- err = np.linalg.norm(np.sum(K0, axis=1) - old)
+ err = nx.norm(nx.sum(K0, axis=1) - old)
old = new
if log:
log['err'].append(err)
if verbose:
- if cpt % 200 == 0:
+ if ii % 200 == 0:
print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
- print('{:5d}|{:8e}|'.format(cpt, err))
-
- cpt = cpt + 1
-
+ print('{:5d}|{:8e}|'.format(ii, err))
+ if err < stopThr:
+ break
+ else:
+ if warn:
+ warnings.warn("Unmixing algorithm did not converge. You might want to "
+ "increase the number of iterations `numItermax` "
+ "or the regularization parameter `reg`.")
if log:
- log['niter'] = cpt
- return np.sum(K0, axis=1), log
+ log['niter'] = ii
+ return nx.sum(K0, axis=1), log
else:
- return np.sum(K0, axis=1)
+ return nx.sum(K0, axis=1)
def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100,
- stopThr=1e-6, verbose=False, log=False, **kwargs):
- r'''Joint OT and proportion estimation for multi-source target shift as proposed in [27]
+ stopThr=1e-6, verbose=False, log=False, warn=True, **kwargs):
+ r'''Joint OT and proportion estimation for multi-source target shift as
+ proposed in :ref:`[27] <references-jcpot-barycenter>`
The function solves the following optimization problem:
.. math::
- \mathbf{h} = arg\min_{\mathbf{h}}\quad \sum_{k=1}^{K} \lambda_k
+ \mathbf{h} = \mathop{\arg \min}_{\mathbf{h}} \quad \sum_{k=1}^{K} \lambda_k
W_{reg}((\mathbf{D}_2^{(k)} \mathbf{h})^T, \mathbf{a})
s.t. \ \forall k, \mathbf{D}_1^{(k)} \gamma_k \mathbf{1}_n= \mathbf{h}
where :
- - :math:`\lambda_k` is the weight of k-th source domain
- - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see ot.bregman.sinkhorn)
- - :math:`\mathbf{D}_2^{(k)}` is a matrix of weights related to k-th source domain defined as in [p. 5, 27], its expected shape is `(n_k, C)` where `n_k` is the number of elements in the k-th source domain and `C` is the number of classes
- - :math:`\mathbf{h}` is a vector of estimated proportions in the target domain of size C
+ - :math:`\lambda_k` is the weight of `k`-th source domain
+ - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance
+ (see :py:func:`ot.bregman.sinkhorn`)
+ - :math:`\mathbf{D}_2^{(k)}` is a matrix of weights related to `k`-th source domain
+ defined as in [p. 5, :ref:`27 <references-jcpot-barycenter>`], its expected shape
+ is :math:`(n_k, C)` where :math:`n_k` is the number of elements in the `k`-th source
+ domain and `C` is the number of classes
+ - :math:`\mathbf{h}` is a vector of estimated proportions in the target domain of size `C`
- :math:`\mathbf{a}` is a uniform vector of weights in the target domain of size `n`
- - :math:`\mathbf{D}_1^{(k)}` is a matrix of class assignments defined as in [p. 5, 27], its expected shape is `(n_k, C)`
+ - :math:`\mathbf{D}_1^{(k)}` is a matrix of class assignments defined as in
+ [p. 5, :ref:`27 <references-jcpot-barycenter>`], its expected shape is :math:`(n_k, C)`
- The problem consist in solving a Wasserstein barycenter problem to estimate the proportions :math:`\mathbf{h}` in the target domain.
+ The problem consist in solving a Wasserstein barycenter problem to estimate
+ the proportions :math:`\mathbf{h}` in the target domain.
The algorithm used for solving the problem is the Iterative Bregman projections algorithm
- with two sets of marginal constraints related to the unknown vector :math:`\mathbf{h}` and uniform target distribution.
+ with two sets of marginal constraints related to the unknown vector
+ :math:`\mathbf{h}` and uniform target distribution.
Parameters
----------
- Xs : list of K np.ndarray(nsk,d)
+ Xs : list of K array-like(nsk,d)
features of all source domains' samples
- Ys : list of K np.ndarray(nsk,)
+ Ys : list of K array-like(nsk,)
labels of all source domains' samples
- Xt : np.ndarray (nt,d)
+ Xt : array-like (nt,d)
samples in the target domain
reg : float
Regularization term > 0
@@ -1541,28 +2589,37 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100,
Max number of iterations
stopThr : float, optional
Stop threshold on relative change in the barycenter (>0)
- log : bool, optional
- record log if True
verbose : bool, optional (default=False)
Controls the verbosity of the optimization algorithm
+ log : bool, optional
+ record log if True
+ warn : bool, optional
+ if True, raises a warning if the algorithm doesn't convergence.
Returns
-------
- h : (C,) ndarray
+ h : (C,) array-like
proportion estimation in the target domain
log : dict
log dictionary return only if log==True in parameters
+ .. _references-jcpot-barycenter:
References
----------
.. [27] Ievgen Redko, Nicolas Courty, Rémi Flamary, Devis Tuia
- "Optimal transport for multi-source domain adaptation under target shift",
- International Conference on Artificial Intelligence and Statistics (AISTATS), 2019.
-
+ "Optimal transport for multi-source domain adaptation under target shift",
+ International Conference on Artificial Intelligence and Statistics (AISTATS), 2019.
'''
- nbclasses = len(np.unique(Ys[0]))
+
+ Xs = list_to_array(*Xs)
+ Ys = list_to_array(*Ys)
+ Xt = list_to_array(Xt)
+
+ nx = get_backend(*Xs, *Ys, Xt)
+
+ nbclasses = len(nx.unique(Ys[0]))
nbdomains = len(Xs)
# log dictionary
@@ -1579,19 +2636,19 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100,
dom = {}
nsk = Xs[d].shape[0] # get number of elements for this domain
dom['nbelem'] = nsk
- classes = np.unique(Ys[d]) # get number of classes for this domain
+ classes = nx.unique(Ys[d]) # get number of classes for this domain
# format classes to start from 0 for convenience
- if np.min(classes) != 0:
- Ys[d] = Ys[d] - np.min(classes)
- classes = np.unique(Ys[d])
+ if nx.min(classes) != 0:
+ Ys[d] -= nx.min(classes)
+ classes = nx.unique(Ys[d])
# build the corresponding D_1 and D_2 matrices
- Dtmp1 = np.zeros((nbclasses, nsk))
- Dtmp2 = np.zeros((nbclasses, nsk))
+ Dtmp1 = nx.zeros((nbclasses, nsk), type_as=Xs[0])
+ Dtmp2 = nx.zeros((nbclasses, nsk), type_as=Xs[0])
for c in classes:
- nbelemperclass = np.sum(Ys[d] == c)
+ nbelemperclass = nx.sum(Ys[d] == c)
if nbelemperclass != 0:
Dtmp1[int(c), Ys[d] == c] = 1.
Dtmp2[int(c), Ys[d] == c] = 1. / (nbelemperclass)
@@ -1602,51 +2659,54 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100,
Mtmp = dist(Xs[d], Xt, metric=metric)
M.append(Mtmp)
- Ktmp = np.empty(Mtmp.shape, dtype=Mtmp.dtype)
- np.divide(Mtmp, -reg, out=Ktmp)
- np.exp(Ktmp, out=Ktmp)
+ Ktmp = nx.exp(-Mtmp / reg)
K.append(Ktmp)
# uniform target distribution
- a = unif(np.shape(Xt)[0])
+ a = nx.from_numpy(unif(Xt.shape[0]), type_as=Xs[0])
- cpt = 0 # iterations count
err = 1
- old_bary = np.ones((nbclasses))
+ old_bary = nx.ones((nbclasses,), type_as=Xs[0])
- while (err > stopThr and cpt < numItermax):
+ for ii in range(numItermax):
- bary = np.zeros((nbclasses))
+ bary = nx.zeros((nbclasses,), type_as=Xs[0])
# update coupling matrices for marginal constraints w.r.t. uniform target distribution
for d in range(nbdomains):
K[d] = projC(K[d], a)
- other = np.sum(K[d], axis=1)
- bary = bary + np.log(np.dot(D1[d], other)) / nbdomains
+ other = nx.sum(K[d], axis=1)
+ bary += nx.log(nx.dot(D1[d], other)) / nbdomains
- bary = np.exp(bary)
+ bary = nx.exp(bary)
# update coupling matrices for marginal constraints w.r.t. unknown proportions based on [Prop 4., 27]
for d in range(nbdomains):
- new = np.dot(D2[d].T, bary)
+ new = nx.dot(D2[d].T, bary)
K[d] = projR(K[d], new)
- err = np.linalg.norm(bary - old_bary)
- cpt = cpt + 1
+ err = nx.norm(bary - old_bary)
+
old_bary = bary
if log:
log['err'].append(err)
+ if err < stopThr:
+ break
if verbose:
- if cpt % 200 == 0:
+ if ii % 200 == 0:
print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
- print('{:5d}|{:8e}|'.format(cpt, err))
-
- bary = bary / np.sum(bary)
+ print('{:5d}|{:8e}|'.format(ii, err))
+ else:
+ if warn:
+ warnings.warn("Algorithm did not converge. You might want to "
+ "increase the number of iterations `numItermax` "
+ "or the regularization parameter `reg`.")
+ bary = bary / nx.sum(bary)
if log:
- log['niter'] = cpt
+ log['niter'] = ii
log['M'] = M
log['D1'] = D1
log['D2'] = D2
@@ -1657,8 +2717,8 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100,
def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
- numIterMax=10000, stopThr=1e-9, verbose=False,
- log=False, **kwargs):
+ numIterMax=10000, stopThr=1e-9, isLazy=False, batchSize=100, verbose=False,
+ log=False, warn=True, **kwargs):
r'''
Solve the entropic regularization optimal transport problem and return the
OT matrix from empirical data
@@ -1666,45 +2726,56 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
The function solves the following optimization problem:
.. math::
- \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)
+ \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
+ \mathrm{reg} \cdot\Omega(\gamma)
- s.t. \gamma 1 = a
+ s.t. \ \gamma \mathbf{1} &= \mathbf{a}
- \gamma^T 1= b
+ \gamma^T \mathbf{1} &= \mathbf{b}
- \gamma\geq 0
+ \gamma &\geq 0
where :
- - :math:`M` is the (n_samples_a, n_samples_b) metric cost matrix
- - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - :math:`a` and :math:`b` are source and target weights (sum to 1)
+ - :math:`\mathbf{M}` is the (`n_samples_a`, `n_samples_b`) metric cost matrix
+ - :math:`\Omega` is the entropic regularization term
+ :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1)
Parameters
----------
- X_s : ndarray, shape (n_samples_a, dim)
+ X_s : array-like, shape (n_samples_a, dim)
samples in the source domain
- X_t : ndarray, shape (n_samples_b, dim)
+ X_t : array-like, shape (n_samples_b, dim)
samples in the target domain
reg : float
Regularization term >0
- a : ndarray, shape (n_samples_a,)
+ a : array-like, shape (n_samples_a,)
samples weights in the source domain
- b : ndarray, shape (n_samples_b,)
+ b : array-like, shape (n_samples_b,)
samples weights in the target domain
numItermax : int, optional
Max number of iterations
stopThr : float, optional
- Stop threshol on error (>0)
+ Stop threshold on error (>0)
+ isLazy: boolean, optional
+ If True, then only calculate the cost matrix by block and return
+ the dual potentials only (to save memory). If False, calculate full
+ cost matrix and return outputs of sinkhorn function.
+ batchSize: int or tuple of 2 int, optional
+ Size of the batches used to compute the sinkhorn update without memory overhead.
+ When a tuple is provided it sets the size of the left/right batches.
verbose : bool, optional
Print information along iterations
log : bool, optional
record log if True
+ warn : bool, optional
+ if True, raises a warning if the algorithm doesn't convergence.
Returns
-------
- gamma : ndarray, shape (n_samples_a, n_samples_b)
+ gamma : array-like, shape (n_samples_a, n_samples_b)
Regularized optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
@@ -1715,9 +2786,9 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
>>> n_samples_a = 2
>>> n_samples_b = 2
>>> reg = 0.1
- >>> X_s = np.reshape(np.arange(n_samples_a), (n_samples_a, 1))
- >>> X_t = np.reshape(np.arange(0, n_samples_b), (n_samples_b, 1))
- >>> empirical_sinkhorn(X_s, X_t, reg, verbose=False) # doctest: +NORMALIZE_WHITESPACE
+ >>> X_s = np.reshape(np.arange(n_samples_a, dtype=np.float64), (n_samples_a, 1))
+ >>> X_t = np.reshape(np.arange(0, n_samples_b, dtype=np.float64), (n_samples_b, 1))
+ >>> empirical_sinkhorn(X_s, X_t, reg=reg, verbose=False) # doctest: +NORMALIZE_WHITESPACE
array([[4.99977301e-01, 2.26989344e-05],
[2.26989344e-05, 4.99977301e-01]])
@@ -1725,30 +2796,115 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
References
----------
- .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
+ .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal
+ Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
- .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519.
+ .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for
+ Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519.
- .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
+ .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
+ Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
'''
+ X_s, X_t = list_to_array(X_s, X_t)
+
+ nx = get_backend(X_s, X_t)
+
+ ns, nt = X_s.shape[0], X_t.shape[0]
if a is None:
- a = unif(np.shape(X_s)[0])
+ a = nx.from_numpy(unif(ns), type_as=X_s)
if b is None:
- b = unif(np.shape(X_t)[0])
+ b = nx.from_numpy(unif(nt), type_as=X_s)
+
+ if isLazy:
+ if log:
+ dict_log = {"err": []}
- M = dist(X_s, X_t, metric=metric)
+ log_a, log_b = nx.log(a), nx.log(b)
+ f, g = nx.zeros((ns,), type_as=a), nx.zeros((nt,), type_as=a)
+
+ if isinstance(batchSize, int):
+ bs, bt = batchSize, batchSize
+ elif isinstance(batchSize, tuple) and len(batchSize) == 2:
+ bs, bt = batchSize[0], batchSize[1]
+ else:
+ raise ValueError("Batch size must be in integer or a tuple of two integers")
+
+ range_s, range_t = range(0, ns, bs), range(0, nt, bt)
+
+ lse_f = nx.zeros((ns,), type_as=a)
+ lse_g = nx.zeros((nt,), type_as=a)
+
+ X_s_np = nx.to_numpy(X_s)
+ X_t_np = nx.to_numpy(X_t)
+
+ for i_ot in range(numIterMax):
+
+ lse_f_cols = []
+ for i in range_s:
+ M = dist(X_s_np[i:i + bs, :], X_t_np, metric=metric)
+ M = nx.from_numpy(M, type_as=a)
+ lse_f_cols.append(
+ nx.logsumexp(g[None, :] - M / reg, axis=1)
+ )
+ lse_f = nx.concatenate(lse_f_cols, axis=0)
+ f = log_a - lse_f
+
+ lse_g_cols = []
+ for j in range_t:
+ M = dist(X_s_np, X_t_np[j:j + bt, :], metric=metric)
+ M = nx.from_numpy(M, type_as=a)
+ lse_g_cols.append(
+ nx.logsumexp(f[:, None] - M / reg, axis=0)
+ )
+ lse_g = nx.concatenate(lse_g_cols, axis=0)
+ g = log_b - lse_g
+
+ if (i_ot + 1) % 10 == 0:
+ m1_cols = []
+ for i in range_s:
+ M = dist(X_s_np[i:i + bs, :], X_t_np, metric=metric)
+ M = nx.from_numpy(M, type_as=a)
+ m1_cols.append(
+ nx.sum(nx.exp(f[i:i + bs, None] + g[None, :] - M / reg), axis=1)
+ )
+ m1 = nx.concatenate(m1_cols, axis=0)
+ err = nx.sum(nx.abs(m1 - a))
+ if log:
+ dict_log["err"].append(err)
+
+ if verbose and (i_ot + 1) % 100 == 0:
+ print("Error in marginal at iteration {} = {}".format(i_ot + 1, err))
+
+ if err <= stopThr:
+ break
+ else:
+ if warn:
+ warnings.warn("Sinkhorn did not converge. You might want to "
+ "increase the number of iterations `numItermax` "
+ "or the regularization parameter `reg`.")
+ if log:
+ dict_log["u"] = f
+ dict_log["v"] = g
+ return (f, g, dict_log)
+ else:
+ return (f, g)
- if log:
- pi, log = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=True, **kwargs)
- return pi, log
else:
- pi = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=False, **kwargs)
- return pi
+ M = dist(X_s, X_t, metric=metric)
+ if log:
+ pi, log = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr,
+ verbose=verbose, log=True, **kwargs)
+ return pi, log
+ else:
+ pi = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr,
+ verbose=verbose, log=False, **kwargs)
+ return pi
-def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9,
- verbose=False, log=False, **kwargs):
+def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
+ numIterMax=10000, stopThr=1e-9, isLazy=False,
+ batchSize=100, verbose=False, log=False, warn=True, **kwargs):
r'''
Solve the entropic regularization optimal transport problem from empirical
data and return the OT loss
@@ -1757,46 +2913,57 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num
The function solves the following optimization problem:
.. math::
- W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)
+ W = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
+ \mathrm{reg} \cdot\Omega(\gamma)
- s.t. \gamma 1 = a
+ s.t. \ \gamma \mathbf{1} &= \mathbf{a}
- \gamma^T 1= b
+ \gamma^T \mathbf{1} &= \mathbf{b}
- \gamma\geq 0
+ \gamma &\geq 0
where :
- - :math:`M` is the (n_samples_a, n_samples_b) metric cost matrix
- - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - :math:`a` and :math:`b` are source and target weights (sum to 1)
+ - :math:`\mathbf{M}` is the (`n_samples_a`, `n_samples_b`) metric cost matrix
+ - :math:`\Omega` is the entropic regularization term
+ :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1)
Parameters
----------
- X_s : ndarray, shape (n_samples_a, dim)
+ X_s : array-like, shape (n_samples_a, dim)
samples in the source domain
- X_t : ndarray, shape (n_samples_b, dim)
+ X_t : array-like, shape (n_samples_b, dim)
samples in the target domain
reg : float
Regularization term >0
- a : ndarray, shape (n_samples_a,)
+ a : array-like, shape (n_samples_a,)
samples weights in the source domain
- b : ndarray, shape (n_samples_b,)
+ b : array-like, shape (n_samples_b,)
samples weights in the target domain
numItermax : int, optional
Max number of iterations
stopThr : float, optional
- Stop threshol on error (>0)
+ Stop threshold on error (>0)
+ isLazy: boolean, optional
+ If True, then only calculate the cost matrix by block and return
+ the dual potentials only (to save memory). If False, calculate
+ full cost matrix and return outputs of sinkhorn function.
+ batchSize: int or tuple of 2 int, optional
+ Size of the batches used to compute the sinkhorn update without memory overhead.
+ When a tuple is provided it sets the size of the left/right batches.
verbose : bool, optional
Print information along iterations
log : bool, optional
record log if True
+ warn : bool, optional
+ if True, raises a warning if the algorithm doesn't convergence.
Returns
-------
- gamma : ndarray, shape (n_samples_a, n_samples_b)
- Regularized optimal transportation matrix for the given parameters
+ W : (n_hists) array-like or float
+ Optimal transportation loss for the given parameters
log : dict
log dictionary return only if log==True in parameters
@@ -1806,41 +2973,94 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num
>>> n_samples_a = 2
>>> n_samples_b = 2
>>> reg = 0.1
- >>> X_s = np.reshape(np.arange(n_samples_a), (n_samples_a, 1))
- >>> X_t = np.reshape(np.arange(0, n_samples_b), (n_samples_b, 1))
- >>> empirical_sinkhorn2(X_s, X_t, reg, verbose=False)
- array([4.53978687e-05])
+ >>> X_s = np.reshape(np.arange(n_samples_a, dtype=np.float64), (n_samples_a, 1))
+ >>> X_t = np.reshape(np.arange(0, n_samples_b, dtype=np.float64), (n_samples_b, 1))
+ >>> b = np.full((n_samples_b, 3), 1/n_samples_b)
+ >>> empirical_sinkhorn2(X_s, X_t, b=b, reg=reg, verbose=False)
+ array([4.53978687e-05, 4.53978687e-05, 4.53978687e-05])
References
----------
- .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
+ .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation
+ of Optimal Transport, Advances in Neural Information
+ Processing Systems (NIPS) 26, 2013
- .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519.
+ .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling
+ Algorithms for Entropy Regularized Transport Problems.
+ arXiv preprint arXiv:1610.06519.
- .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
+ .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
+ Scaling algorithms for unbalanced transport problems.
+ arXiv preprint arXiv:1607.05816.
'''
+ X_s, X_t = list_to_array(X_s, X_t)
+
+ nx = get_backend(X_s, X_t)
+
+ ns, nt = X_s.shape[0], X_t.shape[0]
if a is None:
- a = unif(np.shape(X_s)[0])
+ a = nx.from_numpy(unif(ns), type_as=X_s)
if b is None:
- b = unif(np.shape(X_t)[0])
+ b = nx.from_numpy(unif(nt), type_as=X_s)
- M = dist(X_s, X_t, metric=metric)
+ if isLazy:
+ if log:
+ f, g, dict_log = empirical_sinkhorn(X_s, X_t, reg, a, b, metric,
+ numIterMax=numIterMax,
+ stopThr=stopThr,
+ isLazy=isLazy,
+ batchSize=batchSize,
+ verbose=verbose, log=log,
+ warn=warn)
+ else:
+ f, g = empirical_sinkhorn(X_s, X_t, reg, a, b, metric,
+ numIterMax=numIterMax, stopThr=stopThr,
+ isLazy=isLazy, batchSize=batchSize,
+ verbose=verbose, log=log,
+ warn=warn)
+
+ bs = batchSize if isinstance(batchSize, int) else batchSize[0]
+ range_s = range(0, ns, bs)
+
+ loss = 0
+
+ X_s_np = nx.to_numpy(X_s)
+ X_t_np = nx.to_numpy(X_t)
+
+ for i in range_s:
+ M_block = dist(X_s_np[i:i + bs, :], X_t_np, metric=metric)
+ M_block = nx.from_numpy(M_block, type_as=a)
+ pi_block = nx.exp(f[i:i + bs, None] + g[None, :] - M_block / reg)
+ loss += nx.sum(M_block * pi_block)
+
+ if log:
+ return loss, dict_log
+ else:
+ return loss
- if log:
- sinkhorn_loss, log = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log,
- **kwargs)
- return sinkhorn_loss, log
else:
- sinkhorn_loss = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log,
- **kwargs)
- return sinkhorn_loss
+ M = dist(nx.to_numpy(X_s), nx.to_numpy(X_t), metric=metric)
+ M = nx.from_numpy(M, type_as=a)
+
+ if log:
+ sinkhorn_loss, log = sinkhorn2(a, b, M, reg, numItermax=numIterMax,
+ stopThr=stopThr, verbose=verbose, log=log,
+ warn=warn, **kwargs)
+ return sinkhorn_loss, log
+ else:
+ sinkhorn_loss = sinkhorn2(a, b, M, reg, numItermax=numIterMax,
+ stopThr=stopThr, verbose=verbose, log=log,
+ warn=warn, **kwargs)
+ return sinkhorn_loss
-def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9,
- verbose=False, log=False, **kwargs):
+def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
+ numIterMax=10000, stopThr=1e-9,
+ verbose=False, log=False, warn=True,
+ **kwargs):
r'''
Compute the sinkhorn divergence loss from empirical data
@@ -1849,64 +3069,72 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
.. math::
- W &= \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)
+ W &= \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
+ \mathrm{reg} \cdot\Omega(\gamma)
- W_a &= \min_{\gamma_a} <\gamma_a,M_a>_F + reg\cdot\Omega(\gamma_a)
+ W_a &= \min_{\gamma_a} \quad \langle \gamma_a, \mathbf{M_a} \rangle_F +
+ \mathrm{reg} \cdot\Omega(\gamma_a)
- W_b &= \min_{\gamma_b} <\gamma_b,M_b>_F + reg\cdot\Omega(\gamma_b)
+ W_b &= \min_{\gamma_b} \quad \langle \gamma_b, \mathbf{M_b} \rangle_F +
+ \mathrm{reg} \cdot\Omega(\gamma_b)
- S &= W - 1/2 * (W_a + W_b)
+ S &= W - \frac{W_a + W_b}{2}
.. math::
- s.t. \gamma 1 = a
+ s.t. \ \gamma \mathbf{1} &= \mathbf{a}
- \gamma^T 1= b
+ \gamma^T \mathbf{1} &= \mathbf{b}
- \gamma\geq 0
+ \gamma &\geq 0
- \gamma_a 1 = a
+ \gamma_a \mathbf{1} &= \mathbf{a}
- \gamma_a^T 1= a
+ \gamma_a^T \mathbf{1} &= \mathbf{a}
- \gamma_a\geq 0
+ \gamma_a &\geq 0
- \gamma_b 1 = b
+ \gamma_b \mathbf{1} &= \mathbf{b}
- \gamma_b^T 1= b
+ \gamma_b^T \mathbf{1} &= \mathbf{b}
- \gamma_b\geq 0
+ \gamma_b &\geq 0
where :
- - :math:`M` (resp. :math:`M_a, M_b`) is the (n_samples_a, n_samples_b) metric cost matrix (resp (n_samples_a, n_samples_a) and (n_samples_b, n_samples_b))
- - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - :math:`a` and :math:`b` are source and target weights (sum to 1)
+ - :math:`\mathbf{M}` (resp. :math:`\mathbf{M_a}`, :math:`\mathbf{M_b}`)
+ is the (`n_samples_a`, `n_samples_b`) metric cost matrix
+ (resp (`n_samples_a, n_samples_a`) and (`n_samples_b`, `n_samples_b`))
+ - :math:`\Omega` is the entropic regularization term
+ :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1)
Parameters
----------
- X_s : ndarray, shape (n_samples_a, dim)
+ X_s : array-like, shape (n_samples_a, dim)
samples in the source domain
- X_t : ndarray, shape (n_samples_b, dim)
+ X_t : array-like, shape (n_samples_b, dim)
samples in the target domain
reg : float
Regularization term >0
- a : ndarray, shape (n_samples_a,)
+ a : array-like, shape (n_samples_a,)
samples weights in the source domain
- b : ndarray, shape (n_samples_b,)
+ b : array-like, shape (n_samples_b,)
samples weights in the target domain
numItermax : int, optional
Max number of iterations
stopThr : float, optional
- Stop threshol on error (>0)
+ Stop threshold on error (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
record log if True
+ warn : bool, optional
+ if True, raises a warning if the algorithm doesn't convergence.
Returns
-------
- gamma : ndarray, shape (n_samples_a, n_samples_b)
- Regularized optimal transportation matrix for the given parameters
+ W : (1,) array-like
+ Optimal transportation symmetrized loss for the given parameters
log : dict
log dictionary return only if log==True in parameters
@@ -1915,27 +3143,36 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
>>> n_samples_a = 2
>>> n_samples_b = 4
>>> reg = 0.1
- >>> X_s = np.reshape(np.arange(n_samples_a), (n_samples_a, 1))
- >>> X_t = np.reshape(np.arange(0, n_samples_b), (n_samples_b, 1))
+ >>> X_s = np.reshape(np.arange(n_samples_a, dtype=np.float64), (n_samples_a, 1))
+ >>> X_t = np.reshape(np.arange(0, n_samples_b, dtype=np.float64), (n_samples_b, 1))
>>> empirical_sinkhorn_divergence(X_s, X_t, reg) # doctest: +ELLIPSIS
- array([1.499...])
+ 1.499887176049052
References
----------
- .. [23] Aude Genevay, Gabriel Peyré, Marco Cuturi, Learning Generative Models with Sinkhorn Divergences, Proceedings of the Twenty-First International Conference on Artficial Intelligence and Statistics, (AISTATS) 21, 2018
+ .. [23] Aude Genevay, Gabriel Peyré, Marco Cuturi, Learning Generative
+ Models with Sinkhorn Divergences, Proceedings of the Twenty-First
+ International Conference on Artficial Intelligence and Statistics,
+ (AISTATS) 21, 2018
'''
if log:
- sinkhorn_loss_ab, log_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, numIterMax=numIterMax,
- stopThr=1e-9, verbose=verbose, log=log, **kwargs)
+ sinkhorn_loss_ab, log_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric,
+ numIterMax=numIterMax,
+ stopThr=1e-9, verbose=verbose,
+ log=log, warn=warn, **kwargs)
- sinkhorn_loss_a, log_a = empirical_sinkhorn2(X_s, X_s, reg, a, b, metric=metric, numIterMax=numIterMax,
- stopThr=1e-9, verbose=verbose, log=log, **kwargs)
+ sinkhorn_loss_a, log_a = empirical_sinkhorn2(X_s, X_s, reg, a, a, metric=metric,
+ numIterMax=numIterMax,
+ stopThr=1e-9, verbose=verbose,
+ log=log, warn=warn, **kwargs)
- sinkhorn_loss_b, log_b = empirical_sinkhorn2(X_t, X_t, reg, a, b, metric=metric, numIterMax=numIterMax,
- stopThr=1e-9, verbose=verbose, log=log, **kwargs)
+ sinkhorn_loss_b, log_b = empirical_sinkhorn2(X_t, X_t, reg, b, b, metric=metric,
+ numIterMax=numIterMax,
+ stopThr=1e-9, verbose=verbose,
+ log=log, warn=warn, **kwargs)
- sinkhorn_div = sinkhorn_loss_ab - 1 / 2 * (sinkhorn_loss_a + sinkhorn_loss_b)
+ sinkhorn_div = sinkhorn_loss_ab - 0.5 * (sinkhorn_loss_a + sinkhorn_loss_b)
log = {}
log['sinkhorn_loss_ab'] = sinkhorn_loss_ab
@@ -1948,99 +3185,119 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
return max(0, sinkhorn_div), log
else:
- sinkhorn_loss_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9,
- verbose=verbose, log=log, **kwargs)
+ sinkhorn_loss_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric,
+ numIterMax=numIterMax, stopThr=1e-9,
+ verbose=verbose, log=log,
+ warn=warn, **kwargs)
+
+ sinkhorn_loss_a = empirical_sinkhorn2(X_s, X_s, reg, a, a, metric=metric,
+ numIterMax=numIterMax, stopThr=1e-9,
+ verbose=verbose, log=log,
+ warn=warn, **kwargs)
+
+ sinkhorn_loss_b = empirical_sinkhorn2(X_t, X_t, reg, b, b, metric=metric,
+ numIterMax=numIterMax, stopThr=1e-9,
+ verbose=verbose, log=log,
+ warn=warn, **kwargs)
+
+ sinkhorn_div = sinkhorn_loss_ab - 0.5 * (sinkhorn_loss_a + sinkhorn_loss_b)
+ return max(0, sinkhorn_div)
- sinkhorn_loss_a = empirical_sinkhorn2(X_s, X_s, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9,
- verbose=verbose, log=log, **kwargs)
- sinkhorn_loss_b = empirical_sinkhorn2(X_t, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9,
- verbose=verbose, log=log, **kwargs)
+def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False,
+ restricted=True, maxiter=10000, maxfun=10000, pgtol=1e-09,
+ verbose=False, log=False):
+ r"""
+ Screening Sinkhorn Algorithm for Regularized Optimal Transport
- sinkhorn_div = sinkhorn_loss_ab - 1 / 2 * (sinkhorn_loss_a + sinkhorn_loss_b)
- return max(0, sinkhorn_div)
+ The function solves an approximate dual of Sinkhorn divergence :ref:`[2]
+ <references-screenkhorn>` which is written as the following optimization problem:
+
+ .. math::
+ (\mathbf{u}, \mathbf{v}) = \mathop{\arg \min}_{\mathbf{u}, \mathbf{v}} \quad
+ \mathbf{1}_{ns}^T \mathbf{B}(\mathbf{u}, \mathbf{v}) \mathbf{1}_{nt} -
+ \langle \kappa \mathbf{u}, \mathbf{a} \rangle -
+ \langle \frac{1}{\kappa} \mathbf{v}, \mathbf{b} \rangle
-def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, restricted=True,
- maxiter=10000, maxfun=10000, pgtol=1e-09, verbose=False, log=False):
- r""""
- Screening Sinkhorn Algorithm for Regularized Optimal Transport
+ where:
- The function solves an approximate dual of Sinkhorn divergence [2] which is written as the following optimization problem:
+ .. math::
- ..math::
- (u, v) = \argmin_{u, v} 1_{ns}^T B(u,v) 1_{nt} - <\kappa u, a> - <v/\kappa, b>
+ \mathbf{B}(\mathbf{u}, \mathbf{v}) = \mathrm{diag}(e^\mathbf{u}) \mathbf{K} \mathrm{diag}(e^\mathbf{v}) \text{, with } \mathbf{K} = e^{-\mathbf{M} / \mathrm{reg}} \text{ and}
- where B(u,v) = \diag(e^u) K \diag(e^v), with K = e^{-M/reg} and
+ .. math::
- s.t. e^{u_i} \geq \epsilon / \kappa, for all i \in {1, ..., ns}
+ s.t. \ e^{u_i} &\geq \epsilon / \kappa, \forall i \in \{1, \ldots, ns\}
- e^{v_j} \geq \epsilon \kappa, for all j \in {1, ..., nt}
+ e^{v_j} &\geq \epsilon \kappa, \forall j \in \{1, \ldots, nt\}
- The parameters \kappa and \epsilon are determined w.r.t the couple number budget of points (ns_budget, nt_budget), see Equation (5) in [26]
+ The parameters `kappa` and `epsilon` are determined w.r.t the couple number
+ budget of points (`ns_budget`, `nt_budget`), see Equation (5)
+ in :ref:`[26] <references-screenkhorn>`
Parameters
----------
- a : `numpy.ndarray`, shape=(ns,)
+ a: array-like, shape=(ns,)
samples weights in the source domain
-
- b : `numpy.ndarray`, shape=(nt,)
+ b: array-like, shape=(nt,)
samples weights in the target domain
-
- M : `numpy.ndarray`, shape=(ns, nt)
+ M: array-like, shape=(ns, nt)
Cost matrix
-
- reg : `float`
+ reg: `float`
Level of the entropy regularisation
-
- ns_budget : `int`, deafult=None
- Number budget of points to be keeped in the source domain
- If it is None then 50% of the source sample points will be keeped
-
- nt_budget : `int`, deafult=None
- Number budget of points to be keeped in the target domain
- If it is None then 50% of the target sample points will be keeped
-
- uniform : `bool`, default=False
- If `True`, the source and target distribution are supposed to be uniform, i.e., a_i = 1 / ns and b_j = 1 / nt
-
+ ns_budget: `int`, default=None
+ Number budget of points to be kept in the source domain.
+ If it is None then 50% of the source sample points will be kept
+ nt_budget: `int`, default=None
+ Number budget of points to be kept in the target domain.
+ If it is None then 50% of the target sample points will be kept
+ uniform: `bool`, default=False
+ If `True`, the source and target distribution are supposed to be uniform,
+ i.e., :math:`a_i = 1 / ns` and :math:`b_j = 1 / nt`
restricted : `bool`, default=True
If `True`, a warm-start initialization for the L-BFGS-B solver
using a restricted Sinkhorn algorithm with at most 5 iterations
-
- maxiter : `int`, default=10000
+ maxiter: `int`, default=10000
Maximum number of iterations in LBFGS solver
+ maxfun: `int`, default=10000
+ Maximum number of function evaluations in LBFGS solver
+ pgtol: `float`, default=1e-09
+ Final objective function accuracy in LBFGS solver
+ verbose: `bool`, default=False
+ If `True`, display informations about the cardinals of the active sets
+ and the parameters kappa and epsilon
- maxfun : `int`, default=10000
- Maximum number of function evaluations in LBFGS solver
- pgtol : `float`, default=1e-09
- Final objective function accuracy in LBFGS solver
+ .. admonition:: Dependency
- verbose : `bool`, default=False
- If `True`, dispaly informations about the cardinals of the active sets and the paramerters kappa
- and epsilon
+ To gain more efficiency, :py:func:`ot.bregman.screenkhorn` needs to call the "Bottleneck"
+ package (https://pypi.org/project/Bottleneck/) in the screening pre-processing step.
- Dependency
- ----------
- To gain more efficiency, screenkhorn needs to call the "Bottleneck" package (https://pypi.org/project/Bottleneck/)
- in the screening pre-processing step. If Bottleneck isn't installed, the following error message appears:
- "Bottleneck module doesn't exist. Install it from https://pypi.org/project/Bottleneck/"
+ If Bottleneck isn't installed, the following error message appears:
+
+ "Bottleneck module doesn't exist. Install it from https://pypi.org/project/Bottleneck/"
Returns
-------
- gamma : `numpy.ndarray`, shape=(ns, nt)
+ gamma : array-like, shape=(ns, nt)
Screened optimal transportation matrix for the given parameters
log : `dict`, default=False
Log dictionary return only if log==True in parameters
+ .. _references-screenkhorn:
References
-----------
- .. [26] Alaya M. Z., Bérar M., Gasso G., Rakotomamonjy A. (2019). Screening Sinkhorn Algorithm for Regularized Optimal Transport (NIPS) 33, 2019
+
+ .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport,
+ Advances in Neural Information Processing Systems (NIPS) 26, 2013
+
+ .. [26] Alaya M. Z., Bérar M., Gasso G., Rakotomamonjy A. (2019).
+ Screening Sinkhorn Algorithm for Regularized Optimal Transport (NIPS) 33, 2019
"""
# check if bottleneck module exists
@@ -2048,12 +3305,17 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res
import bottleneck
except ImportError:
warnings.warn(
- "Bottleneck module is not installed. Install it from https://pypi.org/project/Bottleneck/ for better performance.")
+ "Bottleneck module is not installed. Install it from"
+ " https://pypi.org/project/Bottleneck/ for better performance.")
bottleneck = np
- a = np.asarray(a, dtype=np.float64)
- b = np.asarray(b, dtype=np.float64)
- M = np.asarray(M, dtype=np.float64)
+ a, b, M = list_to_array(a, b, M)
+
+ nx = get_backend(M, a, b)
+ if nx.__name__ == "jax":
+ raise TypeError("JAX arrays have been received but screenkhorn is not "
+ "compatible with JAX.")
+
ns, nt = M.shape
# by default, we keep only 50% of the sample data points
@@ -2063,9 +3325,7 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res
nt_budget = int(np.floor(0.5 * nt))
# calculate the Gibbs kernel
- K = np.empty_like(M)
- np.divide(M, -reg, out=K)
- np.exp(K, out=K)
+ K = nx.exp(-M / reg)
def projection(u, epsilon):
u[u <= epsilon] = epsilon
@@ -2077,8 +3337,8 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res
if ns_budget == ns and nt_budget == nt:
# full number of budget points (ns, nt) = (ns_budget, nt_budget)
- Isel = np.ones(ns, dtype=bool)
- Jsel = np.ones(nt, dtype=bool)
+ Isel = nx.from_numpy(np.ones(ns, dtype=bool))
+ Jsel = nx.from_numpy(np.ones(nt, dtype=bool))
epsilon = 0.0
kappa = 1.0
@@ -2094,57 +3354,63 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res
K_IJc = []
K_IcJ = []
- vec_eps_IJc = np.zeros(nt)
- vec_eps_IcJ = np.zeros(ns)
+ vec_eps_IJc = nx.zeros((nt,), type_as=M)
+ vec_eps_IcJ = nx.zeros((ns,), type_as=M)
else:
# sum of rows and columns of K
- K_sum_cols = K.sum(axis=1)
- K_sum_rows = K.sum(axis=0)
+ K_sum_cols = nx.sum(K, axis=1)
+ K_sum_rows = nx.sum(K, axis=0)
if uniform:
if ns / ns_budget < 4:
- aK_sort = np.sort(K_sum_cols)
+ aK_sort = nx.sort(K_sum_cols)
epsilon_u_square = a[0] / aK_sort[ns_budget - 1]
else:
- aK_sort = bottleneck.partition(K_sum_cols, ns_budget - 1)[ns_budget - 1]
+ aK_sort = nx.from_numpy(
+ bottleneck.partition(nx.to_numpy(K_sum_cols), ns_budget - 1)[ns_budget - 1],
+ type_as=M
+ )
epsilon_u_square = a[0] / aK_sort
if nt / nt_budget < 4:
- bK_sort = np.sort(K_sum_rows)
+ bK_sort = nx.sort(K_sum_rows)
epsilon_v_square = b[0] / bK_sort[nt_budget - 1]
else:
- bK_sort = bottleneck.partition(K_sum_rows, nt_budget - 1)[nt_budget - 1]
+ bK_sort = nx.from_numpy(
+ bottleneck.partition(nx.to_numpy(K_sum_rows), nt_budget - 1)[nt_budget - 1],
+ type_as=M
+ )
epsilon_v_square = b[0] / bK_sort
else:
aK = a / K_sum_cols
bK = b / K_sum_rows
- aK_sort = np.sort(aK)[::-1]
+ aK_sort = nx.flip(nx.sort(aK), axis=0)
epsilon_u_square = aK_sort[ns_budget - 1]
- bK_sort = np.sort(bK)[::-1]
+ bK_sort = nx.flip(nx.sort(bK), axis=0)
epsilon_v_square = bK_sort[nt_budget - 1]
# active sets I and J (see Lemma 1 in [26])
Isel = a >= epsilon_u_square * K_sum_cols
Jsel = b >= epsilon_v_square * K_sum_rows
- if sum(Isel) != ns_budget:
+ if nx.sum(Isel) != ns_budget:
if uniform:
aK = a / K_sum_cols
- aK_sort = np.sort(aK)[::-1]
- epsilon_u_square = aK_sort[ns_budget - 1:ns_budget + 1].mean()
+ aK_sort = nx.flip(nx.sort(aK), axis=0)
+ epsilon_u_square = nx.mean(aK_sort[ns_budget - 1:ns_budget + 1])
Isel = a >= epsilon_u_square * K_sum_cols
- ns_budget = sum(Isel)
+ ns_budget = nx.sum(Isel)
- if sum(Jsel) != nt_budget:
+ if nx.sum(Jsel) != nt_budget:
if uniform:
bK = b / K_sum_rows
- bK_sort = np.sort(bK)[::-1]
- epsilon_v_square = bK_sort[nt_budget - 1:nt_budget + 1].mean()
+ bK_sort = nx.flip(nx.sort(bK), axis=0)
+ epsilon_v_square = nx.mean(bK_sort[nt_budget - 1:nt_budget + 1])
Jsel = b >= epsilon_v_square * K_sum_rows
- nt_budget = sum(Jsel)
+ nt_budget = nx.sum(Jsel)
epsilon = (epsilon_u_square * epsilon_v_square) ** (1 / 4)
kappa = (epsilon_v_square / epsilon_u_square) ** (1 / 2)
@@ -2152,7 +3418,8 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res
if verbose:
print("epsilon = %s\n" % epsilon)
print("kappa = %s\n" % kappa)
- print('Cardinality of selected points: |Isel| = %s \t |Jsel| = %s \n' % (sum(Isel), sum(Jsel)))
+ print('Cardinality of selected points: |Isel| = %s \t |Jsel| = %s \n'
+ % (sum(Isel), sum(Jsel)))
# Ic, Jc: complementary of the active sets I and J
Ic = ~Isel
@@ -2162,18 +3429,18 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res
K_IcJ = K[np.ix_(Ic, Jsel)]
K_IJc = K[np.ix_(Isel, Jc)]
- K_min = K_IJ.min()
+ K_min = nx.min(K_IJ)
if K_min == 0:
- K_min = np.finfo(float).tiny
+ K_min = float(np.finfo(float).tiny)
# a_I, b_J, a_Ic, b_Jc
a_I = a[Isel]
b_J = b[Jsel]
if not uniform:
- a_I_min = a_I.min()
- a_I_max = a_I.max()
- b_J_max = b_J.max()
- b_J_min = b_J.min()
+ a_I_min = nx.min(a_I)
+ a_I_max = nx.max(a_I)
+ b_J_max = nx.max(b_J)
+ b_J_min = nx.min(b_J)
else:
a_I_min = a_I[0]
a_I_max = a_I[0]
@@ -2182,33 +3449,37 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res
# box constraints in L-BFGS-B (see Proposition 1 in [26])
bounds_u = [(max(a_I_min / ((nt - nt_budget) * epsilon + nt_budget * (b_J_max / (
- ns * epsilon * kappa * K_min))), epsilon / kappa), a_I_max / (nt * epsilon * K_min))] * ns_budget
+ ns * epsilon * kappa * K_min))), epsilon / kappa), a_I_max / (nt * epsilon * K_min))] * ns_budget
bounds_v = [(
- max(b_J_min / ((ns - ns_budget) * epsilon + ns_budget * (kappa * a_I_max / (nt * epsilon * K_min))),
- epsilon * kappa), b_J_max / (ns * epsilon * K_min))] * nt_budget
+ max(b_J_min / ((ns - ns_budget) * epsilon + ns_budget * (kappa * a_I_max / (nt * epsilon * K_min))),
+ epsilon * kappa), b_J_max / (ns * epsilon * K_min))] * nt_budget
# pre-calculated constants for the objective
- vec_eps_IJc = epsilon * kappa * (K_IJc * np.ones(nt - nt_budget).reshape((1, -1))).sum(axis=1)
- vec_eps_IcJ = (epsilon / kappa) * (np.ones(ns - ns_budget).reshape((-1, 1)) * K_IcJ).sum(axis=0)
+ vec_eps_IJc = epsilon * kappa * nx.sum(
+ K_IJc * nx.ones((nt - nt_budget,), type_as=M)[None, :],
+ axis=1
+ )
+ vec_eps_IcJ = (epsilon / kappa) * nx.sum(
+ nx.ones((ns - ns_budget,), type_as=M)[:, None] * K_IcJ,
+ axis=0
+ )
# initialisation
- u0 = np.full(ns_budget, (1. / ns_budget) + epsilon / kappa)
- v0 = np.full(nt_budget, (1. / nt_budget) + epsilon * kappa)
+ u0 = nx.full((ns_budget,), 1. / ns_budget + epsilon / kappa, type_as=M)
+ v0 = nx.full((nt_budget,), 1. / nt_budget + epsilon * kappa, type_as=M)
# pre-calculed constants for Restricted Sinkhorn (see Algorithm 1 in supplementary of [26])
if restricted:
if ns_budget != ns or nt_budget != nt:
- cst_u = kappa * epsilon * K_IJc.sum(axis=1)
- cst_v = epsilon * K_IcJ.sum(axis=0) / kappa
+ cst_u = kappa * epsilon * nx.sum(K_IJc, axis=1)
+ cst_v = epsilon * nx.sum(K_IcJ, axis=0) / kappa
- cpt = 1
- while cpt < 5: # 5 iterations
- K_IJ_v = np.dot(K_IJ.T, u0) + cst_v
+ for _ in range(5): # 5 iterations
+ K_IJ_v = nx.dot(K_IJ.T, u0) + cst_v
v0 = b_J / (kappa * K_IJ_v)
- KIJ_u = np.dot(K_IJ, v0) + cst_u
+ KIJ_u = nx.dot(K_IJ, v0) + cst_u
u0 = (kappa * a_I) / KIJ_u
- cpt += 1
u0 = projection(u0, epsilon / kappa)
v0 = projection(v0, epsilon * kappa)
@@ -2219,15 +3490,13 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res
def restricted_sinkhorn(usc, vsc, max_iter=5):
"""
- Restricted Sinkhorn Algorithm as a warm-start initialized point for L-BFGS-B (see Algorithm 1 in supplementary of [26])
+ Restricted Sinkhorn Algorithm as a warm-start initialized pointfor L-BFGS-B)
"""
- cpt = 1
- while cpt < max_iter:
- K_IJ_v = np.dot(K_IJ.T, usc) + cst_v
+ for _ in range(max_iter):
+ K_IJ_v = nx.dot(K_IJ.T, usc) + cst_v
vsc = b_J / (kappa * K_IJ_v)
- KIJ_u = np.dot(K_IJ, vsc) + cst_u
+ KIJ_u = nx.dot(K_IJ, vsc) + cst_u
usc = (kappa * a_I) / KIJ_u
- cpt += 1
usc = projection(usc, epsilon / kappa)
vsc = projection(vsc, epsilon * kappa)
@@ -2235,17 +3504,20 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res
return usc, vsc
def screened_obj(usc, vsc):
- part_IJ = np.dot(np.dot(usc, K_IJ), vsc) - kappa * np.dot(a_I, np.log(usc)) - (1. / kappa) * np.dot(b_J,
- np.log(vsc))
- part_IJc = np.dot(usc, vec_eps_IJc)
- part_IcJ = np.dot(vec_eps_IcJ, vsc)
+ part_IJ = (
+ nx.dot(nx.dot(usc, K_IJ), vsc)
+ - kappa * nx.dot(a_I, nx.log(usc))
+ - (1. / kappa) * nx.dot(b_J, nx.log(vsc))
+ )
+ part_IJc = nx.dot(usc, vec_eps_IJc)
+ part_IcJ = nx.dot(vec_eps_IcJ, vsc)
psi_epsilon = part_IJ + part_IJc + part_IcJ
return psi_epsilon
def screened_grad(usc, vsc):
# gradients of Psi_(kappa,epsilon) w.r.t u and v
- grad_u = np.dot(K_IJ, vsc) + vec_eps_IJc - kappa * a_I / usc
- grad_v = np.dot(K_IJ.T, usc) + vec_eps_IcJ - (1. / kappa) * b_J / vsc
+ grad_u = nx.dot(K_IJ, vsc) + vec_eps_IJc - kappa * a_I / usc
+ grad_v = nx.dot(K_IJ.T, usc) + vec_eps_IcJ - (1. / kappa) * b_J / vsc
return grad_u, grad_v
def bfgspost(theta):
@@ -2255,20 +3527,20 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res
f = screened_obj(u, v)
# gradient
g_u, g_v = screened_grad(u, v)
- g = np.hstack([g_u, g_v])
- return f, g
+ g = nx.concatenate([g_u, g_v], axis=0)
+ return nx.to_numpy(f), nx.to_numpy(g)
# ----------------------------------------------------------------------------------------------------------------#
# Step 2: L-BFGS-B solver #
# ----------------------------------------------------------------------------------------------------------------#
u0, v0 = restricted_sinkhorn(u0, v0)
- theta0 = np.hstack([u0, v0])
+ theta0 = nx.concatenate([u0, v0], axis=0)
bounds = bounds_u + bounds_v # constraint bounds
def obj(theta):
- return bfgspost(theta)
+ return bfgspost(nx.from_numpy(theta, type_as=M))
theta, _, _ = fmin_l_bfgs_b(func=obj,
x0=theta0,
@@ -2276,12 +3548,13 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res
maxfun=maxfun,
pgtol=pgtol,
maxiter=maxiter)
+ theta = nx.from_numpy(theta, type_as=M)
usc = theta[:ns_budget]
vsc = theta[ns_budget:]
- usc_full = np.full(ns, epsilon / kappa)
- vsc_full = np.full(nt, epsilon * kappa)
+ usc_full = nx.full((ns,), epsilon / kappa, type_as=M)
+ vsc_full = nx.full((nt,), epsilon * kappa, type_as=M)
usc_full[Isel] = usc
vsc_full[Jsel] = vsc
@@ -2293,7 +3566,7 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res
log['Jsel'] = Jsel
gamma = usc_full[:, None] * K * vsc_full[None, :]
- gamma = gamma / gamma.sum()
+ gamma = gamma / nx.sum(gamma)
if log:
return gamma, log
diff --git a/ot/da.py b/ot/da.py
index b881a8b..4fd97df 100644
--- a/ot/da.py
+++ b/ot/da.py
@@ -26,34 +26,36 @@ from .optim import gcg
def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10,
numInnerItermax=200, stopInnerThr=1e-9, verbose=False,
log=False):
- """
+ r"""
Solve the entropic regularization optimal transport problem with nonconvex
group lasso regularization
The function solves the following optimization problem:
.. math::
- \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega_e(\gamma)
- + \eta \Omega_g(\gamma)
+ \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
+ \mathrm{reg} \cdot \Omega_e(\gamma) + \eta \ \Omega_g(\gamma)
+
+ s.t. \ \gamma \mathbf{1} = \mathbf{a}
+
+ \gamma^T \mathbf{1} = \mathbf{b}
- s.t. \gamma 1 = a
+ \gamma \geq 0
- \gamma^T 1= b
- \gamma\geq 0
where :
- - M is the (ns,nt) metric cost matrix
+ - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix
- :math:`\Omega_e` is the entropic regularization term :math:`\Omega_e
(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- :math:`\Omega_g` is the group lasso regularization term
:math:`\Omega_g(\gamma)=\sum_{i,c} \|\gamma_{i,\mathcal{I}_c}\|^{1/2}_1`
- where :math:`\mathcal{I}_c` are the index of samples from class c
+ where :math:`\mathcal{I}_c` are the index of samples from class `c`
in the source domain.
- - a and b are source and target weights (sum to 1)
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1)
The algorithm used for solving the problem is the generalized conditional
- gradient as proposed in [5]_ [7]_
+ gradient as proposed in :ref:`[5, 7] <references-sinkhorn-lpl1-mm>`.
Parameters
@@ -84,19 +86,20 @@ def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10,
Returns
-------
- gamma : (ns x nt) ndarray
+ gamma : (ns, nt) ndarray
Optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
+ .. _references-sinkhorn-lpl1-mm:
References
----------
-
.. [5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy,
"Optimal Transport for Domain Adaptation," in IEEE
Transactions on Pattern Analysis and Machine Intelligence ,
vol.PP, no.99, pp.1-1
+
.. [7] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015).
Generalized conditional gradient: analysis of convergence
and applications. arXiv preprint arXiv:1510.06567.
@@ -137,34 +140,36 @@ def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10,
def sinkhorn_l1l2_gl(a, labels_a, b, M, reg, eta=0.1, numItermax=10,
numInnerItermax=200, stopInnerThr=1e-9, verbose=False,
log=False):
- """
+ r"""
Solve the entropic regularization optimal transport problem with group
lasso regularization
The function solves the following optimization problem:
.. math::
- \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega_e(\gamma)+
- \eta \Omega_g(\gamma)
+ \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
+ \mathrm{reg} \cdot \Omega_e(\gamma) + \eta \ \Omega_g(\gamma)
+
+ s.t. \ \gamma \mathbf{1} = \mathbf{a}
+
+ \gamma^T \mathbf{1} = \mathbf{b}
- s.t. \gamma 1 = a
+ \gamma \geq 0
- \gamma^T 1= b
- \gamma\geq 0
where :
- - M is the (ns,nt) metric cost matrix
+ - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix
- :math:`\Omega_e` is the entropic regularization term
:math:`\Omega_e(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- :math:`\Omega_g` is the group lasso regulaization term
:math:`\Omega_g(\gamma)=\sum_{i,c} \|\gamma_{i,\mathcal{I}_c}\|^2`
where :math:`\mathcal{I}_c` are the index of samples from class
- c in the source domain.
- - a and b are source and target weights (sum to 1)
+ `c` in the source domain.
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1)
The algorithm used for solving the problem is the generalised conditional
- gradient as proposed in [5]_ [7]_
+ gradient as proposed in :ref:`[5, 7] <references-sinkhorn-l1l2-gl>`.
Parameters
@@ -195,18 +200,19 @@ def sinkhorn_l1l2_gl(a, labels_a, b, M, reg, eta=0.1, numItermax=10,
Returns
-------
- gamma : (ns x nt) ndarray
+ gamma : (ns, nt) ndarray
Optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
+ .. _references-sinkhorn-l1l2-gl:
References
----------
-
.. [5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy,
"Optimal Transport for Domain Adaptation," in IEEE Transactions
on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1
+
.. [7] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015).
Generalized conditional gradient: analysis of convergence and
applications. arXiv preprint arXiv:1510.06567.
@@ -245,38 +251,40 @@ def joint_OT_mapping_linear(xs, xt, mu=1, eta=0.001, bias=False, verbose=False,
verbose2=False, numItermax=100, numInnerItermax=10,
stopInnerThr=1e-6, stopThr=1e-5, log=False,
**kwargs):
- """Joint OT and linear mapping estimation as proposed in [8]
+ r"""Joint OT and linear mapping estimation as proposed in
+ :ref:`[8] <references-joint-OT-mapping-linear>`.
The function solves the following optimization problem:
.. math::
- \min_{\gamma,L}\quad \|L(X_s) -n_s\gamma X_t\|^2_F +
- \mu<\gamma,M>_F + \eta \|L -I\|^2_F
+ \min_{\gamma,L}\quad \|L(\mathbf{X_s}) - n_s\gamma \mathbf{X_t} \|^2_F +
+ \mu \langle \gamma, \mathbf{M} \rangle_F + \eta \|L - \mathbf{I}\|^2_F
- s.t. \gamma 1 = a
+ s.t. \ \gamma \mathbf{1} = \mathbf{a}
- \gamma^T 1= b
+ \gamma^T \mathbf{1} = \mathbf{b}
+
+ \gamma \geq 0
- \gamma\geq 0
where :
- - M is the (ns,nt) squared euclidean cost matrix between samples in
- Xs and Xt (scaled by ns)
- - :math:`L` is a dxd linear operator that approximates the barycentric
+ - :math:`\mathbf{M}` is the (`ns`, `nt`) squared euclidean cost matrix between samples in
+ :math:`\mathbf{X_s}` and :math:`\mathbf{X_t}` (scaled by :math:`n_s`)
+ - :math:`L` is a :math:`d\times d` linear operator that approximates the barycentric
mapping
- - :math:`I` is the identity matrix (neutral linear mapping)
- - a and b are uniform source and target weights
+ - :math:`\mathbf{I}` is the identity matrix (neutral linear mapping)
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are uniform source and target weights
The problem consist in solving jointly an optimal transport matrix
:math:`\gamma` and a linear mapping that fits the barycentric mapping
- :math:`n_s\gamma X_t`.
+ :math:`n_s\gamma \mathbf{X_t}`.
One can also estimate a mapping with constant bias (see supplementary
- material of [8]) using the bias optional argument.
+ material of :ref:`[8] <references-joint-OT-mapping-linear>`) using the bias optional argument.
The algorithm used for solving the problem is the block coordinate
- descent that alternates between updates of G (using conditionnal gradient)
- and the update of L using a classical least square solver.
+ descent that alternates between updates of :math:`\mathbf{G}` (using conditionnal gradient)
+ and the update of :math:`\mathbf{L}` using a classical least square solver.
Parameters
@@ -307,17 +315,17 @@ def joint_OT_mapping_linear(xs, xt, mu=1, eta=0.001, bias=False, verbose=False,
Returns
-------
- gamma : (ns x nt) ndarray
+ gamma : (ns, nt) ndarray
Optimal transportation matrix for the given parameters
- L : (d x d) ndarray
- Linear mapping matrix (d+1 x d if bias)
+ L : (d, d) ndarray
+ Linear mapping matrix ((:math:`d+1`, `d`) if bias)
log : dict
log dictionary return only if log==True in parameters
+ .. _references-joint-OT-mapping-linear:
References
----------
-
.. [8] M. Perrot, N. Courty, R. Flamary, A. Habrard,
"Mapping estimation for discrete optimal transport",
Neural Information Processing Systems (NIPS), 2016.
@@ -434,37 +442,41 @@ def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian',
numItermax=100, numInnerItermax=10,
stopInnerThr=1e-6, stopThr=1e-5, log=False,
**kwargs):
- """Joint OT and nonlinear mapping estimation with kernels as proposed in [8]
+ r"""Joint OT and nonlinear mapping estimation with kernels as proposed in
+ :ref:`[8] <references-joint-OT-mapping-kernel>`.
The function solves the following optimization problem:
.. math::
- \min_{\gamma,L\in\mathcal{H}}\quad \|L(X_s) -
- n_s\gamma X_t\|^2_F + \mu<\gamma,M>_F + \eta \|L\|^2_\mathcal{H}
+ \min_{\gamma, L\in\mathcal{H}}\quad \|L(\mathbf{X_s}) -
+ n_s\gamma \mathbf{X_t}\|^2_F + \mu \langle \gamma, \mathbf{M} \rangle_F +
+ \eta \|L\|^2_\mathcal{H}
+
+ s.t. \ \gamma \mathbf{1} = \mathbf{a}
- s.t. \gamma 1 = a
+ \gamma^T \mathbf{1} = \mathbf{b}
+
+ \gamma \geq 0
- \gamma^T 1= b
- \gamma\geq 0
where :
- - M is the (ns,nt) squared euclidean cost matrix between samples in
- Xs and Xt (scaled by ns)
- - :math:`L` is a ns x d linear operator on a kernel matrix that
+ - :math:`\mathbf{M}` is the (`ns`, `nt`) squared euclidean cost matrix between samples in
+ :math:`\mathbf{X_s}` and :math:`\mathbf{X_t}` (scaled by :math:`n_s`)
+ - :math:`L` is a :math:`n_s \times d` linear operator on a kernel matrix that
approximates the barycentric mapping
- - a and b are uniform source and target weights
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are uniform source and target weights
The problem consist in solving jointly an optimal transport matrix
:math:`\gamma` and the nonlinear mapping that fits the barycentric mapping
- :math:`n_s\gamma X_t`.
+ :math:`n_s\gamma \mathbf{X_t}`.
One can also estimate a mapping with constant bias (see supplementary
- material of [8]) using the bias optional argument.
+ material of :ref:`[8] <references-joint-OT-mapping-kernel>`) using the bias optional argument.
The algorithm used for solving the problem is the block coordinate
- descent that alternates between updates of G (using conditionnal gradient)
- and the update of L using a classical kernel least square solver.
+ descent that alternates between updates of :math:`\mathbf{G}` (using conditionnal gradient)
+ and the update of :math:`\mathbf{L}` using a classical kernel least square solver.
Parameters
@@ -478,7 +490,7 @@ def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian',
eta : float, optional
Regularization term for the linear mapping L (>0)
kerneltype : str,optional
- kernel used by calling function ot.utils.kernel (gaussian by default)
+ kernel used by calling function :py:func:`ot.utils.kernel` (gaussian by default)
sigma : float, optional
Gaussian kernel bandwidth.
bias : bool,optional
@@ -501,17 +513,17 @@ def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian',
Returns
-------
- gamma : (ns x nt) ndarray
+ gamma : (ns, nt) ndarray
Optimal transportation matrix for the given parameters
- L : (ns x d) ndarray
- Nonlinear mapping matrix (ns+1 x d if bias)
+ L : (ns, d) ndarray
+ Nonlinear mapping matrix ((:math:`n_s+1`, `d`) if bias)
log : dict
log dictionary return only if log==True in parameters
+ .. _references-joint-OT-mapping-kernel:
References
----------
-
.. [8] M. Perrot, N. Courty, R. Flamary, A. Habrard,
"Mapping estimation for discrete optimal transport",
Neural Information Processing Systems (NIPS), 2016.
@@ -645,26 +657,27 @@ def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian',
def OT_mapping_linear(xs, xt, reg=1e-6, ws=None,
wt=None, bias=True, log=False):
- """ return OT linear operator between samples
+ r"""Return OT linear operator between samples.
The function estimates the optimal linear operator that aligns the two
empirical distributions. This is equivalent to estimating the closed
- form mapping between two Gaussian distributions :math:`N(\mu_s,\Sigma_s)`
- and :math:`N(\mu_t,\Sigma_t)` as proposed in [14] and discussed in remark
- 2.29 in [15].
+ form mapping between two Gaussian distributions :math:`\mathcal{N}(\mu_s,\Sigma_s)`
+ and :math:`\mathcal{N}(\mu_t,\Sigma_t)` as proposed in
+ :ref:`[14] <references-OT-mapping-linear>` and discussed in remark 2.29 in
+ :ref:`[15] <references-OT-mapping-linear>`.
The linear operator from source to target :math:`M`
.. math::
- M(x)=Ax+b
+ M(\mathbf{x})= \mathbf{A} \mathbf{x} + \mathbf{b}
where :
.. math::
- A=\Sigma_s^{-1/2}(\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2})^{1/2}
+ \mathbf{A} &= \Sigma_s^{-1/2} \left(\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2} \right)^{1/2}
\Sigma_s^{-1/2}
- .. math::
- b=\mu_t-A\mu_s
+
+ \mathbf{b} &= \mu_t - \mathbf{A} \mu_s
Parameters
----------
@@ -673,35 +686,35 @@ def OT_mapping_linear(xs, xt, reg=1e-6, ws=None,
xt : np.ndarray (nt,d)
samples in the target domain
reg : float,optional
- regularization added to the diagonals of convariances (>0)
+ regularization added to the diagonals of covariances (>0)
ws : np.ndarray (ns,1), optional
weights for the source samples
wt : np.ndarray (ns,1), optional
weights for the target samples
bias: boolean, optional
- estimate bias b else b=0 (default:True)
+ estimate bias :math:`\mathbf{b}` else :math:`\mathbf{b} = 0` (default:True)
log : bool, optional
record log if True
Returns
-------
- A : (d x d) ndarray
+ A : (d, d) ndarray
Linear operator
- b : (1 x d) ndarray
+ b : (1, d) ndarray
bias
log : dict
log dictionary return only if log==True in parameters
+ .. _references-OT-mapping-linear:
References
----------
-
.. [14] Knott, M. and Smith, C. S. "On the optimal mapping of
distributions", Journal of Optimization Theory and Applications
Vol 43, 1984
- .. [15] Peyré, G., & Cuturi, M. (2017). "Computational Optimal
+ .. [15] Peyré, G., & Cuturi, M. (2017). "Computational Optimal
Transport", 2018.
@@ -754,24 +767,34 @@ def emd_laplace(a, b, xs, xt, M, sim='knn', sim_param=None, reg='pos', eta=1, al
r"""Solve the optimal transport problem (OT) with Laplacian regularization
.. math::
- \gamma = arg\min_\gamma <\gamma,M>_F + eta\Omega_\alpha(\gamma)
+ \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
+ \eta \cdot \Omega_\alpha(\gamma)
- s.t.\ \gamma 1 = a
+ s.t. \ \gamma \mathbf{1} = \mathbf{a}
- \gamma^T 1= b
+ \gamma^T \mathbf{1} = \mathbf{b}
- \gamma\geq 0
+ \gamma \geq 0
where:
- - a and b are source and target weights (sum to 1)
- - xs and xt are source and target samples
- - M is the (ns,nt) metric cost matrix
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1)
+ - :math:`\mathbf{x_s}` and :math:`\mathbf{x_t}` are source and target samples
+ - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix
- :math:`\Omega_\alpha` is the Laplacian regularization term
- :math:`\Omega_\alpha = (1-\alpha)/n_s^2\sum_{i,j}S^s_{i,j}\|T(\mathbf{x}^s_i)-T(\mathbf{x}^s_j)\|^2+\alpha/n_t^2\sum_{i,j}S^t_{i,j}^'\|T(\mathbf{x}^t_i)-T(\mathbf{x}^t_j)\|^2`
- with :math:`S^s_{i,j}, S^t_{i,j}` denoting source and target similarity matrices and :math:`T(\cdot)` being a barycentric mapping
- The algorithm used for solving the problem is the conditional gradient algorithm as proposed in [5].
+ .. math::
+ \Omega_\alpha = \frac{1 - \alpha}{n_s^2} \sum_{i,j}
+ \mathbf{S^s}_{i,j} \|T(\mathbf{x}^s_i) - T(\mathbf{x}^s_j) \|^2 +
+ \frac{\alpha}{n_t^2} \sum_{i,j}
+ \mathbf{S^t}_{i,j} \|T(\mathbf{x}^t_i) - T(\mathbf{x}^t_j) \|^2
+
+
+ with :math:`\mathbf{S^s}_{i,j}, \mathbf{S^t}_{i,j}` denoting source and target similarity
+ matrices and :math:`T(\cdot)` being a barycentric mapping.
+
+ The algorithm used for solving the problem is the conditional gradient algorithm as proposed in
+ :ref:`[5] <references-emd-laplace>`.
Parameters
----------
@@ -811,22 +834,23 @@ def emd_laplace(a, b, xs, xt, M, sim='knn', sim_param=None, reg='pos', eta=1, al
Returns
-------
- gamma : (ns x nt) ndarray
+ gamma : (ns, nt) ndarray
Optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
+ .. _references-emd-laplace:
References
----------
-
.. [5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy,
"Optimal Transport for Domain Adaptation," in IEEE
- Transactions on Pattern Analysis and Machine Intelligence ,
+ Transactions on Pattern Analysis and Machine Intelligence,
vol.PP, no.99, pp.1-1
+
.. [30] R. Flamary, N. Courty, D. Tuia, A. Rakotomamonjy,
"Optimal transport with Laplacian regularization: Applications to domain adaptation and shape matching,"
- in NIPS Workshop on Optimal Transport and Machine Learning OTML, 2014.
+ in NIPS Workshop on Optimal Transport and Machine Learning OTML, 2014.
See Also
--------
@@ -882,7 +906,7 @@ def emd_laplace(a, b, xs, xt, M, sim='knn', sim_param=None, reg='pos', eta=1, al
def distribution_estimation_uniform(X):
- """estimates a uniform distribution from an array of samples X
+ """estimates a uniform distribution from an array of samples :math:`\mathbf{X}`
Parameters
----------
@@ -892,7 +916,7 @@ def distribution_estimation_uniform(X):
Returns
-------
mu : array-like, shape (n_samples,)
- The uniform distribution estimated from X
+ The uniform distribution estimated from :math:`\mathbf{X}`
"""
return unif(X.shape[0])
@@ -902,32 +926,32 @@ class BaseTransport(BaseEstimator):
"""Base class for OTDA objects
- Notes
- -----
- All estimators should specify all the parameters that can be set
- at the class level in their ``__init__`` as explicit keyword
- arguments (no ``*args`` or ``**kwargs``).
+ .. note::
+ All estimators should specify all the parameters that can be set
+ at the class level in their ``__init__`` as explicit keyword
+ arguments (no ``*args`` or ``**kwargs``).
- the fit method should:
+ The fit method should:
- estimate a cost matrix and store it in a `cost_` attribute
- - estimate a coupling matrix and store it in a `coupling_`
- attribute
+ - estimate a coupling matrix and store it in a `coupling_` attribute
- estimate distributions from source and target data and store them in
- mu_s and mu_t attributes
- - store Xs and Xt in attributes to be used later on in transform and
- inverse_transform methods
+ `mu_s` and `mu_t` attributes
+ - store `Xs` and `Xt` in attributes to be used later on in `transform` and
+ `inverse_transform` methods
+
+ `transform` method should always get as input a `Xs` parameter
+
+ `inverse_transform` method should always get as input a `Xt` parameter
- transform method should always get as input a Xs parameter
- inverse_transform method should always get as input a Xt parameter
+ `transform_labels` method should always get as input a `ys` parameter
- transform_labels method should always get as input a ys parameter
- inverse_transform_labels method should always get as input a yt parameter
+ `inverse_transform_labels` method should always get as input a `yt` parameter
"""
def fit(self, Xs=None, ys=None, Xt=None, yt=None):
"""Build a coupling matrix from source and target sets of samples
- (Xs, ys) and (Xt, yt)
+ :math:`(\mathbf{X_s}, \mathbf{y_s})` and :math:`(\mathbf{X_t}, \mathbf{y_t})`
Parameters
----------
@@ -938,8 +962,8 @@ class BaseTransport(BaseEstimator):
Xt : array-like, shape (n_target_samples, n_features)
The training input samples.
yt : array-like, shape (n_target_samples,)
- The class labels. If some target samples are unlabeled, fill the
- yt's elements with -1.
+ The class labels. If some target samples are unlabelled, fill the
+ :math:`\mathbf{y_t}`'s elements with -1.
Warning: Note that, due to this convention -1 cannot be used as a
class label
@@ -987,8 +1011,8 @@ class BaseTransport(BaseEstimator):
def fit_transform(self, Xs=None, ys=None, Xt=None, yt=None):
"""Build a coupling matrix from source and target sets of samples
- (Xs, ys) and (Xt, yt) and transports source samples Xs onto target
- ones Xt
+ :math:`(\mathbf{X_s}, \mathbf{y_s})` and :math:`(\mathbf{X_t}, \mathbf{y_t})`
+ and transports source samples :math:`\mathbf{X_s}` onto target ones :math:`\mathbf{X_t}`
Parameters
----------
@@ -999,8 +1023,8 @@ class BaseTransport(BaseEstimator):
Xt : array-like, shape (n_target_samples, n_features)
The training input samples.
yt : array-like, shape (n_target_samples,)
- The class labels. If some target samples are unlabeled, fill the
- yt's elements with -1.
+ The class labels. If some target samples are unlabelled, fill the
+ :math:`\mathbf{y_t}`'s elements with -1.
Warning: Note that, due to this convention -1 cannot be used as a
class label
@@ -1014,7 +1038,7 @@ class BaseTransport(BaseEstimator):
return self.fit(Xs, ys, Xt, yt).transform(Xs, ys, Xt, yt)
def transform(self, Xs=None, ys=None, Xt=None, yt=None, batch_size=128):
- """Transports source samples Xs onto target ones Xt
+ """Transports source samples :math:`\mathbf{X_s}` onto target ones :math:`\mathbf{X_t}`
Parameters
----------
@@ -1025,8 +1049,8 @@ class BaseTransport(BaseEstimator):
Xt : array-like, shape (n_target_samples, n_features)
The target input samples.
yt : array-like, shape (n_target_samples,)
- The class labels for target. If some target samples are unlabeled, fill the
- yt's elements with -1.
+ The class labels for target. If some target samples are unlabelled, fill the
+ :math:`\mathbf{y_t}`'s elements with -1.
Warning: Note that, due to this convention -1 cannot be used as a
class label
@@ -1081,7 +1105,8 @@ class BaseTransport(BaseEstimator):
return transp_Xs
def transform_labels(self, ys=None):
- """Propagate source labels ys to obtain estimated target labels as in [27]
+ """Propagate source labels :math:`\mathbf{y_s}` to obtain estimated target labels as in
+ :ref:`[27] <references-basetransport-transform-labels>`.
Parameters
----------
@@ -1093,9 +1118,10 @@ class BaseTransport(BaseEstimator):
transp_ys : array-like, shape (n_target_samples, nb_classes)
Estimated soft target labels.
+
+ .. _references-basetransport-transform-labels:
References
----------
-
.. [27] Ievgen Redko, Nicolas Courty, Rémi Flamary, Devis Tuia
"Optimal transport for multi-source domain adaptation under target shift",
International Conference on Artificial Intelligence and Statistics (AISTATS), 2019.
@@ -1111,7 +1137,7 @@ class BaseTransport(BaseEstimator):
D1 = np.zeros((n, len(ysTemp)))
# perform label propagation
- transp = self.coupling_ / np.sum(self.coupling_, 1)[:, None]
+ transp = self.coupling_ / np.sum(self.coupling_, 0, keepdims=True)
# set nans to 0
transp[~ np.isfinite(transp)] = 0
@@ -1126,7 +1152,7 @@ class BaseTransport(BaseEstimator):
def inverse_transform(self, Xs=None, ys=None, Xt=None, yt=None,
batch_size=128):
- """Transports target samples Xt onto source samples Xs
+ """Transports target samples :math:`\mathbf{X_t}` onto source samples :math:`\mathbf{X_s}`
Parameters
----------
@@ -1137,8 +1163,8 @@ class BaseTransport(BaseEstimator):
Xt : array-like, shape (n_target_samples, n_features)
The target input samples.
yt : array-like, shape (n_target_samples,)
- The target class labels. If some target samples are unlabeled, fill the
- yt's elements with -1.
+ The target class labels. If some target samples are unlabelled, fill the
+ :math:`\mathbf{y_t}`'s elements with -1.
Warning: Note that, due to this convention -1 cannot be used as a
class label
@@ -1192,7 +1218,8 @@ class BaseTransport(BaseEstimator):
return transp_Xt
def inverse_transform_labels(self, yt=None):
- """Propagate target labels yt to obtain estimated source labels ys
+ """Propagate target labels :math:`\mathbf{y_t}` to obtain estimated source labels
+ :math:`\mathbf{y_s}`
Parameters
----------
@@ -1228,39 +1255,41 @@ class BaseTransport(BaseEstimator):
class LinearTransport(BaseTransport):
- """ OT linear operator between empirical distributions
+ r""" OT linear operator between empirical distributions
The function estimates the optimal linear operator that aligns the two
empirical distributions. This is equivalent to estimating the closed
- form mapping between two Gaussian distributions :math:`N(\mu_s,\Sigma_s)`
- and :math:`N(\mu_t,\Sigma_t)` as proposed in [14] and discussed in
- remark 2.29 in [15].
+ form mapping between two Gaussian distributions :math:`\mathcal{N}(\mu_s,\Sigma_s)`
+ and :math:`\mathcal{N}(\mu_t,\Sigma_t)` as proposed in
+ :ref:`[14] <references-lineartransport>` and discussed in remark 2.29 in
+ :ref:`[15] <references-lineartransport>`.
The linear operator from source to target :math:`M`
.. math::
- M(x)=Ax+b
+ M(\mathbf{x})= \mathbf{A} \mathbf{x} + \mathbf{b}
where :
.. math::
- A=\Sigma_s^{-1/2}(\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2})^{1/2}
+ \mathbf{A} &= \Sigma_s^{-1/2} \left(\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2} \right)^{1/2}
\Sigma_s^{-1/2}
- .. math::
- b=\mu_t-A\mu_s
+
+ \mathbf{b} &= \mu_t - \mathbf{A} \mu_s
Parameters
----------
reg : float,optional
- regularization added to the daigonals of convariances (>0)
+ regularization added to the daigonals of covariances (>0)
bias: boolean, optional
- estimate bias b else b=0 (default:True)
+ estimate bias :math:`\mathbf{b}` else :math:`\mathbf{b} = 0` (default:True)
log : bool, optional
record log if True
+
+ .. _references-lineartransport:
References
----------
-
.. [14] Knott, M. and Smith, C. S. "On the optimal mapping of
distributions", Journal of Optimization Theory and Applications
Vol 43, 1984
@@ -1279,7 +1308,7 @@ class LinearTransport(BaseTransport):
def fit(self, Xs=None, ys=None, Xt=None, yt=None):
"""Build a coupling matrix from source and target sets of samples
- (Xs, ys) and (Xt, yt)
+ :math:`(\mathbf{X_s}, \mathbf{y_s})` and :math:`(\mathbf{X_t}, \mathbf{y_t})`
Parameters
----------
@@ -1290,8 +1319,8 @@ class LinearTransport(BaseTransport):
Xt : array-like, shape (n_target_samples, n_features)
The training input samples.
yt : array-like, shape (n_target_samples,)
- The class labels. If some target samples are unlabeled, fill the
- yt's elements with -1.
+ The class labels. If some target samples are unlabelled, fill the
+ :math:`\mathbf{y_t}`'s elements with -1.
Warning: Note that, due to this convention -1 cannot be used as a
class label
@@ -1325,7 +1354,7 @@ class LinearTransport(BaseTransport):
return self
def transform(self, Xs=None, ys=None, Xt=None, yt=None, batch_size=128):
- """Transports source samples Xs onto target ones Xt
+ """Transports source samples :math:`\mathbf{X_s}` onto target ones :math:`\mathbf{X_t}`
Parameters
----------
@@ -1336,8 +1365,8 @@ class LinearTransport(BaseTransport):
Xt : array-like, shape (n_target_samples, n_features)
The training input samples.
yt : array-like, shape (n_target_samples,)
- The class labels. If some target samples are unlabeled, fill the
- yt's elements with -1.
+ The class labels. If some target samples are unlabelled, fill the
+ :math:`\mathbf{y_t}`'s elements with -1.
Warning: Note that, due to this convention -1 cannot be used as a
class label
@@ -1358,7 +1387,7 @@ class LinearTransport(BaseTransport):
def inverse_transform(self, Xs=None, ys=None, Xt=None, yt=None,
batch_size=128):
- """Transports target samples Xt onto target samples Xs
+ """Transports target samples :math:`\mathbf{X_t}` onto source samples :math:`\mathbf{X_s}`
Parameters
----------
@@ -1369,8 +1398,8 @@ class LinearTransport(BaseTransport):
Xt : array-like, shape (n_target_samples, n_features)
The training input samples.
yt : array-like, shape (n_target_samples,)
- The class labels. If some target samples are unlabeled, fill the
- yt's elements with -1.
+ The class labels. If some target samples are unlabelled, fill the
+ :math:`\mathbf{y_t}`'s elements with -1.
Warning: Note that, due to this convention -1 cannot be used as a
class label
@@ -1392,7 +1421,7 @@ class LinearTransport(BaseTransport):
class SinkhornTransport(BaseTransport):
- """Domain Adapatation OT method based on Sinkhorn Algorithm
+ """Domain Adaptation OT method based on Sinkhorn Algorithm
Parameters
----------
@@ -1400,7 +1429,7 @@ class SinkhornTransport(BaseTransport):
Entropic regularization parameter
max_iter : int, float, optional (default=1000)
The minimum number of iteration before stopping the optimization
- algorithm if no it has not converged
+ algorithm if it has not converged
tol : float, optional (default=10e-9)
The precision required to stop the optimization algorithm.
verbose : bool, optional (default=False)
@@ -1417,8 +1446,8 @@ class SinkhornTransport(BaseTransport):
out_of_sample_map : string, optional (default="ferradans")
The kind of out of sample mapping to apply to transport samples
from a domain into another one. Currently the only possible option is
- "ferradans" which uses the method proposed in [6].
- limit_max: float, optional (defaul=np.infty)
+ "ferradans" which uses the method proposed in :ref:`[6] <references-sinkhorntransport>`.
+ limit_max: float, optional (default=np.infty)
Controls the semi supervised mode. Transport between labeled source
and target samples of different classes will exhibit an cost defined
by this variable
@@ -1428,16 +1457,20 @@ class SinkhornTransport(BaseTransport):
coupling_ : array-like, shape (n_source_samples, n_target_samples)
The optimal coupling
log_ : dictionary
- The dictionary of log, empty dic if parameter log is not True
+ The dictionary of log, empty dict if parameter log is not True
+
+ .. _references-sinkhorntransport:
References
----------
.. [1] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy,
"Optimal Transport for Domain Adaptation," in IEEE Transactions
on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1
+
.. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal
Transport, Advances in Neural Information Processing Systems (NIPS)
26, 2013
+
.. [6] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014).
Regularized discrete optimal transport. SIAM Journal on Imaging
Sciences, 7(3), 1853-1882.
@@ -1461,7 +1494,7 @@ class SinkhornTransport(BaseTransport):
def fit(self, Xs=None, ys=None, Xt=None, yt=None):
"""Build a coupling matrix from source and target sets of samples
- (Xs, ys) and (Xt, yt)
+ :math:`(\mathbf{X_s}, \mathbf{y_s})` and :math:`(\mathbf{X_t}, \mathbf{y_t})`
Parameters
----------
@@ -1472,8 +1505,8 @@ class SinkhornTransport(BaseTransport):
Xt : array-like, shape (n_target_samples, n_features)
The training input samples.
yt : array-like, shape (n_target_samples,)
- The class labels. If some target samples are unlabeled, fill the
- yt's elements with -1.
+ The class labels. If some target samples are unlabelled, fill the
+ :math:`\mathbf{y_t}`'s elements with -1.
Warning: Note that, due to this convention -1 cannot be used as a
class label
@@ -1504,7 +1537,7 @@ class SinkhornTransport(BaseTransport):
class EMDTransport(BaseTransport):
- """Domain Adapatation OT method based on Earth Mover's Distance
+ """Domain Adaptation OT method based on Earth Mover's Distance
Parameters
----------
@@ -1520,7 +1553,7 @@ class EMDTransport(BaseTransport):
out_of_sample_map : string, optional (default="ferradans")
The kind of out of sample mapping to apply to transport samples
from a domain into another one. Currently the only possible option is
- "ferradans" which uses the method proposed in [6].
+ "ferradans" which uses the method proposed in :ref:`[6] <references-emdtransport>`.
limit_max: float, optional (default=10)
Controls the semi supervised mode. Transport between labeled source
and target samples of different classes will exhibit an infinite cost
@@ -1534,14 +1567,16 @@ class EMDTransport(BaseTransport):
coupling_ : array-like, shape (n_source_samples, n_target_samples)
The optimal coupling
+
+ .. _references-emdtransport:
References
----------
.. [1] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy,
- "Optimal Transport for Domain Adaptation," in IEEE Transactions
- on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1
+ "Optimal Transport for Domain Adaptation," in IEEE Transactions
+ on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1
.. [6] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014).
- Regularized discrete optimal transport. SIAM Journal on Imaging
- Sciences, 7(3), 1853-1882.
+ Regularized discrete optimal transport. SIAM Journal on Imaging
+ Sciences, 7(3), 1853-1882.
"""
def __init__(self, metric="sqeuclidean", norm=None, log=False,
@@ -1558,7 +1593,7 @@ class EMDTransport(BaseTransport):
def fit(self, Xs, ys=None, Xt=None, yt=None):
"""Build a coupling matrix from source and target sets of samples
- (Xs, ys) and (Xt, yt)
+ :math:`(\mathbf{X_s}, \mathbf{y_s})` and :math:`(\mathbf{X_t}, \mathbf{y_t})`
Parameters
----------
@@ -1569,8 +1604,8 @@ class EMDTransport(BaseTransport):
Xt : array-like, shape (n_target_samples, n_features)
The training input samples.
yt : array-like, shape (n_target_samples,)
- The class labels. If some target samples are unlabeled, fill the
- yt's elements with -1.
+ The class labels. If some target samples are unlabelled, fill the
+ :math:`\mathbf{y_t}`'s elements with -1.
Warning: Note that, due to this convention -1 cannot be used as a
class label
@@ -1597,8 +1632,7 @@ class EMDTransport(BaseTransport):
class SinkhornLpl1Transport(BaseTransport):
-
- """Domain Adapatation OT method based on sinkhorn algorithm +
+ r"""Domain Adaptation OT method based on sinkhorn algorithm +
LpL1 class regularization.
Parameters
@@ -1609,7 +1643,7 @@ class SinkhornLpl1Transport(BaseTransport):
Class regularization parameter
max_iter : int, float, optional (default=10)
The minimum number of iteration before stopping the optimization
- algorithm if no it has not converged
+ algorithm if it has not converged
max_inner_iter : int, float, optional (default=200)
The number of iteration in the inner loop
log : bool, optional (default=False)
@@ -1628,8 +1662,8 @@ class SinkhornLpl1Transport(BaseTransport):
out_of_sample_map : string, optional (default="ferradans")
The kind of out of sample mapping to apply to transport samples
from a domain into another one. Currently the only possible option is
- "ferradans" which uses the method proposed in [6].
- limit_max: float, optional (defaul=np.infty)
+ "ferradans" which uses the method proposed in :ref:`[6] <references-sinkhornlpl1transport>`.
+ limit_max: float, optional (default=np.infty)
Controls the semi supervised mode. Transport between labeled source
and target samples of different classes will exhibit a cost defined by
limit_max.
@@ -1639,16 +1673,19 @@ class SinkhornLpl1Transport(BaseTransport):
coupling_ : array-like, shape (n_source_samples, n_target_samples)
The optimal coupling
+
+ .. _references-sinkhornlpl1transport:
References
----------
-
.. [1] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy,
"Optimal Transport for Domain Adaptation," in IEEE
Transactions on Pattern Analysis and Machine Intelligence ,
vol.PP, no.99, pp.1-1
+
.. [2] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015).
Generalized conditional gradient: analysis of convergence
and applications. arXiv preprint arXiv:1510.06567.
+
.. [6] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014).
Regularized discrete optimal transport. SIAM Journal on Imaging
Sciences, 7(3), 1853-1882.
@@ -1675,7 +1712,7 @@ class SinkhornLpl1Transport(BaseTransport):
def fit(self, Xs, ys=None, Xt=None, yt=None):
"""Build a coupling matrix from source and target sets of samples
- (Xs, ys) and (Xt, yt)
+ :math:`(\mathbf{X_s}, \mathbf{y_s})` and :math:`(\mathbf{X_t}, \mathbf{y_t})`
Parameters
----------
@@ -1686,8 +1723,8 @@ class SinkhornLpl1Transport(BaseTransport):
Xt : array-like, shape (n_target_samples, n_features)
The training input samples.
yt : array-like, shape (n_target_samples,)
- The class labels. If some target samples are unlabeled, fill the
- yt's elements with -1.
+ The class labels. If some target samples are unlabelled, fill the
+ :math:`\mathbf{y_t}`'s elements with -1.
Warning: Note that, due to this convention -1 cannot be used as a
class label
@@ -1719,13 +1756,14 @@ class SinkhornLpl1Transport(BaseTransport):
class EMDLaplaceTransport(BaseTransport):
- """Domain Adapatation OT method based on Earth Mover's Distance with Laplacian regularization
+ """Domain Adaptation OT method based on Earth Mover's Distance with Laplacian regularization
Parameters
----------
reg_type : string optional (default='pos')
Type of the regularization term: 'pos' and 'disp' for
- regularization term defined in [2] and [6], respectively.
+ regularization term defined in :ref:`[2] <references-emdlaplacetransport>` and
+ :ref:`[6] <references-emdlaplacetransport>`, respectively.
reg_lap : float, optional (default=1)
Laplacian regularization parameter
reg_src : float, optional (default=0.5)
@@ -1756,24 +1794,27 @@ class EMDLaplaceTransport(BaseTransport):
out_of_sample_map : string, optional (default="ferradans")
The kind of out of sample mapping to apply to transport samples
from a domain into another one. Currently the only possible option is
- "ferradans" which uses the method proposed in [6].
+ "ferradans" which uses the method proposed in :ref:`[6] <references-emdlaplacetransport>`.
Attributes
----------
coupling_ : array-like, shape (n_source_samples, n_target_samples)
The optimal coupling
+
+ .. _references-emdlaplacetransport:
References
----------
.. [1] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy,
"Optimal Transport for Domain Adaptation," in IEEE Transactions
on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1
+
.. [2] R. Flamary, N. Courty, D. Tuia, A. Rakotomamonjy,
"Optimal transport with Laplacian regularization: Applications to domain adaptation and shape matching,"
- in NIPS Workshop on Optimal Transport and Machine Learning OTML, 2014.
+ in NIPS Workshop on Optimal Transport and Machine Learning OTML, 2014.
+
.. [6] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014).
- Regularized discrete optimal transport. SIAM Journal on Imaging
- Sciences, 7(3), 1853-1882.
+ Regularized discrete optimal transport. SIAM Journal on Imaging Sciences, 7(3), 1853-1882.
"""
def __init__(self, reg_type='pos', reg_lap=1., reg_src=1., metric="sqeuclidean",
@@ -1799,7 +1840,7 @@ class EMDLaplaceTransport(BaseTransport):
def fit(self, Xs, ys=None, Xt=None, yt=None):
"""Build a coupling matrix from source and target sets of samples
- (Xs, ys) and (Xt, yt)
+ :math:`(\mathbf{X_s}, \mathbf{y_s})` and :math:`(\mathbf{X_t}, \mathbf{y_t})`
Parameters
----------
@@ -1810,8 +1851,8 @@ class EMDLaplaceTransport(BaseTransport):
Xt : array-like, shape (n_target_samples, n_features)
The training input samples.
yt : array-like, shape (n_target_samples,)
- The class labels. If some target samples are unlabeled, fill the
- yt's elements with -1.
+ The class labels. If some target samples are unlabelled, fill the
+ :math:`\mathbf{y_t}`'s elements with -1.
Warning: Note that, due to this convention -1 cannot be used as a
class label
@@ -1840,8 +1881,8 @@ class EMDLaplaceTransport(BaseTransport):
class SinkhornL1l2Transport(BaseTransport):
- """Domain Adapatation OT method based on sinkhorn algorithm +
- l1l2 class regularization.
+ """Domain Adaptation OT method based on sinkhorn algorithm +
+ L1L2 class regularization.
Parameters
----------
@@ -1851,7 +1892,7 @@ class SinkhornL1l2Transport(BaseTransport):
Class regularization parameter
max_iter : int, float, optional (default=10)
The minimum number of iteration before stopping the optimization
- algorithm if no it has not converged
+ algorithm if it has not converged
max_inner_iter : int, float, optional (default=200)
The number of iteration in the inner loop
tol : float, optional (default=10e-9)
@@ -1870,7 +1911,7 @@ class SinkhornL1l2Transport(BaseTransport):
out_of_sample_map : string, optional (default="ferradans")
The kind of out of sample mapping to apply to transport samples
from a domain into another one. Currently the only possible option is
- "ferradans" which uses the method proposed in [6].
+ "ferradans" which uses the method proposed in :ref:`[6] <references-sinkhornl1l2transport>`.
limit_max: float, optional (default=10)
Controls the semi supervised mode. Transport between labeled source
and target samples of different classes will exhibit an infinite cost
@@ -1881,18 +1922,21 @@ class SinkhornL1l2Transport(BaseTransport):
coupling_ : array-like, shape (n_source_samples, n_target_samples)
The optimal coupling
log_ : dictionary
- The dictionary of log, empty dic if parameter log is not True
+ The dictionary of log, empty dict if parameter log is not True
+
+ .. _references-sinkhornl1l2transport:
References
----------
-
.. [1] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy,
"Optimal Transport for Domain Adaptation," in IEEE
Transactions on Pattern Analysis and Machine Intelligence ,
vol.PP, no.99, pp.1-1
+
.. [2] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015).
Generalized conditional gradient: analysis of convergence
and applications. arXiv preprint arXiv:1510.06567.
+
.. [6] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014).
Regularized discrete optimal transport. SIAM Journal on Imaging
Sciences, 7(3), 1853-1882.
@@ -1919,7 +1963,7 @@ class SinkhornL1l2Transport(BaseTransport):
def fit(self, Xs, ys=None, Xt=None, yt=None):
"""Build a coupling matrix from source and target sets of samples
- (Xs, ys) and (Xt, yt)
+ :math:`(\mathbf{X_s}, \mathbf{y_s})` and :math:`(\mathbf{X_t}, \mathbf{y_t})`
Parameters
----------
@@ -1930,8 +1974,8 @@ class SinkhornL1l2Transport(BaseTransport):
Xt : array-like, shape (n_target_samples, n_features)
The training input samples.
yt : array-like, shape (n_target_samples,)
- The class labels. If some target samples are unlabeled, fill the
- yt's elements with -1.
+ The class labels. If some target samples are unlabelled, fill the
+ :math:`\mathbf{y_t}`'s elements with -1.
Warning: Note that, due to this convention -1 cannot be used as a
class label
@@ -1973,7 +2017,7 @@ class MappingTransport(BaseEstimator):
mu : float, optional (default=1)
Weight for the linear OT loss (>0)
eta : float, optional (default=0.001)
- Regularization term for the linear mapping L (>0)
+ Regularization term for the linear mapping `L` (>0)
bias : bool, optional (default=False)
Estimate linear mapping with constant bias
metric : string, optional (default="sqeuclidean")
@@ -2004,17 +2048,20 @@ class MappingTransport(BaseEstimator):
----------
coupling_ : array-like, shape (n_source_samples, n_target_samples)
The optimal coupling
- mapping_ : array-like, shape (n_features (+ 1), n_features)
- (if bias) for kernel == linear
+ mapping_ :
The associated mapping
- array-like, shape (n_source_samples (+ 1), n_features)
- (if bias) for kernel == gaussian
+
+ - array-like, shape (`n_features` (+ 1), `n_features`),
+ (if bias) for kernel == linear
+
+ - array-like, shape (`n_source_samples` (+ 1), `n_features`),
+ (if bias) for kernel == gaussian
log_ : dictionary
- The dictionary of log, empty dic if parameter log is not True
+ The dictionary of log, empty dict if parameter log is not True
+
References
----------
-
.. [8] M. Perrot, N. Courty, R. Flamary, A. Habrard,
"Mapping estimation for discrete optimal transport",
Neural Information Processing Systems (NIPS), 2016.
@@ -2042,7 +2089,8 @@ class MappingTransport(BaseEstimator):
def fit(self, Xs=None, ys=None, Xt=None, yt=None):
"""Builds an optimal coupling and estimates the associated mapping
- from source and target sets of samples (Xs, ys) and (Xt, yt)
+ from source and target sets of samples
+ :math:`(\mathbf{X_s}, \mathbf{y_s})` and :math:`(\mathbf{X_t}, \mathbf{y_t})`
Parameters
----------
@@ -2053,8 +2101,8 @@ class MappingTransport(BaseEstimator):
Xt : array-like, shape (n_target_samples, n_features)
The training input samples.
yt : array-like, shape (n_target_samples,)
- The class labels. If some target samples are unlabeled, fill the
- yt's elements with -1.
+ The class labels. If some target samples are unlabelled, fill the
+ :math:`\mathbf{y_t}`'s elements with -1.
Warning: Note that, due to this convention -1 cannot be used as a
class label
@@ -2098,7 +2146,7 @@ class MappingTransport(BaseEstimator):
return self
def transform(self, Xs):
- """Transports source samples Xs onto target ones Xt
+ """Transports source samples :math:`\mathbf{X_s}` onto target ones :math:`\mathbf{X_t}`
Parameters
----------
@@ -2138,7 +2186,7 @@ class MappingTransport(BaseEstimator):
class UnbalancedSinkhornTransport(BaseTransport):
- """Domain Adapatation unbalanced OT method based on sinkhorn algorithm
+ """Domain Adaptation unbalanced OT method based on sinkhorn algorithm
Parameters
----------
@@ -2151,7 +2199,7 @@ class UnbalancedSinkhornTransport(BaseTransport):
'sinkhorn_epsilon_scaling', see those function for specific parameters
max_iter : int, float, optional (default=10)
The minimum number of iteration before stopping the optimization
- algorithm if no it has not converged
+ algorithm if it has not converged
tol : float, optional (default=10e-9)
Stop threshold on error (inner sinkhorn solver) (>0)
verbose : bool, optional (default=False)
@@ -2168,7 +2216,7 @@ class UnbalancedSinkhornTransport(BaseTransport):
out_of_sample_map : string, optional (default="ferradans")
The kind of out of sample mapping to apply to transport samples
from a domain into another one. Currently the only possible option is
- "ferradans" which uses the method proposed in [6].
+ "ferradans" which uses the method proposed in :ref:`[6] <references-unbalancedsinkhorntransport>`.
limit_max: float, optional (default=10)
Controls the semi supervised mode. Transport between labeled source
and target samples of different classes will exhibit an infinite cost
@@ -2179,14 +2227,16 @@ class UnbalancedSinkhornTransport(BaseTransport):
coupling_ : array-like, shape (n_source_samples, n_target_samples)
The optimal coupling
log_ : dictionary
- The dictionary of log, empty dic if parameter log is not True
+ The dictionary of log, empty dict if parameter log is not True
+
+ .. _references-unbalancedsinkhorntransport:
References
----------
-
.. [1] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
- Scaling algorithms for unbalanced transport problems. arXiv preprint
- arXiv:1607.05816.
+ Scaling algorithms for unbalanced transport problems. arXiv preprint
+ arXiv:1607.05816.
+
.. [6] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014).
Regularized discrete optimal transport. SIAM Journal on Imaging
Sciences, 7(3), 1853-1882.
@@ -2212,7 +2262,7 @@ class UnbalancedSinkhornTransport(BaseTransport):
def fit(self, Xs, ys=None, Xt=None, yt=None):
"""Build a coupling matrix from source and target sets of samples
- (Xs, ys) and (Xt, yt)
+ :math:`(\mathbf{X_s}, \mathbf{y_s})` and :math:`(\mathbf{X_t}, \mathbf{y_t})`
Parameters
----------
@@ -2223,8 +2273,8 @@ class UnbalancedSinkhornTransport(BaseTransport):
Xt : array-like, shape (n_target_samples, n_features)
The training input samples.
yt : array-like, shape (n_target_samples,)
- The class labels. If some target samples are unlabeled, fill the
- yt's elements with -1.
+ The class labels. If some target samples are unlabelled, fill the
+ :math:`\mathbf{y_t}`'s elements with -1.
Warning: Note that, due to this convention -1 cannot be used as a
class label
@@ -2258,7 +2308,7 @@ class UnbalancedSinkhornTransport(BaseTransport):
class JCPOTTransport(BaseTransport):
- """Domain Adapatation OT method for multi-source target shift based on Wasserstein barycenter algorithm.
+ """Domain Adaptation OT method for multi-source target shift based on Wasserstein barycenter algorithm.
Parameters
----------
@@ -2266,7 +2316,7 @@ class JCPOTTransport(BaseTransport):
Entropic regularization parameter
max_iter : int, float, optional (default=10)
The minimum number of iteration before stopping the optimization
- algorithm if no it has not converged
+ algorithm if it has not converged
tol : float, optional (default=10e-9)
Stop threshold on error (inner sinkhorn solver) (>0)
verbose : bool, optional (default=False)
@@ -2283,7 +2333,7 @@ class JCPOTTransport(BaseTransport):
out_of_sample_map : string, optional (default="ferradans")
The kind of out of sample mapping to apply to transport samples
from a domain into another one. Currently the only possible option is
- "ferradans" which uses the method proposed in [6].
+ "ferradans" which uses the method proposed in :ref:`[6] <references-jcpottransport>`.
Attributes
----------
@@ -2292,11 +2342,12 @@ class JCPOTTransport(BaseTransport):
proportions_ : array-like, shape (n_classes,)
Estimated class proportions in the target domain
log_ : dictionary
- The dictionary of log, empty dic if parameter log is not True
+ The dictionary of log, empty dict if parameter log is not True
+
+ .. _references-jcpottransport:
References
----------
-
.. [1] Ievgen Redko, Nicolas Courty, Rémi Flamary, Devis Tuia
"Optimal transport for multi-source domain adaptation under target shift",
International Conference on Artificial Intelligence and Statistics (AISTATS),
@@ -2323,7 +2374,7 @@ class JCPOTTransport(BaseTransport):
def fit(self, Xs, ys=None, Xt=None, yt=None):
"""Building coupling matrices from a list of source and target sets of samples
- (Xs, ys) and (Xt, yt)
+ :math:`(\mathbf{X_s}, \mathbf{y_s})` and :math:`(\mathbf{X_t}, \mathbf{y_t})`
Parameters
----------
@@ -2334,8 +2385,8 @@ class JCPOTTransport(BaseTransport):
Xt : array-like, shape (n_target_samples, n_features)
The training input samples.
yt : array-like, shape (n_target_samples,)
- The class labels. If some target samples are unlabeled, fill the
- yt's elements with -1.
+ The class labels. If some target samples are unlabelled, fill the
+ :math:`\mathbf{y_t}`'s elements with -1.
Warning: Note that, due to this convention -1 cannot be used as a
class label
@@ -2368,7 +2419,7 @@ class JCPOTTransport(BaseTransport):
return self
def transform(self, Xs=None, ys=None, Xt=None, yt=None, batch_size=128):
- """Transports source samples Xs onto target ones Xt
+ """Transports source samples :math:`\mathbf{X_s}` onto target ones :math:`\mathbf{X_t}`
Parameters
----------
@@ -2379,8 +2430,8 @@ class JCPOTTransport(BaseTransport):
Xt : array-like, shape (n_target_samples, n_features)
The training input samples.
yt : array-like, shape (n_target_samples,)
- The class labels. If some target samples are unlabeled, fill the
- yt's elements with -1.
+ The class labels. If some target samples are unlabelled, fill the
+ :math:`\mathbf{y_t}`'s elements with -1.
Warning: Note that, due to this convention -1 cannot be used as a
class label
@@ -2440,7 +2491,8 @@ class JCPOTTransport(BaseTransport):
return transp_Xs
def transform_labels(self, ys=None):
- """Propagate source labels ys to obtain target labels as in [27]
+ """Propagate source labels :math:`\mathbf{y_s}` to obtain target labels as in
+ :ref:`[27] <references-jcpottransport-transform-labels>`
Parameters
----------
@@ -2451,6 +2503,14 @@ class JCPOTTransport(BaseTransport):
-------
yt : array-like, shape (n_target_samples, nb_classes)
Estimated soft target labels.
+
+
+ .. _references-jcpottransport-transform-labels:
+ References
+ ----------
+ .. [27] Ievgen Redko, Nicolas Courty, Rémi Flamary, Devis Tuia
+ "Optimal transport for multi-source domain adaptation under target shift",
+ International Conference on Artificial Intelligence and Statistics (AISTATS), 2019.
"""
# check the necessary inputs parameters are here
@@ -2482,11 +2542,12 @@ class JCPOTTransport(BaseTransport):
return yt.T
def inverse_transform_labels(self, yt=None):
- """Propagate source labels ys to obtain target labels
+ """Propagate target labels :math:`\mathbf{y_t}` to obtain estimated source labels
+ :math:`\mathbf{y_s}`
Parameters
----------
- yt : array-like, shape (n_source_samples,)
+ yt : array-like, shape (n_target_samples,)
The target class labels
Returns
diff --git a/ot/datasets.py b/ot/datasets.py
index b86ef3b..ad6390c 100644
--- a/ot/datasets.py
+++ b/ot/datasets.py
@@ -13,7 +13,7 @@ from .utils import check_random_state, deprecated
def make_1D_gauss(n, m, s):
- """return a 1D histogram for a gaussian distribution (n bins, mean m and std s)
+ """return a 1D histogram for a gaussian distribution (`n` bins, mean `m` and std `s`)
Parameters
----------
@@ -26,7 +26,7 @@ def make_1D_gauss(n, m, s):
Returns
-------
- h : ndarray (n,)
+ h : ndarray (`n`,)
1D histogram for a gaussian distribution
"""
x = np.arange(n, dtype=np.float64)
@@ -41,7 +41,7 @@ def get_1D_gauss(n, m, sigma):
def make_2D_samples_gauss(n, m, sigma, random_state=None):
- """Return n samples drawn from 2D gaussian N(m,sigma)
+ """Return `n` samples drawn from 2D gaussian :math:`\mathcal{N}(m, \sigma)`
Parameters
----------
@@ -59,8 +59,8 @@ def make_2D_samples_gauss(n, m, sigma, random_state=None):
Returns
-------
- X : ndarray, shape (n, 2)
- n samples drawn from N(m, sigma).
+ X : ndarray, shape (`n`, 2)
+ n samples drawn from :math:`\mathcal{N}(m, \sigma)`.
"""
generator = check_random_state(random_state)
@@ -102,7 +102,7 @@ def make_data_classif(dataset, n, nz=.5, theta=0, p=.5, random_state=None, **kwa
Returns
-------
X : ndarray, shape (n, d)
- n observation of size d
+ `n` observation of size `d`
y : ndarray, shape (n,)
labels of the samples.
"""
diff --git a/ot/dr.py b/ot/dr.py
index 11d2e10..c2f51f8 100644
--- a/ot/dr.py
+++ b/ot/dr.py
@@ -10,6 +10,7 @@ Dimension reduction with OT
"""
# Author: Remi Flamary <remi.flamary@unice.fr>
+# Minhui Huang <mhhuang@ucdavis.edu>
#
# License: MIT License
@@ -21,7 +22,7 @@ from pymanopt.solvers import SteepestDescent, TrustRegions
def dist(x1, x2):
- """ Compute squared euclidean distance between samples (autograd)
+ r""" Compute squared euclidean distance between samples (autograd)
"""
x1p2 = np.sum(np.square(x1), 1)
x2p2 = np.sum(np.square(x2), 1)
@@ -29,7 +30,7 @@ def dist(x1, x2):
def sinkhorn(w1, w2, M, reg, k):
- """Sinkhorn algorithm with fixed number of iteration (autograd)
+ r"""Sinkhorn algorithm with fixed number of iteration (autograd)
"""
K = np.exp(-M / reg)
ui = np.ones((M.shape[0],))
@@ -42,14 +43,14 @@ def sinkhorn(w1, w2, M, reg, k):
def split_classes(X, y):
- """split samples in X by classes in y
+ r"""split samples in :math:`\mathbf{X}` by classes in :math:`\mathbf{y}`
"""
lstsclass = np.unique(y)
return [X[y == i, :].astype(np.float32) for i in lstsclass]
def fda(X, y, p=2, reg=1e-16):
- """Fisher Discriminant Analysis
+ r"""Fisher Discriminant Analysis
Parameters
----------
@@ -108,20 +109,21 @@ def fda(X, y, p=2, reg=1e-16):
return Popt, proj
-def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None):
- """
- Wasserstein Discriminant Analysis [11]_
+def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None, normalize=False):
+ r"""
+ Wasserstein Discriminant Analysis :ref:`[11] <references-wda>`
The function solves the following optimization problem:
.. math::
- P = \\text{arg}\min_P \\frac{\\sum_i W(PX^i,PX^i)}{\\sum_{i,j\\neq i} W(PX^i,PX^j)}
+ \mathbf{P} = \mathop{\arg \min}_\mathbf{P} \quad
+ \frac{\sum\limits_i W(P \mathbf{X}^i, P \mathbf{X}^i)}{\sum\limits_{i, j \neq i} W(P \mathbf{X}^i, P \mathbf{X}^j)}
where :
- - :math:`P` is a linear projection operator in the Stiefel(p,d) manifold
+ - :math:`P` is a linear projection operator in the Stiefel(`p`, `d`) manifold
- :math:`W` is entropic regularized Wasserstein distances
- - :math:`X^i` are samples in the dataset corresponding to class i
+ - :math:`\mathbf{X}^i` are samples in the dataset corresponding to class i
Parameters
----------
@@ -138,6 +140,8 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None):
else should be a pymanopt.solvers
P0 : ndarray, shape (d, p)
Initial starting point for projection.
+ normalize : bool, optional
+ Normalise the Wasserstaiun distance by the average distance on P0 (default : False)
verbose : int, optional
Print information along iterations.
@@ -148,6 +152,8 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None):
proj : callable
Projection function including mean centering.
+
+ .. _references-wda:
References
----------
.. [11] Flamary, R., Cuturi, M., Courty, N., & Rakotomamonjy, A. (2016).
@@ -163,6 +169,18 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None):
# compute uniform weighs
wc = [np.ones((x.shape[0]), dtype=np.float32) / x.shape[0] for x in xc]
+ # pre-compute reg_c,c'
+ if P0 is not None and normalize:
+ regmean = np.zeros((len(xc), len(xc)))
+ for i, xi in enumerate(xc):
+ xi = np.dot(xi, P0)
+ for j, xj in enumerate(xc[i:]):
+ xj = np.dot(xj, P0)
+ M = dist(xi, xj)
+ regmean[i, j] = np.sum(M) / (len(xi) * len(xj))
+ else:
+ regmean = np.ones((len(xc), len(xc)))
+
def cost(P):
# wda loss
loss_b = 0
@@ -173,7 +191,7 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None):
for j, xj in enumerate(xc[i:]):
xj = np.dot(xj, P)
M = dist(xi, xj)
- G = sinkhorn(wc[i], wc[j + i], M, reg, k)
+ G = sinkhorn(wc[i], wc[j + i], M, reg * regmean[i, j], k)
if j == 0:
loss_w += np.sum(G * M)
else:
@@ -198,3 +216,119 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None):
return (X - mx.reshape((1, -1))).dot(Popt)
return Popt, proj
+
+
+def projection_robust_wasserstein(X, Y, a, b, tau, U0=None, reg=0.1, k=2, stopThr=1e-3, maxiter=100, verbose=0):
+ r"""
+ Projection Robust Wasserstein Distance :ref:`[32] <references-projection-robust-wasserstein>`
+
+ The function solves the following optimization problem:
+
+ .. math::
+ \max_{U \in St(d, k)} \ \min_{\pi \in \Pi(\mu,\nu)} \quad \sum_{i,j} \pi_{i,j}
+ \|U^T(\mathbf{x}_i - \mathbf{y}_j)\|^2 - \mathrm{reg} \cdot H(\pi)
+
+ - :math:`U` is a linear projection operator in the Stiefel(`d`, `k`) manifold
+ - :math:`H(\pi)` is entropy regularizer
+ - :math:`\mathbf{x}_i`, :math:`\mathbf{y}_j` are samples of measures :math:`\mu` and :math:`\nu` respectively
+
+ Parameters
+ ----------
+ X : ndarray, shape (n, d)
+ Samples from measure :math:`\mu`
+ Y : ndarray, shape (n, d)
+ Samples from measure :math:`\nu`
+ a : ndarray, shape (n, )
+ weights for measure :math:`\mu`
+ b : ndarray, shape (n, )
+ weights for measure :math:`\nu`
+ tau : float
+ stepsize for Riemannian Gradient Descent
+ U0 : ndarray, shape (d, p)
+ Initial starting point for projection.
+ reg : float, optional
+ Regularization term >0 (entropic regularization)
+ k : int
+ Subspace dimension
+ stopThr : float, optional
+ Stop threshold on error (>0)
+ verbose : int, optional
+ Print information along iterations.
+
+ Returns
+ -------
+ pi : ndarray, shape (n, n)
+ Optimal transportation matrix for the given parameters
+ U : ndarray, shape (d, k)
+ Projection operator.
+
+
+ .. _references-projection-robust-wasserstein:
+ References
+ ----------
+ .. [32] Huang, M. , Ma S. & Lai L. (2021).
+ A Riemannian Block Coordinate Descent Method for Computing
+ the Projection Robust Wasserstein Distance, ICML.
+ """ # noqa
+
+ # initialization
+ n, d = X.shape
+ m, d = Y.shape
+ a = np.asarray(a, dtype=np.float64)
+ b = np.asarray(b, dtype=np.float64)
+ u = np.ones(n) / n
+ v = np.ones(m) / m
+ ones = np.ones((n, m))
+
+ assert d > k
+
+ if U0 is None:
+ U = np.random.randn(d, k)
+ U, _ = np.linalg.qr(U)
+ else:
+ U = U0
+
+ def Vpi(X, Y, a, b, pi):
+ # Return the second order matrix of the displacements: sum_ij { (pi)_ij (X_i-Y_j)(X_i-Y_j)^T }.
+ A = X.T.dot(pi).dot(Y)
+ return X.T.dot(np.diag(a)).dot(X) + Y.T.dot(np.diag(np.sum(pi, 0))).dot(Y) - A - A.T
+
+ err = 1
+ iter = 0
+
+ while err > stopThr and iter < maxiter:
+
+ # Projected cost matrix
+ UUT = U.dot(U.T)
+ M = np.diag(np.diag(X.dot(UUT.dot(X.T)))).dot(ones) + ones.dot(
+ np.diag(np.diag(Y.dot(UUT.dot(Y.T))))) - 2 * X.dot(UUT.dot(Y.T))
+
+ A = np.empty(M.shape, dtype=M.dtype)
+ np.divide(M, -reg, out=A)
+ np.exp(A, out=A)
+
+ # Sinkhorn update
+ Ap = (1 / a).reshape(-1, 1) * A
+ AtransposeU = np.dot(A.T, u)
+ v = np.divide(b, AtransposeU)
+ u = 1. / np.dot(Ap, v)
+ pi = u.reshape((-1, 1)) * A * v.reshape((1, -1))
+
+ V = Vpi(X, Y, a, b, pi)
+
+ # Riemannian gradient descent
+ G = 2 / reg * V.dot(U)
+ GTU = G.T.dot(U)
+ xi = G - U.dot(GTU + GTU.T) / 2 # Riemannian gradient
+ U, _ = np.linalg.qr(U + tau * xi) # Retraction by QR decomposition
+
+ grad_norm = np.linalg.norm(xi)
+ err = max(reg * grad_norm, np.linalg.norm(np.sum(pi, 0) - b, 1))
+
+ f_val = np.trace(U.T.dot(V.dot(U)))
+ if verbose:
+ print('RBCD Iteration: ', iter, ' error', err, '\t fval: ', f_val)
+
+ iter = iter + 1
+
+ return pi, U
diff --git a/ot/gpu/__init__.py b/ot/gpu/__init__.py
index 7478fb9..12db605 100644
--- a/ot/gpu/__init__.py
+++ b/ot/gpu/__init__.py
@@ -7,7 +7,13 @@ The GPU backend in handled by `cupy
<https://cupy.chainer.org/>`_.
.. warning::
- Note that by default the module is not import in :mod:`ot`. In order to
+ This module is now deprecated and will be removed in future releases. POT
+ now privides a backend mechanism that allows for solving prolem on GPU wth
+ the pytorch backend.
+
+
+.. warning::
+ Note that by default the module is not imported in :mod:`ot`. In order to
use it you need to explicitely import :mod:`ot.gpu` .
By default, the functions in this module accept and return numpy arrays
@@ -25,6 +31,8 @@ result of the function with parameter ``to_numpy=False``.
#
# License: MIT License
+import warnings
+
from . import bregman
from . import da
from .bregman import sinkhorn
@@ -34,7 +42,7 @@ from . import utils
from .utils import dist, to_gpu, to_np
-
+warnings.warn('This module is deprecated and will be removed in the next minor release of POT', category=DeprecationWarning)
__all__ = ["utils", "dist", "sinkhorn",
diff --git a/ot/gpu/bregman.py b/ot/gpu/bregman.py
index 2e2df83..76af00e 100644
--- a/ot/gpu/bregman.py
+++ b/ot/gpu/bregman.py
@@ -15,7 +15,7 @@ from . import utils
def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9,
verbose=False, log=False, to_numpy=True, **kwargs):
- """
+ r"""
Solve the entropic regularization optimal transport on GPU
If the input matrix are in numpy format, they will be uploaded to the
@@ -54,7 +54,7 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9,
numItermax : int, optional
Max number of iterations
stopThr : float, optional
- Stop threshol on error (>0)
+ Stop threshold on error (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
@@ -148,13 +148,15 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9,
# we can speed up the process by checking for the error only all
# the 10th iterations
if nbb:
- err = np.sum((u - uprev)**2) / np.sum((u)**2) + \
- np.sum((v - vprev)**2) / np.sum((v)**2)
+ err = np.sqrt(
+ np.sum((u - uprev)**2) / np.sum((u)**2)
+ + np.sum((v - vprev)**2) / np.sum((v)**2)
+ )
else:
# compute right marginal tmp2= (diag(u)Kdiag(v))^T1
tmp2 = np.sum(u[:, None] * K * v[None, :], 0)
#tmp2=np.einsum('i,ij,j->j', u, K, v)
- err = np.linalg.norm(tmp2 - b)**2 # violation of marginal
+ err = np.linalg.norm(tmp2 - b) # violation of marginal
if log:
log['err'].append(err)
diff --git a/ot/gpu/da.py b/ot/gpu/da.py
index 4a98038..7adb830 100644
--- a/ot/gpu/da.py
+++ b/ot/gpu/da.py
@@ -120,7 +120,7 @@ def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10,
labels_a2 = cp.asnumpy(labels_a)
classes = npp.unique(labels_a2)
for c in classes:
- idxc, = utils.to_gpu(npp.where(labels_a2 == c))
+ idxc = utils.to_gpu(*npp.where(labels_a2 == c))
indices_labels.append(idxc)
W = np.zeros(M.shape)
diff --git a/ot/gromov.py b/ot/gromov.py
index 4427a96..ea667e4 100644
--- a/ot/gromov.py
+++ b/ot/gromov.py
@@ -14,63 +14,85 @@ import numpy as np
from .bregman import sinkhorn
-from .utils import dist, UndefinedParameter
+from .utils import dist, UndefinedParameter, list_to_array
from .optim import cg
+from .lp import emd_1d, emd
+from .utils import check_random_state
+from .backend import get_backend
def init_matrix(C1, C2, p, q, loss_fun='square_loss'):
- """Return loss matrices and tensors for Gromov-Wasserstein fast computation
+ r"""Return loss matrices and tensors for Gromov-Wasserstein fast computation
- Returns the value of \mathcal{L}(C1,C2) \otimes T with the selected loss
- function as the loss function of Gromow-Wasserstein discrepancy.
+ Returns the value of :math:`\mathcal{L}(\mathbf{C_1}, \mathbf{C_2}) \otimes \mathbf{T}` with the
+ selected loss function as the loss function of Gromow-Wasserstein discrepancy.
- The matrices are computed as described in Proposition 1 in [12]
+ The matrices are computed as described in Proposition 1 in :ref:`[12] <references-init-matrix>`
Where :
- * C1 : Metric cost matrix in the source space
- * C2 : Metric cost matrix in the target space
- * T : A coupling between those two spaces
-
- The square-loss function L(a,b)=|a-b|^2 is read as :
- L(a,b) = f1(a)+f2(b)-h1(a)*h2(b) with :
- * f1(a)=(a^2)
- * f2(b)=(b^2)
- * h1(a)=a
- * h2(b)=2*b
-
- The kl-loss function L(a,b)=a*log(a/b)-a+b is read as :
- L(a,b) = f1(a)+f2(b)-h1(a)*h2(b) with :
- * f1(a)=a*log(a)-a
- * f2(b)=b
- * h1(a)=a
- * h2(b)=log(b)
+
+ - :math:`\mathbf{C_1}`: Metric cost matrix in the source space
+ - :math:`\mathbf{C_2}`: Metric cost matrix in the target space
+ - :math:`\mathbf{T}`: A coupling between those two spaces
+
+ The square-loss function :math:`L(a, b) = |a - b|^2` is read as :
+
+ .. math::
+
+ L(a, b) = f_1(a) + f_2(b) - h_1(a) h_2(b)
+
+ \mathrm{with} \ f_1(a) &= a^2
+
+ f_2(b) &= b^2
+
+ h_1(a) &= a
+
+ h_2(b) &= 2b
+
+ The kl-loss function :math:`L(a, b) = a \log\left(\frac{a}{b}\right) - a + b` is read as :
+
+ .. math::
+
+ L(a, b) = f_1(a) + f_2(b) - h_1(a) h_2(b)
+
+ \mathrm{with} \ f_1(a) &= a \log(a) - a
+
+ f_2(b) &= b
+
+ h_1(a) &= a
+
+ h_2(b) &= \log(b)
Parameters
----------
- C1 : ndarray, shape (ns, ns)
+ C1 : array-like, shape (ns, ns)
Metric cost matrix in the source space
- C2 : ndarray, shape (nt, nt)
- Metric costfr matrix in the target space
- T : ndarray, shape (ns, nt)
+ C2 : array-like, shape (nt, nt)
+ Metric cost matrix in the target space
+ T : array-like, shape (ns, nt)
Coupling between source and target spaces
- p : ndarray, shape (ns,)
+ p : array-like, shape (ns,)
Returns
-------
- constC : ndarray, shape (ns, nt)
- Constant C matrix in Eq. (6)
- hC1 : ndarray, shape (ns, ns)
- h1(C1) matrix in Eq. (6)
- hC2 : ndarray, shape (nt, nt)
- h2(C) matrix in Eq. (6)
+ constC : array-like, shape (ns, nt)
+ Constant :math:`\mathbf{C}` matrix in Eq. (6)
+ hC1 : array-like, shape (ns, ns)
+ :math:`\mathbf{h1}(\mathbf{C1})` matrix in Eq. (6)
+ hC2 : array-like, shape (nt, nt)
+ :math:`\mathbf{h2}(\mathbf{C2})` matrix in Eq. (6)
+
+ .. _references-init-matrix:
References
----------
- .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon,
+ .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon,
"Gromov-Wasserstein averaging of kernel and distance matrices."
International Conference on Machine Learning (ICML). 2016.
"""
+ C1, C2, p, q = list_to_array(C1, C2, p, q)
+ nx = get_backend(C1, C2, p, q)
if loss_fun == 'square_loss':
def f1(a):
@@ -86,7 +108,7 @@ def init_matrix(C1, C2, p, q, loss_fun='square_loss'):
return 2 * b
elif loss_fun == 'kl_loss':
def f1(a):
- return a * np.log(a + 1e-15) - a
+ return a * nx.log(a + 1e-15) - a
def f2(b):
return b
@@ -95,12 +117,16 @@ def init_matrix(C1, C2, p, q, loss_fun='square_loss'):
return a
def h2(b):
- return np.log(b + 1e-15)
-
- constC1 = np.dot(np.dot(f1(C1), p.reshape(-1, 1)),
- np.ones(len(q)).reshape(1, -1))
- constC2 = np.dot(np.ones(len(p)).reshape(-1, 1),
- np.dot(q.reshape(1, -1), f2(C2).T))
+ return nx.log(b + 1e-15)
+
+ constC1 = nx.dot(
+ nx.dot(f1(C1), nx.reshape(p, (-1, 1))),
+ nx.ones((1, len(q)), type_as=q)
+ )
+ constC2 = nx.dot(
+ nx.ones((len(p), 1), type_as=p),
+ nx.dot(nx.reshape(q, (1, -1)), f2(C2).T)
+ )
constC = constC1 + constC2
hC1 = h1(C1)
hC2 = h2(C2)
@@ -109,61 +135,70 @@ def init_matrix(C1, C2, p, q, loss_fun='square_loss'):
def tensor_product(constC, hC1, hC2, T):
- """Return the tensor for Gromov-Wasserstein fast computation
+ r"""Return the tensor for Gromov-Wasserstein fast computation
- The tensor is computed as described in Proposition 1 Eq. (6) in [12].
+ The tensor is computed as described in Proposition 1 Eq. (6) in :ref:`[12] <references-tensor-product>`
Parameters
----------
- constC : ndarray, shape (ns, nt)
- Constant C matrix in Eq. (6)
- hC1 : ndarray, shape (ns, ns)
- h1(C1) matrix in Eq. (6)
- hC2 : ndarray, shape (nt, nt)
- h2(C) matrix in Eq. (6)
+ constC : array-like, shape (ns, nt)
+ Constant :math:`\mathbf{C}` matrix in Eq. (6)
+ hC1 : array-like, shape (ns, ns)
+ :math:`\mathbf{h1}(\mathbf{C1})` matrix in Eq. (6)
+ hC2 : array-like, shape (nt, nt)
+ :math:`\mathbf{h2}(\mathbf{C2})` matrix in Eq. (6)
Returns
-------
- tens : ndarray, shape (ns, nt)
- \mathcal{L}(C1,C2) \otimes T tensor-matrix multiplication result
+ tens : array-like, shape (`ns`, `nt`)
+ :math:`\mathcal{L}(\mathbf{C_1}, \mathbf{C_2}) \otimes \mathbf{T}` tensor-matrix multiplication result
+
+ .. _references-tensor-product:
References
----------
- .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon,
+ .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon,
"Gromov-Wasserstein averaging of kernel and distance matrices."
International Conference on Machine Learning (ICML). 2016.
"""
- A = -np.dot(hC1, T).dot(hC2.T)
+ constC, hC1, hC2, T = list_to_array(constC, hC1, hC2, T)
+ nx = get_backend(constC, hC1, hC2, T)
+
+ A = - nx.dot(
+ nx.dot(hC1, T), hC2.T
+ )
tens = constC + A
# tens -= tens.min()
return tens
def gwloss(constC, hC1, hC2, T):
- """Return the Loss for Gromov-Wasserstein
+ r"""Return the Loss for Gromov-Wasserstein
- The loss is computed as described in Proposition 1 Eq. (6) in [12].
+ The loss is computed as described in Proposition 1 Eq. (6) in :ref:`[12] <references-gwloss>`
Parameters
----------
- constC : ndarray, shape (ns, nt)
- Constant C matrix in Eq. (6)
- hC1 : ndarray, shape (ns, ns)
- h1(C1) matrix in Eq. (6)
- hC2 : ndarray, shape (nt, nt)
- h2(C) matrix in Eq. (6)
- T : ndarray, shape (ns, nt)
- Current value of transport matrix T
+ constC : array-like, shape (ns, nt)
+ Constant :math:`\mathbf{C}` matrix in Eq. (6)
+ hC1 : array-like, shape (ns, ns)
+ :math:`\mathbf{h1}(\mathbf{C1})` matrix in Eq. (6)
+ hC2 : array-like, shape (nt, nt)
+ :math:`\mathbf{h2}(\mathbf{C2})` matrix in Eq. (6)
+ T : array-like, shape (ns, nt)
+ Current value of transport matrix :math:`\mathbf{T}`
Returns
-------
loss : float
Gromov Wasserstein loss
+
+ .. _references-gwloss:
References
----------
- .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon,
+ .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon,
"Gromov-Wasserstein averaging of kernel and distance matrices."
International Conference on Machine Learning (ICML). 2016.
@@ -171,33 +206,38 @@ def gwloss(constC, hC1, hC2, T):
tens = tensor_product(constC, hC1, hC2, T)
- return np.sum(tens * T)
+ tens, T = list_to_array(tens, T)
+ nx = get_backend(tens, T)
+
+ return nx.sum(tens * T)
def gwggrad(constC, hC1, hC2, T):
- """Return the gradient for Gromov-Wasserstein
+ r"""Return the gradient for Gromov-Wasserstein
- The gradient is computed as described in Proposition 2 in [12].
+ The gradient is computed as described in Proposition 2 in :ref:`[12] <references-gwggrad>`
Parameters
----------
- constC : ndarray, shape (ns, nt)
- Constant C matrix in Eq. (6)
- hC1 : ndarray, shape (ns, ns)
- h1(C1) matrix in Eq. (6)
- hC2 : ndarray, shape (nt, nt)
- h2(C) matrix in Eq. (6)
- T : ndarray, shape (ns, nt)
- Current value of transport matrix T
+ constC : array-like, shape (ns, nt)
+ Constant :math:`\mathbf{C}` matrix in Eq. (6)
+ hC1 : array-like, shape (ns, ns)
+ :math:`\mathbf{h1}(\mathbf{C1})` matrix in Eq. (6)
+ hC2 : array-like, shape (nt, nt)
+ :math:`\mathbf{h2}(\mathbf{C2})` matrix in Eq. (6)
+ T : array-like, shape (ns, nt)
+ Current value of transport matrix :math:`\mathbf{T}`
Returns
-------
- grad : ndarray, shape (ns, nt)
+ grad : array-like, shape (`ns`, `nt`)
Gromov Wasserstein gradient
+
+ .. _references-gwggrad:
References
----------
- .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon,
+ .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon,
"Gromov-Wasserstein averaging of kernel and distance matrices."
International Conference on Machine Learning (ICML). 2016.
@@ -207,89 +247,109 @@ def gwggrad(constC, hC1, hC2, T):
def update_square_loss(p, lambdas, T, Cs):
- """
- Updates C according to the L2 Loss kernel with the S Ts couplings
- calculated at each iteration
+ r"""
+ Updates :math:`\mathbf{C}` according to the L2 Loss kernel with the `S` :math:`\mathbf{T}_s`
+ couplings calculated at each iteration
Parameters
----------
- p : ndarray, shape (N,)
+ p : array-like, shape (N,)
Masses in the targeted barycenter.
lambdas : list of float
- List of the S spaces' weights.
- T : list of S np.ndarray of shape (ns,N)
- The S Ts couplings calculated at each iteration.
- Cs : list of S ndarray, shape(ns,ns)
+ List of the `S` spaces' weights.
+ T : list of S array-like of shape (ns,N)
+ The `S` :math:`\mathbf{T}_s` couplings calculated at each iteration.
+ Cs : list of S array-like, shape(ns,ns)
Metric cost matrices.
Returns
----------
- C : ndarray, shape (nt, nt)
- Updated C matrix.
+ C : array-like, shape (`nt`, `nt`)
+ Updated :math:`\mathbf{C}` matrix.
"""
- tmpsum = sum([lambdas[s] * np.dot(T[s].T, Cs[s]).dot(T[s])
- for s in range(len(T))])
- ppt = np.outer(p, p)
+ T = list_to_array(*T)
+ Cs = list_to_array(*Cs)
+ p = list_to_array(p)
+ nx = get_backend(p, *T, *Cs)
- return np.divide(tmpsum, ppt)
+ tmpsum = sum([
+ lambdas[s] * nx.dot(
+ nx.dot(T[s].T, Cs[s]),
+ T[s]
+ ) for s in range(len(T))
+ ])
+ ppt = nx.outer(p, p)
+
+ return tmpsum / ppt
def update_kl_loss(p, lambdas, T, Cs):
- """
- Updates C according to the KL Loss kernel with the S Ts couplings calculated at each iteration
+ r"""
+ Updates :math:`\mathbf{C}` according to the KL Loss kernel with the `S` :math:`\mathbf{T}_s` couplings calculated at each iteration
Parameters
----------
- p : ndarray, shape (N,)
+ p : array-like, shape (N,)
Weights in the targeted barycenter.
- lambdas : list of the S spaces' weights
- T : list of S np.ndarray of shape (ns,N)
- The S Ts couplings calculated at each iteration.
- Cs : list of S ndarray, shape(ns,ns)
+ lambdas : list of float
+ List of the `S` spaces' weights
+ T : list of S array-like of shape (ns,N)
+ The `S` :math:`\mathbf{T}_s` couplings calculated at each iteration.
+ Cs : list of S array-like, shape(ns,ns)
Metric cost matrices.
Returns
----------
- C : ndarray, shape (ns,ns)
- updated C matrix
+ C : array-like, shape (`ns`, `ns`)
+ updated :math:`\mathbf{C}` matrix
"""
- tmpsum = sum([lambdas[s] * np.dot(T[s].T, Cs[s]).dot(T[s])
- for s in range(len(T))])
- ppt = np.outer(p, p)
+ Cs = list_to_array(*Cs)
+ T = list_to_array(*T)
+ p = list_to_array(p)
+ nx = get_backend(p, *T, *Cs)
- return np.exp(np.divide(tmpsum, ppt))
+ tmpsum = sum([
+ lambdas[s] * nx.dot(
+ nx.dot(T[s].T, Cs[s]),
+ T[s]
+ ) for s in range(len(T))
+ ])
+ ppt = nx.outer(p, p)
+ return nx.exp(tmpsum / ppt)
-def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs):
- """
- Returns the gromov-wasserstein transport between (C1,p) and (C2,q)
+
+def gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=False, **kwargs):
+ r"""
+ Returns the gromov-wasserstein transport between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})`
The function solves the following optimization problem:
.. math::
- GW = \min_T \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l}
+ \mathbf{GW} = \mathop{\arg \min}_\mathbf{T} \quad \sum_{i,j,k,l}
+ L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l}
Where :
- - C1 : Metric cost matrix in the source space
- - C2 : Metric cost matrix in the target space
- - p : distribution in the source space
- - q : distribution in the target space
- - L : loss function to account for the misfit between the similarity matrices
+
+ - :math:`\mathbf{C_1}`: Metric cost matrix in the source space
+ - :math:`\mathbf{C_2}`: Metric cost matrix in the target space
+ - :math:`\mathbf{p}`: distribution in the source space
+ - :math:`\mathbf{q}`: distribution in the target space
+ - `L`: loss function to account for the misfit between the similarity matrices
Parameters
----------
- C1 : ndarray, shape (ns, ns)
+ C1 : array-like, shape (ns, ns)
Metric cost matrix in the source space
- C2 : ndarray, shape (nt, nt)
- Metric costfr matrix in the target space
- p : ndarray, shape (ns,)
+ C2 : array-like, shape (nt, nt)
+ Metric cost matrix in the target space
+ p : array-like, shape (ns,)
Distribution in the source space
- q : ndarray, shape (nt,)
+ q : array-like, shape (nt,)
Distribution in the target space
loss_fun : str
loss function used for the solver either 'square_loss' or 'kl_loss'
-
max_iter : int, optional
Max number of iterations
tol : float, optional
@@ -299,22 +359,23 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs
log : bool, optional
record log if True
armijo : bool, optional
- If True the steps of the line-search is found via an armijo research. Else closed form is used.
- If there is convergence issues use False.
+ If True the step of the line-search is found via an armijo research. Else closed form is used.
+ If there are convergence issues use False.
**kwargs : dict
parameters can be directly passed to the ot.optim.cg solver
Returns
-------
- T : ndarray, shape (ns, nt)
- Doupling between the two spaces that minimizes:
- \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l}
+ T : array-like, shape (`ns`, `nt`)
+ Coupling between the two spaces that minimizes:
+
+ :math:`\sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l}`
log : dict
Convergence information and loss.
References
----------
- .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon,
+ .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon,
"Gromov-Wasserstein averaging of kernel and distance matrices."
International Conference on Machine Learning (ICML). 2016.
@@ -323,6 +384,15 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs
mathematics 11.4 (2011): 417-487.
"""
+ p, q = list_to_array(p, q)
+
+ p0, q0, C10, C20 = p, q, C1, C2
+ nx = get_backend(p0, q0, C10, C20)
+
+ p = nx.to_numpy(p)
+ q = nx.to_numpy(q)
+ C1 = nx.to_numpy(C10)
+ C2 = nx.to_numpy(C20)
constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun)
@@ -336,37 +406,45 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs
if log:
res, log = cg(p, q, 0, 1, f, df, G0, log=True, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs)
- log['gw_dist'] = gwloss(constC, hC1, hC2, res)
- return res, log
+ log['gw_dist'] = nx.from_numpy(gwloss(constC, hC1, hC2, res), type_as=C10)
+ log['u'] = nx.from_numpy(log['u'], type_as=C10)
+ log['v'] = nx.from_numpy(log['v'], type_as=C10)
+ return nx.from_numpy(res, type_as=C10), log
else:
- return cg(p, q, 0, 1, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs)
+ return nx.from_numpy(cg(p, q, 0, 1, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=False, **kwargs), type_as=C10)
-def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs):
- """
- Returns the gromov-wasserstein discrepancy between (C1,p) and (C2,q)
+def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=False, **kwargs):
+ r"""
+ Returns the gromov-wasserstein discrepancy between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})`
The function solves the following optimization problem:
.. math::
- GW = \min_T \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l}
+ GW = \min_\mathbf{T} \quad \sum_{i,j,k,l}
+ L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l}
Where :
- - C1 : Metric cost matrix in the source space
- - C2 : Metric cost matrix in the target space
- - p : distribution in the source space
- - q : distribution in the target space
- - L : loss function to account for the misfit between the similarity matrices
+
+ - :math:`\mathbf{C_1}`: Metric cost matrix in the source space
+ - :math:`\mathbf{C_2}`: Metric cost matrix in the target space
+ - :math:`\mathbf{p}`: distribution in the source space
+ - :math:`\mathbf{q}`: distribution in the target space
+ - `L`: loss function to account for the misfit between the similarity
+ matrices
+
+ Note that when using backends, this loss function is differentiable wrt the
+ marices and weights for quadratic loss using the gradients from [38]_.
Parameters
----------
- C1 : ndarray, shape (ns, ns)
+ C1 : array-like, shape (ns, ns)
Metric cost matrix in the source space
- C2 : ndarray, shape (nt, nt)
+ C2 : array-like, shape (nt, nt)
Metric cost matrix in the target space
- p : ndarray, shape (ns,)
+ p : array-like, shape (ns,)
Distribution in the source space.
- q : ndarray, shape (nt,)
+ q : array-like, shape (nt,)
Distribution in the target space.
loss_fun : str
loss function used for the solver either 'square_loss' or 'kl_loss'
@@ -379,8 +457,8 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwarg
log : bool, optional
record log if True
armijo : bool, optional
- If True the steps of the line-search is found via an armijo research. Else closed form is used.
- If there is convergence issues use False.
+ If True the step of the line-search is found via an armijo research. Else closed form is used.
+ If there are convergence issues use False.
Returns
-------
@@ -391,7 +469,7 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwarg
References
----------
- .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon,
+ .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon,
"Gromov-Wasserstein averaging of kernel and distance matrices."
International Conference on Machine Learning (ICML). 2016.
@@ -399,7 +477,20 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwarg
metric approach to object matching. Foundations of computational
mathematics 11.4 (2011): 417-487.
+ .. [38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, Online
+ Graph Dictionary Learning, International Conference on Machine Learning
+ (ICML), 2021.
+
"""
+ p, q = list_to_array(p, q)
+
+ p0, q0, C10, C20 = p, q, C1, C2
+ nx = get_backend(p0, q0, C10, C20)
+
+ p = nx.to_numpy(p)
+ q = nx.to_numpy(q)
+ C1 = nx.to_numpy(C10)
+ C2 = nx.to_numpy(C20)
constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun)
@@ -410,53 +501,71 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwarg
def df(G):
return gwggrad(constC, hC1, hC2, G)
- res, log_gw = cg(p, q, 0, 1, f, df, G0, log=True, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs)
- log_gw['gw_dist'] = gwloss(constC, hC1, hC2, res)
- log_gw['T'] = res
+
+ T, log_gw = cg(p, q, 0, 1, f, df, G0, log=True, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs)
+
+ T0 = nx.from_numpy(T, type_as=C10)
+
+ log_gw['gw_dist'] = nx.from_numpy(gwloss(constC, hC1, hC2, T), type_as=C10)
+ log_gw['u'] = nx.from_numpy(log_gw['u'], type_as=C10)
+ log_gw['v'] = nx.from_numpy(log_gw['v'], type_as=C10)
+ log_gw['T'] = T0
+
+ gw = log_gw['gw_dist']
+
+ if loss_fun == 'square_loss':
+ gC1 = nx.from_numpy(2 * C1 * (p[:, None] * p[None, :]) - 2 * T.dot(C2).dot(T.T))
+ gC2 = nx.from_numpy(2 * C2 * (q[:, None] * q[None, :]) - 2 * T.T.dot(C1).dot(T))
+ gw = nx.set_gradients(gw, (p0, q0, C10, C20),
+ (log_gw['u'], log_gw['v'], gC1, gC2))
+
if log:
- return log_gw['gw_dist'], log_gw
+ return gw, log_gw
else:
- return log_gw['gw_dist']
+ return gw
def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, **kwargs):
- """
- Computes the FGW transport between two graphs see [24]
+ r"""
+ Computes the FGW transport between two graphs (see :ref:`[24] <references-fused-gromov-wasserstein>`)
.. math::
- \gamma = arg\min_\gamma (1-\\alpha)*<\gamma,M>_F + \\alpha* \sum_{i,j,k,l}
- L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l}
+ \gamma = \mathop{\arg \min}_\gamma \quad (1 - \alpha) \langle \gamma, \mathbf{M} \rangle_F +
+ \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l}
- s.t. \gamma 1 = p
- \gamma^T 1= q
- \gamma\geq 0
+ s.t. \ \mathbf{\gamma} \mathbf{1} &= \mathbf{p}
+
+ \mathbf{\gamma}^T \mathbf{1} &= \mathbf{q}
+
+ \mathbf{\gamma} &\geq 0
where :
- - M is the (ns,nt) metric cost matrix
- - p and q are source and target weights (sum to 1)
- - L is a loss function to account for the misfit between the similarity matrices
- The algorithm used for solving the problem is conditional gradient as discussed in [24]_
+ - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix
+ - :math:`\mathbf{p}` and :math:`\mathbf{q}` are source and target weights (sum to 1)
+ - `L` is a loss function to account for the misfit between the similarity matrices
+
+ The algorithm used for solving the problem is conditional gradient as discussed in :ref:`[24] <references-fused-gromov-wasserstein>`
Parameters
----------
- M : ndarray, shape (ns, nt)
+ M : array-like, shape (ns, nt)
Metric cost matrix between features across domains
- C1 : ndarray, shape (ns, ns)
+ C1 : array-like, shape (ns, ns)
Metric cost matrix representative of the structure in the source space
- C2 : ndarray, shape (nt, nt)
+ C2 : array-like, shape (nt, nt)
Metric cost matrix representative of the structure in the target space
- p : ndarray, shape (ns,)
+ p : array-like, shape (ns,)
Distribution in the source space
- q : ndarray, shape (nt,)
+ q : array-like, shape (nt,)
Distribution in the target space
loss_fun : str, optional
Loss function used for the solver
alpha : float, optional
Trade-off parameter (0 < alpha < 1)
armijo : bool, optional
- If True the steps of the line-search is found via an armijo research. Else closed form is used.
- If there is convergence issues use False.
+ If True the step of the line-search is found via an armijo research. Else closed form is used.
+ If there are convergence issues use False.
log : bool, optional
record log if True
**kwargs : dict
@@ -464,18 +573,30 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5,
Returns
-------
- gamma : ndarray, shape (ns, nt)
+ gamma : array-like, shape (`ns`, `nt`)
Optimal transportation matrix for the given parameters.
log : dict
Log dictionary return only if log==True in parameters.
+
+ .. _references-fused-gromov-wasserstein:
References
----------
- .. [24] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain
+ .. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain
and Courty Nicolas "Optimal Transport for structured data with
application on graphs", International Conference on Machine Learning
(ICML). 2019.
"""
+ p, q = list_to_array(p, q)
+
+ p0, q0, C10, C20, M0 = p, q, C1, C2, M
+ nx = get_backend(p0, q0, C10, C20, M0)
+
+ p = nx.to_numpy(p)
+ q = nx.to_numpy(q)
+ C1 = nx.to_numpy(C10)
+ C2 = nx.to_numpy(C20)
+ M = nx.to_numpy(M0)
constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun)
@@ -489,69 +610,98 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5,
if log:
res, log = cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs)
- log['fgw_dist'] = log['loss'][::-1][0]
- return res, log
+
+ fgw_dist = nx.from_numpy(log['loss'][-1], type_as=C10)
+
+ log['fgw_dist'] = fgw_dist
+ log['u'] = nx.from_numpy(log['u'], type_as=C10)
+ log['v'] = nx.from_numpy(log['v'], type_as=C10)
+ return nx.from_numpy(res, type_as=C10), log
+
else:
- return cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs)
+ return nx.from_numpy(cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs), type_as=C10)
def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, **kwargs):
- """
- Computes the FGW distance between two graphs see [24]
+ r"""
+ Computes the FGW distance between two graphs see (see :ref:`[24] <references-fused-gromov-wasserstein2>`)
.. math::
- \min_\gamma (1-\\alpha)*<\gamma,M>_F + \\alpha* \sum_{i,j,k,l}
- L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l}
+ \min_\gamma \quad (1 - \alpha) \langle \gamma, \mathbf{M} \rangle_F + \alpha \sum_{i,j,k,l}
+ L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l}
+
+ s.t. \ \mathbf{\gamma} \mathbf{1} &= \mathbf{p}
+ \mathbf{\gamma}^T \mathbf{1} &= \mathbf{q}
- s.t. \gamma 1 = p
- \gamma^T 1= q
- \gamma\geq 0
+ \mathbf{\gamma} &\geq 0
where :
- - M is the (ns,nt) metric cost matrix
- - p and q are source and target weights (sum to 1)
- - L is a loss function to account for the misfit between the similarity matrices
- The algorithm used for solving the problem is conditional gradient as discussed in [1]_
+
+ - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix
+ - :math:`\mathbf{p}` and :math:`\mathbf{q}` are source and target weights (sum to 1)
+ - `L` is a loss function to account for the misfit between the similarity matrices
+
+ The algorithm used for solving the problem is conditional gradient as
+ discussed in :ref:`[24] <references-fused-gromov-wasserstein2>`
+
+ Note that when using backends, this loss function is differentiable wrt the
+ marices and weights for quadratic loss using the gradients from [38]_.
Parameters
----------
- M : ndarray, shape (ns, nt)
+ M : array-like, shape (ns, nt)
Metric cost matrix between features across domains
- C1 : ndarray, shape (ns, ns)
- Metric cost matrix respresentative of the structure in the source space.
- C2 : ndarray, shape (nt, nt)
- Metric cost matrix espresentative of the structure in the target space.
- p : ndarray, shape (ns,)
+ C1 : array-like, shape (ns, ns)
+ Metric cost matrix representative of the structure in the source space.
+ C2 : array-like, shape (nt, nt)
+ Metric cost matrix representative of the structure in the target space.
+ p : array-like, shape (ns,)
Distribution in the source space.
- q : ndarray, shape (nt,)
+ q : array-like, shape (nt,)
Distribution in the target space.
loss_fun : str, optional
Loss function used for the solver.
alpha : float, optional
Trade-off parameter (0 < alpha < 1)
armijo : bool, optional
- If True the steps of the line-search is found via an armijo research.
- Else closed form is used. If there is convergence issues use False.
+ If True the step of the line-search is found via an armijo research.
+ Else closed form is used. If there are convergence issues use False.
log : bool, optional
Record log if True.
**kwargs : dict
- Parameters can be directly pased to the ot.optim.cg solver.
+ Parameters can be directly passed to the ot.optim.cg solver.
Returns
-------
- gamma : ndarray, shape (ns, nt)
- Optimal transportation matrix for the given parameters.
+ fgw-distance : float
+ Fused gromov wasserstein distance for the given parameters.
log : dict
Log dictionary return only if log==True in parameters.
+
+ .. _references-fused-gromov-wasserstein2:
References
----------
- .. [24] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain
+ .. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain
and Courty Nicolas
"Optimal Transport for structured data with application on graphs"
International Conference on Machine Learning (ICML). 2019.
+
+ .. [38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, Online
+ Graph Dictionary Learning, International Conference on Machine Learning
+ (ICML), 2021.
"""
+ p, q = list_to_array(p, q)
+
+ p0, q0, C10, C20, M0 = p, q, C1, C2, M
+ nx = get_backend(p0, q0, C10, C20, M0)
+
+ p = nx.to_numpy(p)
+ q = nx.to_numpy(q)
+ C1 = nx.to_numpy(C10)
+ C2 = nx.to_numpy(C20)
+ M = nx.to_numpy(M0)
constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun)
@@ -563,50 +713,462 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5
def df(G):
return gwggrad(constC, hC1, hC2, G)
- res, log = cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs)
+ T, log_fgw = cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs)
+
+ fgw_dist = nx.from_numpy(log_fgw['loss'][-1], type_as=C10)
+
+ T0 = nx.from_numpy(T, type_as=C10)
+
+ log_fgw['fgw_dist'] = fgw_dist
+ log_fgw['u'] = nx.from_numpy(log_fgw['u'], type_as=C10)
+ log_fgw['v'] = nx.from_numpy(log_fgw['v'], type_as=C10)
+ log_fgw['T'] = T0
+
+ if loss_fun == 'square_loss':
+ gC1 = nx.from_numpy(2 * C1 * (p[:, None] * p[None, :]) - 2 * T.dot(C2).dot(T.T))
+ gC2 = nx.from_numpy(2 * C2 * (q[:, None] * q[None, :]) - 2 * T.T.dot(C1).dot(T))
+ fgw_dist = nx.set_gradients(fgw_dist, (p0, q0, C10, C20, M0),
+ (log_fgw['u'], log_fgw['v'], alpha * gC1, alpha * gC2, (1 - alpha) * T0))
+
if log:
- log['fgw_dist'] = log['loss'][::-1][0]
- log['T'] = res
- return log['fgw_dist'], log
+ return fgw_dist, log_fgw
else:
- return log['fgw_dist']
+ return fgw_dist
-def entropic_gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon,
- max_iter=1000, tol=1e-9, verbose=False, log=False):
+def GW_distance_estimation(C1, C2, p, q, loss_fun, T,
+ nb_samples_p=None, nb_samples_q=None, std=True, random_state=None):
+ r"""
+ Returns an approximation of the gromov-wasserstein cost between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})`
+ with a fixed transport plan :math:`\mathbf{T}`.
+
+ The function gives an unbiased approximation of the following equation:
+
+ .. math::
+
+ GW = \sum_{i,j,k,l} L(\mathbf{C_{1}}_{i,k}, \mathbf{C_{2}}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l}
+
+ Where :
+
+ - :math:`\mathbf{C_1}`: Metric cost matrix in the source space
+ - :math:`\mathbf{C_2}`: Metric cost matrix in the target space
+ - `L` : Loss function to account for the misfit between the similarity matrices
+ - :math:`\mathbf{T}`: Matrix with marginal :math:`\mathbf{p}` and :math:`\mathbf{q}`
+
+ Parameters
+ ----------
+ C1 : array-like, shape (ns, ns)
+ Metric cost matrix in the source space
+ C2 : array-like, shape (nt, nt)
+ Metric cost matrix in the target space
+ p : array-like, shape (ns,)
+ Distribution in the source space
+ q : array-like, shape (nt,)
+ Distribution in the target space
+ loss_fun : function: :math:`\mathbb{R} \times \mathbb{R} \mapsto \mathbb{R}`
+ Loss function used for the distance, the transport plan does not depend on the loss function
+ T : csr or array-like, shape (ns, nt)
+ Transport plan matrix, either a sparse csr or a dense matrix
+ nb_samples_p : int, optional
+ `nb_samples_p` is the number of samples (without replacement) along the first dimension of :math:`\mathbf{T}`
+ nb_samples_q : int, optional
+ `nb_samples_q` is the number of samples along the second dimension of :math:`\mathbf{T}`, for each sample along the first
+ std : bool, optional
+ Standard deviation associated with the prediction of the gromov-wasserstein cost
+ random_state : int or RandomState instance, optional
+ Fix the seed for reproducibility
+
+ Returns
+ -------
+ : float
+ Gromov-wasserstein cost
+
+ References
+ ----------
+ .. [14] Kerdoncuff, Tanguy, Emonet, Rémi, Sebban, Marc
+ "Sampled Gromov Wasserstein."
+ Machine Learning Journal (MLJ). 2021.
+
+ """
+ C1, C2, p, q = list_to_array(C1, C2, p, q)
+ nx = get_backend(C1, C2, p, q)
+
+ generator = check_random_state(random_state)
+
+ len_p = p.shape[0]
+ len_q = q.shape[0]
+
+ # It is always better to sample from the biggest distribution first.
+ if len_p < len_q:
+ p, q = q, p
+ len_p, len_q = len_q, len_p
+ C1, C2 = C2, C1
+ T = T.T
+
+ if nb_samples_p is None:
+ if nx.issparse(T):
+ # If T is sparse, it probably mean that PoGroW was used, thus the number of sample is reduced
+ nb_samples_p = min(int(5 * (len_p * np.log(len_p)) ** 0.5), len_p)
+ else:
+ nb_samples_p = len_p
+ else:
+ # The number of sample along the first dimension is without replacement.
+ nb_samples_p = min(nb_samples_p, len_p)
+ if nb_samples_q is None:
+ nb_samples_q = 1
+ if std:
+ nb_samples_q = max(2, nb_samples_q)
+
+ index_k = np.zeros((nb_samples_p, nb_samples_q), dtype=int)
+ index_l = np.zeros((nb_samples_p, nb_samples_q), dtype=int)
+
+ index_i = generator.choice(len_p, size=nb_samples_p, p=p, replace=False)
+ index_j = generator.choice(len_p, size=nb_samples_p, p=p, replace=False)
+
+ for i in range(nb_samples_p):
+ if nx.issparse(T):
+ T_indexi = nx.reshape(nx.todense(T[index_i[i], :]), (-1,))
+ T_indexj = nx.reshape(nx.todense(T[index_j[i], :]), (-1,))
+ else:
+ T_indexi = T[index_i[i], :]
+ T_indexj = T[index_j[i], :]
+ # For each of the row sampled, the column is sampled.
+ index_k[i] = generator.choice(
+ len_q,
+ size=nb_samples_q,
+ p=T_indexi / nx.sum(T_indexi),
+ replace=True
+ )
+ index_l[i] = generator.choice(
+ len_q,
+ size=nb_samples_q,
+ p=T_indexj / nx.sum(T_indexj),
+ replace=True
+ )
+
+ list_value_sample = nx.stack([
+ loss_fun(
+ C1[np.ix_(index_i, index_j)],
+ C2[np.ix_(index_k[:, n], index_l[:, n])]
+ ) for n in range(nb_samples_q)
+ ], axis=2)
+
+ if std:
+ std_value = nx.sum(nx.std(list_value_sample, axis=2) ** 2) ** 0.5
+ return nx.mean(list_value_sample), std_value / (nb_samples_p * nb_samples_p)
+ else:
+ return nx.mean(list_value_sample)
+
+
+def pointwise_gromov_wasserstein(C1, C2, p, q, loss_fun,
+ alpha=1, max_iter=100, threshold_plan=0, log=False, verbose=False, random_state=None):
+ r"""
+ Returns the gromov-wasserstein transport between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` using a stochastic Frank-Wolfe.
+ This method has a :math:`\mathcal{O}(\mathrm{max\_iter} \times PN^2)` time complexity with `P` the number of Sinkhorn iterations.
+
+ The function solves the following optimization problem:
+
+ .. math::
+ \mathbf{GW} = \mathop{\arg \min}_\mathbf{T} \quad \sum_{i,j,k,l}
+ L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l}
+
+ s.t. \ \mathbf{T} \mathbf{1} &= \mathbf{p}
+
+ \mathbf{T}^T \mathbf{1} &= \mathbf{q}
+
+ \mathbf{T} &\geq 0
+
+ Where :
+
+ - :math:`\mathbf{C_1}`: Metric cost matrix in the source space
+ - :math:`\mathbf{C_2}`: Metric cost matrix in the target space
+ - :math:`\mathbf{p}`: distribution in the source space
+ - :math:`\mathbf{q}`: distribution in the target space
+ - `L`: loss function to account for the misfit between the similarity matrices
+
+ Parameters
+ ----------
+ C1 : array-like, shape (ns, ns)
+ Metric cost matrix in the source space
+ C2 : array-like, shape (nt, nt)
+ Metric cost matrix in the target space
+ p : array-like, shape (ns,)
+ Distribution in the source space
+ q : array-like, shape (nt,)
+ Distribution in the target space
+ loss_fun : function: :math:`\mathbb{R} \times \mathbb{R} \mapsto \mathbb{R}`
+ Loss function used for the distance, the transport plan does not depend on the loss function
+ alpha : float
+ Step of the Frank-Wolfe algorithm, should be between 0 and 1
+ max_iter : int, optional
+ Max number of iterations
+ threshold_plan : float, optional
+ Deleting very small values in the transport plan. If above zero, it violates the marginal constraints.
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ Gives the distance estimated and the standard deviation
+ random_state : int or RandomState instance, optional
+ Fix the seed for reproducibility
+
+ Returns
+ -------
+ T : array-like, shape (`ns`, `nt`)
+ Optimal coupling between the two spaces
+
+ References
+ ----------
+ .. [14] Kerdoncuff, Tanguy, Emonet, Rémi, Sebban, Marc
+ "Sampled Gromov Wasserstein."
+ Machine Learning Journal (MLJ). 2021.
+
"""
- Returns the gromov-wasserstein transport between (C1,p) and (C2,q)
+ C1, C2, p, q = list_to_array(C1, C2, p, q)
+ nx = get_backend(C1, C2, p, q)
+
+ len_p = p.shape[0]
+ len_q = q.shape[0]
+
+ generator = check_random_state(random_state)
+
+ index = np.zeros(2, dtype=int)
- (C1,p) and (C2,q)
+ # Initialize with default marginal
+ index[0] = generator.choice(len_p, size=1, p=p)
+ index[1] = generator.choice(len_q, size=1, p=q)
+ T = nx.tocsr(emd_1d(C1[index[0]], C2[index[1]], a=p, b=q, dense=False))
+
+ best_gw_dist_estimated = np.inf
+ for cpt in range(max_iter):
+ index[0] = generator.choice(len_p, size=1, p=p)
+ T_index0 = nx.reshape(nx.todense(T[index[0], :]), (-1,))
+ index[1] = generator.choice(len_q, size=1, p=T_index0 / T_index0.sum())
+
+ if alpha == 1:
+ T = nx.tocsr(
+ emd_1d(C1[index[0]], C2[index[1]], a=p, b=q, dense=False)
+ )
+ else:
+ new_T = nx.tocsr(
+ emd_1d(C1[index[0]], C2[index[1]], a=p, b=q, dense=False)
+ )
+ T = (1 - alpha) * T + alpha * new_T
+ # To limit the number of non 0, the values below the threshold are set to 0.
+ T = nx.eliminate_zeros(T, threshold=threshold_plan)
+
+ if cpt % 10 == 0 or cpt == (max_iter - 1):
+ gw_dist_estimated = GW_distance_estimation(
+ C1=C1, C2=C2, loss_fun=loss_fun,
+ p=p, q=q, T=T, std=False, random_state=generator
+ )
+
+ if gw_dist_estimated < best_gw_dist_estimated:
+ best_gw_dist_estimated = gw_dist_estimated
+ best_T = nx.copy(T)
+
+ if verbose:
+ if cpt % 200 == 0:
+ print('{:5s}|{:12s}'.format('It.', 'Best gw estimated') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(cpt, best_gw_dist_estimated))
+
+ if log:
+ log = {}
+ log["gw_dist_estimated"], log["gw_dist_std"] = GW_distance_estimation(
+ C1=C1, C2=C2, loss_fun=loss_fun,
+ p=p, q=q, T=best_T, random_state=generator
+ )
+ return best_T, log
+ return best_T
+
+
+def sampled_gromov_wasserstein(C1, C2, p, q, loss_fun,
+ nb_samples_grad=100, epsilon=1, max_iter=500, log=False, verbose=False,
+ random_state=None):
+ r"""
+ Returns the gromov-wasserstein transport between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` using a 1-stochastic Frank-Wolfe.
+ This method has a :math:`\mathcal{O}(\mathrm{max\_iter} \times N \log(N))` time complexity by relying on the 1D Optimal Transport solver.
The function solves the following optimization problem:
.. math::
- GW = arg\min_T \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l}-\epsilon(H(T))
+ \mathbf{GW} = \mathop{\arg \min}_\mathbf{T} \quad \sum_{i,j,k,l}
+ L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l}
- s.t. T 1 = p
+ s.t. \ \mathbf{T} \mathbf{1} &= \mathbf{p}
- T^T 1= q
+ \mathbf{T}^T \mathbf{1} &= \mathbf{q}
- T\geq 0
+ \mathbf{T} &\geq 0
Where :
- - C1 : Metric cost matrix in the source space
- - C2 : Metric cost matrix in the target space
- - p : distribution in the source space
- - q : distribution in the target space
- - L : loss function to account for the misfit between the similarity matrices
- - H : entropy
+
+ - :math:`\mathbf{C_1}`: Metric cost matrix in the source space
+ - :math:`\mathbf{C_2}`: Metric cost matrix in the target space
+ - :math:`\mathbf{p}`: distribution in the source space
+ - :math:`\mathbf{q}`: distribution in the target space
+ - `L`: loss function to account for the misfit between the similarity matrices
Parameters
----------
- C1 : ndarray, shape (ns, ns)
+ C1 : array-like, shape (ns, ns)
Metric cost matrix in the source space
- C2 : ndarray, shape (nt, nt)
- Metric costfr matrix in the target space
- p : ndarray, shape (ns,)
+ C2 : array-like, shape (nt, nt)
+ Metric cost matrix in the target space
+ p : array-like, shape (ns,)
Distribution in the source space
- q : ndarray, shape (nt,)
+ q : array-like, shape (nt,)
+ Distribution in the target space
+ loss_fun : function: :math:`\mathbb{R} \times \mathbb{R} \mapsto \mathbb{R}`
+ Loss function used for the distance, the transport plan does not depend on the loss function
+ nb_samples_grad : int
+ Number of samples to approximate the gradient
+ epsilon : float
+ Weight of the Kullback-Leibler regularization
+ max_iter : int, optional
+ Max number of iterations
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ Gives the distance estimated and the standard deviation
+ random_state : int or RandomState instance, optional
+ Fix the seed for reproducibility
+
+ Returns
+ -------
+ T : array-like, shape (`ns`, `nt`)
+ Optimal coupling between the two spaces
+
+ References
+ ----------
+ .. [14] Kerdoncuff, Tanguy, Emonet, Rémi, Sebban, Marc
+ "Sampled Gromov Wasserstein."
+ Machine Learning Journal (MLJ). 2021.
+
+ """
+ C1, C2, p, q = list_to_array(C1, C2, p, q)
+ nx = get_backend(C1, C2, p, q)
+
+ len_p = p.shape[0]
+ len_q = q.shape[0]
+
+ generator = check_random_state(random_state)
+
+ # The most natural way to define nb_sample is with a simple integer.
+ if isinstance(nb_samples_grad, int):
+ if nb_samples_grad > len_p:
+ # As the sampling along the first dimension is done without replacement, the rest is reported to the second
+ # dimension.
+ nb_samples_grad_p, nb_samples_grad_q = len_p, nb_samples_grad // len_p
+ else:
+ nb_samples_grad_p, nb_samples_grad_q = nb_samples_grad, 1
+ else:
+ nb_samples_grad_p, nb_samples_grad_q = nb_samples_grad
+ T = nx.outer(p, q)
+ # continue_loop allows to stop the loop if there is several successive small modification of T.
+ continue_loop = 0
+
+ # The gradient of GW is more complex if the two matrices are not symmetric.
+ C_are_symmetric = nx.allclose(C1, C1.T, rtol=1e-10, atol=1e-10) and nx.allclose(C2, C2.T, rtol=1e-10, atol=1e-10)
+
+ for cpt in range(max_iter):
+ index0 = generator.choice(len_p, size=nb_samples_grad_p, p=p, replace=False)
+ Lik = 0
+ for i, index0_i in enumerate(index0):
+ index1 = generator.choice(len_q,
+ size=nb_samples_grad_q,
+ p=T[index0_i, :] / nx.sum(T[index0_i, :]),
+ replace=False)
+ # If the matrices C are not symmetric, the gradient has 2 terms, thus the term is chosen randomly.
+ if (not C_are_symmetric) and generator.rand(1) > 0.5:
+ Lik += nx.mean(loss_fun(
+ C1[:, [index0[i]] * nb_samples_grad_q][:, None, :],
+ C2[:, index1][None, :, :]
+ ), axis=2)
+ else:
+ Lik += nx.mean(loss_fun(
+ C1[[index0[i]] * nb_samples_grad_q, :][:, :, None],
+ C2[index1, :][:, None, :]
+ ), axis=0)
+
+ max_Lik = nx.max(Lik)
+ if max_Lik == 0:
+ continue
+ # This division by the max is here to facilitate the choice of epsilon.
+ Lik /= max_Lik
+
+ if epsilon > 0:
+ # Set to infinity all the numbers below exp(-200) to avoid log of 0.
+ log_T = nx.log(nx.clip(T, np.exp(-200), 1))
+ log_T = nx.where(log_T == -200, -np.inf, log_T)
+ Lik = Lik - epsilon * log_T
+
+ try:
+ new_T = sinkhorn(a=p, b=q, M=Lik, reg=epsilon)
+ except (RuntimeWarning, UserWarning):
+ print("Warning catched in Sinkhorn: Return last stable T")
+ break
+ else:
+ new_T = emd(a=p, b=q, M=Lik)
+
+ change_T = nx.mean((T - new_T) ** 2)
+ if change_T <= 10e-20:
+ continue_loop += 1
+ if continue_loop > 100: # Number max of low modifications of T
+ T = nx.copy(new_T)
+ break
+ else:
+ continue_loop = 0
+
+ if verbose and cpt % 10 == 0:
+ if cpt % 200 == 0:
+ print('{:5s}|{:12s}'.format('It.', '||T_n - T_{n+1}||') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(cpt, change_T))
+ T = nx.copy(new_T)
+
+ if log:
+ log = {}
+ log["gw_dist_estimated"], log["gw_dist_std"] = GW_distance_estimation(
+ C1=C1, C2=C2, loss_fun=loss_fun,
+ p=p, q=q, T=T, random_state=generator
+ )
+ return T, log
+ return T
+
+
+def entropic_gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon,
+ max_iter=1000, tol=1e-9, verbose=False, log=False):
+ r"""
+ Returns the gromov-wasserstein transport between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})`
+
+ The function solves the following optimization problem:
+
+ .. math::
+ \mathbf{GW} = \mathop{\arg\min}_\mathbf{T} \quad \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} - \epsilon(H(\mathbf{T}))
+
+ s.t. \ \mathbf{T} \mathbf{1} &= \mathbf{p}
+
+ \mathbf{T}^T \mathbf{1} &= \mathbf{q}
+
+ \mathbf{T} &\geq 0
+
+ Where :
+
+ - :math:`\mathbf{C_1}`: Metric cost matrix in the source space
+ - :math:`\mathbf{C_2}`: Metric cost matrix in the target space
+ - :math:`\mathbf{p}`: distribution in the source space
+ - :math:`\mathbf{q}`: distribution in the target space
+ - `L`: loss function to account for the misfit between the similarity matrices
+ - `H`: entropy
+
+ Parameters
+ ----------
+ C1 : array-like, shape (ns, ns)
+ Metric cost matrix in the source space
+ C2 : array-like, shape (nt, nt)
+ Metric cost matrix in the target space
+ p : array-like, shape (ns,)
+ Distribution in the source space
+ q : array-like, shape (nt,)
Distribution in the target space
loss_fun : string
Loss function used for the solver either 'square_loss' or 'kl_loss'
@@ -623,21 +1185,20 @@ def entropic_gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon,
Returns
-------
- T : ndarray, shape (ns, nt)
+ T : array-like, shape (`ns`, `nt`)
Optimal coupling between the two spaces
References
----------
- .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon,
+ .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon,
"Gromov-Wasserstein averaging of kernel and distance matrices."
International Conference on Machine Learning (ICML). 2016.
"""
+ C1, C2, p, q = list_to_array(C1, C2, p, q)
+ nx = get_backend(C1, C2, p, q)
- C1 = np.asarray(C1, dtype=np.float64)
- C2 = np.asarray(C2, dtype=np.float64)
-
- T = np.outer(p, q) # Initialization
+ T = nx.outer(p, q)
constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun)
@@ -654,12 +1215,12 @@ def entropic_gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon,
# compute the gradient
tens = gwggrad(constC, hC1, hC2, T)
- T = sinkhorn(p, q, tens, epsilon)
+ T = sinkhorn(p, q, tens, epsilon, method='sinkhorn')
if cpt % 10 == 0:
# we can speed up the process by checking for the error only all
# the 10th iterations
- err = np.linalg.norm(T - Tprev)
+ err = nx.norm(T - Tprev)
if log:
log['err'].append(err)
@@ -681,33 +1242,33 @@ def entropic_gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon,
def entropic_gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon,
max_iter=1000, tol=1e-9, verbose=False, log=False):
- """
- Returns the entropic gromov-wasserstein discrepancy between the two measured similarity matrices
-
- (C1,p) and (C2,q)
+ r"""
+ Returns the entropic gromov-wasserstein discrepancy between the two measured similarity matrices :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})`
The function solves the following optimization problem:
.. math::
- GW = \min_T \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l}-\epsilon(H(T))
+ GW = \min_\mathbf{T} \quad \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l})
+ \mathbf{T}_{i,j} \mathbf{T}_{k,l} - \epsilon(H(\mathbf{T}))
Where :
- - C1 : Metric cost matrix in the source space
- - C2 : Metric cost matrix in the target space
- - p : distribution in the source space
- - q : distribution in the target space
- - L : loss function to account for the misfit between the similarity matrices
- - H : entropy
+
+ - :math:`\mathbf{C_1}`: Metric cost matrix in the source space
+ - :math:`\mathbf{C_2}`: Metric cost matrix in the target space
+ - :math:`\mathbf{p}`: distribution in the source space
+ - :math:`\mathbf{q}`: distribution in the target space
+ - `L`: loss function to account for the misfit between the similarity matrices
+ - `H`: entropy
Parameters
----------
- C1 : ndarray, shape (ns, ns)
+ C1 : array-like, shape (ns, ns)
Metric cost matrix in the source space
- C2 : ndarray, shape (nt, nt)
- Metric costfr matrix in the target space
- p : ndarray, shape (ns,)
+ C2 : array-like, shape (nt, nt)
+ Metric cost matrix in the target space
+ p : array-like, shape (ns,)
Distribution in the source space
- q : ndarray, shape (nt,)
+ q : array-like, shape (nt,)
Distribution in the target space
loss_fun : str
Loss function used for the solver either 'square_loss' or 'kl_loss'
@@ -729,7 +1290,7 @@ def entropic_gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon,
References
----------
- .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon,
+ .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon,
"Gromov-Wasserstein averaging of kernel and distance matrices."
International Conference on Machine Learning (ICML). 2016.
@@ -746,76 +1307,79 @@ def entropic_gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon,
def entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon,
- max_iter=1000, tol=1e-9, verbose=False, log=False, init_C=None):
- """
- Returns the gromov-wasserstein barycenters of S measured similarity matrices
-
- (Cs)_{s=1}^{s=S}
+ max_iter=1000, tol=1e-9, verbose=False, log=False, init_C=None, random_state=None):
+ r"""
+ Returns the gromov-wasserstein barycenters of `S` measured similarity matrices :math:`(\mathbf{C}_s)_{1 \leq s \leq S}`
The function solves the following optimization problem:
.. math::
- C = argmin_{C\in R^{NxN}} \sum_s \lambda_s GW(C,C_s,p,p_s)
+ \mathbf{C} = \mathop{\arg \min}_{\mathbf{C}\in \mathbb{R}^{N \times N}} \quad \sum_s \lambda_s \mathrm{GW}(\mathbf{C}, \mathbf{C}_s, \mathbf{p}, \mathbf{p}_s)
Where :
- - :math:`C_s` : metric cost matrix
- - :math:`p_s` : distribution
+ - :math:`\mathbf{C}_s`: metric cost matrix
+ - :math:`\mathbf{p}_s`: distribution
Parameters
----------
N : int
Size of the targeted barycenter
- Cs : list of S np.ndarray of shape (ns,ns)
+ Cs : list of S array-like of shape (ns,ns)
Metric cost matrices
- ps : list of S np.ndarray of shape (ns,)
- Sample weights in the S spaces
- p : ndarray, shape(N,)
+ ps : list of S array-like of shape (ns,)
+ Sample weights in the `S` spaces
+ p : array-like, shape(N,)
Weights in the targeted barycenter
lambdas : list of float
- List of the S spaces' weights.
+ List of the `S` spaces' weights.
loss_fun : callable
Tensor-matrix multiplication function based on specific loss function.
update : callable
- function(p,lambdas,T,Cs) that updates C according to a specific Kernel
- with the S Ts couplings calculated at each iteration
+ function(:math:`\mathbf{p}`, lambdas, :math:`\mathbf{T}`, :math:`\mathbf{Cs}`) that updates
+ :math:`\mathbf{C}` according to a specific Kernel with the `S` :math:`\mathbf{T}_s` couplings
+ calculated at each iteration
epsilon : float
Regularization term >0
max_iter : int, optional
Max number of iterations
tol : float, optional
- Stop threshol on error (>0)
+ Stop threshold on error (>0)
verbose : bool, optional
Print information along iterations.
log : bool, optional
Record log if True.
- init_C : bool | ndarray, shape (N, N)
- Random initial value for the C matrix provided by user.
+ init_C : bool | array-like, shape (N, N)
+ Random initial value for the :math:`\mathbf{C}` matrix provided by user.
+ random_state : int or RandomState instance, optional
+ Fix the seed for reproducibility
Returns
-------
- C : ndarray, shape (N, N)
+ C : array-like, shape (`N`, `N`)
Similarity matrix in the barycenter space (permutated arbitrarily)
References
----------
- .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon,
+ .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon,
"Gromov-Wasserstein averaging of kernel and distance matrices."
International Conference on Machine Learning (ICML). 2016.
"""
+ Cs = list_to_array(*Cs)
+ ps = list_to_array(*ps)
+ p = list_to_array(p)
+ nx = get_backend(*Cs, *ps, p)
S = len(Cs)
- Cs = [np.asarray(Cs[s], dtype=np.float64) for s in range(S)]
- lambdas = np.asarray(lambdas, dtype=np.float64)
-
# Initialization of C : random SPD matrix (if not provided by user)
if init_C is None:
- # XXX use random state
- xalea = np.random.randn(N, 2)
+ generator = check_random_state(random_state)
+ xalea = generator.randn(N, 2)
C = dist(xalea, xalea)
C /= C.max()
+ C = nx.from_numpy(C, type_as=p)
else:
C = init_C
@@ -828,7 +1392,7 @@ def entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon,
Cprev = C
T = [entropic_gromov_wasserstein(Cs[s], C, ps[s], p, loss_fun, epsilon,
- max_iter, 1e-5, verbose, log) for s in range(S)]
+ max_iter, 1e-4, verbose, log) for s in range(S)]
if loss_fun == 'square_loss':
C = update_square_loss(p, lambdas, T, Cs)
@@ -838,7 +1402,7 @@ def entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon,
if cpt % 10 == 0:
# we can speed up the process by checking for the error only all
# the 10th iterations
- err = np.linalg.norm(C - Cprev)
+ err = nx.norm(C - Cprev)
error.append(err)
if log:
@@ -856,72 +1420,78 @@ def entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon,
def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun,
- max_iter=1000, tol=1e-9, verbose=False, log=False, init_C=None):
- """
- Returns the gromov-wasserstein barycenters of S measured similarity matrices
-
- (Cs)_{s=1}^{s=S}
+ max_iter=1000, tol=1e-9, verbose=False, log=False, init_C=None, random_state=None):
+ r"""
+ Returns the gromov-wasserstein barycenters of `S` measured similarity matrices :math:`(\mathbf{C}_s)_{1 \leq s \leq S}`
- The function solves the following optimization problem with block
- coordinate descent:
+ The function solves the following optimization problem with block coordinate descent:
.. math::
- C = argmin_C\in R^NxN \sum_s \lambda_s GW(C,Cs,p,ps)
+
+ \mathbf{C} = \mathop{\arg \min}_{\mathbf{C}\in \mathbb{R}^{N \times N}} \quad \sum_s \lambda_s \mathrm{GW}(\mathbf{C}, \mathbf{C}_s, \mathbf{p}, \mathbf{p}_s)
Where :
- - Cs : metric cost matrix
- - ps : distribution
+ - :math:`\mathbf{C}_s`: metric cost matrix
+ - :math:`\mathbf{p}_s`: distribution
Parameters
----------
N : int
Size of the targeted barycenter
- Cs : list of S np.ndarray of shape (ns, ns)
+ Cs : list of S array-like of shape (ns, ns)
Metric cost matrices
- ps : list of S np.ndarray of shape (ns,)
- Sample weights in the S spaces
- p : ndarray, shape (N,)
+ ps : list of S array-like of shape (ns,)
+ Sample weights in the `S` spaces
+ p : array-like, shape (N,)
Weights in the targeted barycenter
lambdas : list of float
- List of the S spaces' weights
- loss_fun : tensor-matrix multiplication function based on specific loss function
- update : function(p,lambdas,T,Cs) that updates C according to a specific Kernel
- with the S Ts couplings calculated at each iteration
+ List of the `S` spaces' weights
+ loss_fun : callable
+ tensor-matrix multiplication function based on specific loss function
+ update : callable
+ function(:math:`\mathbf{p}`, lambdas, :math:`\mathbf{T}`, :math:`\mathbf{Cs}`) that updates
+ :math:`\mathbf{C}` according to a specific Kernel with the `S` :math:`\mathbf{T}_s` couplings
+ calculated at each iteration
max_iter : int, optional
Max number of iterations
tol : float, optional
- Stop threshol on error (>0).
+ Stop threshold on error (>0).
verbose : bool, optional
Print information along iterations.
log : bool, optional
Record log if True.
- init_C : bool | ndarray, shape(N,N)
- Random initial value for the C matrix provided by user.
+ init_C : bool | array-like, shape(N,N)
+ Random initial value for the :math:`\mathbf{C}` matrix provided by user.
+ random_state : int or RandomState instance, optional
+ Fix the seed for reproducibility
Returns
-------
- C : ndarray, shape (N, N)
+ C : array-like, shape (`N`, `N`)
Similarity matrix in the barycenter space (permutated arbitrarily)
References
----------
- .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon,
+ .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon,
"Gromov-Wasserstein averaging of kernel and distance matrices."
International Conference on Machine Learning (ICML). 2016.
"""
- S = len(Cs)
+ Cs = list_to_array(*Cs)
+ ps = list_to_array(*ps)
+ p = list_to_array(p)
+ nx = get_backend(*Cs, *ps, p)
- Cs = [np.asarray(Cs[s], dtype=np.float64) for s in range(S)]
- lambdas = np.asarray(lambdas, dtype=np.float64)
+ S = len(Cs)
# Initialization of C : random SPD matrix (if not provided by user)
if init_C is None:
- # XXX : should use a random state and not use the global seed
- xalea = np.random.randn(N, 2)
+ generator = check_random_state(random_state)
+ xalea = generator.randn(N, 2)
C = dist(xalea, xalea)
C /= C.max()
+ C = nx.from_numpy(C, type_as=p)
else:
C = init_C
@@ -944,7 +1514,7 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun,
if cpt % 10 == 0:
# we can speed up the process by checking for the error only all
# the 10th iterations
- err = np.linalg.norm(C - Cprev)
+ err = nx.norm(C - Cprev)
error.append(err)
if log:
@@ -963,21 +1533,21 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun,
def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_features=False,
p=None, loss_fun='square_loss', max_iter=100, tol=1e-9,
- verbose=False, log=False, init_C=None, init_X=None):
- """Compute the fgw barycenter as presented eq (5) in [24].
+ verbose=False, log=False, init_C=None, init_X=None, random_state=None):
+ r"""Compute the fgw barycenter as presented eq (5) in :ref:`[24] <references-fgw-barycenters>`
Parameters
----------
- N : integer
+ N : int
Desired number of samples of the target barycenter
- Ys: list of ndarray, each element has shape (ns,d)
+ Ys: list of array-like, each element has shape (ns,d)
Features of all samples
- Cs : list of ndarray, each element has shape (ns,ns)
+ Cs : list of array-like, each element has shape (ns,ns)
Structure matrices of all samples
- ps : list of ndarray, each element has shape (ns,)
+ ps : list of array-like, each element has shape (ns,)
Masses of all samples.
lambdas : list of float
- List of the S spaces' weights
+ List of the `S` spaces' weights
alpha : float
Alpha parameter for the fgw distance
fixed_structure : bool
@@ -989,46 +1559,51 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_
max_iter : int, optional
Max number of iterations
tol : float, optional
- Stop threshol on error (>0).
+ Stop threshold on error (>0).
verbose : bool, optional
Print information along iterations.
log : bool, optional
Record log if True.
- init_C : ndarray, shape (N,N), optional
+ init_C : array-like, shape (N,N), optional
Initialization for the barycenters' structure matrix. If not set
a random init is used.
- init_X : ndarray, shape (N,d), optional
+ init_X : array-like, shape (N,d), optional
Initialization for the barycenters' features. If not set a
random init is used.
+ random_state : int or RandomState instance, optional
+ Fix the seed for reproducibility
Returns
-------
- X : ndarray, shape (N, d)
+ X : array-like, shape (`N`, `d`)
Barycenters' features
- C : ndarray, shape (N, N)
+ C : array-like, shape (`N`, `N`)
Barycenters' structure matrix
- log_: dict
+ log : dict
Only returned when log=True. It contains the keys:
- T : list of (N,ns) transport matrices
- Ms : all distance matrices between the feature of the barycenter and the
- other features dist(X,Ys) shape (N,ns)
+ - :math:`\mathbf{T}`: list of (`N`, `ns`) transport matrices
+ - :math:`(\mathbf{M}_s)_s`: all distance matrices between the feature of the barycenter and the other features :math:`(dist(\mathbf{X}, \mathbf{Y}_s))_s` shape (`N`, `ns`)
+
+
+ .. _references-fgw-barycenters:
References
----------
- .. [24] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain
+ .. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain
and Courty Nicolas
"Optimal Transport for structured data with application on graphs"
International Conference on Machine Learning (ICML). 2019.
"""
+ Cs = list_to_array(*Cs)
+ ps = list_to_array(*ps)
+ Ys = list_to_array(*Ys)
+ p = list_to_array(p)
+ nx = get_backend(*Cs, *Ys, *ps)
+
S = len(Cs)
d = Ys[0].shape[1] # dimension on the node features
if p is None:
- p = np.ones(N) / N
-
- Cs = [np.asarray(Cs[s], dtype=np.float64) for s in range(S)]
- Ys = [np.asarray(Ys[s], dtype=np.float64) for s in range(S)]
-
- lambdas = np.asarray(lambdas, dtype=np.float64)
+ p = nx.ones(N, type_as=Cs[0]) / N
if fixed_structure:
if init_C is None:
@@ -1037,8 +1612,10 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_
C = init_C
else:
if init_C is None:
- xalea = np.random.randn(N, 2)
+ generator = check_random_state(random_state)
+ xalea = generator.randn(N, 2)
C = dist(xalea, xalea)
+ C = nx.from_numpy(C, type_as=ps[0])
else:
C = init_C
@@ -1049,13 +1626,13 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_
X = init_X
else:
if init_X is None:
- X = np.zeros((N, d))
+ X = nx.zeros((N, d), type_as=ps[0])
else:
X = init_X
- T = [np.outer(p, q) for q in ps]
+ T = [nx.outer(p, q) for q in ps]
- Ms = [np.asarray(dist(X, Ys[s]), dtype=np.float64) for s in range(len(Ys))] # Ms is N,ns
+ Ms = [dist(X, Ys[s]) for s in range(len(Ys))]
cpt = 0
err_feature = 1
@@ -1075,20 +1652,19 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_
Ys_temp = [y.T for y in Ys]
X = update_feature_matrix(lambdas, Ys_temp, T, p).T
- Ms = [np.asarray(dist(X, Ys[s]), dtype=np.float64) for s in range(len(Ys))]
+ Ms = [dist(X, Ys[s]) for s in range(len(Ys))]
if not fixed_structure:
if loss_fun == 'square_loss':
T_temp = [t.T for t in T]
- C = update_sructure_matrix(p, lambdas, T_temp, Cs)
+ C = update_structure_matrix(p, lambdas, T_temp, Cs)
T = [fused_gromov_wasserstein(Ms[s], C, Cs[s], p, ps[s], loss_fun, alpha,
numItermax=max_iter, stopThr=1e-5, verbose=verbose) for s in range(S)]
# T is N,ns
- err_feature = np.linalg.norm(X - Xprev.reshape(N, d))
- err_structure = np.linalg.norm(C - Cprev)
-
+ err_feature = nx.norm(X - nx.reshape(Xprev, (N, d)))
+ err_structure = nx.norm(C - Cprev)
if log:
log_['err_feature'].append(err_feature)
log_['err_structure'].append(err_structure)
@@ -1114,64 +1690,80 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_
return X, C
-def update_sructure_matrix(p, lambdas, T, Cs):
- """Updates C according to the L2 Loss kernel with the S Ts couplings.
+def update_structure_matrix(p, lambdas, T, Cs):
+ r"""Updates :math:`\mathbf{C}` according to the L2 Loss kernel with the `S` :math:`\mathbf{T}_s` couplings.
It is calculated at each iteration
Parameters
----------
- p : ndarray, shape (N,)
+ p : array-like, shape (N,)
Masses in the targeted barycenter.
lambdas : list of float
- List of the S spaces' weights.
- T : list of S ndarray of shape (ns, N)
- The S Ts couplings calculated at each iteration.
- Cs : list of S ndarray, shape (ns, ns)
- Metric cost matrices.
+ List of the `S` spaces' weights.
+ T : list of S array-like of shape (ns, N)
+ The `S` :math:`\mathbf{T}_s` couplings calculated at each iteration.
+ Cs : list of S array-like, shape (ns, ns)
+ Metric cost matrices.
Returns
-------
- C : ndarray, shape (nt, nt)
- Updated C matrix.
+ C : array-like, shape (`nt`, `nt`)
+ Updated :math:`\mathbf{C}` matrix.
"""
- tmpsum = sum([lambdas[s] * np.dot(T[s].T, Cs[s]).dot(T[s]) for s in range(len(T))])
- ppt = np.outer(p, p)
+ p = list_to_array(p)
+ T = list_to_array(*T)
+ Cs = list_to_array(*Cs)
+ nx = get_backend(*Cs, *T, p)
- return np.divide(tmpsum, ppt)
+ tmpsum = sum([
+ lambdas[s] * nx.dot(
+ nx.dot(T[s].T, Cs[s]),
+ T[s]
+ ) for s in range(len(T))
+ ])
+ ppt = nx.outer(p, p)
+ return tmpsum / ppt
def update_feature_matrix(lambdas, Ys, Ts, p):
- """Updates the feature with respect to the S Ts couplings.
+ r"""Updates the feature with respect to the `S` :math:`\mathbf{T}_s` couplings.
See "Solving the barycenter problem with Block Coordinate Descent (BCD)"
- in [24] calculated at each iteration
+ in :ref:`[24] <references-update-feature-matrix>` calculated at each iteration
Parameters
----------
- p : ndarray, shape (N,)
+ p : array-like, shape (N,)
masses in the targeted barycenter
lambdas : list of float
- List of the S spaces' weights
- Ts : list of S np.ndarray(ns,N)
- the S Ts couplings calculated at each iteration
- Ys : list of S ndarray, shape(d,ns)
+ List of the `S` spaces' weights
+ Ts : list of S array-like, shape (ns,N)
+ The `S` :math:`\mathbf{T}_s` couplings calculated at each iteration
+ Ys : list of S array-like, shape (d,ns)
The features.
Returns
-------
- X : ndarray, shape (d, N)
+ X : array-like, shape (`d`, `N`)
+
+ .. _references-update-feature-matrix:
References
----------
- .. [24] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain
- and Courty Nicolas
+ .. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain and Courty Nicolas
"Optimal Transport for structured data with application on graphs"
International Conference on Machine Learning (ICML). 2019.
"""
- p = np.array(1. / p).reshape(-1,)
-
- tmpsum = sum([lambdas[s] * np.dot(Ys[s], Ts[s].T) * p[None, :] for s in range(len(Ts))])
-
+ p = list_to_array(p)
+ Ts = list_to_array(*Ts)
+ Ys = list_to_array(*Ys)
+ nx = get_backend(*Ys, *Ts, p)
+
+ p = 1. / p
+ tmpsum = sum([
+ lambdas[s] * nx.dot(Ys[s], Ts[s].T) * p[None, :]
+ for s in range(len(Ts))
+ ])
return tmpsum
diff --git a/ot/helpers/__init__.py b/ot/helpers/__init__.py
new file mode 100644
index 0000000..b948671
--- /dev/null
+++ b/ot/helpers/__init__.py
@@ -0,0 +1,3 @@
+# Author: Remi Flamary <remi.flamary@unice.fr>
+#
+# License: MIT License
diff --git a/ot/helpers/openmp_helpers.py b/ot/helpers/openmp_helpers.py
new file mode 100644
index 0000000..a6ad38b
--- /dev/null
+++ b/ot/helpers/openmp_helpers.py
@@ -0,0 +1,85 @@
+"""Helpers for OpenMP support during the build."""
+
+# This code is adapted for a large part from the astropy openmp helpers, which
+# can be found at: https://github.com/astropy/extension-helpers/blob/master/extension_helpers/_openmp_helpers.py # noqa
+
+
+import os
+import sys
+import textwrap
+import subprocess
+
+from distutils.errors import CompileError, LinkError
+
+from pre_build_helpers import compile_test_program
+
+
+def get_openmp_flag(compiler):
+ """Get openmp flags for a given compiler"""
+
+ if hasattr(compiler, 'compiler'):
+ compiler = compiler.compiler[0]
+ else:
+ compiler = compiler.__class__.__name__
+
+ if sys.platform == "win32" and ('icc' in compiler or 'icl' in compiler):
+ omp_flag = ['/Qopenmp']
+ elif sys.platform == "win32":
+ omp_flag = ['/openmp']
+ elif sys.platform in ("darwin", "linux") and "icc" in compiler:
+ omp_flag = ['-qopenmp']
+ elif sys.platform == "darwin" and 'openmp' in os.getenv('CPPFLAGS', ''):
+ omp_flag = []
+ else:
+ # Default flag for GCC and clang:
+ omp_flag = ['-fopenmp']
+ if sys.platform.startswith("darwin"):
+ omp_flag += ["-Xpreprocessor", "-lomp"]
+ return omp_flag
+
+
+def check_openmp_support():
+ """Check whether OpenMP test code can be compiled and run"""
+
+ code = textwrap.dedent(
+ """\
+ #include <omp.h>
+ #include <stdio.h>
+ int main(void) {
+ #pragma omp parallel
+ printf("nthreads=%d\\n", omp_get_num_threads());
+ return 0;
+ }
+ """)
+
+ extra_preargs = os.getenv('LDFLAGS', None)
+ if extra_preargs is not None:
+ extra_preargs = extra_preargs.strip().split(" ")
+ extra_preargs = [
+ flag for flag in extra_preargs
+ if flag.startswith(('-L', '-Wl,-rpath', '-l'))]
+
+ extra_postargs = get_openmp_flag
+
+ try:
+ output, compile_flags = compile_test_program(
+ code,
+ extra_preargs=extra_preargs,
+ extra_postargs=extra_postargs
+ )
+
+ if output and 'nthreads=' in output[0]:
+ nthreads = int(output[0].strip().split('=')[1])
+ openmp_supported = len(output) == nthreads
+ elif "PYTHON_CROSSENV" in os.environ:
+ # Since we can't run the test program when cross-compiling
+ # assume that openmp is supported if the program can be
+ # compiled.
+ openmp_supported = True
+ else:
+ openmp_supported = False
+
+ except (CompileError, LinkError, subprocess.CalledProcessError):
+ openmp_supported = False
+ compile_flags = []
+ return openmp_supported, compile_flags
diff --git a/ot/helpers/pre_build_helpers.py b/ot/helpers/pre_build_helpers.py
new file mode 100644
index 0000000..93ecd6a
--- /dev/null
+++ b/ot/helpers/pre_build_helpers.py
@@ -0,0 +1,87 @@
+"""Helpers to check build environment before actual build of POT"""
+
+import os
+import sys
+import glob
+import tempfile
+import setuptools # noqa
+import subprocess
+
+from distutils.dist import Distribution
+from distutils.sysconfig import customize_compiler
+from numpy.distutils.ccompiler import new_compiler
+from numpy.distutils.command.config_compiler import config_cc
+
+
+def _get_compiler():
+ """Get a compiler equivalent to the one that will be used to build POT
+ Handles compiler specified as follows:
+ - python setup.py build_ext --compiler=<compiler>
+ - CC=<compiler> python setup.py build_ext
+ """
+ dist = Distribution({'script_name': os.path.basename(sys.argv[0]),
+ 'script_args': sys.argv[1:],
+ 'cmdclass': {'config_cc': config_cc}})
+
+ cmd_opts = dist.command_options.get('build_ext')
+ if cmd_opts is not None and 'compiler' in cmd_opts:
+ compiler = cmd_opts['compiler'][1]
+ else:
+ compiler = None
+
+ ccompiler = new_compiler(compiler=compiler)
+ customize_compiler(ccompiler)
+
+ return ccompiler
+
+
+def compile_test_program(code, extra_preargs=[], extra_postargs=[]):
+ """Check that some C code can be compiled and run"""
+ ccompiler = _get_compiler()
+
+ # extra_(pre/post)args can be a callable to make it possible to get its
+ # value from the compiler
+ if callable(extra_preargs):
+ extra_preargs = extra_preargs(ccompiler)
+ if callable(extra_postargs):
+ extra_postargs = extra_postargs(ccompiler)
+
+ start_dir = os.path.abspath('.')
+
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ try:
+ os.chdir(tmp_dir)
+
+ # Write test program
+ with open('test_program.c', 'w') as f:
+ f.write(code)
+
+ os.mkdir('objects')
+
+ # Compile, test program
+ ccompiler.compile(['test_program.c'], output_dir='objects',
+ extra_postargs=extra_postargs)
+
+ # Link test program
+ objects = glob.glob(
+ os.path.join('objects', '*' + ccompiler.obj_extension))
+ ccompiler.link_executable(objects, 'test_program',
+ extra_preargs=extra_preargs,
+ extra_postargs=extra_postargs)
+
+ if "PYTHON_CROSSENV" not in os.environ:
+ # Run test program if not cross compiling
+ # will raise a CalledProcessError if return code was non-zero
+ output = subprocess.check_output('./test_program')
+ output = output.decode(
+ sys.stdout.encoding or 'utf-8').splitlines()
+ else:
+ # Return an empty output if we are cross compiling
+ # as we cannot run the test_program
+ output = []
+ except Exception:
+ raise
+ finally:
+ os.chdir(start_dir)
+
+ return output, extra_postargs
diff --git a/ot/lp/EMD.h b/ot/lp/EMD.h
index c0fe7a3..8a1f9ac 100644
--- a/ot/lp/EMD.h
+++ b/ot/lp/EMD.h
@@ -18,19 +18,18 @@
#include <iostream>
#include <vector>
-#include "network_simplex_simple.h"
-using namespace lemon;
typedef unsigned int node_id_type;
enum ProblemType {
INFEASIBLE,
OPTIMAL,
UNBOUNDED,
- MAX_ITER_REACHED
+ MAX_ITER_REACHED
};
int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int maxIter);
+int EMD_wrap_omp(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int maxIter, int numThreads);
diff --git a/ot/lp/EMD_wrapper.cpp b/ot/lp/EMD_wrapper.cpp
index bc873ed..2bdc172 100644
--- a/ot/lp/EMD_wrapper.cpp
+++ b/ot/lp/EMD_wrapper.cpp
@@ -12,16 +12,22 @@
*
*/
+
+#include "network_simplex_simple.h"
+#include "network_simplex_simple_omp.h"
#include "EMD.h"
+#include <cstdint>
int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,
double* alpha, double* beta, double *cost, int maxIter) {
- // beware M and C anre strored in row major C style!!!
- int n, m, i, cur;
+ // beware M and C are stored in row major C style!!!
+
+ using namespace lemon;
+ int n, m, cur;
typedef FullBipartiteDigraph Digraph;
- DIGRAPH_TYPEDEFS(FullBipartiteDigraph);
+ DIGRAPH_TYPEDEFS(Digraph);
// Get the number of non zero coordinates for r and c
n=0;
@@ -48,7 +54,7 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,
std::vector<int> indI(n), indJ(m);
std::vector<double> weights1(n), weights2(m);
Digraph di(n, m);
- NetworkSimplexSimple<Digraph,double,double, node_id_type> net(di, true, n+m, n*m, maxIter);
+ NetworkSimplexSimple<Digraph,double,double, node_id_type> net(di, true, n+m, ((int64_t)n)*((int64_t)m), maxIter);
// Set supply and demand, don't account for 0 values (faster)
@@ -76,10 +82,12 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,
net.supplyMap(&weights1[0], n, &weights2[0], m);
// Set the cost of each edge
+ int64_t idarc = 0;
for (int i=0; i<n; i++) {
for (int j=0; j<m; j++) {
double val=*(D+indI[i]*n2+indJ[j]);
- net.setCost(di.arcFromId(i*m+j), val);
+ net.setCost(di.arcFromId(idarc), val);
+ ++idarc;
}
}
@@ -87,12 +95,13 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,
// Solve the problem with the network simplex algorithm
int ret=net.run();
+ int i, j;
if (ret==(int)net.OPTIMAL || ret==(int)net.MAX_ITER_REACHED) {
*cost = 0;
Arc a; di.first(a);
for (; a != INVALID; di.next(a)) {
- int i = di.source(a);
- int j = di.target(a);
+ i = di.source(a);
+ j = di.target(a);
double flow = net.flow(a);
*cost += flow * (*(D+indI[i]*n2+indJ[j-n]));
*(G+indI[i]*n2+indJ[j-n]) = flow;
@@ -106,3 +115,104 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,
return ret;
}
+
+
+
+
+
+
+int EMD_wrap_omp(int n1, int n2, double *X, double *Y, double *D, double *G,
+ double* alpha, double* beta, double *cost, int maxIter, int numThreads) {
+ // beware M and C are stored in row major C style!!!
+
+ using namespace lemon_omp;
+ int n, m, cur;
+
+ typedef FullBipartiteDigraph Digraph;
+ DIGRAPH_TYPEDEFS(Digraph);
+
+ // Get the number of non zero coordinates for r and c
+ n=0;
+ for (int i=0; i<n1; i++) {
+ double val=*(X+i);
+ if (val>0) {
+ n++;
+ }else if(val<0){
+ return INFEASIBLE;
+ }
+ }
+ m=0;
+ for (int i=0; i<n2; i++) {
+ double val=*(Y+i);
+ if (val>0) {
+ m++;
+ }else if(val<0){
+ return INFEASIBLE;
+ }
+ }
+
+ // Define the graph
+
+ std::vector<int> indI(n), indJ(m);
+ std::vector<double> weights1(n), weights2(m);
+ Digraph di(n, m);
+ NetworkSimplexSimple<Digraph,double,double, node_id_type> net(di, true, n+m, ((int64_t)n)*((int64_t)m), maxIter, numThreads);
+
+ // Set supply and demand, don't account for 0 values (faster)
+
+ cur=0;
+ for (int i=0; i<n1; i++) {
+ double val=*(X+i);
+ if (val>0) {
+ weights1[ cur ] = val;
+ indI[cur++]=i;
+ }
+ }
+
+ // Demand is actually negative supply...
+
+ cur=0;
+ for (int i=0; i<n2; i++) {
+ double val=*(Y+i);
+ if (val>0) {
+ weights2[ cur ] = -val;
+ indJ[cur++]=i;
+ }
+ }
+
+
+ net.supplyMap(&weights1[0], n, &weights2[0], m);
+
+ // Set the cost of each edge
+ int64_t idarc = 0;
+ for (int i=0; i<n; i++) {
+ for (int j=0; j<m; j++) {
+ double val=*(D+indI[i]*n2+indJ[j]);
+ net.setCost(di.arcFromId(idarc), val);
+ ++idarc;
+ }
+ }
+
+
+ // Solve the problem with the network simplex algorithm
+
+ int ret=net.run();
+ int i, j;
+ if (ret==(int)net.OPTIMAL || ret==(int)net.MAX_ITER_REACHED) {
+ *cost = 0;
+ Arc a; di.first(a);
+ for (; a != INVALID; di.next(a)) {
+ i = di.source(a);
+ j = di.target(a);
+ double flow = net.flow(a);
+ *cost += flow * (*(D+indI[i]*n2+indJ[j-n]));
+ *(G+indI[i]*n2+indJ[j-n]) = flow;
+ *(alpha + indI[i]) = -net.potential(i);
+ *(beta + indJ[j-n]) = net.potential(j);
+ }
+
+ }
+
+
+ return ret;
+}
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py
index 514a607..5da897d 100644
--- a/ot/lp/__init__.py
+++ b/ot/lp/__init__.py
@@ -8,25 +8,50 @@ Solvers for the original linear program OT problem
#
# License: MIT License
+import os
import multiprocessing
import sys
import numpy as np
-from scipy.sparse import coo_matrix
+import warnings
from . import cvx
from .cvx import barycenter
+
# import compiled emd
from .emd_wrap import emd_c, check_result, emd_1d_sorted
-from ..utils import dist
+from .solver_1d import emd_1d, emd2_1d, wasserstein_1d
+
+from ..utils import dist, list_to_array
from ..utils import parmap
+from ..backend import get_backend
-__all__ = ['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx',
+__all__ = ['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx', ' emd_1d_sorted',
'emd_1d', 'emd2_1d', 'wasserstein_1d']
+def check_number_threads(numThreads):
+ """Checks whether or not the requested number of threads has a valid value.
+
+ Parameters
+ ----------
+ numThreads : int or str
+ The requested number of threads, should either be a strictly positive integer or "max" or None
+
+ Returns
+ -------
+ numThreads : int
+ Corrected number of threads
+ """
+ if (numThreads is None) or (isinstance(numThreads, str) and numThreads.lower() == 'max'):
+ return -1
+ if (not isinstance(numThreads, int)) or numThreads < 1:
+ raise ValueError('numThreads should either be "max" or a strictly positive integer')
+ return numThreads
+
+
def center_ot_dual(alpha0, beta0, a=None, b=None):
- r"""Center dual OT potentials w.r.t. theirs weights
+ r"""Center dual OT potentials w.r.t. their weights
The main idea of this function is to find unique dual potentials
that ensure some kind of centering/fairness. The main idea is to find dual potentials that lead to the same final objective value for both source and targets (see below for more details). It will help having
@@ -37,7 +62,7 @@ def center_ot_dual(alpha0, beta0, a=None, b=None):
is the following:
.. math::
- \alpha^T a= \beta^T b
+ \alpha^T \mathbf{a} = \beta^T \mathbf{b}
in addition to the OT problem constraints.
@@ -45,11 +70,11 @@ def center_ot_dual(alpha0, beta0, a=None, b=None):
a constant from both :math:`\alpha_0` and :math:`\beta_0`.
.. math::
- c=\frac{\beta0^T b-\alpha_0^T a}{1^Tb+1^Ta}
+ c &= \frac{\beta_0^T \mathbf{b} - \alpha_0^T \mathbf{a}}{\mathbf{1}^T \mathbf{b} + \mathbf{1}^T \mathbf{a}}
- \alpha=\alpha_0+c
+ \alpha &= \alpha_0 + c
- \beta=\beta0+c
+ \beta &= \beta_0 + c
Parameters
----------
@@ -92,35 +117,35 @@ def estimate_dual_null_weights(alpha0, beta0, a, b, M):
The feasible values are computed efficiently but rather coarsely.
.. warning::
- This function is necessary because the C++ solver in emd_c
- discards all samples in the distributions with
- zeros weights. This means that while the primal variable (transport
+ This function is necessary because the C++ solver in `emd_c`
+ discards all samples in the distributions with
+ zeros weights. This means that while the primal variable (transport
matrix) is exact, the solver only returns feasible dual potentials
- on the samples with weights different from zero.
+ on the samples with weights different from zero.
First we compute the constraints violations:
.. math::
- V=\alpha+\beta^T-M
+ \mathbf{V} = \alpha + \beta^T - \mathbf{M}
- Next we compute the max amount of violation per row (alpha) and
- columns (beta)
+ Next we compute the max amount of violation per row (:math:`\alpha`) and
+ columns (:math:`beta`)
.. math::
- v^a_i=\max_j V_{i,j}
+ \mathbf{v^a}_i = \max_j \mathbf{V}_{i,j}
- v^b_j=\max_i V_{i,j}
+ \mathbf{v^b}_j = \max_i \mathbf{V}_{i,j}
Finally we update the dual potential with 0 weights if a
constraint is violated
.. math::
- \alpha_i = \alpha_i -v^a_i \quad \text{ if } a_i=0 \text{ and } v^a_i>0
+ \alpha_i = \alpha_i - \mathbf{v^a}_i \quad \text{ if } \mathbf{a}_i=0 \text{ and } \mathbf{v^a}_i>0
- \beta_j = \beta_j -v^b_j \quad \text{ if } b_j=0 \text{ and } v^b_j>0
+ \beta_j = \beta_j - \mathbf{v^b}_j \quad \text{ if } \mathbf{b}_j=0 \text{ and } \mathbf{v^b}_j > 0
In the end the dual potentials are centered using function
- :ref:`center_ot_dual`.
+ :py:func:`ot.lp.center_ot_dual`.
Note that all those updates do not change the objective value of the
solution but provide dual potentials that do not violate the constraints.
@@ -172,54 +197,62 @@ def estimate_dual_null_weights(alpha0, beta0, a, b, M):
return center_ot_dual(alpha, beta, a, b)
-def emd(a, b, M, numItermax=100000, log=False, center_dual=True):
+def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1):
r"""Solves the Earth Movers distance problem and returns the OT matrix
.. math::
- \gamma = arg\min_\gamma <\gamma,M>_F
+ \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F
- s.t. \gamma 1 = a
+ s.t. \ \gamma \mathbf{1} = \mathbf{a}
- \gamma^T 1= b
+ \gamma^T \mathbf{1} = \mathbf{b}
+
+ \gamma \geq 0
- \gamma\geq 0
where :
- - M is the metric cost matrix
- - a and b are the sample weights
+ - :math:`\mathbf{M}` is the metric cost matrix
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are the sample weights
- .. warning::
- Note that the M matrix needs to be a C-order numpy.array in float64
- format.
+ .. warning:: Note that the :math:`\mathbf{M}` matrix in numpy needs to be a C-order
+ numpy.array in float64 format. It will be converted if not in this
+ format
+
+ .. note:: This function is backend-compatible and will work on arrays
+ from all compatible backends.
- Uses the algorithm proposed in [1]_
+ Uses the algorithm proposed in :ref:`[1] <references-emd>`.
Parameters
----------
- a : (ns,) numpy.ndarray, float64
+ a : (ns,) array-like, float
Source histogram (uniform weight if empty list)
- b : (nt,) numpy.ndarray, float64
+ b : (nt,) array-like, float
Target histogram (uniform weight if empty list)
- M : (ns,nt) numpy.ndarray, float64
- Loss matrix (c-order array with type float64)
+ M : (ns,nt) array-like, float
+ Loss matrix (c-order array in numpy with type float64)
numItermax : int, optional (default=100000)
The maximum number of iterations before stopping the optimization
algorithm if it has not converged.
log: bool, optional (default=False)
- If True, returns a dictionary containing the cost and dual
- variables. Otherwise returns only the optimal transportation matrix.
+ If True, returns a dictionary containing the cost and dual variables.
+ Otherwise returns only the optimal transportation matrix.
center_dual: boolean, optional (default=True)
If True, centers the dual potential using function
:ref:`center_ot_dual`.
+ numThreads: int or "max", optional (default=1, i.e. OpenMP is not used)
+ If compiled with OpenMP, chooses the number of threads to parallelize.
+ "max" selects the highest number possible.
Returns
-------
- gamma: (ns x nt) numpy.ndarray
- Optimal transportation matrix for the given parameters
- log: dict
- If input log is true, a dictionary containing the cost and dual
- variables and exit status
+ gamma: array-like, shape (ns, nt)
+ Optimal transportation matrix for the given
+ parameters
+ log: dict, optional
+ If input log is true, a dictionary containing the
+ cost and dual variables and exit status
Examples
@@ -232,26 +265,39 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True):
>>> a=[.5,.5]
>>> b=[.5,.5]
>>> M=[[0.,1.],[1.,0.]]
- >>> ot.emd(a,b,M)
+ >>> ot.emd(a, b, M)
array([[0.5, 0. ],
[0. , 0.5]])
+
+ .. _references-emd:
References
----------
-
- .. [1] Bonneel, N., Van De Panne, M., Paris, S., & Heidrich, W.
- (2011, December). Displacement interpolation using Lagrangian mass
- transport. In ACM Transactions on Graphics (TOG) (Vol. 30, No. 6, p.
- 158). ACM.
+ .. [1] Bonneel, N., Van De Panne, M., Paris, S., & Heidrich, W. (2011,
+ December). Displacement interpolation using Lagrangian mass transport.
+ In ACM Transactions on Graphics (TOG) (Vol. 30, No. 6, p. 158). ACM.
See Also
--------
ot.bregman.sinkhorn : Entropic regularized OT
- ot.optim.cg : General regularized OT"""
+ ot.optim.cg : General regularized OT
+ """
+
+ # convert to numpy if list
+ a, b, M = list_to_array(a, b, M)
+
+ a0, b0, M0 = a, b, M
+ nx = get_backend(M0, a0, b0)
+ # convert to numpy
+ M = nx.to_numpy(M)
+ a = nx.to_numpy(a)
+ b = nx.to_numpy(b)
+
+ # ensure float64
a = np.asarray(a, dtype=np.float64)
b = np.asarray(b, dtype=np.float64)
- M = np.asarray(M, dtype=np.float64)
+ M = np.asarray(M, dtype=np.float64, order='C')
# if empty array given then use uniform distributions
if len(a) == 0:
@@ -262,81 +308,91 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True):
assert (a.shape[0] == M.shape[0] and b.shape[0] == M.shape[1]), \
"Dimension mismatch, check dimensions of M with a and b"
+ # ensure that same mass
+ np.testing.assert_almost_equal(a.sum(0),
+ b.sum(0), err_msg='a and b vector must have the same sum')
+ b = b * a.sum() / b.sum()
+
asel = a != 0
bsel = b != 0
- G, cost, u, v, result_code = emd_c(a, b, M, numItermax)
+ numThreads = check_number_threads(numThreads)
+
+ G, cost, u, v, result_code = emd_c(a, b, M, numItermax, numThreads)
if center_dual:
u, v = center_ot_dual(u, v, a, b)
if np.any(~asel) or np.any(~bsel):
u, v = estimate_dual_null_weights(u, v, a, b, M)
-
+
result_code_string = check_result(result_code)
if log:
log = {}
log['cost'] = cost
- log['u'] = u
- log['v'] = v
+ log['u'] = nx.from_numpy(u, type_as=a0)
+ log['v'] = nx.from_numpy(v, type_as=b0)
log['warning'] = result_code_string
log['result_code'] = result_code
- return G, log
- return G
+ return nx.from_numpy(G, type_as=M0), log
+ return nx.from_numpy(G, type_as=M0)
-def emd2(a, b, M, processes=multiprocessing.cpu_count(),
+def emd2(a, b, M, processes=1,
numItermax=100000, log=False, return_matrix=False,
- center_dual=True):
+ center_dual=True, numThreads=1):
r"""Solves the Earth Movers distance problem and returns the loss
.. math::
- \min_\gamma <\gamma,M>_F
+ \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F
+
+ s.t. \ \gamma \mathbf{1} = \mathbf{a}
- s.t. \gamma 1 = a
+ \gamma^T \mathbf{1} = \mathbf{b}
- \gamma^T 1= b
+ \gamma \geq 0
- \gamma\geq 0
where :
- - M is the metric cost matrix
- - a and b are the sample weights
+ - :math:`\mathbf{M}` is the metric cost matrix
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are the sample weights
- .. warning::
- Note that the M matrix needs to be a C-order numpy.array in float64
- format.
+ .. note:: This function is backend-compatible and will work on arrays
+ from all compatible backends.
- Uses the algorithm proposed in [1]_
+ Uses the algorithm proposed in :ref:`[1] <references-emd2>`.
Parameters
----------
- a : (ns,) numpy.ndarray, float64
+ a : (ns,) array-like, float64
Source histogram (uniform weight if empty list)
- b : (nt,) numpy.ndarray, float64
+ b : (nt,) array-like, float64
Target histogram (uniform weight if empty list)
- M : (ns,nt) numpy.ndarray, float64
- Loss matrix (c-order array with type float64)
- processes : int, optional (default=nb cpu)
- Nb of processes used for multiple emd computation (not used on windows)
+ M : (ns,nt) array-like, float64
+ Loss matrix (for numpy c-order array with type float64)
+ processes : int, optional (default=1)
+ Nb of processes used for multiple emd computation (deprecated)
numItermax : int, optional (default=100000)
The maximum number of iterations before stopping the optimization
algorithm if it has not converged.
log: boolean, optional (default=False)
- If True, returns a dictionary containing the cost and dual
+ If True, returns a dictionary containing dual
variables. Otherwise returns only the optimal transportation cost.
return_matrix: boolean, optional (default=False)
If True, returns the optimal transportation matrix in the log.
center_dual: boolean, optional (default=True)
If True, centers the dual potential using function
:ref:`center_ot_dual`.
+ numThreads: int or "max", optional (default=1, i.e. OpenMP is not used)
+ If compiled with OpenMP, chooses the number of threads to parallelize.
+ "max" selects the highest number possible.
Returns
-------
- gamma: (ns x nt) ndarray
- Optimal transportation matrix for the given parameters
- log: dictnp
- If input log is true, a dictionary containing the cost and dual
+ W: float, array-like
+ Optimal transportation loss for the given parameters
+ log: dict
+ If input log is true, a dictionary containing dual
variables and exit status
@@ -354,9 +410,10 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
>>> ot.emd2(a,b,M)
0.0
+
+ .. _references-emd2:
References
----------
-
.. [1] Bonneel, N., Van De Panne, M., Paris, S., & Heidrich, W.
(2011, December). Displacement interpolation using Lagrangian mass
transport. In ACM Transactions on Graphics (TOG) (Vol. 30, No. 6, p.
@@ -365,15 +422,22 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
See Also
--------
ot.bregman.sinkhorn : Entropic regularized OT
- ot.optim.cg : General regularized OT"""
+ ot.optim.cg : General regularized OT
+ """
+
+ a, b, M = list_to_array(a, b, M)
+
+ a0, b0, M0 = a, b, M
+ nx = get_backend(M0, a0, b0)
+
+ # convert to numpy
+ M = nx.to_numpy(M)
+ a = nx.to_numpy(a)
+ b = nx.to_numpy(b)
a = np.asarray(a, dtype=np.float64)
b = np.asarray(b, dtype=np.float64)
- M = np.asarray(M, dtype=np.float64)
-
- # problem with pikling Forks
- if sys.platform.endswith('win32'):
- processes = 1
+ M = np.asarray(M, dtype=np.float64, order='C')
# if empty array given then use uniform distributions
if len(a) == 0:
@@ -386,11 +450,13 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
asel = a != 0
+ numThreads = check_number_threads(numThreads)
+
if log or return_matrix:
def f(b):
bsel = b != 0
-
- G, cost, u, v, result_code = emd_c(a, b, M, numItermax)
+
+ G, cost, u, v, result_code = emd_c(a, b, M, numItermax, numThreads)
if center_dual:
u, v = center_ot_dual(u, v, a, b)
@@ -400,17 +466,20 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
result_code_string = check_result(result_code)
log = {}
+ G = nx.from_numpy(G, type_as=M0)
if return_matrix:
log['G'] = G
- log['u'] = u
- log['v'] = v
+ log['u'] = nx.from_numpy(u, type_as=a0)
+ log['v'] = nx.from_numpy(v, type_as=b0)
log['warning'] = result_code_string
log['result_code'] = result_code
+ cost = nx.set_gradients(nx.from_numpy(cost, type_as=M0),
+ (a0, b0, M0), (log['u'], log['v'], G))
return [cost, log]
else:
def f(b):
bsel = b != 0
- G, cost, u, v, result_code = emd_c(a, b, M, numItermax)
+ G, cost, u, v, result_code = emd_c(a, b, M, numItermax, numThreads)
if center_dual:
u, v = center_ot_dual(u, v, a, b)
@@ -418,6 +487,11 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
if np.any(~asel) or np.any(~bsel):
u, v = estimate_dual_null_weights(u, v, a, b, M)
+ G = nx.from_numpy(G, type_as=M0)
+ cost = nx.set_gradients(nx.from_numpy(cost, type_as=M0),
+ (a0, b0, M0), (nx.from_numpy(u, type_as=a0),
+ nx.from_numpy(v, type_as=b0), G))
+
check_result(result_code)
return cost
@@ -426,35 +500,53 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
nb = b.shape[1]
if processes > 1:
- res = parmap(f, [b[:, i] for i in range(nb)], processes)
- else:
- res = list(map(f, [b[:, i].copy() for i in range(nb)]))
+ warnings.warn(
+ "The 'processes' parameter has been deprecated. "
+ "Multiprocessing should be done outside of POT."
+ )
+ res = list(map(f, [b[:, i].copy() for i in range(nb)]))
return res
def free_support_barycenter(measures_locations, measures_weights, X_init, b=None, weights=None, numItermax=100,
- stopThr=1e-7, verbose=False, log=None):
- """
- Solves the free support (locations of the barycenters are optimized, not the weights) Wasserstein barycenter problem (i.e. the weighted Frechet mean for the 2-Wasserstein distance)
+ stopThr=1e-7, verbose=False, log=None, numThreads=1):
+ r"""
+ Solves the free support (locations of the barycenters are optimized, not the weights) Wasserstein barycenter problem (i.e. the weighted Frechet mean for the 2-Wasserstein distance), formally:
+
+ .. math::
+ \min_\mathbf{X} \quad \sum_{i=1}^N w_i W_2^2(\mathbf{b}, \mathbf{X}, \mathbf{a}_i, \mathbf{X}_i)
+
+ where :
+
+ - :math:`w \in \mathbb{(0, 1)}^{N}`'s are the barycenter weights and sum to one
+ - the :math:`\mathbf{a}_i \in \mathbb{R}^{k_i}` are the empirical measures weights and sum to one for each :math:`i`
+ - the :math:`\mathbf{X}_i \in \mathbb{R}^{k_i, d}` are the empirical measures atoms locations
+ - :math:`\mathbf{b} \in \mathbb{R}^{k}` is the desired weights vector of the barycenter
+
+ This problem is considered in :ref:`[1] <references-free-support-barycenter>` (Algorithm 2).
+ There are two differences with the following codes:
- The function solves the Wasserstein barycenter problem when the barycenter measure is constrained to be supported on k atoms.
- This problem is considered in [1] (Algorithm 2). There are two differences with the following codes:
- we do not optimize over the weights
- - we do not do line search for the locations updates, we use i.e. theta = 1 in [1] (Algorithm 2). This can be seen as a discrete implementation of the fixed-point algorithm of [2] proposed in the continuous setting.
+ - we do not do line search for the locations updates, we use i.e. :math:`\theta = 1` in
+ :ref:`[1] <references-free-support-barycenter>` (Algorithm 2). This can be seen as a discrete
+ implementation of the fixed-point algorithm of
+ :ref:`[2] <references-free-support-barycenter>` proposed in the continuous setting.
Parameters
----------
- measures_locations : list of (k_i,d) numpy.ndarray
- The discrete support of a measure supported on k_i locations of a d-dimensional space (k_i can be different for each element of the list)
- measures_weights : list of (k_i,) numpy.ndarray
- Numpy arrays where each numpy array has k_i non-negatives values summing to one representing the weights of each discrete input measure
+ measures_locations : list of N (k_i,d) numpy.ndarray
+ The discrete support of a measure supported on :math:`k_i` locations of a `d`-dimensional space
+ (:math:`k_i` can be different for each element of the list)
+ measures_weights : list of N (k_i,) numpy.ndarray
+ Numpy arrays where each numpy array has :math:`k_i` non-negatives values summing to one
+ representing the weights of each discrete input measure
X_init : (k,d) np.ndarray
- Initialization of the support locations (on k atoms) of the barycenter
+ Initialization of the support locations (on `k` atoms) of the barycenter
b : (k,) np.ndarray
Initialization of the weights of the barycenter (non-negatives, sum to 1)
- weights : (k,) np.ndarray
+ weights : (N,) np.ndarray
Initialization of the coefficients of the barycenter (non-negatives, sum to 1)
numItermax : int, optional
@@ -465,15 +557,20 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None
Print information along iterations
log : bool, optional
record log if True
+ numThreads: int or "max", optional (default=1, i.e. OpenMP is not used)
+ If compiled with OpenMP, chooses the number of threads to parallelize.
+ "max" selects the highest number possible.
+
Returns
-------
X : (k,d) np.ndarray
Support locations (on k atoms) of the barycenter
+
+ .. _references-free-support-barycenter:
References
----------
-
.. [1] Cuturi, Marco, and Arnaud Doucet. "Fast computation of Wasserstein barycenters." International Conference on Machine Learning. 2014.
.. [2] Álvarez-Esteban, Pedro C., et al. "A fixed-point approach to barycenters in Wasserstein space." Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762.
@@ -504,7 +601,7 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None
for (measure_locations_i, measure_weights_i, weight_i) in zip(measures_locations, measures_weights,
weights.tolist()):
M_i = dist(X, measure_locations_i)
- T_i = emd(b, measure_weights_i, M_i)
+ T_i = emd(b, measure_weights_i, M_i, numThreads=numThreads)
T_sum = T_sum + weight_i * np.reshape(1. / b, (-1, 1)) * np.matmul(T_i, measure_locations_i)
displacement_square_norm = np.sum(np.square(T_sum - X))
@@ -523,287 +620,3 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None
return X, log_dict
else:
return X
-
-
-def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True,
- log=False):
- r"""Solves the Earth Movers distance problem between 1d measures and returns
- the OT matrix
-
-
- .. math::
- \gamma = arg\min_\gamma \sum_i \sum_j \gamma_{ij} d(x_a[i], x_b[j])
-
- s.t. \gamma 1 = a,
- \gamma^T 1= b,
- \gamma\geq 0
- where :
-
- - d is the metric
- - x_a and x_b are the samples
- - a and b are the sample weights
-
- When 'minkowski' is used as a metric, :math:`d(x, y) = |x - y|^p`.
-
- Uses the algorithm detailed in [1]_
-
- Parameters
- ----------
- x_a : (ns,) or (ns, 1) ndarray, float64
- Source dirac locations (on the real line)
- x_b : (nt,) or (ns, 1) ndarray, float64
- Target dirac locations (on the real line)
- a : (ns,) ndarray, float64, optional
- Source histogram (default is uniform weight)
- b : (nt,) ndarray, float64, optional
- Target histogram (default is uniform weight)
- metric: str, optional (default='sqeuclidean')
- Metric to be used. Only strings listed in :func:`ot.dist` are accepted.
- Due to implementation details, this function runs faster when
- `'sqeuclidean'`, `'cityblock'`, or `'euclidean'` metrics are used.
- p: float, optional (default=1.0)
- The p-norm to apply for if metric='minkowski'
- dense: boolean, optional (default=True)
- If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt).
- Otherwise returns a sparse representation using scipy's `coo_matrix`
- format. Due to implementation details, this function runs faster when
- `'sqeuclidean'`, `'minkowski'`, `'cityblock'`, or `'euclidean'` metrics
- are used.
- log: boolean, optional (default=False)
- If True, returns a dictionary containing the cost.
- Otherwise returns only the optimal transportation matrix.
-
- Returns
- -------
- gamma: (ns, nt) ndarray
- Optimal transportation matrix for the given parameters
- log: dict
- If input log is True, a dictionary containing the cost
-
-
- Examples
- --------
-
- Simple example with obvious solution. The function emd_1d accepts lists and
- performs automatic conversion to numpy arrays
-
- >>> import ot
- >>> a=[.5, .5]
- >>> b=[.5, .5]
- >>> x_a = [2., 0.]
- >>> x_b = [0., 3.]
- >>> ot.emd_1d(x_a, x_b, a, b)
- array([[0. , 0.5],
- [0.5, 0. ]])
- >>> ot.emd_1d(x_a, x_b)
- array([[0. , 0.5],
- [0.5, 0. ]])
-
- References
- ----------
-
- .. [1] Peyré, G., & Cuturi, M. (2017). "Computational Optimal
- Transport", 2018.
-
- See Also
- --------
- ot.lp.emd : EMD for multidimensional distributions
- ot.lp.emd2_1d : EMD for 1d distributions (returns cost instead of the
- transportation matrix)
- """
- a = np.asarray(a, dtype=np.float64)
- b = np.asarray(b, dtype=np.float64)
- x_a = np.asarray(x_a, dtype=np.float64)
- x_b = np.asarray(x_b, dtype=np.float64)
-
- assert (x_a.ndim == 1 or x_a.ndim == 2 and x_a.shape[1] == 1), \
- "emd_1d should only be used with monodimensional data"
- assert (x_b.ndim == 1 or x_b.ndim == 2 and x_b.shape[1] == 1), \
- "emd_1d should only be used with monodimensional data"
-
- # if empty array given then use uniform distributions
- if a.ndim == 0 or len(a) == 0:
- a = np.ones((x_a.shape[0],), dtype=np.float64) / x_a.shape[0]
- if b.ndim == 0 or len(b) == 0:
- b = np.ones((x_b.shape[0],), dtype=np.float64) / x_b.shape[0]
-
- x_a_1d = x_a.reshape((-1,))
- x_b_1d = x_b.reshape((-1,))
- perm_a = np.argsort(x_a_1d)
- perm_b = np.argsort(x_b_1d)
-
- G_sorted, indices, cost = emd_1d_sorted(a[perm_a], b[perm_b],
- x_a_1d[perm_a], x_b_1d[perm_b],
- metric=metric, p=p)
- G = coo_matrix((G_sorted, (perm_a[indices[:, 0]], perm_b[indices[:, 1]])),
- shape=(a.shape[0], b.shape[0]))
- if dense:
- G = G.toarray()
- if log:
- log = {'cost': cost}
- return G, log
- return G
-
-
-def emd2_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True,
- log=False):
- r"""Solves the Earth Movers distance problem between 1d measures and returns
- the loss
-
-
- .. math::
- \gamma = arg\min_\gamma \sum_i \sum_j \gamma_{ij} d(x_a[i], x_b[j])
-
- s.t. \gamma 1 = a,
- \gamma^T 1= b,
- \gamma\geq 0
- where :
-
- - d is the metric
- - x_a and x_b are the samples
- - a and b are the sample weights
-
- When 'minkowski' is used as a metric, :math:`d(x, y) = |x - y|^p`.
-
- Uses the algorithm detailed in [1]_
-
- Parameters
- ----------
- x_a : (ns,) or (ns, 1) ndarray, float64
- Source dirac locations (on the real line)
- x_b : (nt,) or (ns, 1) ndarray, float64
- Target dirac locations (on the real line)
- a : (ns,) ndarray, float64, optional
- Source histogram (default is uniform weight)
- b : (nt,) ndarray, float64, optional
- Target histogram (default is uniform weight)
- metric: str, optional (default='sqeuclidean')
- Metric to be used. Only strings listed in :func:`ot.dist` are accepted.
- Due to implementation details, this function runs faster when
- `'sqeuclidean'`, `'minkowski'`, `'cityblock'`, or `'euclidean'` metrics
- are used.
- p: float, optional (default=1.0)
- The p-norm to apply for if metric='minkowski'
- dense: boolean, optional (default=True)
- If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt).
- Otherwise returns a sparse representation using scipy's `coo_matrix`
- format. Only used if log is set to True. Due to implementation details,
- this function runs faster when dense is set to False.
- log: boolean, optional (default=False)
- If True, returns a dictionary containing the transportation matrix.
- Otherwise returns only the loss.
-
- Returns
- -------
- loss: float
- Cost associated to the optimal transportation
- log: dict
- If input log is True, a dictionary containing the Optimal transportation
- matrix for the given parameters
-
-
- Examples
- --------
-
- Simple example with obvious solution. The function emd2_1d accepts lists and
- performs automatic conversion to numpy arrays
-
- >>> import ot
- >>> a=[.5, .5]
- >>> b=[.5, .5]
- >>> x_a = [2., 0.]
- >>> x_b = [0., 3.]
- >>> ot.emd2_1d(x_a, x_b, a, b)
- 0.5
- >>> ot.emd2_1d(x_a, x_b)
- 0.5
-
- References
- ----------
-
- .. [1] Peyré, G., & Cuturi, M. (2017). "Computational Optimal
- Transport", 2018.
-
- See Also
- --------
- ot.lp.emd2 : EMD for multidimensional distributions
- ot.lp.emd_1d : EMD for 1d distributions (returns the transportation matrix
- instead of the cost)
- """
- # If we do not return G (log==False), then we should not to cast it to dense
- # (useless overhead)
- G, log_emd = emd_1d(x_a=x_a, x_b=x_b, a=a, b=b, metric=metric, p=p,
- dense=dense and log, log=True)
- cost = log_emd['cost']
- if log:
- log_emd = {'G': G}
- return cost, log_emd
- return cost
-
-
-def wasserstein_1d(x_a, x_b, a=None, b=None, p=1.):
- r"""Solves the p-Wasserstein distance problem between 1d measures and returns
- the distance
-
- .. math::
- \min_\gamma \left( \sum_i \sum_j \gamma_{ij} \|x_a[i] - x_b[j]\|^p \right)^{1/p}
-
- s.t. \gamma 1 = a,
- \gamma^T 1= b,
- \gamma\geq 0
-
- where :
-
- - x_a and x_b are the samples
- - a and b are the sample weights
-
- Uses the algorithm detailed in [1]_
-
- Parameters
- ----------
- x_a : (ns,) or (ns, 1) ndarray, float64
- Source dirac locations (on the real line)
- x_b : (nt,) or (ns, 1) ndarray, float64
- Target dirac locations (on the real line)
- a : (ns,) ndarray, float64, optional
- Source histogram (default is uniform weight)
- b : (nt,) ndarray, float64, optional
- Target histogram (default is uniform weight)
- p: float, optional (default=1.0)
- The order of the p-Wasserstein distance to be computed
-
- Returns
- -------
- dist: float
- p-Wasserstein distance
-
-
- Examples
- --------
-
- Simple example with obvious solution. The function wasserstein_1d accepts
- lists and performs automatic conversion to numpy arrays
-
- >>> import ot
- >>> a=[.5, .5]
- >>> b=[.5, .5]
- >>> x_a = [2., 0.]
- >>> x_b = [0., 3.]
- >>> ot.wasserstein_1d(x_a, x_b, a, b)
- 0.5
- >>> ot.wasserstein_1d(x_a, x_b)
- 0.5
-
- References
- ----------
-
- .. [1] Peyré, G., & Cuturi, M. (2017). "Computational Optimal
- Transport", 2018.
-
- See Also
- --------
- ot.lp.emd_1d : EMD for 1d distributions
- """
- cost_emd = emd2_1d(x_a=x_a, x_b=x_b, a=a, b=b, metric='minkowski', p=p,
- dense=False, log=False)
- return np.power(cost_emd, 1. / p)
diff --git a/ot/lp/cvx.py b/ot/lp/cvx.py
index 8e763be..869d450 100644
--- a/ot/lp/cvx.py
+++ b/ot/lp/cvx.py
@@ -27,7 +27,7 @@ def scipy_sparse_to_spmatrix(A):
def barycenter(A, M, weights=None, verbose=False, log=False, solver='interior-point'):
- """Compute the Wasserstein barycenter of distributions A
+ r"""Compute the Wasserstein barycenter of distributions A
The function solves the following optimization problem [16]:
@@ -76,7 +76,6 @@ def barycenter(A, M, weights=None, verbose=False, log=False, solver='interior-po
.. [16] Agueh, M., & Carlier, G. (2011). Barycenters in the Wasserstein space. SIAM Journal on Mathematical Analysis, 43(2), 904-924.
-
"""
if weights is None:
diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx
index c167964..42e08f4 100644
--- a/ot/lp/emd_wrap.pyx
+++ b/ot/lp/emd_wrap.pyx
@@ -20,6 +20,7 @@ import warnings
cdef extern from "EMD.h":
int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int maxIter) nogil
+ int EMD_wrap_omp(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int maxIter, int numThreads) nogil
cdef enum ProblemType: INFEASIBLE, OPTIMAL, UNBOUNDED, MAX_ITER_REACHED
@@ -38,7 +39,7 @@ def check_result(result_code):
@cython.boundscheck(False)
@cython.wraparound(False)
-def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mode="c"] b, np.ndarray[double, ndim=2, mode="c"] M, int max_iter):
+def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mode="c"] b, np.ndarray[double, ndim=2, mode="c"] M, int max_iter, int numThreads):
"""
Solves the Earth Movers distance problem and returns the optimal transport matrix
@@ -97,8 +98,6 @@ def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mod
cdef np.ndarray[double, ndim=2, mode="c"] G=np.zeros([0, 0])
cdef np.ndarray[double, ndim=1, mode="c"] Gv=np.zeros(0)
- cdef np.ndarray[long, ndim=1, mode="c"] iG=np.zeros(0,dtype=np.int)
- cdef np.ndarray[long, ndim=1, mode="c"] jG=np.zeros(0,dtype=np.int)
if not len(a):
a=np.ones((n1,))/n1
@@ -111,8 +110,10 @@ def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mod
# calling the function
with nogil:
- result_code = EMD_wrap(n1, n2, <double*> a.data, <double*> b.data, <double*> M.data, <double*> G.data, <double*> alpha.data, <double*> beta.data, <double*> &cost, max_iter)
-
+ if numThreads == 1:
+ result_code = EMD_wrap(n1, n2, <double*> a.data, <double*> b.data, <double*> M.data, <double*> G.data, <double*> alpha.data, <double*> beta.data, <double*> &cost, max_iter)
+ else:
+ result_code = EMD_wrap_omp(n1, n2, <double*> a.data, <double*> b.data, <double*> M.data, <double*> G.data, <double*> alpha.data, <double*> beta.data, <double*> &cost, max_iter, numThreads)
return G, cost, alpha, beta, result_code
@@ -157,22 +158,22 @@ def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights,
cost associated to the optimal transportation
"""
cdef double cost = 0.
- cdef int n = u_weights.shape[0]
- cdef int m = v_weights.shape[0]
+ cdef Py_ssize_t n = u_weights.shape[0]
+ cdef Py_ssize_t m = v_weights.shape[0]
- cdef int i = 0
+ cdef Py_ssize_t i = 0
cdef double w_i = u_weights[0]
- cdef int j = 0
+ cdef Py_ssize_t j = 0
cdef double w_j = v_weights[0]
cdef double m_ij = 0.
cdef np.ndarray[double, ndim=1, mode="c"] G = np.zeros((n + m - 1, ),
dtype=np.float64)
- cdef np.ndarray[long, ndim=2, mode="c"] indices = np.zeros((n + m - 1, 2),
- dtype=np.int)
- cdef int cur_idx = 0
- while i < n and j < m:
+ cdef np.ndarray[long long, ndim=2, mode="c"] indices = np.zeros((n + m - 1, 2),
+ dtype=np.int64)
+ cdef Py_ssize_t cur_idx = 0
+ while True:
if metric == 'sqeuclidean':
m_ij = (u[i] - v[j]) * (u[i] - v[j])
elif metric == 'cityblock' or metric == 'euclidean':
@@ -188,6 +189,8 @@ def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights,
indices[cur_idx, 0] = i
indices[cur_idx, 1] = j
i += 1
+ if i == n:
+ break
w_j -= w_i
w_i = u_weights[i]
else:
@@ -196,7 +199,10 @@ def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights,
indices[cur_idx, 0] = i
indices[cur_idx, 1] = j
j += 1
+ if j == m:
+ break
w_i -= w_j
w_j = v_weights[j]
cur_idx += 1
+ cur_idx += 1
return G[:cur_idx], indices[:cur_idx], cost
diff --git a/ot/lp/full_bipartitegraph.h b/ot/lp/full_bipartitegraph.h
index 87a1bec..713ccb5 100644
--- a/ot/lp/full_bipartitegraph.h
+++ b/ot/lp/full_bipartitegraph.h
@@ -23,10 +23,10 @@
*
*/
-#ifndef LEMON_FULL_BIPARTITE_GRAPH_H
-#define LEMON_FULL_BIPARTITE_GRAPH_H
+#pragma once
#include "core.h"
+#include <cstdint>
///\ingroup graphs
///\file
@@ -44,16 +44,16 @@ namespace lemon {
//class Node;
typedef int Node;
//class Arc;
- typedef long long Arc;
+ typedef int64_t Arc;
protected:
int _node_num;
- long long _arc_num;
+ int64_t _arc_num;
FullBipartiteDigraphBase() {}
- void construct(int n1, int n2) { _node_num = n1+n2; _arc_num = n1 * n2; _n1=n1; _n2=n2;}
+ void construct(int n1, int n2) { _node_num = n1+n2; _arc_num = (int64_t)n1 * (int64_t)n2; _n1=n1; _n2=n2;}
public:
@@ -65,25 +65,25 @@ namespace lemon {
Arc arc(const Node& s, const Node& t) const {
if (s<_n1 && t>=_n1)
- return Arc(s * _n2 + (t-_n1) );
+ return Arc((int64_t)s * (int64_t)_n2 + (int64_t)(t-_n1) );
else
return Arc(-1);
}
int nodeNum() const { return _node_num; }
- long long arcNum() const { return _arc_num; }
+ int64_t arcNum() const { return _arc_num; }
int maxNodeId() const { return _node_num - 1; }
- long long maxArcId() const { return _arc_num - 1; }
+ int64_t maxArcId() const { return _arc_num - 1; }
Node source(Arc arc) const { return arc / _n2; }
Node target(Arc arc) const { return (arc % _n2) + _n1; }
static int id(Node node) { return node; }
- static long long id(Arc arc) { return arc; }
+ static int64_t id(Arc arc) { return arc; }
static Node nodeFromId(int id) { return Node(id);}
- static Arc arcFromId(int id) { return Arc(id);}
+ static Arc arcFromId(int64_t id) { return Arc(id);}
Arc findArc(Node s, Node t, Arc prev = -1) const {
@@ -136,7 +136,7 @@ namespace lemon {
///
/// \brief A directed full graph class.
///
- /// FullBipartiteDigraph is a simple and fast implmenetation of directed full
+ /// FullBipartiteDigraph is a simple and fast implementation of directed full
/// (complete) graphs. It contains an arc from each node to each node
/// (including a loop for each node), therefore the number of arcs
/// is the square of the number of nodes.
@@ -203,13 +203,10 @@ namespace lemon {
/// \brief Number of nodes.
int nodeNum() const { return Parent::nodeNum(); }
/// \brief Number of arcs.
- long long arcNum() const { return Parent::arcNum(); }
+ int64_t arcNum() const { return Parent::arcNum(); }
};
} //namespace lemon
-
-
-#endif //LEMON_FULL_GRAPH_H
diff --git a/ot/lp/full_bipartitegraph_omp.h b/ot/lp/full_bipartitegraph_omp.h
new file mode 100644
index 0000000..8cbed0b
--- /dev/null
+++ b/ot/lp/full_bipartitegraph_omp.h
@@ -0,0 +1,234 @@
+/* -*- mode: C++; indent-tabs-mode: nil; -*-
+ *
+ * This file has been adapted by Nicolas Bonneel (2013),
+ * from full_graph.h from LEMON, a generic C++ optimization library,
+ * to implement a lightweight fully connected bipartite graph. A previous
+ * version of this file is used as part of the Displacement Interpolation
+ * project,
+ * Web: http://www.cs.ubc.ca/labs/imager/tr/2011/DisplacementInterpolation/
+ *
+ *
+ **** Original file Copyright Notice :
+ * Copyright (C) 2003-2010
+ * Egervary Jeno Kombinatorikus Optimalizalasi Kutatocsoport
+ * (Egervary Research Group on Combinatorial Optimization, EGRES).
+ *
+ * Permission to use, modify and distribute this software is granted
+ * provided that this copyright notice appears in all copies. For
+ * precise terms see the accompanying LICENSE file.
+ *
+ * This software is provided "AS IS" with no warranty of any kind,
+ * express or implied, and with no claim as to its suitability for any
+ * purpose.
+ *
+ */
+
+#pragma once
+
+#include <cstdint>
+
+///\ingroup graphs
+///\file
+///\brief FullBipartiteDigraph and FullBipartiteGraph classes.
+
+
+namespace lemon_omp {
+
+ ///This \c \#define creates convenient type definitions for the following
+ ///types of \c Digraph: \c Node, \c NodeIt, \c Arc, \c ArcIt, \c InArcIt,
+ ///\c OutArcIt, \c BoolNodeMap, \c IntNodeMap, \c DoubleNodeMap,
+ ///\c BoolArcMap, \c IntArcMap, \c DoubleArcMap.
+ ///
+ ///\note If the graph type is a dependent type, ie. the graph type depend
+ ///on a template parameter, then use \c TEMPLATE_DIGRAPH_TYPEDEFS()
+ ///macro.
+#define DIGRAPH_TYPEDEFS(Digraph) \
+ typedef Digraph::Node Node; \
+ typedef Digraph::Arc Arc; \
+
+
+ ///Create convenience typedefs for the digraph types and iterators
+
+ ///\see DIGRAPH_TYPEDEFS
+ ///
+ ///\note Use this macro, if the graph type is a dependent type,
+ ///ie. the graph type depend on a template parameter.
+#define TEMPLATE_DIGRAPH_TYPEDEFS(Digraph) \
+ typedef typename Digraph::Node Node; \
+ typedef typename Digraph::Arc Arc; \
+
+
+ class FullBipartiteDigraphBase {
+ public:
+
+ typedef FullBipartiteDigraphBase Digraph;
+
+ //class Node;
+ typedef int Node;
+ //class Arc;
+ typedef int64_t Arc;
+
+ protected:
+
+ int _node_num;
+ int64_t _arc_num;
+
+ FullBipartiteDigraphBase() {}
+
+ void construct(int n1, int n2) { _node_num = n1+n2; _arc_num = (int64_t)n1 * (int64_t)n2; _n1=n1; _n2=n2;}
+
+ public:
+
+ int _n1, _n2;
+
+
+ Node operator()(int ix) const { return Node(ix); }
+ static int index(const Node& node) { return node; }
+
+ Arc arc(const Node& s, const Node& t) const {
+ if (s<_n1 && t>=_n1)
+ return Arc((int64_t)s * (int64_t)_n2 + (int64_t)(t-_n1) );
+ else
+ return Arc(-1);
+ }
+
+ int nodeNum() const { return _node_num; }
+ int64_t arcNum() const { return _arc_num; }
+
+ int maxNodeId() const { return _node_num - 1; }
+ int64_t maxArcId() const { return _arc_num - 1; }
+
+ Node source(Arc arc) const { return arc / _n2; }
+ Node target(Arc arc) const { return (arc % _n2) + _n1; }
+
+ static int id(Node node) { return node; }
+ static int64_t id(Arc arc) { return arc; }
+
+ static Node nodeFromId(int id) { return Node(id);}
+ static Arc arcFromId(int64_t id) { return Arc(id);}
+
+
+ Arc findArc(Node s, Node t, Arc prev = -1) const {
+ return prev == -1 ? arc(s, t) : -1;
+ }
+
+ void first(Node& node) const {
+ node = _node_num - 1;
+ }
+
+ static void next(Node& node) {
+ --node;
+ }
+
+ void first(Arc& arc) const {
+ arc = _arc_num - 1;
+ }
+
+ static void next(Arc& arc) {
+ --arc;
+ }
+
+ void firstOut(Arc& arc, const Node& node) const {
+ if (node>=_n1)
+ arc = -1;
+ else
+ arc = (node + 1) * _n2 - 1;
+ }
+
+ void nextOut(Arc& arc) const {
+ if (arc % _n2 == 0) arc = 0;
+ --arc;
+ }
+
+ void firstIn(Arc& arc, const Node& node) const {
+ if (node<_n1)
+ arc = -1;
+ else
+ arc = _arc_num + node - _node_num;
+ }
+
+ void nextIn(Arc& arc) const {
+ arc -= _n2;
+ if (arc < 0) arc = -1;
+ }
+
+ };
+
+ /// \ingroup graphs
+ ///
+ /// \brief A directed full graph class.
+ ///
+ /// FullBipartiteDigraph is a simple and fast implmenetation of directed full
+ /// (complete) graphs. It contains an arc from each node to each node
+ /// (including a loop for each node), therefore the number of arcs
+ /// is the square of the number of nodes.
+ /// This class is completely static and it needs constant memory space.
+ /// Thus you can neither add nor delete nodes or arcs, however
+ /// the structure can be resized using resize().
+ ///
+ /// This type fully conforms to the \ref concepts::Digraph "Digraph concept".
+ /// Most of its member functions and nested classes are documented
+ /// only in the concept class.
+ ///
+ /// This class provides constant time counting for nodes and arcs.
+ ///
+ /// \note FullBipartiteDigraph and FullBipartiteGraph classes are very similar,
+ /// but there are two differences. While this class conforms only
+ /// to the \ref concepts::Digraph "Digraph" concept, FullBipartiteGraph
+ /// conforms to the \ref concepts::Graph "Graph" concept,
+ /// moreover FullBipartiteGraph does not contain a loop for each
+ /// node as this class does.
+ ///
+ /// \sa FullBipartiteGraph
+ class FullBipartiteDigraph : public FullBipartiteDigraphBase {
+ typedef FullBipartiteDigraphBase Parent;
+
+ public:
+
+ /// \brief Default constructor.
+ ///
+ /// Default constructor. The number of nodes and arcs will be zero.
+ FullBipartiteDigraph() { construct(0,0); }
+
+ /// \brief Constructor
+ ///
+ /// Constructor.
+ /// \param n The number of the nodes.
+ FullBipartiteDigraph(int n1, int n2) { construct(n1, n2); }
+
+
+ /// \brief Returns the node with the given index.
+ ///
+ /// Returns the node with the given index. Since this structure is
+ /// completely static, the nodes can be indexed with integers from
+ /// the range <tt>[0..nodeNum()-1]</tt>.
+ /// The index of a node is the same as its ID.
+ /// \sa index()
+ Node operator()(int ix) const { return Parent::operator()(ix); }
+
+ /// \brief Returns the index of the given node.
+ ///
+ /// Returns the index of the given node. Since this structure is
+ /// completely static, the nodes can be indexed with integers from
+ /// the range <tt>[0..nodeNum()-1]</tt>.
+ /// The index of a node is the same as its ID.
+ /// \sa operator()()
+ static int index(const Node& node) { return Parent::index(node); }
+
+ /// \brief Returns the arc connecting the given nodes.
+ ///
+ /// Returns the arc connecting the given nodes.
+ /*Arc arc(Node u, Node v) const {
+ return Parent::arc(u, v);
+ }*/
+
+ /// \brief Number of nodes.
+ int nodeNum() const { return Parent::nodeNum(); }
+ /// \brief Number of arcs.
+ int64_t arcNum() const { return Parent::arcNum(); }
+ };
+
+
+
+
+} //namespace lemon_omp
diff --git a/ot/lp/network_simplex_simple.h b/ot/lp/network_simplex_simple.h
index 5d93040..3b46b9b 100644
--- a/ot/lp/network_simplex_simple.h
+++ b/ot/lp/network_simplex_simple.h
@@ -25,15 +25,17 @@
*
*/
-#ifndef LEMON_NETWORK_SIMPLEX_SIMPLE_H
-#define LEMON_NETWORK_SIMPLEX_SIMPLE_H
+#pragma once
+#undef DEBUG_LVL
#define DEBUG_LVL 0
#if DEBUG_LVL>0
#include <iomanip>
#endif
-
+#undef EPSILON
+#undef _EPSILON
+#undef MAX_DEBUG_ITER
#define EPSILON 2.2204460492503131e-15
#define _EPSILON 1e-8
#define MAX_DEBUG_ITER 100000
@@ -50,6 +52,7 @@
#include <vector>
#include <limits>
#include <algorithm>
+#include <iostream>
#include <cstdio>
#ifdef HASHMAP
#include <hash_map>
@@ -63,6 +66,8 @@
//#include "sparse_array_n.h"
#include "full_bipartitegraph.h"
+#undef INVALIDNODE
+#undef INVALID
#define INVALIDNODE -1
#define INVALID (-1)
@@ -76,16 +81,16 @@ namespace lemon {
class SparseValueVector
{
public:
- SparseValueVector(int n=0)
+ SparseValueVector(size_t n=0)
{
}
- void resize(int n=0){};
- T operator[](const int id) const
+ void resize(size_t n=0){};
+ T operator[](const size_t id) const
{
#ifdef HASHMAP
- typename stdext::hash_map<int,T>::const_iterator it = data.find(id);
+ typename stdext::hash_map<size_t,T>::const_iterator it = data.find(id);
#else
- typename std::map<int,T>::const_iterator it = data.find(id);
+ typename std::map<size_t,T>::const_iterator it = data.find(id);
#endif
if (it==data.end())
return 0;
@@ -93,16 +98,16 @@ namespace lemon {
return it->second;
}
- ProxyObject<T> operator[](const int id)
+ ProxyObject<T> operator[](const size_t id)
{
return ProxyObject<T>( this, id );
}
//private:
#ifdef HASHMAP
- stdext::hash_map<int,T> data;
+ stdext::hash_map<size_t,T> data;
#else
- std::map<int,T> data;
+ std::map<size_t,T> data;
#endif
};
@@ -110,7 +115,7 @@ namespace lemon {
template <typename T>
class ProxyObject {
public:
- ProxyObject( SparseValueVector<T> *v, int idx ){_v=v; _idx=idx;};
+ ProxyObject( SparseValueVector<T> *v, size_t idx ){_v=v; _idx=idx;};
ProxyObject<T> & operator=( const T &v ) {
// If we get here, we know that operator[] was called to perform a write access,
// so we can insert an item in the vector if needed
@@ -123,9 +128,9 @@ namespace lemon {
// If we get here, we know that operator[] was called to perform a read access,
// so we can simply return the existing object
#ifdef HASHMAP
- typename stdext::hash_map<int,T>::iterator it = _v->data.find(_idx);
+ typename stdext::hash_map<size_t,T>::iterator it = _v->data.find(_idx);
#else
- typename std::map<int,T>::iterator it = _v->data.find(_idx);
+ typename std::map<size_t,T>::iterator it = _v->data.find(_idx);
#endif
if (it==_v->data.end())
return 0;
@@ -137,9 +142,9 @@ namespace lemon {
{
if (val==0) return;
#ifdef HASHMAP
- typename stdext::hash_map<int,T>::iterator it = _v->data.find(_idx);
+ typename stdext::hash_map<size_t,T>::iterator it = _v->data.find(_idx);
#else
- typename std::map<int,T>::iterator it = _v->data.find(_idx);
+ typename std::map<size_t,T>::iterator it = _v->data.find(_idx);
#endif
if (it==_v->data.end())
_v->data[_idx] = val;
@@ -156,9 +161,9 @@ namespace lemon {
{
if (val==0) return;
#ifdef HASHMAP
- typename stdext::hash_map<int,T>::iterator it = _v->data.find(_idx);
+ typename stdext::hash_map<size_t,T>::iterator it = _v->data.find(_idx);
#else
- typename std::map<int,T>::iterator it = _v->data.find(_idx);
+ typename std::map<size_t,T>::iterator it = _v->data.find(_idx);
#endif
if (it==_v->data.end())
_v->data[_idx] = -val;
@@ -173,7 +178,7 @@ namespace lemon {
}
SparseValueVector<T> *_v;
- int _idx;
+ size_t _idx;
};
@@ -204,7 +209,7 @@ namespace lemon {
///
/// \tparam GR The digraph type the algorithm runs on.
/// \tparam V The number type used for flow amounts, capacity bounds
- /// and supply values in the algorithm. By default, it is \c int.
+ /// and supply values in the algorithm. By default, it is \c int64_t.
/// \tparam C The number type used for costs and potentials in the
/// algorithm. By default, it is the same as \c V.
///
@@ -214,7 +219,7 @@ namespace lemon {
/// \note %NetworkSimplexSimple provides five different pivot rule
/// implementations, from which the most efficient one is used
/// by default. For more information, see \ref PivotRule.
- template <typename GR, typename V = int, typename C = V, typename NodesType = unsigned short int>
+ template <typename GR, typename V = int, typename C = V, typename NodesType = unsigned short int, typename ArcsType = int64_t>
class NetworkSimplexSimple
{
public:
@@ -228,7 +233,7 @@ namespace lemon {
/// mixed order in the internal data structure.
/// In special cases, it could lead to better overall performance,
/// but it is usually slower. Therefore it is disabled by default.
- NetworkSimplexSimple(const GR& graph, bool arc_mixing, int nbnodes, long long nb_arcs,int maxiters) :
+ NetworkSimplexSimple(const GR& graph, bool arc_mixing, int nbnodes, ArcsType nb_arcs, size_t maxiters) :
_graph(graph), //_arc_id(graph),
_arc_mixing(arc_mixing), _init_nb_nodes(nbnodes), _init_nb_arcs(nb_arcs),
MAX(std::numeric_limits<Value>::max()),
@@ -288,11 +293,11 @@ namespace lemon {
private:
- int max_iter;
+ size_t max_iter;
TEMPLATE_DIGRAPH_TYPEDEFS(GR);
typedef std::vector<int> IntVector;
- typedef std::vector<NodesType> UHalfIntVector;
+ typedef std::vector<ArcsType> ArcVector;
typedef std::vector<Value> ValueVector;
typedef std::vector<Cost> CostVector;
// typedef SparseValueVector<Cost> CostVector;
@@ -315,9 +320,9 @@ namespace lemon {
// Data related to the underlying digraph
const GR &_graph;
int _node_num;
- int _arc_num;
- int _all_arc_num;
- int _search_arc_num;
+ ArcsType _arc_num;
+ ArcsType _all_arc_num;
+ ArcsType _search_arc_num;
// Parameters of the problem
SupplyType _stype;
@@ -325,9 +330,9 @@ namespace lemon {
inline int _node_id(int n) const {return _node_num-n-1;} ;
- //IntArcMap _arc_id;
- UHalfIntVector _source;
- UHalfIntVector _target;
+// IntArcMap _arc_id;
+ IntVector _source; // keep nodes as integers
+ IntVector _target;
bool _arc_mixing;
public:
// Node and arc data
@@ -341,7 +346,7 @@ namespace lemon {
private:
// Data for storing the spanning tree structure
IntVector _parent;
- IntVector _pred;
+ ArcVector _pred;
IntVector _thread;
IntVector _rev_thread;
IntVector _succ_num;
@@ -349,17 +354,17 @@ namespace lemon {
IntVector _dirty_revs;
BoolVector _forward;
StateVector _state;
- int _root;
+ ArcsType _root;
// Temporary data used in the current pivot iteration
- int in_arc, join, u_in, v_in, u_out, v_out;
- int first, second, right, last;
- int stem, par_stem, new_stem;
+ ArcsType in_arc, join, u_in, v_in, u_out, v_out;
+ ArcsType first, second, right, last;
+ ArcsType stem, par_stem, new_stem;
Value delta;
const Value MAX;
- int mixingCoeff;
+ ArcsType mixingCoeff;
public:
@@ -373,27 +378,27 @@ namespace lemon {
private:
// thank you to DVK and MizardX from StackOverflow for this function!
- inline int sequence(int k) const {
- int smallv = (k > num_total_big_subsequence_numbers) & 1;
+ inline ArcsType sequence(ArcsType k) const {
+ ArcsType smallv = (k > num_total_big_subsequence_numbers) & 1;
k -= num_total_big_subsequence_numbers * smallv;
- int subsequence_length2 = subsequence_length- smallv;
- int subsequence_num = (k / subsequence_length2) + num_big_subseqiences * smallv;
- int subsequence_offset = (k % subsequence_length2) * mixingCoeff;
+ ArcsType subsequence_length2 = subsequence_length- smallv;
+ ArcsType subsequence_num = (k / subsequence_length2) + num_big_subseqiences * smallv;
+ ArcsType subsequence_offset = (k % subsequence_length2) * mixingCoeff;
return subsequence_offset + subsequence_num;
}
- int subsequence_length;
- int num_big_subseqiences;
- int num_total_big_subsequence_numbers;
+ ArcsType subsequence_length;
+ ArcsType num_big_subseqiences;
+ ArcsType num_total_big_subsequence_numbers;
- inline int getArcID(const Arc &arc) const
+ inline ArcsType getArcID(const Arc &arc) const
{
//int n = _arc_num-arc._id-1;
- int n = _arc_num-GR::id(arc)-1;
+ ArcsType n = _arc_num-GR::id(arc)-1;
- //int a = mixingCoeff*(n%mixingCoeff) + n/mixingCoeff;
- //int b = _arc_id[arc];
+ //ArcsType a = mixingCoeff*(n%mixingCoeff) + n/mixingCoeff;
+ //ArcsType b = _arc_id[arc];
if (_arc_mixing)
return sequence(n);
else
@@ -401,16 +406,16 @@ namespace lemon {
}
// finally unused because too slow
- inline int getSource(const int arc) const
+ inline ArcsType getSource(const ArcsType arc) const
{
- //int a = _source[arc];
+ //ArcsType a = _source[arc];
//return a;
- int n = _arc_num-arc-1;
+ ArcsType n = _arc_num-arc-1;
if (_arc_mixing)
n = mixingCoeff*(n%mixingCoeff) + n/mixingCoeff;
- int b;
+ ArcsType b;
if (n>=0)
b = _node_id(_graph.source(GR::arcFromId( n ) ));
else
@@ -436,17 +441,17 @@ namespace lemon {
private:
// References to the NetworkSimplexSimple class
- const UHalfIntVector &_source;
- const UHalfIntVector &_target;
+ const IntVector &_source;
+ const IntVector &_target;
const CostVector &_cost;
const StateVector &_state;
const CostVector &_pi;
- int &_in_arc;
- int _search_arc_num;
+ ArcsType &_in_arc;
+ ArcsType _search_arc_num;
// Pivot rule data
- int _block_size;
- int _next_arc;
+ ArcsType _block_size;
+ ArcsType _next_arc;
NetworkSimplexSimple &_ns;
public:
@@ -460,17 +465,16 @@ namespace lemon {
{
// The main parameters of the pivot rule
const double BLOCK_SIZE_FACTOR = 1.0;
- const int MIN_BLOCK_SIZE = 10;
+ const ArcsType MIN_BLOCK_SIZE = 10;
- _block_size = std::max( int(BLOCK_SIZE_FACTOR *
- std::sqrt(double(_search_arc_num))),
- MIN_BLOCK_SIZE );
+ _block_size = std::max(ArcsType(BLOCK_SIZE_FACTOR * std::sqrt(double(_search_arc_num))), MIN_BLOCK_SIZE);
}
+
// Find next entering arc
bool findEnteringArc() {
Cost c, min = 0;
- int e;
- int cnt = _block_size;
+ ArcsType e;
+ ArcsType cnt = _block_size;
double a;
for (e = _next_arc; e != _search_arc_num; ++e) {
c = _state[e] * (_cost[e] + _pi[_source[e]] - _pi[_target[e]]);
@@ -516,7 +520,7 @@ namespace lemon {
int _init_nb_nodes;
- long long _init_nb_arcs;
+ ArcsType _init_nb_arcs;
/// \name Parameters
/// The parameters of the algorithm can be specified using these
@@ -736,7 +740,7 @@ namespace lemon {
for (int i = 0; i != _node_num; ++i) {
_supply[i] = 0;
}
- for (int i = 0; i != _arc_num; ++i) {
+ for (ArcsType i = 0; i != _arc_num; ++i) {
_cost[i] = 1;
}
_stype = GEQ;
@@ -745,7 +749,7 @@ namespace lemon {
- int divid (int x, int y)
+ int64_t divid (int64_t x, int64_t y)
{
return (x-x%y)/y;
}
@@ -775,7 +779,7 @@ namespace lemon {
_node_num = _init_nb_nodes;
_arc_num = _init_nb_arcs;
int all_node_num = _node_num + 1;
- int max_arc_num = _arc_num + 2 * _node_num;
+ ArcsType max_arc_num = _arc_num + 2 * _node_num;
_source.resize(max_arc_num);
_target.resize(max_arc_num);
@@ -798,13 +802,13 @@ namespace lemon {
//_arc_mixing=false;
if (_arc_mixing) {
// Store the arcs in a mixed order
- int k = std::max(int(std::sqrt(double(_arc_num))), 10);
+ const ArcsType k = std::max(ArcsType(std::sqrt(double(_arc_num))), ArcsType(10));
mixingCoeff = k;
subsequence_length = _arc_num / mixingCoeff + 1;
num_big_subseqiences = _arc_num % mixingCoeff;
num_total_big_subsequence_numbers = subsequence_length * num_big_subseqiences;
- int i = 0, j = 0;
+ ArcsType i = 0, j = 0;
Arc a; _graph.first(a);
for (; a != INVALID; _graph.next(a)) {
_source[i] = _node_id(_graph.source(a));
@@ -814,7 +818,7 @@ namespace lemon {
}
} else {
// Store the arcs in the original order
- int i = 0;
+ ArcsType i = 0;
Arc a; _graph.first(a);
for (; a != INVALID; _graph.next(a), ++i) {
_source[i] = _node_id(_graph.source(a));
@@ -856,7 +860,7 @@ namespace lemon {
Number totalCost() const {
Number c = 0;
for (ArcIt a(_graph); a != INVALID; ++a) {
- int i = getArcID(a);
+ int64_t i = getArcID(a);
c += Number(_flow[i]) * Number(_cost[i]);
}
return c;
@@ -867,15 +871,15 @@ namespace lemon {
Number c = 0;
/*#ifdef HASHMAP
- typename stdext::hash_map<int, Value>::const_iterator it;
+ typename stdext::hash_map<int64_t, Value>::const_iterator it;
#else
- typename std::map<int, Value>::const_iterator it;
+ typename std::map<int64_t, Value>::const_iterator it;
#endif
for (it = _flow.data.begin(); it!=_flow.data.end(); ++it)
c += Number(it->second) * Number(_cost[it->first]);
return c;*/
- for (unsigned long i=0; i<_flow.size(); i++)
+ for (ArcsType i=0; i<_flow.size(); i++)
c += _flow[i] * Number(_cost[i]);
return c;
@@ -944,14 +948,14 @@ namespace lemon {
// Initialize internal data structures
bool init() {
if (_node_num == 0) return false;
-
+
// Check the sum of supply values
_sum_supply = 0;
for (int i = 0; i != _node_num; ++i) {
_sum_supply += _supply[i];
}
if ( fabs(_sum_supply) > _EPSILON ) return false;
-
+
_sum_supply = 0;
// Initialize artifical cost
@@ -960,14 +964,14 @@ namespace lemon {
ART_COST = std::numeric_limits<Cost>::max() / 2 + 1;
} else {
ART_COST = 0;
- for (int i = 0; i != _arc_num; ++i) {
+ for (ArcsType i = 0; i != _arc_num; ++i) {
if (_cost[i] > ART_COST) ART_COST = _cost[i];
}
ART_COST = (ART_COST + 1) * _node_num;
}
// Initialize arc maps
- for (int i = 0; i != _arc_num; ++i) {
+ for (ArcsType i = 0; i != _arc_num; ++i) {
//_flow[i] = 0; //by default, the sparse matrix is empty
_state[i] = STATE_LOWER;
}
@@ -988,7 +992,7 @@ namespace lemon {
// EQ supply constraints
_search_arc_num = _arc_num;
_all_arc_num = _arc_num + _node_num;
- for (int u = 0, e = _arc_num; u != _node_num; ++u, ++e) {
+ for (ArcsType u = 0, e = _arc_num; u != _node_num; ++u, ++e) {
_parent[u] = _root;
_pred[u] = e;
_thread[u] = u + 1;
@@ -1016,8 +1020,8 @@ namespace lemon {
else if (_sum_supply > 0) {
// LEQ supply constraints
_search_arc_num = _arc_num + _node_num;
- int f = _arc_num + _node_num;
- for (int u = 0, e = _arc_num; u != _node_num; ++u, ++e) {
+ ArcsType f = _arc_num + _node_num;
+ for (ArcsType u = 0, e = _arc_num; u != _node_num; ++u, ++e) {
_parent[u] = _root;
_thread[u] = u + 1;
_rev_thread[u + 1] = u;
@@ -1054,8 +1058,8 @@ namespace lemon {
else {
// GEQ supply constraints
_search_arc_num = _arc_num + _node_num;
- int f = _arc_num + _node_num;
- for (int u = 0, e = _arc_num; u != _node_num; ++u, ++e) {
+ ArcsType f = _arc_num + _node_num;
+ for (ArcsType u = 0, e = _arc_num; u != _node_num; ++u, ++e) {
_parent[u] = _root;
_thread[u] = u + 1;
_rev_thread[u + 1] = u;
@@ -1120,9 +1124,9 @@ namespace lemon {
second = _source[in_arc];
}
delta = INF;
- int result = 0;
+ char result = 0;
Value d;
- int e;
+ ArcsType e;
// Search the cycle along the path form the first node to the root
for (int u = first; u != join; u = _parent[u]) {
@@ -1239,7 +1243,7 @@ namespace lemon {
// Update _rev_thread using the new _thread values
for (int i = 0; i != int(_dirty_revs.size()); ++i) {
- u = _dirty_revs[i];
+ int u = _dirty_revs[i];
_rev_thread[_thread[u]] = u;
}
@@ -1257,7 +1261,7 @@ namespace lemon {
u = w;
}
_pred[u_in] = in_arc;
- _forward[u_in] = ((unsigned int)u_in == _source[in_arc]);
+ _forward[u_in] = (u_in == _source[in_arc]);
_succ_num[u_in] = old_succ_num;
// Set limits for updating _last_succ form v_in and v_out
@@ -1328,7 +1332,7 @@ namespace lemon {
if (_sum_supply > 0) total -= _sum_supply;
if (total <= 0) return true;
- IntVector arc_vector;
+ ArcVector arc_vector;
if (_sum_supply >= 0) {
if (supply_nodes.size() == 1 && demand_nodes.size() == 1) {
// Perform a reverse graph search from the sink to the source
@@ -1345,7 +1349,7 @@ namespace lemon {
Arc a; _graph.firstIn(a, v);
for (; a != INVALID; _graph.nextIn(a)) {
if (reached[u = _graph.source(a)]) continue;
- int j = getArcID(a);
+ ArcsType j = getArcID(a);
if (INF >= total) {
arc_vector.push_back(j);
reached[u] = true;
@@ -1355,7 +1359,7 @@ namespace lemon {
}
} else {
// Find the min. cost incomming arc for each demand node
- for (int i = 0; i != int(demand_nodes.size()); ++i) {
+ for (int i = 0; i != demand_nodes.size(); ++i) {
Node v = demand_nodes[i];
Cost c, min_cost = std::numeric_limits<Cost>::max();
Arc min_arc = INVALID;
@@ -1393,7 +1397,7 @@ namespace lemon {
}
// Perform heuristic initial pivots
- for (int i = 0; i != int(arc_vector.size()); ++i) {
+ for (ArcsType i = 0; i != arc_vector.size(); ++i) {
in_arc = arc_vector[i];
// l'erreur est probablement ici...
if (_state[in_arc] * (_cost[in_arc] + _pi[_source[in_arc]] -
@@ -1423,7 +1427,7 @@ namespace lemon {
// Perform heuristic initial pivots
if (!initialPivots()) return UNBOUNDED;
- int iter_number=0;
+ size_t iter_number=0;
//pivot.setDantzig(true);
// Execute the Network Simplex algorithm
while (pivot.findEnteringArc()) {
@@ -1443,7 +1447,7 @@ namespace lemon {
double a;
a= (fabs(_pi[_source[in_arc]])>=fabs(_pi[_target[in_arc]])) ? fabs(_pi[_source[in_arc]]) : fabs(_pi[_target[in_arc]]);
a=a>=fabs(_cost[in_arc])?a:fabs(_cost[in_arc]);
- for (int i=0; i<_flow.size(); i++) {
+ for (int64_t i=0; i<_flow.size(); i++) {
sumFlow+=_state[i]*_flow[i];
}
std::cout << "Sum of the flow " << std::setprecision(20) << sumFlow << "\n" << iter_number << " iterations, current cost=" << curCost << "\nReduced cost=" << _state[in_arc] * (_cost[in_arc] + _pi[_source[in_arc]] -_pi[_target[in_arc]]) << "\nPrecision = "<< -EPSILON*(a) << "\n";
@@ -1482,12 +1486,12 @@ namespace lemon {
double a;
a= (fabs(_pi[_source[in_arc]])>=fabs(_pi[_target[in_arc]])) ? fabs(_pi[_source[in_arc]]) : fabs(_pi[_target[in_arc]]);
a=a>=fabs(_cost[in_arc])?a:fabs(_cost[in_arc]);
- for (int i=0; i<_flow.size(); i++) {
+ for (int64_t i=0; i<_flow.size(); i++) {
sumFlow+=_state[i]*_flow[i];
}
-
+
std::cout << "Sum of the flow " << std::setprecision(20) << sumFlow << "\n" << niter << " iterations, current cost=" << curCost << "\nReduced cost=" << _state[in_arc] * (_cost[in_arc] + _pi[_source[in_arc]] -_pi[_target[in_arc]]) << "\nPrecision = "<< -EPSILON*(a) << "\n";
-
+
std::cout << "Arc in = (" << _node_id(_source[in_arc]) << ", " << _node_id(_target[in_arc]) <<")\n";
std::cout << "Supplies = (" << _supply[_source[in_arc]] << ", " << _supply[_target[in_arc]] << ")\n";
@@ -1505,9 +1509,9 @@ namespace lemon {
#endif
// Check feasibility
if( retVal == OPTIMAL){
- for (int e = _search_arc_num; e != _all_arc_num; ++e) {
+ for (ArcsType e = _search_arc_num; e != _all_arc_num; ++e) {
if (_flow[e] != 0){
- if (abs(_flow[e]) > EPSILON)
+ if (fabs(_flow[e]) > _EPSILON) // change of the original code following issue #126
return INFEASIBLE;
else
_flow[e]=0;
@@ -1521,20 +1525,20 @@ namespace lemon {
if (_sum_supply == 0) {
if (_stype == GEQ) {
Cost max_pot = -std::numeric_limits<Cost>::max();
- for (int i = 0; i != _node_num; ++i) {
+ for (ArcsType i = 0; i != _node_num; ++i) {
if (_pi[i] > max_pot) max_pot = _pi[i];
}
if (max_pot > 0) {
- for (int i = 0; i != _node_num; ++i)
+ for (ArcsType i = 0; i != _node_num; ++i)
_pi[i] -= max_pot;
}
} else {
Cost min_pot = std::numeric_limits<Cost>::max();
- for (int i = 0; i != _node_num; ++i) {
+ for (ArcsType i = 0; i != _node_num; ++i) {
if (_pi[i] < min_pot) min_pot = _pi[i];
}
if (min_pot < 0) {
- for (int i = 0; i != _node_num; ++i)
+ for (ArcsType i = 0; i != _node_num; ++i)
_pi[i] -= min_pot;
}
}
@@ -1548,5 +1552,3 @@ namespace lemon {
///@}
} //namespace lemon
-
-#endif //LEMON_NETWORK_SIMPLEX_H
diff --git a/ot/lp/network_simplex_simple_omp.h b/ot/lp/network_simplex_simple_omp.h
new file mode 100644
index 0000000..87e4c05
--- /dev/null
+++ b/ot/lp/network_simplex_simple_omp.h
@@ -0,0 +1,1699 @@
+/* -*- mode: C++; indent-tabs-mode: nil; -*-
+*
+*
+* This file has been adapted by Nicolas Bonneel (2013),
+* from network_simplex.h from LEMON, a generic C++ optimization library,
+* to implement a lightweight network simplex for mass transport, more
+* memory efficient than the original file. A previous version of this file
+* is used as part of the Displacement Interpolation project,
+* Web: http://www.cs.ubc.ca/labs/imager/tr/2011/DisplacementInterpolation/
+*
+* Revisions:
+* March 2015: added OpenMP parallelization
+* March 2017: included Antoine Rolet's trick to make it more robust
+* April 2018: IMPORTANT bug fix + uses 64bit integers (slightly slower but less risks of overflows), updated to a newer version of the algo by LEMON, sparse flow by default + minor edits.
+*
+*
+**** Original file Copyright Notice :
+*
+* Copyright (C) 2003-2010
+* Egervary Jeno Kombinatorikus Optimalizalasi Kutatocsoport
+* (Egervary Research Group on Combinatorial Optimization, EGRES).
+*
+* Permission to use, modify and distribute this software is granted
+* provided that this copyright notice appears in all copies. For
+* precise terms see the accompanying LICENSE file.
+*
+* This software is provided "AS IS" with no warranty of any kind,
+* express or implied, and with no claim as to its suitability for any
+* purpose.
+*
+*/
+
+#pragma once
+#undef DEBUG_LVL
+#define DEBUG_LVL 0
+
+#if DEBUG_LVL>0
+#include <iomanip>
+#endif
+
+#undef EPSILON
+#undef _EPSILON
+#undef MAX_DEBUG_ITER
+#define EPSILON std::numeric_limits<Cost>::epsilon()*10
+#define _EPSILON 1e-8
+#define MAX_DEBUG_ITER 100000
+
+/// \ingroup min_cost_flow_algs
+///
+/// \file
+/// \brief Network Simplex algorithm for finding a minimum cost flow.
+
+// if your compiler has troubles with unorderedmaps, just comment the following line to use a slower std::map instead
+#define HASHMAP // now handled with unorderedmaps instead of stdext::hash_map. Should be better supported.
+
+#define SPARSE_FLOW // a sparse flow vector will be 10-15% slower for small problems but uses less memory and becomes faster for large problems (40k total nodes)
+
+#include <vector>
+#include <limits>
+#include <algorithm>
+#include <iostream>
+#ifdef HASHMAP
+#include <unordered_map>
+#else
+#include <map>
+#endif
+//#include "core.h"
+//#include "lmath.h"
+
+#ifdef OMP
+#include <omp.h>
+#endif
+#include <cmath>
+
+
+//#include "sparse_array_n.h"
+#include "full_bipartitegraph_omp.h"
+
+#undef INVALIDNODE
+#undef INVALID
+#define INVALIDNODE -1
+#define INVALID (-1)
+
+namespace lemon_omp {
+
+ int64_t max_threads = -1;
+
+ template <typename T>
+ class ProxyObject;
+
+ template<typename T>
+ class SparseValueVector
+ {
+ public:
+ SparseValueVector(size_t n = 0) // parameter n for compatibility with standard vectors
+ {
+ }
+ void resize(size_t n = 0) {};
+ T operator[](const size_t id) const
+ {
+#ifdef HASHMAP
+ typename std::unordered_map<size_t, T>::const_iterator it = data.find(id);
+#else
+ typename std::map<size_t, T>::const_iterator it = data.find(id);
+#endif
+ if (it == data.end())
+ return 0;
+ else
+ return it->second;
+ }
+
+ ProxyObject<T> operator[](const size_t id)
+ {
+ return ProxyObject<T>(this, id);
+ }
+
+ //private:
+#ifdef HASHMAP
+ std::unordered_map<size_t, T> data;
+#else
+ std::map<size_t, T> data;
+#endif
+
+ };
+
+ template <typename T>
+ class ProxyObject {
+ public:
+ ProxyObject(SparseValueVector<T> *v, size_t idx) { _v = v; _idx = idx; };
+ ProxyObject<T> & operator=(const T &v) {
+ // If we get here, we know that operator[] was called to perform a write access,
+ // so we can insert an item in the vector if needed
+ if (v != 0)
+ _v->data[_idx] = v;
+ return *this;
+ }
+
+ operator T() {
+ // If we get here, we know that operator[] was called to perform a read access,
+ // so we can simply return the existing object
+#ifdef HASHMAP
+ typename std::unordered_map<size_t, T>::iterator it = _v->data.find(_idx);
+#else
+ typename std::map<size_t, T>::iterator it = _v->data.find(_idx);
+#endif
+ if (it == _v->data.end())
+ return 0;
+ else
+ return it->second;
+ }
+
+ void operator+=(T val)
+ {
+ if (val == 0) return;
+#ifdef HASHMAP
+ typename std::unordered_map<size_t, T>::iterator it = _v->data.find(_idx);
+#else
+ typename std::map<size_t, T>::iterator it = _v->data.find(_idx);
+#endif
+ if (it == _v->data.end())
+ _v->data[_idx] = val;
+ else
+ {
+ T sum = it->second + val;
+ if (sum == 0)
+ _v->data.erase(it);
+ else
+ it->second = sum;
+ }
+ }
+ void operator-=(T val)
+ {
+ if (val == 0) return;
+#ifdef HASHMAP
+ typename std::unordered_map<size_t, T>::iterator it = _v->data.find(_idx);
+#else
+ typename std::map<size_t, T>::iterator it = _v->data.find(_idx);
+#endif
+ if (it == _v->data.end())
+ _v->data[_idx] = -val;
+ else
+ {
+ T sum = it->second - val;
+ if (sum == 0)
+ _v->data.erase(it);
+ else
+ it->second = sum;
+ }
+ }
+
+ SparseValueVector<T> *_v;
+ size_t _idx;
+ };
+
+
+
+ /// \addtogroup min_cost_flow_algs
+ /// @{
+
+ /// \brief Implementation of the primal Network Simplex algorithm
+ /// for finding a \ref min_cost_flow "minimum cost flow".
+ ///
+ /// \ref NetworkSimplexSimple implements the primal Network Simplex algorithm
+ /// for finding a \ref min_cost_flow "minimum cost flow"
+ /// \ref amo93networkflows, \ref dantzig63linearprog,
+ /// \ref kellyoneill91netsimplex.
+ /// This algorithm is a highly efficient specialized version of the
+ /// linear programming simplex method directly for the minimum cost
+ /// flow problem.
+ ///
+ /// In general, %NetworkSimplexSimple is the fastest implementation available
+ /// in LEMON for this problem.
+ /// Moreover, it supports both directions of the supply/demand inequality
+ /// constraints. For more information, see \ref SupplyType.
+ ///
+ /// Most of the parameters of the problem (except for the digraph)
+ /// can be given using separate functions, and the algorithm can be
+ /// executed using the \ref run() function. If some parameters are not
+ /// specified, then default values will be used.
+ ///
+ /// \tparam GR The digraph type the algorithm runs on.
+ /// \tparam V The number type used for flow amounts, capacity bounds
+ /// and supply values in the algorithm. By default, it is \c int.
+ /// \tparam C The number type used for costs and potentials in the
+ /// algorithm. By default, it is the same as \c V.
+ ///
+ /// \warning Both number types must be signed and all input data must
+ /// be integer.
+ ///
+ /// \note %NetworkSimplexSimple provides five different pivot rule
+ /// implementations, from which the most efficient one is used
+ /// by default. For more information, see \ref PivotRule.
+ template <typename GR, typename V = int, typename C = V, typename ArcsType = int64_t>
+ class NetworkSimplexSimple
+ {
+ public:
+
+ /// \brief Constructor.
+ ///
+ /// The constructor of the class.
+ ///
+ /// \param graph The digraph the algorithm runs on.
+ /// \param arc_mixing Indicate if the arcs have to be stored in a
+ /// mixed order in the internal data structure.
+ /// In special cases, it could lead to better overall performance,
+ /// but it is usually slower. Therefore it is disabled by default.
+ NetworkSimplexSimple(const GR& graph, bool arc_mixing, int nbnodes, ArcsType nb_arcs, size_t maxiters = 0, int numThreads=-1) :
+ _graph(graph), //_arc_id(graph),
+ _arc_mixing(arc_mixing), _init_nb_nodes(nbnodes), _init_nb_arcs(nb_arcs),
+ MAX(std::numeric_limits<Value>::max()),
+ INF(std::numeric_limits<Value>::has_infinity ?
+ std::numeric_limits<Value>::infinity() : MAX)
+ {
+ // Reset data structures
+ reset();
+ max_iter = maxiters;
+#ifdef OMP
+ if (max_threads < 0) {
+ max_threads = omp_get_max_threads();
+ }
+ if (numThreads > 0 && numThreads<=max_threads){
+ num_threads = numThreads;
+ } else if (numThreads == -1 || numThreads>max_threads) {
+ num_threads = max_threads;
+ } else {
+ num_threads = 1;
+ }
+ omp_set_num_threads(num_threads);
+#else
+ num_threads = 1;
+#endif
+ }
+
+ /// The type of the flow amounts, capacity bounds and supply values
+ typedef V Value;
+ /// The type of the arc costs
+ typedef C Cost;
+
+ public:
+ /// \brief Problem type constants for the \c run() function.
+ ///
+ /// Enum type containing the problem type constants that can be
+ /// returned by the \ref run() function of the algorithm.
+ enum ProblemType {
+ /// The problem has no feasible solution (flow).
+ INFEASIBLE,
+ /// The problem has optimal solution (i.e. it is feasible and
+ /// bounded), and the algorithm has found optimal flow and node
+ /// potentials (primal and dual solutions).
+ OPTIMAL,
+ /// The objective function of the problem is unbounded, i.e.
+ /// there is a directed cycle having negative total cost and
+ /// infinite upper bound.
+ UNBOUNDED,
+ // The maximum number of iteration has been reached
+ MAX_ITER_REACHED
+ };
+
+ /// \brief Constants for selecting the type of the supply constraints.
+ ///
+ /// Enum type containing constants for selecting the supply type,
+ /// i.e. the direction of the inequalities in the supply/demand
+ /// constraints of the \ref min_cost_flow "minimum cost flow problem".
+ ///
+ /// The default supply type is \c GEQ, the \c LEQ type can be
+ /// selected using \ref supplyType().
+ /// The equality form is a special case of both supply types.
+ enum SupplyType {
+ /// This option means that there are <em>"greater or equal"</em>
+ /// supply/demand constraints in the definition of the problem.
+ GEQ,
+ /// This option means that there are <em>"less or equal"</em>
+ /// supply/demand constraints in the definition of the problem.
+ LEQ
+ };
+
+
+
+ private:
+ size_t max_iter;
+ int num_threads;
+ TEMPLATE_DIGRAPH_TYPEDEFS(GR);
+
+ typedef std::vector<int> IntVector;
+ typedef std::vector<ArcsType> ArcVector;
+ typedef std::vector<Value> ValueVector;
+ typedef std::vector<Cost> CostVector;
+ // typedef SparseValueVector<Cost> CostVector;
+ typedef std::vector<char> BoolVector;
+ // Note: vector<char> is used instead of vector<bool> for efficiency reasons
+
+ // State constants for arcs
+ enum ArcState {
+ STATE_UPPER = -1,
+ STATE_TREE = 0,
+ STATE_LOWER = 1
+ };
+
+ typedef std::vector<signed char> StateVector;
+ // Note: vector<signed char> is used instead of vector<ArcState> for
+ // efficiency reasons
+
+ private:
+
+ // Data related to the underlying digraph
+ const GR &_graph;
+ int _node_num;
+ ArcsType _arc_num;
+ ArcsType _all_arc_num;
+ ArcsType _search_arc_num;
+
+ // Parameters of the problem
+ SupplyType _stype;
+ Value _sum_supply;
+
+ inline int _node_id(int n) const { return _node_num - n - 1; };
+
+ //IntArcMap _arc_id;
+ IntVector _source; // keep nodes as integers
+ IntVector _target;
+ bool _arc_mixing;
+
+ // Node and arc data
+ CostVector _cost;
+ ValueVector _supply;
+#ifdef SPARSE_FLOW
+ SparseValueVector<Value> _flow;
+#else
+ ValueVector _flow;
+#endif
+
+ CostVector _pi;
+
+ // Data for storing the spanning tree structure
+ IntVector _parent;
+ ArcVector _pred;
+ IntVector _thread;
+ IntVector _rev_thread;
+ IntVector _succ_num;
+ IntVector _last_succ;
+ IntVector _dirty_revs;
+ BoolVector _forward;
+ StateVector _state;
+ ArcsType _root;
+
+ // Temporary data used in the current pivot iteration
+ ArcsType in_arc, join, u_in, v_in, u_out, v_out;
+ ArcsType first, second, right, last;
+ ArcsType stem, par_stem, new_stem;
+ Value delta;
+
+ const Value MAX;
+
+ ArcsType mixingCoeff;
+
+ public:
+
+ /// \brief Constant for infinite upper bounds (capacities).
+ ///
+ /// Constant for infinite upper bounds (capacities).
+ /// It is \c std::numeric_limits<Value>::infinity() if available,
+ /// \c std::numeric_limits<Value>::max() otherwise.
+ const Value INF;
+
+ private:
+
+ // thank you to DVK and MizardX from StackOverflow for this function!
+ inline ArcsType sequence(ArcsType k) const {
+ ArcsType smallv = (k > num_total_big_subsequence_numbers) & 1;
+
+ k -= num_total_big_subsequence_numbers * smallv;
+ ArcsType subsequence_length2 = subsequence_length - smallv;
+ ArcsType subsequence_num = (k / subsequence_length2) + num_big_subsequences * smallv;
+ ArcsType subsequence_offset = (k % subsequence_length2) * mixingCoeff;
+
+ return subsequence_offset + subsequence_num;
+ }
+ ArcsType subsequence_length;
+ ArcsType num_big_subsequences;
+ ArcsType num_total_big_subsequence_numbers;
+
+ inline ArcsType getArcID(const Arc &arc) const
+ {
+ //int n = _arc_num-arc._id-1;
+ ArcsType n = _arc_num - GR::id(arc) - 1;
+
+ //ArcsType a = mixingCoeff*(n%mixingCoeff) + n/mixingCoeff;
+ //ArcsType b = _arc_id[arc];
+ if (_arc_mixing)
+ return sequence(n);
+ else
+ return n;
+ }
+
+ // finally unused because too slow
+ inline ArcsType getSource(const ArcsType arc) const
+ {
+ //ArcsType a = _source[arc];
+ //return a;
+
+ ArcsType n = _arc_num - arc - 1;
+ if (_arc_mixing)
+ n = mixingCoeff*(n%mixingCoeff) + n / mixingCoeff;
+
+ ArcsType b;
+ if (n >= 0)
+ b = _node_id(_graph.source(GR::arcFromId(n)));
+ else
+ {
+ n = arc + 1 - _arc_num;
+ if (n <= _node_num)
+ b = _node_num;
+ else
+ if (n >= _graph._n1)
+ b = _graph._n1;
+ else
+ b = _graph._n1 - n;
+ }
+
+ return b;
+ }
+
+
+
+ // Implementation of the Block Search pivot rule
+ class BlockSearchPivotRule
+ {
+ private:
+
+ // References to the NetworkSimplexSimple class
+ const IntVector &_source;
+ const IntVector &_target;
+ const CostVector &_cost;
+ const StateVector &_state;
+ const CostVector &_pi;
+ ArcsType &_in_arc;
+ ArcsType _search_arc_num;
+
+ // Pivot rule data
+ ArcsType _block_size;
+ ArcsType _next_arc;
+ NetworkSimplexSimple &_ns;
+
+ public:
+
+ // Constructor
+ BlockSearchPivotRule(NetworkSimplexSimple &ns) :
+ _source(ns._source), _target(ns._target),
+ _cost(ns._cost), _state(ns._state), _pi(ns._pi),
+ _in_arc(ns.in_arc), _search_arc_num(ns._search_arc_num),
+ _next_arc(0), _ns(ns)
+ {
+ // The main parameters of the pivot rule
+ const double BLOCK_SIZE_FACTOR = 1;
+ const ArcsType MIN_BLOCK_SIZE = 10;
+
+ _block_size = std::max(ArcsType(BLOCK_SIZE_FACTOR * std::sqrt(double(_search_arc_num))), MIN_BLOCK_SIZE);
+ }
+
+ // Find next entering arc
+ bool findEnteringArc() {
+ Cost min_val = 0;
+
+ ArcsType N = _ns.num_threads;
+
+ std::vector<Cost> minArray(N, 0);
+ std::vector<ArcsType> arcId(N);
+ ArcsType bs = (ArcsType)ceil(_block_size / (double)N);
+
+ for (ArcsType i = 0; i < _search_arc_num; i += _block_size) {
+
+ ArcsType e;
+ int j;
+#pragma omp parallel
+ {
+#ifdef OMP
+ int t = omp_get_thread_num();
+#else
+ int t = 0;
+#endif
+
+#pragma omp for schedule(static, bs) lastprivate(e)
+ for (j = 0; j < std::min(i + _block_size, _search_arc_num) - i; j++) {
+ e = (_next_arc + i + j); if (e >= _search_arc_num) e -= _search_arc_num;
+ Cost c = _state[e] * (_cost[e] + _pi[_source[e]] - _pi[_target[e]]);
+ if (c < minArray[t]) {
+ minArray[t] = c;
+ arcId[t] = e;
+ }
+ }
+ }
+ for (int j = 0; j < N; j++) {
+ if (minArray[j] < min_val) {
+ min_val = minArray[j];
+ _in_arc = arcId[j];
+ }
+ }
+ Cost a = std::abs(_pi[_source[_in_arc]]) > std::abs(_pi[_target[_in_arc]]) ? std::abs(_pi[_source[_in_arc]]) : std::abs(_pi[_target[_in_arc]]);
+ a = a > std::abs(_cost[_in_arc]) ? a : std::abs(_cost[_in_arc]);
+ if (min_val < -EPSILON*a) {
+ _next_arc = e;
+ return true;
+ }
+ }
+
+ Cost a = fabs(_pi[_source[_in_arc]]) > fabs(_pi[_target[_in_arc]]) ? fabs(_pi[_source[_in_arc]]) : fabs(_pi[_target[_in_arc]]);
+ a = a > fabs(_cost[_in_arc]) ? a : fabs(_cost[_in_arc]);
+ if (min_val >= -EPSILON*a) return false;
+
+ return true;
+ }
+
+
+ // Find next entering arc
+ /*bool findEnteringArc() {
+ Cost min_val = 0;
+ int N = omp_get_max_threads();
+ std::vector<Cost> minArray(N);
+ std::vector<ArcsType> arcId(N);
+
+ ArcsType bs = (ArcsType)ceil(_block_size / (double)N);
+ for (ArcsType i = 0; i < _search_arc_num; i += _block_size) {
+
+ ArcsType maxJ = std::min(i + _block_size, _search_arc_num) - i;
+ ArcsType j;
+#pragma omp parallel
+ {
+ int t = omp_get_thread_num();
+ Cost minV = 0;
+ ArcsType arcStart = _next_arc + i;
+ ArcsType arc = -1;
+#pragma omp for schedule(static, bs)
+ for (j = 0; j < maxJ; j++) {
+ ArcsType e = arcStart + j; if (e >= _search_arc_num) e -= _search_arc_num;
+ Cost c = _state[e] * (_cost[e] + _pi[_source[e]] - _pi[_target[e]]);
+ if (c < minV) {
+ minV = c;
+ arc = e;
+ }
+ }
+
+ minArray[t] = minV;
+ arcId[t] = arc;
+ }
+ for (int j = 0; j < N; j++) {
+ if (minArray[j] < min_val) {
+ min_val = minArray[j];
+ _in_arc = arcId[j];
+ }
+ }
+
+ //FIX by Antoine Rolet to avoid precision issues
+ Cost a = std::max(std::abs(_cost[_in_arc]), std::max(std::abs(_pi[_source[_in_arc]]), std::abs(_pi[_target[_in_arc]])));
+ if (min_val <-std::numeric_limits<Cost>::epsilon()*a) {
+ _next_arc = _next_arc + i + maxJ - 1;
+ if (_next_arc >= _search_arc_num) _next_arc -= _search_arc_num;
+ return true;
+ }
+ }
+
+ if (min_val >= 0) {
+ return false;
+ }
+
+ return true;
+ }*/
+
+
+ /*bool findEnteringArc() {
+ Cost c, min = 0;
+ int cnt = _block_size;
+ int e, min_arc = _next_arc;
+ for (e = _next_arc; e < _search_arc_num; ++e) {
+ c = _state[e] * (_cost[e] + _pi[_source[e]] - _pi[_target[e]]);
+ if (c < min) {
+ min = c;
+ min_arc = e;
+
+ }
+ if (--cnt == 0) {
+ if (min < 0) break;
+ cnt = _block_size;
+
+ }
+
+ }
+ if (min == 0 || cnt > 0) {
+ for (e = 0; e < _next_arc; ++e) {
+ c = _state[e] * (_cost[e] + _pi[_source[e]] - _pi[_target[e]]);
+ if (c < min) {
+ min = c;
+ min_arc = e;
+
+ }
+ if (--cnt == 0) {
+ if (min < 0) break;
+ cnt = _block_size;
+
+ }
+
+ }
+
+ }
+ if (min >= 0) return false;
+ _in_arc = min_arc;
+ _next_arc = e;
+ return true;
+ }*/
+
+
+
+ }; //class BlockSearchPivotRule
+
+
+
+ public:
+
+
+
+ int _init_nb_nodes;
+ ArcsType _init_nb_arcs;
+
+ /// \name Parameters
+ /// The parameters of the algorithm can be specified using these
+ /// functions.
+
+ /// @{
+
+
+ /// \brief Set the costs of the arcs.
+ ///
+ /// This function sets the costs of the arcs.
+ /// If it is not used before calling \ref run(), the costs
+ /// will be set to \c 1 on all arcs.
+ ///
+ /// \param map An arc map storing the costs.
+ /// Its \c Value type must be convertible to the \c Cost type
+ /// of the algorithm.
+ ///
+ /// \return <tt>(*this)</tt>
+ template<typename CostMap>
+ NetworkSimplexSimple& costMap(const CostMap& map) {
+ Arc a; _graph.first(a);
+ for (; a != INVALID; _graph.next(a)) {
+ _cost[getArcID(a)] = map[a];
+ }
+ return *this;
+ }
+
+
+ /// \brief Set the costs of one arc.
+ ///
+ /// This function sets the costs of one arcs.
+ /// Done for memory reasons
+ ///
+ /// \param arc An arc.
+ /// \param arc A cost
+ ///
+ /// \return <tt>(*this)</tt>
+ template<typename Value>
+ NetworkSimplexSimple& setCost(const Arc& arc, const Value cost) {
+ _cost[getArcID(arc)] = cost;
+ return *this;
+ }
+
+
+ /// \brief Set the supply values of the nodes.
+ ///
+ /// This function sets the supply values of the nodes.
+ /// If neither this function nor \ref stSupply() is used before
+ /// calling \ref run(), the supply of each node will be set to zero.
+ ///
+ /// \param map A node map storing the supply values.
+ /// Its \c Value type must be convertible to the \c Value type
+ /// of the algorithm.
+ ///
+ /// \return <tt>(*this)</tt>
+ template<typename SupplyMap>
+ NetworkSimplexSimple& supplyMap(const SupplyMap& map) {
+ Node n; _graph.first(n);
+ for (; n != INVALIDNODE; _graph.next(n)) {
+ _supply[_node_id(n)] = map[n];
+ }
+ return *this;
+ }
+ template<typename SupplyMap>
+ NetworkSimplexSimple& supplyMap(const SupplyMap* map1, int n1, const SupplyMap* map2, int n2) {
+ Node n; _graph.first(n);
+ for (; n != INVALIDNODE; _graph.next(n)) {
+ if (n<n1)
+ _supply[_node_id(n)] = map1[n];
+ else
+ _supply[_node_id(n)] = map2[n - n1];
+ }
+ return *this;
+ }
+ template<typename SupplyMap>
+ NetworkSimplexSimple& supplyMapAll(SupplyMap val1, int n1, SupplyMap val2, int n2) {
+ Node n; _graph.first(n);
+ for (; n != INVALIDNODE; _graph.next(n)) {
+ if (n<n1)
+ _supply[_node_id(n)] = val1;
+ else
+ _supply[_node_id(n)] = val2;
+ }
+ return *this;
+ }
+
+ /// \brief Set single source and target nodes and a supply value.
+ ///
+ /// This function sets a single source node and a single target node
+ /// and the required flow value.
+ /// If neither this function nor \ref supplyMap() is used before
+ /// calling \ref run(), the supply of each node will be set to zero.
+ ///
+ /// Using this function has the same effect as using \ref supplyMap()
+ /// with such a map in which \c k is assigned to \c s, \c -k is
+ /// assigned to \c t and all other nodes have zero supply value.
+ ///
+ /// \param s The source node.
+ /// \param t The target node.
+ /// \param k The required amount of flow from node \c s to node \c t
+ /// (i.e. the supply of \c s and the demand of \c t).
+ ///
+ /// \return <tt>(*this)</tt>
+ NetworkSimplexSimple& stSupply(const Node& s, const Node& t, Value k) {
+ for (int i = 0; i != _node_num; ++i) {
+ _supply[i] = 0;
+ }
+ _supply[_node_id(s)] = k;
+ _supply[_node_id(t)] = -k;
+ return *this;
+ }
+
+ /// \brief Set the type of the supply constraints.
+ ///
+ /// This function sets the type of the supply/demand constraints.
+ /// If it is not used before calling \ref run(), the \ref GEQ supply
+ /// type will be used.
+ ///
+ /// For more information, see \ref SupplyType.
+ ///
+ /// \return <tt>(*this)</tt>
+ NetworkSimplexSimple& supplyType(SupplyType supply_type) {
+ _stype = supply_type;
+ return *this;
+ }
+
+ /// @}
+
+ /// \name Execution Control
+ /// The algorithm can be executed using \ref run().
+
+ /// @{
+
+ /// \brief Run the algorithm.
+ ///
+ /// This function runs the algorithm.
+ /// The paramters can be specified using functions \ref lowerMap(),
+ /// \ref upperMap(), \ref costMap(), \ref supplyMap(), \ref stSupply(),
+ /// \ref supplyType().
+ /// For example,
+ /// \code
+ /// NetworkSimplexSimple<ListDigraph> ns(graph);
+ /// ns.lowerMap(lower).upperMap(upper).costMap(cost)
+ /// .supplyMap(sup).run();
+ /// \endcode
+ ///
+ /// This function can be called more than once. All the given parameters
+ /// are kept for the next call, unless \ref resetParams() or \ref reset()
+ /// is used, thus only the modified parameters have to be set again.
+ /// If the underlying digraph was also modified after the construction
+ /// of the class (or the last \ref reset() call), then the \ref reset()
+ /// function must be called.
+ ///
+ /// \param pivot_rule The pivot rule that will be used during the
+ /// algorithm. For more information, see \ref PivotRule.
+ ///
+ /// \return \c INFEASIBLE if no feasible flow exists,
+ /// \n \c OPTIMAL if the problem has optimal solution
+ /// (i.e. it is feasible and bounded), and the algorithm has found
+ /// optimal flow and node potentials (primal and dual solutions),
+ /// \n \c UNBOUNDED if the objective function of the problem is
+ /// unbounded, i.e. there is a directed cycle having negative total
+ /// cost and infinite upper bound.
+ ///
+ /// \see ProblemType, PivotRule
+ /// \see resetParams(), reset()
+ ProblemType run() {
+#if DEBUG_LVL>0
+ std::cout << "OPTIMAL = " << OPTIMAL << "\nINFEASIBLE = " << INFEASIBLE << "\nUNBOUNDED = " << UNBOUNDED << "\nMAX_ITER_REACHED" << MAX_ITER_REACHED << "\n" ;
+#endif
+ if (!init()) return INFEASIBLE;
+#if DEBUG_LVL>0
+ std::cout << "Init done, starting iterations\n";
+#endif
+
+ return start();
+ }
+
+ /// \brief Reset all the parameters that have been given before.
+ ///
+ /// This function resets all the paramaters that have been given
+ /// before using functions \ref lowerMap(), \ref upperMap(),
+ /// \ref costMap(), \ref supplyMap(), \ref stSupply(), \ref supplyType().
+ ///
+ /// It is useful for multiple \ref run() calls. Basically, all the given
+ /// parameters are kept for the next \ref run() call, unless
+ /// \ref resetParams() or \ref reset() is used.
+ /// If the underlying digraph was also modified after the construction
+ /// of the class or the last \ref reset() call, then the \ref reset()
+ /// function must be used, otherwise \ref resetParams() is sufficient.
+ ///
+ /// For example,
+ /// \code
+ /// NetworkSimplexSimple<ListDigraph> ns(graph);
+ ///
+ /// // First run
+ /// ns.lowerMap(lower).upperMap(upper).costMap(cost)
+ /// .supplyMap(sup).run();
+ ///
+ /// // Run again with modified cost map (resetParams() is not called,
+ /// // so only the cost map have to be set again)
+ /// cost[e] += 100;
+ /// ns.costMap(cost).run();
+ ///
+ /// // Run again from scratch using resetParams()
+ /// // (the lower bounds will be set to zero on all arcs)
+ /// ns.resetParams();
+ /// ns.upperMap(capacity).costMap(cost)
+ /// .supplyMap(sup).run();
+ /// \endcode
+ ///
+ /// \return <tt>(*this)</tt>
+ ///
+ /// \see reset(), run()
+ NetworkSimplexSimple& resetParams() {
+ for (int i = 0; i != _node_num; ++i) {
+ _supply[i] = 0;
+ }
+ for (ArcsType i = 0; i != _arc_num; ++i) {
+ _cost[i] = 1;
+ }
+ _stype = GEQ;
+ return *this;
+ }
+
+
+ /// \brief Reset the internal data structures and all the parameters
+ /// that have been given before.
+ ///
+ /// This function resets the internal data structures and all the
+ /// paramaters that have been given before using functions \ref lowerMap(),
+ /// \ref upperMap(), \ref costMap(), \ref supplyMap(), \ref stSupply(),
+ /// \ref supplyType().
+ ///
+ /// It is useful for multiple \ref run() calls. Basically, all the given
+ /// parameters are kept for the next \ref run() call, unless
+ /// \ref resetParams() or \ref reset() is used.
+ /// If the underlying digraph was also modified after the construction
+ /// of the class or the last \ref reset() call, then the \ref reset()
+ /// function must be used, otherwise \ref resetParams() is sufficient.
+ ///
+ /// See \ref resetParams() for examples.
+ ///
+ /// \return <tt>(*this)</tt>
+ ///
+ /// \see resetParams(), run()
+ NetworkSimplexSimple& reset() {
+ // Resize vectors
+ _node_num = _init_nb_nodes;
+ _arc_num = _init_nb_arcs;
+ int all_node_num = _node_num + 1;
+ ArcsType max_arc_num = _arc_num + 2 * _node_num;
+
+ _source.resize(max_arc_num);
+ _target.resize(max_arc_num);
+
+ _cost.resize(max_arc_num);
+ _supply.resize(all_node_num);
+ _flow.resize(max_arc_num);
+ _pi.resize(all_node_num);
+
+ _parent.resize(all_node_num);
+ _pred.resize(all_node_num);
+ _forward.resize(all_node_num);
+ _thread.resize(all_node_num);
+ _rev_thread.resize(all_node_num);
+ _succ_num.resize(all_node_num);
+ _last_succ.resize(all_node_num);
+ _state.resize(max_arc_num);
+
+
+ //_arc_mixing=false;
+ if (_arc_mixing && _node_num > 1) {
+ // Store the arcs in a mixed order
+ //ArcsType k = std::max(ArcsType(std::sqrt(double(_arc_num))), ArcsType(10));
+ const ArcsType k = std::max(ArcsType(_arc_num / _node_num), ArcsType(3));
+ mixingCoeff = k;
+ subsequence_length = _arc_num / mixingCoeff + 1;
+ num_big_subsequences = _arc_num % mixingCoeff;
+ num_total_big_subsequence_numbers = subsequence_length * num_big_subsequences;
+
+#pragma omp parallel for schedule(static)
+ for (Arc a = 0; a <= _graph.maxArcId(); a++) { // --a <=> _graph.next(a) , -1 == INVALID
+ ArcsType i = sequence(_graph.maxArcId()-a);
+ _source[i] = _node_id(_graph.source(a));
+ _target[i] = _node_id(_graph.target(a));
+ }
+ } else {
+ // Store the arcs in the original order
+ ArcsType i = 0;
+ Arc a; _graph.first(a);
+ for (; a != INVALID; _graph.next(a), ++i) {
+ _source[i] = _node_id(_graph.source(a));
+ _target[i] = _node_id(_graph.target(a));
+ //_arc_id[a] = i;
+ }
+ }
+
+ // Reset parameters
+ resetParams();
+ return *this;
+ }
+
+ /// @}
+
+ /// \name Query Functions
+ /// The results of the algorithm can be obtained using these
+ /// functions.\n
+ /// The \ref run() function must be called before using them.
+
+ /// @{
+
+ /// \brief Return the total cost of the found flow.
+ ///
+ /// This function returns the total cost of the found flow.
+ /// Its complexity is O(e).
+ ///
+ /// \note The return type of the function can be specified as a
+ /// template parameter. For example,
+ /// \code
+ /// ns.totalCost<double>();
+ /// \endcode
+ /// It is useful if the total cost cannot be stored in the \c Cost
+ /// type of the algorithm, which is the default return type of the
+ /// function.
+ ///
+ /// \pre \ref run() must be called before using this function.
+ /*template <typename Number>
+ Number totalCost() const {
+ Number c = 0;
+ for (ArcIt a(_graph); a != INVALID; ++a) {
+ int i = getArcID(a);
+ c += Number(_flow[i]) * Number(_cost[i]);
+ }
+ return c;
+ }*/
+
+ template <typename Number>
+ Number totalCost() const {
+ Number c = 0;
+
+#ifdef SPARSE_FLOW
+ #ifdef HASHMAP
+ typename std::unordered_map<size_t, Value>::const_iterator it;
+ #else
+ typename std::map<size_t, Value>::const_iterator it;
+ #endif
+ for (it = _flow.data.begin(); it!=_flow.data.end(); ++it)
+ c += Number(it->second) * Number(_cost[it->first]);
+ return c;
+#else
+ for (ArcsType i = 0; i<_flow.size(); i++)
+ c += _flow[i] * Number(_cost[i]);
+ return c;
+#endif
+ }
+
+#ifndef DOXYGEN
+ Cost totalCost() const {
+ return totalCost<Cost>();
+ }
+#endif
+
+ /// \brief Return the flow on the given arc.
+ ///
+ /// This function returns the flow on the given arc.
+ ///
+ /// \pre \ref run() must be called before using this function.
+ Value flow(const Arc& a) const {
+ return _flow[getArcID(a)];
+ }
+
+ /// \brief Return the flow map (the primal solution).
+ ///
+ /// This function copies the flow value on each arc into the given
+ /// map. The \c Value type of the algorithm must be convertible to
+ /// the \c Value type of the map.
+ ///
+ /// \pre \ref run() must be called before using this function.
+ template <typename FlowMap>
+ void flowMap(FlowMap &map) const {
+ Arc a; _graph.first(a);
+ for (; a != INVALID; _graph.next(a)) {
+ map.set(a, _flow[getArcID(a)]);
+ }
+ }
+
+ /// \brief Return the potential (dual value) of the given node.
+ ///
+ /// This function returns the potential (dual value) of the
+ /// given node.
+ ///
+ /// \pre \ref run() must be called before using this function.
+ Cost potential(const Node& n) const {
+ return _pi[_node_id(n)];
+ }
+
+ /// \brief Return the potential map (the dual solution).
+ ///
+ /// This function copies the potential (dual value) of each node
+ /// into the given map.
+ /// The \c Cost type of the algorithm must be convertible to the
+ /// \c Value type of the map.
+ ///
+ /// \pre \ref run() must be called before using this function.
+ template <typename PotentialMap>
+ void potentialMap(PotentialMap &map) const {
+ Node n; _graph.first(n);
+ for (; n != INVALID; _graph.next(n)) {
+ map.set(n, _pi[_node_id(n)]);
+ }
+ }
+
+ /// @}
+
+ private:
+
+ // Initialize internal data structures
+ bool init() {
+ if (_node_num == 0) return false;
+
+ // Check the sum of supply values
+ _sum_supply = 0;
+ for (int i = 0; i != _node_num; ++i) {
+ _sum_supply += _supply[i];
+ }
+ /*if (!((_stype == GEQ && _sum_supply <= 0) ||
+ (_stype == LEQ && _sum_supply >= 0))) return false;*/
+
+
+ // Initialize artifical cost
+ Cost ART_COST;
+ if (std::numeric_limits<Cost>::is_exact) {
+ ART_COST = std::numeric_limits<Cost>::max() / 2 + 1;
+ } else {
+ ART_COST = 0;
+ for (ArcsType i = 0; i != _arc_num; ++i) {
+ if (_cost[i] > ART_COST) ART_COST = _cost[i];
+ }
+ ART_COST = (ART_COST + 1) * _node_num;
+ }
+
+ // Initialize arc maps
+ for (ArcsType i = 0; i != _arc_num; ++i) {
+#ifndef SPARSE_FLOW
+ _flow[i] = 0; //by default, the sparse matrix is empty
+#endif
+ _state[i] = STATE_LOWER;
+ }
+#ifdef SPARSE_FLOW
+ _flow = SparseValueVector<Value>();
+#endif
+
+ // Set data for the artificial root node
+ _root = _node_num;
+ _parent[_root] = -1;
+ _pred[_root] = -1;
+ _thread[_root] = 0;
+ _rev_thread[0] = _root;
+ _succ_num[_root] = _node_num + 1;
+ _last_succ[_root] = _root - 1;
+ _supply[_root] = -_sum_supply;
+ _pi[_root] = 0;
+
+ // Add artificial arcs and initialize the spanning tree data structure
+ if (_sum_supply == 0) {
+ // EQ supply constraints
+ _search_arc_num = _arc_num;
+ _all_arc_num = _arc_num + _node_num;
+ for (ArcsType u = 0, e = _arc_num; u != _node_num; ++u, ++e) {
+ _parent[u] = _root;
+ _pred[u] = e;
+ _thread[u] = u + 1;
+ _rev_thread[u + 1] = u;
+ _succ_num[u] = 1;
+ _last_succ[u] = u;
+ _state[e] = STATE_TREE;
+ if (_supply[u] >= 0) {
+ _forward[u] = true;
+ _pi[u] = 0;
+ _source[e] = u;
+ _target[e] = _root;
+ _flow[e] = _supply[u];
+ _cost[e] = 0;
+ } else {
+ _forward[u] = false;
+ _pi[u] = ART_COST;
+ _source[e] = _root;
+ _target[e] = u;
+ _flow[e] = -_supply[u];
+ _cost[e] = ART_COST;
+ }
+ }
+ } else if (_sum_supply > 0) {
+ // LEQ supply constraints
+ _search_arc_num = _arc_num + _node_num;
+ ArcsType f = _arc_num + _node_num;
+ for (ArcsType u = 0, e = _arc_num; u != _node_num; ++u, ++e) {
+ _parent[u] = _root;
+ _thread[u] = u + 1;
+ _rev_thread[u + 1] = u;
+ _succ_num[u] = 1;
+ _last_succ[u] = u;
+ if (_supply[u] >= 0) {
+ _forward[u] = true;
+ _pi[u] = 0;
+ _pred[u] = e;
+ _source[e] = u;
+ _target[e] = _root;
+ _flow[e] = _supply[u];
+ _cost[e] = 0;
+ _state[e] = STATE_TREE;
+ } else {
+ _forward[u] = false;
+ _pi[u] = ART_COST;
+ _pred[u] = f;
+ _source[f] = _root;
+ _target[f] = u;
+ _flow[f] = -_supply[u];
+ _cost[f] = ART_COST;
+ _state[f] = STATE_TREE;
+ _source[e] = u;
+ _target[e] = _root;
+ //_flow[e] = 0; //by default, the sparse matrix is empty
+ _cost[e] = 0;
+ _state[e] = STATE_LOWER;
+ ++f;
+ }
+ }
+ _all_arc_num = f;
+ } else {
+ // GEQ supply constraints
+ _search_arc_num = _arc_num + _node_num;
+ ArcsType f = _arc_num + _node_num;
+ for (ArcsType u = 0, e = _arc_num; u != _node_num; ++u, ++e) {
+ _parent[u] = _root;
+ _thread[u] = u + 1;
+ _rev_thread[u + 1] = u;
+ _succ_num[u] = 1;
+ _last_succ[u] = u;
+ if (_supply[u] <= 0) {
+ _forward[u] = false;
+ _pi[u] = 0;
+ _pred[u] = e;
+ _source[e] = _root;
+ _target[e] = u;
+ _flow[e] = -_supply[u];
+ _cost[e] = 0;
+ _state[e] = STATE_TREE;
+ } else {
+ _forward[u] = true;
+ _pi[u] = -ART_COST;
+ _pred[u] = f;
+ _source[f] = u;
+ _target[f] = _root;
+ _flow[f] = _supply[u];
+ _state[f] = STATE_TREE;
+ _cost[f] = ART_COST;
+ _source[e] = _root;
+ _target[e] = u;
+ //_flow[e] = 0; //by default, the sparse matrix is empty
+ _cost[e] = 0;
+ _state[e] = STATE_LOWER;
+ ++f;
+ }
+ }
+ _all_arc_num = f;
+ }
+
+ return true;
+ }
+
+ // Find the join node
+ void findJoinNode() {
+ int u = _source[in_arc];
+ int v = _target[in_arc];
+ while (u != v) {
+ if (_succ_num[u] < _succ_num[v]) {
+ u = _parent[u];
+ } else {
+ v = _parent[v];
+ }
+ }
+ join = u;
+ }
+
+ // Find the leaving arc of the cycle and returns true if the
+ // leaving arc is not the same as the entering arc
+ bool findLeavingArc() {
+ // Initialize first and second nodes according to the direction
+ // of the cycle
+ if (_state[in_arc] == STATE_LOWER) {
+ first = _source[in_arc];
+ second = _target[in_arc];
+ } else {
+ first = _target[in_arc];
+ second = _source[in_arc];
+ }
+ delta = INF;
+ char result = 0;
+ Value d;
+ ArcsType e;
+
+ // Search the cycle along the path form the first node to the root
+ for (int u = first; u != join; u = _parent[u]) {
+ e = _pred[u];
+ d = _forward[u] ? _flow[e] : INF;
+ if (d < delta) {
+ delta = d;
+ u_out = u;
+ result = 1;
+ }
+ }
+ // Search the cycle along the path form the second node to the root
+ for (int u = second; u != join; u = _parent[u]) {
+ e = _pred[u];
+ d = _forward[u] ? INF : _flow[e];
+ if (d <= delta) {
+ delta = d;
+ u_out = u;
+ result = 2;
+ }
+ }
+
+ if (result == 1) {
+ u_in = first;
+ v_in = second;
+ } else {
+ u_in = second;
+ v_in = first;
+ }
+ return result != 0;
+ }
+
+ // Change _flow and _state vectors
+ void changeFlow(bool change) {
+ // Augment along the cycle
+ if (delta > 0) {
+ Value val = _state[in_arc] * delta;
+ _flow[in_arc] += val;
+ for (int u = _source[in_arc]; u != join; u = _parent[u]) {
+ _flow[_pred[u]] += _forward[u] ? -val : val;
+ }
+ for (int u = _target[in_arc]; u != join; u = _parent[u]) {
+ _flow[_pred[u]] += _forward[u] ? val : -val;
+ }
+ }
+ // Update the state of the entering and leaving arcs
+ if (change) {
+ _state[in_arc] = STATE_TREE;
+ _state[_pred[u_out]] =
+ (_flow[_pred[u_out]] == 0) ? STATE_LOWER : STATE_UPPER;
+ } else {
+ _state[in_arc] = -_state[in_arc];
+ }
+ }
+
+ // Update the tree structure
+ void updateTreeStructure() {
+ int old_rev_thread = _rev_thread[u_out];
+ int old_succ_num = _succ_num[u_out];
+ int old_last_succ = _last_succ[u_out];
+ v_out = _parent[u_out];
+
+ // Check if u_in and u_out coincide
+ if (u_in == u_out) {
+ // Update _parent, _pred, _pred_dir
+ _parent[u_in] = v_in;
+ _pred[u_in] = in_arc;
+ _forward[u_in] = (u_in == _source[in_arc]);
+
+ // Update _thread and _rev_thread
+ if (_thread[v_in] != u_out) {
+ ArcsType after = _thread[old_last_succ];
+ _thread[old_rev_thread] = after;
+ _rev_thread[after] = old_rev_thread;
+ after = _thread[v_in];
+ _thread[v_in] = u_out;
+ _rev_thread[u_out] = v_in;
+ _thread[old_last_succ] = after;
+ _rev_thread[after] = old_last_succ;
+ }
+ } else {
+ // Handle the case when old_rev_thread equals to v_in
+ // (it also means that join and v_out coincide)
+ int thread_continue = old_rev_thread == v_in ?
+ _thread[old_last_succ] : _thread[v_in];
+
+ // Update _thread and _parent along the stem nodes (i.e. the nodes
+ // between u_in and u_out, whose parent have to be changed)
+ int stem = u_in; // the current stem node
+ int par_stem = v_in; // the new parent of stem
+ int next_stem; // the next stem node
+ int last = _last_succ[u_in]; // the last successor of stem
+ int before, after = _thread[last];
+ _thread[v_in] = u_in;
+ _dirty_revs.clear();
+ _dirty_revs.push_back(v_in);
+ while (stem != u_out) {
+ // Insert the next stem node into the thread list
+ next_stem = _parent[stem];
+ _thread[last] = next_stem;
+ _dirty_revs.push_back(last);
+
+ // Remove the subtree of stem from the thread list
+ before = _rev_thread[stem];
+ _thread[before] = after;
+ _rev_thread[after] = before;
+
+ // Change the parent node and shift stem nodes
+ _parent[stem] = par_stem;
+ par_stem = stem;
+ stem = next_stem;
+
+ // Update last and after
+ last = _last_succ[stem] == _last_succ[par_stem] ?
+ _rev_thread[par_stem] : _last_succ[stem];
+ after = _thread[last];
+ }
+ _parent[u_out] = par_stem;
+ _thread[last] = thread_continue;
+ _rev_thread[thread_continue] = last;
+ _last_succ[u_out] = last;
+
+ // Remove the subtree of u_out from the thread list except for
+ // the case when old_rev_thread equals to v_in
+ if (old_rev_thread != v_in) {
+ _thread[old_rev_thread] = after;
+ _rev_thread[after] = old_rev_thread;
+ }
+
+ // Update _rev_thread using the new _thread values
+ for (int i = 0; i != int(_dirty_revs.size()); ++i) {
+ int u = _dirty_revs[i];
+ _rev_thread[_thread[u]] = u;
+ }
+
+ // Update _pred, _pred_dir, _last_succ and _succ_num for the
+ // stem nodes from u_out to u_in
+ int tmp_sc = 0, tmp_ls = _last_succ[u_out];
+ for (int u = u_out, p = _parent[u]; u != u_in; u = p, p = _parent[u]) {
+ _pred[u] = _pred[p];
+ _forward[u] = !_forward[p];
+ tmp_sc += _succ_num[u] - _succ_num[p];
+ _succ_num[u] = tmp_sc;
+ _last_succ[p] = tmp_ls;
+ }
+ _pred[u_in] = in_arc;
+ _forward[u_in] = (u_in == _source[in_arc]);
+ _succ_num[u_in] = old_succ_num;
+ }
+
+ // Update _last_succ from v_in towards the root
+ int up_limit_out = _last_succ[join] == v_in ? join : -1;
+ int last_succ_out = _last_succ[u_out];
+ for (int u = v_in; u != -1 && _last_succ[u] == v_in; u = _parent[u]) {
+ _last_succ[u] = last_succ_out;
+ }
+
+ // Update _last_succ from v_out towards the root
+ if (join != old_rev_thread && v_in != old_rev_thread) {
+ for (int u = v_out; u != up_limit_out && _last_succ[u] == old_last_succ;
+ u = _parent[u]) {
+ _last_succ[u] = old_rev_thread;
+ }
+ } else if (last_succ_out != old_last_succ) {
+ for (int u = v_out; u != up_limit_out && _last_succ[u] == old_last_succ;
+ u = _parent[u]) {
+ _last_succ[u] = last_succ_out;
+ }
+ }
+
+ // Update _succ_num from v_in to join
+ for (int u = v_in; u != join; u = _parent[u]) {
+ _succ_num[u] += old_succ_num;
+ }
+ // Update _succ_num from v_out to join
+ for (int u = v_out; u != join; u = _parent[u]) {
+ _succ_num[u] -= old_succ_num;
+ }
+ }
+
+ void updatePotential() {
+ Cost sigma = _pi[v_in] - _pi[u_in] -
+ ((_forward[u_in])?_cost[in_arc]:(-_cost[in_arc]));
+ int end = _thread[_last_succ[u_in]];
+ for (int u = u_in; u != end; u = _thread[u]) {
+ _pi[u] += sigma;
+ }
+ }
+
+
+ // Heuristic initial pivots
+ bool initialPivots() {
+ Value curr, total = 0;
+ std::vector<Node> supply_nodes, demand_nodes;
+ Node u; _graph.first(u);
+ for (; u != INVALIDNODE; _graph.next(u)) {
+ curr = _supply[_node_id(u)];
+ if (curr > 0) {
+ total += curr;
+ supply_nodes.push_back(u);
+ } else if (curr < 0) {
+ demand_nodes.push_back(u);
+ }
+ }
+ if (_sum_supply > 0) total -= _sum_supply;
+ if (total <= 0) return true;
+
+ ArcVector arc_vector;
+ if (_sum_supply >= 0) {
+ if (supply_nodes.size() == 1 && demand_nodes.size() == 1) {
+ // Perform a reverse graph search from the sink to the source
+ //typename GR::template NodeMap<bool> reached(_graph, false);
+ BoolVector reached(_node_num, false);
+ Node s = supply_nodes[0], t = demand_nodes[0];
+ std::vector<Node> stack;
+ reached[t] = true;
+ stack.push_back(t);
+ while (!stack.empty()) {
+ Node u, v = stack.back();
+ stack.pop_back();
+ if (v == s) break;
+ Arc a; _graph.firstIn(a, v);
+ for (; a != INVALID; _graph.nextIn(a)) {
+ if (reached[u = _graph.source(a)]) continue;
+ ArcsType j = getArcID(a);
+ arc_vector.push_back(j);
+ reached[u] = true;
+ stack.push_back(u);
+ }
+ }
+ } else {
+ arc_vector.resize(demand_nodes.size());
+ // Find the min. cost incomming arc for each demand node
+#pragma omp parallel for
+ for (int i = 0; i < demand_nodes.size(); ++i) {
+ Node v = demand_nodes[i];
+ Cost min_cost = std::numeric_limits<Cost>::max();
+ Arc min_arc = INVALID;
+ Arc a; _graph.firstIn(a, v);
+ for (; a != INVALID; _graph.nextIn(a)) {
+ Cost c = _cost[getArcID(a)];
+ if (c < min_cost) {
+ min_cost = c;
+ min_arc = a;
+ }
+ }
+ arc_vector[i] = getArcID(min_arc);
+ }
+ arc_vector.erase(std::remove(arc_vector.begin(), arc_vector.end(), INVALID), arc_vector.end());
+ }
+ } else {
+ arc_vector.resize(supply_nodes.size());
+ // Find the min. cost outgoing arc for each supply node
+#pragma omp parallel for
+ for (int i = 0; i < int(supply_nodes.size()); ++i) {
+ Node u = supply_nodes[i];
+ Cost min_cost = std::numeric_limits<Cost>::max();
+ Arc min_arc = INVALID;
+ Arc a; _graph.firstOut(a, u);
+ for (; a != INVALID; _graph.nextOut(a)) {
+ Cost c = _cost[getArcID(a)];
+ if (c < min_cost) {
+ min_cost = c;
+ min_arc = a;
+ }
+ }
+ arc_vector[i] = getArcID(min_arc);
+ }
+ arc_vector.erase(std::remove(arc_vector.begin(), arc_vector.end(), INVALID), arc_vector.end());
+ }
+
+ // Perform heuristic initial pivots
+ for (ArcsType i = 0; i != ArcsType(arc_vector.size()); ++i) {
+ in_arc = arc_vector[i];
+ if (_state[in_arc] * (_cost[in_arc] + _pi[_source[in_arc]] -
+ _pi[_target[in_arc]]) >= 0) continue;
+ findJoinNode();
+ bool change = findLeavingArc();
+ if (delta >= MAX) return false;
+ changeFlow(change);
+ if (change) {
+ updateTreeStructure();
+ updatePotential();
+ }
+ }
+ return true;
+ }
+
+ // Execute the algorithm
+ ProblemType start() {
+ return start<BlockSearchPivotRule>();
+ }
+
+ template <typename PivotRuleImpl>
+ ProblemType start() {
+ PivotRuleImpl pivot(*this);
+ ProblemType retVal = OPTIMAL;
+
+ // Perform heuristic initial pivots
+ if (!initialPivots()) return UNBOUNDED;
+
+ size_t iter_number = 0;
+ // Execute the Network Simplex algorithm
+ while (pivot.findEnteringArc()) {
+ if ((++iter_number <= max_iter&&max_iter > 0) || max_iter<=0) {
+#if DEBUG_LVL>0
+ if(iter_number>MAX_DEBUG_ITER)
+ break;
+ if(iter_number%1000==0||iter_number%1000==1){
+ Cost curCost=totalCost();
+ Value sumFlow=0;
+ Cost a;
+ a= (fabs(_pi[_source[in_arc]])>=fabs(_pi[_target[in_arc]])) ? fabs(_pi[_source[in_arc]]) : fabs(_pi[_target[in_arc]]);
+ a=a>=fabs(_cost[in_arc])?a:fabs(_cost[in_arc]);
+ for (int i=0; i<_flow.size(); i++) {
+ sumFlow+=_state[i]*_flow[i];
+ }
+ std::cout << "Sum of the flow " << std::setprecision(20) << sumFlow << "\n" << iter_number << " iterations, current cost=" << curCost << "\nReduced cost=" << _state[in_arc] * (_cost[in_arc] + _pi[_source[in_arc]] -_pi[_target[in_arc]]) << "\nPrecision = "<< -EPSILON*(a) << "\n";
+ std::cout << "Arc in = (" << _node_id(_source[in_arc]) << ", " << _node_id(_target[in_arc]) <<")\n";
+ std::cout << "Supplies = (" << _supply[_source[in_arc]] << ", " << _supply[_target[in_arc]] << ")\n";
+ std::cout << _cost[in_arc] << "\n";
+ std::cout << _pi[_source[in_arc]] << "\n";
+ std::cout << _pi[_target[in_arc]] << "\n";
+ std::cout << a << "\n";
+ }
+#endif
+
+ findJoinNode();
+ bool change = findLeavingArc();
+ if (delta >= MAX) return UNBOUNDED;
+ changeFlow(change);
+ if (change) {
+ updateTreeStructure();
+ updatePotential();
+ }
+
+#if DEBUG_LVL>0
+ else{
+ std::cout << "No change\n";
+ }
+#endif
+
+#if DEBUG_LVL>1
+ std::cout << "Arc in = (" << _source[in_arc] << ", " << _target[in_arc] << ")\n";
+#endif
+
+
+ } else {
+ char errMess[1000];
+ sprintf( errMess, "RESULT MIGHT BE INACURATE\nMax number of iteration reached, currently \%d. Sometimes iterations go on in cycle even though the solution has been reached, to check if it's the case here have a look at the minimal reduced cost. If it is very close to machine precision, you might actually have the correct solution, if not try setting the maximum number of iterations a bit higher\n",iter_number );
+ std::cerr << errMess;
+ retVal = MAX_ITER_REACHED;
+ break;
+ }
+
+ }
+
+
+
+#if DEBUG_LVL>0
+ Cost curCost=totalCost();
+ Value sumFlow=0;
+ Cost a;
+ a= (fabs(_pi[_source[in_arc]])>=fabs(_pi[_target[in_arc]])) ? fabs(_pi[_source[in_arc]]) : fabs(_pi[_target[in_arc]]);
+ a=a>=fabs(_cost[in_arc])?a:fabs(_cost[in_arc]);
+ for (int i=0; i<_flow.size(); i++) {
+ sumFlow+=_state[i]*_flow[i];
+ }
+
+ std::cout << "Sum of the flow " << std::setprecision(20) << sumFlow << "\n" << niter << " iterations, current cost=" << curCost << "\nReduced cost=" << _state[in_arc] * (_cost[in_arc] + _pi[_source[in_arc]] -_pi[_target[in_arc]]) << "\nPrecision = "<< -EPSILON*(a) << "\n";
+
+ std::cout << "Arc in = (" << _node_id(_source[in_arc]) << ", " << _node_id(_target[in_arc]) <<")\n";
+ std::cout << "Supplies = (" << _supply[_source[in_arc]] << ", " << _supply[_target[in_arc]] << ")\n";
+
+#endif
+
+
+
+#if DEBUG_LVL>1
+ sumFlow=0;
+ for (int i=0; i<_flow.size(); i++) {
+ sumFlow+=_state[i]*_flow[i];
+ if (_state[i]==STATE_TREE) {
+ std::cout << "Non zero value at (" << _node_num+1-_source[i] << ", " << _node_num+1-_target[i] << ")\n";
+ }
+ }
+ std::cout << "Sum of the flow " << sumFlow << "\n"<< niter <<" iterations, current cost=" << totalCost() << "\n";
+#endif
+
+
+
+ //Check feasibility
+ if(retVal == OPTIMAL){
+ for (ArcsType e = _search_arc_num; e != _all_arc_num; ++e) {
+ if (_flow[e] != 0){
+ if (fabs(_flow[e]) > _EPSILON) // change of the original code following issue #126
+ return INFEASIBLE;
+ else
+ _flow[e]=0;
+ }
+ }
+ }
+
+ // Shift potentials to meet the requirements of the GEQ/LEQ type
+ // optimality conditions
+ if (_sum_supply == 0) {
+ if (_stype == GEQ) {
+ Cost max_pot = -std::numeric_limits<Cost>::max();
+ for (ArcsType i = 0; i != _node_num; ++i) {
+ if (_pi[i] > max_pot) max_pot = _pi[i];
+ }
+ if (max_pot > 0) {
+ for (ArcsType i = 0; i != _node_num; ++i)
+ _pi[i] -= max_pot;
+ }
+ } else {
+ Cost min_pot = std::numeric_limits<Cost>::max();
+ for (ArcsType i = 0; i != _node_num; ++i) {
+ if (_pi[i] < min_pot) min_pot = _pi[i];
+ }
+ if (min_pot < 0) {
+ for (ArcsType i = 0; i != _node_num; ++i)
+ _pi[i] -= min_pot;
+ }
+ }
+ }
+
+ return retVal;
+ }
+
+ }; //class NetworkSimplexSimple
+
+ ///@}
+
+} //namespace lemon_omp
diff --git a/ot/lp/solver_1d.py b/ot/lp/solver_1d.py
new file mode 100644
index 0000000..8b4d0c3
--- /dev/null
+++ b/ot/lp/solver_1d.py
@@ -0,0 +1,367 @@
+# -*- coding: utf-8 -*-
+"""
+Exact solvers for the 1D Wasserstein distance using cvxopt
+"""
+
+# Author: Remi Flamary <remi.flamary@unice.fr>
+# Author: Nicolas Courty <ncourty@irisa.fr>
+#
+# License: MIT License
+
+import numpy as np
+import warnings
+
+from .emd_wrap import emd_1d_sorted
+from ..backend import get_backend
+from ..utils import list_to_array
+
+
+def quantile_function(qs, cws, xs):
+ r""" Computes the quantile function of an empirical distribution
+
+ Parameters
+ ----------
+ qs: array-like, shape (n,)
+ Quantiles at which the quantile function is evaluated
+ cws: array-like, shape (m, ...)
+ cumulative weights of the 1D empirical distribution, if batched, must be similar to xs
+ xs: array-like, shape (n, ...)
+ locations of the 1D empirical distribution, batched against the `xs.ndim - 1` first dimensions
+
+ Returns
+ -------
+ q: array-like, shape (..., n)
+ The quantiles of the distribution
+ """
+ nx = get_backend(qs, cws)
+ n = xs.shape[0]
+ if nx.__name__ == 'torch':
+ # this is to ensure the best performance for torch searchsorted
+ # and avoid a warninng related to non-contiguous arrays
+ cws = cws.T.contiguous()
+ qs = qs.T.contiguous()
+ else:
+ cws = cws.T
+ qs = qs.T
+ idx = nx.searchsorted(cws, qs).T
+ return nx.take_along_axis(xs, nx.clip(idx, 0, n - 1), axis=0)
+
+
+def wasserstein_1d(u_values, v_values, u_weights=None, v_weights=None, p=1, require_sort=True):
+ r"""
+ Computes the 1 dimensional OT loss [15] between two (batched) empirical
+ distributions
+
+ .. math:
+ OT_{loss} = \int_0^1 |cdf_u^{-1}(q) cdf_v^{-1}(q)|^p dq
+
+ It is formally the p-Wasserstein distance raised to the power p.
+ We do so in a vectorized way by first building the individual quantile functions then integrating them.
+
+ This function should be preferred to `emd_1d` whenever the backend is
+ different to numpy, and when gradients over
+ either sample positions or weights are required.
+
+ Parameters
+ ----------
+ u_values: array-like, shape (n, ...)
+ locations of the first empirical distribution
+ v_values: array-like, shape (m, ...)
+ locations of the second empirical distribution
+ u_weights: array-like, shape (n, ...), optional
+ weights of the first empirical distribution, if None then uniform weights are used
+ v_weights: array-like, shape (m, ...), optional
+ weights of the second empirical distribution, if None then uniform weights are used
+ p: int, optional
+ order of the ground metric used, should be at least 1 (see [2, Chap. 2], default is 1
+ require_sort: bool, optional
+ sort the distributions atoms locations, if False we will consider they have been sorted prior to being passed to
+ the function, default is True
+
+ Returns
+ -------
+ cost: float/array-like, shape (...)
+ the batched EMD
+
+ References
+ ----------
+ .. [15] Peyré, G., & Cuturi, M. (2018). Computational Optimal Transport.
+
+ """
+
+ assert p >= 1, "The OT loss is only valid for p>=1, {p} was given".format(p=p)
+
+ if u_weights is not None and v_weights is not None:
+ nx = get_backend(u_values, v_values, u_weights, v_weights)
+ else:
+ nx = get_backend(u_values, v_values)
+
+ n = u_values.shape[0]
+ m = v_values.shape[0]
+
+ if u_weights is None:
+ u_weights = nx.full(u_values.shape, 1. / n)
+ elif u_weights.ndim != u_values.ndim:
+ u_weights = nx.repeat(u_weights[..., None], u_values.shape[-1], -1)
+ if v_weights is None:
+ v_weights = nx.full(v_values.shape, 1. / m)
+ elif v_weights.ndim != v_values.ndim:
+ v_weights = nx.repeat(v_weights[..., None], v_values.shape[-1], -1)
+
+ if require_sort:
+ u_sorter = nx.argsort(u_values, 0)
+ u_values = nx.take_along_axis(u_values, u_sorter, 0)
+
+ v_sorter = nx.argsort(v_values, 0)
+ v_values = nx.take_along_axis(v_values, v_sorter, 0)
+
+ u_weights = nx.take_along_axis(u_weights, u_sorter, 0)
+ v_weights = nx.take_along_axis(v_weights, v_sorter, 0)
+
+ u_cumweights = nx.cumsum(u_weights, 0)
+ v_cumweights = nx.cumsum(v_weights, 0)
+
+ qs = nx.sort(nx.concatenate((u_cumweights, v_cumweights), 0), 0)
+ u_quantiles = quantile_function(qs, u_cumweights, u_values)
+ v_quantiles = quantile_function(qs, v_cumweights, v_values)
+ qs = nx.zero_pad(qs, pad_width=[(1, 0)] + (qs.ndim - 1) * [(0, 0)])
+ delta = qs[1:, ...] - qs[:-1, ...]
+ diff_quantiles = nx.abs(u_quantiles - v_quantiles)
+
+ if p == 1:
+ return nx.sum(delta * nx.abs(diff_quantiles), axis=0)
+ return nx.sum(delta * nx.power(diff_quantiles, p), axis=0)
+
+
+def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True,
+ log=False):
+ r"""Solves the Earth Movers distance problem between 1d measures and returns
+ the OT matrix
+
+
+ .. math::
+ \gamma = arg\min_\gamma \sum_i \sum_j \gamma_{ij} d(x_a[i], x_b[j])
+
+ s.t. \gamma 1 = a,
+ \gamma^T 1= b,
+ \gamma\geq 0
+ where :
+
+ - d is the metric
+ - x_a and x_b are the samples
+ - a and b are the sample weights
+
+ When 'minkowski' is used as a metric, :math:`d(x, y) = |x - y|^p`.
+
+ Uses the algorithm detailed in [1]_
+
+ Parameters
+ ----------
+ x_a : (ns,) or (ns, 1) ndarray, float64
+ Source dirac locations (on the real line)
+ x_b : (nt,) or (ns, 1) ndarray, float64
+ Target dirac locations (on the real line)
+ a : (ns,) ndarray, float64, optional
+ Source histogram (default is uniform weight)
+ b : (nt,) ndarray, float64, optional
+ Target histogram (default is uniform weight)
+ metric: str, optional (default='sqeuclidean')
+ Metric to be used. Only strings listed in :func:`ot.dist` are accepted.
+ Due to implementation details, this function runs faster when
+ `'sqeuclidean'`, `'cityblock'`, or `'euclidean'` metrics are used.
+ p: float, optional (default=1.0)
+ The p-norm to apply for if metric='minkowski'
+ dense: boolean, optional (default=True)
+ If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt).
+ Otherwise returns a sparse representation using scipy's `coo_matrix`
+ format. Due to implementation details, this function runs faster when
+ `'sqeuclidean'`, `'minkowski'`, `'cityblock'`, or `'euclidean'` metrics
+ are used.
+ log: boolean, optional (default=False)
+ If True, returns a dictionary containing the cost.
+ Otherwise returns only the optimal transportation matrix.
+
+ Returns
+ -------
+ gamma: (ns, nt) ndarray
+ Optimal transportation matrix for the given parameters
+ log: dict
+ If input log is True, a dictionary containing the cost
+
+
+ Examples
+ --------
+
+ Simple example with obvious solution. The function emd_1d accepts lists and
+ performs automatic conversion to numpy arrays
+
+ >>> import ot
+ >>> a=[.5, .5]
+ >>> b=[.5, .5]
+ >>> x_a = [2., 0.]
+ >>> x_b = [0., 3.]
+ >>> ot.emd_1d(x_a, x_b, a, b)
+ array([[0. , 0.5],
+ [0.5, 0. ]])
+ >>> ot.emd_1d(x_a, x_b)
+ array([[0. , 0.5],
+ [0.5, 0. ]])
+
+ References
+ ----------
+
+ .. [1] Peyré, G., & Cuturi, M. (2017). "Computational Optimal
+ Transport", 2018.
+
+ See Also
+ --------
+ ot.lp.emd : EMD for multidimensional distributions
+ ot.lp.emd2_1d : EMD for 1d distributions (returns cost instead of the
+ transportation matrix)
+ """
+ a, b, x_a, x_b = list_to_array(a, b, x_a, x_b)
+ nx = get_backend(x_a, x_b)
+
+ assert (x_a.ndim == 1 or x_a.ndim == 2 and x_a.shape[1] == 1), \
+ "emd_1d should only be used with monodimensional data"
+ assert (x_b.ndim == 1 or x_b.ndim == 2 and x_b.shape[1] == 1), \
+ "emd_1d should only be used with monodimensional data"
+
+ # if empty array given then use uniform distributions
+ if a is None or a.ndim == 0 or len(a) == 0:
+ a = nx.ones((x_a.shape[0],), type_as=x_a) / x_a.shape[0]
+ if b is None or b.ndim == 0 or len(b) == 0:
+ b = nx.ones((x_b.shape[0],), type_as=x_b) / x_b.shape[0]
+
+ # ensure that same mass
+ np.testing.assert_almost_equal(
+ nx.to_numpy(nx.sum(a, axis=0)),
+ nx.to_numpy(nx.sum(b, axis=0)),
+ err_msg='a and b vector must have the same sum'
+ )
+ b = b * nx.sum(a) / nx.sum(b)
+
+ x_a_1d = nx.reshape(x_a, (-1,))
+ x_b_1d = nx.reshape(x_b, (-1,))
+ perm_a = nx.argsort(x_a_1d)
+ perm_b = nx.argsort(x_b_1d)
+
+ G_sorted, indices, cost = emd_1d_sorted(
+ nx.to_numpy(a[perm_a]).astype(np.float64),
+ nx.to_numpy(b[perm_b]).astype(np.float64),
+ nx.to_numpy(x_a_1d[perm_a]).astype(np.float64),
+ nx.to_numpy(x_b_1d[perm_b]).astype(np.float64),
+ metric=metric, p=p
+ )
+
+ G = nx.coo_matrix(
+ G_sorted,
+ perm_a[indices[:, 0]],
+ perm_b[indices[:, 1]],
+ shape=(a.shape[0], b.shape[0]),
+ type_as=x_a
+ )
+ if dense:
+ G = nx.todense(G)
+ elif str(nx) == "jax":
+ warnings.warn("JAX does not support sparse matrices, converting to dense")
+ if log:
+ log = {'cost': nx.from_numpy(cost, type_as=x_a)}
+ return G, log
+ return G
+
+
+def emd2_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True,
+ log=False):
+ r"""Solves the Earth Movers distance problem between 1d measures and returns
+ the loss
+
+
+ .. math::
+ \gamma = arg\min_\gamma \sum_i \sum_j \gamma_{ij} d(x_a[i], x_b[j])
+
+ s.t. \gamma 1 = a,
+ \gamma^T 1= b,
+ \gamma\geq 0
+ where :
+
+ - d is the metric
+ - x_a and x_b are the samples
+ - a and b are the sample weights
+
+ When 'minkowski' is used as a metric, :math:`d(x, y) = |x - y|^p`.
+
+ Uses the algorithm detailed in [1]_
+
+ Parameters
+ ----------
+ x_a : (ns,) or (ns, 1) ndarray, float64
+ Source dirac locations (on the real line)
+ x_b : (nt,) or (ns, 1) ndarray, float64
+ Target dirac locations (on the real line)
+ a : (ns,) ndarray, float64, optional
+ Source histogram (default is uniform weight)
+ b : (nt,) ndarray, float64, optional
+ Target histogram (default is uniform weight)
+ metric: str, optional (default='sqeuclidean')
+ Metric to be used. Only strings listed in :func:`ot.dist` are accepted.
+ Due to implementation details, this function runs faster when
+ `'sqeuclidean'`, `'minkowski'`, `'cityblock'`, or `'euclidean'` metrics
+ are used.
+ p: float, optional (default=1.0)
+ The p-norm to apply for if metric='minkowski'
+ dense: boolean, optional (default=True)
+ If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt).
+ Otherwise returns a sparse representation using scipy's `coo_matrix`
+ format. Only used if log is set to True. Due to implementation details,
+ this function runs faster when dense is set to False.
+ log: boolean, optional (default=False)
+ If True, returns a dictionary containing the transportation matrix.
+ Otherwise returns only the loss.
+
+ Returns
+ -------
+ loss: float
+ Cost associated to the optimal transportation
+ log: dict
+ If input log is True, a dictionary containing the Optimal transportation
+ matrix for the given parameters
+
+
+ Examples
+ --------
+
+ Simple example with obvious solution. The function emd2_1d accepts lists and
+ performs automatic conversion to numpy arrays
+
+ >>> import ot
+ >>> a=[.5, .5]
+ >>> b=[.5, .5]
+ >>> x_a = [2., 0.]
+ >>> x_b = [0., 3.]
+ >>> ot.emd2_1d(x_a, x_b, a, b)
+ 0.5
+ >>> ot.emd2_1d(x_a, x_b)
+ 0.5
+
+ References
+ ----------
+
+ .. [1] Peyré, G., & Cuturi, M. (2017). "Computational Optimal
+ Transport", 2018.
+
+ See Also
+ --------
+ ot.lp.emd2 : EMD for multidimensional distributions
+ ot.lp.emd_1d : EMD for 1d distributions (returns the transportation matrix
+ instead of the cost)
+ """
+ # If we do not return G (log==False), then we should not to cast it to dense
+ # (useless overhead)
+ G, log_emd = emd_1d(x_a=x_a, x_b=x_b, a=a, b=b, metric=metric, p=p,
+ dense=dense and log, log=True)
+ cost = log_emd['cost']
+ if log:
+ log_emd = {'G': G}
+ return cost, log_emd
+ return cost
diff --git a/ot/optim.py b/ot/optim.py
index b9ca891..bd8ca26 100644
--- a/ot/optim.py
+++ b/ot/optim.py
@@ -12,34 +12,36 @@ import numpy as np
from scipy.optimize.linesearch import scalar_search_armijo
from .lp import emd
from .bregman import sinkhorn
+from ot.utils import list_to_array
+from .backend import get_backend
# The corresponding scipy function does not work for matrices
def line_search_armijo(f, xk, pk, gfk, old_fval,
args=(), c1=1e-4, alpha0=0.99):
- """
+ r"""
Armijo linesearch function that works with matrices
- find an approximate minimum of f(xk+alpha*pk) that satifies the
+ Find an approximate minimum of :math:`f(x_k + \alpha \cdot p_k)` that satisfies the
armijo conditions.
Parameters
----------
f : callable
loss function
- xk : ndarray
+ xk : array-like
initial position
- pk : ndarray
+ pk : array-like
descent direction
- gfk : ndarray
- gradient of f at xk
+ gfk : array-like
+ gradient of `f` at :math:`x_k`
old_fval : float
- loss value at xk
+ loss value at :math:`x_k`
args : tuple, optional
- arguments given to f
+ arguments given to `f`
c1 : float, optional
- c1 const in armijo rule (>0)
+ :math:`c_1` const in armijo rule (>0)
alpha0 : float, optional
initial step (>0)
@@ -53,7 +55,13 @@ def line_search_armijo(f, xk, pk, gfk, old_fval,
loss value at step alpha
"""
- xk = np.atleast_1d(xk)
+
+ xk, pk, gfk = list_to_array(xk, pk, gfk)
+ nx = get_backend(xk, pk)
+
+ if len(xk.shape) == 0:
+ xk = nx.reshape(xk, (-1,))
+
fc = [0]
def phi(alpha1):
@@ -65,10 +73,13 @@ def line_search_armijo(f, xk, pk, gfk, old_fval,
else:
phi0 = old_fval
- derphi0 = np.sum(pk * gfk) # Quickfix for matrices
+ derphi0 = nx.sum(pk * gfk) # Quickfix for matrices
alpha, phi1 = scalar_search_armijo(
phi, phi0, derphi0, c1=c1, alpha0=alpha0)
+ # scalar_search_armijo can return alpha > 1
+ if alpha is not None:
+ alpha = min(1, alpha)
return alpha, fc[0], phi1
@@ -76,55 +87,64 @@ def solve_linesearch(cost, G, deltaG, Mi, f_val,
armijo=True, C1=None, C2=None, reg=None, Gc=None, constC=None, M=None):
"""
Solve the linesearch in the FW iterations
+
Parameters
----------
cost : method
Cost in the FW for the linesearch
- G : ndarray, shape(ns,nt)
+ G : array-like, shape(ns,nt)
The transport map at a given iteration of the FW
- deltaG : ndarray (ns,nt)
+ deltaG : array-like (ns,nt)
Difference between the optimal map found by linearization in the FW algorithm and the value at a given iteration
- Mi : ndarray (ns,nt)
+ Mi : array-like (ns,nt)
Cost matrix of the linearized transport problem. Corresponds to the gradient of the cost
- f_val : float
- Value of the cost at G
+ f_val : float
+ Value of the cost at `G`
armijo : bool, optional
- If True the steps of the line-search is found via an armijo research. Else closed form is used.
- If there is convergence issues use False.
- C1 : ndarray (ns,ns), optional
+ If True the steps of the line-search is found via an armijo research. Else closed form is used.
+ If there is convergence issues use False.
+ C1 : array-like (ns,ns), optional
Structure matrix in the source domain. Only used and necessary when armijo=False
- C2 : ndarray (nt,nt), optional
+ C2 : array-like (nt,nt), optional
Structure matrix in the target domain. Only used and necessary when armijo=False
reg : float, optional
- Regularization parameter. Only used and necessary when armijo=False
- Gc : ndarray (ns,nt)
+ Regularization parameter. Only used and necessary when armijo=False
+ Gc : array-like (ns,nt)
Optimal map found by linearization in the FW algorithm. Only used and necessary when armijo=False
- constC : ndarray (ns,nt)
- Constant for the gromov cost. See [24]. Only used and necessary when armijo=False
- M : ndarray (ns,nt), optional
+ constC : array-like (ns,nt)
+ Constant for the gromov cost. See :ref:`[24] <references-solve-linesearch>`. Only used and necessary when armijo=False
+ M : array-like (ns,nt), optional
Cost matrix between the features. Only used and necessary when armijo=False
+
Returns
-------
alpha : float
- The optimal step size of the FW
+ The optimal step size of the FW
fc : int
- nb of function call. Useless here
- f_val : float
- The value of the cost for the next iteration
+ nb of function call. Useless here
+ f_val : float
+ The value of the cost for the next iteration
+
+
+ .. _references-solve-linesearch:
References
----------
- .. [24] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain
- and Courty Nicolas
+ .. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain and Courty Nicolas
"Optimal Transport for structured data with application on graphs"
International Conference on Machine Learning (ICML). 2019.
"""
if armijo:
alpha, fc, f_val = line_search_armijo(cost, G, deltaG, Mi, f_val)
else: # requires symetric matrices
- dot1 = np.dot(C1, deltaG)
- dot12 = dot1.dot(C2)
- a = -2 * reg * np.sum(dot12 * deltaG)
- b = np.sum((M + reg * constC) * deltaG) - 2 * reg * (np.sum(dot12 * G) + np.sum(np.dot(C1, G).dot(C2) * deltaG))
+ G, deltaG, C1, C2, constC, M = list_to_array(G, deltaG, C1, C2, constC, M)
+ if isinstance(M, int) or isinstance(M, float):
+ nx = get_backend(G, deltaG, C1, C2, constC)
+ else:
+ nx = get_backend(G, deltaG, C1, C2, constC, M)
+
+ dot = nx.dot(nx.dot(C1, deltaG), C2)
+ a = -2 * reg * nx.sum(dot * deltaG)
+ b = nx.sum((M + reg * constC) * deltaG) - 2 * reg * (nx.sum(dot * G) + nx.sum(nx.dot(nx.dot(C1, G), C2) * deltaG))
c = cost(G)
alpha = solve_1d_linesearch_quad(a, b, c)
@@ -136,48 +156,49 @@ def solve_linesearch(cost, G, deltaG, Mi, f_val,
def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000,
stopThr=1e-9, stopThr2=1e-9, verbose=False, log=False, **kwargs):
- """
+ r"""
Solve the general regularized OT problem with conditional gradient
The function solves the following optimization problem:
.. math::
- \gamma = arg\min_\gamma <\gamma,M>_F + reg*f(\gamma)
+ \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
+ \mathrm{reg} \cdot f(\gamma)
- s.t. \gamma 1 = a
+ s.t. \ \gamma \mathbf{1} &= \mathbf{a}
- \gamma^T 1= b
+ \gamma^T \mathbf{1} &= \mathbf{b}
- \gamma\geq 0
+ \gamma &\geq 0
where :
- - M is the (ns,nt) metric cost matrix
- - :math:`f` is the regularization term ( and df is its gradient)
- - a and b are source and target weights (sum to 1)
+ - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix
+ - :math:`f` is the regularization term (and `df` is its gradient)
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1)
- The algorithm used for solving the problem is conditional gradient as discussed in [1]_
+ The algorithm used for solving the problem is conditional gradient as discussed in :ref:`[1] <references-cg>`
Parameters
----------
- a : ndarray, shape (ns,)
+ a : array-like, shape (ns,)
samples weights in the source domain
- b : ndarray, shape (nt,)
+ b : array-like, shape (nt,)
samples in the target domain
- M : ndarray, shape (ns, nt)
+ M : array-like, shape (ns, nt)
loss matrix
reg : float
Regularization term >0
- G0 : ndarray, shape (ns,nt), optional
+ G0 : array-like, shape (ns,nt), optional
initial guess (default is indep joint density)
numItermax : int, optional
Max number of iterations
numItermaxEmd : int, optional
Max number of iterations for emd
stopThr : float, optional
- Stop threshol on the relative variation (>0)
+ Stop threshold on the relative variation (>0)
stopThr2 : float, optional
- Stop threshol on the absolute variation (>0)
+ Stop threshold on the absolute variation (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
@@ -193,6 +214,7 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000,
log dictionary return only if log==True in parameters
+ .. _references-cg:
References
----------
@@ -204,6 +226,11 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000,
ot.bregman.sinkhorn : Entropic regularized optimal transport
"""
+ a, b, M, G0 = list_to_array(a, b, M, G0)
+ if isinstance(M, int) or isinstance(M, float):
+ nx = get_backend(a, b)
+ else:
+ nx = get_backend(a, b, M)
loop = 1
@@ -211,12 +238,12 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000,
log = {'loss': []}
if G0 is None:
- G = np.outer(a, b)
+ G = nx.outer(a, b)
else:
G = G0
def cost(G):
- return np.sum(M * G) + reg * f(G)
+ return nx.sum(M * G) + reg * f(G)
f_val = cost(G)
if log:
@@ -237,15 +264,17 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000,
# problem linearization
Mi = M + reg * df(G)
# set M positive
- Mi += Mi.min()
+ Mi += nx.min(Mi)
# solve linear program
- Gc = emd(a, b, Mi, numItermax=numItermaxEmd)
+ Gc, logemd = emd(a, b, Mi, numItermax=numItermaxEmd, log=True)
deltaG = Gc - G
# line search
alpha, fc, f_val = solve_linesearch(cost, G, deltaG, Mi, f_val, reg=reg, M=M, Gc=Gc, **kwargs)
+ if alpha is None:
+ alpha = 0.0
G = G + alpha * deltaG
@@ -268,6 +297,7 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000,
print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, f_val, relative_delta_fval, abs_delta_fval))
if log:
+ log.update(logemd)
return G, log
else:
return G
@@ -275,51 +305,52 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000,
def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10,
numInnerItermax=200, stopThr=1e-9, stopThr2=1e-9, verbose=False, log=False):
- """
+ r"""
Solve the general regularized OT problem with the generalized conditional gradient
The function solves the following optimization problem:
.. math::
- \gamma = arg\min_\gamma <\gamma,M>_F + reg1\cdot\Omega(\gamma) + reg2\cdot f(\gamma)
+ \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
+ \mathrm{reg_1}\cdot\Omega(\gamma) + \mathrm{reg_2}\cdot f(\gamma)
- s.t. \gamma 1 = a
+ s.t. \ \gamma \mathbf{1} &= \mathbf{a}
- \gamma^T 1= b
+ \gamma^T \mathbf{1} &= \mathbf{b}
- \gamma\geq 0
+ \gamma &\geq 0
where :
- - M is the (ns,nt) metric cost matrix
+ - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix
- :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - :math:`f` is the regularization term ( and df is its gradient)
- - a and b are source and target weights (sum to 1)
+ - :math:`f` is the regularization term (and `df` is its gradient)
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1)
- The algorithm used for solving the problem is the generalized conditional gradient as discussed in [5,7]_
+ The algorithm used for solving the problem is the generalized conditional gradient as discussed in :ref:`[5, 7] <references-gcg>`
Parameters
----------
- a : ndarray, shape (ns,)
+ a : array-like, shape (ns,)
samples weights in the source domain
- b : ndarrayv (nt,)
+ b : array-like, (nt,)
samples in the target domain
- M : ndarray, shape (ns, nt)
+ M : array-like, shape (ns, nt)
loss matrix
reg1 : float
Entropic Regularization term >0
reg2 : float
Second Regularization term >0
- G0 : ndarray, shape (ns, nt), optional
+ G0 : array-like, shape (ns, nt), optional
initial guess (default is indep joint density)
numItermax : int, optional
Max number of iterations
numInnerItermax : int, optional
Max number of iterations of Sinkhorn
stopThr : float, optional
- Stop threshol on the relative variation (>0)
+ Stop threshold on the relative variation (>0)
stopThr2 : float, optional
- Stop threshol on the absolute variation (>0)
+ Stop threshold on the absolute variation (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
@@ -332,9 +363,13 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10,
log : dict
log dictionary return only if log==True in parameters
+
+ .. _references-gcg:
References
----------
+
.. [5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, "Optimal Transport for Domain Adaptation," in IEEE Transactions on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1
+
.. [7] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015). Generalized conditional gradient: analysis of convergence and applications. arXiv preprint arXiv:1510.06567.
See Also
@@ -342,6 +377,8 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10,
ot.optim.cg : conditional gradient
"""
+ a, b, M, G0 = list_to_array(a, b, M, G0)
+ nx = get_backend(a, b, M)
loop = 1
@@ -349,12 +386,12 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10,
log = {'loss': []}
if G0 is None:
- G = np.outer(a, b)
+ G = nx.outer(a, b)
else:
G = G0
def cost(G):
- return np.sum(M * G) + reg1 * np.sum(G * np.log(G)) + reg2 * f(G)
+ return nx.sum(M * G) + reg1 * nx.sum(G * nx.log(G)) + reg2 * f(G)
f_val = cost(G)
if log:
@@ -382,7 +419,7 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10,
deltaG = Gc - G
# line search
- dcost = Mi + reg1 * (1 + np.log(G)) # ??
+ dcost = Mi + reg1 * (1 + nx.log(G)) # ??
alpha, fc, f_val = line_search_armijo(cost, G, deltaG, dcost, f_val)
G = G + alpha * deltaG
@@ -413,10 +450,12 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10,
def solve_1d_linesearch_quad(a, b, c):
- """
- For any convex or non-convex 1d quadratic function f, solve on [0,1] the following problem:
+ r"""
+ For any convex or non-convex 1d quadratic function `f`, solve the following problem:
+
.. math::
- \argmin f(x)=a*x^{2}+b*x+c
+
+ \mathop{\arg \min}_{0 \leq x \leq 1} \quad f(x) = ax^{2} + bx + c
Parameters
----------
diff --git a/ot/partial.py b/ot/partial.py
index eb707d8..b7093e4 100755
--- a/ot/partial.py
+++ b/ot/partial.py
@@ -20,13 +20,16 @@ def partial_wasserstein_lagrange(a, b, M, reg_m=None, nb_dummies=1, log=False,
The function considers the following problem:
.. math::
- \gamma = \arg\min_\gamma <\gamma,(M-\lambda)>_F
+ \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, (\mathbf{M} - \lambda) \rangle_F
- s.t.
- \gamma\geq 0 \\
- \gamma 1 \leq a\\
- \gamma^T 1 \leq b\\
- 1^T \gamma^T 1 = m \leq \min\{\|a\|_1, \|b\|_1\}
+ .. math::
+ s.t. \ \gamma \mathbf{1} &\leq \mathbf{a}
+
+ \gamma^T \mathbf{1} &\leq \mathbf{b}
+
+ \gamma &\geq 0
+
+ \mathbf{1}^T \gamma^T \mathbf{1} = m &\leq \min\{\|\mathbf{a}\|_1, \|\mathbf{b}\|_1\}
or equivalently (see Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X.
@@ -34,33 +37,32 @@ def partial_wasserstein_lagrange(a, b, M, reg_m=None, nb_dummies=1, log=False,
metrics. Foundations of Computational Mathematics, 18(1), 1-44.)
.. math::
- \gamma = \arg\min_\gamma <\gamma,M>_F + \sqrt(\lambda/2)
- (\|\gamma 1 - a\|_1 + \|\gamma^T 1 - b\|_1)
+ \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
+ \sqrt{\frac{\lambda}{2} (\|\gamma \mathbf{1} - \mathbf{a}\|_1 + \|\gamma^T \mathbf{1} - \mathbf{b}\|_1)}
- s.t.
- \gamma\geq 0 \\
+ s.t. \ \gamma \geq 0
where :
- - M is the metric cost matrix
- - a and b are source and target unbalanced distributions
- - :math:`\lambda` is the lagragian cost. Tuning its value allows attaining
- a given mass to be transported m
+ - :math:`\mathbf{M}` is the metric cost matrix
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target unbalanced distributions
+ - :math:`\lambda` is the lagrangian cost. Tuning its value allows attaining
+ a given mass to be transported `m`
- The formulation of the problem has been proposed in [28]_
+ The formulation of the problem has been proposed in :ref:`[28] <references-partial-wasserstein-lagrange>`
Parameters
----------
a : np.ndarray (dim_a,)
- Unnormalized histogram of dimension dim_a
+ Unnormalized histogram of dimension `dim_a`
b : np.ndarray (dim_b,)
- Unnormalized histograms of dimension dim_b
+ Unnormalized histograms of dimension `dim_b`
M : np.ndarray (dim_a, dim_b)
cost matrix for the quadratic cost
reg_m : float, optional
- Lagragian cost
+ Lagrangian cost
nb_dummies : int, optional, default:1
number of reservoir points to be added (to avoid numerical
instabilities, increase its value if an error is raised)
@@ -69,6 +71,7 @@ def partial_wasserstein_lagrange(a, b, M, reg_m=None, nb_dummies=1, log=False,
**kwargs : dict
parameters can be directly passed to the emd solver
+
.. warning::
When dealing with a large number of points, the EMD solver may face
some instabilities, especially when the mass associated to the dummy
@@ -77,7 +80,7 @@ def partial_wasserstein_lagrange(a, b, M, reg_m=None, nb_dummies=1, log=False,
Returns
-------
- gamma : (dim_a x dim_b) ndarray
+ gamma : (dim_a, dim_b) ndarray
Optimal transportation matrix for the given parameters
log : dict
log dictionary returned only if `log` is `True`
@@ -97,9 +100,10 @@ def partial_wasserstein_lagrange(a, b, M, reg_m=None, nb_dummies=1, log=False,
array([[0.1, 0. ],
[0. , 0. ]])
+
+ .. _references-partial-wasserstein-lagrange:
References
----------
-
.. [28] Caffarelli, L. A., & McCann, R. J. (2010) Free boundaries in
optimal transport and Monge-Ampere obstacle problems. Annals of
mathematics, 673-730.
@@ -162,27 +166,30 @@ def partial_wasserstein(a, b, M, m=None, nb_dummies=1, log=False, **kwargs):
The function considers the following problem:
.. math::
- \gamma = \arg\min_\gamma <\gamma,M>_F
+ \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F
- s.t.
- \gamma\geq 0 \\
- \gamma 1 \leq a\\
- \gamma^T 1 \leq b\\
- 1^T \gamma^T 1 = m \leq \min\{\|a\|_1, \|b\|_1\}
+ .. math::
+ s.t. \ \gamma \mathbf{1} &\leq \mathbf{a}
+
+ \gamma^T \mathbf{1} &\leq \mathbf{b}
+
+ \gamma &\geq 0
+
+ \mathbf{1}^T \gamma^T \mathbf{1} = m &\leq \min\{\|\mathbf{a}\|_1, \|\mathbf{b}\|_1\}
where :
- - M is the metric cost matrix
- - a and b are source and target unbalanced distributions
- - m is the amount of mass to be transported
+ - :math:`\mathbf{M}` is the metric cost matrix
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target unbalanced distributions
+ - `m` is the amount of mass to be transported
Parameters
----------
a : np.ndarray (dim_a,)
- Unnormalized histogram of dimension dim_a
+ Unnormalized histogram of dimension `dim_a`
b : np.ndarray (dim_b,)
- Unnormalized histograms of dimension dim_b
+ Unnormalized histograms of dimension `dim_b`
M : np.ndarray (dim_a, dim_b)
cost matrix for the quadratic cost
m : float, optional
@@ -205,7 +212,7 @@ def partial_wasserstein(a, b, M, m=None, nb_dummies=1, log=False, **kwargs):
Returns
-------
- :math:`gamma` : (dim_a x dim_b) ndarray
+ gamma : (dim_a, dim_b) ndarray
Optimal transportation matrix for the given parameters
log : dict
log dictionary returned only if `log` is `True`
@@ -230,9 +237,9 @@ def partial_wasserstein(a, b, M, m=None, nb_dummies=1, log=False, **kwargs):
.. [28] Caffarelli, L. A., & McCann, R. J. (2010) Free boundaries in
optimal transport and Monge-Ampere obstacle problems. Annals of
mathematics, 673-730.
- .. [29] Chapel, L., Alaya, M., Gasso, G. (2019). "Partial Gromov-
- Wasserstein with Applications on Positive-Unlabeled Learning".
- arXiv preprint arXiv:2002.08276.
+ .. [29] Chapel, L., Alaya, M., Gasso, G. (2020). "Partial Optimal
+ Transport with Applications on Positive-Unlabeled Learning".
+ NeurIPS.
See Also
--------
@@ -254,7 +261,7 @@ def partial_wasserstein(a, b, M, m=None, nb_dummies=1, log=False, **kwargs):
b_extended = np.append(b, [(np.sum(a) - m) / nb_dummies] * nb_dummies)
a_extended = np.append(a, [(np.sum(b) - m) / nb_dummies] * nb_dummies)
M_extended = np.zeros((len(a_extended), len(b_extended)))
- M_extended[-1, -1] = np.max(M) * 1e5
+ M_extended[-nb_dummies:, -nb_dummies:] = np.max(M) * 1e5
M_extended[:len(a), :len(b)] = M
gamma, log_emd = emd(a_extended, b_extended, M_extended, log=True,
@@ -278,27 +285,30 @@ def partial_wasserstein2(a, b, M, m=None, nb_dummies=1, log=False, **kwargs):
The function considers the following problem:
.. math::
- \gamma = \arg\min_\gamma <\gamma,M>_F
+ \gamma = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F
+
+ .. math::
+ s.t. \ \gamma \mathbf{1} &\leq \mathbf{a}
- s.t.
- \gamma\geq 0 \\
- \gamma 1 \leq a\\
- \gamma^T 1 \leq b\\
- 1^T \gamma^T 1 = m \leq \min\{\|a\|_1, \|b\|_1\}
+ \gamma^T \mathbf{1} &\leq \mathbf{b}
+
+ \gamma &\geq 0
+
+ \mathbf{1}^T \gamma^T \mathbf{1} = m &\leq \min\{\|\mathbf{a}\|_1, \|\mathbf{b}\|_1\}
where :
- - M is the metric cost matrix
- - a and b are source and target unbalanced distributions
- - m is the amount of mass to be transported
+ - :math:`\mathbf{M}` is the metric cost matrix
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target unbalanced distributions
+ - `m` is the amount of mass to be transported
Parameters
----------
a : np.ndarray (dim_a,)
- Unnormalized histogram of dimension dim_a
+ Unnormalized histogram of dimension `dim_a`
b : np.ndarray (dim_b,)
- Unnormalized histograms of dimension dim_b
+ Unnormalized histograms of dimension `dim_b`
M : np.ndarray (dim_a, dim_b)
cost matrix for the quadratic cost
m : float, optional
@@ -321,8 +331,8 @@ def partial_wasserstein2(a, b, M, m=None, nb_dummies=1, log=False, **kwargs):
Returns
-------
- :math:`gamma` : (dim_a x dim_b) ndarray
- Optimal transportation matrix for the given parameters
+ GW: float
+ partial GW discrepancy
log : dict
log dictionary returned only if `log` is `True`
@@ -344,14 +354,13 @@ def partial_wasserstein2(a, b, M, m=None, nb_dummies=1, log=False, **kwargs):
.. [28] Caffarelli, L. A., & McCann, R. J. (2010) Free boundaries in
optimal transport and Monge-Ampere obstacle problems. Annals of
mathematics, 673-730.
- .. [29] Chapel, L., Alaya, M., Gasso, G. (2019). "Partial Gromov-
- Wasserstein with Applications on Positive-Unlabeled Learning".
- arXiv preprint arXiv:2002.08276.
+ .. [29] Chapel, L., Alaya, M., Gasso, G. (2020). "Partial Optimal
+ Transport with Applications on Positive-Unlabeled Learning".
+ NeurIPS.
"""
partial_gw, log_w = partial_wasserstein(a, b, M, m, nb_dummies, log=True,
**kwargs)
-
log_w['T'] = partial_gw
if log:
@@ -361,8 +370,8 @@ def partial_wasserstein2(a, b, M, m=None, nb_dummies=1, log=False, **kwargs):
def gwgrad_partial(C1, C2, T):
- """Compute the GW gradient. Note: we can not use the trick in [12]_ as
- the marginals may not sum to 1.
+ """Compute the GW gradient. Note: we can not use the trick in :ref:`[12] <references-gwgrad-partial>`
+ as the marginals may not sum to 1.
Parameters
----------
@@ -380,6 +389,8 @@ def gwgrad_partial(C1, C2, T):
numpy.array of shape (n_p+nb_dummies, n_u)
gradient
+
+ .. _references-gwgrad-partial:
References
----------
.. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon,
@@ -426,22 +437,25 @@ def partial_gromov_wasserstein(C1, C2, p, q, m=None, nb_dummies=1, G0=None,
The function considers the following problem:
.. math::
- \gamma = arg\min_\gamma <\gamma,M>_F
+ \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F
- s.t. \gamma 1 \leq a \\
- \gamma^T 1 \leq b \\
- \gamma\geq 0 \\
- 1^T \gamma^T 1 = m \leq \min\{\|a\|_1, \|b\|_1\} \\
+ .. math::
+ s.t. \ \gamma \mathbf{1} &\leq \mathbf{a}
+
+ \gamma^T \mathbf{1} &\leq \mathbf{b}
+
+ \gamma &\geq 0
+
+ \mathbf{1}^T \gamma^T \mathbf{1} = m &\leq \min\{\|\mathbf{a}\|_1, \|\mathbf{b}\|_1\}
where :
- - M is the metric cost matrix
- - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)
- =\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - a and b are the sample weights
- - m is the amount of mass to be transported
+ - :math:`\mathbf{M}` is the metric cost matrix
+ - :math:`\Omega` is the entropic regularization term, :math:`\Omega(\gamma) = \sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are the sample weights
+ - `m` is the amount of mass to be transported
- The formulation of the problem has been proposed in [29]_
+ The formulation of the problem has been proposed in :ref:`[29] <references-partial-gromov-wasserstein>`
Parameters
@@ -455,7 +469,7 @@ def partial_gromov_wasserstein(C1, C2, p, q, m=None, nb_dummies=1, G0=None,
q : ndarray, shape (nt,)
Distribution in the target space
m : float, optional
- Amount of mass to be transported (default: min (|p|_1, |q|_1))
+ Amount of mass to be transported (default: :math:`\min\{\|\mathbf{p}\|_1, \|\mathbf{q}\|_1\}`)
nb_dummies : int, optional
Number of dummy points to add (avoid instabilities in the EMD solver)
G0 : ndarray, shape (ns, nt), optional
@@ -477,7 +491,7 @@ def partial_gromov_wasserstein(C1, C2, p, q, m=None, nb_dummies=1, G0=None,
Returns
-------
- gamma : (dim_a x dim_b) ndarray
+ gamma : (dim_a, dim_b) ndarray
Optimal transportation matrix for the given parameters
log : dict
log dictionary returned only if `log` is `True`
@@ -501,14 +515,16 @@ def partial_gromov_wasserstein(C1, C2, p, q, m=None, nb_dummies=1, G0=None,
>>> np.round(partial_gromov_wasserstein(C1, C2, a, b, m=0.25),2)
array([[0. , 0. , 0. , 0. ],
[0. , 0. , 0. , 0. ],
- [0. , 0. , 0. , 0. ],
- [0. , 0. , 0. , 0.25]])
+ [0. , 0. , 0.25, 0. ],
+ [0. , 0. , 0. , 0. ]])
+
+ .. _references-partial-gromov-wasserstein:
References
----------
- .. [29] Chapel, L., Alaya, M., Gasso, G. (2019). "Partial Gromov-
- Wasserstein with Applications on Positive-Unlabeled Learning".
- arXiv preprint arXiv:2002.08276.
+ .. [29] Chapel, L., Alaya, M., Gasso, G. (2020). "Partial Optimal
+ Transport with Applications on Positive-Unlabeled Learning".
+ NeurIPS.
"""
@@ -530,20 +546,18 @@ def partial_gromov_wasserstein(C1, C2, p, q, m=None, nb_dummies=1, G0=None,
cpt = 0
err = 1
- eps = 1e-20
+
if log:
log = {'err': []}
while (err > tol and cpt < numItermax):
- Gprev = G0
+ Gprev = np.copy(G0)
M = gwgrad_partial(C1, C2, G0)
- M[M < eps] = np.quantile(M, thres)
-
M_emd = np.zeros(dim_G_extended)
M_emd[:len(p), :len(q)] = M
- M_emd[-nb_dummies:, -nb_dummies:] = np.max(M) * 1e5
+ M_emd[-nb_dummies:, -nb_dummies:] = np.max(M) * 1e2
M_emd = np.asarray(M_emd, dtype=np.float64)
Gc, logemd = emd(p_extended, q_extended, M_emd, log=True, **kwargs)
@@ -565,6 +579,22 @@ def partial_gromov_wasserstein(C1, C2, p, q, m=None, nb_dummies=1, G0=None,
print('{:5d}|{:8e}|{:8e}'.format(cpt, err,
gwloss_partial(C1, C2, G0)))
+ deltaG = G0 - Gprev
+ a = gwloss_partial(C1, C2, deltaG)
+ b = 2 * np.sum(M * deltaG)
+ if b > 0: # due to numerical precision
+ gamma = 0
+ cpt = numItermax
+ elif a > 0:
+ gamma = min(1, np.divide(-b, 2.0 * a))
+ else:
+ if (a + b) < 0:
+ gamma = 1
+ else:
+ gamma = 0
+ cpt = numItermax
+
+ G0 = Gprev + gamma * deltaG
cpt += 1
if log:
@@ -584,22 +614,25 @@ def partial_gromov_wasserstein2(C1, C2, p, q, m=None, nb_dummies=1, G0=None,
The function considers the following problem:
.. math::
- \gamma = arg\min_\gamma <\gamma,M>_F
+ GW = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F
- s.t. \gamma 1 \leq a \\
- \gamma^T 1 \leq b \\
- \gamma\geq 0 \\
- 1^T \gamma^T 1 = m \leq \min\{\|a\|_1, \|b\|_1\} \\
+ .. math::
+ s.t. \ \gamma \mathbf{1} &\leq \mathbf{a}
+
+ \gamma^T \mathbf{1} &\leq \mathbf{b}
+
+ \gamma &\geq 0
+
+ \mathbf{1}^T \gamma^T \mathbf{1} = m &\leq \min\{\|\mathbf{a}\|_1, \|\mathbf{b}\|_1\}
where :
- - M is the metric cost matrix
- - :math:`\Omega` is the entropic regularization term
- :math:`\Omega=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - a and b are the sample weights
- - m is the amount of mass to be transported
+ - :math:`\mathbf{M}` is the metric cost matrix
+ - :math:`\Omega` is the entropic regularization term, :math:`\Omega(\gamma) = \sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are the sample weights
+ - `m` is the amount of mass to be transported
- The formulation of the problem has been proposed in [29]_
+ The formulation of the problem has been proposed in :ref:`[29] <references-partial-gromov-wasserstein2>`
Parameters
@@ -613,7 +646,7 @@ def partial_gromov_wasserstein2(C1, C2, p, q, m=None, nb_dummies=1, G0=None,
q : ndarray, shape (nt,)
Distribution in the target space
m : float, optional
- Amount of mass to be transported (default: min (|p|_1, |q|_1))
+ Amount of mass to be transported (default: :math:`\min\{\|\mathbf{p}\|_1, \|\mathbf{q}\|_1\}`)
nb_dummies : int, optional
Number of dummy points to add (avoid instabilities in the EMD solver)
G0 : ndarray, shape (ns, nt), optional
@@ -642,7 +675,7 @@ def partial_gromov_wasserstein2(C1, C2, p, q, m=None, nb_dummies=1, G0=None,
Returns
-------
- partial_gw_dist : (dim_a x dim_b) ndarray
+ partial_gw_dist : float
partial GW discrepancy
log : dict
log dictionary returned only if `log` is `True`
@@ -663,11 +696,13 @@ def partial_gromov_wasserstein2(C1, C2, p, q, m=None, nb_dummies=1, G0=None,
>>> np.round(partial_gromov_wasserstein2(C1, C2, a, b, m=0.25),2)
0.0
+
+ .. _references-partial-gromov-wasserstein2:
References
----------
- .. [29] Chapel, L., Alaya, M., Gasso, G. (2019). "Partial Gromov-
- Wasserstein with Applications on Positive-Unlabeled Learning".
- arXiv preprint arXiv:2002.08276.
+ .. [29] Chapel, L., Alaya, M., Gasso, G. (2020). "Partial Optimal
+ Transport with Applications on Positive-Unlabeled Learning".
+ NeurIPS.
"""
@@ -693,30 +728,29 @@ def entropic_partial_wasserstein(a, b, M, reg, m=None, numItermax=1000,
The function considers the following problem:
.. math::
- \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)
+ \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + \mathrm{reg} \cdot\Omega(\gamma)
- s.t. \gamma 1 \leq a \\
- \gamma^T 1 \leq b \\
- \gamma\geq 0 \\
- 1^T \gamma^T 1 = m \leq \min\{\|a\|_1, \|b\|_1\} \\
+ s.t. \gamma \mathbf{1} &\leq \mathbf{a} \\
+ \gamma^T \mathbf{1} &\leq \mathbf{b} \\
+ \gamma &\geq 0 \\
+ \mathbf{1}^T \gamma^T \mathbf{1} = m &\leq \min\{\|\mathbf{a}\|_1, \|\mathbf{b}\|_1\} \\
where :
- - M is the metric cost matrix
- - :math:`\Omega` is the entropic regularization term
- :math:`\Omega=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - a and b are the sample weights
- - m is the amount of mass to be transported
+ - :math:`\mathbf{M}` is the metric cost matrix
+ - :math:`\Omega` is the entropic regularization term, :math:`\Omega=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are the sample weights
+ - `m` is the amount of mass to be transported
- The formulation of the problem has been proposed in [3]_ (prop. 5)
+ The formulation of the problem has been proposed in :ref:`[3] <references-entropic-partial-wasserstein>` (prop. 5)
Parameters
----------
a : np.ndarray (dim_a,)
- Unnormalized histogram of dimension dim_a
+ Unnormalized histogram of dimension `dim_a`
b : np.ndarray (dim_b,)
- Unnormalized histograms of dimension dim_b
+ Unnormalized histograms of dimension `dim_b`
M : np.ndarray (dim_a, dim_b)
cost matrix
reg : float
@@ -735,7 +769,7 @@ def entropic_partial_wasserstein(a, b, M, reg, m=None, numItermax=1000,
Returns
-------
- gamma : (dim_a x dim_b) ndarray
+ gamma : (dim_a, dim_b) ndarray
Optimal transportation matrix for the given parameters
log : dict
log dictionary returned only if `log` is `True`
@@ -751,6 +785,8 @@ def entropic_partial_wasserstein(a, b, M, reg, m=None, numItermax=1000,
array([[0.06, 0.02],
[0.01, 0. ]])
+
+ .. _references-entropic-partial-wasserstein:
References
----------
.. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G.
@@ -825,32 +861,34 @@ def entropic_partial_gromov_wasserstein(C1, C2, p, q, reg, m=None, G0=None,
numItermax=1000, tol=1e-7, log=False,
verbose=False):
r"""
- Returns the partial Gromov-Wasserstein transport between (C1,p) and (C2,q)
+ Returns the partial Gromov-Wasserstein transport between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})`
The function solves the following optimization problem:
.. math::
- GW = \arg\min_{\gamma} \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})\cdot
- \gamma_{i,j}\cdot\gamma_{k,l} + reg\cdot\Omega(\gamma)
+ \gamma = \mathop{\arg \min}_{\gamma} \quad \sum_{i,j,k,l}
+ L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l})\cdot
+ \gamma_{i,j}\cdot\gamma_{k,l} + \mathrm{reg} \cdot\Omega(\gamma)
- s.t.
- \gamma\geq 0 \\
- \gamma 1 \leq a\\
- \gamma^T 1 \leq b\\
- 1^T \gamma^T 1 = m \leq \min\{\|a\|_1, \|b\|_1\}
+ .. math::
+ s.t. \ \gamma &\geq 0
+
+ \gamma \mathbf{1} &\leq \mathbf{a}
+
+ \gamma^T \mathbf{1} &\leq \mathbf{b}
+
+ \mathbf{1}^T \gamma^T \mathbf{1} = m &\leq \min\{\|\mathbf{a}\|_1, \|\mathbf{b}\|_1\}
where :
- - C1 is the metric cost matrix in the source space
- - C2 is the metric cost matrix in the target space
- - p and q are the sample weights
- - L : quadratic loss function
- - :math:`\Omega` is the entropic regularization term
- :math:`\Omega=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - m is the amount of mass to be transported
+ - :math:`\mathbf{C_1}` is the metric cost matrix in the source space
+ - :math:`\mathbf{C_2}` is the metric cost matrix in the target space
+ - :math:`\mathbf{p}` and :math:`\mathbf{q}` are the sample weights
+ - `L`: quadratic loss function
+ - :math:`\Omega` is the entropic regularization term, :math:`\Omega=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - `m` is the amount of mass to be transported
- The formulation of the GW problem has been proposed in [12]_ and the
- partial GW in [29]_.
+ The formulation of the GW problem has been proposed in :ref:`[12] <references-entropic-partial-gromov-wassertein>` and the partial GW in :ref:`[29] <references-entropic-partial-gromov-wassertein>`
Parameters
----------
@@ -865,7 +903,7 @@ def entropic_partial_gromov_wasserstein(C1, C2, p, q, reg, m=None, G0=None,
reg: float
entropic regularization parameter
m : float, optional
- Amount of mass to be transported (default: min (|p|_1, |q|_1))
+ Amount of mass to be transported (default: :math:`\min\{\|\mathbf{p}\|_1, \|\mathbf{q}\|_1\}`)
G0 : ndarray, shape (ns, nt), optional
Initialisation of the transportation matrix
numItermax : int, optional
@@ -887,12 +925,12 @@ def entropic_partial_gromov_wasserstein(C1, C2, p, q, reg, m=None, G0=None,
>>> y = np.array([3,2,98,199]).reshape((-1,1))
>>> C1 = sp.spatial.distance.cdist(x, x)
>>> C2 = sp.spatial.distance.cdist(y, y)
- >>> np.round(entropic_partial_gromov_wasserstein(C1, C2, a, b,50), 2)
+ >>> np.round(entropic_partial_gromov_wasserstein(C1, C2, a, b, 50), 2)
array([[0.12, 0.13, 0. , 0. ],
[0.13, 0.12, 0. , 0. ],
[0. , 0. , 0.25, 0. ],
[0. , 0. , 0. , 0.25]])
- >>> np.round(entropic_partial_gromov_wasserstein(C1, C2, a, b, 50, m=0.25), 2)
+ >>> np.round(entropic_partial_gromov_wasserstein(C1, C2, a, b, 50,0.25), 2)
array([[0.02, 0.03, 0. , 0.03],
[0.03, 0.03, 0. , 0.03],
[0. , 0. , 0.03, 0. ],
@@ -900,19 +938,22 @@ def entropic_partial_gromov_wasserstein(C1, C2, p, q, reg, m=None, G0=None,
Returns
-------
- :math: `gamma` : (dim_a x dim_b) ndarray
+ :math: `gamma` : (dim_a, dim_b) ndarray
Optimal transportation matrix for the given parameters
log : dict
log dictionary returned only if `log` is `True`
+
+ .. _references-entropic-partial-gromov-wassertein:
References
----------
.. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon,
"Gromov-Wasserstein averaging of kernel and distance matrices."
International Conference on Machine Learning (ICML). 2016.
- .. [29] Chapel, L., Alaya, M., Gasso, G. (2019). "Partial Gromov-
- Wasserstein with Applications on Positive-Unlabeled Learning".
- arXiv preprint arXiv:2002.08276.
+
+ .. [29] Chapel, L., Alaya, M., Gasso, G. (2020). "Partial Optimal
+ Transport with Applications on Positive-Unlabeled Learning".
+ NeurIPS.
See Also
--------
@@ -964,33 +1005,33 @@ def entropic_partial_gromov_wasserstein2(C1, C2, p, q, reg, m=None, G0=None,
numItermax=1000, tol=1e-7, log=False,
verbose=False):
r"""
- Returns the partial Gromov-Wasserstein discrepancy between (C1,p) and
- (C2,q)
+ Returns the partial Gromov-Wasserstein discrepancy between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})`
The function solves the following optimization problem:
.. math::
- GW = \arg\min_{\gamma} \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})\cdot
- \gamma_{i,j}\cdot\gamma_{k,l} + reg\cdot\Omega(\gamma)
+ GW = \min_{\gamma} \quad \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l})\cdot
+ \gamma_{i,j}\cdot\gamma_{k,l} + \mathrm{reg} \cdot\Omega(\gamma)
- s.t.
- \gamma\geq 0 \\
- \gamma 1 \leq a\\
- \gamma^T 1 \leq b\\
- 1^T \gamma^T 1 = m \leq \min\{\|a\|_1, \|b\|_1\}
+ .. math::
+ s.t. \ \gamma &\geq 0
+
+ \gamma \mathbf{1} &\leq \mathbf{a}
+
+ \gamma^T \mathbf{1} &\leq \mathbf{b}
+
+ \mathbf{1}^T \gamma^T \mathbf{1} = m &\leq \min\{\|\mathbf{a}\|_1, \|\mathbf{b}\|_1\}
where :
- - C1 is the metric cost matrix in the source space
- - C2 is the metric cost matrix in the target space
- - p and q are the sample weights
- - L : quadratic loss function
- - :math:`\Omega` is the entropic regularization term
- :math:`\Omega=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - m is the amount of mass to be transported
+ - :math:`\mathbf{C_1}` is the metric cost matrix in the source space
+ - :math:`\mathbf{C_2}` is the metric cost matrix in the target space
+ - :math:`\mathbf{p}` and :math:`\mathbf{q}` are the sample weights
+ - `L` : quadratic loss function
+ - :math:`\Omega` is the entropic regularization term, :math:`\Omega=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - `m` is the amount of mass to be transported
- The formulation of the GW problem has been proposed in [12]_ and the
- partial GW in [29]_.
+ The formulation of the GW problem has been proposed in :ref:`[12] <references-entropic-partial-gromov-wassertein2>` and the partial GW in :ref:`[29] <references-entropic-partial-gromov-wassertein2>`
Parameters
@@ -1006,7 +1047,7 @@ def entropic_partial_gromov_wasserstein2(C1, C2, p, q, reg, m=None, G0=None,
reg: float
entropic regularization parameter
m : float, optional
- Amount of mass to be transported (default: min (|p|_1, |q|_1))
+ Amount of mass to be transported (default: :math:`\min\{\|\mathbf{p}\|_1, \|\mathbf{q}\|_1\}`)
G0 : ndarray, shape (ns, nt), optional
Initialisation of the transportation matrix
numItermax : int, optional
@@ -1039,14 +1080,17 @@ def entropic_partial_gromov_wasserstein2(C1, C2, p, q, reg, m=None, G0=None,
>>> np.round(entropic_partial_gromov_wasserstein2(C1, C2, a, b,50), 2)
1.87
+
+ .. _references-entropic-partial-gromov-wassertein2:
References
----------
.. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon,
"Gromov-Wasserstein averaging of kernel and distance matrices."
International Conference on Machine Learning (ICML). 2016.
- .. [29] Chapel, L., Alaya, M., Gasso, G. (2019). "Partial Gromov-
- Wasserstein with Applications on Positive-Unlabeled Learning".
- arXiv preprint arXiv:2002.08276.
+
+ .. [29] Chapel, L., Alaya, M., Gasso, G. (2020). "Partial Optimal
+ Transport with Applications on Positive-Unlabeled Learning".
+ NeurIPS.
"""
partial_gw, log_gw = entropic_partial_gromov_wasserstein(C1, C2, p, q, reg,
diff --git a/ot/plot.py b/ot/plot.py
index ad436b4..3e3bed7 100644
--- a/ot/plot.py
+++ b/ot/plot.py
@@ -18,10 +18,10 @@ from matplotlib import gridspec
def plot1D_mat(a, b, M, title=''):
- """ Plot matrix M with the source and target 1D distribution
+ """ Plot matrix :math:`\mathbf{M}` with the source and target 1D distribution
- Creates a subplot with the source distribution a on the left and
- target distribution b on the tot. The matrix M is shown in between.
+ Creates a subplot with the source distribution :math:`\mathbf{a}` on the left and
+ target distribution :math:`\mathbf{b}` on the top. The matrix :math:`\mathbf{M}` is shown in between.
Parameters
@@ -61,10 +61,10 @@ def plot1D_mat(a, b, M, title=''):
def plot2D_samples_mat(xs, xt, G, thr=1e-8, **kwargs):
- """ Plot matrix M in 2D with lines using alpha values
+ """ Plot matrix :math:`\mathbf{G}` in 2D with lines using alpha values
Plot lines between source and target 2D samples with a color
- proportional to the value of the matrix G between samples.
+ proportional to the value of the matrix :math:`\mathbf{G}` between samples.
Parameters
diff --git a/ot/regpath.py b/ot/regpath.py
new file mode 100644
index 0000000..269937a
--- /dev/null
+++ b/ot/regpath.py
@@ -0,0 +1,827 @@
+# -*- coding: utf-8 -*-
+"""
+Regularization path OT solvers
+"""
+
+# Author: Haoran Wu <haoran.wu@univ-ubs.fr>
+# License: MIT License
+
+import numpy as np
+import scipy.sparse as sp
+
+
+def recast_ot_as_lasso(a, b, C):
+ r"""This function recasts the l2-penalized UOT problem as a Lasso problem
+
+ Recall the l2-penalized UOT problem defined in [Chapel et al., 2021]
+ .. math::
+ UOT = \min_T <C, T> + \lambda \|T 1_m - a\|_2^2 +
+ \lambda \|T^T 1_n - b\|_2^2
+ s.t.
+ T \geq 0
+ where :
+ - C is the (dim_a, dim_b) metric cost matrix
+ - :math:`\lambda` is the l2-regularization coefficient
+ - a and b are source and target distributions
+ - T is the transport plan to optimize
+
+ The problem above can be reformulated to a non-negative penalized
+ linear regression problem, particularly Lasso
+ .. math::
+ UOT2 = \min_t \gamma c^T t + 0.5 * \|H t - y\|_2^2
+ s.t.
+ t \geq 0
+ where :
+ - c is a (dim_a * dim_b, ) metric cost vector (flattened version of C)
+ - :math:`\gamma = 1/\lambda` is the l2-regularization coefficient
+ - y is the concatenation of vectors a and b, defined as y^T = [a^T b^T]
+ - H is a (dim_a + dim_b, dim_a * dim_b) metric matrix,
+ see [Chapel et al., 2021] for the design of H. The matrix product H t
+ computes both the source marginal and the target marginal.
+ - t is a (dim_a * dim_b, ) metric vector (flattened version of T)
+ Parameters
+ ----------
+ a : np.ndarray (dim_a,)
+ Histogram of dimension dim_a
+ b : np.ndarray (dim_b,)
+ Histogram of dimension dim_b
+ C : np.ndarray, shape (dim_a, dim_b)
+ Cost matrix
+ Returns
+ -------
+ H : np.ndarray (dim_a+dim_b, dim_a*dim_b)
+ Auxiliary matrix constituted by 0 and 1
+ y : np.ndarray (ns + nt, )
+ Concatenation of histogram a and histogram b
+ c : np.ndarray (ns * nt, )
+ Flattened array of cost matrix
+ Examples
+ --------
+ >>> import ot
+ >>> a = np.array([0.2, 0.3, 0.5])
+ >>> b = np.array([0.1, 0.9])
+ >>> C = np.array([[16., 25.], [28., 16.], [40., 36.]])
+ >>> H, y, c = ot.regpath.recast_ot_as_lasso(a, b, C)
+ >>> H.toarray()
+ array([[1., 1., 0., 0., 0., 0.],
+ [0., 0., 1., 1., 0., 0.],
+ [0., 0., 0., 0., 1., 1.],
+ [1., 0., 1., 0., 1., 0.],
+ [0., 1., 0., 1., 0., 1.]])
+ >>> y
+ array([0.2, 0.3, 0.5, 0.1, 0.9])
+ >>> c
+ array([16., 25., 28., 16., 40., 36.])
+
+ References
+ ----------
+ [Chapel et al., 2021]:
+ Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021).
+ Unbalanced optimal transport through non-negative penalized
+ linear regression.
+ """
+
+ dim_a = np.shape(a)[0]
+ dim_b = np.shape(b)[0]
+ y = np.concatenate((a, b))
+ c = C.flatten()
+ jHa = np.arange(dim_a * dim_b)
+ iHa = np.repeat(np.arange(dim_a), dim_b)
+ jHb = np.arange(dim_a * dim_b)
+ iHb = np.tile(np.arange(dim_b), dim_a) + dim_a
+ j = np.concatenate((jHa, jHb))
+ i = np.concatenate((iHa, iHb))
+ H = sp.csc_matrix((np.ones(dim_a * dim_b * 2), (i, j)),
+ shape=(dim_a + dim_b, dim_a * dim_b))
+ return H, y, c
+
+
+def recast_semi_relaxed_as_lasso(a, b, C):
+ r"""This function recasts the semi-relaxed l2-UOT problem as Lasso problem
+
+ .. math::
+ semi-relaxed UOT = \min_T <C, T> + \lambda \|T 1_m - a\|_2^2
+ s.t.
+ T^T 1_n = b
+ t \geq 0
+ where :
+ - C is the (dim_a, dim_b) metric cost matrix
+ - :math:`\lambda` is the l2-regularization coefficient
+ - a and b are source and target distributions
+ - T is the transport plan to optimize
+
+ The problem above can be reformulated as follows
+ .. math::
+ semi-relaxed UOT2 = \min_t \gamma c^T t + 0.5 * \|H_r t - a\|_2^2
+ s.t.
+ H_c t = b
+ t \geq 0
+ where :
+ - c is a (dim_a * dim_b, ) metric cost vector (flattened version of C)
+ - :math:`\gamma = 1/\lambda` is the l2-regularization coefficient
+ - H_r is a (dim_a, dim_a * dim_b) metric matrix,
+ which computes the sum along the rows of transport plan T
+ - H_c is a (dim_b, dim_a * dim_b) metric matrix,
+ which computes the sum along the columns of transport plan T
+ - t is a (dim_a * dim_b, ) metric vector (flattened version of T)
+ Parameters
+ ----------
+ a : np.ndarray (dim_a,)
+ Histogram of dimension dim_a
+ b : np.ndarray (dim_b,)
+ Histogram of dimension dim_b
+ C : np.ndarray, shape (dim_a, dim_b)
+ Cost matrix
+ Returns
+ -------
+ Hr : np.ndarray (dim_a, dim_a * dim_b)
+ Auxiliary matrix constituted by 0 and 1, which computes
+ the sum along the rows of transport plan T
+ Hc : np.ndarray (dim_b, dim_a * dim_b)
+ Auxiliary matrix constituted by 0 and 1, which computes
+ the sum along the columns of transport plan T
+ c : np.ndarray (ns * nt, )
+ Flattened array of cost matrix
+ Examples
+ --------
+ >>> import ot
+ >>> a = np.array([0.2, 0.3, 0.5])
+ >>> b = np.array([0.1, 0.9])
+ >>> C = np.array([[16., 25.], [28., 16.], [40., 36.]])
+ >>> Hr,Hc,c = ot.regpath.recast_semi_relaxed_as_lasso(a, b, C)
+ >>> Hr.toarray()
+ array([[1., 1., 0., 0., 0., 0.],
+ [0., 0., 1., 1., 0., 0.],
+ [0., 0., 0., 0., 1., 1.]])
+ >>> Hc.toarray()
+ array([[1., 0., 1., 0., 1., 0.],
+ [0., 1., 0., 1., 0., 1.]])
+ >>> c
+ array([16., 25., 28., 16., 40., 36.])
+ """
+
+ dim_a = np.shape(a)[0]
+ dim_b = np.shape(b)[0]
+
+ c = C.flatten()
+ jHr = np.arange(dim_a * dim_b)
+ iHr = np.repeat(np.arange(dim_a), dim_b)
+ jHc = np.arange(dim_a * dim_b)
+ iHc = np.tile(np.arange(dim_b), dim_a)
+
+ Hr = sp.csc_matrix((np.ones(dim_a * dim_b), (iHr, jHr)),
+ shape=(dim_a, dim_a * dim_b))
+ Hc = sp.csc_matrix((np.ones(dim_a * dim_b), (iHc, jHc)),
+ shape=(dim_b, dim_a * dim_b))
+
+ return Hr, Hc, c
+
+
+def ot_next_gamma(phi, delta, HtH, Hty, c, active_index, current_gamma):
+ r""" This function computes the next value of gamma if a variable
+ will be added in next iteration of the regularization path
+
+ We look for the largest value of gamma such that
+ the gradient of an inactive variable vanishes
+ .. math::
+ \max_{i \in \bar{A}} \frac{h_i^T(H_A \phi - y)}{h_i^T H_A \delta - c_i}
+ where :
+ - A is the current active set
+ - h_i is the ith column of auxiliary matrix H
+ - H_A is the sub-matrix constructed by the columns of H
+ whose indices belong to the active set A
+ - c_i is the ith element of cost vector c
+ - y is the concatenation of source and target distribution
+ - :math:`\phi` is the intercept of the solutions in current iteration
+ - :math:`\delta` is the slope of the solutions in current iteration
+ Parameters
+ ----------
+ phi : np.ndarray (|A|, )
+ Intercept of the solutions in current iteration (t is piecewise linear)
+ delta : np.ndarray (|A|, )
+ Slope of the solutions in current iteration (t is piecewise linear)
+ HtH : np.ndarray (dim_a * dim_b, dim_a * dim_b)
+ Matrix product of H^T H
+ Hty : np.ndarray (dim_a + dim_b, )
+ Matrix product of H^T y
+ c: np.ndarray (dim_a * dim_b, )
+ Flattened array of cost matrix C
+ active_index : list
+ Indices of active variables
+ current_gamma : float
+ Value of regularization coefficient at the start of current iteration
+ Returns
+ -------
+ next_gamma : float
+ Value of gamma if a variable is added to active set in next iteration
+ next_active_index : int
+ Index of variable to be activated
+ References
+ ----------
+ [Chapel et al., 2021]:
+ Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021).
+ Unbalanced optimal transport through non-negative penalized
+ linear regression.
+ """
+ M = (HtH[:, active_index].dot(phi) - Hty) / \
+ (HtH[:, active_index].dot(delta) - c + 1e-16)
+ M[active_index] = 0
+ M[M > (current_gamma - 1e-10 * current_gamma)] = 0
+ return np.max(M), np.argmax(M)
+
+
+def semi_relaxed_next_gamma(phi, delta, phi_u, delta_u, HrHr, Hc, Hra,
+ c, active_index, current_gamma):
+ r""" This function computes the next value of gamma when a variable is
+ active in the regularization path of semi-relaxed UOT.
+
+ By taking the Lagrangian form of the problem, we obtain a similar update
+ as the two-sided relaxed UOT
+ .. math::
+ \max_{i \in \bar{A}} \frac{h_{r i}^T(H_{r A} \phi - a) + h_{c i}^T
+ \phi_u}{h_{r i}^T H_{r A} \delta + h_{c i} \delta_u - c_i}
+ where :
+ - A is the current active set
+ - h_{r i} is the ith column of the matrix H_r
+ - h_{c i} is the ith column of the matrix H_c
+ - H_{r A} is the sub-matrix constructed by the columns of H_r
+ whose indices belong to the active set A
+ - c_i is the ith element of cost vector c
+ - y is the concatenation of source and target distribution
+ - :math:`\phi` is the intercept of the solutions in current iteration
+ - :math:`\delta` is the slope of the solutions in current iteration
+ - :math:`\phi_u` is the intercept of Lagrange parameter in current
+ iteration
+ - :math:`\delta_u` is the slope of Lagrange parameter in current iteration
+ Parameters
+ ----------
+ phi : np.ndarray (|A|, )
+ Intercept of the solutions in current iteration (t is piecewise linear)
+ delta : np.ndarray (|A|, )
+ Slope of the solutions in current iteration (t is piecewise linear)
+ phi_u : np.ndarray (dim_b, )
+ Intercept of the Lagrange parameter in current iteration (also linear)
+ delta_u : np.ndarray (dim_b, )
+ Slope of the Lagrange parameter in current iteration (also linear)
+ HrHr : np.ndarray (dim_a * dim_b, dim_a * dim_b)
+ Matrix product of H_r^T H_r
+ Hc : np.ndarray (dim_b, dim_a * dim_b)
+ Matrix that computes the sum along the columns of transport plan T
+ Hra : np.ndarray (dim_a * dim_b, )
+ Matrix product of H_r^T a
+ c: np.ndarray (dim_a * dim_b, )
+ Flattened array of cost matrix C
+ active_index : list
+ Indices of active variables
+ current_gamma : float
+ Value of regularization coefficient at the start of current iteration
+ Returns
+ -------
+ next_gamma : float
+ Value of gamma if a variable is added to active set in next iteration
+ next_active_index : int
+ Index of variable to be activated
+ References
+ ----------
+ [Chapel et al., 2021]:
+ Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021).
+ Unbalanced optimal transport through non-negative penalized
+ linear regression.
+ """
+
+ M = (HrHr[:, active_index].dot(phi) - Hra + Hc.T.dot(phi_u)) / \
+ (HrHr[:, active_index].dot(delta) - c + Hc.T.dot(delta_u) + 1e-16)
+ M[active_index] = 0
+ M[M > (current_gamma - 1e-10 * current_gamma)] = 0
+ return np.max(M), np.argmax(M)
+
+
+def compute_next_removal(phi, delta, current_gamma):
+ r""" This function computes the next value of gamma if a variable
+ is removed in next iteration of regularization path
+
+ We look for the largest value of gamma such that
+ an element of current solution vanishes
+ .. math::
+ \max_{j \in A} \frac{\phi_j}{\delta_j}
+ where :
+ - A is the current active set
+ - phi_j is the jth element of the intercept of current solution
+ - delta_j is the jth elemnt of the slope of current solution
+ Parameters
+ ----------
+ phi : np.ndarray (|A|, )
+ Intercept of the solutions in current iteration (t is piecewise linear)
+ delta : np.ndarray (|A|, )
+ Slope of the solutions in current iteration (t is piecewise linear)
+ current_gamma : float
+ Value of regularization coefficient at the start of current iteration
+ Returns
+ -------
+ next_removal_gamma : float
+ Value of gamma if a variable is removed in next iteration
+ next_removal_index : int
+ Index of the variable to remove in next iteration
+ References
+ ----------
+ [Chapel et al., 2021]:
+ Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021).
+ Unbalanced optimal transport through non-negative penalized
+ linear regression.
+ """
+ r_candidate = phi / (delta - 1e-16)
+ r_candidate[r_candidate >= (1 - 1e-8) * current_gamma] = 0
+ return np.max(r_candidate), np.argmax(r_candidate)
+
+
+def complement_schur(M_current, b, d, id_pop):
+ r""" This function computes the inverse of matrix in regularization path
+ using Schur complement
+
+ Two cases may arise: Firstly one variable is added to the active set
+ .. math::
+ M_{k+1}^{-1} =
+ \begin{bmatrix}
+ M_{k}^{-1} + s^{-1} M_{k}^{-1} b b^T M_{k}^{-1} & -s^{-1} \\
+ - s^{-1} b^T M_{k}^{-1} & s^{-1}
+ \end{bmatrix}
+ where :
+ - :math:`M_k^{-1}` is the inverse of matrix in previous iteration and
+ :math:`M_k` is the upper left block matrix in Schur formulation
+ - b is the upper right block matrix in Schur formulation. In our case,
+ b is reduced to a column vector and b^T is the lower left block matrix
+ - s is the Schur complement, given by
+ :math:`s = d - b^T M_{k}^{-1} b` in our case
+
+ Secondly, one variable is removed from the active set
+ .. math::
+ M_{k+1}^{-1} = M^{-1}_{A_k \backslash q} -
+ \frac{r_{-q,q} r^{T}_{-q,q}}{r_{q,q}}
+ where :
+ - q is the index of column and row to delete
+ - :math:`M^{-1}_{A_k \backslash q}` is the previous inverse matrix
+ without qth column and qth row
+ - r_{-q,q} is the qth column of :math:`M^{-1}_{k}` without the qth element
+ - r_{q, q} is the element of qth column and qth row in :math:`M^{-1}_{k}`
+ Parameters
+ ----------
+ M_current : np.ndarray (|A|-1, |A|-1)
+ Inverse matrix in previous iteration
+ b : np.ndarray (|A|-1, )
+ Upper right matrix in Schur complement, a column vector in our case
+ d : float
+ Lower right matrix in Schur complement, a scalar in our case
+ id_pop
+ Index of the variable to be removed, equal to -1
+ if none of the variables is deleted in current iteration
+ Returns
+ -------
+ M : np.ndarray (|A|, |A|)
+ Inverse matrix needed in current iteration
+ References
+ ----------
+ [Chapel et al., 2021]:
+ Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021).
+ Unbalanced optimal transport through non-negative penalized
+ linear regression.
+ """
+ if b is None:
+ b = M_current[id_pop, :]
+ b = np.delete(b, id_pop)
+ M_del = np.delete(M_current, id_pop, 0)
+ a = M_del[:, id_pop]
+ M_del = np.delete(M_del, id_pop, 1)
+ M = M_del - np.outer(a, b) / M_current[id_pop, id_pop]
+ else:
+ n = b.shape[0] + 1
+ if np.shape(b)[0] == 0:
+ M = np.array([[0.5]])
+ else:
+ X = M_current.dot(b)
+ s = d - b.T.dot(X)
+ M = np.zeros((n, n))
+ M[:-1, :-1] = M_current + X.dot(X.T) / s
+ X_ravel = X.ravel()
+ M[-1, :-1] = -X_ravel / s
+ M[:-1, -1] = -X_ravel / s
+ M[-1, -1] = 1 / s
+ return M
+
+
+def construct_augmented_H(active_index, m, Hc, HrHr):
+ r""" This function construct an augmented matrix for the first iteration of
+ semi-relaxed regularization path
+
+ .. math::
+ Augmented_H =
+ \begin{bmatrix}
+ 0 & H_{c A} \\
+ H_{c A}^T & H_{r A}^T H_{r A}
+ \end{bmatrix}
+ where :
+ - H_{r A} is the sub-matrix constructed by the columns of H_r
+ whose indices belong to the active set A
+ - H_{c A} is the sub-matrix constructed by the columns of H_c
+ whose indices belong to the active set A
+ Parameters
+ ----------
+ active_index : list
+ Indices of active variables
+ m : int
+ Length of the target distribution
+ Hc : np.ndarray (dim_b, dim_a * dim_b)
+ Matrix that computes the sum along the columns of transport plan T
+ HrHr : np.ndarray (dim_a * dim_b, dim_a * dim_b)
+ Matrix product of H_r^T H_r
+ Returns
+ -------
+ H_augmented : np.ndarray (dim_b + |A|, dim_b + |A|)
+ Augmented matrix for the first iteration of the semi-relaxed
+ regularization path
+ """
+ Hc_sub = Hc[:, active_index].toarray()
+ HrHr_sub = HrHr[:, active_index]
+ HrHr_sub = HrHr_sub[active_index, :].toarray()
+ H_augmented = np.block([[np.zeros((m, m)), Hc_sub], [Hc_sub.T, HrHr_sub]])
+ return H_augmented
+
+
+def fully_relaxed_path(a: np.array, b: np.array, C: np.array, reg=1e-4,
+ itmax=50000):
+ r"""This function gives the regularization path of l2-penalized UOT problem
+
+ The problem to optimize is the Lasso reformulation of the l2-penalized UOT:
+ .. math::
+ \min_t \gamma c^T t + 0.5 * \|H t - y\|_2^2
+ s.t.
+ t \geq 0
+ where :
+ - c is a (dim_a * dim_b, ) metric cost vector (flattened version of C)
+ - :math:`\gamma = 1/\lambda` is the l2-regularization coefficient
+ - y is the concatenation of vectors a and b, defined as y^T = [a^T b^T]
+ - H is a (dim_a + dim_b, dim_a * dim_b) metric matrix,
+ see [Chapel et al., 2021] for the design of H. The matrix product Ht
+ computes both the source marginal and the target marginal.
+ - t is a (dim_a * dim_b, ) metric vector (flattened version of T)
+ Parameters
+ ----------
+ a : np.ndarray (dim_a,)
+ Histogram of dimension dim_a
+ b : np.ndarray (dim_b,)
+ Histogram of dimension dim_b
+ C : np.ndarray, shape (dim_a, dim_b)
+ Cost matrix
+ reg: float
+ l2-regularization coefficient
+ itmax: int
+ Maximum number of iteration
+ Returns
+ -------
+ t : np.ndarray (dim_a*dim_b, )
+ Flattened vector of optimal transport matrix
+ t_list : list
+ List of solutions in regularization path
+ gamma_list : list
+ List of regularization coefficient in regularization path
+ Examples
+ --------
+ >>> import ot
+ >>> import numpy as np
+ >>> n = 3
+ >>> xs = np.array([1., 2., 3.]).reshape((n, 1))
+ >>> xt = np.array([5., 6., 7.]).reshape((n, 1))
+ >>> C = ot.dist(xs, xt)
+ >>> C /= C.max()
+ >>> a = np.array([0.2, 0.5, 0.3])
+ >>> b = np.array([0.2, 0.5, 0.3])
+ >>> t, _, _ = ot.regpath.fully_relaxed_path(a, b, C, 1e-4)
+ >>> t
+ array([1.99958333e-01, 0.00000000e+00, 0.00000000e+00, 3.88888889e-05,
+ 4.99938889e-01, 0.00000000e+00, 0.00000000e+00, 3.88888889e-05,
+ 2.99958333e-01])
+
+ References
+ ----------
+ [Chapel et al., 2021]:
+ Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021).
+ Unbalanced optimal transport through non-negative penalized
+ linear regression.
+ """
+
+ n = np.shape(a)[0]
+ m = np.shape(b)[0]
+ H, y, c = recast_ot_as_lasso(a, b, C)
+ HtH = H.T.dot(H)
+ Hty = H.T.dot(y)
+ n_iter = 1
+
+ # initialization
+ M0 = Hty / c
+ gamma_list = [np.max(M0)]
+ active_index = [np.argmax(M0)]
+ t_list = [np.zeros((n * m,))]
+ H_inv = np.array([[]])
+ add_col = np.array([])
+ id_pop = -1
+
+ while n_iter < itmax and gamma_list[-1] > reg:
+ H_inv = complement_schur(H_inv, add_col, 2., id_pop)
+ current_gamma = gamma_list[-1]
+
+ # compute the intercept and slope of solutions in current iteration
+ # t = phi - gamma * delta
+ phi = H_inv.dot(Hty[active_index])
+ delta = H_inv.dot(c[active_index])
+ gamma, ik = ot_next_gamma(phi, delta, HtH, Hty, c, active_index,
+ current_gamma)
+
+ # compute the next lambda when removing a point from the active set
+ alt_gamma, id_pop = compute_next_removal(phi, delta, current_gamma)
+
+ # if the positivity constraint is violated, we remove id_pop
+ # from active set, otherwise we add ik to active set
+ if alt_gamma > gamma:
+ gamma = alt_gamma
+ else:
+ id_pop = -1
+
+ # compute the solution of current segment
+ tA = phi - gamma * delta
+ sol = np.zeros((n * m, ))
+ sol[active_index] = tA
+
+ if id_pop != -1:
+ active_index.pop(id_pop)
+ add_col = None
+ else:
+ active_index.append(ik)
+ add_col = HtH[active_index[:-1], ik].toarray()
+
+ gamma_list.append(gamma)
+ t_list.append(sol)
+ n_iter += 1
+
+ if itmax <= n_iter:
+ print('maximum iteration has been reached !')
+
+ # correct the last solution and gamma
+ if len(t_list) > 1:
+ t_final = (t_list[-2] + (t_list[-1] - t_list[-2]) *
+ (reg - gamma_list[-2]) / (gamma_list[-1] - gamma_list[-2]))
+ t_list[-1] = t_final
+ gamma_list[-1] = reg
+ else:
+ gamma_list[-1] = reg
+ print('Regularization path does not exist !')
+
+ return t_list[-1], t_list, gamma_list
+
+
+def semi_relaxed_path(a: np.array, b: np.array, C: np.array, reg=1e-4,
+ itmax=50000):
+ r"""This function gives the regularization path of semi-relaxed
+ l2-UOT problem
+
+ The problem to optimize is the Lasso reformulation of the l2-penalized UOT:
+ .. math::
+ \min_t \gamma c^T t + 0.5 * \|H_r t - a\|_2^2
+ s.t.
+ H_c t = b
+ t \geq 0
+ where :
+ - c is a (dim_a * dim_b, ) metric cost vector (flattened version of C)
+ - :math:`\gamma = 1/\lambda` is the l2-regularization coefficient
+ - H_r is a (dim_a, dim_a * dim_b) metric matrix,
+ which computes the sum along the rows of transport plan T
+ - H_c is a (dim_b, dim_a * dim_b) metric matrix,
+ which computes the sum along the columns of transport plan T
+ - t is a (dim_a * dim_b, ) metric vector (flattened version of T)
+ Parameters
+ ----------
+ a : np.ndarray (dim_a,)
+ Histogram of dimension dim_a
+ b : np.ndarray (dim_b,)
+ Histogram of dimension dim_b
+ C : np.ndarray, shape (dim_a, dim_b)
+ Cost matrix
+ reg: float (optional)
+ l2-regularization coefficient
+ itmax: int (optional)
+ Maximum number of iteration
+ Returns
+ -------
+ t : np.ndarray (dim_a*dim_b, )
+ Flattened vector of optimal transport matrix
+ t_list : list
+ List of solutions in regularization path
+ gamma_list : list
+ List of regularization coefficient in regularization path
+ Examples
+ --------
+ >>> import ot
+ >>> import numpy as np
+ >>> n = 3
+ >>> xs = np.array([1., 2., 3.]).reshape((n, 1))
+ >>> xt = np.array([5., 6., 7.]).reshape((n, 1))
+ >>> C = ot.dist(xs, xt)
+ >>> C /= C.max()
+ >>> a = np.array([0.2, 0.5, 0.3])
+ >>> b = np.array([0.2, 0.5, 0.3])
+ >>> t, _, _ = ot.regpath.semi_relaxed_path(a, b, C, 1e-4)
+ >>> t
+ array([1.99980556e-01, 0.00000000e+00, 0.00000000e+00, 1.94444444e-05,
+ 4.99980556e-01, 0.00000000e+00, 0.00000000e+00, 1.94444444e-05,
+ 3.00000000e-01])
+
+ References
+ ----------
+ [Chapel et al., 2021]:
+ Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021).
+ Unbalanced optimal transport through non-negative penalized
+ linear regression.
+ """
+
+ n = np.shape(a)[0]
+ m = np.shape(b)[0]
+ Hr, Hc, c = recast_semi_relaxed_as_lasso(a, b, C)
+ Hra = Hr.T.dot(a)
+ HrHr = Hr.T.dot(Hr)
+ n_iter = 1
+ active_index = []
+
+ # initialization
+ for j in range(np.shape(C)[1]):
+ i = np.argmin(C[:, j])
+ active_index.append(i * m + j)
+ gamma_list = []
+ t_list = []
+ current_gamma = np.Inf
+ augmented_H0 = construct_augmented_H(active_index, m, Hc, HrHr)
+ add_col = np.array([])
+ id_pop = -1
+
+ while n_iter < itmax and current_gamma > reg:
+ if n_iter == 1:
+ H_inv = np.linalg.inv(augmented_H0)
+ else:
+ H_inv = complement_schur(H_inv, add_col, 1., id_pop + m)
+ # compute the intercept and slope of solutions in current iteration
+ augmented_phi = H_inv.dot(np.concatenate((b, Hra[active_index])))
+ augmented_delta = H_inv[:, m:].dot(c[active_index])
+ phi = augmented_phi[m:]
+ delta = augmented_delta[m:]
+ phi_u = augmented_phi[0:m]
+ delta_u = augmented_delta[0:m]
+ gamma, ik = semi_relaxed_next_gamma(phi, delta, phi_u, delta_u,
+ HrHr, Hc, Hra, c, active_index,
+ current_gamma)
+
+ # compute the next lambda when removing a point from the active set
+ alt_gamma, id_pop = compute_next_removal(phi, delta, current_gamma)
+
+ # if the positivity constraint is violated, we remove id_pop
+ # from active set, otherwise we add ik to active set
+ if alt_gamma > gamma:
+ gamma = alt_gamma
+ else:
+ id_pop = -1
+
+ # compute the solution of current segment
+ tA = phi - gamma * delta
+ sol = np.zeros((n * m, ))
+ sol[active_index] = tA
+ if id_pop != -1:
+ active_index.pop(id_pop)
+ add_col = None
+ else:
+ active_index.append(ik)
+ add_col = np.concatenate((Hc.toarray()[:, ik],
+ HrHr.toarray()[active_index[:-1], ik]))
+ add_col = add_col[:, np.newaxis]
+
+ gamma_list.append(gamma)
+ t_list.append(sol)
+ current_gamma = gamma
+ n_iter += 1
+
+ if itmax <= n_iter:
+ print('maximum iteration has been reached !')
+
+ # correct the last solution and gamma
+ if len(t_list) > 1:
+ t_final = (t_list[-2] + (t_list[-1] - t_list[-2]) *
+ (reg - gamma_list[-2]) / (gamma_list[-1] - gamma_list[-2]))
+ t_list[-1] = t_final
+ gamma_list[-1] = reg
+ else:
+ gamma_list[-1] = reg
+ print('Regularization path does not exist !')
+
+ return t_list[-1], t_list, gamma_list
+
+
+def regularization_path(a: np.array, b: np.array, C: np.array, reg=1e-4,
+ semi_relaxed=False, itmax=50000):
+ r"""This function combines both the semi-relaxed and the fully-relaxed
+ regularization paths of l2-UOT problem
+
+ Parameters
+ ----------
+ a : np.ndarray (dim_a,)
+ Histogram of dimension dim_a
+ b : np.ndarray (dim_b,)
+ Histogram of dimension dim_b
+ C : np.ndarray, shape (dim_a, dim_b)
+ Cost matrix
+ reg: float (optional)
+ l2-regularization coefficient
+ semi_relaxed : bool (optional)
+ Give the semi-relaxed path if true
+ itmax: int (optional)
+ Maximum number of iteration
+ Returns
+ -------
+ t : np.ndarray (dim_a*dim_b, )
+ Flattened vector of optimal transport matrix
+ t_list : list
+ List of solutions in regularization path
+ gamma_list : list
+ List of regularization coefficient in regularization path
+ References
+ ----------
+ [Chapel et al., 2021]:
+ Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021).
+ Unbalanced optimal transport through non-negative penalized
+ linear regression.
+ """
+ if semi_relaxed:
+ t, t_list, gamma_list = semi_relaxed_path(a, b, C, reg=reg,
+ itmax=itmax)
+ else:
+ t, t_list, gamma_list = fully_relaxed_path(a, b, C, reg=reg,
+ itmax=itmax)
+ return t, t_list, gamma_list
+
+
+def compute_transport_plan(gamma, gamma_list, Pi_list):
+ r""" Given the regularization path, this function computes the transport
+ plan for any value of gamma by the piecewise linearity of the path
+
+ .. math::
+ t(\gamma) = \phi(\gamma) - \gamma \delta(\gamma)
+ where :
+ - :math:`\gamma` is the regularization coefficient
+ - :math:`\phi(\gamma)` is the corresponding intercept
+ - :math:`\delta(\gamma)` is the corresponding slope
+ - t is a (dim_a * dim_b, ) vector (flattened version of transport matrix)
+ Parameters
+ ----------
+ gamma : float
+ Regularization coefficient
+ gamma_list : list
+ List of regularization coefficients in regularization path
+ Pi_list : list
+ List of solutions in regularization path
+ Returns
+ -------
+ t : np.ndarray (dim_a*dim_b, )
+ Transport vector corresponding to the given value of gamma
+ Examples
+ --------
+ >>> import ot
+ >>> import numpy as np
+ >>> n = 3
+ >>> xs = np.array([1., 2., 3.]).reshape((n, 1))
+ >>> xt = np.array([5., 6., 7.]).reshape((n, 1))
+ >>> C = ot.dist(xs, xt)
+ >>> C /= C.max()
+ >>> a = np.array([0.2, 0.5, 0.3])
+ >>> b = np.array([0.2, 0.5, 0.3])
+ >>> t, pi_list, g_list = ot.regpath.regularization_path(a, b, C, reg=1e-4)
+ >>> gamma = 1
+ >>> t2 = ot.regpath.compute_transport_plan(gamma, g_list, pi_list)
+ >>> t2
+ array([0. , 0. , 0. , 0.19722222, 0.05555556,
+ 0. , 0. , 0.24722222, 0. ])
+
+ References
+ ----------
+ [Chapel et al., 2021]:
+ Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021).
+ Unbalanced optimal transport through non-negative penalized
+ linear regression.
+ """
+
+ if gamma >= gamma_list[0]:
+ Pi = Pi_list[0]
+ elif gamma <= gamma_list[-1]:
+ Pi = Pi_list[-1]
+ else:
+ idx = np.where(gamma <= np.array(gamma_list))[0][-1]
+ gamma_k0 = gamma_list[idx]
+ gamma_k1 = gamma_list[idx + 1]
+ pi_k0 = Pi_list[idx]
+ pi_k1 = Pi_list[idx + 1]
+ Pi = pi_k0 + (pi_k1 - pi_k0) * (gamma - gamma_k0) \
+ / (gamma_k1 - gamma_k0)
+ return Pi
diff --git a/ot/sliced.py b/ot/sliced.py
new file mode 100644
index 0000000..cf2d3be
--- /dev/null
+++ b/ot/sliced.py
@@ -0,0 +1,258 @@
+"""
+Sliced OT Distances
+
+"""
+
+# Author: Adrien Corenflos <adrien.corenflos@aalto.fi>
+# Nicolas Courty <ncourty@irisa.fr>
+# Rémi Flamary <remi.flamary@polytechnique.edu>
+#
+# License: MIT License
+
+
+import numpy as np
+from .backend import get_backend, NumpyBackend
+from .utils import list_to_array
+
+
+def get_random_projections(d, n_projections, seed=None, backend=None, type_as=None):
+ r"""
+ Generates n_projections samples from the uniform on the unit sphere of dimension :math:`d-1`: :math:`\mathcal{U}(\mathcal{S}^{d-1})`
+
+ Parameters
+ ----------
+ d : int
+ dimension of the space
+ n_projections : int
+ number of samples requested
+ seed: int or RandomState, optional
+ Seed used for numpy random number generator
+ backend:
+ Backend to ue for random generation
+
+ Returns
+ -------
+ out: ndarray, shape (d, n_projections)
+ The uniform unit vectors on the sphere
+
+ Examples
+ --------
+ >>> n_projections = 100
+ >>> d = 5
+ >>> projs = get_random_projections(d, n_projections)
+ >>> np.allclose(np.sum(np.square(projs), 0), 1.) # doctest: +NORMALIZE_WHITESPACE
+ True
+
+ """
+
+ if backend is None:
+ nx = NumpyBackend()
+ else:
+ nx = backend
+
+ if isinstance(seed, np.random.RandomState) and str(nx) == 'numpy':
+ projections = seed.randn(d, n_projections)
+ else:
+ if seed is not None:
+ nx.seed(seed)
+ projections = nx.randn(d, n_projections, type_as=type_as)
+
+ projections = projections / nx.sqrt(nx.sum(projections**2, 0, keepdims=True))
+ return projections
+
+
+def sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, p=2,
+ projections=None, seed=None, log=False):
+ r"""
+ Computes a Monte-Carlo approximation of the p-Sliced Wasserstein distance
+
+ .. math::
+ \mathcal{SWD}_p(\mu, \nu) = \underset{\theta \sim \mathcal{U}(\mathbb{S}^{d-1})}{\mathbb{E}}\left(\mathcal{W}_p^p(\theta_\# \mu, \theta_\# \nu)\right)^{\frac{1}{p}}
+
+
+ where :
+
+ - :math:`\theta_\# \mu` stands for the pushforwards of the projection :math:`X \in \mathbb{R}^d \mapsto \langle \theta, X \rangle`
+
+
+ Parameters
+ ----------
+ X_s : ndarray, shape (n_samples_a, dim)
+ samples in the source domain
+ X_t : ndarray, shape (n_samples_b, dim)
+ samples in the target domain
+ a : ndarray, shape (n_samples_a,), optional
+ samples weights in the source domain
+ b : ndarray, shape (n_samples_b,), optional
+ samples weights in the target domain
+ n_projections : int, optional
+ Number of projections used for the Monte-Carlo approximation
+ p: float, optional =
+ Power p used for computing the sliced Wasserstein
+ projections: shape (dim, n_projections), optional
+ Projection matrix (n_projections and seed are not used in this case)
+ seed: int or RandomState or None, optional
+ Seed used for random number generator
+ log: bool, optional
+ if True, sliced_wasserstein_distance returns the projections used and their associated EMD.
+
+ Returns
+ -------
+ cost: float
+ Sliced Wasserstein Cost
+ log : dict, optional
+ log dictionary return only if log==True in parameters
+
+ Examples
+ --------
+
+ >>> n_samples_a = 20
+ >>> reg = 0.1
+ >>> X = np.random.normal(0., 1., (n_samples_a, 5))
+ >>> sliced_wasserstein_distance(X, X, seed=0) # doctest: +NORMALIZE_WHITESPACE
+ 0.0
+
+ References
+ ----------
+
+ .. [31] Bonneel, Nicolas, et al. "Sliced and radon wasserstein barycenters of measures." Journal of Mathematical Imaging and Vision 51.1 (2015): 22-45
+ """
+ from .lp import wasserstein_1d
+
+ X_s, X_t = list_to_array(X_s, X_t)
+
+ if a is not None and b is not None and projections is None:
+ nx = get_backend(X_s, X_t, a, b)
+ elif a is not None and b is not None and projections is not None:
+ nx = get_backend(X_s, X_t, a, b, projections)
+ elif a is None and b is None and projections is not None:
+ nx = get_backend(X_s, X_t, projections)
+ else:
+ nx = get_backend(X_s, X_t)
+
+ n = X_s.shape[0]
+ m = X_t.shape[0]
+
+ if X_s.shape[1] != X_t.shape[1]:
+ raise ValueError(
+ "X_s and X_t must have the same number of dimensions {} and {} respectively given".format(X_s.shape[1],
+ X_t.shape[1]))
+
+ if a is None:
+ a = nx.full(n, 1 / n, type_as=X_s)
+ if b is None:
+ b = nx.full(m, 1 / m, type_as=X_s)
+
+ d = X_s.shape[1]
+
+ if projections is None:
+ projections = get_random_projections(d, n_projections, seed, backend=nx, type_as=X_s)
+
+ X_s_projections = nx.dot(X_s, projections)
+ X_t_projections = nx.dot(X_t, projections)
+
+ projected_emd = wasserstein_1d(X_s_projections, X_t_projections, a, b, p=p)
+
+ res = (nx.sum(projected_emd) / n_projections) ** (1.0 / p)
+ if log:
+ return res, {"projections": projections, "projected_emds": projected_emd}
+ return res
+
+
+def max_sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, p=2,
+ projections=None, seed=None, log=False):
+ r"""
+ Computes a Monte-Carlo approximation of the max p-Sliced Wasserstein distance
+
+ .. math::
+ \mathcal{Max-SWD}_p(\mu, \nu) = \underset{\theta _in
+ \mathcal{U}(\mathbb{S}^{d-1})}{\max} [\mathcal{W}_p^p(\theta_\#
+ \mu, \theta_\# \nu)]^{\frac{1}{p}}
+
+ where :
+
+ - :math:`\theta_\# \mu` stands for the pushforwars of the projection :math:`\mathbb{R}^d \ni X \mapsto \langle \theta, X \rangle`
+
+
+ Parameters
+ ----------
+ X_s : ndarray, shape (n_samples_a, dim)
+ samples in the source domain
+ X_t : ndarray, shape (n_samples_b, dim)
+ samples in the target domain
+ a : ndarray, shape (n_samples_a,), optional
+ samples weights in the source domain
+ b : ndarray, shape (n_samples_b,), optional
+ samples weights in the target domain
+ n_projections : int, optional
+ Number of projections used for the Monte-Carlo approximation
+ p: float, optional =
+ Power p used for computing the sliced Wasserstein
+ projections: shape (dim, n_projections), optional
+ Projection matrix (n_projections and seed are not used in this case)
+ seed: int or RandomState or None, optional
+ Seed used for random number generator
+ log: bool, optional
+ if True, sliced_wasserstein_distance returns the projections used and their associated EMD.
+
+ Returns
+ -------
+ cost: float
+ Sliced Wasserstein Cost
+ log : dict, optional
+ log dictionary return only if log==True in parameters
+
+ Examples
+ --------
+
+ >>> n_samples_a = 20
+ >>> reg = 0.1
+ >>> X = np.random.normal(0., 1., (n_samples_a, 5))
+ >>> sliced_wasserstein_distance(X, X, seed=0) # doctest: +NORMALIZE_WHITESPACE
+ 0.0
+
+ References
+ ----------
+
+ .. [35] Deshpande, I., Hu, Y. T., Sun, R., Pyrros, A., Siddiqui, N., Koyejo, S., ... & Schwing, A. G. (2019). Max-sliced wasserstein distance and its use for gans. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (pp. 10648-10656).
+ """
+ from .lp import wasserstein_1d
+
+ X_s, X_t = list_to_array(X_s, X_t)
+
+ if a is not None and b is not None and projections is None:
+ nx = get_backend(X_s, X_t, a, b)
+ elif a is not None and b is not None and projections is not None:
+ nx = get_backend(X_s, X_t, a, b, projections)
+ elif a is None and b is None and projections is not None:
+ nx = get_backend(X_s, X_t, projections)
+ else:
+ nx = get_backend(X_s, X_t)
+
+ n = X_s.shape[0]
+ m = X_t.shape[0]
+
+ if X_s.shape[1] != X_t.shape[1]:
+ raise ValueError(
+ "X_s and X_t must have the same number of dimensions {} and {} respectively given".format(X_s.shape[1],
+ X_t.shape[1]))
+
+ if a is None:
+ a = nx.full(n, 1 / n, type_as=X_s)
+ if b is None:
+ b = nx.full(m, 1 / m, type_as=X_s)
+
+ d = X_s.shape[1]
+
+ if projections is None:
+ projections = get_random_projections(d, n_projections, seed, backend=nx, type_as=X_s)
+
+ X_s_projections = nx.dot(X_s, projections)
+ X_t_projections = nx.dot(X_t, projections)
+
+ projected_emd = wasserstein_1d(X_s_projections, X_t_projections, a, b, p=p)
+
+ res = nx.max(projected_emd) ** (1.0 / p)
+ if log:
+ return res, {"projections": projections, "projected_emds": projected_emd}
+ return res
diff --git a/ot/smooth.py b/ot/smooth.py
index 81f6a3e..6855005 100644
--- a/ot/smooth.py
+++ b/ot/smooth.py
@@ -47,15 +47,24 @@ from scipy.optimize import minimize
def projection_simplex(V, z=1, axis=None):
- """ Projection of x onto the simplex, scaled by z
+ r""" Projection of :math:`\mathbf{V}` onto the simplex, scaled by `z`
- P(x; z) = argmin_{y >= 0, sum(y) = z} ||y - x||^2
+ .. math::
+ P\left(\mathbf{V}, z\right) = \mathop{\arg \min}_{\substack{\mathbf{y} >= 0 \\ \sum_i \mathbf{y}_i = z}} \quad \|\mathbf{y} - \mathbf{V}\|^2
+
+ Parameters
+ ----------
+ V: ndarray, rank 2
z: float or array
- If array, len(z) must be compatible with V
+ If array, len(z) must be compatible with :math:`\mathbf{V}`
axis: None or int
- - axis=None: project V by P(V.ravel(); z)
- - axis=1: project each V[i] by P(V[i]; z[i])
- - axis=0: project each V[:, j] by P(V[:, j]; z[j])
+ - axis=None: project :math:`\mathbf{V}` by :math:`P(\mathbf{V}.\mathrm{ravel}(), z)`
+ - axis=1: project each :math:`\mathbf{V}_i` by :math:`P(\mathbf{V}_i, z_i)`
+ - axis=0: project each :math:`\mathbf{V}_{:, j}` by :math:`P(\mathbf{V}_{:, j}, z_j)`
+
+ Returns
+ -------
+ projection: ndarray, shape :math:`\mathbf{V}`.shape
"""
if axis == 1:
n_features = V.shape[1]
@@ -77,12 +86,12 @@ def projection_simplex(V, z=1, axis=None):
class Regularization(object):
- """Base class for Regularization objects
+ r"""Base class for Regularization objects
Notes
-----
- This class is not intended for direct use but as aparent for true
- regularizatiojn implementation.
+ This class is not intended for direct use but as apparent for true
+ regularization implementation.
"""
def __init__(self, gamma=1.0):
@@ -98,40 +107,48 @@ class Regularization(object):
self.gamma = gamma
def delta_Omega(X):
- """
- Compute delta_Omega(X[:, j]) for each X[:, j].
- delta_Omega(x) = sup_{y >= 0} y^T x - Omega(y).
+ r"""
+ Compute :math:`\delta_\Omega(\mathbf{X}_{:, j})` for each :math:`\mathbf{X}_{:, j}`.
+
+ .. math::
+ \delta_\Omega(\mathbf{x}) = \sup_{\mathbf{y} >= 0} \
+ \mathbf{y}^T \mathbf{x} - \Omega(\mathbf{y})
Parameters
----------
- X: array, shape = len(a) x len(b)
+ X: array, shape = (len(a), len(b))
Input array.
Returns
-------
- v: array, len(b)
- Values: v[j] = delta_Omega(X[:, j])
- G: array, len(a) x len(b)
- Gradients: G[:, j] = nabla delta_Omega(X[:, j])
+ v: array, (len(b), )
+ Values: :math:`\mathbf{v}_j = \delta_\Omega(\mathbf{X}_{:, j})`
+ G: array, (len(a), len(b))
+ Gradients: :math:`\mathbf{G}_{:, j} = \nabla \delta_\Omega(\mathbf{X}_{:, j})`
"""
raise NotImplementedError
def max_Omega(X, b):
- """
- Compute max_Omega_j(X[:, j]) for each X[:, j].
- max_Omega_j(x) = sup_{y >= 0, sum(y) = 1} y^T x - Omega(b[j] y) / b[j].
+ r"""
+ Compute :math:`\mathrm{max}_{\Omega, j}(\mathbf{X}_{:, j})` for each :math:`\mathbf{X}_{:, j}`.
+
+ .. math::
+ \mathrm{max}_{\Omega, j}(\mathbf{x}) =
+ \sup_{\substack{\mathbf{y} >= 0 \ \sum_i \mathbf{y}_i = 1}}
+ \mathbf{y}^T \mathbf{x} - \frac{1}{\mathbf{b}_j} \Omega(\mathbf{b}_j \mathbf{y})
Parameters
----------
- X: array, shape = len(a) x len(b)
+ X: array, shape = (len(a), len(b))
Input array.
+ b: array, shape = (len(b), )
Returns
-------
- v: array, len(b)
- Values: v[j] = max_Omega_j(X[:, j])
- G: array, len(a) x len(b)
- Gradients: G[:, j] = nabla max_Omega_j(X[:, j])
+ v: array, (len(b), )
+ Values: :math:`\mathbf{v}_j = \mathrm{max}_{\Omega, j}(\mathbf{X}_{:, j})`
+ G: array, (len(a), len(b))
+ Gradients: :math:`\mathbf{G}_{:, j} = \nabla \mathrm{max}_{\Omega, j}(\mathbf{X}_{:, j})`
"""
raise NotImplementedError
@@ -192,7 +209,7 @@ class SquaredL2(Regularization):
def dual_obj_grad(alpha, beta, a, b, C, regul):
- """
+ r"""
Compute objective value and gradients of dual objective.
Parameters
@@ -203,19 +220,19 @@ def dual_obj_grad(alpha, beta, a, b, C, regul):
a: array, shape = len(a)
b: array, shape = len(b)
Input histograms (should be non-negative and sum to 1).
- C: array, shape = len(a) x len(b)
+ C: array, shape = (len(a), len(b))
Ground cost matrix.
regul: Regularization object
- Should implement a delta_Omega(X) method.
+ Should implement a `delta_Omega(X)` method.
Returns
-------
obj: float
Objective value (higher is better).
grad_alpha: array, shape = len(a)
- Gradient w.r.t. alpha.
+ Gradient w.r.t. `alpha`.
grad_beta: array, shape = len(b)
- Gradient w.r.t. beta.
+ Gradient w.r.t. `beta`.
"""
obj = np.dot(alpha, a) + np.dot(beta, b)
grad_alpha = a.copy()
@@ -242,13 +259,13 @@ def solve_dual(a, b, C, regul, method="L-BFGS-B", tol=1e-3, max_iter=500,
Parameters
----------
- a: array, shape = len(a)
- b: array, shape = len(b)
+ a: array, shape = (len(a), )
+ b: array, shape = (len(b), )
Input histograms (should be non-negative and sum to 1).
- C: array, shape = len(a) x len(b)
+ C: array, shape = (len(a), len(b))
Ground cost matrix.
regul: Regularization object
- Should implement a delta_Omega(X) method.
+ Should implement a `delta_Omega(X)` method.
method: str
Solver to be used (passed to `scipy.optimize.minimize`).
tol: float
@@ -258,8 +275,8 @@ def solve_dual(a, b, C, regul, method="L-BFGS-B", tol=1e-3, max_iter=500,
Returns
-------
- alpha: array, shape = len(a)
- beta: array, shape = len(b)
+ alpha: array, shape = (len(a), )
+ beta: array, shape = (len(b), )
Dual potentials.
"""
@@ -302,10 +319,10 @@ def semi_dual_obj_grad(alpha, a, b, C, regul):
a: array, shape = len(a)
b: array, shape = len(b)
Input histograms (should be non-negative and sum to 1).
- C: array, shape = len(a) x len(b)
+ C: array, shape = (len(a), len(b))
Ground cost matrix.
regul: Regularization object
- Should implement a max_Omega(X) method.
+ Should implement a `max_Omega(X)` method.
Returns
-------
@@ -337,13 +354,13 @@ def solve_semi_dual(a, b, C, regul, method="L-BFGS-B", tol=1e-3, max_iter=500,
Parameters
----------
- a: array, shape = len(a)
- b: array, shape = len(b)
+ a: array, shape = (len(a), )
+ b: array, shape = (len(b), )
Input histograms (should be non-negative and sum to 1).
- C: array, shape = len(a) x len(b)
+ C: array, shape = (len(a), len(b))
Ground cost matrix.
regul: Regularization object
- Should implement a max_Omega(X) method.
+ Should implement a `max_Omega(X)` method.
method: str
Solver to be used (passed to `scipy.optimize.minimize`).
tol: float
@@ -353,7 +370,7 @@ def solve_semi_dual(a, b, C, regul, method="L-BFGS-B", tol=1e-3, max_iter=500,
Returns
-------
- alpha: array, shape = len(a)
+ alpha: array, shape = (len(a), )
Semi-dual potentials.
"""
@@ -371,7 +388,7 @@ def solve_semi_dual(a, b, C, regul, method="L-BFGS-B", tol=1e-3, max_iter=500,
def get_plan_from_dual(alpha, beta, C, regul):
- """
+ r"""
Retrieve optimal transportation plan from optimal dual potentials.
Parameters
@@ -379,14 +396,14 @@ def get_plan_from_dual(alpha, beta, C, regul):
alpha: array, shape = len(a)
beta: array, shape = len(b)
Optimal dual potentials.
- C: array, shape = len(a) x len(b)
+ C: array, shape = (len(a), len(b))
Ground cost matrix.
regul: Regularization object
- Should implement a delta_Omega(X) method.
+ Should implement a `delta_Omega(X)` method.
Returns
-------
- T: array, shape = len(a) x len(b)
+ T: array, shape = (len(a), len(b))
Optimal transportation plan.
"""
X = alpha[:, np.newaxis] + beta - C
@@ -394,7 +411,7 @@ def get_plan_from_dual(alpha, beta, C, regul):
def get_plan_from_semi_dual(alpha, b, C, regul):
- """
+ r"""
Retrieve optimal transportation plan from optimal semi-dual potentials.
Parameters
@@ -403,14 +420,14 @@ def get_plan_from_semi_dual(alpha, b, C, regul):
Optimal semi-dual potentials.
b: array, shape = len(b)
Second input histogram (should be non-negative and sum to 1).
- C: array, shape = len(a) x len(b)
+ C: array, shape = (len(a), len(b))
Ground cost matrix.
regul: Regularization object
- Should implement a delta_Omega(X) method.
+ Should implement a `delta_Omega(X)` method.
Returns
-------
- T: array, shape = len(a) x len(b)
+ T: array, shape = (len(a), len(b))
Optimal transportation plan.
"""
X = alpha[:, np.newaxis] - C
@@ -422,19 +439,21 @@ def smooth_ot_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=1e-9,
r"""
Solve the regularized OT problem in the dual and return the OT matrix
- The function solves the smooth relaxed dual formulation (7) in [17]_ :
+ The function solves the smooth relaxed dual formulation (7) in
+ :ref:`[17] <references-smooth-ot-dual>`:
.. math::
- \max_{\alpha,\beta}\quad a^T\alpha+b^T\beta-\sum_j\delta_\Omega(\alpha+\beta_j-\mathbf{m}_j)
+ \max_{\alpha,\beta}\quad \mathbf{a}^T\alpha + \mathbf{b}^T\beta -
+ \sum_j \delta_\Omega \left(\alpha+\beta_j-\mathbf{m}_j \right)
where :
- - :math:`\mathbf{m}_j` is the jth column of the cost matrix
+ - :math:`\mathbf{m}_j` is the j-th column of the cost matrix
- :math:`\delta_\Omega` is the convex conjugate of the regularization term :math:`\Omega`
- - a and b are source and target weights (sum to 1)
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1)
The OT matrix can is reconstructed from the gradient of :math:`\delta_\Omega`
- (See [17]_ Proposition 1).
+ (See :ref:`[17] <references-smooth-ot-dual>` Proposition 1).
The optimization algorithm is using gradient decent (L-BFGS by default).
@@ -444,21 +463,25 @@ def smooth_ot_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=1e-9,
samples weights in the source domain
b : np.ndarray (nt,) or np.ndarray (nt,nbb)
samples in the target domain, compute sinkhorn with multiple targets
- and fixed M if b is a matrix (return OT loss + dual variables in log)
+ and fixed :math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix
+ (return OT loss + dual variables in log)
M : np.ndarray (ns,nt)
loss matrix
reg : float
Regularization term >0
reg_type : str
- Regularization type, can be the following (default ='l2'):
- - 'kl' : Kullback Leibler (~ Neg-entropy used in sinkhorn [2]_)
- - 'l2' : Squared Euclidean regularization
+ Regularization type, can be the following (default ='l2'):
+
+ - 'kl' : Kullback Leibler (~ Neg-entropy used in sinkhorn
+ :ref:`[2] <references-smooth-ot-dual>`)
+
+ - 'l2' : Squared Euclidean regularization
method : str
Solver to use for scipy.optimize.minimize
numItermax : int, optional
Max number of iterations
stopThr : float, optional
- Stop threshol on error (>0)
+ Stop threshold on error (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
@@ -467,15 +490,15 @@ def smooth_ot_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=1e-9,
Returns
-------
- gamma : (ns x nt) ndarray
+ gamma : (ns, nt) ndarray
Optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
+ .. _references-smooth-ot-dual:
References
----------
-
.. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
.. [17] Blondel, M., Seguy, V., & Rolet, A. (2018). Smooth and Sparse Optimal Transport. Proceedings of the Twenty-First International Conference on Artificial Intelligence and Statistics (AISTATS).
@@ -514,21 +537,23 @@ def smooth_ot_semi_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=
r"""
Solve the regularized OT problem in the semi-dual and return the OT matrix
- The function solves the smooth relaxed dual formulation (10) in [17]_ :
+ The function solves the smooth relaxed dual formulation (10) in
+ :ref:`[17] <references-smooth-ot-semi-dual>`:
.. math::
- \max_{\alpha}\quad a^T\alpha-OT_\Omega^*(\alpha,b)
+ \max_{\alpha}\quad \mathbf{a}^T\alpha- \mathrm{OT}_\Omega^*(\alpha, \mathbf{b})
where :
.. math::
- OT_\Omega^*(\alpha,b)=\sum_j b_j
+ \mathrm{OT}_\Omega^*(\alpha,b)=\sum_j \mathbf{b}_j
- - :math:`\mathbf{m}_j` is the jth column of the cost matrix
- - :math:`OT_\Omega^*(\alpha,b)` is defined in Eq. (9) in [17]
- - a and b are source and target weights (sum to 1)
+ - :math:`\mathbf{m}_j` is the j-th column of the cost matrix
+ - :math:`\mathrm{OT}_\Omega^*(\alpha,b)` is defined in Eq. (9) in
+ :ref:`[17] <references-smooth-ot-semi-dual>`
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1)
- The OT matrix can is reconstructed using [17]_ Proposition 2.
+ The OT matrix can is reconstructed using :ref:`[17] <references-smooth-ot-semi-dual>` Proposition 2.
The optimization algorithm is using gradient decent (L-BFGS by default).
@@ -538,21 +563,25 @@ def smooth_ot_semi_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=
samples weights in the source domain
b : np.ndarray (nt,) or np.ndarray (nt,nbb)
samples in the target domain, compute sinkhorn with multiple targets
- and fixed M if b is a matrix (return OT loss + dual variables in log)
+ and fixed:math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix
+ (return OT loss + dual variables in log)
M : np.ndarray (ns,nt)
loss matrix
reg : float
Regularization term >0
reg_type : str
- Regularization type, can be the following (default ='l2'):
- - 'kl' : Kullback Leibler (~ Neg-entropy used in sinkhorn [2]_)
- - 'l2' : Squared Euclidean regularization
+ Regularization type, can be the following (default ='l2'):
+
+ - 'kl' : Kullback Leibler (~ Neg-entropy used in sinkhorn
+ :ref:`[2] <references-smooth-ot-semi-dual>`)
+
+ - 'l2' : Squared Euclidean regularization
method : str
Solver to use for scipy.optimize.minimize
numItermax : int, optional
Max number of iterations
stopThr : float, optional
- Stop threshol on error (>0)
+ Stop threshold on error (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
@@ -561,15 +590,15 @@ def smooth_ot_semi_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=
Returns
-------
- gamma : (ns x nt) ndarray
+ gamma : (ns, nt) ndarray
Optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
+ .. _references-smooth-ot-semi-dual:
References
----------
-
.. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
.. [17] Blondel, M., Seguy, V., & Rolet, A. (2018). Smooth and Sparse Optimal Transport. Proceedings of the Twenty-First International Conference on Artificial Intelligence and Statistics (AISTATS).
diff --git a/ot/stochastic.py b/ot/stochastic.py
index 13ed9cc..693675f 100644
--- a/ot/stochastic.py
+++ b/ot/stochastic.py
@@ -18,22 +18,25 @@ import numpy as np
def coordinate_grad_semi_dual(b, M, reg, beta, i):
r'''
- Compute the coordinate gradient update for regularized discrete distributions for (i, :)
+ Compute the coordinate gradient update for regularized discrete distributions for :math:`(i, :)`
The function computes the gradient of the semi dual problem:
.. math::
- \max_v \sum_i (\sum_j v_j * b_j - reg * log(\sum_j exp((v_j - M_{i,j})/reg) * b_j)) * a_i
+ \max_\mathbf{v} \ \sum_i \mathbf{a}_i \left[ \sum_j \mathbf{v}_j \mathbf{b}_j - \mathrm{reg}
+ \cdot \log \left( \sum_j \mathbf{b}_j
+ \exp \left( \frac{\mathbf{v}_j - \mathbf{M}_{i,j}}{\mathrm{reg}}
+ \right) \right) \right]
Where :
- - M is the (ns,nt) metric cost matrix
- - v is a dual variable in R^J
+ - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix
+ - :math:`\mathbf{v}` is a dual variable in :math:`\mathbb{R}^{nt}`
- reg is the regularization term
- - a and b are source and target weights (sum to 1)
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1)
The algorithm used for solving the problem is the ASGD & SAG algorithms
- as proposed in [18]_ [alg.1 & alg.2]
+ as proposed in :ref:`[18] <references-coordinate-grad-semi-dual>` [alg.1 & alg.2]
Parameters
@@ -47,7 +50,7 @@ def coordinate_grad_semi_dual(b, M, reg, beta, i):
v : ndarray, shape (nt,)
Dual variable.
i : int
- Picked number i.
+ Picked number `i`.
Returns
-------
@@ -74,12 +77,10 @@ def coordinate_grad_semi_dual(b, M, reg, beta, i):
[4.14237512e-02, 2.67487857e-02, 7.23016955e-02, 2.38291052e-03]])
+ .. _references-coordinate-grad-semi-dual:
References
----------
- [Genevay et al., 2016] :
- Stochastic Optimization for Large-scale Optimal Transport,
- Advances in Neural Information Processing Systems (2016),
- arXiv preprint arxiv:1605.08527.
+ .. [18] Genevay, A., Cuturi, M., Peyré, G. & Bach, F. (2016) Stochastic Optimization for Large-scale Optimal Transport. Advances in Neural Information Processing Systems (2016).
'''
r = M[i, :] - beta
exp_beta = np.exp(-r / reg) * b
@@ -88,29 +89,29 @@ def coordinate_grad_semi_dual(b, M, reg, beta, i):
def sag_entropic_transport(a, b, M, reg, numItermax=10000, lr=None):
- r'''
- Compute the SAG algorithm to solve the regularized discrete measures
- optimal transport max problem
+ r"""
+ Compute the SAG algorithm to solve the regularized discrete measures optimal transport max problem
The function solves the following optimization problem:
.. math::
- \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)
+ \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
+ \mathrm{reg} \cdot\Omega(\gamma)
- s.t. \gamma 1 = a
+ s.t. \ \gamma \mathbf{1} = \mathbf{a}
- \gamma^T 1 = b
+ \gamma^T \mathbf{1} = \mathbf{b}
\gamma \geq 0
Where :
- - M is the (ns,nt) metric cost matrix
+ - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix
- :math:`\Omega` is the entropic regularization term with :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - a and b are source and target weights (sum to 1)
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1)
The algorithm used for solving the problem is the SAG algorithm
- as proposed in [18]_ [alg.1]
+ as proposed in :ref:`[18] <references-sag-entropic-transport>` [alg.1]
Parameters
@@ -131,7 +132,7 @@ def sag_entropic_transport(a, b, M, reg, numItermax=10000, lr=None):
Returns
-------
- v : ndarray, shape (nt,)
+ v : ndarray, shape (`nt`,)
Dual variable.
Examples
@@ -154,14 +155,12 @@ def sag_entropic_transport(a, b, M, reg, numItermax=10000, lr=None):
[2.17218856e-02, 9.12931802e-04, 1.87962526e-03, 1.18342700e-01],
[4.14237512e-02, 2.67487857e-02, 7.23016955e-02, 2.38291052e-03]])
+
+ .. _references-sag-entropic-transport:
References
----------
-
- [Genevay et al., 2016] :
- Stochastic Optimization for Large-scale Optimal Transport,
- Advances in Neural Information Processing Systems (2016),
- arXiv preprint arxiv:1605.08527.
- '''
+ .. [18] Genevay, A., Cuturi, M., Peyré, G. & Bach, F. (2016) Stochastic Optimization for Large-scale Optimal Transport. Advances in Neural Information Processing Systems (2016).
+ """
if lr is None:
lr = 1. / max(a / reg)
@@ -187,22 +186,23 @@ def averaged_sgd_entropic_transport(a, b, M, reg, numItermax=300000, lr=None):
The function solves the following optimization problem:
.. math::
- \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)
+ \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
+ \mathrm{reg}\cdot\Omega(\gamma)
- s.t. \gamma 1 = a
+ s.t. \gamma \mathbf{1} = \mathbf{a}
- \gamma^T 1= b
+ \gamma^T \mathbf{1} = \mathbf{b}
\gamma \geq 0
Where :
- - M is the (ns,nt) metric cost matrix
+ - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix
- :math:`\Omega` is the entropic regularization term with :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - a and b are source and target weights (sum to 1)
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1)
The algorithm used for solving the problem is the ASGD algorithm
- as proposed in [18]_ [alg.2]
+ as proposed in :ref:`[18] <references-averaged-sgd-entropic-transport>` [alg.2]
Parameters
@@ -220,7 +220,7 @@ def averaged_sgd_entropic_transport(a, b, M, reg, numItermax=300000, lr=None):
Returns
-------
- ave_v : ndarray, shape (nt,)
+ ave_v : ndarray, shape (`nt`,)
dual variable
Examples
@@ -243,13 +243,11 @@ def averaged_sgd_entropic_transport(a, b, M, reg, numItermax=300000, lr=None):
[2.17218856e-02, 9.12931802e-04, 1.87962526e-03, 1.18342700e-01],
[4.14237512e-02, 2.67487857e-02, 7.23016955e-02, 2.38291052e-03]])
+
+ .. _references-averaged-sgd-entropic-transport:
References
----------
-
- [Genevay et al., 2016] :
- Stochastic Optimization for Large-scale Optimal Transport,
- Advances in Neural Information Processing Systems (2016),
- arXiv preprint arxiv:1605.08527.
+ .. [18] Genevay, A., Cuturi, M., Peyré, G. & Bach, F. (2016) Stochastic Optimization for Large-scale Optimal Transport. Advances in Neural Information Processing Systems (2016).
'''
if lr is None:
@@ -271,20 +269,21 @@ def c_transform_entropic(b, M, reg, beta):
r'''
The goal is to recover u from the c-transform.
- The function computes the c_transform of a dual variable from the other
+ The function computes the c-transform of a dual variable from the other
dual variable:
.. math::
- u = v^{c,reg} = -reg \sum_j exp((v - M)/reg) b_j
+ \mathbf{u} = \mathbf{v}^{c,reg} = - \mathrm{reg} \sum_j \mathbf{b}_j
+ \exp\left( \frac{\mathbf{v} - \mathbf{M}}{\mathrm{reg}} \right)
Where :
- - M is the (ns,nt) metric cost matrix
- - u, v are dual variables in R^IxR^J
+ - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix
+ - :math:`\mathbf{u}`, :math:`\mathbf{v}` are dual variables in :math:`\mathbb{R}^{ns} \times \mathbb{R}^{nt}`
- reg is the regularization term
It is used to recover an optimal u from optimal v solving the semi dual
- problem, see Proposition 2.1 of [18]_
+ problem, see Proposition 2.1 of :ref:`[18] <references-c-transform-entropic>`
Parameters
@@ -300,7 +299,7 @@ def c_transform_entropic(b, M, reg, beta):
Returns
-------
- u : ndarray, shape (ns,)
+ u : ndarray, shape (`ns`,)
Dual variable.
Examples
@@ -323,13 +322,11 @@ def c_transform_entropic(b, M, reg, beta):
[2.17218856e-02, 9.12931802e-04, 1.87962526e-03, 1.18342700e-01],
[4.14237512e-02, 2.67487857e-02, 7.23016955e-02, 2.38291052e-03]])
+
+ .. _references-c-transform-entropic:
References
----------
-
- [Genevay et al., 2016] :
- Stochastic Optimization for Large-scale Optimal Transport,
- Advances in Neural Information Processing Systems (2016),
- arXiv preprint arxiv:1605.08527.
+ .. [18] Genevay, A., Cuturi, M., Peyré, G. & Bach, F. (2016) Stochastic Optimization for Large-scale Optimal Transport. Advances in Neural Information Processing Systems (2016).
'''
n_source = np.shape(M)[0]
@@ -345,27 +342,28 @@ def c_transform_entropic(b, M, reg, beta):
def solve_semi_dual_entropic(a, b, M, reg, method, numItermax=10000, lr=None,
log=False):
r'''
- Compute the transportation matrix to solve the regularized discrete
- measures optimal transport max problem
+ Compute the transportation matrix to solve the regularized discrete measures optimal transport max problem
The function solves the following optimization problem:
.. math::
- \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)
+ \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
+ \mathrm{reg} \cdot\Omega(\gamma)
- s.t. \gamma 1 = a
+ s.t. \ \gamma \mathbf{1} = \mathbf{a}
- \gamma^T 1= b
+ \gamma^T \mathbf{1} = \mathbf{b}
\gamma \geq 0
Where :
- - M is the (ns,nt) metric cost matrix
+ - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix
- :math:`\Omega` is the entropic regularization term with :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - a and b are source and target weights (sum to 1)
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1)
+
The algorithm used for solving the problem is the SAG or ASGD algorithms
- as proposed in [18]_
+ as proposed in :ref:`[18] <references-solve-semi-dual-entropic>`
Parameters
@@ -419,13 +417,11 @@ def solve_semi_dual_entropic(a, b, M, reg, method, numItermax=10000, lr=None,
[2.17218856e-02, 9.12931802e-04, 1.87962526e-03, 1.18342700e-01],
[4.14237512e-02, 2.67487857e-02, 7.23016955e-02, 2.38291052e-03]])
+
+ .. _references-solve-semi-dual-entropic:
References
----------
-
- [Genevay et al., 2016] :
- Stochastic Optimization for Large-scale Optimal Transport,
- Advances in Neural Information Processing Systems (2016),
- arXiv preprint arxiv:1605.08527.
+ .. [18] Genevay, A., Cuturi, M., Peyré, G. & Bach, F. (2016) Stochastic Optimization for Large-scale Optimal Transport. Advances in Neural Information Processing Systems (2016).
'''
if method.lower() == "sag":
@@ -459,26 +455,30 @@ def batch_grad_dual(a, b, M, reg, alpha, beta, batch_size, batch_alpha,
r'''
Computes the partial gradient of the dual optimal transport problem.
- For each (i,j) in a batch of coordinates, the partial gradients are :
+ For each :math:`(i,j)` in a batch of coordinates, the partial gradients are :
.. math::
- \partial_{u_i} F = u_i * b_s/l_{v} - \sum_{j \in B_v} exp((u_i + v_j - M_{i,j})/reg) * a_i * b_j
+ \partial_{\mathbf{u}_i} F = \frac{b_s}{l_v} \mathbf{u}_i -
+ \sum_{j \in B_v} \mathbf{a}_i \mathbf{b}_j
+ \exp\left( \frac{\mathbf{u}_i + \mathbf{v}_j - \mathbf{M}_{i,j}}{\mathrm{reg}} \right)
- \partial_{v_j} F = v_j * b_s/l_{u} - \sum_{i \in B_u} exp((u_i + v_j - M_{i,j})/reg) * a_i * b_j
+ \partial_{\mathbf{v}_j} F = \frac{b_s}{l_u} \mathbf{v}_j -
+ \sum_{i \in B_u} \mathbf{a}_i \mathbf{b}_j
+ \exp\left( \frac{\mathbf{u}_i + \mathbf{v}_j - \mathbf{M}_{i,j}}{\mathrm{reg}} \right)
Where :
- - M is the (ns,nt) metric cost matrix
- - u, v are dual variables in R^ixR^J
+ - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix
+ - :math:`\mathbf{u}`, :math:`\mathbf{v}` are dual variables in :math:`\mathbb{R}^{ns} \times \mathbb{R}^{nt}`
- reg is the regularization term
- :math:`B_u` and :math:`B_v` are lists of index
- - :math:`b_s` is the size of the batchs :math:`B_u` and :math:`B_v`
- - :math:`l_u` and :math:`l_v` are the lenghts of :math:`B_u` and :math:`B_v`
- - a and b are source and target weights (sum to 1)
+ - :math:`b_s` is the size of the batches :math:`B_u` and :math:`B_v`
+ - :math:`l_u` and :math:`l_v` are the lengths of :math:`B_u` and :math:`B_v`
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1)
The algorithm used for solving the dual problem is the SGD algorithm
- as proposed in [19]_ [alg.1]
+ as proposed in :ref:`[19] <references-batch-grad-dual>` [alg.1]
Parameters
@@ -504,7 +504,7 @@ def batch_grad_dual(a, b, M, reg, alpha, beta, batch_size, batch_alpha,
Returns
-------
- grad : ndarray, shape (ns,)
+ grad : ndarray, shape (`ns`,)
partial grad F
Examples
@@ -533,12 +533,11 @@ def batch_grad_dual(a, b, M, reg, alpha, beta, batch_size, batch_alpha,
[5.06266486e-02, 2.16230494e-03, 2.26215141e-03, 6.81514609e-04],
[6.06713990e-02, 3.98139808e-02, 5.46829338e-02, 8.62371424e-06]])
+
+ .. _references-batch-grad-dual:
References
----------
-
- [Seguy et al., 2018] :
- International Conference on Learning Representation (2018),
- arXiv preprint arxiv:1711.02283.
+ .. [19] Seguy, V., Bhushan Damodaran, B., Flamary, R., Courty, N., Rolet, A.& Blondel, M. Large-scale Optimal Transport and Mapping Estimation. International Conference on Learning Representation (2018)
'''
G = - (np.exp((alpha[batch_alpha, None] + beta[None, batch_beta] -
M[batch_alpha, :][:, batch_beta]) / reg) *
@@ -555,25 +554,25 @@ def batch_grad_dual(a, b, M, reg, alpha, beta, batch_size, batch_alpha,
def sgd_entropic_regularization(a, b, M, reg, batch_size, numItermax, lr):
r'''
- Compute the sgd algorithm to solve the regularized discrete measures
- optimal transport dual problem
+ Compute the sgd algorithm to solve the regularized discrete measures optimal transport dual problem
The function solves the following optimization problem:
.. math::
- \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)
+ \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
+ \mathrm{reg} \cdot\Omega(\gamma)
- s.t. \gamma 1 = a
+ s.t. \ \gamma \mathbf{1} = \mathbf{a}
- \gamma^T 1= b
+ \gamma^T \mathbf{1} = \mathbf{b}
\gamma \geq 0
Where :
- - M is the (ns,nt) metric cost matrix
+ - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix
- :math:`\Omega` is the entropic regularization term with :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - a and b are source and target weights (sum to 1)
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1)
Parameters
----------
@@ -632,9 +631,7 @@ def sgd_entropic_regularization(a, b, M, reg, batch_size, numItermax, lr):
References
----------
- [Seguy et al., 2018] :
- International Conference on Learning Representation (2018),
- arXiv preprint arxiv:1711.02283.
+ .. [19] Seguy, V., Bhushan Damodaran, B., Flamary, R., Courty, N., Rolet, A.& Blondel, M. Large-scale Optimal Transport and Mapping Estimation. International Conference on Learning Representation (2018)
'''
n_source = np.shape(M)[0]
@@ -657,25 +654,25 @@ def sgd_entropic_regularization(a, b, M, reg, batch_size, numItermax, lr):
def solve_dual_entropic(a, b, M, reg, batch_size, numItermax=10000, lr=1,
log=False):
r'''
- Compute the transportation matrix to solve the regularized discrete measures
- optimal transport dual problem
+ Compute the transportation matrix to solve the regularized discrete measures optimal transport dual problem
The function solves the following optimization problem:
.. math::
- \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)
+ \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
+ \mathrm{reg} \cdot\Omega(\gamma)
- s.t. \gamma 1 = a
+ s.t. \ \gamma \mathbf{1} = \mathbf{a}
- \gamma^T 1= b
+ \gamma^T \mathbf{1} = \mathbf{b}
\gamma \geq 0
Where :
- - M is the (ns,nt) metric cost matrix
- - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - a and b are source and target weights (sum to 1)
+ - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix
+ - :math:`\Omega` is the entropic regularization term with :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1)
Parameters
----------
@@ -736,10 +733,7 @@ def solve_dual_entropic(a, b, M, reg, batch_size, numItermax=10000, lr=1,
References
----------
-
- [Seguy et al., 2018] :
- International Conference on Learning Representation (2018),
- arXiv preprint arxiv:1711.02283.
+ .. [19] Seguy, V., Bhushan Damodaran, B., Flamary, R., Courty, N., Rolet, A.& Blondel, M. Large-scale Optimal Transport and Mapping Estimation. International Conference on Learning Representation (2018)
'''
opt_alpha, opt_beta = sgd_entropic_regularization(a, b, M, reg, batch_size,
diff --git a/ot/unbalanced.py b/ot/unbalanced.py
index e37f10c..15e180b 100644
--- a/ot/unbalanced.py
+++ b/ot/unbalanced.py
@@ -23,29 +23,31 @@ def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', numItermax=1000,
The function solves the following optimization problem:
.. math::
- W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + reg_m KL(\gamma 1, a) + reg_m KL(\gamma^T 1, b)
+ W = \min_\gamma \ \langle \gamma, \mathbf{M} \rangle_F + \mathrm{reg}\cdot\Omega(\gamma) +
+ \mathrm{reg_m} \cdot \mathrm{KL}(\gamma \mathbf{1}, \mathbf{a}) +
+ \mathrm{reg_m} \cdot \mathrm{KL}(\gamma^T \mathbf{1}, \mathbf{b})
s.t.
- \gamma\geq 0
+ \gamma \geq 0
+
where :
- - M is the (dim_a, dim_b) metric cost matrix
- - :math:`\Omega` is the entropic regularization
- term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - a and b are source and target unbalanced distributions
+ - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix
+ - :math:`\Omega` is the entropic regularization term, :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target unbalanced distributions
- KL is the Kullback-Leibler divergence
The algorithm used for solving the problem is the generalized
- Sinkhorn-Knopp matrix scaling algorithm as proposed in [10, 23]_
+ Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[10, 25] <references-sinkhorn-unbalanced>`
Parameters
----------
a : np.ndarray (dim_a,)
- Unnormalized histogram of dimension dim_a
+ Unnormalized histogram of dimension `dim_a`
b : np.ndarray (dim_b,) or np.ndarray (dim_b, n_hists)
- One or multiple unnormalized histograms of dimension dim_b
- If many, compute all the OT distances (a, b_i)
+ One or multiple unnormalized histograms of dimension `dim_b`.
+ If many, compute all the OT distances :math:`(\mathbf{a}, \mathbf{b}_i)_i`
M : np.ndarray (dim_a, dim_b)
loss matrix
reg : float
@@ -58,7 +60,7 @@ def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', numItermax=1000,
numItermax : int, optional
Max number of iterations
stopThr : float, optional
- Stop threshol on error (>0)
+ Stop threshold on error (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
@@ -68,14 +70,14 @@ def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', numItermax=1000,
Returns
-------
if n_hists == 1:
- gamma : (dim_a x dim_b) ndarray
+ - gamma : (dim_a, dim_b) ndarray
Optimal transportation matrix for the given parameters
- log : dict
+ - log : dict
log dictionary returned only if `log` is `True`
else:
- ot_distance : (n_hists,) ndarray
- the OT distance between `a` and each of the histograms `b_i`
- log : dict
+ - ot_distance : (n_hists,) ndarray
+ the OT distance between :math:`\mathbf{a}` and each of the histograms :math:`\mathbf{b}_i`
+ - log : dict
log dictionary returned only if `log` is `True`
Examples
@@ -90,9 +92,9 @@ def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', numItermax=1000,
[0.18807035, 0.51122823]])
+ .. _references-sinkhorn-unbalanced:
References
----------
-
.. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal
Transport, Advances in Neural Information Processing Systems
(NIPS) 26, 2013
@@ -111,11 +113,11 @@ def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', numItermax=1000,
See Also
--------
- ot.unbalanced.sinkhorn_knopp_unbalanced : Unbalanced Classic Sinkhorn [10]
+ ot.unbalanced.sinkhorn_knopp_unbalanced : Unbalanced Classic Sinkhorn :ref:`[10] <references-sinkhorn-unbalanced>`
ot.unbalanced.sinkhorn_stabilized_unbalanced:
- Unbalanced Stabilized sinkhorn [9][10]
+ Unbalanced Stabilized sinkhorn :ref:`[9, 10] <references-sinkhorn-unbalanced>`
ot.unbalanced.sinkhorn_reg_scaling_unbalanced:
- Unbalanced Sinkhorn with epslilon scaling [9][10]
+ Unbalanced Sinkhorn with epslilon scaling :ref:`[9, 10] <references-sinkhorn-unbalanced>`
"""
@@ -151,29 +153,30 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn',
The function solves the following optimization problem:
.. math::
- W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + reg_m KL(\gamma 1, a) + reg_m KL(\gamma^T 1, b)
+ W = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + \mathrm{reg}\cdot\Omega(\gamma) +
+ \mathrm{reg_m} \cdot \mathrm{KL}(\gamma \mathbf{1}, \mathbf{a}) +
+ \mathrm{reg_m} \cdot \mathrm{KL}(\gamma^T \mathbf{1}, \mathbf{b})
s.t.
\gamma\geq 0
where :
- - M is the (dim_a, dim_b) metric cost matrix
- - :math:`\Omega` is the entropic regularization term
- :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - a and b are source and target unbalanced distributions
+ - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix
+ - :math:`\Omega` is the entropic regularization term, :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target unbalanced distributions
- KL is the Kullback-Leibler divergence
The algorithm used for solving the problem is the generalized
- Sinkhorn-Knopp matrix scaling algorithm as proposed in [10, 23]_
+ Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[10, 25] <references-sinkhorn-unbalanced2>`
Parameters
----------
a : np.ndarray (dim_a,)
- Unnormalized histogram of dimension dim_a
+ Unnormalized histogram of dimension `dim_a`
b : np.ndarray (dim_b,) or np.ndarray (dim_b, n_hists)
- One or multiple unnormalized histograms of dimension dim_b
- If many, compute all the OT distances (a, b_i)
+ One or multiple unnormalized histograms of dimension `dim_b`.
+ If many, compute all the OT distances :math:`(\mathbf{a}, \mathbf{b}_i)_i`
M : np.ndarray (dim_a, dim_b)
loss matrix
reg : float
@@ -186,7 +189,7 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn',
numItermax : int, optional
Max number of iterations
stopThr : float, optional
- Stop threshol on error (>0)
+ Stop threshold on error (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
@@ -196,7 +199,7 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn',
Returns
-------
ot_distance : (n_hists,) ndarray
- the OT distance between `a` and each of the histograms `b_i`
+ the OT distance between :math:`\mathbf{a}` and each of the histograms :math:`\mathbf{b}_i`
log : dict
log dictionary returned only if `log` is `True`
@@ -211,10 +214,9 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn',
array([0.31912866])
-
+ .. _references-sinkhorn-unbalanced2:
References
----------
-
.. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal
Transport, Advances in Neural Information Processing Systems
(NIPS) 26, 2013
@@ -232,9 +234,9 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn',
See Also
--------
- ot.unbalanced.sinkhorn_knopp : Unbalanced Classic Sinkhorn [10]
- ot.unbalanced.sinkhorn_stabilized: Unbalanced Stabilized sinkhorn [9][10]
- ot.unbalanced.sinkhorn_reg_scaling: Unbalanced Sinkhorn with epslilon scaling [9][10]
+ ot.unbalanced.sinkhorn_knopp : Unbalanced Classic Sinkhorn :ref:`[10] <references-sinkhorn-unbalanced2>`
+ ot.unbalanced.sinkhorn_stabilized: Unbalanced Stabilized sinkhorn :ref:`[9, 10] <references-sinkhorn-unbalanced2>`
+ ot.unbalanced.sinkhorn_reg_scaling: Unbalanced Sinkhorn with epslilon scaling :ref:`[9, 10] <references-sinkhorn-unbalanced2>`
"""
b = np.asarray(b, dtype=np.float64)
@@ -270,26 +272,29 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000,
The function solves the following optimization problem:
.. math::
- W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + \reg_m KL(\gamma 1, a) + \reg_m KL(\gamma^T 1, b)
+ W = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + \mathrm{reg}\cdot\Omega(\gamma) +
+ \mathrm{reg_m} \cdot \mathrm{KL}(\gamma \mathbf{1}, \mathbf{a}) +
+ \mathrm{reg_m} \cdot \mathrm{KL}(\gamma^T \mathbf{1}, \mathbf{b})
s.t.
- \gamma\geq 0
+ \gamma \geq 0
+
where :
- - M is the (dim_a, dim_b) metric cost matrix
- - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - a and b are source and target unbalanced distributions
+ - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix
+ - :math:`\Omega` is the entropic regularization term, :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target unbalanced distributions
- KL is the Kullback-Leibler divergence
- The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in [10, 23]_
+ The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[10, 25] <references-sinkhorn-knopp-unbalanced>`
Parameters
----------
a : np.ndarray (dim_a,)
- Unnormalized histogram of dimension dim_a
+ Unnormalized histogram of dimension `dim_a`
b : np.ndarray (dim_b,) or np.ndarray (dim_b, n_hists)
- One or multiple unnormalized histograms of dimension dim_b
+ One or multiple unnormalized histograms of dimension `dim_b`
If many, compute all the OT distances (a, b_i)
M : np.ndarray (dim_a, dim_b)
loss matrix
@@ -300,7 +305,7 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000,
numItermax : int, optional
Max number of iterations
stopThr : float, optional
- Stop threshol on error (> 0)
+ Stop threshold on error (> 0)
verbose : bool, optional
Print information along iterations
log : bool, optional
@@ -310,15 +315,16 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000,
Returns
-------
if n_hists == 1:
- gamma : (dim_a x dim_b) ndarray
+ - gamma : (dim_a, dim_b) ndarray
Optimal transportation matrix for the given parameters
- log : dict
+ - log : dict
log dictionary returned only if `log` is `True`
else:
- ot_distance : (n_hists,) ndarray
- the OT distance between `a` and each of the histograms `b_i`
- log : dict
+ - ot_distance : (n_hists,) ndarray
+ the OT distance between :math:`\mathbf{a}` and each of the histograms :math:`\mathbf{b}_i`
+ - log : dict
log dictionary returned only if `log` is `True`
+
Examples
--------
@@ -330,9 +336,10 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000,
array([[0.51122823, 0.18807035],
[0.18807035, 0.51122823]])
+
+ .. _references-sinkhorn-knopp-unbalanced:
References
----------
-
.. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
Scaling algorithms for unbalanced transport problems. arXiv preprint
arXiv:1607.05816.
@@ -445,32 +452,34 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000
problem and return the loss
The function solves the following optimization problem using log-domain
- stabilization as proposed in [10]:
+ stabilization as proposed in :ref:`[10] <references-sinkhorn-stabilized-unbalanced>`:
.. math::
- W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + reg_m KL(\gamma 1, a) + reg_m KL(\gamma^T 1, b)
+ W = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + \mathrm{reg}\cdot\Omega(\gamma) +
+ \mathrm{reg_m} \cdot \mathrm{KL}(\gamma \mathbf{1}, \mathbf{a}) +
+ \mathrm{reg_m} \cdot \mathrm{KL}(\gamma^T \mathbf{1}, \mathbf{b})
s.t.
- \gamma\geq 0
+ \gamma \geq 0
+
where :
- - M is the (dim_a, dim_b) metric cost matrix
- - :math:`\Omega` is the entropic regularization
- term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - a and b are source and target unbalanced distributions
+ - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix
+ - :math:`\Omega` is the entropic regularization term, :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target unbalanced distributions
- KL is the Kullback-Leibler divergence
The algorithm used for solving the problem is the generalized
- Sinkhorn-Knopp matrix scaling algorithm as proposed in [10, 23]_
+ Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[10, 25] <references-sinkhorn-stabilized-unbalanced>`
Parameters
----------
a : np.ndarray (dim_a,)
- Unnormalized histogram of dimension dim_a
+ Unnormalized histogram of dimension `dim_a`
b : np.ndarray (dim_b,) or np.ndarray (dim_b, n_hists)
- One or multiple unnormalized histograms of dimension dim_b
- If many, compute all the OT distances (a, b_i)
+ One or multiple unnormalized histograms of dimension `dim_b`.
+ If many, compute all the OT distances :math:`(\mathbf{a}, \mathbf{b}_i)_i`
M : np.ndarray (dim_a, dim_b)
loss matrix
reg : float
@@ -482,7 +491,7 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000
numItermax : int, optional
Max number of iterations
stopThr : float, optional
- Stop threshol on error (>0)
+ Stop threshold on error (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
@@ -492,14 +501,14 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000
Returns
-------
if n_hists == 1:
- gamma : (dim_a x dim_b) ndarray
+ - gamma : (dim_a, dim_b) ndarray
Optimal transportation matrix for the given parameters
- log : dict
+ - log : dict
log dictionary returned only if `log` is `True`
else:
- ot_distance : (n_hists,) ndarray
- the OT distance between `a` and each of the histograms `b_i`
- log : dict
+ - ot_distance : (n_hists,) ndarray
+ the OT distance between :math:`\mathbf{a}` and each of the histograms :math:`\mathbf{b}_i`
+ - log : dict
log dictionary returned only if `log` is `True`
Examples
--------
@@ -512,9 +521,10 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000
array([[0.51122823, 0.18807035],
[0.18807035, 0.51122823]])
+
+ .. _references-sinkhorn-stabilized-unbalanced:
References
----------
-
.. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
@@ -654,29 +664,27 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000
def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3,
numItermax=1000, stopThr=1e-6,
verbose=False, log=False):
- r"""Compute the entropic unbalanced wasserstein barycenter of A with stabilization.
+ r"""Compute the entropic unbalanced wasserstein barycenter of :math:`\mathbf{A}` with stabilization.
The function solves the following optimization problem:
.. math::
- \mathbf{a} = arg\min_\mathbf{a} \sum_i Wu_{reg}(\mathbf{a},\mathbf{a}_i)
+ \mathbf{a} = \mathop{\arg \min}_\mathbf{a} \quad \sum_i W_{u_{reg}}(\mathbf{a},\mathbf{a}_i)
where :
- - :math:`Wu_{reg}(\cdot,\cdot)` is the unbalanced entropic regularized
- Wasserstein distance (see ot.unbalanced.sinkhorn_unbalanced)
- - :math:`\mathbf{a}_i` are training distributions in the columns of
- matrix :math:`\mathbf{A}`
- - reg and :math:`\mathbf{M}` are respectively the regularization term and
- the cost matrix for OT
+ - :math:`W_{u_{reg}}(\cdot,\cdot)` is the unbalanced entropic regularized Wasserstein distance (see :py:func:`ot.unbalanced.sinkhorn_unbalanced`)
+ - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}`
+ - reg and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT
- reg_mis the marginal relaxation hyperparameter
- The algorithm used for solving the problem is the generalized
- Sinkhorn-Knopp matrix scaling algorithm as proposed in [10]_
+
+ The algorithm used for solving the problem is the generalized
+ Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[10] <references-barycenter-unbalanced-stabilized>`
Parameters
----------
A : np.ndarray (dim, n_hists)
- `n_hists` training distributions a_i of dimension dim
+ `n_hists` training distributions :math:`\mathbf{a}_i` of dimension `dim`
M : np.ndarray (dim, dim)
ground metric matrix for OT.
reg : float
@@ -691,7 +699,7 @@ def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3,
numItermax : int, optional
Max number of iterations
stopThr : float, optional
- Stop threshol on error (> 0)
+ Stop threshold on error (> 0)
verbose : bool, optional
Print information along iterations
log : bool, optional
@@ -706,9 +714,9 @@ def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3,
log dictionary return only if log==True in parameters
+ .. _references-barycenter-unbalanced-stabilized:
References
----------
-
.. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré,
G. (2015). Iterative Bregman projections for regularized transportation
problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138.
@@ -806,29 +814,27 @@ def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3,
def barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, weights=None,
numItermax=1000, stopThr=1e-6,
verbose=False, log=False):
- r"""Compute the entropic unbalanced wasserstein barycenter of A.
+ r"""Compute the entropic unbalanced wasserstein barycenter of :math:`\mathbf{A}`.
- The function solves the following optimization problem with a
+ The function solves the following optimization problem with :math:`\mathbf{a}`
.. math::
- \mathbf{a} = arg\min_\mathbf{a} \sum_i Wu_{reg}(\mathbf{a},\mathbf{a}_i)
+ \mathbf{a} = \mathop{\arg \min}_\mathbf{a} \quad \sum_i W_{u_{reg}}(\mathbf{a},\mathbf{a}_i)
where :
- - :math:`Wu_{reg}(\cdot,\cdot)` is the unbalanced entropic regularized
- Wasserstein distance (see ot.unbalanced.sinkhorn_unbalanced)
- - :math:`\mathbf{a}_i` are training distributions in the columns of matrix
- :math:`\mathbf{A}`
- - reg and :math:`\mathbf{M}` are respectively the regularization term and
- the cost matrix for OT
+ - :math:`W_{u_{reg}}(\cdot,\cdot)` is the unbalanced entropic regularized Wasserstein distance (see :py:func:`ot.unbalanced.sinkhorn_unbalanced`)
+ - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}`
+ - reg and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT
- reg_mis the marginal relaxation hyperparameter
+
The algorithm used for solving the problem is the generalized
- Sinkhorn-Knopp matrix scaling algorithm as proposed in [10]_
+ Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[10] <references-barycenter-unbalanced-sinkhorn>`
Parameters
----------
A : np.ndarray (dim, n_hists)
- `n_hists` training distributions a_i of dimension dim
+ `n_hists` training distributions :math:`\mathbf{a}_i` of dimension `dim`
M : np.ndarray (dim, dim)
ground metric matrix for OT.
reg : float
@@ -841,7 +847,7 @@ def barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, weights=None,
numItermax : int, optional
Max number of iterations
stopThr : float, optional
- Stop threshol on error (> 0)
+ Stop threshold on error (> 0)
verbose : bool, optional
Print information along iterations
log : bool, optional
@@ -856,9 +862,9 @@ def barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, weights=None,
log dictionary return only if log==True in parameters
+ .. _references-barycenter-unbalanced-sinkhorn:
References
----------
-
.. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G.
(2015). Iterative Bregman projections for regularized transportation
problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138.
@@ -936,29 +942,27 @@ def barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, weights=None,
def barycenter_unbalanced(A, M, reg, reg_m, method="sinkhorn", weights=None,
numItermax=1000, stopThr=1e-6,
verbose=False, log=False, **kwargs):
- r"""Compute the entropic unbalanced wasserstein barycenter of A.
+ r"""Compute the entropic unbalanced wasserstein barycenter of :math:`\mathbf{A}`.
- The function solves the following optimization problem with a
+ The function solves the following optimization problem with :math:`\mathbf{a}`
.. math::
- \mathbf{a} = arg\min_\mathbf{a} \sum_i Wu_{reg}(\mathbf{a},\mathbf{a}_i)
+ \mathbf{a} = \mathop{\arg \min}_\mathbf{a} \quad \sum_i W_{u_{reg}}(\mathbf{a},\mathbf{a}_i)
where :
- - :math:`Wu_{reg}(\cdot,\cdot)` is the unbalanced entropic regularized
- Wasserstein distance (see ot.unbalanced.sinkhorn_unbalanced)
- - :math:`\mathbf{a}_i` are training distributions in the columns of matrix
- :math:`\mathbf{A}`
- - reg and :math:`\mathbf{M}` are respectively the regularization term and
- the cost matrix for OT
+ - :math:`W_{u_{reg}}(\cdot,\cdot)` is the unbalanced entropic regularized Wasserstein distance (see :py:func:`ot.unbalanced.sinkhorn_unbalanced`)
+ - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}`
+ - reg and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT
- reg_mis the marginal relaxation hyperparameter
+
The algorithm used for solving the problem is the generalized
- Sinkhorn-Knopp matrix scaling algorithm as proposed in [10]_
+ Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[10] <references-barycenter-unbalanced>`
Parameters
----------
A : np.ndarray (dim, n_hists)
- `n_hists` training distributions a_i of dimension dim
+ `n_hists` training distributions :math:`\mathbf{a}_i` of dimension `dim`
M : np.ndarray (dim, dim)
ground metric matrix for OT.
reg : float
@@ -971,7 +975,7 @@ def barycenter_unbalanced(A, M, reg, reg_m, method="sinkhorn", weights=None,
numItermax : int, optional
Max number of iterations
stopThr : float, optional
- Stop threshol on error (> 0)
+ Stop threshold on error (> 0)
verbose : bool, optional
Print information along iterations
log : bool, optional
@@ -986,9 +990,9 @@ def barycenter_unbalanced(A, M, reg, reg_m, method="sinkhorn", weights=None,
log dictionary return only if log==True in parameters
+ .. _references-barycenter-unbalanced:
References
----------
-
.. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G.
(2015). Iterative Bregman projections for regularized transportation
problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138.
diff --git a/ot/utils.py b/ot/utils.py
index f9911a1..c878563 100644
--- a/ot/utils.py
+++ b/ot/utils.py
@@ -7,7 +7,6 @@ Various useful functions
#
# License: MIT License
-import multiprocessing
from functools import reduce
import time
@@ -15,67 +14,127 @@ import numpy as np
from scipy.spatial.distance import cdist
import sys
import warnings
-try:
- from inspect import signature
-except ImportError:
- from .externals.funcsigs import signature
+from inspect import signature
+from .backend import get_backend
__time_tic_toc = time.time()
def tic():
- """ Python implementation of Matlab tic() function """
+ r""" Python implementation of Matlab tic() function """
global __time_tic_toc
__time_tic_toc = time.time()
def toc(message='Elapsed time : {} s'):
- """ Python implementation of Matlab toc() function """
+ r""" Python implementation of Matlab toc() function """
t = time.time()
print(message.format(t - __time_tic_toc))
return t - __time_tic_toc
def toq():
- """ Python implementation of Julia toc() function """
+ r""" Python implementation of Julia toc() function """
t = time.time()
return t - __time_tic_toc
def kernel(x1, x2, method='gaussian', sigma=1, **kwargs):
- """Compute kernel matrix"""
+ r"""Compute kernel matrix"""
+
+ nx = get_backend(x1, x2)
+
if method.lower() in ['gaussian', 'gauss', 'rbf']:
- K = np.exp(-dist(x1, x2) / (2 * sigma**2))
+ K = nx.exp(-dist(x1, x2) / (2 * sigma**2))
return K
def laplacian(x):
- """Compute Laplacian matrix"""
+ r"""Compute Laplacian matrix"""
L = np.diag(np.sum(x, axis=0)) - x
return L
-def unif(n):
- """ return a uniform histogram of length n (simplex)
+def list_to_array(*lst):
+ r""" Convert a list if in numpy format """
+ if len(lst) > 1:
+ return [np.array(a) if isinstance(a, list) else a for a in lst]
+ else:
+ return np.array(lst[0]) if isinstance(lst[0], list) else lst[0]
+
+
+def proj_simplex(v, z=1):
+ r"""Compute the closest point (orthogonal projection) on the
+ generalized `(n-1)`-simplex of a vector :math:`\mathbf{v}` wrt. to the Euclidean
+ distance, thus solving:
+
+ .. math::
+ \mathcal{P}(w) \in \mathop{\arg \min}_\gamma \| \gamma - \mathbf{v} \|_2
+
+ s.t. \ \gamma^T \mathbf{1} = z
+
+ \gamma \geq 0
+
+ If :math:`\mathbf{v}` is a 2d array, compute all the projections wrt. axis 0
+
+ .. note:: This function is backend-compatible and will work on arrays
+ from all compatible backends.
Parameters
----------
+ v : {array-like}, shape (n, d)
+ z : int, optional
+ 'size' of the simplex (each vectors sum to z, 1 by default)
+
+ Returns
+ -------
+ h : ndarray, shape (`n`, `d`)
+ Array of projections on the simplex
+ """
+ nx = get_backend(v)
+ n = v.shape[0]
+ if v.ndim == 1:
+ d1 = 1
+ v = v[:, None]
+ else:
+ d1 = 0
+ d = v.shape[1]
+
+ # sort u in ascending order
+ u = nx.sort(v, axis=0)
+ # take the descending order
+ u = nx.flip(u, 0)
+ cssv = nx.cumsum(u, axis=0) - z
+ ind = nx.arange(n, type_as=v)[:, None] + 1
+ cond = u - cssv / ind > 0
+ rho = nx.sum(cond, 0)
+ theta = cssv[rho - 1, nx.arange(d)] / rho
+ w = nx.maximum(v - theta[None, :], nx.zeros(v.shape, type_as=v))
+ if d1:
+ return w[:, 0]
+ else:
+ return w
+
+
+def unif(n):
+ r"""
+ Return a uniform histogram of length `n` (simplex).
+ Parameters
+ ----------
n : int
number of bins in the histogram
Returns
-------
- h : np.array (n,)
- histogram of length n such that h_i=1/n for all i
-
-
+ h : np.array (`n`,)
+ histogram of length `n` such that :math:`\forall i, \mathbf{h}_i = \frac{1}{n}`
"""
return np.ones((n,)) / n
def clean_zeros(a, b, M):
- """ Remove all components with zeros weights in a and b
+ r""" Remove all components with zeros weights in :math:`\mathbf{a}` and :math:`\mathbf{b}`
"""
M2 = M[a > 0, :][:, b > 0].copy() # copy force c style matrix (froemd)
a2 = a[a > 0]
@@ -84,55 +143,71 @@ def clean_zeros(a, b, M):
def euclidean_distances(X, Y, squared=False):
- """
- Considering the rows of X (and Y=X) as vectors, compute the
+ r"""
+ Considering the rows of :math:`\mathbf{X}` (and :math:`\mathbf{Y} = \mathbf{X}`) as vectors, compute the
distance matrix between each pair of vectors.
+
+ .. note:: This function is backend-compatible and will work on arrays
+ from all compatible backends.
+
Parameters
----------
- X : {array-like}, shape (n_samples_1, n_features)
- Y : {array-like}, shape (n_samples_2, n_features)
+ X : array-like, shape (n_samples_1, n_features)
+ Y : array-like, shape (n_samples_2, n_features)
squared : boolean, optional
Return squared Euclidean distances.
+
Returns
-------
- distances : {array}, shape (n_samples_1, n_samples_2)
+ distances : array-like, shape (`n_samples_1`, `n_samples_2`)
"""
- XX = np.einsum('ij,ij->i', X, X)[:, np.newaxis]
- YY = np.einsum('ij,ij->i', Y, Y)[np.newaxis, :]
- distances = np.dot(X, Y.T)
- distances *= -2
- distances += XX
- distances += YY
- np.maximum(distances, 0, out=distances)
+
+ nx = get_backend(X, Y)
+
+ a2 = nx.einsum('ij,ij->i', X, X)
+ b2 = nx.einsum('ij,ij->i', Y, Y)
+
+ c = -2 * nx.dot(X, Y.T)
+ c += a2[:, None]
+ c += b2[None, :]
+
+ c = nx.maximum(c, 0)
+
+ if not squared:
+ c = nx.sqrt(c)
+
if X is Y:
- # Ensure that distances between vectors and themselves are set to 0.0.
- # This may not be the case due to floating point rounding errors.
- distances.flat[::distances.shape[0] + 1] = 0.0
- return distances if squared else np.sqrt(distances, out=distances)
+ c = c * (1 - nx.eye(X.shape[0], type_as=c))
+
+ return c
+
+def dist(x1, x2=None, metric='sqeuclidean', p=2):
+ r"""Compute distance between samples in :math:`\mathbf{x_1}` and :math:`\mathbf{x_2}`
-def dist(x1, x2=None, metric='sqeuclidean'):
- """Compute distance between samples in x1 and x2 using function scipy.spatial.distance.cdist
+ .. note:: This function is backend-compatible and will work on arrays
+ from all compatible backends.
Parameters
----------
- x1 : ndarray, shape (n1,d)
- matrix with n1 samples of size d
- x2 : array, shape (n2,d), optional
- matrix with n2 samples of size d (if None then x2=x1)
+ x1 : array-like, shape (n1,d)
+ matrix with `n1` samples of size `d`
+ x2 : array-like, shape (n2,d), optional
+ matrix with `n2` samples of size `d` (if None then :math:`\mathbf{x_2} = \mathbf{x_1}`)
metric : str | callable, optional
- Name of the metric to be computed (full list in the doc of scipy), If a string,
- the distance function can be 'braycurtis', 'canberra', 'chebyshev', 'cityblock',
- 'correlation', 'cosine', 'dice', 'euclidean', 'hamming', 'jaccard', 'kulsinski',
- 'mahalanobis', 'matching', 'minkowski', 'rogerstanimoto', 'russellrao', 'seuclidean',
+ 'sqeuclidean' or 'euclidean' on all backends. On numpy the function also
+ accepts from the scipy.spatial.distance.cdist function : 'braycurtis',
+ 'canberra', 'chebyshev', 'cityblock', 'correlation', 'cosine', 'dice',
+ 'euclidean', 'hamming', 'jaccard', 'kulsinski', 'mahalanobis',
+ 'matching', 'minkowski', 'rogerstanimoto', 'russellrao', 'seuclidean',
'sokalmichener', 'sokalsneath', 'sqeuclidean', 'wminkowski', 'yule'.
Returns
-------
- M : np.array (n1,n2)
+ M : array-like, shape (`n1`, `n2`)
distance matrix computed with given metric
"""
@@ -140,11 +215,17 @@ def dist(x1, x2=None, metric='sqeuclidean'):
x2 = x1
if metric == "sqeuclidean":
return euclidean_distances(x1, x2, squared=True)
- return cdist(x1, x2, metric=metric)
+ elif metric == "euclidean":
+ return euclidean_distances(x1, x2, squared=False)
+ else:
+ if not get_backend(x1, x2).__name__ == 'numpy':
+ raise NotImplementedError()
+ else:
+ return cdist(x1, x2, metric=metric, p=p)
def dist0(n, method='lin_square'):
- """Compute standard cost matrices of size (n, n) for OT problems
+ r"""Compute standard cost matrices of size (`n`, `n`) for OT problems
Parameters
----------
@@ -153,11 +234,11 @@ def dist0(n, method='lin_square'):
method : str, optional
Type of loss matrix chosen from:
- * 'lin_square' : linear sampling between 0 and n-1, quadratic loss
+ * 'lin_square' : linear sampling between 0 and `n-1`, quadratic loss
Returns
-------
- M : ndarray, shape (n1,n2)
+ M : ndarray, shape (`n1`, `n2`)
Distance matrix computed with given metric.
"""
res = 0
@@ -168,7 +249,7 @@ def dist0(n, method='lin_square'):
def cost_normalization(C, norm=None):
- """ Apply normalization to the loss matrix
+ r""" Apply normalization to the loss matrix
Parameters
----------
@@ -180,7 +261,7 @@ def cost_normalization(C, norm=None):
Returns
-------
- C : ndarray, shape (n1, n2)
+ C : ndarray, shape (`n1`, `n2`)
The input cost matrix normalized according to given norm.
"""
@@ -202,23 +283,23 @@ def cost_normalization(C, norm=None):
def dots(*args):
- """ dots function for multiple matrix multiply """
+ r""" dots function for multiple matrix multiply """
return reduce(np.dot, args)
def label_normalization(y, start=0):
- """ Transform labels to start at a given value
+ r""" Transform labels to start at a given value
Parameters
----------
y : array-like, shape (n, )
The vector of labels to be normalized.
start : int
- Desired value for the smallest label in y (default=0)
+ Desired value for the smallest label in :math:`\mathbf{y}` (default=0)
Returns
-------
- y : array-like, shape (n1, )
+ y : array-like, shape (`n1`, )
The input vector of labels normalized according to given start value.
"""
@@ -228,42 +309,15 @@ def label_normalization(y, start=0):
return y
-def fun(f, q_in, q_out):
- """ Utility function for parmap with no serializing problems """
- while True:
- i, x = q_in.get()
- if i is None:
- break
- q_out.put((i, f(x)))
-
-
-def parmap(f, X, nprocs=multiprocessing.cpu_count()):
- """ paralell map for multiprocessing (only map on windows)"""
-
- if not sys.platform.endswith('win32'):
-
- q_in = multiprocessing.Queue(1)
- q_out = multiprocessing.Queue()
-
- proc = [multiprocessing.Process(target=fun, args=(f, q_in, q_out))
- for _ in range(nprocs)]
- for p in proc:
- p.daemon = True
- p.start()
-
- sent = [q_in.put((i, x)) for i, x in enumerate(X)]
- [q_in.put((None, None)) for _ in range(nprocs)]
- res = [q_out.get() for _ in range(len(sent))]
-
- [p.join() for p in proc]
-
- return [x for i, x in sorted(res)]
- else:
- return list(map(f, X))
+def parmap(f, X, nprocs="default"):
+ r""" parallel map for multiprocessing.
+ The function has been deprecated and only performs a regular map.
+ """
+ return list(map(f, X))
def check_params(**kwargs):
- """check_params: check whether some parameters are missing
+ r"""check_params: check whether some parameters are missing
"""
missing_params = []
@@ -284,14 +338,14 @@ def check_params(**kwargs):
def check_random_state(seed):
- """Turn seed into a np.random.RandomState instance
+ r"""Turn `seed` into a np.random.RandomState instance
Parameters
----------
seed : None | int | instance of RandomState
- If seed is None, return the RandomState singleton used by np.random.
- If seed is an int, return a new RandomState instance seeded with seed.
- If seed is already a RandomState instance, return it.
+ If `seed` is None, return the RandomState singleton used by np.random.
+ If `seed` is an int, return a new RandomState instance seeded with `seed`.
+ If `seed` is already a RandomState instance, return it.
Otherwise raise ValueError.
"""
if seed is None or seed is np.random:
@@ -305,18 +359,21 @@ def check_random_state(seed):
class deprecated(object):
- """Decorator to mark a function or class as deprecated.
+ r"""Decorator to mark a function or class as deprecated.
deprecated class from scikit-learn package
https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/utils/deprecation.py
Issue a warning when the function is called/the class is instantiated and
adds a warning to the docstring.
The optional extra argument will be appended to the deprecation message
- and the docstring. Note: to use this with the default value for extra, put
- in an empty of parentheses:
- >>> from ot.deprecation import deprecated # doctest: +SKIP
- >>> @deprecated() # doctest: +SKIP
- ... def some_function(): pass # doctest: +SKIP
+ and the docstring.
+
+ .. note::
+ To use this with the default value for extra, use empty parentheses:
+
+ >>> from ot.deprecation import deprecated # doctest: +SKIP
+ >>> @deprecated() # doctest: +SKIP
+ ... def some_function(): pass # doctest: +SKIP
Parameters
----------
@@ -331,7 +388,7 @@ class deprecated(object):
self.extra = extra
def __call__(self, obj):
- """Call method
+ r"""Call method
Parameters
----------
obj : object
@@ -362,7 +419,7 @@ class deprecated(object):
return cls
def _decorate_fun(self, fun):
- """Decorate function fun"""
+ r"""Decorate function fun"""
msg = "Function %s is deprecated" % fun.__name__
if self.extra:
@@ -388,7 +445,7 @@ class deprecated(object):
def _is_deprecated(func):
- """Helper to check if func is wraped by our deprecated decorator"""
+ r"""Helper to check if func is wraped by our deprecated decorator"""
if sys.version_info < (3, 5):
raise NotImplementedError("This is only available for python3.5 "
"or above")
@@ -402,7 +459,7 @@ def _is_deprecated(func):
class BaseEstimator(object):
- """Base class for most objects in POT
+ r"""Base class for most objects in POT
Code adapted from sklearn BaseEstimator class
@@ -415,7 +472,7 @@ class BaseEstimator(object):
@classmethod
def _get_param_names(cls):
- """Get parameter names for the estimator"""
+ r"""Get parameter names for the estimator"""
# fetch the constructor or the original constructor before
# deprecation wrapping if any
@@ -442,7 +499,7 @@ class BaseEstimator(object):
return sorted([p.name for p in parameters])
def get_params(self, deep=True):
- """Get parameters for this estimator.
+ r"""Get parameters for this estimator.
Parameters
----------
@@ -479,7 +536,7 @@ class BaseEstimator(object):
return out
def set_params(self, **params):
- """Set the parameters of this estimator.
+ r"""Set the parameters of this estimator.
The method works on simple estimators as well as on nested objects
(such as pipelines). The latter have parameters of the form
@@ -519,7 +576,7 @@ class BaseEstimator(object):
class UndefinedParameter(Exception):
- """
+ r"""
Aim at raising an Exception when a undefined parameter is called
"""
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 0000000..3f8ae8b
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,3 @@
+[build-system]
+requires = ["setuptools", "wheel", "numpy>=1.16", "cython>=0.23"]
+build-backend = "setuptools.build_meta" \ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
index 331dd57..4353247 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -7,4 +7,7 @@ pymanopt==0.2.4; python_version <'3'
pymanopt; python_version >= '3'
cvxopt
scikit-learn
+torch
+jax
+jaxlib
pytest \ No newline at end of file
diff --git a/setup.cfg b/setup.cfg
index 6be91fe..1177faf 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -12,7 +12,7 @@ addopts =
--ignore=docs --ignore=examples --ignore=notebooks
[pycodestyle]
-exclude = __init__.py,*externals*,constants.py,fixes.py
+exclude = __init__.py,constants.py,fixes.py
ignore = E241,E305,W504
[pydocstyle]
diff --git a/setup.py b/setup.py
index 91c24d9..86c7c8d 100755..100644
--- a/setup.py
+++ b/setup.py
@@ -1,21 +1,18 @@
#!/usr/bin/env python
-from setuptools import setup, find_packages
-from codecs import open
-from os import path
-from setuptools.extension import Extension
-from Cython.Build import cythonize
-import numpy
-import re
import os
-import sys
+import re
import subprocess
+import sys
-here = path.abspath(path.dirname(__file__))
+from setuptools import find_packages, setup
+from setuptools.extension import Extension
+import numpy
+from Cython.Build import cythonize
-os.environ["CC"] = "g++"
-os.environ["CXX"] = "g++"
+sys.path.append(os.path.join("ot", "helpers"))
+from openmp_helpers import check_openmp_support
# dirty but working
__version__ = re.search(
@@ -24,74 +21,79 @@ __version__ = re.search(
# The beautiful part is, I don't even need to check exceptions here.
# If something messes up, let the build process fail noisy, BEFORE my release!
-# thanks Pipy for handling markdown now
+# thanks PyPI for handling markdown now
ROOT = os.path.abspath(os.path.dirname(__file__))
-
with open(os.path.join(ROOT, 'README.md'), encoding="utf-8") as f:
README = f.read()
-opt_arg = ["-O3"]
-
# clean cython output is clean is called
if 'clean' in sys.argv[1:]:
if os.path.isfile('ot/lp/emd_wrap.cpp'):
os.remove('ot/lp/emd_wrap.cpp')
-
# add platform dependant optional compilation argument
+openmp_supported, flags = check_openmp_support()
+compile_args = ["/O2" if sys.platform == "win32" else "-O3"]
+link_args = []
+
+if openmp_supported:
+ compile_args += flags + ["/DOMP" if sys.platform == 'win32' else "-DOMP"]
+ link_args += flags
+
if sys.platform.startswith('darwin'):
- opt_arg.append("-stdlib=libc++")
+ compile_args.append("-stdlib=libc++")
sdk_path = subprocess.check_output(['xcrun', '--show-sdk-path'])
os.environ['CFLAGS'] = '-isysroot "{}"'.format(sdk_path.rstrip().decode("utf-8"))
-setup(name='POT',
- version=__version__,
- description='Python Optimal Transport Library',
- long_description=README,
- long_description_content_type='text/markdown',
- author=u'Remi Flamary, Nicolas Courty',
- author_email='remi.flamary@gmail.com, ncourty@gmail.com',
- url='https://github.com/PythonOT/POT',
- packages=find_packages(),
- ext_modules=cythonize(Extension(
- "ot.lp.emd_wrap", # the extension name
- sources=["ot/lp/emd_wrap.pyx", "ot/lp/EMD_wrapper.cpp"], # the Cython source and
- # additional C++ source files
- language="c++", # generate and compile C++ code,
- include_dirs=[numpy.get_include(), os.path.join(ROOT, 'ot/lp')],
- extra_compile_args=opt_arg
- )),
- platforms=['linux', 'macosx', 'windows'],
- download_url='https://github.com/PythonOT/POT/archive/{}.tar.gz'.format(__version__),
- license='MIT',
- scripts=[],
- data_files=[],
- requires=["numpy", "scipy", "cython"],
- setup_requires=["numpy>=1.16", "scipy>=1.0", "cython>=0.23"],
- install_requires=["numpy>=1.16", "scipy>=1.0", "cython>=0.23"],
- classifiers=[
- 'Development Status :: 5 - Production/Stable',
- 'Intended Audience :: Developers',
- 'Intended Audience :: Education',
- 'Intended Audience :: Science/Research',
- 'License :: OSI Approved :: MIT License',
- 'Environment :: Console',
- 'Operating System :: OS Independent',
- 'Operating System :: MacOS',
- 'Operating System :: POSIX',
- 'Programming Language :: Python',
- 'Programming Language :: C++',
- 'Programming Language :: C',
- 'Programming Language :: Cython',
- 'Topic :: Utilities',
- 'Topic :: Scientific/Engineering :: Artificial Intelligence',
- 'Topic :: Scientific/Engineering :: Mathematics',
- 'Topic :: Scientific/Engineering :: Information Analysis',
- 'Programming Language :: Python :: 2',
- 'Programming Language :: Python :: 2.7',
- 'Programming Language :: Python :: 3',
- 'Programming Language :: Python :: 3.4',
- 'Programming Language :: Python :: 3.5',
- 'Programming Language :: Python :: 3.6',
- ]
- )
+setup(
+ name='POT',
+ version=__version__,
+ description='Python Optimal Transport Library',
+ long_description=README,
+ long_description_content_type='text/markdown',
+ author=u'Remi Flamary, Nicolas Courty',
+ author_email='remi.flamary@gmail.com, ncourty@gmail.com',
+ url='https://github.com/PythonOT/POT',
+ packages=find_packages(),
+ ext_modules=cythonize(Extension(
+ name="ot.lp.emd_wrap",
+ sources=["ot/lp/emd_wrap.pyx", "ot/lp/EMD_wrapper.cpp"], # cython/c++ src files
+ language="c++",
+ include_dirs=[numpy.get_include(), os.path.join(ROOT, 'ot/lp')],
+ extra_compile_args=compile_args,
+ extra_link_args=link_args
+ )),
+ platforms=['linux', 'macosx', 'windows'],
+ download_url='https://github.com/PythonOT/POT/archive/{}.tar.gz'.format(__version__),
+ license='MIT',
+ scripts=[],
+ data_files=[],
+ setup_requires=["numpy>=1.16", "cython>=0.23"],
+ install_requires=["numpy>=1.16", "scipy>=1.0"],
+ classifiers=[
+ 'Development Status :: 5 - Production/Stable',
+ 'Intended Audience :: Developers',
+ 'Intended Audience :: Education',
+ 'Intended Audience :: Science/Research',
+ 'License :: OSI Approved :: MIT License',
+ 'Environment :: Console',
+ 'Operating System :: OS Independent',
+ 'Operating System :: POSIX :: Linux',
+ 'Operating System :: MacOS',
+ 'Operating System :: POSIX',
+ 'Operating System :: Microsoft :: Windows',
+ 'Programming Language :: Python',
+ 'Programming Language :: C++',
+ 'Programming Language :: C',
+ 'Programming Language :: Cython',
+ 'Topic :: Utilities',
+ 'Topic :: Scientific/Engineering :: Artificial Intelligence',
+ 'Topic :: Scientific/Engineering :: Mathematics',
+ 'Topic :: Scientific/Engineering :: Information Analysis',
+ 'Programming Language :: Python :: 3',
+ 'Programming Language :: Python :: 3.6',
+ 'Programming Language :: Python :: 3.7',
+ 'Programming Language :: Python :: 3.8',
+ 'Programming Language :: Python :: 3.9',
+ ]
+)
diff --git a/test/conftest.py b/test/conftest.py
new file mode 100644
index 0000000..987d98e
--- /dev/null
+++ b/test/conftest.py
@@ -0,0 +1,62 @@
+# -*- coding: utf-8 -*-
+
+# Configuration file for pytest
+
+# License: MIT License
+
+import pytest
+from ot.backend import jax
+from ot.backend import get_backend_list
+import functools
+
+if jax:
+ from jax.config import config
+ config.update("jax_enable_x64", True)
+
+backend_list = get_backend_list()
+
+
+@pytest.fixture(params=backend_list)
+def nx(request):
+ backend = request.param
+
+ yield backend
+
+
+def skip_arg(arg, value, reason=None, getter=lambda x: x):
+ if isinstance(arg, tuple) or isinstance(arg, list):
+ n = len(arg)
+ else:
+ arg = (arg, )
+ n = 1
+ if n != 1 and (isinstance(value, tuple) or isinstance(value, list)):
+ pass
+ else:
+ value = (value, )
+ if isinstance(getter, tuple) or isinstance(value, list):
+ pass
+ else:
+ getter = [getter] * n
+
+ if reason is None:
+ reason = f"Param {arg} should be skipped for value {value}"
+
+ def wrapper(function):
+
+ @functools.wraps(function)
+ def wrapped(*args, **kwargs):
+ if all(
+ arg[i] in kwargs.keys() and getter[i](kwargs[arg[i]]) == value[i]
+ for i in range(n)
+ ):
+ pytest.skip(reason)
+ return function(*args, **kwargs)
+
+ return wrapped
+
+ return wrapper
+
+
+def pytest_configure(config):
+ pytest.skip_arg = skip_arg
+ pytest.skip_backend = functools.partial(skip_arg, "nx", getter=str)
diff --git a/test/test_1d_solver.py b/test/test_1d_solver.py
new file mode 100644
index 0000000..cb85cb9
--- /dev/null
+++ b/test/test_1d_solver.py
@@ -0,0 +1,172 @@
+"""Tests for module 1d Wasserstein solver"""
+
+# Author: Adrien Corenflos <adrien.corenflos@aalto.fi>
+# Nicolas Courty <ncourty@irisa.fr>
+#
+# License: MIT License
+
+import numpy as np
+import pytest
+
+import ot
+from ot.lp import wasserstein_1d
+
+from ot.backend import get_backend_list
+from scipy.stats import wasserstein_distance
+
+backend_list = get_backend_list()
+
+
+def test_emd_1d_emd2_1d_with_weights():
+ # test emd1d gives similar results as emd
+ n = 20
+ m = 30
+ rng = np.random.RandomState(0)
+ u = rng.randn(n, 1)
+ v = rng.randn(m, 1)
+
+ w_u = rng.uniform(0., 1., n)
+ w_u = w_u / w_u.sum()
+
+ w_v = rng.uniform(0., 1., m)
+ w_v = w_v / w_v.sum()
+
+ M = ot.dist(u, v, metric='sqeuclidean')
+
+ G, log = ot.emd(w_u, w_v, M, log=True)
+ wass = log["cost"]
+ G_1d, log = ot.emd_1d(u, v, w_u, w_v, metric='sqeuclidean', log=True)
+ wass1d = log["cost"]
+ wass1d_emd2 = ot.emd2_1d(u, v, w_u, w_v, metric='sqeuclidean', log=False)
+ wass1d_euc = ot.emd2_1d(u, v, w_u, w_v, metric='euclidean', log=False)
+
+ # check loss is similar
+ np.testing.assert_allclose(wass, wass1d)
+ np.testing.assert_allclose(wass, wass1d_emd2)
+
+ # check loss is similar to scipy's implementation for Euclidean metric
+ wass_sp = wasserstein_distance(u.reshape((-1,)), v.reshape((-1,)), w_u, w_v)
+ np.testing.assert_allclose(wass_sp, wass1d_euc)
+
+ # check constraints
+ np.testing.assert_allclose(w_u, G.sum(1))
+ np.testing.assert_allclose(w_v, G.sum(0))
+
+
+@pytest.mark.parametrize('nx', backend_list)
+def test_wasserstein_1d(nx):
+ from scipy.stats import wasserstein_distance
+
+ rng = np.random.RandomState(0)
+
+ n = 100
+ x = np.linspace(0, 5, n)
+ rho_u = np.abs(rng.randn(n))
+ rho_u /= rho_u.sum()
+ rho_v = np.abs(rng.randn(n))
+ rho_v /= rho_v.sum()
+
+ xb = nx.from_numpy(x)
+ rho_ub = nx.from_numpy(rho_u)
+ rho_vb = nx.from_numpy(rho_v)
+
+ # test 1 : wasserstein_1d should be close to scipy W_1 implementation
+ np.testing.assert_almost_equal(wasserstein_1d(xb, xb, rho_ub, rho_vb, p=1),
+ wasserstein_distance(x, x, rho_u, rho_v))
+
+ # test 2 : wasserstein_1d should be close to one when only translating the support
+ np.testing.assert_almost_equal(wasserstein_1d(xb, xb + 1, p=2),
+ 1.)
+
+ # test 3 : arrays test
+ X = np.stack((np.linspace(0, 5, n), np.linspace(0, 5, n) * 10), -1)
+ Xb = nx.from_numpy(X)
+ res = wasserstein_1d(Xb, Xb, rho_ub, rho_vb, p=2)
+ np.testing.assert_almost_equal(100 * res[0], res[1], decimal=4)
+
+
+def test_wasserstein_1d_type_devices(nx):
+
+ rng = np.random.RandomState(0)
+
+ n = 10
+ x = np.linspace(0, 5, n)
+ rho_u = np.abs(rng.randn(n))
+ rho_u /= rho_u.sum()
+ rho_v = np.abs(rng.randn(n))
+ rho_v /= rho_v.sum()
+
+ for tp in nx.__type_list__:
+ print(nx.dtype_device(tp))
+
+ xb = nx.from_numpy(x, type_as=tp)
+ rho_ub = nx.from_numpy(rho_u, type_as=tp)
+ rho_vb = nx.from_numpy(rho_v, type_as=tp)
+
+ res = wasserstein_1d(xb, xb, rho_ub, rho_vb, p=1)
+
+ nx.assert_same_dtype_device(xb, res)
+
+
+def test_emd_1d_emd2_1d():
+ # test emd1d gives similar results as emd
+ n = 20
+ m = 30
+ rng = np.random.RandomState(0)
+ u = rng.randn(n, 1)
+ v = rng.randn(m, 1)
+
+ M = ot.dist(u, v, metric='sqeuclidean')
+
+ G, log = ot.emd([], [], M, log=True)
+ wass = log["cost"]
+ G_1d, log = ot.emd_1d(u, v, [], [], metric='sqeuclidean', log=True)
+ wass1d = log["cost"]
+ wass1d_emd2 = ot.emd2_1d(u, v, [], [], metric='sqeuclidean', log=False)
+ wass1d_euc = ot.emd2_1d(u, v, [], [], metric='euclidean', log=False)
+
+ # check loss is similar
+ np.testing.assert_allclose(wass, wass1d)
+ np.testing.assert_allclose(wass, wass1d_emd2)
+
+ # check loss is similar to scipy's implementation for Euclidean metric
+ wass_sp = wasserstein_distance(u.reshape((-1,)), v.reshape((-1,)))
+ np.testing.assert_allclose(wass_sp, wass1d_euc)
+
+ # check constraints
+ np.testing.assert_allclose(np.ones((n,)) / n, G.sum(1))
+ np.testing.assert_allclose(np.ones((m,)) / m, G.sum(0))
+
+ # check G is similar
+ np.testing.assert_allclose(G, G_1d, atol=1e-15)
+
+ # check AssertionError is raised if called on non 1d arrays
+ u = np.random.randn(n, 2)
+ v = np.random.randn(m, 2)
+ with pytest.raises(AssertionError):
+ ot.emd_1d(u, v, [], [])
+
+
+def test_emd1d_type_devices(nx):
+
+ rng = np.random.RandomState(0)
+
+ n = 10
+ x = np.linspace(0, 5, n)
+ rho_u = np.abs(rng.randn(n))
+ rho_u /= rho_u.sum()
+ rho_v = np.abs(rng.randn(n))
+ rho_v /= rho_v.sum()
+
+ for tp in nx.__type_list__:
+ print(nx.dtype_device(tp))
+
+ xb = nx.from_numpy(x, type_as=tp)
+ rho_ub = nx.from_numpy(rho_u, type_as=tp)
+ rho_vb = nx.from_numpy(rho_v, type_as=tp)
+
+ emd = ot.emd_1d(xb, xb, rho_ub, rho_vb)
+ emd2 = ot.emd2_1d(xb, xb, rho_ub, rho_vb)
+
+ nx.assert_same_dtype_device(xb, emd)
+ nx.assert_same_dtype_device(xb, emd2)
diff --git a/test/test_backend.py b/test/test_backend.py
new file mode 100644
index 0000000..1832b91
--- /dev/null
+++ b/test/test_backend.py
@@ -0,0 +1,577 @@
+"""Tests for backend module """
+
+# Author: Remi Flamary <remi.flamary@polytechnique.edu>
+# Nicolas Courty <ncourty@irisa.fr>
+#
+# License: MIT License
+
+import ot
+import ot.backend
+from ot.backend import torch, jax
+
+import pytest
+
+import numpy as np
+from numpy.testing import assert_array_almost_equal_nulp
+
+from ot.backend import get_backend, get_backend_list, to_numpy
+
+
+def test_get_backend_list():
+
+ lst = get_backend_list()
+
+ assert len(lst) > 0
+ assert isinstance(lst[0], ot.backend.NumpyBackend)
+
+
+def test_to_numpy(nx):
+
+ v = nx.zeros(10)
+ M = nx.ones((10, 10))
+
+ v2 = to_numpy(v)
+ assert isinstance(v2, np.ndarray)
+
+ v2, M2 = to_numpy(v, M)
+ assert isinstance(M2, np.ndarray)
+
+
+def test_get_backend():
+
+ A = np.zeros((3, 2))
+ B = np.zeros((3, 1))
+
+ nx = get_backend(A)
+ assert nx.__name__ == 'numpy'
+
+ nx = get_backend(A, B)
+ assert nx.__name__ == 'numpy'
+
+ # error if no parameters
+ with pytest.raises(ValueError):
+ get_backend()
+
+ # error if unknown types
+ with pytest.raises(ValueError):
+ get_backend(1, 2.0)
+
+ # test torch
+ if torch:
+
+ A2 = torch.from_numpy(A)
+ B2 = torch.from_numpy(B)
+
+ nx = get_backend(A2)
+ assert nx.__name__ == 'torch'
+
+ nx = get_backend(A2, B2)
+ assert nx.__name__ == 'torch'
+
+ # test not unique types in input
+ with pytest.raises(ValueError):
+ get_backend(A, B2)
+
+ if jax:
+
+ A2 = jax.numpy.array(A)
+ B2 = jax.numpy.array(B)
+
+ nx = get_backend(A2)
+ assert nx.__name__ == 'jax'
+
+ nx = get_backend(A2, B2)
+ assert nx.__name__ == 'jax'
+
+ # test not unique types in input
+ with pytest.raises(ValueError):
+ get_backend(A, B2)
+
+
+def test_convert_between_backends(nx):
+
+ A = np.zeros((3, 2))
+ B = np.zeros((3, 1))
+
+ A2 = nx.from_numpy(A)
+ B2 = nx.from_numpy(B)
+
+ assert isinstance(A2, nx.__type__)
+ assert isinstance(B2, nx.__type__)
+
+ nx2 = get_backend(A2, B2)
+
+ assert nx2.__name__ == nx.__name__
+
+ assert_array_almost_equal_nulp(nx.to_numpy(A2), A)
+ assert_array_almost_equal_nulp(nx.to_numpy(B2), B)
+
+
+def test_empty_backend():
+
+ rnd = np.random.RandomState(0)
+ M = rnd.randn(10, 3)
+ v = rnd.randn(3)
+
+ nx = ot.backend.Backend()
+
+ with pytest.raises(NotImplementedError):
+ nx.from_numpy(M)
+ with pytest.raises(NotImplementedError):
+ nx.to_numpy(M)
+ with pytest.raises(NotImplementedError):
+ nx.set_gradients(0, 0, 0)
+ with pytest.raises(NotImplementedError):
+ nx.zeros((10, 3))
+ with pytest.raises(NotImplementedError):
+ nx.ones((10, 3))
+ with pytest.raises(NotImplementedError):
+ nx.arange(10, 1, 2)
+ with pytest.raises(NotImplementedError):
+ nx.full((10, 3), 3.14)
+ with pytest.raises(NotImplementedError):
+ nx.eye((10, 3))
+ with pytest.raises(NotImplementedError):
+ nx.sum(M)
+ with pytest.raises(NotImplementedError):
+ nx.cumsum(M)
+ with pytest.raises(NotImplementedError):
+ nx.max(M)
+ with pytest.raises(NotImplementedError):
+ nx.min(M)
+ with pytest.raises(NotImplementedError):
+ nx.maximum(v, v)
+ with pytest.raises(NotImplementedError):
+ nx.minimum(v, v)
+ with pytest.raises(NotImplementedError):
+ nx.abs(M)
+ with pytest.raises(NotImplementedError):
+ nx.log(M)
+ with pytest.raises(NotImplementedError):
+ nx.exp(M)
+ with pytest.raises(NotImplementedError):
+ nx.sqrt(M)
+ with pytest.raises(NotImplementedError):
+ nx.power(v, 2)
+ with pytest.raises(NotImplementedError):
+ nx.dot(v, v)
+ with pytest.raises(NotImplementedError):
+ nx.norm(M)
+ with pytest.raises(NotImplementedError):
+ nx.exp(M)
+ with pytest.raises(NotImplementedError):
+ nx.any(M)
+ with pytest.raises(NotImplementedError):
+ nx.isnan(M)
+ with pytest.raises(NotImplementedError):
+ nx.isinf(M)
+ with pytest.raises(NotImplementedError):
+ nx.einsum('ij->i', M)
+ with pytest.raises(NotImplementedError):
+ nx.sort(M)
+ with pytest.raises(NotImplementedError):
+ nx.argsort(M)
+ with pytest.raises(NotImplementedError):
+ nx.searchsorted(v, v)
+ with pytest.raises(NotImplementedError):
+ nx.flip(M)
+ with pytest.raises(NotImplementedError):
+ nx.outer(v, v)
+ with pytest.raises(NotImplementedError):
+ nx.clip(M, -1, 1)
+ with pytest.raises(NotImplementedError):
+ nx.repeat(M, 0, 1)
+ with pytest.raises(NotImplementedError):
+ nx.take_along_axis(M, v, 0)
+ with pytest.raises(NotImplementedError):
+ nx.concatenate([v, v])
+ with pytest.raises(NotImplementedError):
+ nx.zero_pad(M, v)
+ with pytest.raises(NotImplementedError):
+ nx.argmax(M)
+ with pytest.raises(NotImplementedError):
+ nx.mean(M)
+ with pytest.raises(NotImplementedError):
+ nx.std(M)
+ with pytest.raises(NotImplementedError):
+ nx.linspace(0, 1, 50)
+ with pytest.raises(NotImplementedError):
+ nx.meshgrid(v, v)
+ with pytest.raises(NotImplementedError):
+ nx.diag(M)
+ with pytest.raises(NotImplementedError):
+ nx.unique([M, M])
+ with pytest.raises(NotImplementedError):
+ nx.logsumexp(M)
+ with pytest.raises(NotImplementedError):
+ nx.stack([M, M])
+ with pytest.raises(NotImplementedError):
+ nx.reshape(M, (5, 3, 2))
+ with pytest.raises(NotImplementedError):
+ nx.seed(42)
+ with pytest.raises(NotImplementedError):
+ nx.rand()
+ with pytest.raises(NotImplementedError):
+ nx.randn()
+ nx.coo_matrix(M, M, M)
+ with pytest.raises(NotImplementedError):
+ nx.issparse(M)
+ with pytest.raises(NotImplementedError):
+ nx.tocsr(M)
+ with pytest.raises(NotImplementedError):
+ nx.eliminate_zeros(M)
+ with pytest.raises(NotImplementedError):
+ nx.todense(M)
+ with pytest.raises(NotImplementedError):
+ nx.where(M, M, M)
+ with pytest.raises(NotImplementedError):
+ nx.copy(M)
+ with pytest.raises(NotImplementedError):
+ nx.allclose(M, M)
+
+
+def test_func_backends(nx):
+
+ rnd = np.random.RandomState(0)
+ M = rnd.randn(10, 3)
+ v = rnd.randn(3)
+ val = np.array([1.0])
+
+ # Sparse tensors test
+ sp_row = np.array([0, 3, 1, 0, 3])
+ sp_col = np.array([0, 3, 1, 2, 2])
+ sp_data = np.array([4, 5, 7, 9, 0])
+
+ lst_tot = []
+
+ for nx in [ot.backend.NumpyBackend(), nx]:
+
+ print('Backend: ', nx.__name__)
+
+ lst_b = []
+ lst_name = []
+
+ Mb = nx.from_numpy(M)
+ vb = nx.from_numpy(v)
+
+ val = nx.from_numpy(val)
+
+ sp_rowb = nx.from_numpy(sp_row)
+ sp_colb = nx.from_numpy(sp_col)
+ sp_datab = nx.from_numpy(sp_data)
+
+ A = nx.set_gradients(val, v, v)
+
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('set_gradients')
+
+ A = nx.zeros((10, 3))
+ A = nx.zeros((10, 3), type_as=Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('zeros')
+
+ A = nx.ones((10, 3))
+ A = nx.ones((10, 3), type_as=Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('ones')
+
+ A = nx.arange(10, 1, 2)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('arange')
+
+ A = nx.full((10, 3), 3.14)
+ A = nx.full((10, 3), 3.14, type_as=Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('full')
+
+ A = nx.eye(10, 3)
+ A = nx.eye(10, 3, type_as=Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('eye')
+
+ A = nx.sum(Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('sum')
+
+ A = nx.sum(Mb, axis=1, keepdims=True)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('sum(axis)')
+
+ A = nx.cumsum(Mb, 0)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('cumsum(axis)')
+
+ A = nx.max(Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('max')
+
+ A = nx.max(Mb, axis=1, keepdims=True)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('max(axis)')
+
+ A = nx.min(Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('min')
+
+ A = nx.min(Mb, axis=1, keepdims=True)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('min(axis)')
+
+ A = nx.maximum(vb, 0)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('maximum')
+
+ A = nx.minimum(vb, 0)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('minimum')
+
+ A = nx.abs(Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('abs')
+
+ A = nx.log(A)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('log')
+
+ A = nx.exp(Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('exp')
+
+ A = nx.sqrt(nx.abs(Mb))
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('sqrt')
+
+ A = nx.power(Mb, 2)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('power')
+
+ A = nx.dot(vb, vb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('dot(v,v)')
+
+ A = nx.dot(Mb, vb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('dot(M,v)')
+
+ A = nx.dot(Mb, Mb.T)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('dot(M,M)')
+
+ A = nx.norm(vb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('norm')
+
+ A = nx.any(vb > 0)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('any')
+
+ A = nx.isnan(vb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('isnan')
+
+ A = nx.isinf(vb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('isinf')
+
+ A = nx.einsum('ij->i', Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('einsum(ij->i)')
+
+ A = nx.einsum('ij,j->i', Mb, vb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('nx.einsum(ij,j->i)')
+
+ A = nx.einsum('ij->i', Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('nx.einsum(ij->i)')
+
+ A = nx.sort(Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('sort')
+
+ A = nx.argsort(Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('argsort')
+
+ A = nx.searchsorted(Mb, Mb, 'right')
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('searchsorted')
+
+ A = nx.flip(Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('flip')
+
+ A = nx.outer(vb, vb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('outer')
+
+ A = nx.clip(vb, 0, 1)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('clip')
+
+ A = nx.repeat(Mb, 0)
+ A = nx.repeat(Mb, 2, -1)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('repeat')
+
+ A = nx.take_along_axis(vb, nx.arange(3), -1)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('take_along_axis')
+
+ A = nx.concatenate((Mb, Mb), -1)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('concatenate')
+
+ A = nx.zero_pad(Mb, len(Mb.shape) * [(3, 3)])
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('zero_pad')
+
+ A = nx.argmax(Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('argmax')
+
+ A = nx.mean(Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('mean')
+
+ A = nx.std(Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('std')
+
+ A = nx.linspace(0, 1, 50)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('linspace')
+
+ X, Y = nx.meshgrid(vb, vb)
+ lst_b.append(np.stack([nx.to_numpy(X), nx.to_numpy(Y)]))
+ lst_name.append('meshgrid')
+
+ A = nx.diag(Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('diag2D')
+
+ A = nx.diag(vb, 1)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('diag1D')
+
+ A = nx.unique(nx.from_numpy(np.stack([M, M])))
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('unique')
+
+ A = nx.logsumexp(Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('logsumexp')
+
+ A = nx.stack([Mb, Mb])
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('stack')
+
+ A = nx.reshape(Mb, (5, 3, 2))
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('reshape')
+
+ sp_Mb = nx.coo_matrix(sp_datab, sp_rowb, sp_colb, shape=(4, 4))
+ nx.todense(Mb)
+ lst_b.append(nx.to_numpy(nx.todense(sp_Mb)))
+ lst_name.append('coo_matrix')
+
+ assert not nx.issparse(Mb), 'Assert fail on: issparse (expected False)'
+ assert nx.issparse(sp_Mb) or nx.__name__ == "jax", 'Assert fail on: issparse (expected True)'
+
+ A = nx.tocsr(sp_Mb)
+ lst_b.append(nx.to_numpy(nx.todense(A)))
+ lst_name.append('tocsr')
+
+ A = nx.eliminate_zeros(nx.copy(sp_datab), threshold=5.)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('eliminate_zeros (dense)')
+
+ A = nx.eliminate_zeros(sp_Mb)
+ lst_b.append(nx.to_numpy(nx.todense(A)))
+ lst_name.append('eliminate_zeros (sparse)')
+
+ A = nx.where(Mb >= nx.stack([nx.linspace(0, 1, 10)] * 3, axis=1), Mb, 0.0)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('where')
+
+ A = nx.copy(Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('copy')
+
+ assert nx.allclose(Mb, Mb), 'Assert fail on: allclose (expected True)'
+ assert not nx.allclose(2 * Mb, Mb), 'Assert fail on: allclose (expected False)'
+
+ lst_tot.append(lst_b)
+
+ lst_np = lst_tot[0]
+ lst_b = lst_tot[1]
+
+ for a1, a2, name in zip(lst_np, lst_b, lst_name):
+ if not np.allclose(a1, a2):
+ print('Assert fail on: ', name)
+ assert np.allclose(a1, a2, atol=1e-7)
+
+
+def test_random_backends(nx):
+
+ tmp_u = nx.rand()
+
+ assert tmp_u < 1
+
+ tmp_n = nx.randn()
+
+ nx.seed(0)
+ M1 = nx.to_numpy(nx.rand(5, 2))
+ nx.seed(0)
+ M2 = nx.to_numpy(nx.rand(5, 2, type_as=tmp_n))
+
+ assert np.all(M1 >= 0)
+ assert np.all(M1 < 1)
+ assert M1.shape == (5, 2)
+ assert np.allclose(M1, M2)
+
+ nx.seed(0)
+ M1 = nx.to_numpy(nx.randn(5, 2))
+ nx.seed(0)
+ M2 = nx.to_numpy(nx.randn(5, 2, type_as=tmp_u))
+
+ nx.seed(42)
+ v1 = nx.randn()
+ v2 = nx.randn()
+ assert v1 != v2
+
+
+def test_gradients_backends():
+
+ rnd = np.random.RandomState(0)
+ v = rnd.randn(10)
+ c = rnd.randn()
+ e = rnd.randn()
+
+ if torch:
+
+ nx = ot.backend.TorchBackend()
+
+ v2 = torch.tensor(v, requires_grad=True)
+ c2 = torch.tensor(c, requires_grad=True)
+
+ val = c2 * torch.sum(v2 * v2)
+
+ val2 = nx.set_gradients(val, (v2, c2), (v2, c2))
+
+ val2.backward()
+
+ assert torch.equal(v2.grad, v2)
+ assert torch.equal(c2.grad, c2)
+
+ if jax:
+ nx = ot.backend.JaxBackend()
+ with jax.checking_leaks():
+ def fun(a, b, d):
+ val = b * nx.sum(a ** 4) + d
+ return nx.set_gradients(val, (a, b, d), (a, b, 2 * d))
+ grad_val = jax.grad(fun, argnums=(0, 1, 2))(v, c, e)
+
+ np.testing.assert_almost_equal(fun(v, c, e), c * np.sum(v ** 4) + e, decimal=4)
+ np.testing.assert_allclose(grad_val[0], v, atol=1e-4)
+ np.testing.assert_allclose(grad_val[2], 2 * e, atol=1e-4)
diff --git a/test/test_bregman.py b/test/test_bregman.py
index 6aa4e08..830052d 100644
--- a/test/test_bregman.py
+++ b/test/test_bregman.py
@@ -2,15 +2,21 @@
# Author: Remi Flamary <remi.flamary@unice.fr>
# Kilian Fatras <kilian.fatras@irisa.fr>
+# Quang Huy Tran <quang-huy.tran@univ-ubs.fr>
#
# License: MIT License
+from itertools import product
+
import numpy as np
-import ot
import pytest
+import ot
+from ot.backend import torch
+
-def test_sinkhorn():
+@pytest.mark.parametrize("verbose, warn", product([True, False], [True, False]))
+def test_sinkhorn(verbose, warn):
# test sinkhorn
n = 100
rng = np.random.RandomState(0)
@@ -20,14 +26,189 @@ def test_sinkhorn():
M = ot.dist(x, x)
- G = ot.sinkhorn(u, u, M, 1, stopThr=1e-10)
+ G = ot.sinkhorn(u, u, M, 1, stopThr=1e-10, verbose=verbose, warn=warn)
- # check constratints
+ # check constraints
np.testing.assert_allclose(
u, G.sum(1), atol=1e-05) # cf convergence sinkhorn
np.testing.assert_allclose(
u, G.sum(0), atol=1e-05) # cf convergence sinkhorn
+ with pytest.warns(UserWarning):
+ ot.sinkhorn(u, u, M, 1, stopThr=0, numItermax=1)
+
+
+@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized",
+ "sinkhorn_epsilon_scaling",
+ "greenkhorn",
+ "sinkhorn_log"])
+def test_convergence_warning(method):
+ # test sinkhorn
+ n = 100
+ a1 = ot.datasets.make_1D_gauss(n, m=30, s=10)
+ a2 = ot.datasets.make_1D_gauss(n, m=40, s=10)
+ A = np.asarray([a1, a2]).T
+ M = ot.utils.dist0(n)
+
+ with pytest.warns(UserWarning):
+ ot.sinkhorn(a1, a2, M, 1., method=method, stopThr=0, numItermax=1)
+
+ if method in ["sinkhorn", "sinkhorn_stabilized", "sinkhorn_log"]:
+ with pytest.warns(UserWarning):
+ ot.barycenter(A, M, 1, method=method, stopThr=0, numItermax=1)
+ with pytest.warns(UserWarning):
+ ot.sinkhorn2(a1, a2, M, 1, method=method, stopThr=0, numItermax=1)
+
+
+def test_not_impemented_method():
+ # test sinkhorn
+ w = 10
+ n = w ** 2
+ rng = np.random.RandomState(42)
+ A_img = rng.rand(2, w, w)
+ A_flat = A_img.reshape(n, 2)
+ a1, a2 = A_flat.T
+ M_flat = ot.utils.dist0(n)
+ not_implemented = "new_method"
+ reg = 0.01
+ with pytest.raises(ValueError):
+ ot.sinkhorn(a1, a2, M_flat, reg, method=not_implemented)
+ with pytest.raises(ValueError):
+ ot.sinkhorn2(a1, a2, M_flat, reg, method=not_implemented)
+ with pytest.raises(ValueError):
+ ot.barycenter(A_flat, M_flat, reg, method=not_implemented)
+ with pytest.raises(ValueError):
+ ot.bregman.barycenter_debiased(A_flat, M_flat, reg,
+ method=not_implemented)
+ with pytest.raises(ValueError):
+ ot.bregman.convolutional_barycenter2d(A_img, reg,
+ method=not_implemented)
+ with pytest.raises(ValueError):
+ ot.bregman.convolutional_barycenter2d_debiased(A_img, reg,
+ method=not_implemented)
+
+
+@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"])
+def test_nan_warning(method):
+ # test sinkhorn
+ n = 100
+ a1 = ot.datasets.make_1D_gauss(n, m=30, s=10)
+ a2 = ot.datasets.make_1D_gauss(n, m=40, s=10)
+
+ M = ot.utils.dist0(n)
+ reg = 0
+ with pytest.warns(UserWarning):
+ # warn set to False to avoid catching a convergence warning instead
+ ot.sinkhorn(a1, a2, M, reg, method=method, warn=False)
+
+
+def test_sinkhorn_stabilization():
+ # test sinkhorn
+ n = 100
+ a1 = ot.datasets.make_1D_gauss(n, m=30, s=10)
+ a2 = ot.datasets.make_1D_gauss(n, m=40, s=10)
+ M = ot.utils.dist0(n)
+ reg = 1e-5
+ loss1 = ot.sinkhorn2(a1, a2, M, reg, method="sinkhorn_log")
+ loss2 = ot.sinkhorn2(a1, a2, M, reg, tau=1, method="sinkhorn_stabilized")
+ np.testing.assert_allclose(
+ loss1, loss2, atol=1e-06) # cf convergence sinkhorn
+
+
+@pytest.mark.parametrize("method, verbose, warn",
+ product(["sinkhorn", "sinkhorn_stabilized",
+ "sinkhorn_log"],
+ [True, False], [True, False]))
+def test_sinkhorn_multi_b(method, verbose, warn):
+ # test sinkhorn
+ n = 10
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 2)
+ u = ot.utils.unif(n)
+
+ b = rng.rand(n, 3)
+ b = b / np.sum(b, 0, keepdims=True)
+
+ M = ot.dist(x, x)
+
+ loss0, log = ot.sinkhorn(u, b, M, .1, method=method, stopThr=1e-10,
+ log=True)
+
+ loss = [ot.sinkhorn2(u, b[:, k], M, .1, method=method, stopThr=1e-10,
+ verbose=verbose, warn=warn) for k in range(3)]
+ # check constraints
+ np.testing.assert_allclose(
+ loss0, loss, atol=1e-4) # cf convergence sinkhorn
+
+
+def test_sinkhorn_backends(nx):
+ n_samples = 100
+ n_features = 2
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n_samples, n_features)
+ y = rng.randn(n_samples, n_features)
+ a = ot.utils.unif(n_samples)
+
+ M = ot.dist(x, y)
+
+ G = ot.sinkhorn(a, a, M, 1)
+
+ ab = nx.from_numpy(a)
+ M_nx = nx.from_numpy(M)
+
+ Gb = ot.sinkhorn(ab, ab, M_nx, 1)
+
+ np.allclose(G, nx.to_numpy(Gb))
+
+
+def test_sinkhorn2_backends(nx):
+ n_samples = 100
+ n_features = 2
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n_samples, n_features)
+ y = rng.randn(n_samples, n_features)
+ a = ot.utils.unif(n_samples)
+
+ M = ot.dist(x, y)
+
+ G = ot.sinkhorn(a, a, M, 1)
+
+ ab = nx.from_numpy(a)
+ M_nx = nx.from_numpy(M)
+
+ Gb = ot.sinkhorn2(ab, ab, M_nx, 1)
+
+ np.allclose(G, nx.to_numpy(Gb))
+
+
+def test_sinkhorn2_gradients():
+ n_samples = 100
+ n_features = 2
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n_samples, n_features)
+ y = rng.randn(n_samples, n_features)
+ a = ot.utils.unif(n_samples)
+
+ M = ot.dist(x, y)
+
+ if torch:
+
+ a1 = torch.tensor(a, requires_grad=True)
+ b1 = torch.tensor(a, requires_grad=True)
+ M1 = torch.tensor(M, requires_grad=True)
+
+ val = ot.sinkhorn2(a1, b1, M1, 1)
+
+ val.backward()
+
+ assert a1.shape == a1.grad.shape
+ assert b1.shape == b1.grad.shape
+ assert M1.shape == M1.grad.shape
+
def test_sinkhorn_empty():
# test sinkhorn
@@ -39,21 +220,27 @@ def test_sinkhorn_empty():
M = ot.dist(x, x)
+ G, log = ot.sinkhorn([], [], M, 1, stopThr=1e-10, method="sinkhorn_log",
+ verbose=True, log=True)
+ # check constraints
+ np.testing.assert_allclose(u, G.sum(1), atol=1e-05)
+ np.testing.assert_allclose(u, G.sum(0), atol=1e-05)
+
G, log = ot.sinkhorn([], [], M, 1, stopThr=1e-10, verbose=True, log=True)
- # check constratints
+ # check constraints
np.testing.assert_allclose(u, G.sum(1), atol=1e-05)
np.testing.assert_allclose(u, G.sum(0), atol=1e-05)
G, log = ot.sinkhorn([], [], M, 1, stopThr=1e-10,
method='sinkhorn_stabilized', verbose=True, log=True)
- # check constratints
+ # check constraints
np.testing.assert_allclose(u, G.sum(1), atol=1e-05)
np.testing.assert_allclose(u, G.sum(0), atol=1e-05)
G, log = ot.sinkhorn(
[], [], M, 1, stopThr=1e-10, method='sinkhorn_epsilon_scaling',
verbose=True, log=True)
- # check constratints
+ # check constraints
np.testing.assert_allclose(u, G.sum(1), atol=1e-05)
np.testing.assert_allclose(u, G.sum(0), atol=1e-05)
@@ -61,7 +248,8 @@ def test_sinkhorn_empty():
ot.sinkhorn([], [], M, 1, method='greenkhorn', stopThr=1e-10, log=True)
-def test_sinkhorn_variants():
+@pytest.skip_backend("jax")
+def test_sinkhorn_variants(nx):
# test sinkhorn
n = 100
rng = np.random.RandomState(0)
@@ -71,22 +259,131 @@ def test_sinkhorn_variants():
M = ot.dist(x, x)
- G0 = ot.sinkhorn(u, u, M, 1, method='sinkhorn', stopThr=1e-10)
- Gs = ot.sinkhorn(u, u, M, 1, method='sinkhorn_stabilized', stopThr=1e-10)
- Ges = ot.sinkhorn(
- u, u, M, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10)
- G_green = ot.sinkhorn(u, u, M, 1, method='greenkhorn', stopThr=1e-10)
+ ub = nx.from_numpy(u)
+ M_nx = nx.from_numpy(M)
+
+ G = ot.sinkhorn(u, u, M, 1, method='sinkhorn', stopThr=1e-10)
+ Gl = nx.to_numpy(ot.sinkhorn(ub, ub, M_nx, 1, method='sinkhorn_log', stopThr=1e-10))
+ G0 = nx.to_numpy(ot.sinkhorn(ub, ub, M_nx, 1, method='sinkhorn', stopThr=1e-10))
+ Gs = nx.to_numpy(ot.sinkhorn(ub, ub, M_nx, 1, method='sinkhorn_stabilized', stopThr=1e-10))
+ Ges = nx.to_numpy(ot.sinkhorn(
+ ub, ub, M_nx, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10))
+ G_green = nx.to_numpy(ot.sinkhorn(ub, ub, M_nx, 1, method='greenkhorn', stopThr=1e-10))
# check values
+ np.testing.assert_allclose(G, G0, atol=1e-05)
+ np.testing.assert_allclose(G, Gl, atol=1e-05)
np.testing.assert_allclose(G0, Gs, atol=1e-05)
np.testing.assert_allclose(G0, Ges, atol=1e-05)
np.testing.assert_allclose(G0, G_green, atol=1e-5)
- print(G0, G_green)
+
+
+@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized",
+ "sinkhorn_epsilon_scaling",
+ "greenkhorn",
+ "sinkhorn_log"])
+@pytest.skip_arg(("nx", "method"), ("jax", "sinkhorn_epsilon_scaling"), reason="jax does not support sinkhorn_epsilon_scaling", getter=str)
+@pytest.skip_arg(("nx", "method"), ("jax", "greenkhorn"), reason="jax does not support greenkhorn", getter=str)
+def test_sinkhorn_variants_dtype_device(nx, method):
+ n = 100
+
+ x = np.random.randn(n, 2)
+ u = ot.utils.unif(n)
+
+ M = ot.dist(x, x)
+
+ for tp in nx.__type_list__:
+ print(nx.dtype_device(tp))
+
+ ub = nx.from_numpy(u, type_as=tp)
+ Mb = nx.from_numpy(M, type_as=tp)
+
+ Gb = ot.sinkhorn(ub, ub, Mb, 1, method=method, stopThr=1e-10)
+
+ nx.assert_same_dtype_device(Mb, Gb)
+
+
+@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized", "sinkhorn_log"])
+def test_sinkhorn2_variants_dtype_device(nx, method):
+ n = 100
+
+ x = np.random.randn(n, 2)
+ u = ot.utils.unif(n)
+
+ M = ot.dist(x, x)
+
+ for tp in nx.__type_list__:
+ print(nx.dtype_device(tp))
+
+ ub = nx.from_numpy(u, type_as=tp)
+ Mb = nx.from_numpy(M, type_as=tp)
+
+ lossb = ot.sinkhorn2(ub, ub, Mb, 1, method=method, stopThr=1e-10)
+
+ nx.assert_same_dtype_device(Mb, lossb)
+
+
+@pytest.skip_backend("jax")
+def test_sinkhorn_variants_multi_b(nx):
+ # test sinkhorn
+ n = 50
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 2)
+ u = ot.utils.unif(n)
+
+ b = rng.rand(n, 3)
+ b = b / np.sum(b, 0, keepdims=True)
+
+ M = ot.dist(x, x)
+
+ ub = nx.from_numpy(u)
+ bb = nx.from_numpy(b)
+ M_nx = nx.from_numpy(M)
+
+ G = ot.sinkhorn(u, b, M, 1, method='sinkhorn', stopThr=1e-10)
+ Gl = nx.to_numpy(ot.sinkhorn(ub, bb, M_nx, 1, method='sinkhorn_log', stopThr=1e-10))
+ G0 = nx.to_numpy(ot.sinkhorn(ub, bb, M_nx, 1, method='sinkhorn', stopThr=1e-10))
+ Gs = nx.to_numpy(ot.sinkhorn(ub, bb, M_nx, 1, method='sinkhorn_stabilized', stopThr=1e-10))
+
+ # check values
+ np.testing.assert_allclose(G, G0, atol=1e-05)
+ np.testing.assert_allclose(G, Gl, atol=1e-05)
+ np.testing.assert_allclose(G0, Gs, atol=1e-05)
+
+
+@pytest.skip_backend("jax")
+def test_sinkhorn2_variants_multi_b(nx):
+ # test sinkhorn
+ n = 50
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 2)
+ u = ot.utils.unif(n)
+
+ b = rng.rand(n, 3)
+ b = b / np.sum(b, 0, keepdims=True)
+
+ M = ot.dist(x, x)
+
+ ub = nx.from_numpy(u)
+ bb = nx.from_numpy(b)
+ M_nx = nx.from_numpy(M)
+
+ G = ot.sinkhorn2(u, b, M, 1, method='sinkhorn', stopThr=1e-10)
+ Gl = nx.to_numpy(ot.sinkhorn2(ub, bb, M_nx, 1, method='sinkhorn_log', stopThr=1e-10))
+ G0 = nx.to_numpy(ot.sinkhorn2(ub, bb, M_nx, 1, method='sinkhorn', stopThr=1e-10))
+ Gs = nx.to_numpy(ot.sinkhorn2(ub, bb, M_nx, 1, method='sinkhorn_stabilized', stopThr=1e-10))
+
+ # check values
+ np.testing.assert_allclose(G, G0, atol=1e-05)
+ np.testing.assert_allclose(G, Gl, atol=1e-05)
+ np.testing.assert_allclose(G0, Gs, atol=1e-05)
def test_sinkhorn_variants_log():
# test sinkhorn
- n = 100
+ n = 50
rng = np.random.RandomState(0)
x = rng.randn(n, 2)
@@ -95,20 +392,87 @@ def test_sinkhorn_variants_log():
M = ot.dist(x, x)
G0, log0 = ot.sinkhorn(u, u, M, 1, method='sinkhorn', stopThr=1e-10, log=True)
+ Gl, logl = ot.sinkhorn(u, u, M, 1, method='sinkhorn_log', stopThr=1e-10, log=True)
Gs, logs = ot.sinkhorn(u, u, M, 1, method='sinkhorn_stabilized', stopThr=1e-10, log=True)
Ges, loges = ot.sinkhorn(
- u, u, M, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10, log=True)
+ u, u, M, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10, log=True,)
G_green, loggreen = ot.sinkhorn(u, u, M, 1, method='greenkhorn', stopThr=1e-10, log=True)
# check values
np.testing.assert_allclose(G0, Gs, atol=1e-05)
+ np.testing.assert_allclose(G0, Gl, atol=1e-05)
np.testing.assert_allclose(G0, Ges, atol=1e-05)
np.testing.assert_allclose(G0, G_green, atol=1e-5)
- print(G0, G_green)
-@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"])
-def test_barycenter(method):
+@pytest.mark.parametrize("verbose, warn", product([True, False], [True, False]))
+def test_sinkhorn_variants_log_multib(verbose, warn):
+ # test sinkhorn
+ n = 50
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 2)
+ u = ot.utils.unif(n)
+ b = rng.rand(n, 3)
+ b = b / np.sum(b, 0, keepdims=True)
+
+ M = ot.dist(x, x)
+
+ G0, log0 = ot.sinkhorn(u, b, M, 1, method='sinkhorn', stopThr=1e-10, log=True)
+ Gl, logl = ot.sinkhorn(u, b, M, 1, method='sinkhorn_log', stopThr=1e-10, log=True,
+ verbose=verbose, warn=warn)
+ Gs, logs = ot.sinkhorn(u, b, M, 1, method='sinkhorn_stabilized', stopThr=1e-10, log=True,
+ verbose=verbose, warn=warn)
+
+ # check values
+ np.testing.assert_allclose(G0, Gs, atol=1e-05)
+ np.testing.assert_allclose(G0, Gl, atol=1e-05)
+
+
+@pytest.mark.parametrize("method, verbose, warn",
+ product(["sinkhorn", "sinkhorn_stabilized", "sinkhorn_log"],
+ [True, False], [True, False]))
+def test_barycenter(nx, method, verbose, warn):
+ n_bins = 100 # nb bins
+
+ # Gaussian distributions
+ a1 = ot.datasets.make_1D_gauss(n_bins, m=30, s=10) # m= mean, s= std
+ a2 = ot.datasets.make_1D_gauss(n_bins, m=40, s=10)
+
+ # creating matrix A containing all distributions
+ A = np.vstack((a1, a2)).T
+
+ # loss matrix + normalization
+ M = ot.utils.dist0(n_bins)
+ M /= M.max()
+
+ alpha = 0.5 # 0<=alpha<=1
+ weights = np.array([1 - alpha, alpha])
+
+ A_nx = nx.from_numpy(A)
+ M_nx = nx.from_numpy(M)
+ weights_nx = nx.from_numpy(weights)
+ reg = 1e-2
+
+ if nx.__name__ == "jax" and method == "sinkhorn_log":
+ with pytest.raises(NotImplementedError):
+ ot.bregman.barycenter(A_nx, M_nx, reg, weights, method=method)
+ else:
+ # wasserstein
+ bary_wass_np = ot.bregman.barycenter(A, M, reg, weights, method=method, verbose=verbose, warn=warn)
+ bary_wass, _ = ot.bregman.barycenter(A_nx, M_nx, reg, weights_nx, method=method, log=True)
+ bary_wass = nx.to_numpy(bary_wass)
+
+ np.testing.assert_allclose(1, np.sum(bary_wass))
+ np.testing.assert_allclose(bary_wass, bary_wass_np)
+
+ ot.bregman.barycenter(A_nx, M_nx, reg, log=True)
+
+
+@pytest.mark.parametrize("method, verbose, warn",
+ product(["sinkhorn", "sinkhorn_log"],
+ [True, False], [True, False]))
+def test_barycenter_debiased(nx, method, verbose, warn):
n_bins = 100 # nb bins
# Gaussian distributions
@@ -125,16 +489,61 @@ def test_barycenter(method):
alpha = 0.5 # 0<=alpha<=1
weights = np.array([1 - alpha, alpha])
+ A_nx = nx.from_numpy(A)
+ M_nx = nx.from_numpy(M)
+ weights_nx = nx.from_numpy(weights)
+
# wasserstein
reg = 1e-2
- bary_wass, log = ot.bregman.barycenter(A, M, reg, weights, method=method, log=True)
+ if nx.__name__ == "jax" and method == "sinkhorn_log":
+ with pytest.raises(NotImplementedError):
+ ot.bregman.barycenter_debiased(A_nx, M_nx, reg, weights, method=method)
+ else:
+ bary_wass_np = ot.bregman.barycenter_debiased(A, M, reg, weights, method=method,
+ verbose=verbose, warn=warn)
+ bary_wass, _ = ot.bregman.barycenter_debiased(A_nx, M_nx, reg, weights_nx, method=method, log=True)
+ bary_wass = nx.to_numpy(bary_wass)
+
+ np.testing.assert_allclose(1, np.sum(bary_wass), atol=1e-3)
+ np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-5)
- np.testing.assert_allclose(1, np.sum(bary_wass))
+ ot.bregman.barycenter_debiased(A_nx, M_nx, reg, log=True, verbose=False)
- ot.bregman.barycenter(A, M, reg, log=True, verbose=True)
+@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_log"])
+def test_convergence_warning_barycenters(method):
+ w = 10
+ n_bins = w ** 2 # nb bins
+
+ # Gaussian distributions
+ a1 = ot.datasets.make_1D_gauss(n_bins, m=30, s=10) # m= mean, s= std
+ a2 = ot.datasets.make_1D_gauss(n_bins, m=40, s=10)
+
+ # creating matrix A containing all distributions
+ A = np.vstack((a1, a2)).T
+ A_img = A.reshape(2, w, w)
+ A_img /= A_img.sum((1, 2))[:, None, None]
+
+ # loss matrix + normalization
+ M = ot.utils.dist0(n_bins)
+ M /= M.max()
-def test_barycenter_stabilization():
+ alpha = 0.5 # 0<=alpha<=1
+ weights = np.array([1 - alpha, alpha])
+ reg = 0.1
+ with pytest.warns(UserWarning):
+ ot.bregman.barycenter_debiased(A, M, reg, weights, method=method, numItermax=1)
+ with pytest.warns(UserWarning):
+ ot.bregman.barycenter(A, M, reg, weights, method=method, numItermax=1)
+ with pytest.warns(UserWarning):
+ ot.bregman.convolutional_barycenter2d(A_img, reg, weights,
+ method=method, numItermax=1)
+ with pytest.warns(UserWarning):
+ ot.bregman.convolutional_barycenter2d_debiased(A_img, reg, weights,
+ method=method, numItermax=1)
+
+
+def test_barycenter_stabilization(nx):
n_bins = 100 # nb bins
# Gaussian distributions
@@ -151,22 +560,64 @@ def test_barycenter_stabilization():
alpha = 0.5 # 0<=alpha<=1
weights = np.array([1 - alpha, alpha])
+ A_nx = nx.from_numpy(A)
+ M_nx = nx.from_numpy(M)
+ weights_b = nx.from_numpy(weights)
+
# wasserstein
reg = 1e-2
- bar_stable = ot.bregman.barycenter(A, M, reg, weights,
- method="sinkhorn_stabilized",
- stopThr=1e-8, verbose=True)
- bar = ot.bregman.barycenter(A, M, reg, weights, method="sinkhorn",
- stopThr=1e-8, verbose=True)
+ bar_np = ot.bregman.barycenter(A, M, reg, weights, method="sinkhorn", stopThr=1e-8, verbose=True)
+ bar_stable = nx.to_numpy(ot.bregman.barycenter(
+ A_nx, M_nx, reg, weights_b, method="sinkhorn_stabilized",
+ stopThr=1e-8, verbose=True
+ ))
+ bar = nx.to_numpy(ot.bregman.barycenter(
+ A_nx, M_nx, reg, weights_b, method="sinkhorn",
+ stopThr=1e-8, verbose=True
+ ))
np.testing.assert_allclose(bar, bar_stable)
+ np.testing.assert_allclose(bar, bar_np)
+
+
+@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_log"])
+def test_wasserstein_bary_2d(nx, method):
+ size = 20 # size of a square image
+ a1 = np.random.rand(size, size)
+ a1 += a1.min()
+ a1 = a1 / np.sum(a1)
+ a2 = np.random.rand(size, size)
+ a2 += a2.min()
+ a2 = a2 / np.sum(a2)
+ # creating matrix A containing all distributions
+ A = np.zeros((2, size, size))
+ A[0, :, :] = a1
+ A[1, :, :] = a2
+ A_nx = nx.from_numpy(A)
-def test_wasserstein_bary_2d():
- size = 100 # size of a square image
- a1 = np.random.randn(size, size)
+ # wasserstein
+ reg = 1e-2
+ if nx.__name__ == "jax" and method == "sinkhorn_log":
+ with pytest.raises(NotImplementedError):
+ ot.bregman.convolutional_barycenter2d(A_nx, reg, method=method)
+ else:
+ bary_wass_np = ot.bregman.convolutional_barycenter2d(A, reg, method=method)
+ bary_wass = nx.to_numpy(ot.bregman.convolutional_barycenter2d(A_nx, reg, method=method))
+
+ np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3)
+ np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3)
+
+ # help in checking if log and verbose do not bug the function
+ ot.bregman.convolutional_barycenter2d(A, reg, log=True, verbose=True)
+
+
+@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_log"])
+def test_wasserstein_bary_2d_debiased(nx, method):
+ size = 20 # size of a square image
+ a1 = np.random.rand(size, size)
a1 += a1.min()
a1 = a1 / np.sum(a1)
- a2 = np.random.randn(size, size)
+ a2 = np.random.rand(size, size)
a2 += a2.min()
a2 = a2 / np.sum(a2)
# creating matrix A containing all distributions
@@ -174,17 +625,25 @@ def test_wasserstein_bary_2d():
A[0, :, :] = a1
A[1, :, :] = a2
+ A_nx = nx.from_numpy(A)
+
# wasserstein
reg = 1e-2
- bary_wass = ot.bregman.convolutional_barycenter2d(A, reg)
+ if nx.__name__ == "jax" and method == "sinkhorn_log":
+ with pytest.raises(NotImplementedError):
+ ot.bregman.convolutional_barycenter2d_debiased(A_nx, reg, method=method)
+ else:
+ bary_wass_np = ot.bregman.convolutional_barycenter2d_debiased(A, reg, method=method)
+ bary_wass = nx.to_numpy(ot.bregman.convolutional_barycenter2d_debiased(A_nx, reg, method=method))
- np.testing.assert_allclose(1, np.sum(bary_wass))
+ np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3)
+ np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3)
- # help in checking if log and verbose do not bug the function
- ot.bregman.convolutional_barycenter2d(A, reg, log=True, verbose=True)
+ # help in checking if log and verbose do not bug the function
+ ot.bregman.convolutional_barycenter2d(A, reg, log=True, verbose=True)
-def test_unmix():
+def test_unmix(nx):
n_bins = 50 # nb bins
# Gaussian distributions
@@ -204,41 +663,58 @@ def test_unmix():
M0 /= M0.max()
h0 = ot.unif(2)
+ ab = nx.from_numpy(a)
+ Db = nx.from_numpy(D)
+ M_nx = nx.from_numpy(M)
+ M0b = nx.from_numpy(M0)
+ h0b = nx.from_numpy(h0)
+
# wasserstein
reg = 1e-3
- um = ot.bregman.unmix(a, D, M, M0, h0, reg, 1, alpha=0.01, )
+ um_np = ot.bregman.unmix(a, D, M, M0, h0, reg, 1, alpha=0.01)
+ um = nx.to_numpy(ot.bregman.unmix(ab, Db, M_nx, M0b, h0b, reg, 1, alpha=0.01))
np.testing.assert_allclose(1, np.sum(um), rtol=1e-03, atol=1e-03)
np.testing.assert_allclose([0.5, 0.5], um, rtol=1e-03, atol=1e-03)
+ np.testing.assert_allclose(um, um_np)
- ot.bregman.unmix(a, D, M, M0, h0, reg,
+ ot.bregman.unmix(ab, Db, M_nx, M0b, h0b, reg,
1, alpha=0.01, log=True, verbose=True)
-def test_empirical_sinkhorn():
+def test_empirical_sinkhorn(nx):
# test sinkhorn
- n = 100
+ n = 10
a = ot.unif(n)
b = ot.unif(n)
- X_s = np.reshape(np.arange(n), (n, 1))
- X_t = np.reshape(np.arange(0, n), (n, 1))
+ X_s = np.reshape(1.0 * np.arange(n), (n, 1))
+ X_t = np.reshape(1.0 * np.arange(0, n), (n, 1))
M = ot.dist(X_s, X_t)
- M_m = ot.dist(X_s, X_t, metric='minkowski')
+ M_m = ot.dist(X_s, X_t, metric='euclidean')
+
+ ab = nx.from_numpy(a)
+ bb = nx.from_numpy(b)
+ X_sb = nx.from_numpy(X_s)
+ X_tb = nx.from_numpy(X_t)
+ M_nx = nx.from_numpy(M, type_as=ab)
+ M_mb = nx.from_numpy(M_m, type_as=ab)
- G_sqe = ot.bregman.empirical_sinkhorn(X_s, X_t, 1)
- sinkhorn_sqe = ot.sinkhorn(a, b, M, 1)
+ G_sqe = nx.to_numpy(ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1))
+ sinkhorn_sqe = nx.to_numpy(ot.sinkhorn(ab, bb, M_nx, 1))
- G_log, log_es = ot.bregman.empirical_sinkhorn(X_s, X_t, 0.1, log=True)
- sinkhorn_log, log_s = ot.sinkhorn(a, b, M, 0.1, log=True)
+ G_log, log_es = ot.bregman.empirical_sinkhorn(X_sb, X_tb, 0.1, log=True)
+ G_log = nx.to_numpy(G_log)
+ sinkhorn_log, log_s = ot.sinkhorn(ab, bb, M_nx, 0.1, log=True)
+ sinkhorn_log = nx.to_numpy(sinkhorn_log)
- G_m = ot.bregman.empirical_sinkhorn(X_s, X_t, 1, metric='minkowski')
- sinkhorn_m = ot.sinkhorn(a, b, M_m, 1)
+ G_m = nx.to_numpy(ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1, metric='euclidean'))
+ sinkhorn_m = nx.to_numpy(ot.sinkhorn(ab, bb, M_mb, 1))
- loss_emp_sinkhorn = ot.bregman.empirical_sinkhorn2(X_s, X_t, 1)
- loss_sinkhorn = ot.sinkhorn2(a, b, M, 1)
+ loss_emp_sinkhorn = nx.to_numpy(ot.bregman.empirical_sinkhorn2(X_sb, X_tb, 1))
+ loss_sinkhorn = nx.to_numpy(ot.sinkhorn2(ab, bb, M_nx, 1))
- # check constratints
+ # check constraints
np.testing.assert_allclose(
sinkhorn_sqe.sum(1), G_sqe.sum(1), atol=1e-05) # metric sqeuclidian
np.testing.assert_allclose(
@@ -254,34 +730,98 @@ def test_empirical_sinkhorn():
np.testing.assert_allclose(loss_emp_sinkhorn, loss_sinkhorn, atol=1e-05)
-def test_empirical_sinkhorn_divergence():
- # Test sinkhorn divergence
+def test_lazy_empirical_sinkhorn(nx):
+ # test sinkhorn
n = 10
a = ot.unif(n)
b = ot.unif(n)
+ numIterMax = 1000
+
+ X_s = np.reshape(np.arange(n, dtype=np.float64), (n, 1))
+ X_t = np.reshape(np.arange(0, n, dtype=np.float64), (n, 1))
+ M = ot.dist(X_s, X_t)
+ M_m = ot.dist(X_s, X_t, metric='euclidean')
+
+ ab = nx.from_numpy(a)
+ bb = nx.from_numpy(b)
+ X_sb = nx.from_numpy(X_s)
+ X_tb = nx.from_numpy(X_t)
+ M_nx = nx.from_numpy(M, type_as=ab)
+ M_mb = nx.from_numpy(M_m, type_as=ab)
+
+ f, g = ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1, numIterMax=numIterMax, isLazy=True, batchSize=(1, 3), verbose=True)
+ f, g = nx.to_numpy(f), nx.to_numpy(g)
+ G_sqe = np.exp(f[:, None] + g[None, :] - M / 1)
+ sinkhorn_sqe = nx.to_numpy(ot.sinkhorn(ab, bb, M_nx, 1))
+
+ f, g, log_es = ot.bregman.empirical_sinkhorn(X_sb, X_tb, 0.1, numIterMax=numIterMax, isLazy=True, batchSize=1, log=True)
+ f, g = nx.to_numpy(f), nx.to_numpy(g)
+ G_log = np.exp(f[:, None] + g[None, :] - M / 0.1)
+ sinkhorn_log, log_s = ot.sinkhorn(ab, bb, M_nx, 0.1, log=True)
+ sinkhorn_log = nx.to_numpy(sinkhorn_log)
+
+ f, g = ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1, metric='euclidean', numIterMax=numIterMax, isLazy=True, batchSize=1)
+ f, g = nx.to_numpy(f), nx.to_numpy(g)
+ G_m = np.exp(f[:, None] + g[None, :] - M_m / 1)
+ sinkhorn_m = nx.to_numpy(ot.sinkhorn(ab, bb, M_mb, 1))
+
+ loss_emp_sinkhorn, log = ot.bregman.empirical_sinkhorn2(X_sb, X_tb, 1, numIterMax=numIterMax, isLazy=True, batchSize=1, log=True)
+ loss_emp_sinkhorn = nx.to_numpy(loss_emp_sinkhorn)
+ loss_sinkhorn = nx.to_numpy(ot.sinkhorn2(ab, bb, M_nx, 1))
+
+ # check constraints
+ np.testing.assert_allclose(
+ sinkhorn_sqe.sum(1), G_sqe.sum(1), atol=1e-05) # metric sqeuclidian
+ np.testing.assert_allclose(
+ sinkhorn_sqe.sum(0), G_sqe.sum(0), atol=1e-05) # metric sqeuclidian
+ np.testing.assert_allclose(
+ sinkhorn_log.sum(1), G_log.sum(1), atol=1e-05) # log
+ np.testing.assert_allclose(
+ sinkhorn_log.sum(0), G_log.sum(0), atol=1e-05) # log
+ np.testing.assert_allclose(
+ sinkhorn_m.sum(1), G_m.sum(1), atol=1e-05) # metric euclidian
+ np.testing.assert_allclose(
+ sinkhorn_m.sum(0), G_m.sum(0), atol=1e-05) # metric euclidian
+ np.testing.assert_allclose(loss_emp_sinkhorn, loss_sinkhorn, atol=1e-05)
+
+
+def test_empirical_sinkhorn_divergence(nx):
+ # Test sinkhorn divergence
+ n = 10
+ a = np.linspace(1, n, n)
+ a /= a.sum()
+ b = ot.unif(n)
X_s = np.reshape(np.arange(n), (n, 1))
X_t = np.reshape(np.arange(0, n * 2, 2), (n, 1))
M = ot.dist(X_s, X_t)
M_s = ot.dist(X_s, X_s)
M_t = ot.dist(X_t, X_t)
- emp_sinkhorn_div = ot.bregman.empirical_sinkhorn_divergence(X_s, X_t, 1)
- sinkhorn_div = (ot.sinkhorn2(a, b, M, 1) - 1 / 2 * ot.sinkhorn2(a, a, M_s, 1) - 1 / 2 * ot.sinkhorn2(b, b, M_t, 1))
-
- emp_sinkhorn_div_log, log_es = ot.bregman.empirical_sinkhorn_divergence(X_s, X_t, 1, log=True)
- sink_div_log_ab, log_s_ab = ot.sinkhorn2(a, b, M, 1, log=True)
- sink_div_log_a, log_s_a = ot.sinkhorn2(a, a, M_s, 1, log=True)
- sink_div_log_b, log_s_b = ot.sinkhorn2(b, b, M_t, 1, log=True)
- sink_div_log = sink_div_log_ab - 1 / 2 * (sink_div_log_a + sink_div_log_b)
-
- # check constratints
+ ab = nx.from_numpy(a)
+ bb = nx.from_numpy(b)
+ X_sb = nx.from_numpy(X_s)
+ X_tb = nx.from_numpy(X_t)
+ M_nx = nx.from_numpy(M, type_as=ab)
+ M_sb = nx.from_numpy(M_s, type_as=ab)
+ M_tb = nx.from_numpy(M_t, type_as=ab)
+
+ emp_sinkhorn_div = nx.to_numpy(ot.bregman.empirical_sinkhorn_divergence(X_sb, X_tb, 1, a=ab, b=bb))
+ sinkhorn_div = nx.to_numpy(
+ ot.sinkhorn2(ab, bb, M_nx, 1)
+ - 1 / 2 * ot.sinkhorn2(ab, ab, M_sb, 1)
+ - 1 / 2 * ot.sinkhorn2(bb, bb, M_tb, 1)
+ )
+ emp_sinkhorn_div_np = ot.bregman.empirical_sinkhorn_divergence(X_s, X_t, 1, a=a, b=b)
+
+ # check constraints
+ np.testing.assert_allclose(emp_sinkhorn_div, emp_sinkhorn_div_np, atol=1e-05)
np.testing.assert_allclose(
emp_sinkhorn_div, sinkhorn_div, atol=1e-05) # cf conv emp sinkhorn
- np.testing.assert_allclose(
- emp_sinkhorn_div_log, sink_div_log, atol=1e-05) # cf conv emp sinkhorn
+
+ ot.bregman.empirical_sinkhorn_divergence(X_sb, X_tb, 1, a=ab, b=bb, log=True)
-def test_stabilized_vs_sinkhorn_multidim():
+def test_stabilized_vs_sinkhorn_multidim(nx):
# test if stable version matches sinkhorn
# for multidimensional inputs
n = 100
@@ -297,12 +837,21 @@ def test_stabilized_vs_sinkhorn_multidim():
M = ot.utils.dist0(n)
M /= np.median(M)
epsilon = 0.1
- G, log = ot.bregman.sinkhorn(a, b, M, reg=epsilon,
+
+ ab = nx.from_numpy(a)
+ bb = nx.from_numpy(b)
+ M_nx = nx.from_numpy(M, type_as=ab)
+
+ G_np, _ = ot.bregman.sinkhorn(a, b, M, reg=epsilon, method="sinkhorn", log=True)
+ G, log = ot.bregman.sinkhorn(ab, bb, M_nx, reg=epsilon,
method="sinkhorn_stabilized",
log=True)
- G2, log2 = ot.bregman.sinkhorn(a, b, M, epsilon,
+ G = nx.to_numpy(G)
+ G2, log2 = ot.bregman.sinkhorn(ab, bb, M_nx, epsilon,
method="sinkhorn", log=True)
+ G2 = nx.to_numpy(G2)
+ np.testing.assert_allclose(G_np, G2)
np.testing.assert_allclose(G, G2)
@@ -320,8 +869,9 @@ def test_implemented_methods():
# make dists unbalanced
b = ot.utils.unif(n)
A = rng.rand(n, 2)
+ A /= A.sum(0, keepdims=True)
M = ot.dist(x, x)
- epsilon = 1.
+ epsilon = 1.0
for method in IMPLEMENTED_METHODS:
ot.bregman.sinkhorn(a, b, M, epsilon, method=method)
@@ -338,7 +888,9 @@ def test_implemented_methods():
ot.bregman.sinkhorn2(a, b, M, epsilon, method=method)
-def test_screenkhorn():
+@pytest.skip_backend("jax")
+@pytest.mark.filterwarnings("ignore:Bottleneck")
+def test_screenkhorn(nx):
# test screenkhorn
rng = np.random.RandomState(0)
n = 100
@@ -347,17 +899,31 @@ def test_screenkhorn():
x = rng.randn(n, 2)
M = ot.dist(x, x)
+
+ ab = nx.from_numpy(a)
+ bb = nx.from_numpy(b)
+ M_nx = nx.from_numpy(M, type_as=ab)
+
+ # np sinkhorn
+ G_sink_np = ot.sinkhorn(a, b, M, 1e-03)
# sinkhorn
- G_sink = ot.sinkhorn(a, b, M, 1e-03)
+ G_sink = nx.to_numpy(ot.sinkhorn(ab, bb, M_nx, 1e-03))
# screenkhorn
- G_screen = ot.bregman.screenkhorn(a, b, M, 1e-03, uniform=True, verbose=True)
+ G_screen = nx.to_numpy(ot.bregman.screenkhorn(ab, bb, M_nx, 1e-03, uniform=True, verbose=True))
# check marginals
+ np.testing.assert_allclose(G_sink_np, G_sink)
np.testing.assert_allclose(G_sink.sum(0), G_screen.sum(0), atol=1e-02)
np.testing.assert_allclose(G_sink.sum(1), G_screen.sum(1), atol=1e-02)
-def test_convolutional_barycenter_non_square():
+def test_convolutional_barycenter_non_square(nx):
# test for image with height not equal width
A = np.ones((2, 2, 3)) / (2 * 3)
- b = ot.bregman.convolutional_barycenter2d(A, 1e-03)
+ A_nx = nx.from_numpy(A)
+
+ b_np = ot.bregman.convolutional_barycenter2d(A, 1e-03)
+ b = nx.to_numpy(ot.bregman.convolutional_barycenter2d(A_nx, 1e-03))
+
+ np.testing.assert_allclose(np.ones((2, 3)) / (2 * 3), b, atol=1e-02)
np.testing.assert_allclose(np.ones((2, 3)) / (2 * 3), b, atol=1e-02)
+ np.testing.assert_allclose(b, b_np)
diff --git a/test/test_da.py b/test/test_da.py
index 3b28119..9f2bb50 100644
--- a/test/test_da.py
+++ b/test/test_da.py
@@ -6,11 +6,18 @@
import numpy as np
from numpy.testing import assert_allclose, assert_equal
+import pytest
import ot
from ot.datasets import make_data_classif
from ot.utils import unif
+try: # test if cudamat installed
+ import sklearn # noqa: F401
+ nosklearn = False
+except ImportError:
+ nosklearn = True
+
def test_sinkhorn_lpl1_transport_class():
"""test_sinkhorn_transport
@@ -99,8 +106,8 @@ def test_sinkhorn_l1l2_transport_class():
"""test_sinkhorn_transport
"""
- ns = 150
- nt = 200
+ ns = 50
+ nt = 100
Xs, ys = make_data_classif('3gauss', ns)
Xt, yt = make_data_classif('3gauss2', nt)
@@ -441,8 +448,8 @@ def test_mapping_transport_class():
"""test_mapping_transport
"""
- ns = 60
- nt = 120
+ ns = 20
+ nt = 30
Xs, ys = make_data_classif('3gauss', ns)
Xt, yt = make_data_classif('3gauss2', nt)
@@ -558,6 +565,14 @@ def test_mapping_transport_class():
otda.fit(Xs=Xs, Xt=Xt)
assert len(otda.log_.keys()) != 0
+ # check that it does not crash when derphi is very close to 0
+ np.random.seed(39)
+ Xs, ys = make_data_classif('3gauss', ns)
+ Xt, yt = make_data_classif('3gauss2', nt)
+ otda = ot.da.MappingTransport(kernel="gaussian", bias=False)
+ otda.fit(Xs=Xs, Xt=Xt)
+ np.random.seed(None)
+
def test_linear_mapping():
ns = 150
@@ -691,6 +706,7 @@ def test_jcpot_barycenter():
np.testing.assert_allclose(prop, [1 - pt, pt], rtol=1e-3, atol=1e-3)
+@pytest.mark.skipif(nosklearn, reason="No sklearn available")
def test_emd_laplace_class():
"""test_emd_laplace_transport
"""
diff --git a/test/test_dr.py b/test/test_dr.py
index c5df287..741f2ad 100644
--- a/test/test_dr.py
+++ b/test/test_dr.py
@@ -1,6 +1,7 @@
"""Tests for module dr on Dimensionality Reduction """
# Author: Remi Flamary <remi.flamary@unice.fr>
+# Minhui Huang <mhhuang@ucdavis.edu>
#
# License: MIT License
@@ -57,3 +58,64 @@ def test_wda():
projwda(xs)
np.testing.assert_allclose(np.sum(Pwda**2, 0), np.ones(p))
+
+
+@pytest.mark.skipif(nogo, reason="Missing modules (autograd or pymanopt)")
+def test_wda_normalized():
+
+ n_samples = 100 # nb samples in source and target datasets
+ np.random.seed(0)
+
+ # generate gaussian dataset
+ xs, ys = ot.datasets.make_data_classif('gaussrot', n_samples)
+
+ n_features_noise = 8
+
+ xs = np.hstack((xs, np.random.randn(n_samples, n_features_noise)))
+
+ p = 2
+
+ P0 = np.random.randn(10, p)
+ P0 /= P0.sum(0, keepdims=True)
+
+ Pwda, projwda = ot.dr.wda(xs, ys, p, maxiter=10, P0=P0, normalize=True)
+
+ projwda(xs)
+
+ np.testing.assert_allclose(np.sum(Pwda**2, 0), np.ones(p))
+
+
+@pytest.mark.skipif(nogo, reason="Missing modules (autograd or pymanopt)")
+def test_prw():
+ d = 100 # Dimension
+ n = 100 # Number samples
+ k = 3 # Subspace dimension
+ dim = 3
+
+ def fragmented_hypercube(n, d, dim):
+ assert dim <= d
+ assert dim >= 1
+ assert dim == int(dim)
+
+ a = (1. / n) * np.ones(n)
+ b = (1. / n) * np.ones(n)
+
+ # First measure : uniform on the hypercube
+ X = np.random.uniform(-1, 1, size=(n, d))
+
+ # Second measure : fragmentation
+ tmp_y = np.random.uniform(-1, 1, size=(n, d))
+ Y = tmp_y + 2 * np.sign(tmp_y) * np.array(dim * [1] + (d - dim) * [0])
+ return a, b, X, Y
+
+ a, b, X, Y = fragmented_hypercube(n, d, dim)
+
+ tau = 0.002
+ reg = 0.2
+
+ pi, U = ot.dr.projection_robust_wasserstein(X, Y, a, b, tau, reg=reg, k=k, maxiter=1000, verbose=1)
+
+ U0 = np.random.randn(d, k)
+ U0, _ = np.linalg.qr(U0)
+
+ pi, U = ot.dr.projection_robust_wasserstein(X, Y, a, b, tau, U0=U0, reg=reg, k=k, maxiter=1000, verbose=1)
diff --git a/test/test_gromov.py b/test/test_gromov.py
index 43da9fc..c4bc04c 100644
--- a/test/test_gromov.py
+++ b/test/test_gromov.py
@@ -8,9 +8,13 @@
import numpy as np
import ot
+from ot.backend import NumpyBackend
+from ot.backend import torch
+import pytest
-def test_gromov():
+
+def test_gromov(nx):
n_samples = 50 # nb samples
mu_s = np.array([0, 0])
@@ -29,37 +33,121 @@ def test_gromov():
C1 /= C1.max()
C2 /= C2.max()
+ C1b = nx.from_numpy(C1)
+ C2b = nx.from_numpy(C2)
+ pb = nx.from_numpy(p)
+ qb = nx.from_numpy(q)
+
G = ot.gromov.gromov_wasserstein(C1, C2, p, q, 'square_loss', verbose=True)
+ Gb = nx.to_numpy(ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', verbose=True))
- # check constratints
+ # check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
np.testing.assert_allclose(
- p, G.sum(1), atol=1e-04) # cf convergence gromov
+ p, Gb.sum(1), atol=1e-04) # cf convergence gromov
np.testing.assert_allclose(
- q, G.sum(0), atol=1e-04) # cf convergence gromov
+ q, Gb.sum(0), atol=1e-04) # cf convergence gromov
Id = (1 / (1.0 * n_samples)) * np.eye(n_samples, n_samples)
- np.testing.assert_allclose(
- G, np.flipud(Id), atol=1e-04)
+ np.testing.assert_allclose(Gb, np.flipud(Id), atol=1e-04)
gw, log = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', log=True)
+ gwb, logb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', log=True)
gw_val = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', log=False)
+ gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', log=False)
G = log['T']
+ Gb = nx.to_numpy(logb['T'])
- np.testing.assert_allclose(gw, 0, atol=1e-1, rtol=1e-1)
+ np.testing.assert_allclose(gw, gwb, atol=1e-06)
+ np.testing.assert_allclose(gwb, 0, atol=1e-1, rtol=1e-1)
- np.testing.assert_allclose(gw, gw_val, atol=1e-1, rtol=1e-1) # cf log=False
+ np.testing.assert_allclose(gw_val, gw_valb, atol=1e-06)
+ np.testing.assert_allclose(gwb, gw_valb, atol=1e-1, rtol=1e-1) # cf log=False
- # check constratints
+ # check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
np.testing.assert_allclose(
- p, G.sum(1), atol=1e-04) # cf convergence gromov
+ p, Gb.sum(1), atol=1e-04) # cf convergence gromov
np.testing.assert_allclose(
- q, G.sum(0), atol=1e-04) # cf convergence gromov
+ q, Gb.sum(0), atol=1e-04) # cf convergence gromov
+
+
+def test_gromov_dtype_device(nx):
+ # setup
+ n_samples = 50 # nb samples
+
+ mu_s = np.array([0, 0])
+ cov_s = np.array([[1, 0], [0, 1]])
+
+ xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=4)
+
+ xt = xs[::-1].copy()
+
+ p = ot.unif(n_samples)
+ q = ot.unif(n_samples)
+
+ C1 = ot.dist(xs, xs)
+ C2 = ot.dist(xt, xt)
+
+ C1 /= C1.max()
+ C2 /= C2.max()
+
+ for tp in nx.__type_list__:
+ print(nx.dtype_device(tp))
+
+ C1b = nx.from_numpy(C1, type_as=tp)
+ C2b = nx.from_numpy(C2, type_as=tp)
+ pb = nx.from_numpy(p, type_as=tp)
+ qb = nx.from_numpy(q, type_as=tp)
+
+ Gb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', verbose=True)
+ gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', log=False)
+
+ nx.assert_same_dtype_device(C1b, Gb)
+ nx.assert_same_dtype_device(C1b, gw_valb)
+
+
+def test_gromov2_gradients():
+ n_samples = 50 # nb samples
+
+ mu_s = np.array([0, 0])
+ cov_s = np.array([[1, 0], [0, 1]])
+
+ xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=4)
+
+ xt = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=5)
+
+ p = ot.unif(n_samples)
+ q = ot.unif(n_samples)
+
+ C1 = ot.dist(xs, xs)
+ C2 = ot.dist(xt, xt)
+
+ C1 /= C1.max()
+ C2 /= C2.max()
+
+ if torch:
+
+ p1 = torch.tensor(p, requires_grad=True)
+ q1 = torch.tensor(q, requires_grad=True)
+ C11 = torch.tensor(C1, requires_grad=True)
+ C12 = torch.tensor(C2, requires_grad=True)
+
+ val = ot.gromov_wasserstein2(C11, C12, p1, q1)
+
+ val.backward()
+ assert q1.shape == q1.grad.shape
+ assert p1.shape == p1.grad.shape
+ assert C11.shape == C11.grad.shape
+ assert C12.shape == C12.grad.shape
-def test_entropic_gromov():
+
+@pytest.skip_backend("jax", reason="test very slow with jax backend")
+def test_entropic_gromov(nx):
n_samples = 50 # nb samples
mu_s = np.array([0, 0])
@@ -78,85 +166,278 @@ def test_entropic_gromov():
C1 /= C1.max()
C2 /= C2.max()
+ C1b = nx.from_numpy(C1)
+ C2b = nx.from_numpy(C2)
+ pb = nx.from_numpy(p)
+ qb = nx.from_numpy(q)
+
G = ot.gromov.entropic_gromov_wasserstein(
C1, C2, p, q, 'square_loss', epsilon=5e-4, verbose=True)
+ Gb = nx.to_numpy(ot.gromov.entropic_gromov_wasserstein(
+ C1b, C2b, pb, qb, 'square_loss', epsilon=5e-4, verbose=True
+ ))
- # check constratints
+ # check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
np.testing.assert_allclose(
- p, G.sum(1), atol=1e-04) # cf convergence gromov
+ p, Gb.sum(1), atol=1e-04) # cf convergence gromov
np.testing.assert_allclose(
- q, G.sum(0), atol=1e-04) # cf convergence gromov
+ q, Gb.sum(0), atol=1e-04) # cf convergence gromov
gw, log = ot.gromov.entropic_gromov_wasserstein2(
C1, C2, p, q, 'kl_loss', epsilon=1e-2, log=True)
+ gwb, logb = ot.gromov.entropic_gromov_wasserstein2(
+ C1b, C2b, pb, qb, 'kl_loss', epsilon=1e-2, log=True)
G = log['T']
+ Gb = nx.to_numpy(logb['T'])
+ np.testing.assert_allclose(gw, gwb, atol=1e-06)
np.testing.assert_allclose(gw, 0, atol=1e-1, rtol=1e-1)
- # check constratints
+ # check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
np.testing.assert_allclose(
- p, G.sum(1), atol=1e-04) # cf convergence gromov
+ p, Gb.sum(1), atol=1e-04) # cf convergence gromov
np.testing.assert_allclose(
- q, G.sum(0), atol=1e-04) # cf convergence gromov
+ q, Gb.sum(0), atol=1e-04) # cf convergence gromov
-def test_gromov_barycenter():
- ns = 50
- nt = 60
+@pytest.skip_backend("jax", reason="test very slow with jax backend")
+def test_entropic_gromov_dtype_device(nx):
+ # setup
+ n_samples = 50 # nb samples
- Xs, ys = ot.datasets.make_data_classif('3gauss', ns, random_state=42)
- Xt, yt = ot.datasets.make_data_classif('3gauss2', nt, random_state=42)
+ mu_s = np.array([0, 0])
+ cov_s = np.array([[1, 0], [0, 1]])
- C1 = ot.dist(Xs)
- C2 = ot.dist(Xt)
+ xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=42)
- n_samples = 3
- Cb = ot.gromov.gromov_barycenters(n_samples, [C1, C2],
- [ot.unif(ns), ot.unif(nt)
- ], ot.unif(n_samples), [.5, .5],
- 'square_loss', # 5e-4,
- max_iter=100, tol=1e-3,
- verbose=True)
- np.testing.assert_allclose(Cb.shape, (n_samples, n_samples))
+ xt = xs[::-1].copy()
- Cb2 = ot.gromov.gromov_barycenters(n_samples, [C1, C2],
- [ot.unif(ns), ot.unif(nt)
- ], ot.unif(n_samples), [.5, .5],
- 'kl_loss', # 5e-4,
- max_iter=100, tol=1e-3)
- np.testing.assert_allclose(Cb2.shape, (n_samples, n_samples))
+ p = ot.unif(n_samples)
+ q = ot.unif(n_samples)
+ C1 = ot.dist(xs, xs)
+ C2 = ot.dist(xt, xt)
-def test_gromov_entropic_barycenter():
- ns = 50
- nt = 60
+ C1 /= C1.max()
+ C2 /= C2.max()
+
+ for tp in nx.__type_list__:
+ print(nx.dtype_device(tp))
+
+ C1b = nx.from_numpy(C1, type_as=tp)
+ C2b = nx.from_numpy(C2, type_as=tp)
+ pb = nx.from_numpy(p, type_as=tp)
+ qb = nx.from_numpy(q, type_as=tp)
+
+ Gb = ot.gromov.entropic_gromov_wasserstein(
+ C1b, C2b, pb, qb, 'square_loss', epsilon=5e-4, verbose=True
+ )
+ gw_valb = ot.gromov.entropic_gromov_wasserstein2(
+ C1b, C2b, pb, qb, 'square_loss', epsilon=5e-4, verbose=True
+ )
+
+ nx.assert_same_dtype_device(C1b, Gb)
+ nx.assert_same_dtype_device(C1b, gw_valb)
+
+
+def test_pointwise_gromov(nx):
+ n_samples = 50 # nb samples
+
+ mu_s = np.array([0, 0])
+ cov_s = np.array([[1, 0], [0, 1]])
+
+ xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=42)
+
+ xt = xs[::-1].copy()
+
+ p = ot.unif(n_samples)
+ q = ot.unif(n_samples)
+
+ C1 = ot.dist(xs, xs)
+ C2 = ot.dist(xt, xt)
+
+ C1 /= C1.max()
+ C2 /= C2.max()
+
+ C1b = nx.from_numpy(C1)
+ C2b = nx.from_numpy(C2)
+ pb = nx.from_numpy(p)
+ qb = nx.from_numpy(q)
+
+ def loss(x, y):
+ return np.abs(x - y)
+
+ def lossb(x, y):
+ return nx.abs(x - y)
+
+ G, log = ot.gromov.pointwise_gromov_wasserstein(
+ C1, C2, p, q, loss, max_iter=100, log=True, verbose=True, random_state=42)
+ G = NumpyBackend().todense(G)
+ Gb, logb = ot.gromov.pointwise_gromov_wasserstein(
+ C1b, C2b, pb, qb, lossb, max_iter=100, log=True, verbose=True, random_state=42)
+ Gb = nx.to_numpy(nx.todense(Gb))
+
+ # check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
+ np.testing.assert_allclose(
+ p, Gb.sum(1), atol=1e-04) # cf convergence gromov
+ np.testing.assert_allclose(
+ q, Gb.sum(0), atol=1e-04) # cf convergence gromov
+
+ np.testing.assert_allclose(logb['gw_dist_estimated'], 0.0, atol=1e-08)
+ np.testing.assert_allclose(logb['gw_dist_std'], 0.0, atol=1e-08)
+
+ G, log = ot.gromov.pointwise_gromov_wasserstein(
+ C1, C2, p, q, loss, max_iter=100, alpha=0.1, log=True, verbose=True, random_state=42)
+ G = NumpyBackend().todense(G)
+ Gb, logb = ot.gromov.pointwise_gromov_wasserstein(
+ C1b, C2b, pb, qb, lossb, max_iter=100, alpha=0.1, log=True, verbose=True, random_state=42)
+ Gb = nx.to_numpy(nx.todense(Gb))
+
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
+ np.testing.assert_allclose(logb['gw_dist_estimated'], 0.10342276348494964, atol=1e-8)
+ np.testing.assert_allclose(logb['gw_dist_std'], 0.0015952535464736394, atol=1e-8)
+
+
+@pytest.skip_backend("jax", reason="test very slow with jax backend")
+def test_sampled_gromov(nx):
+ n_samples = 50 # nb samples
+
+ mu_s = np.array([0, 0], dtype=np.float64)
+ cov_s = np.array([[1, 0], [0, 1]], dtype=np.float64)
+
+ xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=42)
+
+ xt = xs[::-1].copy()
+
+ p = ot.unif(n_samples)
+ q = ot.unif(n_samples)
+
+ C1 = ot.dist(xs, xs)
+ C2 = ot.dist(xt, xt)
+
+ C1 /= C1.max()
+ C2 /= C2.max()
+
+ C1b = nx.from_numpy(C1)
+ C2b = nx.from_numpy(C2)
+ pb = nx.from_numpy(p)
+ qb = nx.from_numpy(q)
+
+ def loss(x, y):
+ return np.abs(x - y)
+
+ def lossb(x, y):
+ return nx.abs(x - y)
+
+ G, log = ot.gromov.sampled_gromov_wasserstein(
+ C1, C2, p, q, loss, max_iter=100, epsilon=1, log=True, verbose=True, random_state=42)
+ Gb, logb = ot.gromov.sampled_gromov_wasserstein(
+ C1b, C2b, pb, qb, lossb, max_iter=100, epsilon=1, log=True, verbose=True, random_state=42)
+ Gb = nx.to_numpy(Gb)
+
+ # check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
+ np.testing.assert_allclose(
+ p, Gb.sum(1), atol=1e-04) # cf convergence gromov
+ np.testing.assert_allclose(
+ q, Gb.sum(0), atol=1e-04) # cf convergence gromov
+
+ np.testing.assert_allclose(logb['gw_dist_estimated'], 0.05679474884977278, atol=1e-08)
+ np.testing.assert_allclose(logb['gw_dist_std'], 0.0005986592106971995, atol=1e-08)
+
+
+def test_gromov_barycenter(nx):
+ ns = 10
+ nt = 20
Xs, ys = ot.datasets.make_data_classif('3gauss', ns, random_state=42)
Xt, yt = ot.datasets.make_data_classif('3gauss2', nt, random_state=42)
C1 = ot.dist(Xs)
C2 = ot.dist(Xt)
-
+ p1 = ot.unif(ns)
+ p2 = ot.unif(nt)
n_samples = 3
- Cb = ot.gromov.entropic_gromov_barycenters(n_samples, [C1, C2],
- [ot.unif(ns), ot.unif(nt)
- ], ot.unif(n_samples), [.5, .5],
- 'square_loss', 2e-3,
- max_iter=100, tol=1e-3,
- verbose=True)
- np.testing.assert_allclose(Cb.shape, (n_samples, n_samples))
+ p = ot.unif(n_samples)
- Cb2 = ot.gromov.entropic_gromov_barycenters(n_samples, [C1, C2],
- [ot.unif(ns), ot.unif(nt)
- ], ot.unif(n_samples), [.5, .5],
- 'kl_loss', 2e-3,
- max_iter=100, tol=1e-3)
- np.testing.assert_allclose(Cb2.shape, (n_samples, n_samples))
+ C1b = nx.from_numpy(C1)
+ C2b = nx.from_numpy(C2)
+ p1b = nx.from_numpy(p1)
+ p2b = nx.from_numpy(p2)
+ pb = nx.from_numpy(p)
+
+ Cb = ot.gromov.gromov_barycenters(
+ n_samples, [C1, C2], [p1, p2], p, [.5, .5],
+ 'square_loss', max_iter=100, tol=1e-3, verbose=True, random_state=42
+ )
+ Cbb = nx.to_numpy(ot.gromov.gromov_barycenters(
+ n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5],
+ 'square_loss', max_iter=100, tol=1e-3, verbose=True, random_state=42
+ ))
+ np.testing.assert_allclose(Cb, Cbb, atol=1e-06)
+ np.testing.assert_allclose(Cbb.shape, (n_samples, n_samples))
+
+ Cb2 = ot.gromov.gromov_barycenters(
+ n_samples, [C1, C2], [p1, p2], p, [.5, .5],
+ 'kl_loss', max_iter=100, tol=1e-3, random_state=42
+ )
+ Cb2b = nx.to_numpy(ot.gromov.gromov_barycenters(
+ n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5],
+ 'kl_loss', max_iter=100, tol=1e-3, random_state=42
+ ))
+ np.testing.assert_allclose(Cb2, Cb2b, atol=1e-06)
+ np.testing.assert_allclose(Cb2b.shape, (n_samples, n_samples))
+
+
+@pytest.mark.filterwarnings("ignore:divide")
+def test_gromov_entropic_barycenter(nx):
+ ns = 10
+ nt = 20
+ Xs, ys = ot.datasets.make_data_classif('3gauss', ns, random_state=42)
+ Xt, yt = ot.datasets.make_data_classif('3gauss2', nt, random_state=42)
-def test_fgw():
+ C1 = ot.dist(Xs)
+ C2 = ot.dist(Xt)
+ p1 = ot.unif(ns)
+ p2 = ot.unif(nt)
+ n_samples = 2
+ p = ot.unif(n_samples)
+ C1b = nx.from_numpy(C1)
+ C2b = nx.from_numpy(C2)
+ p1b = nx.from_numpy(p1)
+ p2b = nx.from_numpy(p2)
+ pb = nx.from_numpy(p)
+
+ Cb = ot.gromov.entropic_gromov_barycenters(
+ n_samples, [C1, C2], [p1, p2], p, [.5, .5],
+ 'square_loss', 1e-3, max_iter=50, tol=1e-3, verbose=True, random_state=42
+ )
+ Cbb = nx.to_numpy(ot.gromov.entropic_gromov_barycenters(
+ n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5],
+ 'square_loss', 1e-3, max_iter=50, tol=1e-3, verbose=True, random_state=42
+ ))
+ np.testing.assert_allclose(Cb, Cbb, atol=1e-06)
+ np.testing.assert_allclose(Cbb.shape, (n_samples, n_samples))
+
+ Cb2 = ot.gromov.entropic_gromov_barycenters(
+ n_samples, [C1, C2], [p1, p2], p, [.5, .5],
+ 'kl_loss', 1e-3, max_iter=100, tol=1e-3, random_state=42
+ )
+ Cb2b = nx.to_numpy(ot.gromov.entropic_gromov_barycenters(
+ n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5],
+ 'kl_loss', 1e-3, max_iter=100, tol=1e-3, random_state=42
+ ))
+ np.testing.assert_allclose(Cb2, Cb2b, atol=1e-06)
+ np.testing.assert_allclose(Cb2b.shape, (n_samples, n_samples))
+
+
+def test_fgw(nx):
n_samples = 50 # nb samples
mu_s = np.array([0, 0])
@@ -181,33 +462,85 @@ def test_fgw():
M = ot.dist(ys, yt)
M /= M.max()
- G = ot.gromov.fused_gromov_wasserstein(M, C1, C2, p, q, 'square_loss', alpha=0.5)
+ Mb = nx.from_numpy(M)
+ C1b = nx.from_numpy(C1)
+ C2b = nx.from_numpy(C2)
+ pb = nx.from_numpy(p)
+ qb = nx.from_numpy(q)
- # check constratints
+ G, log = ot.gromov.fused_gromov_wasserstein(M, C1, C2, p, q, 'square_loss', alpha=0.5, log=True)
+ Gb, logb = ot.gromov.fused_gromov_wasserstein(Mb, C1b, C2b, pb, qb, 'square_loss', alpha=0.5, log=True)
+ Gb = nx.to_numpy(Gb)
+
+ # check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
np.testing.assert_allclose(
- p, G.sum(1), atol=1e-04) # cf convergence fgw
+ p, Gb.sum(1), atol=1e-04) # cf convergence fgw
np.testing.assert_allclose(
- q, G.sum(0), atol=1e-04) # cf convergence fgw
+ q, Gb.sum(0), atol=1e-04) # cf convergence fgw
Id = (1 / (1.0 * n_samples)) * np.eye(n_samples, n_samples)
np.testing.assert_allclose(
- G, np.flipud(Id), atol=1e-04) # cf convergence gromov
+ Gb, np.flipud(Id), atol=1e-04) # cf convergence gromov
fgw, log = ot.gromov.fused_gromov_wasserstein2(M, C1, C2, p, q, 'square_loss', alpha=0.5, log=True)
+ fgwb, logb = ot.gromov.fused_gromov_wasserstein2(Mb, C1b, C2b, pb, qb, 'square_loss', alpha=0.5, log=True)
G = log['T']
+ Gb = nx.to_numpy(logb['T'])
- np.testing.assert_allclose(fgw, 0, atol=1e-1, rtol=1e-1)
+ np.testing.assert_allclose(fgw, fgwb, atol=1e-08)
+ np.testing.assert_allclose(fgwb, 0, atol=1e-1, rtol=1e-1)
- # check constratints
+ # check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
np.testing.assert_allclose(
- p, G.sum(1), atol=1e-04) # cf convergence gromov
+ p, Gb.sum(1), atol=1e-04) # cf convergence gromov
np.testing.assert_allclose(
- q, G.sum(0), atol=1e-04) # cf convergence gromov
+ q, Gb.sum(0), atol=1e-04) # cf convergence gromov
+
+
+def test_fgw2_gradients():
+ n_samples = 50 # nb samples
+ mu_s = np.array([0, 0])
+ cov_s = np.array([[1, 0], [0, 1]])
-def test_fgw_barycenter():
+ xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=4)
+
+ xt = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=5)
+
+ p = ot.unif(n_samples)
+ q = ot.unif(n_samples)
+
+ C1 = ot.dist(xs, xs)
+ C2 = ot.dist(xt, xt)
+ M = ot.dist(xs, xt)
+
+ C1 /= C1.max()
+ C2 /= C2.max()
+
+ if torch:
+
+ p1 = torch.tensor(p, requires_grad=True)
+ q1 = torch.tensor(q, requires_grad=True)
+ C11 = torch.tensor(C1, requires_grad=True)
+ C12 = torch.tensor(C2, requires_grad=True)
+ M1 = torch.tensor(M, requires_grad=True)
+
+ val = ot.fused_gromov_wasserstein2(M1, C11, C12, p1, q1)
+
+ val.backward()
+
+ assert q1.shape == q1.grad.shape
+ assert p1.shape == p1.grad.shape
+ assert C11.shape == C11.grad.shape
+ assert C12.shape == C12.grad.shape
+ assert M1.shape == M1.grad.shape
+
+
+def test_fgw_barycenter(nx):
np.random.seed(42)
ns = 50
@@ -221,30 +554,44 @@ def test_fgw_barycenter():
C1 = ot.dist(Xs)
C2 = ot.dist(Xt)
-
+ p1, p2 = ot.unif(ns), ot.unif(nt)
n_samples = 3
- X, C = ot.gromov.fgw_barycenters(n_samples, [ys, yt], [C1, C2], [ot.unif(ns), ot.unif(nt)], [.5, .5], 0.5,
- fixed_structure=False, fixed_features=False,
- p=ot.unif(n_samples), loss_fun='square_loss',
- max_iter=100, tol=1e-3)
- np.testing.assert_allclose(C.shape, (n_samples, n_samples))
- np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1]))
+ p = ot.unif(n_samples)
+
+ ysb = nx.from_numpy(ys)
+ ytb = nx.from_numpy(yt)
+ C1b = nx.from_numpy(C1)
+ C2b = nx.from_numpy(C2)
+ p1b = nx.from_numpy(p1)
+ p2b = nx.from_numpy(p2)
+ pb = nx.from_numpy(p)
+
+ Xb, Cb = ot.gromov.fgw_barycenters(
+ n_samples, [ysb, ytb], [C1b, C2b], [p1b, p2b], [.5, .5], 0.5, fixed_structure=False,
+ fixed_features=False, p=pb, loss_fun='square_loss', max_iter=100, tol=1e-3, random_state=12345
+ )
xalea = np.random.randn(n_samples, 2)
init_C = ot.dist(xalea, xalea)
-
- X, C = ot.gromov.fgw_barycenters(n_samples, [ys, yt], [C1, C2], ps=[ot.unif(ns), ot.unif(nt)], lambdas=[.5, .5], alpha=0.5,
- fixed_structure=True, init_C=init_C, fixed_features=False,
- p=ot.unif(n_samples), loss_fun='square_loss',
- max_iter=100, tol=1e-3)
- np.testing.assert_allclose(C.shape, (n_samples, n_samples))
- np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1]))
+ init_Cb = nx.from_numpy(init_C)
+
+ Xb, Cb = ot.gromov.fgw_barycenters(
+ n_samples, [ysb, ytb], [C1b, C2b], ps=[p1b, p2b], lambdas=[.5, .5],
+ alpha=0.5, fixed_structure=True, init_C=init_Cb, fixed_features=False,
+ p=pb, loss_fun='square_loss', max_iter=100, tol=1e-3
+ )
+ Xb, Cb = nx.to_numpy(Xb), nx.to_numpy(Cb)
+ np.testing.assert_allclose(Cb.shape, (n_samples, n_samples))
+ np.testing.assert_allclose(Xb.shape, (n_samples, ys.shape[1]))
init_X = np.random.randn(n_samples, ys.shape[1])
-
- X, C = ot.gromov.fgw_barycenters(n_samples, [ys, yt], [C1, C2], [ot.unif(ns), ot.unif(nt)], [.5, .5], 0.5,
- fixed_structure=False, fixed_features=True, init_X=init_X,
- p=ot.unif(n_samples), loss_fun='square_loss',
- max_iter=100, tol=1e-3)
- np.testing.assert_allclose(C.shape, (n_samples, n_samples))
- np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1]))
+ init_Xb = nx.from_numpy(init_X)
+
+ Xb, Cb, logb = ot.gromov.fgw_barycenters(
+ n_samples, [ysb, ytb], [C1b, C2b], [p1b, p2b], [.5, .5], 0.5,
+ fixed_structure=False, fixed_features=True, init_X=init_Xb,
+ p=pb, loss_fun='square_loss', max_iter=100, tol=1e-3, log=True, random_state=98765
+ )
+ Xb, Cb = nx.to_numpy(Xb), nx.to_numpy(Cb)
+ np.testing.assert_allclose(Cb.shape, (n_samples, n_samples))
+ np.testing.assert_allclose(Xb.shape, (n_samples, ys.shape[1]))
diff --git a/test/test_helpers.py b/test/test_helpers.py
new file mode 100644
index 0000000..cc4c90e
--- /dev/null
+++ b/test/test_helpers.py
@@ -0,0 +1,26 @@
+"""Tests for helpers functions """
+
+# Author: Remi Flamary <remi.flamary@polytechnique.edu>
+#
+# License: MIT License
+
+import os
+import sys
+
+sys.path.append(os.path.join("ot", "helpers"))
+
+from openmp_helpers import get_openmp_flag, check_openmp_support # noqa
+from pre_build_helpers import _get_compiler, compile_test_program # noqa
+
+
+def test_helpers():
+
+ compiler = _get_compiler()
+
+ get_openmp_flag(compiler)
+
+ s = '#include <stdio.h>\n#include <stdlib.h>\n\nint main(void) {\n\tprintf("Hello world!\\n");\n\treturn 0;\n}'
+ output, _ = compile_test_program(s)
+ assert len(output) == 1 and output[0] == "Hello world!"
+
+ check_openmp_support()
diff --git a/test/test_optim.py b/test/test_optim.py
index 87b0268..4efd9b1 100644
--- a/test/test_optim.py
+++ b/test/test_optim.py
@@ -8,7 +8,7 @@ import numpy as np
import ot
-def test_conditional_gradient():
+def test_conditional_gradient(nx):
n_bins = 100 # nb bins
np.random.seed(0)
@@ -29,16 +29,26 @@ def test_conditional_gradient():
def df(G):
return G
+ def fb(G):
+ return 0.5 * nx.sum(G ** 2)
+
+ ab = nx.from_numpy(a)
+ bb = nx.from_numpy(b)
+ Mb = nx.from_numpy(M, type_as=ab)
+
reg = 1e-1
G, log = ot.optim.cg(a, b, M, reg, f, df, verbose=True, log=True)
+ Gb, log = ot.optim.cg(ab, bb, Mb, reg, fb, df, verbose=True, log=True)
+ Gb = nx.to_numpy(Gb)
- np.testing.assert_allclose(a, G.sum(1))
- np.testing.assert_allclose(b, G.sum(0))
+ np.testing.assert_allclose(Gb, G)
+ np.testing.assert_allclose(a, Gb.sum(1))
+ np.testing.assert_allclose(b, Gb.sum(0))
-def test_conditional_gradient2():
- n = 1000 # nb samples
+def test_conditional_gradient_itermax(nx):
+ n = 100 # nb samples
mu_s = np.array([0, 0])
cov_s = np.array([[1, 0], [0, 1]])
@@ -61,16 +71,27 @@ def test_conditional_gradient2():
def df(G):
return G
+ def fb(G):
+ return 0.5 * nx.sum(G ** 2)
+
+ ab = nx.from_numpy(a)
+ bb = nx.from_numpy(b)
+ Mb = nx.from_numpy(M, type_as=ab)
+
reg = 1e-1
- G, log = ot.optim.cg(a, b, M, reg, f, df, numItermaxEmd=200000,
+ G, log = ot.optim.cg(a, b, M, reg, f, df, numItermaxEmd=10000,
verbose=True, log=True)
+ Gb, log = ot.optim.cg(ab, bb, Mb, reg, fb, df, numItermaxEmd=10000,
+ verbose=True, log=True)
+ Gb = nx.to_numpy(Gb)
- np.testing.assert_allclose(a, G.sum(1))
- np.testing.assert_allclose(b, G.sum(0))
+ np.testing.assert_allclose(Gb, G)
+ np.testing.assert_allclose(a, Gb.sum(1))
+ np.testing.assert_allclose(b, Gb.sum(0))
-def test_generalized_conditional_gradient():
+def test_generalized_conditional_gradient(nx):
n_bins = 100 # nb bins
np.random.seed(0)
@@ -91,16 +112,76 @@ def test_generalized_conditional_gradient():
def df(G):
return G
+ def fb(G):
+ return 0.5 * nx.sum(G ** 2)
+
reg1 = 1e-3
reg2 = 1e-1
+ ab = nx.from_numpy(a)
+ bb = nx.from_numpy(b)
+ Mb = nx.from_numpy(M, type_as=ab)
+
G, log = ot.optim.gcg(a, b, M, reg1, reg2, f, df, verbose=True, log=True)
+ Gb, log = ot.optim.gcg(ab, bb, Mb, reg1, reg2, fb, df, verbose=True, log=True)
+ Gb = nx.to_numpy(Gb)
- np.testing.assert_allclose(a, G.sum(1), atol=1e-05)
- np.testing.assert_allclose(b, G.sum(0), atol=1e-05)
+ np.testing.assert_allclose(Gb, G)
+ np.testing.assert_allclose(a, Gb.sum(1), atol=1e-05)
+ np.testing.assert_allclose(b, Gb.sum(0), atol=1e-05)
def test_solve_1d_linesearch_quad_funct():
np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad(1, -1, 0), 0.5)
np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad(-1, 5, 0), 0)
np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad(-1, 0.5, 0), 1)
+
+
+def test_line_search_armijo(nx):
+ xk = np.array([[0.25, 0.25], [0.25, 0.25]])
+ pk = np.array([[-0.25, 0.25], [0.25, -0.25]])
+ gfk = np.array([[23.04273441, 23.0449082], [23.04273441, 23.0449082]])
+ old_fval = -123
+ # Should not throw an exception and return None for alpha
+ alpha, a, b = ot.optim.line_search_armijo(
+ lambda x: 1, nx.from_numpy(xk), nx.from_numpy(pk), nx.from_numpy(gfk), old_fval
+ )
+ alpha_np, anp, bnp = ot.optim.line_search_armijo(
+ lambda x: 1, xk, pk, gfk, old_fval
+ )
+ assert a == anp
+ assert b == bnp
+ assert alpha is None
+
+ # check line search armijo
+ def f(x):
+ return nx.sum((x - 5.0) ** 2)
+
+ def grad(x):
+ return 2 * (x - 5.0)
+
+ xk = nx.from_numpy(np.array([[[-5.0, -5.0]]]))
+ pk = nx.from_numpy(np.array([[[100.0, 100.0]]]))
+ gfk = grad(xk)
+ old_fval = f(xk)
+
+ # chech the case where the optimum is on the direction
+ alpha, _, _ = ot.optim.line_search_armijo(f, xk, pk, gfk, old_fval)
+ np.testing.assert_allclose(alpha, 0.1)
+
+ # check the case where the direction is not far enough
+ pk = nx.from_numpy(np.array([[[3.0, 3.0]]]))
+ alpha, _, _ = ot.optim.line_search_armijo(f, xk, pk, gfk, old_fval, alpha0=1.0)
+ np.testing.assert_allclose(alpha, 1.0)
+
+ # check the case where checking the wrong direction
+ alpha, _, _ = ot.optim.line_search_armijo(f, xk, -pk, gfk, old_fval)
+ assert alpha <= 0
+
+ # check the case where the point is not a vector
+ xk = nx.from_numpy(np.array(-5.0))
+ pk = nx.from_numpy(np.array(100.0))
+ gfk = grad(xk)
+ old_fval = f(xk)
+ alpha, _, _ = ot.optim.line_search_armijo(f, xk, pk, gfk, old_fval)
+ np.testing.assert_allclose(alpha, 0.1)
diff --git a/test/test_ot.py b/test/test_ot.py
index b7306f6..92f26a7 100644
--- a/test/test_ot.py
+++ b/test/test_ot.py
@@ -8,13 +8,13 @@ import warnings
import numpy as np
import pytest
-from scipy.stats import wasserstein_distance
import ot
from ot.datasets import make_1D_gauss as gauss
+from ot.backend import torch
-def test_emd_dimension_mismatch():
+def test_emd_dimension_and_mass_mismatch():
# test emd and emd2 for dimension mismatch
n_samples = 100
n_features = 2
@@ -29,122 +29,125 @@ def test_emd_dimension_mismatch():
np.testing.assert_raises(AssertionError, ot.emd2, a, a, M)
+ b = a.copy()
+ a[0] = 100
+ np.testing.assert_raises(AssertionError, ot.emd, a, b, M)
-def test_emd_emd2():
- # test emd and emd2 for simple identity
- n = 100
+
+def test_emd_backends(nx):
+ n_samples = 100
+ n_features = 2
rng = np.random.RandomState(0)
- x = rng.randn(n, 2)
- u = ot.utils.unif(n)
+ x = rng.randn(n_samples, n_features)
+ y = rng.randn(n_samples, n_features)
+ a = ot.utils.unif(n_samples)
- M = ot.dist(x, x)
+ M = ot.dist(x, y)
- G = ot.emd(u, u, M)
+ G = ot.emd(a, a, M)
- # check G is identity
- np.testing.assert_allclose(G, np.eye(n) / n)
- # check constraints
- np.testing.assert_allclose(u, G.sum(1)) # cf convergence sinkhorn
- np.testing.assert_allclose(u, G.sum(0)) # cf convergence sinkhorn
+ ab = nx.from_numpy(a)
+ Mb = nx.from_numpy(M)
- w = ot.emd2(u, u, M)
- # check loss=0
- np.testing.assert_allclose(w, 0)
+ Gb = ot.emd(ab, ab, Mb)
+
+ np.allclose(G, nx.to_numpy(Gb))
-def test_emd_1d_emd2_1d():
- # test emd1d gives similar results as emd
- n = 20
- m = 30
+def test_emd2_backends(nx):
+ n_samples = 100
+ n_features = 2
rng = np.random.RandomState(0)
- u = rng.randn(n, 1)
- v = rng.randn(m, 1)
- M = ot.dist(u, v, metric='sqeuclidean')
+ x = rng.randn(n_samples, n_features)
+ y = rng.randn(n_samples, n_features)
+ a = ot.utils.unif(n_samples)
+
+ M = ot.dist(x, y)
- G, log = ot.emd([], [], M, log=True)
- wass = log["cost"]
- G_1d, log = ot.emd_1d(u, v, [], [], metric='sqeuclidean', log=True)
- wass1d = log["cost"]
- wass1d_emd2 = ot.emd2_1d(u, v, [], [], metric='sqeuclidean', log=False)
- wass1d_euc = ot.emd2_1d(u, v, [], [], metric='euclidean', log=False)
+ val = ot.emd2(a, a, M)
- # check loss is similar
- np.testing.assert_allclose(wass, wass1d)
- np.testing.assert_allclose(wass, wass1d_emd2)
+ ab = nx.from_numpy(a)
+ Mb = nx.from_numpy(M)
- # check loss is similar to scipy's implementation for Euclidean metric
- wass_sp = wasserstein_distance(u.reshape((-1,)), v.reshape((-1,)))
- np.testing.assert_allclose(wass_sp, wass1d_euc)
+ valb = ot.emd2(ab, ab, Mb)
- # check constraints
- np.testing.assert_allclose(np.ones((n,)) / n, G.sum(1))
- np.testing.assert_allclose(np.ones((m,)) / m, G.sum(0))
+ np.allclose(val, nx.to_numpy(valb))
+
+
+def test_emd_emd2_types_devices(nx):
+ n_samples = 100
+ n_features = 2
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n_samples, n_features)
+ y = rng.randn(n_samples, n_features)
+ a = ot.utils.unif(n_samples)
+
+ M = ot.dist(x, y)
+
+ for tp in nx.__type_list__:
+ print(nx.dtype_device(tp))
+
+ ab = nx.from_numpy(a, type_as=tp)
+ Mb = nx.from_numpy(M, type_as=tp)
- # check G is similar
- np.testing.assert_allclose(G, G_1d)
+ Gb = ot.emd(ab, ab, Mb)
- # check AssertionError is raised if called on non 1d arrays
- u = np.random.randn(n, 2)
- v = np.random.randn(m, 2)
- with pytest.raises(AssertionError):
- ot.emd_1d(u, v, [], [])
+ w = ot.emd2(ab, ab, Mb)
+ nx.assert_same_dtype_device(Mb, Gb)
+ nx.assert_same_dtype_device(Mb, w)
-def test_emd_1d_emd2_1d_with_weights():
- # test emd1d gives similar results as emd
- n = 20
- m = 30
+
+def test_emd2_gradients():
+ n_samples = 100
+ n_features = 2
rng = np.random.RandomState(0)
- u = rng.randn(n, 1)
- v = rng.randn(m, 1)
- w_u = rng.uniform(0., 1., n)
- w_u = w_u / w_u.sum()
+ x = rng.randn(n_samples, n_features)
+ y = rng.randn(n_samples, n_features)
+ a = ot.utils.unif(n_samples)
- w_v = rng.uniform(0., 1., m)
- w_v = w_v / w_v.sum()
+ M = ot.dist(x, y)
- M = ot.dist(u, v, metric='sqeuclidean')
+ if torch:
- G, log = ot.emd(w_u, w_v, M, log=True)
- wass = log["cost"]
- G_1d, log = ot.emd_1d(u, v, w_u, w_v, metric='sqeuclidean', log=True)
- wass1d = log["cost"]
- wass1d_emd2 = ot.emd2_1d(u, v, w_u, w_v, metric='sqeuclidean', log=False)
- wass1d_euc = ot.emd2_1d(u, v, w_u, w_v, metric='euclidean', log=False)
+ a1 = torch.tensor(a, requires_grad=True)
+ b1 = torch.tensor(a, requires_grad=True)
+ M1 = torch.tensor(M, requires_grad=True)
- # check loss is similar
- np.testing.assert_allclose(wass, wass1d)
- np.testing.assert_allclose(wass, wass1d_emd2)
+ val = ot.emd2(a1, b1, M1)
- # check loss is similar to scipy's implementation for Euclidean metric
- wass_sp = wasserstein_distance(u.reshape((-1,)), v.reshape((-1,)), w_u, w_v)
- np.testing.assert_allclose(wass_sp, wass1d_euc)
+ val.backward()
- # check constraints
- np.testing.assert_allclose(w_u, G.sum(1))
- np.testing.assert_allclose(w_v, G.sum(0))
+ assert a1.shape == a1.grad.shape
+ assert b1.shape == b1.grad.shape
+ assert M1.shape == M1.grad.shape
-def test_wass_1d():
- # test emd1d gives similar results as emd
- n = 20
- m = 30
+def test_emd_emd2():
+ # test emd and emd2 for simple identity
+ n = 100
rng = np.random.RandomState(0)
- u = rng.randn(n, 1)
- v = rng.randn(m, 1)
- M = ot.dist(u, v, metric='sqeuclidean')
+ x = rng.randn(n, 2)
+ u = ot.utils.unif(n)
- G, log = ot.emd([], [], M, log=True)
- wass = log["cost"]
+ M = ot.dist(x, x)
- wass1d = ot.wasserstein_1d(u, v, [], [], p=2.)
+ G = ot.emd(u, u, M)
+
+ # check G is identity
+ np.testing.assert_allclose(G, np.eye(n) / n)
+ # check constraints
+ np.testing.assert_allclose(u, G.sum(1)) # cf convergence sinkhorn
+ np.testing.assert_allclose(u, G.sum(0)) # cf convergence sinkhorn
- # check loss is similar
- np.testing.assert_allclose(np.sqrt(wass), wass1d)
+ w = ot.emd2(u, u, M)
+ # check loss=0
+ np.testing.assert_allclose(w, 0)
def test_emd_empty():
@@ -291,17 +294,7 @@ def test_warnings():
print('Computing {} EMD '.format(1))
ot.emd(a, b, M, numItermax=1)
assert "numItermax" in str(w[-1].message)
- assert len(w) == 1
- a[0] = 100
- print('Computing {} EMD '.format(2))
- ot.emd(a, b, M)
- assert "infeasible" in str(w[-1].message)
- assert len(w) == 2
- a[0] = -1
- print('Computing {} EMD '.format(2))
- ot.emd(a, b, M)
- assert "infeasible" in str(w[-1].message)
- assert len(w) == 3
+ #assert len(w) == 1
def test_dual_variables():
diff --git a/test/test_partial.py b/test/test_partial.py
index 510e081..97c611b 100755
--- a/test/test_partial.py
+++ b/test/test_partial.py
@@ -51,10 +51,12 @@ def test_raise_errors():
ot.partial.partial_gromov_wasserstein(M, M, p, q, m=-1, log=True)
with pytest.raises(ValueError):
- ot.partial.entropic_partial_gromov_wasserstein(M, M, p, q, reg=1, m=2, log=True)
+ ot.partial.entropic_partial_gromov_wasserstein(M, M, p, q, reg=1, m=2,
+ log=True)
with pytest.raises(ValueError):
- ot.partial.entropic_partial_gromov_wasserstein(M, M, p, q, reg=1, m=-1, log=True)
+ ot.partial.entropic_partial_gromov_wasserstein(M, M, p, q, reg=1, m=-1,
+ log=True)
def test_partial_wasserstein_lagrange():
@@ -102,7 +104,7 @@ def test_partial_wasserstein():
w, log = ot.partial.entropic_partial_wasserstein(p, q, M, reg=1, m=m,
log=True, verbose=True)
- # check constratints
+ # check constraints
np.testing.assert_equal(
w0.sum(1) - p <= 1e-5, [True] * len(p)) # cf convergence wasserstein
np.testing.assert_equal(
@@ -125,11 +127,11 @@ def test_partial_wasserstein():
np.testing.assert_allclose(w0, w0_val, atol=1e-1, rtol=1e-1)
- # check constratints
+ # check constraints
np.testing.assert_equal(
- G.sum(1) <= p, [True] * len(p)) # cf convergence wasserstein
+ G.sum(1) - p <= 1e-5, [True] * len(p)) # cf convergence wasserstein
np.testing.assert_equal(
- G.sum(0) <= q, [True] * len(q)) # cf convergence wasserstein
+ G.sum(0) - q <= 1e-5, [True] * len(q)) # cf convergence wasserstein
np.testing.assert_allclose(
np.sum(G), m, atol=1e-04)
@@ -192,7 +194,7 @@ def test_partial_gromov_wasserstein():
100, m=m,
log=True)
- # check constratints
+ # check constraints
np.testing.assert_equal(
res0.sum(1) <= p, [True] * len(p)) # cf convergence wasserstein
np.testing.assert_equal(
diff --git a/test/test_regpath.py b/test/test_regpath.py
new file mode 100644
index 0000000..967c27b
--- /dev/null
+++ b/test/test_regpath.py
@@ -0,0 +1,64 @@
+"""Tests for module regularization path"""
+
+# Author: Haoran Wu <haoran.wu@univ-ubs.fr>
+#
+# License: MIT License
+
+import numpy as np
+import ot
+
+
+def test_fully_relaxed_path():
+
+ n_source = 50 # nb source samples (gaussian)
+ n_target = 40 # nb target samples (gaussian)
+
+ mu = np.array([0, 0])
+ cov = np.array([[1, 0], [0, 2]])
+
+ np.random.seed(0)
+ xs = ot.datasets.make_2D_samples_gauss(n_source, mu, cov)
+ xt = ot.datasets.make_2D_samples_gauss(n_target, mu, cov)
+
+ # source and target distributions
+ a = ot.utils.unif(n_source)
+ b = ot.utils.unif(n_target)
+
+ # loss matrix
+ M = ot.dist(xs, xt)
+ M /= M.max()
+
+ t, _, _ = ot.regpath.regularization_path(a, b, M, reg=1e-8,
+ semi_relaxed=False)
+
+ G = t.reshape((n_source, n_target))
+ np.testing.assert_allclose(a, G.sum(1), atol=1e-05)
+ np.testing.assert_allclose(b, G.sum(0), atol=1e-05)
+
+
+def test_semi_relaxed_path():
+
+ n_source = 50 # nb source samples (gaussian)
+ n_target = 40 # nb target samples (gaussian)
+
+ mu = np.array([0, 0])
+ cov = np.array([[1, 0], [0, 2]])
+
+ np.random.seed(0)
+ xs = ot.datasets.make_2D_samples_gauss(n_source, mu, cov)
+ xt = ot.datasets.make_2D_samples_gauss(n_target, mu, cov)
+
+ # source and target distributions
+ a = ot.utils.unif(n_source)
+ b = ot.utils.unif(n_target)
+
+ # loss matrix
+ M = ot.dist(xs, xt)
+ M /= M.max()
+
+ t, _, _ = ot.regpath.regularization_path(a, b, M, reg=1e-8,
+ semi_relaxed=True)
+
+ G = t.reshape((n_source, n_target))
+ np.testing.assert_allclose(a, G.sum(1), atol=1e-05)
+ np.testing.assert_allclose(b, G.sum(0), atol=1e-10)
diff --git a/test/test_sliced.py b/test/test_sliced.py
new file mode 100644
index 0000000..245202c
--- /dev/null
+++ b/test/test_sliced.py
@@ -0,0 +1,213 @@
+"""Tests for module sliced"""
+
+# Author: Adrien Corenflos <adrien.corenflos@aalto.fi>
+# Nicolas Courty <ncourty@irisa.fr>
+#
+# License: MIT License
+
+import numpy as np
+import pytest
+
+import ot
+from ot.sliced import get_random_projections
+
+
+def test_get_random_projections():
+ rng = np.random.RandomState(0)
+ projections = get_random_projections(1000, 50, rng)
+ np.testing.assert_almost_equal(np.sum(projections ** 2, 0), 1.)
+
+
+def test_sliced_same_dist():
+ n = 100
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 2)
+ u = ot.utils.unif(n)
+
+ res = ot.sliced_wasserstein_distance(x, x, u, u, 10, seed=rng)
+ np.testing.assert_almost_equal(res, 0.)
+
+
+def test_sliced_bad_shapes():
+ n = 100
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 2)
+ y = rng.randn(n, 4)
+ u = ot.utils.unif(n)
+
+ with pytest.raises(ValueError):
+ _ = ot.sliced_wasserstein_distance(x, y, u, u, 10, seed=rng)
+
+
+def test_sliced_log():
+ n = 100
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 4)
+ y = rng.randn(n, 4)
+ u = ot.utils.unif(n)
+
+ res, log = ot.sliced_wasserstein_distance(x, y, u, u, 10, p=1, seed=rng, log=True)
+ assert len(log) == 2
+ projections = log["projections"]
+ projected_emds = log["projected_emds"]
+
+ assert projections.shape[1] == len(projected_emds) == 10
+ for emd in projected_emds:
+ assert emd > 0
+
+
+def test_sliced_different_dists():
+ n = 100
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 2)
+ u = ot.utils.unif(n)
+ y = rng.randn(n, 2)
+
+ res = ot.sliced_wasserstein_distance(x, y, u, u, 10, seed=rng)
+ assert res > 0.
+
+
+def test_1d_sliced_equals_emd():
+ n = 100
+ m = 120
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 1)
+ a = rng.uniform(0, 1, n)
+ a /= a.sum()
+ y = rng.randn(m, 1)
+ u = ot.utils.unif(m)
+ res = ot.sliced_wasserstein_distance(x, y, a, u, 10, seed=42)
+ expected = ot.emd2_1d(x.squeeze(), y.squeeze(), a, u)
+ np.testing.assert_almost_equal(res ** 2, expected)
+
+
+def test_max_sliced_same_dist():
+ n = 100
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 2)
+ u = ot.utils.unif(n)
+
+ res = ot.max_sliced_wasserstein_distance(x, x, u, u, 10, seed=rng)
+ np.testing.assert_almost_equal(res, 0.)
+
+
+def test_max_sliced_different_dists():
+ n = 100
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 2)
+ u = ot.utils.unif(n)
+ y = rng.randn(n, 2)
+
+ res, log = ot.max_sliced_wasserstein_distance(x, y, u, u, 10, seed=rng, log=True)
+ assert res > 0.
+
+
+def test_sliced_backend(nx):
+
+ n = 100
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 2)
+ y = rng.randn(2 * n, 2)
+
+ P = rng.randn(2, 20)
+ P = P / np.sqrt((P**2).sum(0, keepdims=True))
+
+ n_projections = 20
+
+ xb = nx.from_numpy(x)
+ yb = nx.from_numpy(y)
+ Pb = nx.from_numpy(P)
+
+ val0 = ot.sliced_wasserstein_distance(x, y, projections=P)
+
+ val = ot.sliced_wasserstein_distance(xb, yb, n_projections=n_projections, seed=0)
+ val2 = ot.sliced_wasserstein_distance(xb, yb, n_projections=n_projections, seed=0)
+
+ assert val > 0
+ assert val == val2
+
+ valb = nx.to_numpy(ot.sliced_wasserstein_distance(xb, yb, projections=Pb))
+
+ assert np.allclose(val0, valb)
+
+
+def test_sliced_backend_type_devices(nx):
+ n = 100
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 2)
+ y = rng.randn(2 * n, 2)
+
+ P = rng.randn(2, 20)
+ P = P / np.sqrt((P**2).sum(0, keepdims=True))
+
+ for tp in nx.__type_list__:
+ print(nx.dtype_device(tp))
+
+ xb = nx.from_numpy(x, type_as=tp)
+ yb = nx.from_numpy(y, type_as=tp)
+ Pb = nx.from_numpy(P, type_as=tp)
+
+ valb = ot.sliced_wasserstein_distance(xb, yb, projections=Pb)
+
+ nx.assert_same_dtype_device(xb, valb)
+
+
+def test_max_sliced_backend(nx):
+
+ n = 100
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 2)
+ y = rng.randn(2 * n, 2)
+
+ P = rng.randn(2, 20)
+ P = P / np.sqrt((P**2).sum(0, keepdims=True))
+
+ n_projections = 20
+
+ xb = nx.from_numpy(x)
+ yb = nx.from_numpy(y)
+ Pb = nx.from_numpy(P)
+
+ val0 = ot.max_sliced_wasserstein_distance(x, y, projections=P)
+
+ val = ot.max_sliced_wasserstein_distance(xb, yb, n_projections=n_projections, seed=0)
+ val2 = ot.max_sliced_wasserstein_distance(xb, yb, n_projections=n_projections, seed=0)
+
+ assert val > 0
+ assert val == val2
+
+ valb = nx.to_numpy(ot.max_sliced_wasserstein_distance(xb, yb, projections=Pb))
+
+ assert np.allclose(val0, valb)
+
+
+def test_max_sliced_backend_type_devices(nx):
+ n = 100
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 2)
+ y = rng.randn(2 * n, 2)
+
+ P = rng.randn(2, 20)
+ P = P / np.sqrt((P**2).sum(0, keepdims=True))
+
+ for tp in nx.__type_list__:
+ print(nx.dtype_device(tp))
+
+ xb = nx.from_numpy(x, type_as=tp)
+ yb = nx.from_numpy(y, type_as=tp)
+ Pb = nx.from_numpy(P, type_as=tp)
+
+ valb = ot.max_sliced_wasserstein_distance(xb, yb, projections=Pb)
+
+ nx.assert_same_dtype_device(xb, valb)
diff --git a/test/test_smooth.py b/test/test_smooth.py
index 2afa4f8..31e0b2e 100644
--- a/test/test_smooth.py
+++ b/test/test_smooth.py
@@ -25,16 +25,16 @@ def test_smooth_ot_dual():
Gl2, log = ot.smooth.smooth_ot_dual(u, u, M, 1, reg_type='l2', log=True, stopThr=1e-10)
- # check constratints
+ # check constraints
np.testing.assert_allclose(
u, Gl2.sum(1), atol=1e-05) # cf convergence sinkhorn
np.testing.assert_allclose(
u, Gl2.sum(0), atol=1e-05) # cf convergence sinkhorn
- # kl regyularisation
+ # kl regularisation
G = ot.smooth.smooth_ot_dual(u, u, M, 1, reg_type='kl', stopThr=1e-10)
- # check constratints
+ # check constraints
np.testing.assert_allclose(
u, G.sum(1), atol=1e-05) # cf convergence sinkhorn
np.testing.assert_allclose(
@@ -60,16 +60,16 @@ def test_smooth_ot_semi_dual():
Gl2, log = ot.smooth.smooth_ot_semi_dual(u, u, M, 1, reg_type='l2', log=True, stopThr=1e-10)
- # check constratints
+ # check constraints
np.testing.assert_allclose(
u, Gl2.sum(1), atol=1e-05) # cf convergence sinkhorn
np.testing.assert_allclose(
u, Gl2.sum(0), atol=1e-05) # cf convergence sinkhorn
- # kl regyularisation
+ # kl regularisation
G = ot.smooth.smooth_ot_semi_dual(u, u, M, 1, reg_type='kl', stopThr=1e-10)
- # check constratints
+ # check constraints
np.testing.assert_allclose(
u, G.sum(1), atol=1e-05) # cf convergence sinkhorn
np.testing.assert_allclose(
diff --git a/test/test_stochastic.py b/test/test_stochastic.py
index 155622c..736df32 100644
--- a/test/test_stochastic.py
+++ b/test/test_stochastic.py
@@ -30,7 +30,7 @@ import ot
def test_stochastic_sag():
# test sag
- n = 15
+ n = 10
reg = 1
numItermax = 30000
rng = np.random.RandomState(0)
@@ -43,11 +43,11 @@ def test_stochastic_sag():
G = ot.stochastic.solve_semi_dual_entropic(u, u, M, reg, "sag",
numItermax=numItermax)
- # check constratints
+ # check constraints
np.testing.assert_allclose(
- u, G.sum(1), atol=1e-04) # cf convergence sag
+ u, G.sum(1), atol=1e-03) # cf convergence sag
np.testing.assert_allclose(
- u, G.sum(0), atol=1e-04) # cf convergence sag
+ u, G.sum(0), atol=1e-03) # cf convergence sag
#############################################################################
@@ -60,9 +60,9 @@ def test_stochastic_sag():
def test_stochastic_asgd():
# test asgd
- n = 15
+ n = 10
reg = 1
- numItermax = 100000
+ numItermax = 10000
rng = np.random.RandomState(0)
x = rng.randn(n, 2)
@@ -73,11 +73,11 @@ def test_stochastic_asgd():
G, log = ot.stochastic.solve_semi_dual_entropic(u, u, M, reg, "asgd",
numItermax=numItermax, log=True)
- # check constratints
+ # check constraints
np.testing.assert_allclose(
- u, G.sum(1), atol=1e-03) # cf convergence asgd
+ u, G.sum(1), atol=1e-02) # cf convergence asgd
np.testing.assert_allclose(
- u, G.sum(0), atol=1e-03) # cf convergence asgd
+ u, G.sum(0), atol=1e-02) # cf convergence asgd
#############################################################################
@@ -90,9 +90,9 @@ def test_stochastic_asgd():
def test_sag_asgd_sinkhorn():
# test all algorithms
- n = 15
+ n = 10
reg = 1
- nb_iter = 100000
+ nb_iter = 10000
rng = np.random.RandomState(0)
x = rng.randn(n, 2)
@@ -105,19 +105,19 @@ def test_sag_asgd_sinkhorn():
numItermax=nb_iter)
G_sinkhorn = ot.sinkhorn(u, u, M, reg)
- # check constratints
+ # check constraints
np.testing.assert_allclose(
- G_sag.sum(1), G_sinkhorn.sum(1), atol=1e-03)
+ G_sag.sum(1), G_sinkhorn.sum(1), atol=1e-02)
np.testing.assert_allclose(
- G_sag.sum(0), G_sinkhorn.sum(0), atol=1e-03)
+ G_sag.sum(0), G_sinkhorn.sum(0), atol=1e-02)
np.testing.assert_allclose(
- G_asgd.sum(1), G_sinkhorn.sum(1), atol=1e-03)
+ G_asgd.sum(1), G_sinkhorn.sum(1), atol=1e-02)
np.testing.assert_allclose(
- G_asgd.sum(0), G_sinkhorn.sum(0), atol=1e-03)
+ G_asgd.sum(0), G_sinkhorn.sum(0), atol=1e-02)
np.testing.assert_allclose(
- G_sag, G_sinkhorn, atol=1e-03) # cf convergence sag
+ G_sag, G_sinkhorn, atol=1e-02) # cf convergence sag
np.testing.assert_allclose(
- G_asgd, G_sinkhorn, atol=1e-03) # cf convergence asgd
+ G_asgd, G_sinkhorn, atol=1e-02) # cf convergence asgd
#############################################################################
@@ -136,7 +136,7 @@ def test_stochastic_dual_sgd():
# test sgd
n = 10
reg = 1
- numItermax = 15000
+ numItermax = 5000
batch_size = 10
rng = np.random.RandomState(0)
@@ -148,7 +148,7 @@ def test_stochastic_dual_sgd():
G, log = ot.stochastic.solve_dual_entropic(u, u, M, reg, batch_size,
numItermax=numItermax, log=True)
- # check constratints
+ # check constraints
np.testing.assert_allclose(
u, G.sum(1), atol=1e-03) # cf convergence sgd
np.testing.assert_allclose(
@@ -167,7 +167,7 @@ def test_dual_sgd_sinkhorn():
# test all dual algorithms
n = 10
reg = 1
- nb_iter = 15000
+ nb_iter = 5000
batch_size = 10
rng = np.random.RandomState(0)
@@ -181,13 +181,13 @@ def test_dual_sgd_sinkhorn():
G_sinkhorn = ot.sinkhorn(u, u, M, reg)
- # check constratints
+ # check constraints
np.testing.assert_allclose(
- G_sgd.sum(1), G_sinkhorn.sum(1), atol=1e-03)
+ G_sgd.sum(1), G_sinkhorn.sum(1), atol=1e-02)
np.testing.assert_allclose(
- G_sgd.sum(0), G_sinkhorn.sum(0), atol=1e-03)
+ G_sgd.sum(0), G_sinkhorn.sum(0), atol=1e-02)
np.testing.assert_allclose(
- G_sgd, G_sinkhorn, atol=1e-03) # cf convergence sgd
+ G_sgd, G_sinkhorn, atol=1e-02) # cf convergence sgd
# Test gaussian
n = 30
@@ -206,7 +206,7 @@ def test_dual_sgd_sinkhorn():
G_sinkhorn = ot.sinkhorn(a, b, M, reg)
- # check constratints
+ # check constraints
np.testing.assert_allclose(
G_sgd.sum(1), G_sinkhorn.sum(1), atol=1e-03)
np.testing.assert_allclose(
diff --git a/test/test_unbalanced.py b/test/test_unbalanced.py
index dfeaad9..e8349d1 100644
--- a/test/test_unbalanced.py
+++ b/test/test_unbalanced.py
@@ -115,7 +115,8 @@ def test_stabilized_vs_sinkhorn():
G, log = ot.unbalanced.sinkhorn_unbalanced2(a, b, M, reg=epsilon,
method="sinkhorn_stabilized",
reg_m=reg_m,
- log=True)
+ log=True,
+ verbose=True)
G2, log2 = ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, reg_m,
method="sinkhorn", log=True)
@@ -138,7 +139,7 @@ def test_unbalanced_barycenter(method):
reg_m = 1.
q, log = barycenter_unbalanced(A, M, reg=epsilon, reg_m=reg_m,
- method=method, log=True)
+ method=method, log=True, verbose=True)
# check fixed point equations
fi = reg_m / (reg_m + epsilon)
logA = np.log(A + 1e-16)
@@ -173,6 +174,7 @@ def test_barycenter_stabilized_vs_sinkhorn():
reg_m=reg_m, log=True,
tau=100,
method="sinkhorn_stabilized",
+ verbose=True
)
q, log = barycenter_unbalanced(A, M, reg=epsilon, reg_m=reg_m,
method="sinkhorn",
@@ -182,6 +184,33 @@ def test_barycenter_stabilized_vs_sinkhorn():
q, qstable, atol=1e-05)
+def test_wrong_method():
+
+ n = 10
+ rng = np.random.RandomState(42)
+
+ x = rng.randn(n, 2)
+ a = ot.utils.unif(n)
+
+ # make dists unbalanced
+ b = ot.utils.unif(n) * 1.5
+
+ M = ot.dist(x, x)
+ epsilon = 1.
+ reg_m = 1.
+
+ with pytest.raises(ValueError):
+ ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon,
+ reg_m=reg_m,
+ method='badmethod',
+ log=True,
+ verbose=True)
+ with pytest.raises(ValueError):
+ ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, reg_m,
+ method='badmethod',
+ verbose=True)
+
+
def test_implemented_methods():
IMPLEMENTED_METHODS = ['sinkhorn', 'sinkhorn_stabilized']
TO_BE_IMPLEMENTED_METHODS = ['sinkhorn_reg_scaling']
diff --git a/test/test_utils.py b/test/test_utils.py
index db9cda6..40f4e49 100644
--- a/test/test_utils.py
+++ b/test/test_utils.py
@@ -4,10 +4,41 @@
#
# License: MIT License
-
import ot
import numpy as np
import sys
+import pytest
+
+
+def test_proj_simplex(nx):
+ n = 10
+ rng = np.random.RandomState(0)
+
+ # test on matrix when projection is done on axis 0
+ x = rng.randn(n, 2)
+ x1 = nx.from_numpy(x)
+
+ # all projections should sum to 1
+ proj = ot.utils.proj_simplex(x1)
+ l1 = np.sum(nx.to_numpy(proj), axis=0)
+ l2 = np.ones(2)
+ np.testing.assert_allclose(l1, l2, atol=1e-5)
+
+ # all projections should sum to 3
+ proj = ot.utils.proj_simplex(x1, 3)
+ l1 = np.sum(nx.to_numpy(proj), axis=0)
+ l2 = 3 * np.ones(2)
+ np.testing.assert_allclose(l1, l2, atol=1e-5)
+
+ # tets on vector
+ x = rng.randn(n)
+ x1 = nx.from_numpy(x)
+
+ # all projections should sum to 1
+ proj = ot.utils.proj_simplex(x1)
+ l1 = np.sum(nx.to_numpy(proj), axis=0)
+ l2 = np.ones(2)
+ np.testing.assert_allclose(l1, l2, atol=1e-5)
def test_parmap():
@@ -45,8 +76,8 @@ def test_tic_toc():
def test_kernel():
n = 100
-
- x = np.random.randn(n, 2)
+ rng = np.random.RandomState(0)
+ x = rng.randn(n, 2)
K = ot.utils.kernel(x, x)
@@ -67,7 +98,8 @@ def test_dist():
n = 100
- x = np.random.randn(n, 2)
+ rng = np.random.RandomState(0)
+ x = rng.randn(n, 2)
D = np.zeros((n, n))
for i in range(n):
@@ -77,9 +109,31 @@ def test_dist():
D2 = ot.dist(x, x)
D3 = ot.dist(x)
+ D4 = ot.dist(x, x, metric='minkowski', p=2)
+
+ assert D4[0, 1] == D4[1, 0]
+
# dist shoul return squared euclidean
- np.testing.assert_allclose(D, D2)
- np.testing.assert_allclose(D, D3)
+ np.testing.assert_allclose(D, D2, atol=1e-14)
+ np.testing.assert_allclose(D, D3, atol=1e-14)
+
+
+def test_dist_backends(nx):
+
+ n = 100
+ rng = np.random.RandomState(0)
+ x = rng.randn(n, 2)
+ x1 = nx.from_numpy(x)
+
+ lst_metric = ['euclidean', 'sqeuclidean']
+
+ for metric in lst_metric:
+
+ D = ot.dist(x, x, metric=metric)
+ D1 = ot.dist(x1, x1, metric=metric)
+
+ # low atol because jax forces float32
+ np.testing.assert_allclose(D, nx.to_numpy(D1), atol=1e-5)
def test_dist0():
@@ -95,9 +149,11 @@ def test_dots():
n1, n2, n3, n4 = 100, 50, 200, 100
- A = np.random.randn(n1, n2)
- B = np.random.randn(n2, n3)
- C = np.random.randn(n3, n4)
+ rng = np.random.RandomState(0)
+
+ A = rng.randn(n1, n2)
+ B = rng.randn(n2, n3)
+ C = rng.randn(n3, n4)
X1 = ot.utils.dots(A, B, C)
@@ -169,6 +225,13 @@ def test_deprecated_func():
class Class():
pass
+ with pytest.warns(DeprecationWarning):
+ fun()
+
+ with pytest.warns(DeprecationWarning):
+ cl = Class()
+ print(cl)
+
if sys.version_info < (3, 5):
print('Not tested')
else:
@@ -199,4 +262,7 @@ def test_BaseEstimator():
params['first'] = 'spam again'
cl.set_params(**params)
+ with pytest.raises(ValueError):
+ cl.set_params(bibi=10)
+
assert cl.first == 'spam again'